diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/core/algorithms.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/core/algorithms.py new file mode 100644 index 0000000000000000000000000000000000000000..15a07da76d2f7c11daf7492dc7fcab1a78d328cf --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/core/algorithms.py @@ -0,0 +1,1747 @@ +""" +Generic data algorithms. This module is experimental at the moment and not +intended for public consumption +""" +from __future__ import annotations + +import decimal +import operator +from textwrap import dedent +from typing import ( + TYPE_CHECKING, + Literal, + cast, +) +import warnings + +import numpy as np + +from pandas._libs import ( + algos, + hashtable as htable, + iNaT, + lib, +) +from pandas._typing import ( + AnyArrayLike, + ArrayLike, + AxisInt, + DtypeObj, + TakeIndexer, + npt, +) +from pandas.util._decorators import doc +from pandas.util._exceptions import find_stack_level + +from pandas.core.dtypes.cast import ( + construct_1d_object_array_from_listlike, + np_find_common_type, +) +from pandas.core.dtypes.common import ( + ensure_float64, + ensure_object, + ensure_platform_int, + is_array_like, + is_bool_dtype, + is_complex_dtype, + is_dict_like, + is_extension_array_dtype, + is_float_dtype, + is_integer, + is_integer_dtype, + is_list_like, + is_object_dtype, + is_signed_integer_dtype, + needs_i8_conversion, +) +from pandas.core.dtypes.concat import concat_compat +from pandas.core.dtypes.dtypes import ( + BaseMaskedDtype, + CategoricalDtype, + ExtensionDtype, + NumpyEADtype, +) +from pandas.core.dtypes.generic import ( + ABCDatetimeArray, + ABCExtensionArray, + ABCIndex, + ABCMultiIndex, + ABCSeries, + ABCTimedeltaArray, +) +from pandas.core.dtypes.missing import ( + isna, + na_value_for_dtype, +) + +from pandas.core.array_algos.take import take_nd +from pandas.core.construction import ( + array as pd_array, + ensure_wrapped_if_datetimelike, + extract_array, +) +from pandas.core.indexers import validate_indices + +if TYPE_CHECKING: + from pandas._typing import ( + ListLike, + NumpySorter, + NumpyValueArrayLike, + ) + + from pandas import ( + Categorical, + Index, + Series, + ) + from pandas.core.arrays import ( + BaseMaskedArray, + ExtensionArray, + ) + + +# --------------- # +# dtype access # +# --------------- # +def _ensure_data(values: ArrayLike) -> np.ndarray: + """ + routine to ensure that our data is of the correct + input dtype for lower-level routines + + This will coerce: + - ints -> int64 + - uint -> uint64 + - bool -> uint8 + - datetimelike -> i8 + - datetime64tz -> i8 (in local tz) + - categorical -> codes + + Parameters + ---------- + values : np.ndarray or ExtensionArray + + Returns + ------- + np.ndarray + """ + + if not isinstance(values, ABCMultiIndex): + # extract_array would raise + values = extract_array(values, extract_numpy=True) + + if is_object_dtype(values.dtype): + return ensure_object(np.asarray(values)) + + elif isinstance(values.dtype, BaseMaskedDtype): + # i.e. BooleanArray, FloatingArray, IntegerArray + values = cast("BaseMaskedArray", values) + if not values._hasna: + # No pd.NAs -> We can avoid an object-dtype cast (and copy) GH#41816 + # recurse to avoid re-implementing logic for eg bool->uint8 + return _ensure_data(values._data) + return np.asarray(values) + + elif isinstance(values.dtype, CategoricalDtype): + # NB: cases that go through here should NOT be using _reconstruct_data + # on the back-end. + values = cast("Categorical", values) + return values.codes + + elif is_bool_dtype(values.dtype): + if isinstance(values, np.ndarray): + # i.e. actually dtype == np.dtype("bool") + return np.asarray(values).view("uint8") + else: + # e.g. Sparse[bool, False] # TODO: no test cases get here + return np.asarray(values).astype("uint8", copy=False) + + elif is_integer_dtype(values.dtype): + return np.asarray(values) + + elif is_float_dtype(values.dtype): + # Note: checking `values.dtype == "float128"` raises on Windows and 32bit + # error: Item "ExtensionDtype" of "Union[Any, ExtensionDtype, dtype[Any]]" + # has no attribute "itemsize" + if values.dtype.itemsize in [2, 12, 16]: # type: ignore[union-attr] + # we dont (yet) have float128 hashtable support + return ensure_float64(values) + return np.asarray(values) + + elif is_complex_dtype(values.dtype): + return cast(np.ndarray, values) + + # datetimelike + elif needs_i8_conversion(values.dtype): + npvalues = values.view("i8") + npvalues = cast(np.ndarray, npvalues) + return npvalues + + # we have failed, return object + values = np.asarray(values, dtype=object) + return ensure_object(values) + + +def _reconstruct_data( + values: ArrayLike, dtype: DtypeObj, original: AnyArrayLike +) -> ArrayLike: + """ + reverse of _ensure_data + + Parameters + ---------- + values : np.ndarray or ExtensionArray + dtype : np.dtype or ExtensionDtype + original : AnyArrayLike + + Returns + ------- + ExtensionArray or np.ndarray + """ + if isinstance(values, ABCExtensionArray) and values.dtype == dtype: + # Catch DatetimeArray/TimedeltaArray + return values + + if not isinstance(dtype, np.dtype): + # i.e. ExtensionDtype; note we have ruled out above the possibility + # that values.dtype == dtype + cls = dtype.construct_array_type() + + values = cls._from_sequence(values, dtype=dtype) + + else: + values = values.astype(dtype, copy=False) + + return values + + +def _ensure_arraylike(values, func_name: str) -> ArrayLike: + """ + ensure that we are arraylike if not already + """ + if not isinstance(values, (ABCIndex, ABCSeries, ABCExtensionArray, np.ndarray)): + # GH#52986 + if func_name != "isin-targets": + # Make an exception for the comps argument in isin. + warnings.warn( + f"{func_name} with argument that is not not a Series, Index, " + "ExtensionArray, or np.ndarray is deprecated and will raise in a " + "future version.", + FutureWarning, + stacklevel=find_stack_level(), + ) + + inferred = lib.infer_dtype(values, skipna=False) + if inferred in ["mixed", "string", "mixed-integer"]: + # "mixed-integer" to ensure we do not cast ["ss", 42] to str GH#22160 + if isinstance(values, tuple): + values = list(values) + values = construct_1d_object_array_from_listlike(values) + else: + values = np.asarray(values) + return values + + +_hashtables = { + "complex128": htable.Complex128HashTable, + "complex64": htable.Complex64HashTable, + "float64": htable.Float64HashTable, + "float32": htable.Float32HashTable, + "uint64": htable.UInt64HashTable, + "uint32": htable.UInt32HashTable, + "uint16": htable.UInt16HashTable, + "uint8": htable.UInt8HashTable, + "int64": htable.Int64HashTable, + "int32": htable.Int32HashTable, + "int16": htable.Int16HashTable, + "int8": htable.Int8HashTable, + "string": htable.StringHashTable, + "object": htable.PyObjectHashTable, +} + + +def _get_hashtable_algo(values: np.ndarray): + """ + Parameters + ---------- + values : np.ndarray + + Returns + ------- + htable : HashTable subclass + values : ndarray + """ + values = _ensure_data(values) + + ndtype = _check_object_for_strings(values) + hashtable = _hashtables[ndtype] + return hashtable, values + + +def _check_object_for_strings(values: np.ndarray) -> str: + """ + Check if we can use string hashtable instead of object hashtable. + + Parameters + ---------- + values : ndarray + + Returns + ------- + str + """ + ndtype = values.dtype.name + if ndtype == "object": + # it's cheaper to use a String Hash Table than Object; we infer + # including nulls because that is the only difference between + # StringHashTable and ObjectHashtable + if lib.is_string_array(values, skipna=False): + ndtype = "string" + return ndtype + + +# --------------- # +# top-level algos # +# --------------- # + + +def unique(values): + """ + Return unique values based on a hash table. + + Uniques are returned in order of appearance. This does NOT sort. + + Significantly faster than numpy.unique for long enough sequences. + Includes NA values. + + Parameters + ---------- + values : 1d array-like + + Returns + ------- + numpy.ndarray or ExtensionArray + + The return can be: + + * Index : when the input is an Index + * Categorical : when the input is a Categorical dtype + * ndarray : when the input is a Series/ndarray + + Return numpy.ndarray or ExtensionArray. + + See Also + -------- + Index.unique : Return unique values from an Index. + Series.unique : Return unique values of Series object. + + Examples + -------- + >>> pd.unique(pd.Series([2, 1, 3, 3])) + array([2, 1, 3]) + + >>> pd.unique(pd.Series([2] + [1] * 5)) + array([2, 1]) + + >>> pd.unique(pd.Series([pd.Timestamp("20160101"), pd.Timestamp("20160101")])) + array(['2016-01-01T00:00:00.000000000'], dtype='datetime64[ns]') + + >>> pd.unique( + ... pd.Series( + ... [ + ... pd.Timestamp("20160101", tz="US/Eastern"), + ... pd.Timestamp("20160101", tz="US/Eastern"), + ... ] + ... ) + ... ) + + ['2016-01-01 00:00:00-05:00'] + Length: 1, dtype: datetime64[ns, US/Eastern] + + >>> pd.unique( + ... pd.Index( + ... [ + ... pd.Timestamp("20160101", tz="US/Eastern"), + ... pd.Timestamp("20160101", tz="US/Eastern"), + ... ] + ... ) + ... ) + DatetimeIndex(['2016-01-01 00:00:00-05:00'], + dtype='datetime64[ns, US/Eastern]', + freq=None) + + >>> pd.unique(np.array(list("baabc"), dtype="O")) + array(['b', 'a', 'c'], dtype=object) + + An unordered Categorical will return categories in the + order of appearance. + + >>> pd.unique(pd.Series(pd.Categorical(list("baabc")))) + ['b', 'a', 'c'] + Categories (3, object): ['a', 'b', 'c'] + + >>> pd.unique(pd.Series(pd.Categorical(list("baabc"), categories=list("abc")))) + ['b', 'a', 'c'] + Categories (3, object): ['a', 'b', 'c'] + + An ordered Categorical preserves the category ordering. + + >>> pd.unique( + ... pd.Series( + ... pd.Categorical(list("baabc"), categories=list("abc"), ordered=True) + ... ) + ... ) + ['b', 'a', 'c'] + Categories (3, object): ['a' < 'b' < 'c'] + + An array of tuples + + >>> pd.unique(pd.Series([("a", "b"), ("b", "a"), ("a", "c"), ("b", "a")]).values) + array([('a', 'b'), ('b', 'a'), ('a', 'c')], dtype=object) + """ + return unique_with_mask(values) + + +def nunique_ints(values: ArrayLike) -> int: + """ + Return the number of unique values for integer array-likes. + + Significantly faster than pandas.unique for long enough sequences. + No checks are done to ensure input is integral. + + Parameters + ---------- + values : 1d array-like + + Returns + ------- + int : The number of unique values in ``values`` + """ + if len(values) == 0: + return 0 + values = _ensure_data(values) + # bincount requires intp + result = (np.bincount(values.ravel().astype("intp")) != 0).sum() + return result + + +def unique_with_mask(values, mask: npt.NDArray[np.bool_] | None = None): + """See algorithms.unique for docs. Takes a mask for masked arrays.""" + values = _ensure_arraylike(values, func_name="unique") + + if isinstance(values.dtype, ExtensionDtype): + # Dispatch to extension dtype's unique. + return values.unique() + + original = values + hashtable, values = _get_hashtable_algo(values) + + table = hashtable(len(values)) + if mask is None: + uniques = table.unique(values) + uniques = _reconstruct_data(uniques, original.dtype, original) + return uniques + + else: + uniques, mask = table.unique(values, mask=mask) + uniques = _reconstruct_data(uniques, original.dtype, original) + assert mask is not None # for mypy + return uniques, mask.astype("bool") + + +unique1d = unique + + +_MINIMUM_COMP_ARR_LEN = 1_000_000 + + +def isin(comps: ListLike, values: ListLike) -> npt.NDArray[np.bool_]: + """ + Compute the isin boolean array. + + Parameters + ---------- + comps : list-like + values : list-like + + Returns + ------- + ndarray[bool] + Same length as `comps`. + """ + if not is_list_like(comps): + raise TypeError( + "only list-like objects are allowed to be passed " + f"to isin(), you passed a `{type(comps).__name__}`" + ) + if not is_list_like(values): + raise TypeError( + "only list-like objects are allowed to be passed " + f"to isin(), you passed a `{type(values).__name__}`" + ) + + if not isinstance(values, (ABCIndex, ABCSeries, ABCExtensionArray, np.ndarray)): + orig_values = list(values) + values = _ensure_arraylike(orig_values, func_name="isin-targets") + + if ( + len(values) > 0 + and values.dtype.kind in "iufcb" + and not is_signed_integer_dtype(comps) + ): + # GH#46485 Use object to avoid upcast to float64 later + # TODO: Share with _find_common_type_compat + values = construct_1d_object_array_from_listlike(orig_values) + + elif isinstance(values, ABCMultiIndex): + # Avoid raising in extract_array + values = np.array(values) + else: + values = extract_array(values, extract_numpy=True, extract_range=True) + + comps_array = _ensure_arraylike(comps, func_name="isin") + comps_array = extract_array(comps_array, extract_numpy=True) + if not isinstance(comps_array, np.ndarray): + # i.e. Extension Array + return comps_array.isin(values) + + elif needs_i8_conversion(comps_array.dtype): + # Dispatch to DatetimeLikeArrayMixin.isin + return pd_array(comps_array).isin(values) + elif needs_i8_conversion(values.dtype) and not is_object_dtype(comps_array.dtype): + # e.g. comps_array are integers and values are datetime64s + return np.zeros(comps_array.shape, dtype=bool) + # TODO: not quite right ... Sparse/Categorical + elif needs_i8_conversion(values.dtype): + return isin(comps_array, values.astype(object)) + + elif isinstance(values.dtype, ExtensionDtype): + return isin(np.asarray(comps_array), np.asarray(values)) + + # GH16012 + # Ensure np.isin doesn't get object types or it *may* throw an exception + # Albeit hashmap has O(1) look-up (vs. O(logn) in sorted array), + # isin is faster for small sizes + if ( + len(comps_array) > _MINIMUM_COMP_ARR_LEN + and len(values) <= 26 + and comps_array.dtype != object + ): + # If the values include nan we need to check for nan explicitly + # since np.nan it not equal to np.nan + if isna(values).any(): + + def f(c, v): + return np.logical_or(np.isin(c, v).ravel(), np.isnan(c)) + + else: + f = lambda a, b: np.isin(a, b).ravel() + + else: + common = np_find_common_type(values.dtype, comps_array.dtype) + values = values.astype(common, copy=False) + comps_array = comps_array.astype(common, copy=False) + f = htable.ismember + + return f(comps_array, values) + + +def factorize_array( + values: np.ndarray, + use_na_sentinel: bool = True, + size_hint: int | None = None, + na_value: object = None, + mask: npt.NDArray[np.bool_] | None = None, +) -> tuple[npt.NDArray[np.intp], np.ndarray]: + """ + Factorize a numpy array to codes and uniques. + + This doesn't do any coercion of types or unboxing before factorization. + + Parameters + ---------- + values : ndarray + use_na_sentinel : bool, default True + If True, the sentinel -1 will be used for NaN values. If False, + NaN values will be encoded as non-negative integers and will not drop the + NaN from the uniques of the values. + size_hint : int, optional + Passed through to the hashtable's 'get_labels' method + na_value : object, optional + A value in `values` to consider missing. Note: only use this + parameter when you know that you don't have any values pandas would + consider missing in the array (NaN for float data, iNaT for + datetimes, etc.). + mask : ndarray[bool], optional + If not None, the mask is used as indicator for missing values + (True = missing, False = valid) instead of `na_value` or + condition "val != val". + + Returns + ------- + codes : ndarray[np.intp] + uniques : ndarray + """ + original = values + if values.dtype.kind in "mM": + # _get_hashtable_algo will cast dt64/td64 to i8 via _ensure_data, so we + # need to do the same to na_value. We are assuming here that the passed + # na_value is an appropriately-typed NaT. + # e.g. test_where_datetimelike_categorical + na_value = iNaT + + hash_klass, values = _get_hashtable_algo(values) + + table = hash_klass(size_hint or len(values)) + uniques, codes = table.factorize( + values, + na_sentinel=-1, + na_value=na_value, + mask=mask, + ignore_na=use_na_sentinel, + ) + + # re-cast e.g. i8->dt64/td64, uint8->bool + uniques = _reconstruct_data(uniques, original.dtype, original) + + codes = ensure_platform_int(codes) + return codes, uniques + + +@doc( + values=dedent( + """\ + values : sequence + A 1-D sequence. Sequences that aren't pandas objects are + coerced to ndarrays before factorization. + """ + ), + sort=dedent( + """\ + sort : bool, default False + Sort `uniques` and shuffle `codes` to maintain the + relationship. + """ + ), + size_hint=dedent( + """\ + size_hint : int, optional + Hint to the hashtable sizer. + """ + ), +) +def factorize( + values, + sort: bool = False, + use_na_sentinel: bool = True, + size_hint: int | None = None, +) -> tuple[np.ndarray, np.ndarray | Index]: + """ + Encode the object as an enumerated type or categorical variable. + + This method is useful for obtaining a numeric representation of an + array when all that matters is identifying distinct values. `factorize` + is available as both a top-level function :func:`pandas.factorize`, + and as a method :meth:`Series.factorize` and :meth:`Index.factorize`. + + Parameters + ---------- + {values}{sort} + use_na_sentinel : bool, default True + If True, the sentinel -1 will be used for NaN values. If False, + NaN values will be encoded as non-negative integers and will not drop the + NaN from the uniques of the values. + + .. versionadded:: 1.5.0 + {size_hint}\ + + Returns + ------- + codes : ndarray + An integer ndarray that's an indexer into `uniques`. + ``uniques.take(codes)`` will have the same values as `values`. + uniques : ndarray, Index, or Categorical + The unique valid values. When `values` is Categorical, `uniques` + is a Categorical. When `values` is some other pandas object, an + `Index` is returned. Otherwise, a 1-D ndarray is returned. + + .. note:: + + Even if there's a missing value in `values`, `uniques` will + *not* contain an entry for it. + + See Also + -------- + cut : Discretize continuous-valued array. + unique : Find the unique value in an array. + + Notes + ----- + Reference :ref:`the user guide ` for more examples. + + Examples + -------- + These examples all show factorize as a top-level method like + ``pd.factorize(values)``. The results are identical for methods like + :meth:`Series.factorize`. + + >>> codes, uniques = pd.factorize(np.array(['b', 'b', 'a', 'c', 'b'], dtype="O")) + >>> codes + array([0, 0, 1, 2, 0]) + >>> uniques + array(['b', 'a', 'c'], dtype=object) + + With ``sort=True``, the `uniques` will be sorted, and `codes` will be + shuffled so that the relationship is the maintained. + + >>> codes, uniques = pd.factorize(np.array(['b', 'b', 'a', 'c', 'b'], dtype="O"), + ... sort=True) + >>> codes + array([1, 1, 0, 2, 1]) + >>> uniques + array(['a', 'b', 'c'], dtype=object) + + When ``use_na_sentinel=True`` (the default), missing values are indicated in + the `codes` with the sentinel value ``-1`` and missing values are not + included in `uniques`. + + >>> codes, uniques = pd.factorize(np.array(['b', None, 'a', 'c', 'b'], dtype="O")) + >>> codes + array([ 0, -1, 1, 2, 0]) + >>> uniques + array(['b', 'a', 'c'], dtype=object) + + Thus far, we've only factorized lists (which are internally coerced to + NumPy arrays). When factorizing pandas objects, the type of `uniques` + will differ. For Categoricals, a `Categorical` is returned. + + >>> cat = pd.Categorical(['a', 'a', 'c'], categories=['a', 'b', 'c']) + >>> codes, uniques = pd.factorize(cat) + >>> codes + array([0, 0, 1]) + >>> uniques + ['a', 'c'] + Categories (3, object): ['a', 'b', 'c'] + + Notice that ``'b'`` is in ``uniques.categories``, despite not being + present in ``cat.values``. + + For all other pandas objects, an Index of the appropriate type is + returned. + + >>> cat = pd.Series(['a', 'a', 'c']) + >>> codes, uniques = pd.factorize(cat) + >>> codes + array([0, 0, 1]) + >>> uniques + Index(['a', 'c'], dtype='object') + + If NaN is in the values, and we want to include NaN in the uniques of the + values, it can be achieved by setting ``use_na_sentinel=False``. + + >>> values = np.array([1, 2, 1, np.nan]) + >>> codes, uniques = pd.factorize(values) # default: use_na_sentinel=True + >>> codes + array([ 0, 1, 0, -1]) + >>> uniques + array([1., 2.]) + + >>> codes, uniques = pd.factorize(values, use_na_sentinel=False) + >>> codes + array([0, 1, 0, 2]) + >>> uniques + array([ 1., 2., nan]) + """ + # Implementation notes: This method is responsible for 3 things + # 1.) coercing data to array-like (ndarray, Index, extension array) + # 2.) factorizing codes and uniques + # 3.) Maybe boxing the uniques in an Index + # + # Step 2 is dispatched to extension types (like Categorical). They are + # responsible only for factorization. All data coercion, sorting and boxing + # should happen here. + if isinstance(values, (ABCIndex, ABCSeries)): + return values.factorize(sort=sort, use_na_sentinel=use_na_sentinel) + + values = _ensure_arraylike(values, func_name="factorize") + original = values + + if ( + isinstance(values, (ABCDatetimeArray, ABCTimedeltaArray)) + and values.freq is not None + ): + # The presence of 'freq' means we can fast-path sorting and know there + # aren't NAs + codes, uniques = values.factorize(sort=sort) + return codes, uniques + + elif not isinstance(values, np.ndarray): + # i.e. ExtensionArray + codes, uniques = values.factorize(use_na_sentinel=use_na_sentinel) + + else: + values = np.asarray(values) # convert DTA/TDA/MultiIndex + + if not use_na_sentinel and values.dtype == object: + # factorize can now handle differentiating various types of null values. + # These can only occur when the array has object dtype. + # However, for backwards compatibility we only use the null for the + # provided dtype. This may be revisited in the future, see GH#48476. + null_mask = isna(values) + if null_mask.any(): + na_value = na_value_for_dtype(values.dtype, compat=False) + # Don't modify (potentially user-provided) array + values = np.where(null_mask, na_value, values) + + codes, uniques = factorize_array( + values, + use_na_sentinel=use_na_sentinel, + size_hint=size_hint, + ) + + if sort and len(uniques) > 0: + uniques, codes = safe_sort( + uniques, + codes, + use_na_sentinel=use_na_sentinel, + assume_unique=True, + verify=False, + ) + + uniques = _reconstruct_data(uniques, original.dtype, original) + + return codes, uniques + + +def value_counts( + values, + sort: bool = True, + ascending: bool = False, + normalize: bool = False, + bins=None, + dropna: bool = True, +) -> Series: + """ + Compute a histogram of the counts of non-null values. + + Parameters + ---------- + values : ndarray (1-d) + sort : bool, default True + Sort by values + ascending : bool, default False + Sort in ascending order + normalize: bool, default False + If True then compute a relative histogram + bins : integer, optional + Rather than count values, group them into half-open bins, + convenience for pd.cut, only works with numeric data + dropna : bool, default True + Don't include counts of NaN + + Returns + ------- + Series + """ + warnings.warn( + # GH#53493 + "pandas.value_counts is deprecated and will be removed in a " + "future version. Use pd.Series(obj).value_counts() instead.", + FutureWarning, + stacklevel=find_stack_level(), + ) + return value_counts_internal( + values, + sort=sort, + ascending=ascending, + normalize=normalize, + bins=bins, + dropna=dropna, + ) + + +def value_counts_internal( + values, + sort: bool = True, + ascending: bool = False, + normalize: bool = False, + bins=None, + dropna: bool = True, +) -> Series: + from pandas import ( + Index, + Series, + ) + + index_name = getattr(values, "name", None) + name = "proportion" if normalize else "count" + + if bins is not None: + from pandas.core.reshape.tile import cut + + if isinstance(values, Series): + values = values._values + + try: + ii = cut(values, bins, include_lowest=True) + except TypeError as err: + raise TypeError("bins argument only works with numeric data.") from err + + # count, remove nulls (from the index), and but the bins + result = ii.value_counts(dropna=dropna) + result.name = name + result = result[result.index.notna()] + result.index = result.index.astype("interval") + result = result.sort_index() + + # if we are dropna and we have NO values + if dropna and (result._values == 0).all(): + result = result.iloc[0:0] + + # normalizing is by len of all (regardless of dropna) + counts = np.array([len(ii)]) + + else: + if is_extension_array_dtype(values): + # handle Categorical and sparse, + result = Series(values, copy=False)._values.value_counts(dropna=dropna) + result.name = name + result.index.name = index_name + counts = result._values + if not isinstance(counts, np.ndarray): + # e.g. ArrowExtensionArray + counts = np.asarray(counts) + + elif isinstance(values, ABCMultiIndex): + # GH49558 + levels = list(range(values.nlevels)) + result = ( + Series(index=values, name=name) + .groupby(level=levels, dropna=dropna) + .size() + ) + result.index.names = values.names + counts = result._values + + else: + values = _ensure_arraylike(values, func_name="value_counts") + keys, counts, _ = value_counts_arraylike(values, dropna) + if keys.dtype == np.float16: + keys = keys.astype(np.float32) + + # For backwards compatibility, we let Index do its normal type + # inference, _except_ for if if infers from object to bool. + idx = Index(keys) + if idx.dtype == bool and keys.dtype == object: + idx = idx.astype(object) + elif ( + idx.dtype != keys.dtype # noqa: PLR1714 # # pylint: disable=R1714 + and idx.dtype != "string[pyarrow_numpy]" + ): + warnings.warn( + # GH#56161 + "The behavior of value_counts with object-dtype is deprecated. " + "In a future version, this will *not* perform dtype inference " + "on the resulting index. To retain the old behavior, use " + "`result.index = result.index.infer_objects()`", + FutureWarning, + stacklevel=find_stack_level(), + ) + idx.name = index_name + + result = Series(counts, index=idx, name=name, copy=False) + + if sort: + result = result.sort_values(ascending=ascending) + + if normalize: + result = result / counts.sum() + + return result + + +# Called once from SparseArray, otherwise could be private +def value_counts_arraylike( + values: np.ndarray, dropna: bool, mask: npt.NDArray[np.bool_] | None = None +) -> tuple[ArrayLike, npt.NDArray[np.int64], int]: + """ + Parameters + ---------- + values : np.ndarray + dropna : bool + mask : np.ndarray[bool] or None, default None + + Returns + ------- + uniques : np.ndarray + counts : np.ndarray[np.int64] + """ + original = values + values = _ensure_data(values) + + keys, counts, na_counter = htable.value_count(values, dropna, mask=mask) + + if needs_i8_conversion(original.dtype): + # datetime, timedelta, or period + + if dropna: + mask = keys != iNaT + keys, counts = keys[mask], counts[mask] + + res_keys = _reconstruct_data(keys, original.dtype, original) + return res_keys, counts, na_counter + + +def duplicated( + values: ArrayLike, + keep: Literal["first", "last", False] = "first", + mask: npt.NDArray[np.bool_] | None = None, +) -> npt.NDArray[np.bool_]: + """ + Return boolean ndarray denoting duplicate values. + + Parameters + ---------- + values : np.ndarray or ExtensionArray + Array over which to check for duplicate values. + keep : {'first', 'last', False}, default 'first' + - ``first`` : Mark duplicates as ``True`` except for the first + occurrence. + - ``last`` : Mark duplicates as ``True`` except for the last + occurrence. + - False : Mark all duplicates as ``True``. + mask : ndarray[bool], optional + array indicating which elements to exclude from checking + + Returns + ------- + duplicated : ndarray[bool] + """ + values = _ensure_data(values) + return htable.duplicated(values, keep=keep, mask=mask) + + +def mode( + values: ArrayLike, dropna: bool = True, mask: npt.NDArray[np.bool_] | None = None +) -> ArrayLike: + """ + Returns the mode(s) of an array. + + Parameters + ---------- + values : array-like + Array over which to check for duplicate values. + dropna : bool, default True + Don't consider counts of NaN/NaT. + + Returns + ------- + np.ndarray or ExtensionArray + """ + values = _ensure_arraylike(values, func_name="mode") + original = values + + if needs_i8_conversion(values.dtype): + # Got here with ndarray; dispatch to DatetimeArray/TimedeltaArray. + values = ensure_wrapped_if_datetimelike(values) + values = cast("ExtensionArray", values) + return values._mode(dropna=dropna) + + values = _ensure_data(values) + + npresult, res_mask = htable.mode(values, dropna=dropna, mask=mask) + if res_mask is not None: + return npresult, res_mask # type: ignore[return-value] + + try: + npresult = np.sort(npresult) + except TypeError as err: + warnings.warn( + f"Unable to sort modes: {err}", + stacklevel=find_stack_level(), + ) + + result = _reconstruct_data(npresult, original.dtype, original) + return result + + +def rank( + values: ArrayLike, + axis: AxisInt = 0, + method: str = "average", + na_option: str = "keep", + ascending: bool = True, + pct: bool = False, +) -> npt.NDArray[np.float64]: + """ + Rank the values along a given axis. + + Parameters + ---------- + values : np.ndarray or ExtensionArray + Array whose values will be ranked. The number of dimensions in this + array must not exceed 2. + axis : int, default 0 + Axis over which to perform rankings. + method : {'average', 'min', 'max', 'first', 'dense'}, default 'average' + The method by which tiebreaks are broken during the ranking. + na_option : {'keep', 'top'}, default 'keep' + The method by which NaNs are placed in the ranking. + - ``keep``: rank each NaN value with a NaN ranking + - ``top``: replace each NaN with either +/- inf so that they + there are ranked at the top + ascending : bool, default True + Whether or not the elements should be ranked in ascending order. + pct : bool, default False + Whether or not to the display the returned rankings in integer form + (e.g. 1, 2, 3) or in percentile form (e.g. 0.333..., 0.666..., 1). + """ + is_datetimelike = needs_i8_conversion(values.dtype) + values = _ensure_data(values) + + if values.ndim == 1: + ranks = algos.rank_1d( + values, + is_datetimelike=is_datetimelike, + ties_method=method, + ascending=ascending, + na_option=na_option, + pct=pct, + ) + elif values.ndim == 2: + ranks = algos.rank_2d( + values, + axis=axis, + is_datetimelike=is_datetimelike, + ties_method=method, + ascending=ascending, + na_option=na_option, + pct=pct, + ) + else: + raise TypeError("Array with ndim > 2 are not supported.") + + return ranks + + +# ---- # +# take # +# ---- # + + +def take( + arr, + indices: TakeIndexer, + axis: AxisInt = 0, + allow_fill: bool = False, + fill_value=None, +): + """ + Take elements from an array. + + Parameters + ---------- + arr : array-like or scalar value + Non array-likes (sequences/scalars without a dtype) are coerced + to an ndarray. + + .. deprecated:: 2.1.0 + Passing an argument other than a numpy.ndarray, ExtensionArray, + Index, or Series is deprecated. + + indices : sequence of int or one-dimensional np.ndarray of int + Indices to be taken. + axis : int, default 0 + The axis over which to select values. + allow_fill : bool, default False + How to handle negative values in `indices`. + + * False: negative values in `indices` indicate positional indices + from the right (the default). This is similar to :func:`numpy.take`. + + * True: negative values in `indices` indicate + missing values. These values are set to `fill_value`. Any other + negative values raise a ``ValueError``. + + fill_value : any, optional + Fill value to use for NA-indices when `allow_fill` is True. + This may be ``None``, in which case the default NA value for + the type (``self.dtype.na_value``) is used. + + For multi-dimensional `arr`, each *element* is filled with + `fill_value`. + + Returns + ------- + ndarray or ExtensionArray + Same type as the input. + + Raises + ------ + IndexError + When `indices` is out of bounds for the array. + ValueError + When the indexer contains negative values other than ``-1`` + and `allow_fill` is True. + + Notes + ----- + When `allow_fill` is False, `indices` may be whatever dimensionality + is accepted by NumPy for `arr`. + + When `allow_fill` is True, `indices` should be 1-D. + + See Also + -------- + numpy.take : Take elements from an array along an axis. + + Examples + -------- + >>> import pandas as pd + + With the default ``allow_fill=False``, negative numbers indicate + positional indices from the right. + + >>> pd.api.extensions.take(np.array([10, 20, 30]), [0, 0, -1]) + array([10, 10, 30]) + + Setting ``allow_fill=True`` will place `fill_value` in those positions. + + >>> pd.api.extensions.take(np.array([10, 20, 30]), [0, 0, -1], allow_fill=True) + array([10., 10., nan]) + + >>> pd.api.extensions.take(np.array([10, 20, 30]), [0, 0, -1], allow_fill=True, + ... fill_value=-10) + array([ 10, 10, -10]) + """ + if not isinstance(arr, (np.ndarray, ABCExtensionArray, ABCIndex, ABCSeries)): + # GH#52981 + warnings.warn( + "pd.api.extensions.take accepting non-standard inputs is deprecated " + "and will raise in a future version. Pass either a numpy.ndarray, " + "ExtensionArray, Index, or Series instead.", + FutureWarning, + stacklevel=find_stack_level(), + ) + + if not is_array_like(arr): + arr = np.asarray(arr) + + indices = ensure_platform_int(indices) + + if allow_fill: + # Pandas style, -1 means NA + validate_indices(indices, arr.shape[axis]) + result = take_nd( + arr, indices, axis=axis, allow_fill=True, fill_value=fill_value + ) + else: + # NumPy style + result = arr.take(indices, axis=axis) + return result + + +# ------------ # +# searchsorted # +# ------------ # + + +def searchsorted( + arr: ArrayLike, + value: NumpyValueArrayLike | ExtensionArray, + side: Literal["left", "right"] = "left", + sorter: NumpySorter | None = None, +) -> npt.NDArray[np.intp] | np.intp: + """ + Find indices where elements should be inserted to maintain order. + + Find the indices into a sorted array `arr` (a) such that, if the + corresponding elements in `value` were inserted before the indices, + the order of `arr` would be preserved. + + Assuming that `arr` is sorted: + + ====== ================================ + `side` returned index `i` satisfies + ====== ================================ + left ``arr[i-1] < value <= self[i]`` + right ``arr[i-1] <= value < self[i]`` + ====== ================================ + + Parameters + ---------- + arr: np.ndarray, ExtensionArray, Series + Input array. If `sorter` is None, then it must be sorted in + ascending order, otherwise `sorter` must be an array of indices + that sort it. + value : array-like or scalar + Values to insert into `arr`. + side : {'left', 'right'}, optional + If 'left', the index of the first suitable location found is given. + If 'right', return the last such index. If there is no suitable + index, return either 0 or N (where N is the length of `self`). + sorter : 1-D array-like, optional + Optional array of integer indices that sort array a into ascending + order. They are typically the result of argsort. + + Returns + ------- + array of ints or int + If value is array-like, array of insertion points. + If value is scalar, a single integer. + + See Also + -------- + numpy.searchsorted : Similar method from NumPy. + """ + if sorter is not None: + sorter = ensure_platform_int(sorter) + + if ( + isinstance(arr, np.ndarray) + and arr.dtype.kind in "iu" + and (is_integer(value) or is_integer_dtype(value)) + ): + # if `arr` and `value` have different dtypes, `arr` would be + # recast by numpy, causing a slow search. + # Before searching below, we therefore try to give `value` the + # same dtype as `arr`, while guarding against integer overflows. + iinfo = np.iinfo(arr.dtype.type) + value_arr = np.array([value]) if is_integer(value) else np.array(value) + if (value_arr >= iinfo.min).all() and (value_arr <= iinfo.max).all(): + # value within bounds, so no overflow, so can convert value dtype + # to dtype of arr + dtype = arr.dtype + else: + dtype = value_arr.dtype + + if is_integer(value): + # We know that value is int + value = cast(int, dtype.type(value)) + else: + value = pd_array(cast(ArrayLike, value), dtype=dtype) + else: + # E.g. if `arr` is an array with dtype='datetime64[ns]' + # and `value` is a pd.Timestamp, we may need to convert value + arr = ensure_wrapped_if_datetimelike(arr) + + # Argument 1 to "searchsorted" of "ndarray" has incompatible type + # "Union[NumpyValueArrayLike, ExtensionArray]"; expected "NumpyValueArrayLike" + return arr.searchsorted(value, side=side, sorter=sorter) # type: ignore[arg-type] + + +# ---- # +# diff # +# ---- # + +_diff_special = {"float64", "float32", "int64", "int32", "int16", "int8"} + + +def diff(arr, n: int, axis: AxisInt = 0): + """ + difference of n between self, + analogous to s-s.shift(n) + + Parameters + ---------- + arr : ndarray or ExtensionArray + n : int + number of periods + axis : {0, 1} + axis to shift on + stacklevel : int, default 3 + The stacklevel for the lost dtype warning. + + Returns + ------- + shifted + """ + + n = int(n) + na = np.nan + dtype = arr.dtype + + is_bool = is_bool_dtype(dtype) + if is_bool: + op = operator.xor + else: + op = operator.sub + + if isinstance(dtype, NumpyEADtype): + # NumpyExtensionArray cannot necessarily hold shifted versions of itself. + arr = arr.to_numpy() + dtype = arr.dtype + + if not isinstance(arr, np.ndarray): + # i.e ExtensionArray + if hasattr(arr, f"__{op.__name__}__"): + if axis != 0: + raise ValueError(f"cannot diff {type(arr).__name__} on axis={axis}") + return op(arr, arr.shift(n)) + else: + raise TypeError( + f"{type(arr).__name__} has no 'diff' method. " + "Convert to a suitable dtype prior to calling 'diff'." + ) + + is_timedelta = False + if arr.dtype.kind in "mM": + dtype = np.int64 + arr = arr.view("i8") + na = iNaT + is_timedelta = True + + elif is_bool: + # We have to cast in order to be able to hold np.nan + dtype = np.object_ + + elif dtype.kind in "iu": + # We have to cast in order to be able to hold np.nan + + # int8, int16 are incompatible with float64, + # see https://github.com/cython/cython/issues/2646 + if arr.dtype.name in ["int8", "int16"]: + dtype = np.float32 + else: + dtype = np.float64 + + orig_ndim = arr.ndim + if orig_ndim == 1: + # reshape so we can always use algos.diff_2d + arr = arr.reshape(-1, 1) + # TODO: require axis == 0 + + dtype = np.dtype(dtype) + out_arr = np.empty(arr.shape, dtype=dtype) + + na_indexer = [slice(None)] * 2 + na_indexer[axis] = slice(None, n) if n >= 0 else slice(n, None) + out_arr[tuple(na_indexer)] = na + + if arr.dtype.name in _diff_special: + # TODO: can diff_2d dtype specialization troubles be fixed by defining + # out_arr inside diff_2d? + algos.diff_2d(arr, out_arr, n, axis, datetimelike=is_timedelta) + else: + # To keep mypy happy, _res_indexer is a list while res_indexer is + # a tuple, ditto for lag_indexer. + _res_indexer = [slice(None)] * 2 + _res_indexer[axis] = slice(n, None) if n >= 0 else slice(None, n) + res_indexer = tuple(_res_indexer) + + _lag_indexer = [slice(None)] * 2 + _lag_indexer[axis] = slice(None, -n) if n > 0 else slice(-n, None) + lag_indexer = tuple(_lag_indexer) + + out_arr[res_indexer] = op(arr[res_indexer], arr[lag_indexer]) + + if is_timedelta: + out_arr = out_arr.view("timedelta64[ns]") + + if orig_ndim == 1: + out_arr = out_arr[:, 0] + return out_arr + + +# -------------------------------------------------------------------- +# Helper functions + + +# Note: safe_sort is in algorithms.py instead of sorting.py because it is +# low-dependency, is used in this module, and used private methods from +# this module. +def safe_sort( + values: Index | ArrayLike, + codes: npt.NDArray[np.intp] | None = None, + use_na_sentinel: bool = True, + assume_unique: bool = False, + verify: bool = True, +) -> AnyArrayLike | tuple[AnyArrayLike, np.ndarray]: + """ + Sort ``values`` and reorder corresponding ``codes``. + + ``values`` should be unique if ``codes`` is not None. + Safe for use with mixed types (int, str), orders ints before strs. + + Parameters + ---------- + values : list-like + Sequence; must be unique if ``codes`` is not None. + codes : np.ndarray[intp] or None, default None + Indices to ``values``. All out of bound indices are treated as + "not found" and will be masked with ``-1``. + use_na_sentinel : bool, default True + If True, the sentinel -1 will be used for NaN values. If False, + NaN values will be encoded as non-negative integers and will not drop the + NaN from the uniques of the values. + assume_unique : bool, default False + When True, ``values`` are assumed to be unique, which can speed up + the calculation. Ignored when ``codes`` is None. + verify : bool, default True + Check if codes are out of bound for the values and put out of bound + codes equal to ``-1``. If ``verify=False``, it is assumed there + are no out of bound codes. Ignored when ``codes`` is None. + + Returns + ------- + ordered : AnyArrayLike + Sorted ``values`` + new_codes : ndarray + Reordered ``codes``; returned when ``codes`` is not None. + + Raises + ------ + TypeError + * If ``values`` is not list-like or if ``codes`` is neither None + nor list-like + * If ``values`` cannot be sorted + ValueError + * If ``codes`` is not None and ``values`` contain duplicates. + """ + if not isinstance(values, (np.ndarray, ABCExtensionArray, ABCIndex)): + raise TypeError( + "Only np.ndarray, ExtensionArray, and Index objects are allowed to " + "be passed to safe_sort as values" + ) + + sorter = None + ordered: AnyArrayLike + + if ( + not isinstance(values.dtype, ExtensionDtype) + and lib.infer_dtype(values, skipna=False) == "mixed-integer" + ): + ordered = _sort_mixed(values) + else: + try: + sorter = values.argsort() + ordered = values.take(sorter) + except (TypeError, decimal.InvalidOperation): + # Previous sorters failed or were not applicable, try `_sort_mixed` + # which would work, but which fails for special case of 1d arrays + # with tuples. + if values.size and isinstance(values[0], tuple): + # error: Argument 1 to "_sort_tuples" has incompatible type + # "Union[Index, ExtensionArray, ndarray[Any, Any]]"; expected + # "ndarray[Any, Any]" + ordered = _sort_tuples(values) # type: ignore[arg-type] + else: + ordered = _sort_mixed(values) + + # codes: + + if codes is None: + return ordered + + if not is_list_like(codes): + raise TypeError( + "Only list-like objects or None are allowed to " + "be passed to safe_sort as codes" + ) + codes = ensure_platform_int(np.asarray(codes)) + + if not assume_unique and not len(unique(values)) == len(values): + raise ValueError("values should be unique if codes is not None") + + if sorter is None: + # mixed types + # error: Argument 1 to "_get_hashtable_algo" has incompatible type + # "Union[Index, ExtensionArray, ndarray[Any, Any]]"; expected + # "ndarray[Any, Any]" + hash_klass, values = _get_hashtable_algo(values) # type: ignore[arg-type] + t = hash_klass(len(values)) + t.map_locations(values) + sorter = ensure_platform_int(t.lookup(ordered)) + + if use_na_sentinel: + # take_nd is faster, but only works for na_sentinels of -1 + order2 = sorter.argsort() + if verify: + mask = (codes < -len(values)) | (codes >= len(values)) + codes[mask] = 0 + else: + mask = None + new_codes = take_nd(order2, codes, fill_value=-1) + else: + reverse_indexer = np.empty(len(sorter), dtype=int) + reverse_indexer.put(sorter, np.arange(len(sorter))) + # Out of bound indices will be masked with `-1` next, so we + # may deal with them here without performance loss using `mode='wrap'` + new_codes = reverse_indexer.take(codes, mode="wrap") + + if use_na_sentinel: + mask = codes == -1 + if verify: + mask = mask | (codes < -len(values)) | (codes >= len(values)) + + if use_na_sentinel and mask is not None: + np.putmask(new_codes, mask, -1) + + return ordered, ensure_platform_int(new_codes) + + +def _sort_mixed(values) -> AnyArrayLike: + """order ints before strings before nulls in 1d arrays""" + str_pos = np.array([isinstance(x, str) for x in values], dtype=bool) + null_pos = np.array([isna(x) for x in values], dtype=bool) + num_pos = ~str_pos & ~null_pos + str_argsort = np.argsort(values[str_pos]) + num_argsort = np.argsort(values[num_pos]) + # convert boolean arrays to positional indices, then order by underlying values + str_locs = str_pos.nonzero()[0].take(str_argsort) + num_locs = num_pos.nonzero()[0].take(num_argsort) + null_locs = null_pos.nonzero()[0] + locs = np.concatenate([num_locs, str_locs, null_locs]) + return values.take(locs) + + +def _sort_tuples(values: np.ndarray) -> np.ndarray: + """ + Convert array of tuples (1d) to array of arrays (2d). + We need to keep the columns separately as they contain different types and + nans (can't use `np.sort` as it may fail when str and nan are mixed in a + column as types cannot be compared). + """ + from pandas.core.internals.construction import to_arrays + from pandas.core.sorting import lexsort_indexer + + arrays, _ = to_arrays(values, None) + indexer = lexsort_indexer(arrays, orders=True) + return values[indexer] + + +def union_with_duplicates( + lvals: ArrayLike | Index, rvals: ArrayLike | Index +) -> ArrayLike | Index: + """ + Extracts the union from lvals and rvals with respect to duplicates and nans in + both arrays. + + Parameters + ---------- + lvals: np.ndarray or ExtensionArray + left values which is ordered in front. + rvals: np.ndarray or ExtensionArray + right values ordered after lvals. + + Returns + ------- + np.ndarray or ExtensionArray + Containing the unsorted union of both arrays. + + Notes + ----- + Caller is responsible for ensuring lvals.dtype == rvals.dtype. + """ + from pandas import Series + + with warnings.catch_warnings(): + # filter warning from object dtype inference; we will end up discarding + # the index here, so the deprecation does not affect the end result here. + warnings.filterwarnings( + "ignore", + "The behavior of value_counts with object-dtype is deprecated", + category=FutureWarning, + ) + l_count = value_counts_internal(lvals, dropna=False) + r_count = value_counts_internal(rvals, dropna=False) + l_count, r_count = l_count.align(r_count, fill_value=0) + final_count = np.maximum(l_count.values, r_count.values) + final_count = Series(final_count, index=l_count.index, dtype="int", copy=False) + if isinstance(lvals, ABCMultiIndex) and isinstance(rvals, ABCMultiIndex): + unique_vals = lvals.append(rvals).unique() + else: + if isinstance(lvals, ABCIndex): + lvals = lvals._values + if isinstance(rvals, ABCIndex): + rvals = rvals._values + # error: List item 0 has incompatible type "Union[ExtensionArray, + # ndarray[Any, Any], Index]"; expected "Union[ExtensionArray, + # ndarray[Any, Any]]" + combined = concat_compat([lvals, rvals]) # type: ignore[list-item] + unique_vals = unique(combined) + unique_vals = ensure_wrapped_if_datetimelike(unique_vals) + repeats = final_count.reindex(unique_vals).values + return np.repeat(unique_vals, repeats) + + +def map_array( + arr: ArrayLike, + mapper, + na_action: Literal["ignore"] | None = None, + convert: bool = True, +) -> np.ndarray | ExtensionArray | Index: + """ + Map values using an input mapping or function. + + Parameters + ---------- + mapper : function, dict, or Series + Mapping correspondence. + na_action : {None, 'ignore'}, default None + If 'ignore', propagate NA values, without passing them to the + mapping correspondence. + convert : bool, default True + Try to find better dtype for elementwise function results. If + False, leave as dtype=object. + + Returns + ------- + Union[ndarray, Index, ExtensionArray] + The output of the mapping function applied to the array. + If the function returns a tuple with more than one element + a MultiIndex will be returned. + """ + if na_action not in (None, "ignore"): + msg = f"na_action must either be 'ignore' or None, {na_action} was passed" + raise ValueError(msg) + + # we can fastpath dict/Series to an efficient map + # as we know that we are not going to have to yield + # python types + if is_dict_like(mapper): + if isinstance(mapper, dict) and hasattr(mapper, "__missing__"): + # If a dictionary subclass defines a default value method, + # convert mapper to a lookup function (GH #15999). + dict_with_default = mapper + mapper = lambda x: dict_with_default[ + np.nan if isinstance(x, float) and np.isnan(x) else x + ] + else: + # Dictionary does not have a default. Thus it's safe to + # convert to an Series for efficiency. + # we specify the keys here to handle the + # possibility that they are tuples + + # The return value of mapping with an empty mapper is + # expected to be pd.Series(np.nan, ...). As np.nan is + # of dtype float64 the return value of this method should + # be float64 as well + from pandas import Series + + if len(mapper) == 0: + mapper = Series(mapper, dtype=np.float64) + else: + mapper = Series(mapper) + + if isinstance(mapper, ABCSeries): + if na_action == "ignore": + mapper = mapper[mapper.index.notna()] + + # Since values were input this means we came from either + # a dict or a series and mapper should be an index + indexer = mapper.index.get_indexer(arr) + new_values = take_nd(mapper._values, indexer) + + return new_values + + if not len(arr): + return arr.copy() + + # we must convert to python types + values = arr.astype(object, copy=False) + if na_action is None: + return lib.map_infer(values, mapper, convert=convert) + else: + return lib.map_infer_mask( + values, mapper, mask=isna(values).view(np.uint8), convert=convert + ) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/core/flags.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/core/flags.py new file mode 100644 index 0000000000000000000000000000000000000000..aff7a15f283bafe1459173070e64df4caef0d45d --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/core/flags.py @@ -0,0 +1,117 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING +import weakref + +if TYPE_CHECKING: + from pandas.core.generic import NDFrame + + +class Flags: + """ + Flags that apply to pandas objects. + + Parameters + ---------- + obj : Series or DataFrame + The object these flags are associated with. + allows_duplicate_labels : bool, default True + Whether to allow duplicate labels in this object. By default, + duplicate labels are permitted. Setting this to ``False`` will + cause an :class:`errors.DuplicateLabelError` to be raised when + `index` (or columns for DataFrame) is not unique, or any + subsequent operation on introduces duplicates. + See :ref:`duplicates.disallow` for more. + + .. warning:: + + This is an experimental feature. Currently, many methods fail to + propagate the ``allows_duplicate_labels`` value. In future versions + it is expected that every method taking or returning one or more + DataFrame or Series objects will propagate ``allows_duplicate_labels``. + + Examples + -------- + Attributes can be set in two ways: + + >>> df = pd.DataFrame() + >>> df.flags + + >>> df.flags.allows_duplicate_labels = False + >>> df.flags + + + >>> df.flags['allows_duplicate_labels'] = True + >>> df.flags + + """ + + _keys: set[str] = {"allows_duplicate_labels"} + + def __init__(self, obj: NDFrame, *, allows_duplicate_labels: bool) -> None: + self._allows_duplicate_labels = allows_duplicate_labels + self._obj = weakref.ref(obj) + + @property + def allows_duplicate_labels(self) -> bool: + """ + Whether this object allows duplicate labels. + + Setting ``allows_duplicate_labels=False`` ensures that the + index (and columns of a DataFrame) are unique. Most methods + that accept and return a Series or DataFrame will propagate + the value of ``allows_duplicate_labels``. + + See :ref:`duplicates` for more. + + See Also + -------- + DataFrame.attrs : Set global metadata on this object. + DataFrame.set_flags : Set global flags on this object. + + Examples + -------- + >>> df = pd.DataFrame({"A": [1, 2]}, index=['a', 'a']) + >>> df.flags.allows_duplicate_labels + True + >>> df.flags.allows_duplicate_labels = False + Traceback (most recent call last): + ... + pandas.errors.DuplicateLabelError: Index has duplicates. + positions + label + a [0, 1] + """ + return self._allows_duplicate_labels + + @allows_duplicate_labels.setter + def allows_duplicate_labels(self, value: bool) -> None: + value = bool(value) + obj = self._obj() + if obj is None: + raise ValueError("This flag's object has been deleted.") + + if not value: + for ax in obj.axes: + ax._maybe_check_unique() + + self._allows_duplicate_labels = value + + def __getitem__(self, key: str): + if key not in self._keys: + raise KeyError(key) + + return getattr(self, key) + + def __setitem__(self, key: str, value) -> None: + if key not in self._keys: + raise ValueError(f"Unknown flag {key}. Must be one of {self._keys}") + setattr(self, key, value) + + def __repr__(self) -> str: + return f"" + + def __eq__(self, other) -> bool: + if isinstance(other, type(self)): + return self.allows_duplicate_labels == other.allows_duplicate_labels + return False diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/core/roperator.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/core/roperator.py new file mode 100644 index 0000000000000000000000000000000000000000..2f320f4e9c6b984b64e0fc1268e50a8ad1a7e1fe --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/core/roperator.py @@ -0,0 +1,62 @@ +""" +Reversed Operations not available in the stdlib operator module. +Defining these instead of using lambdas allows us to reference them by name. +""" +from __future__ import annotations + +import operator + + +def radd(left, right): + return right + left + + +def rsub(left, right): + return right - left + + +def rmul(left, right): + return right * left + + +def rdiv(left, right): + return right / left + + +def rtruediv(left, right): + return right / left + + +def rfloordiv(left, right): + return right // left + + +def rmod(left, right): + # check if right is a string as % is the string + # formatting operation; this is a TypeError + # otherwise perform the op + if isinstance(right, str): + typ = type(left).__name__ + raise TypeError(f"{typ} cannot perform the operation mod") + + return right % left + + +def rdivmod(left, right): + return divmod(right, left) + + +def rpow(left, right): + return right**left + + +def rand_(left, right): + return operator.and_(right, left) + + +def ror_(left, right): + return operator.or_(right, left) + + +def rxor(left, right): + return operator.xor(right, left) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/test_aggregation.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/test_aggregation.py new file mode 100644 index 0000000000000000000000000000000000000000..7695c953712ed9925e4e804d0db1e8cf606a97eb --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/test_aggregation.py @@ -0,0 +1,93 @@ +import numpy as np +import pytest + +from pandas.core.apply import ( + _make_unique_kwarg_list, + maybe_mangle_lambdas, +) + + +def test_maybe_mangle_lambdas_passthrough(): + assert maybe_mangle_lambdas("mean") == "mean" + assert maybe_mangle_lambdas(lambda x: x).__name__ == "" + # don't mangel single lambda. + assert maybe_mangle_lambdas([lambda x: x])[0].__name__ == "" + + +def test_maybe_mangle_lambdas_listlike(): + aggfuncs = [lambda x: 1, lambda x: 2] + result = maybe_mangle_lambdas(aggfuncs) + assert result[0].__name__ == "" + assert result[1].__name__ == "" + assert aggfuncs[0](None) == result[0](None) + assert aggfuncs[1](None) == result[1](None) + + +def test_maybe_mangle_lambdas(): + func = {"A": [lambda x: 0, lambda x: 1]} + result = maybe_mangle_lambdas(func) + assert result["A"][0].__name__ == "" + assert result["A"][1].__name__ == "" + + +def test_maybe_mangle_lambdas_args(): + func = {"A": [lambda x, a, b=1: (0, a, b), lambda x: 1]} + result = maybe_mangle_lambdas(func) + assert result["A"][0].__name__ == "" + assert result["A"][1].__name__ == "" + + assert func["A"][0](0, 1) == (0, 1, 1) + assert func["A"][0](0, 1, 2) == (0, 1, 2) + assert func["A"][0](0, 2, b=3) == (0, 2, 3) + + +def test_maybe_mangle_lambdas_named(): + func = {"C": np.mean, "D": {"foo": np.mean, "bar": np.mean}} + result = maybe_mangle_lambdas(func) + assert result == func + + +@pytest.mark.parametrize( + "order, expected_reorder", + [ + ( + [ + ("height", ""), + ("height", "max"), + ("weight", "max"), + ("height", ""), + ("weight", ""), + ], + [ + ("height", "_0"), + ("height", "max"), + ("weight", "max"), + ("height", "_1"), + ("weight", ""), + ], + ), + ( + [ + ("col2", "min"), + ("col1", ""), + ("col1", ""), + ("col1", ""), + ], + [ + ("col2", "min"), + ("col1", "_0"), + ("col1", "_1"), + ("col1", "_2"), + ], + ), + ( + [("col", ""), ("col", ""), ("col", "")], + [("col", "_0"), ("col", "_1"), ("col", "_2")], + ), + ], +) +def test_make_unique(order, expected_reorder): + # GH 27519, test if make_unique function reorders correctly + result = _make_unique_kwarg_list(order) + + assert result == expected_reorder diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/test_algos.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/test_algos.py new file mode 100644 index 0000000000000000000000000000000000000000..718d1b3ee2e834507919cd1e46b2e2bead191589 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/test_algos.py @@ -0,0 +1,2041 @@ +from datetime import datetime +import struct + +import numpy as np +import pytest + +from pandas._libs import ( + algos as libalgos, + hashtable as ht, +) + +from pandas.core.dtypes.common import ( + is_bool_dtype, + is_complex_dtype, + is_float_dtype, + is_integer_dtype, + is_object_dtype, +) +from pandas.core.dtypes.dtypes import CategoricalDtype + +import pandas as pd +from pandas import ( + Categorical, + CategoricalIndex, + DataFrame, + DatetimeIndex, + Index, + IntervalIndex, + MultiIndex, + NaT, + Period, + PeriodIndex, + Series, + Timedelta, + Timestamp, + cut, + date_range, + timedelta_range, + to_datetime, + to_timedelta, +) +import pandas._testing as tm +import pandas.core.algorithms as algos +from pandas.core.arrays import ( + DatetimeArray, + TimedeltaArray, +) +import pandas.core.common as com + + +class TestFactorize: + def test_factorize_complex(self): + # GH#17927 + array = [1, 2, 2 + 1j] + msg = "factorize with argument that is not not a Series" + with tm.assert_produces_warning(FutureWarning, match=msg): + labels, uniques = algos.factorize(array) + + expected_labels = np.array([0, 1, 2], dtype=np.intp) + tm.assert_numpy_array_equal(labels, expected_labels) + + # Should return a complex dtype in the future + expected_uniques = np.array([(1 + 0j), (2 + 0j), (2 + 1j)], dtype=object) + tm.assert_numpy_array_equal(uniques, expected_uniques) + + @pytest.mark.parametrize("sort", [True, False]) + def test_factorize(self, index_or_series_obj, sort): + obj = index_or_series_obj + result_codes, result_uniques = obj.factorize(sort=sort) + + constructor = Index + if isinstance(obj, MultiIndex): + constructor = MultiIndex.from_tuples + expected_arr = obj.unique() + if expected_arr.dtype == np.float16: + expected_arr = expected_arr.astype(np.float32) + expected_uniques = constructor(expected_arr) + if ( + isinstance(obj, Index) + and expected_uniques.dtype == bool + and obj.dtype == object + ): + expected_uniques = expected_uniques.astype(object) + + if sort: + expected_uniques = expected_uniques.sort_values() + + # construct an integer ndarray so that + # `expected_uniques.take(expected_codes)` is equal to `obj` + expected_uniques_list = list(expected_uniques) + expected_codes = [expected_uniques_list.index(val) for val in obj] + expected_codes = np.asarray(expected_codes, dtype=np.intp) + + tm.assert_numpy_array_equal(result_codes, expected_codes) + tm.assert_index_equal(result_uniques, expected_uniques, exact=True) + + def test_series_factorize_use_na_sentinel_false(self): + # GH#35667 + values = np.array([1, 2, 1, np.nan]) + ser = Series(values) + codes, uniques = ser.factorize(use_na_sentinel=False) + + expected_codes = np.array([0, 1, 0, 2], dtype=np.intp) + expected_uniques = Index([1.0, 2.0, np.nan]) + + tm.assert_numpy_array_equal(codes, expected_codes) + tm.assert_index_equal(uniques, expected_uniques) + + def test_basic(self): + items = np.array(["a", "b", "b", "a", "a", "c", "c", "c"], dtype=object) + codes, uniques = algos.factorize(items) + tm.assert_numpy_array_equal(uniques, np.array(["a", "b", "c"], dtype=object)) + + codes, uniques = algos.factorize(items, sort=True) + exp = np.array([0, 1, 1, 0, 0, 2, 2, 2], dtype=np.intp) + tm.assert_numpy_array_equal(codes, exp) + exp = np.array(["a", "b", "c"], dtype=object) + tm.assert_numpy_array_equal(uniques, exp) + + arr = np.arange(5, dtype=np.intp)[::-1] + + codes, uniques = algos.factorize(arr) + exp = np.array([0, 1, 2, 3, 4], dtype=np.intp) + tm.assert_numpy_array_equal(codes, exp) + exp = np.array([4, 3, 2, 1, 0], dtype=arr.dtype) + tm.assert_numpy_array_equal(uniques, exp) + + codes, uniques = algos.factorize(arr, sort=True) + exp = np.array([4, 3, 2, 1, 0], dtype=np.intp) + tm.assert_numpy_array_equal(codes, exp) + exp = np.array([0, 1, 2, 3, 4], dtype=arr.dtype) + tm.assert_numpy_array_equal(uniques, exp) + + arr = np.arange(5.0)[::-1] + + codes, uniques = algos.factorize(arr) + exp = np.array([0, 1, 2, 3, 4], dtype=np.intp) + tm.assert_numpy_array_equal(codes, exp) + exp = np.array([4.0, 3.0, 2.0, 1.0, 0.0], dtype=arr.dtype) + tm.assert_numpy_array_equal(uniques, exp) + + codes, uniques = algos.factorize(arr, sort=True) + exp = np.array([4, 3, 2, 1, 0], dtype=np.intp) + tm.assert_numpy_array_equal(codes, exp) + exp = np.array([0.0, 1.0, 2.0, 3.0, 4.0], dtype=arr.dtype) + tm.assert_numpy_array_equal(uniques, exp) + + def test_mixed(self): + # doc example reshaping.rst + x = Series(["A", "A", np.nan, "B", 3.14, np.inf]) + codes, uniques = algos.factorize(x) + + exp = np.array([0, 0, -1, 1, 2, 3], dtype=np.intp) + tm.assert_numpy_array_equal(codes, exp) + exp = Index(["A", "B", 3.14, np.inf]) + tm.assert_index_equal(uniques, exp) + + codes, uniques = algos.factorize(x, sort=True) + exp = np.array([2, 2, -1, 3, 0, 1], dtype=np.intp) + tm.assert_numpy_array_equal(codes, exp) + exp = Index([3.14, np.inf, "A", "B"]) + tm.assert_index_equal(uniques, exp) + + def test_factorize_datetime64(self): + # M8 + v1 = Timestamp("20130101 09:00:00.00004") + v2 = Timestamp("20130101") + x = Series([v1, v1, v1, v2, v2, v1]) + codes, uniques = algos.factorize(x) + + exp = np.array([0, 0, 0, 1, 1, 0], dtype=np.intp) + tm.assert_numpy_array_equal(codes, exp) + exp = DatetimeIndex([v1, v2]) + tm.assert_index_equal(uniques, exp) + + codes, uniques = algos.factorize(x, sort=True) + exp = np.array([1, 1, 1, 0, 0, 1], dtype=np.intp) + tm.assert_numpy_array_equal(codes, exp) + exp = DatetimeIndex([v2, v1]) + tm.assert_index_equal(uniques, exp) + + def test_factorize_period(self): + # period + v1 = Period("201302", freq="M") + v2 = Period("201303", freq="M") + x = Series([v1, v1, v1, v2, v2, v1]) + + # periods are not 'sorted' as they are converted back into an index + codes, uniques = algos.factorize(x) + exp = np.array([0, 0, 0, 1, 1, 0], dtype=np.intp) + tm.assert_numpy_array_equal(codes, exp) + tm.assert_index_equal(uniques, PeriodIndex([v1, v2])) + + codes, uniques = algos.factorize(x, sort=True) + exp = np.array([0, 0, 0, 1, 1, 0], dtype=np.intp) + tm.assert_numpy_array_equal(codes, exp) + tm.assert_index_equal(uniques, PeriodIndex([v1, v2])) + + def test_factorize_timedelta(self): + # GH 5986 + v1 = to_timedelta("1 day 1 min") + v2 = to_timedelta("1 day") + x = Series([v1, v2, v1, v1, v2, v2, v1]) + codes, uniques = algos.factorize(x) + exp = np.array([0, 1, 0, 0, 1, 1, 0], dtype=np.intp) + tm.assert_numpy_array_equal(codes, exp) + tm.assert_index_equal(uniques, to_timedelta([v1, v2])) + + codes, uniques = algos.factorize(x, sort=True) + exp = np.array([1, 0, 1, 1, 0, 0, 1], dtype=np.intp) + tm.assert_numpy_array_equal(codes, exp) + tm.assert_index_equal(uniques, to_timedelta([v2, v1])) + + def test_factorize_nan(self): + # nan should map to na_sentinel, not reverse_indexer[na_sentinel] + # rizer.factorize should not raise an exception if na_sentinel indexes + # outside of reverse_indexer + key = np.array([1, 2, 1, np.nan], dtype="O") + rizer = ht.ObjectFactorizer(len(key)) + for na_sentinel in (-1, 20): + ids = rizer.factorize(key, na_sentinel=na_sentinel) + expected = np.array([0, 1, 0, na_sentinel], dtype=np.intp) + assert len(set(key)) == len(set(expected)) + tm.assert_numpy_array_equal(pd.isna(key), expected == na_sentinel) + tm.assert_numpy_array_equal(ids, expected) + + def test_factorizer_with_mask(self): + # GH#49549 + data = np.array([1, 2, 3, 1, 1, 0], dtype="int64") + mask = np.array([False, False, False, False, False, True]) + rizer = ht.Int64Factorizer(len(data)) + result = rizer.factorize(data, mask=mask) + expected = np.array([0, 1, 2, 0, 0, -1], dtype=np.intp) + tm.assert_numpy_array_equal(result, expected) + expected_uniques = np.array([1, 2, 3], dtype="int64") + tm.assert_numpy_array_equal(rizer.uniques.to_array(), expected_uniques) + + def test_factorizer_object_with_nan(self): + # GH#49549 + data = np.array([1, 2, 3, 1, np.nan]) + rizer = ht.ObjectFactorizer(len(data)) + result = rizer.factorize(data.astype(object)) + expected = np.array([0, 1, 2, 0, -1], dtype=np.intp) + tm.assert_numpy_array_equal(result, expected) + expected_uniques = np.array([1, 2, 3], dtype=object) + tm.assert_numpy_array_equal(rizer.uniques.to_array(), expected_uniques) + + @pytest.mark.parametrize( + "data, expected_codes, expected_uniques", + [ + ( + [(1, 1), (1, 2), (0, 0), (1, 2), "nonsense"], + [0, 1, 2, 1, 3], + [(1, 1), (1, 2), (0, 0), "nonsense"], + ), + ( + [(1, 1), (1, 2), (0, 0), (1, 2), (1, 2, 3)], + [0, 1, 2, 1, 3], + [(1, 1), (1, 2), (0, 0), (1, 2, 3)], + ), + ([(1, 1), (1, 2), (0, 0), (1, 2)], [0, 1, 2, 1], [(1, 1), (1, 2), (0, 0)]), + ], + ) + def test_factorize_tuple_list(self, data, expected_codes, expected_uniques): + # GH9454 + msg = "factorize with argument that is not not a Series" + with tm.assert_produces_warning(FutureWarning, match=msg): + codes, uniques = pd.factorize(data) + + tm.assert_numpy_array_equal(codes, np.array(expected_codes, dtype=np.intp)) + + expected_uniques_array = com.asarray_tuplesafe(expected_uniques, dtype=object) + tm.assert_numpy_array_equal(uniques, expected_uniques_array) + + def test_complex_sorting(self): + # gh 12666 - check no segfault + x17 = np.array([complex(i) for i in range(17)], dtype=object) + + msg = "'[<>]' not supported between instances of .*" + with pytest.raises(TypeError, match=msg): + algos.factorize(x17[::-1], sort=True) + + def test_numeric_dtype_factorize(self, any_real_numpy_dtype): + # GH41132 + dtype = any_real_numpy_dtype + data = np.array([1, 2, 2, 1], dtype=dtype) + expected_codes = np.array([0, 1, 1, 0], dtype=np.intp) + expected_uniques = np.array([1, 2], dtype=dtype) + + codes, uniques = algos.factorize(data) + tm.assert_numpy_array_equal(codes, expected_codes) + tm.assert_numpy_array_equal(uniques, expected_uniques) + + def test_float64_factorize(self, writable): + data = np.array([1.0, 1e8, 1.0, 1e-8, 1e8, 1.0], dtype=np.float64) + data.setflags(write=writable) + expected_codes = np.array([0, 1, 0, 2, 1, 0], dtype=np.intp) + expected_uniques = np.array([1.0, 1e8, 1e-8], dtype=np.float64) + + codes, uniques = algos.factorize(data) + tm.assert_numpy_array_equal(codes, expected_codes) + tm.assert_numpy_array_equal(uniques, expected_uniques) + + def test_uint64_factorize(self, writable): + data = np.array([2**64 - 1, 1, 2**64 - 1], dtype=np.uint64) + data.setflags(write=writable) + expected_codes = np.array([0, 1, 0], dtype=np.intp) + expected_uniques = np.array([2**64 - 1, 1], dtype=np.uint64) + + codes, uniques = algos.factorize(data) + tm.assert_numpy_array_equal(codes, expected_codes) + tm.assert_numpy_array_equal(uniques, expected_uniques) + + def test_int64_factorize(self, writable): + data = np.array([2**63 - 1, -(2**63), 2**63 - 1], dtype=np.int64) + data.setflags(write=writable) + expected_codes = np.array([0, 1, 0], dtype=np.intp) + expected_uniques = np.array([2**63 - 1, -(2**63)], dtype=np.int64) + + codes, uniques = algos.factorize(data) + tm.assert_numpy_array_equal(codes, expected_codes) + tm.assert_numpy_array_equal(uniques, expected_uniques) + + def test_string_factorize(self, writable): + data = np.array(["a", "c", "a", "b", "c"], dtype=object) + data.setflags(write=writable) + expected_codes = np.array([0, 1, 0, 2, 1], dtype=np.intp) + expected_uniques = np.array(["a", "c", "b"], dtype=object) + + codes, uniques = algos.factorize(data) + tm.assert_numpy_array_equal(codes, expected_codes) + tm.assert_numpy_array_equal(uniques, expected_uniques) + + def test_object_factorize(self, writable): + data = np.array(["a", "c", None, np.nan, "a", "b", NaT, "c"], dtype=object) + data.setflags(write=writable) + expected_codes = np.array([0, 1, -1, -1, 0, 2, -1, 1], dtype=np.intp) + expected_uniques = np.array(["a", "c", "b"], dtype=object) + + codes, uniques = algos.factorize(data) + tm.assert_numpy_array_equal(codes, expected_codes) + tm.assert_numpy_array_equal(uniques, expected_uniques) + + def test_datetime64_factorize(self, writable): + # GH35650 Verify whether read-only datetime64 array can be factorized + data = np.array([np.datetime64("2020-01-01T00:00:00.000")], dtype="M8[ns]") + data.setflags(write=writable) + expected_codes = np.array([0], dtype=np.intp) + expected_uniques = np.array( + ["2020-01-01T00:00:00.000000000"], dtype="datetime64[ns]" + ) + + codes, uniques = pd.factorize(data) + tm.assert_numpy_array_equal(codes, expected_codes) + tm.assert_numpy_array_equal(uniques, expected_uniques) + + @pytest.mark.parametrize("sort", [True, False]) + def test_factorize_rangeindex(self, sort): + # increasing -> sort doesn't matter + ri = pd.RangeIndex.from_range(range(10)) + expected = np.arange(10, dtype=np.intp), ri + + result = algos.factorize(ri, sort=sort) + tm.assert_numpy_array_equal(result[0], expected[0]) + tm.assert_index_equal(result[1], expected[1], exact=True) + + result = ri.factorize(sort=sort) + tm.assert_numpy_array_equal(result[0], expected[0]) + tm.assert_index_equal(result[1], expected[1], exact=True) + + @pytest.mark.parametrize("sort", [True, False]) + def test_factorize_rangeindex_decreasing(self, sort): + # decreasing -> sort matters + ri = pd.RangeIndex.from_range(range(10)) + expected = np.arange(10, dtype=np.intp), ri + + ri2 = ri[::-1] + expected = expected[0], ri2 + if sort: + expected = expected[0][::-1], expected[1][::-1] + + result = algos.factorize(ri2, sort=sort) + tm.assert_numpy_array_equal(result[0], expected[0]) + tm.assert_index_equal(result[1], expected[1], exact=True) + + result = ri2.factorize(sort=sort) + tm.assert_numpy_array_equal(result[0], expected[0]) + tm.assert_index_equal(result[1], expected[1], exact=True) + + def test_deprecate_order(self): + # gh 19727 - check warning is raised for deprecated keyword, order. + # Test not valid once order keyword is removed. + data = np.array([2**63, 1, 2**63], dtype=np.uint64) + with pytest.raises(TypeError, match="got an unexpected keyword"): + algos.factorize(data, order=True) + with tm.assert_produces_warning(False): + algos.factorize(data) + + @pytest.mark.parametrize( + "data", + [ + np.array([0, 1, 0], dtype="u8"), + np.array([-(2**63), 1, -(2**63)], dtype="i8"), + np.array(["__nan__", "foo", "__nan__"], dtype="object"), + ], + ) + def test_parametrized_factorize_na_value_default(self, data): + # arrays that include the NA default for that type, but isn't used. + codes, uniques = algos.factorize(data) + expected_uniques = data[[0, 1]] + expected_codes = np.array([0, 1, 0], dtype=np.intp) + tm.assert_numpy_array_equal(codes, expected_codes) + tm.assert_numpy_array_equal(uniques, expected_uniques) + + @pytest.mark.parametrize( + "data, na_value", + [ + (np.array([0, 1, 0, 2], dtype="u8"), 0), + (np.array([1, 0, 1, 2], dtype="u8"), 1), + (np.array([-(2**63), 1, -(2**63), 0], dtype="i8"), -(2**63)), + (np.array([1, -(2**63), 1, 0], dtype="i8"), 1), + (np.array(["a", "", "a", "b"], dtype=object), "a"), + (np.array([(), ("a", 1), (), ("a", 2)], dtype=object), ()), + (np.array([("a", 1), (), ("a", 1), ("a", 2)], dtype=object), ("a", 1)), + ], + ) + def test_parametrized_factorize_na_value(self, data, na_value): + codes, uniques = algos.factorize_array(data, na_value=na_value) + expected_uniques = data[[1, 3]] + expected_codes = np.array([-1, 0, -1, 1], dtype=np.intp) + tm.assert_numpy_array_equal(codes, expected_codes) + tm.assert_numpy_array_equal(uniques, expected_uniques) + + @pytest.mark.parametrize("sort", [True, False]) + @pytest.mark.parametrize( + "data, uniques", + [ + ( + np.array(["b", "a", None, "b"], dtype=object), + np.array(["b", "a"], dtype=object), + ), + ( + pd.array([2, 1, np.nan, 2], dtype="Int64"), + pd.array([2, 1], dtype="Int64"), + ), + ], + ids=["numpy_array", "extension_array"], + ) + def test_factorize_use_na_sentinel(self, sort, data, uniques): + codes, uniques = algos.factorize(data, sort=sort, use_na_sentinel=True) + if sort: + expected_codes = np.array([1, 0, -1, 1], dtype=np.intp) + expected_uniques = algos.safe_sort(uniques) + else: + expected_codes = np.array([0, 1, -1, 0], dtype=np.intp) + expected_uniques = uniques + tm.assert_numpy_array_equal(codes, expected_codes) + if isinstance(data, np.ndarray): + tm.assert_numpy_array_equal(uniques, expected_uniques) + else: + tm.assert_extension_array_equal(uniques, expected_uniques) + + @pytest.mark.parametrize( + "data, expected_codes, expected_uniques", + [ + ( + ["a", None, "b", "a"], + np.array([0, 1, 2, 0], dtype=np.dtype("intp")), + np.array(["a", np.nan, "b"], dtype=object), + ), + ( + ["a", np.nan, "b", "a"], + np.array([0, 1, 2, 0], dtype=np.dtype("intp")), + np.array(["a", np.nan, "b"], dtype=object), + ), + ], + ) + def test_object_factorize_use_na_sentinel_false( + self, data, expected_codes, expected_uniques + ): + codes, uniques = algos.factorize( + np.array(data, dtype=object), use_na_sentinel=False + ) + + tm.assert_numpy_array_equal(uniques, expected_uniques, strict_nan=True) + tm.assert_numpy_array_equal(codes, expected_codes, strict_nan=True) + + @pytest.mark.parametrize( + "data, expected_codes, expected_uniques", + [ + ( + [1, None, 1, 2], + np.array([0, 1, 0, 2], dtype=np.dtype("intp")), + np.array([1, np.nan, 2], dtype="O"), + ), + ( + [1, np.nan, 1, 2], + np.array([0, 1, 0, 2], dtype=np.dtype("intp")), + np.array([1, np.nan, 2], dtype=np.float64), + ), + ], + ) + def test_int_factorize_use_na_sentinel_false( + self, data, expected_codes, expected_uniques + ): + msg = "factorize with argument that is not not a Series" + with tm.assert_produces_warning(FutureWarning, match=msg): + codes, uniques = algos.factorize(data, use_na_sentinel=False) + + tm.assert_numpy_array_equal(uniques, expected_uniques, strict_nan=True) + tm.assert_numpy_array_equal(codes, expected_codes, strict_nan=True) + + @pytest.mark.parametrize( + "data, expected_codes, expected_uniques", + [ + ( + Index(Categorical(["a", "a", "b"])), + np.array([0, 0, 1], dtype=np.intp), + CategoricalIndex(["a", "b"], categories=["a", "b"], dtype="category"), + ), + ( + Series(Categorical(["a", "a", "b"])), + np.array([0, 0, 1], dtype=np.intp), + CategoricalIndex(["a", "b"], categories=["a", "b"], dtype="category"), + ), + ( + Series(DatetimeIndex(["2017", "2017"], tz="US/Eastern")), + np.array([0, 0], dtype=np.intp), + DatetimeIndex(["2017"], tz="US/Eastern"), + ), + ], + ) + def test_factorize_mixed_values(self, data, expected_codes, expected_uniques): + # GH 19721 + codes, uniques = algos.factorize(data) + tm.assert_numpy_array_equal(codes, expected_codes) + tm.assert_index_equal(uniques, expected_uniques) + + def test_factorize_interval_non_nano(self, unit): + # GH#56099 + left = DatetimeIndex(["2016-01-01", np.nan, "2015-10-11"]).as_unit(unit) + right = DatetimeIndex(["2016-01-02", np.nan, "2015-10-15"]).as_unit(unit) + idx = IntervalIndex.from_arrays(left, right) + codes, cats = idx.factorize() + assert cats.dtype == f"interval[datetime64[{unit}], right]" + + ts = Timestamp(0).as_unit(unit) + idx2 = IntervalIndex.from_arrays(left - ts, right - ts) + codes2, cats2 = idx2.factorize() + assert cats2.dtype == f"interval[timedelta64[{unit}], right]" + + idx3 = IntervalIndex.from_arrays( + left.tz_localize("US/Pacific"), right.tz_localize("US/Pacific") + ) + codes3, cats3 = idx3.factorize() + assert cats3.dtype == f"interval[datetime64[{unit}, US/Pacific], right]" + + +class TestUnique: + def test_ints(self): + arr = np.random.default_rng(2).integers(0, 100, size=50) + + result = algos.unique(arr) + assert isinstance(result, np.ndarray) + + def test_objects(self): + arr = np.random.default_rng(2).integers(0, 100, size=50).astype("O") + + result = algos.unique(arr) + assert isinstance(result, np.ndarray) + + def test_object_refcount_bug(self): + lst = np.array(["A", "B", "C", "D", "E"], dtype=object) + for i in range(1000): + len(algos.unique(lst)) + + def test_on_index_object(self): + mindex = MultiIndex.from_arrays( + [np.arange(5).repeat(5), np.tile(np.arange(5), 5)] + ) + expected = mindex.values + expected.sort() + + mindex = mindex.repeat(2) + + result = pd.unique(mindex) + result.sort() + + tm.assert_almost_equal(result, expected) + + def test_dtype_preservation(self, any_numpy_dtype): + # GH 15442 + if any_numpy_dtype in (tm.BYTES_DTYPES + tm.STRING_DTYPES): + data = [1, 2, 2] + uniques = [1, 2] + elif is_integer_dtype(any_numpy_dtype): + data = [1, 2, 2] + uniques = [1, 2] + elif is_float_dtype(any_numpy_dtype): + data = [1, 2, 2] + uniques = [1.0, 2.0] + elif is_complex_dtype(any_numpy_dtype): + data = [complex(1, 0), complex(2, 0), complex(2, 0)] + uniques = [complex(1, 0), complex(2, 0)] + elif is_bool_dtype(any_numpy_dtype): + data = [True, True, False] + uniques = [True, False] + elif is_object_dtype(any_numpy_dtype): + data = ["A", "B", "B"] + uniques = ["A", "B"] + else: + # datetime64[ns]/M8[ns]/timedelta64[ns]/m8[ns] tested elsewhere + data = [1, 2, 2] + uniques = [1, 2] + + result = Series(data, dtype=any_numpy_dtype).unique() + expected = np.array(uniques, dtype=any_numpy_dtype) + + if any_numpy_dtype in tm.STRING_DTYPES: + expected = expected.astype(object) + + if expected.dtype.kind in ["m", "M"]: + # We get TimedeltaArray/DatetimeArray + assert isinstance(result, (DatetimeArray, TimedeltaArray)) + result = np.array(result) + tm.assert_numpy_array_equal(result, expected) + + def test_datetime64_dtype_array_returned(self): + # GH 9431 + expected = np.array( + [ + "2015-01-03T00:00:00.000000000", + "2015-01-01T00:00:00.000000000", + ], + dtype="M8[ns]", + ) + + dt_index = to_datetime( + [ + "2015-01-03T00:00:00.000000000", + "2015-01-01T00:00:00.000000000", + "2015-01-01T00:00:00.000000000", + ] + ) + result = algos.unique(dt_index) + tm.assert_numpy_array_equal(result, expected) + assert result.dtype == expected.dtype + + s = Series(dt_index) + result = algos.unique(s) + tm.assert_numpy_array_equal(result, expected) + assert result.dtype == expected.dtype + + arr = s.values + result = algos.unique(arr) + tm.assert_numpy_array_equal(result, expected) + assert result.dtype == expected.dtype + + def test_datetime_non_ns(self): + a = np.array(["2000", "2000", "2001"], dtype="datetime64[s]") + result = pd.unique(a) + expected = np.array(["2000", "2001"], dtype="datetime64[s]") + tm.assert_numpy_array_equal(result, expected) + + def test_timedelta_non_ns(self): + a = np.array(["2000", "2000", "2001"], dtype="timedelta64[s]") + result = pd.unique(a) + expected = np.array([2000, 2001], dtype="timedelta64[s]") + tm.assert_numpy_array_equal(result, expected) + + def test_timedelta64_dtype_array_returned(self): + # GH 9431 + expected = np.array([31200, 45678, 10000], dtype="m8[ns]") + + td_index = to_timedelta([31200, 45678, 31200, 10000, 45678]) + result = algos.unique(td_index) + tm.assert_numpy_array_equal(result, expected) + assert result.dtype == expected.dtype + + s = Series(td_index) + result = algos.unique(s) + tm.assert_numpy_array_equal(result, expected) + assert result.dtype == expected.dtype + + arr = s.values + result = algos.unique(arr) + tm.assert_numpy_array_equal(result, expected) + assert result.dtype == expected.dtype + + def test_uint64_overflow(self): + s = Series([1, 2, 2**63, 2**63], dtype=np.uint64) + exp = np.array([1, 2, 2**63], dtype=np.uint64) + tm.assert_numpy_array_equal(algos.unique(s), exp) + + def test_nan_in_object_array(self): + duplicated_items = ["a", np.nan, "c", "c"] + result = pd.unique(np.array(duplicated_items, dtype=object)) + expected = np.array(["a", np.nan, "c"], dtype=object) + tm.assert_numpy_array_equal(result, expected) + + def test_categorical(self): + # we are expecting to return in the order + # of appearance + expected = Categorical(list("bac")) + + # we are expecting to return in the order + # of the categories + expected_o = Categorical(list("bac"), categories=list("abc"), ordered=True) + + # GH 15939 + c = Categorical(list("baabc")) + result = c.unique() + tm.assert_categorical_equal(result, expected) + + result = algos.unique(c) + tm.assert_categorical_equal(result, expected) + + c = Categorical(list("baabc"), ordered=True) + result = c.unique() + tm.assert_categorical_equal(result, expected_o) + + result = algos.unique(c) + tm.assert_categorical_equal(result, expected_o) + + # Series of categorical dtype + s = Series(Categorical(list("baabc")), name="foo") + result = s.unique() + tm.assert_categorical_equal(result, expected) + + result = pd.unique(s) + tm.assert_categorical_equal(result, expected) + + # CI -> return CI + ci = CategoricalIndex(Categorical(list("baabc"), categories=list("abc"))) + expected = CategoricalIndex(expected) + result = ci.unique() + tm.assert_index_equal(result, expected) + + result = pd.unique(ci) + tm.assert_index_equal(result, expected) + + def test_datetime64tz_aware(self, unit): + # GH 15939 + + dti = Index( + [ + Timestamp("20160101", tz="US/Eastern"), + Timestamp("20160101", tz="US/Eastern"), + ] + ).as_unit(unit) + ser = Series(dti) + + result = ser.unique() + expected = dti[:1]._data + tm.assert_extension_array_equal(result, expected) + + result = dti.unique() + expected = dti[:1] + tm.assert_index_equal(result, expected) + + result = pd.unique(ser) + expected = dti[:1]._data + tm.assert_extension_array_equal(result, expected) + + result = pd.unique(dti) + expected = dti[:1] + tm.assert_index_equal(result, expected) + + def test_order_of_appearance(self): + # 9346 + # light testing of guarantee of order of appearance + # these also are the doc-examples + result = pd.unique(Series([2, 1, 3, 3])) + tm.assert_numpy_array_equal(result, np.array([2, 1, 3], dtype="int64")) + + result = pd.unique(Series([2] + [1] * 5)) + tm.assert_numpy_array_equal(result, np.array([2, 1], dtype="int64")) + + msg = "unique with argument that is not not a Series, Index," + with tm.assert_produces_warning(FutureWarning, match=msg): + result = pd.unique(list("aabc")) + expected = np.array(["a", "b", "c"], dtype=object) + tm.assert_numpy_array_equal(result, expected) + + result = pd.unique(Series(Categorical(list("aabc")))) + expected = Categorical(list("abc")) + tm.assert_categorical_equal(result, expected) + + def test_order_of_appearance_dt64(self, unit): + ser = Series([Timestamp("20160101"), Timestamp("20160101")]).dt.as_unit(unit) + result = pd.unique(ser) + expected = np.array(["2016-01-01T00:00:00.000000000"], dtype=f"M8[{unit}]") + tm.assert_numpy_array_equal(result, expected) + + def test_order_of_appearance_dt64tz(self, unit): + dti = DatetimeIndex( + [ + Timestamp("20160101", tz="US/Eastern"), + Timestamp("20160101", tz="US/Eastern"), + ] + ).as_unit(unit) + result = pd.unique(dti) + expected = DatetimeIndex( + ["2016-01-01 00:00:00"], dtype=f"datetime64[{unit}, US/Eastern]", freq=None + ) + tm.assert_index_equal(result, expected) + + @pytest.mark.parametrize( + "arg ,expected", + [ + (("1", "1", "2"), np.array(["1", "2"], dtype=object)), + (("foo",), np.array(["foo"], dtype=object)), + ], + ) + def test_tuple_with_strings(self, arg, expected): + # see GH 17108 + msg = "unique with argument that is not not a Series" + with tm.assert_produces_warning(FutureWarning, match=msg): + result = pd.unique(arg) + tm.assert_numpy_array_equal(result, expected) + + def test_obj_none_preservation(self): + # GH 20866 + arr = np.array(["foo", None], dtype=object) + result = pd.unique(arr) + expected = np.array(["foo", None], dtype=object) + + tm.assert_numpy_array_equal(result, expected, strict_nan=True) + + def test_signed_zero(self): + # GH 21866 + a = np.array([-0.0, 0.0]) + result = pd.unique(a) + expected = np.array([-0.0]) # 0.0 and -0.0 are equivalent + tm.assert_numpy_array_equal(result, expected) + + def test_different_nans(self): + # GH 21866 + # create different nans from bit-patterns: + NAN1 = struct.unpack("d", struct.pack("=Q", 0x7FF8000000000000))[0] + NAN2 = struct.unpack("d", struct.pack("=Q", 0x7FF8000000000001))[0] + assert NAN1 != NAN1 + assert NAN2 != NAN2 + a = np.array([NAN1, NAN2]) # NAN1 and NAN2 are equivalent + result = pd.unique(a) + expected = np.array([np.nan]) + tm.assert_numpy_array_equal(result, expected) + + @pytest.mark.parametrize("el_type", [np.float64, object]) + def test_first_nan_kept(self, el_type): + # GH 22295 + # create different nans from bit-patterns: + bits_for_nan1 = 0xFFF8000000000001 + bits_for_nan2 = 0x7FF8000000000001 + NAN1 = struct.unpack("d", struct.pack("=Q", bits_for_nan1))[0] + NAN2 = struct.unpack("d", struct.pack("=Q", bits_for_nan2))[0] + assert NAN1 != NAN1 + assert NAN2 != NAN2 + a = np.array([NAN1, NAN2], dtype=el_type) + result = pd.unique(a) + assert result.size == 1 + # use bit patterns to identify which nan was kept: + result_nan_bits = struct.unpack("=Q", struct.pack("d", result[0]))[0] + assert result_nan_bits == bits_for_nan1 + + def test_do_not_mangle_na_values(self, unique_nulls_fixture, unique_nulls_fixture2): + # GH 22295 + if unique_nulls_fixture is unique_nulls_fixture2: + return # skip it, values not unique + a = np.array([unique_nulls_fixture, unique_nulls_fixture2], dtype=object) + result = pd.unique(a) + assert result.size == 2 + assert a[0] is unique_nulls_fixture + assert a[1] is unique_nulls_fixture2 + + def test_unique_masked(self, any_numeric_ea_dtype): + # GH#48019 + ser = Series([1, pd.NA, 2] * 3, dtype=any_numeric_ea_dtype) + result = pd.unique(ser) + expected = pd.array([1, pd.NA, 2], dtype=any_numeric_ea_dtype) + tm.assert_extension_array_equal(result, expected) + + +def test_nunique_ints(index_or_series_or_array): + # GH#36327 + values = index_or_series_or_array(np.random.default_rng(2).integers(0, 20, 30)) + result = algos.nunique_ints(values) + expected = len(algos.unique(values)) + assert result == expected + + +class TestIsin: + def test_invalid(self): + msg = ( + r"only list-like objects are allowed to be passed to isin\(\), " + r"you passed a `int`" + ) + with pytest.raises(TypeError, match=msg): + algos.isin(1, 1) + with pytest.raises(TypeError, match=msg): + algos.isin(1, [1]) + with pytest.raises(TypeError, match=msg): + algos.isin([1], 1) + + def test_basic(self): + msg = "isin with argument that is not not a Series" + with tm.assert_produces_warning(FutureWarning, match=msg): + result = algos.isin([1, 2], [1]) + expected = np.array([True, False]) + tm.assert_numpy_array_equal(result, expected) + + result = algos.isin(np.array([1, 2]), [1]) + expected = np.array([True, False]) + tm.assert_numpy_array_equal(result, expected) + + result = algos.isin(Series([1, 2]), [1]) + expected = np.array([True, False]) + tm.assert_numpy_array_equal(result, expected) + + result = algos.isin(Series([1, 2]), Series([1])) + expected = np.array([True, False]) + tm.assert_numpy_array_equal(result, expected) + + result = algos.isin(Series([1, 2]), {1}) + expected = np.array([True, False]) + tm.assert_numpy_array_equal(result, expected) + + with tm.assert_produces_warning(FutureWarning, match=msg): + result = algos.isin(["a", "b"], ["a"]) + expected = np.array([True, False]) + tm.assert_numpy_array_equal(result, expected) + + result = algos.isin(Series(["a", "b"]), Series(["a"])) + expected = np.array([True, False]) + tm.assert_numpy_array_equal(result, expected) + + result = algos.isin(Series(["a", "b"]), {"a"}) + expected = np.array([True, False]) + tm.assert_numpy_array_equal(result, expected) + + with tm.assert_produces_warning(FutureWarning, match=msg): + result = algos.isin(["a", "b"], [1]) + expected = np.array([False, False]) + tm.assert_numpy_array_equal(result, expected) + + def test_i8(self): + arr = date_range("20130101", periods=3).values + result = algos.isin(arr, [arr[0]]) + expected = np.array([True, False, False]) + tm.assert_numpy_array_equal(result, expected) + + result = algos.isin(arr, arr[0:2]) + expected = np.array([True, True, False]) + tm.assert_numpy_array_equal(result, expected) + + result = algos.isin(arr, set(arr[0:2])) + expected = np.array([True, True, False]) + tm.assert_numpy_array_equal(result, expected) + + arr = timedelta_range("1 day", periods=3).values + result = algos.isin(arr, [arr[0]]) + expected = np.array([True, False, False]) + tm.assert_numpy_array_equal(result, expected) + + result = algos.isin(arr, arr[0:2]) + expected = np.array([True, True, False]) + tm.assert_numpy_array_equal(result, expected) + + result = algos.isin(arr, set(arr[0:2])) + expected = np.array([True, True, False]) + tm.assert_numpy_array_equal(result, expected) + + @pytest.mark.parametrize("dtype1", ["m8[ns]", "M8[ns]", "M8[ns, UTC]", "period[D]"]) + @pytest.mark.parametrize("dtype", ["i8", "f8", "u8"]) + def test_isin_datetimelike_values_numeric_comps(self, dtype, dtype1): + # Anything but object and we get all-False shortcut + + dta = date_range("2013-01-01", periods=3)._values + arr = Series(dta.view("i8")).array.view(dtype1) + + comps = arr.view("i8").astype(dtype) + + result = algos.isin(comps, arr) + expected = np.zeros(comps.shape, dtype=bool) + tm.assert_numpy_array_equal(result, expected) + + def test_large(self): + s = date_range("20000101", periods=2000000, freq="s").values + result = algos.isin(s, s[0:2]) + expected = np.zeros(len(s), dtype=bool) + expected[0] = True + expected[1] = True + tm.assert_numpy_array_equal(result, expected) + + @pytest.mark.parametrize("dtype", ["m8[ns]", "M8[ns]", "M8[ns, UTC]", "period[D]"]) + def test_isin_datetimelike_all_nat(self, dtype): + # GH#56427 + dta = date_range("2013-01-01", periods=3)._values + arr = Series(dta.view("i8")).array.view(dtype) + + arr[0] = NaT + result = algos.isin(arr, [NaT]) + expected = np.array([True, False, False], dtype=bool) + tm.assert_numpy_array_equal(result, expected) + + @pytest.mark.parametrize("dtype", ["m8[ns]", "M8[ns]", "M8[ns, UTC]"]) + def test_isin_datetimelike_strings_deprecated(self, dtype): + # GH#53111 + dta = date_range("2013-01-01", periods=3)._values + arr = Series(dta.view("i8")).array.view(dtype) + + vals = [str(x) for x in arr] + msg = "The behavior of 'isin' with dtype=.* is deprecated" + with tm.assert_produces_warning(FutureWarning, match=msg): + res = algos.isin(arr, vals) + assert res.all() + + vals2 = np.array(vals, dtype=str) + with tm.assert_produces_warning(FutureWarning, match=msg): + res2 = algos.isin(arr, vals2) + assert res2.all() + + def test_isin_dt64tz_with_nat(self): + # the all-NaT values used to get inferred to tznaive, which was evaluated + # as non-matching GH#56427 + dti = date_range("2016-01-01", periods=3, tz="UTC") + ser = Series(dti) + ser[0] = NaT + + res = algos.isin(ser._values, [NaT]) + exp = np.array([True, False, False], dtype=bool) + tm.assert_numpy_array_equal(res, exp) + + def test_categorical_from_codes(self): + # GH 16639 + vals = np.array([0, 1, 2, 0]) + cats = ["a", "b", "c"] + Sd = Series(Categorical([1]).from_codes(vals, cats)) + St = Series(Categorical([1]).from_codes(np.array([0, 1]), cats)) + expected = np.array([True, True, False, True]) + result = algos.isin(Sd, St) + tm.assert_numpy_array_equal(expected, result) + + def test_categorical_isin(self): + vals = np.array([0, 1, 2, 0]) + cats = ["a", "b", "c"] + cat = Categorical([1]).from_codes(vals, cats) + other = Categorical([1]).from_codes(np.array([0, 1]), cats) + + expected = np.array([True, True, False, True]) + result = algos.isin(cat, other) + tm.assert_numpy_array_equal(expected, result) + + def test_same_nan_is_in(self): + # GH 22160 + # nan is special, because from " a is b" doesn't follow "a == b" + # at least, isin() should follow python's "np.nan in [nan] == True" + # casting to -> np.float64 -> another float-object somewhere on + # the way could lead jeopardize this behavior + comps = [np.nan] # could be casted to float64 + values = [np.nan] + expected = np.array([True]) + msg = "isin with argument that is not not a Series" + with tm.assert_produces_warning(FutureWarning, match=msg): + result = algos.isin(comps, values) + tm.assert_numpy_array_equal(expected, result) + + def test_same_nan_is_in_large(self): + # https://github.com/pandas-dev/pandas/issues/22205 + s = np.tile(1.0, 1_000_001) + s[0] = np.nan + result = algos.isin(s, np.array([np.nan, 1])) + expected = np.ones(len(s), dtype=bool) + tm.assert_numpy_array_equal(result, expected) + + def test_same_nan_is_in_large_series(self): + # https://github.com/pandas-dev/pandas/issues/22205 + s = np.tile(1.0, 1_000_001) + series = Series(s) + s[0] = np.nan + result = series.isin(np.array([np.nan, 1])) + expected = Series(np.ones(len(s), dtype=bool)) + tm.assert_series_equal(result, expected) + + def test_same_object_is_in(self): + # GH 22160 + # there could be special treatment for nans + # the user however could define a custom class + # with similar behavior, then we at least should + # fall back to usual python's behavior: "a in [a] == True" + class LikeNan: + def __eq__(self, other) -> bool: + return False + + def __hash__(self): + return 0 + + a, b = LikeNan(), LikeNan() + + msg = "isin with argument that is not not a Series" + with tm.assert_produces_warning(FutureWarning, match=msg): + # same object -> True + tm.assert_numpy_array_equal(algos.isin([a], [a]), np.array([True])) + # different objects -> False + tm.assert_numpy_array_equal(algos.isin([a], [b]), np.array([False])) + + def test_different_nans(self): + # GH 22160 + # all nans are handled as equivalent + + comps = [float("nan")] + values = [float("nan")] + assert comps[0] is not values[0] # different nan-objects + + # as list of python-objects: + result = algos.isin(np.array(comps), values) + tm.assert_numpy_array_equal(np.array([True]), result) + + # as object-array: + result = algos.isin( + np.asarray(comps, dtype=object), np.asarray(values, dtype=object) + ) + tm.assert_numpy_array_equal(np.array([True]), result) + + # as float64-array: + result = algos.isin( + np.asarray(comps, dtype=np.float64), np.asarray(values, dtype=np.float64) + ) + tm.assert_numpy_array_equal(np.array([True]), result) + + def test_no_cast(self): + # GH 22160 + # ensure 42 is not casted to a string + comps = ["ss", 42] + values = ["42"] + expected = np.array([False, False]) + msg = "isin with argument that is not not a Series, Index" + with tm.assert_produces_warning(FutureWarning, match=msg): + result = algos.isin(comps, values) + tm.assert_numpy_array_equal(expected, result) + + @pytest.mark.parametrize("empty", [[], Series(dtype=object), np.array([])]) + def test_empty(self, empty): + # see gh-16991 + vals = Index(["a", "b"]) + expected = np.array([False, False]) + + result = algos.isin(vals, empty) + tm.assert_numpy_array_equal(expected, result) + + def test_different_nan_objects(self): + # GH 22119 + comps = np.array(["nan", np.nan * 1j, float("nan")], dtype=object) + vals = np.array([float("nan")], dtype=object) + expected = np.array([False, False, True]) + result = algos.isin(comps, vals) + tm.assert_numpy_array_equal(expected, result) + + def test_different_nans_as_float64(self): + # GH 21866 + # create different nans from bit-patterns, + # these nans will land in different buckets in the hash-table + # if no special care is taken + NAN1 = struct.unpack("d", struct.pack("=Q", 0x7FF8000000000000))[0] + NAN2 = struct.unpack("d", struct.pack("=Q", 0x7FF8000000000001))[0] + assert NAN1 != NAN1 + assert NAN2 != NAN2 + + # check that NAN1 and NAN2 are equivalent: + arr = np.array([NAN1, NAN2], dtype=np.float64) + lookup1 = np.array([NAN1], dtype=np.float64) + result = algos.isin(arr, lookup1) + expected = np.array([True, True]) + tm.assert_numpy_array_equal(result, expected) + + lookup2 = np.array([NAN2], dtype=np.float64) + result = algos.isin(arr, lookup2) + expected = np.array([True, True]) + tm.assert_numpy_array_equal(result, expected) + + def test_isin_int_df_string_search(self): + """Comparing df with int`s (1,2) with a string at isin() ("1") + -> should not match values because int 1 is not equal str 1""" + df = DataFrame({"values": [1, 2]}) + result = df.isin(["1"]) + expected_false = DataFrame({"values": [False, False]}) + tm.assert_frame_equal(result, expected_false) + + def test_isin_nan_df_string_search(self): + """Comparing df with nan value (np.nan,2) with a string at isin() ("NaN") + -> should not match values because np.nan is not equal str NaN""" + df = DataFrame({"values": [np.nan, 2]}) + result = df.isin(np.array(["NaN"], dtype=object)) + expected_false = DataFrame({"values": [False, False]}) + tm.assert_frame_equal(result, expected_false) + + def test_isin_float_df_string_search(self): + """Comparing df with floats (1.4245,2.32441) with a string at isin() ("1.4245") + -> should not match values because float 1.4245 is not equal str 1.4245""" + df = DataFrame({"values": [1.4245, 2.32441]}) + result = df.isin(np.array(["1.4245"], dtype=object)) + expected_false = DataFrame({"values": [False, False]}) + tm.assert_frame_equal(result, expected_false) + + def test_isin_unsigned_dtype(self): + # GH#46485 + ser = Series([1378774140726870442], dtype=np.uint64) + result = ser.isin([1378774140726870528]) + expected = Series(False) + tm.assert_series_equal(result, expected) + + +class TestValueCounts: + def test_value_counts(self): + arr = np.random.default_rng(1234).standard_normal(4) + factor = cut(arr, 4) + + # assert isinstance(factor, n) + msg = "pandas.value_counts is deprecated" + with tm.assert_produces_warning(FutureWarning, match=msg): + result = algos.value_counts(factor) + breaks = [-1.606, -1.018, -0.431, 0.155, 0.741] + index = IntervalIndex.from_breaks(breaks).astype(CategoricalDtype(ordered=True)) + expected = Series([1, 0, 2, 1], index=index, name="count") + tm.assert_series_equal(result.sort_index(), expected.sort_index()) + + def test_value_counts_bins(self): + s = [1, 2, 3, 4] + msg = "pandas.value_counts is deprecated" + with tm.assert_produces_warning(FutureWarning, match=msg): + result = algos.value_counts(s, bins=1) + expected = Series( + [4], index=IntervalIndex.from_tuples([(0.996, 4.0)]), name="count" + ) + tm.assert_series_equal(result, expected) + + with tm.assert_produces_warning(FutureWarning, match=msg): + result = algos.value_counts(s, bins=2, sort=False) + expected = Series( + [2, 2], + index=IntervalIndex.from_tuples([(0.996, 2.5), (2.5, 4.0)]), + name="count", + ) + tm.assert_series_equal(result, expected) + + def test_value_counts_dtypes(self): + msg2 = "pandas.value_counts is deprecated" + with tm.assert_produces_warning(FutureWarning, match=msg2): + result = algos.value_counts(np.array([1, 1.0])) + assert len(result) == 1 + + with tm.assert_produces_warning(FutureWarning, match=msg2): + result = algos.value_counts(np.array([1, 1.0]), bins=1) + assert len(result) == 1 + + with tm.assert_produces_warning(FutureWarning, match=msg2): + result = algos.value_counts(Series([1, 1.0, "1"])) # object + assert len(result) == 2 + + msg = "bins argument only works with numeric data" + with pytest.raises(TypeError, match=msg): + with tm.assert_produces_warning(FutureWarning, match=msg2): + algos.value_counts(np.array(["1", 1], dtype=object), bins=1) + + def test_value_counts_nat(self): + td = Series([np.timedelta64(10000), NaT], dtype="timedelta64[ns]") + dt = to_datetime(["NaT", "2014-01-01"]) + + msg = "pandas.value_counts is deprecated" + + for ser in [td, dt]: + with tm.assert_produces_warning(FutureWarning, match=msg): + vc = algos.value_counts(ser) + vc_with_na = algos.value_counts(ser, dropna=False) + assert len(vc) == 1 + assert len(vc_with_na) == 2 + + exp_dt = Series({Timestamp("2014-01-01 00:00:00"): 1}, name="count") + with tm.assert_produces_warning(FutureWarning, match=msg): + result_dt = algos.value_counts(dt) + tm.assert_series_equal(result_dt, exp_dt) + + exp_td = Series({np.timedelta64(10000): 1}, name="count") + with tm.assert_produces_warning(FutureWarning, match=msg): + result_td = algos.value_counts(td) + tm.assert_series_equal(result_td, exp_td) + + @pytest.mark.parametrize("dtype", [object, "M8[us]"]) + def test_value_counts_datetime_outofbounds(self, dtype): + # GH 13663 + ser = Series( + [ + datetime(3000, 1, 1), + datetime(5000, 1, 1), + datetime(5000, 1, 1), + datetime(6000, 1, 1), + datetime(3000, 1, 1), + datetime(3000, 1, 1), + ], + dtype=dtype, + ) + res = ser.value_counts() + + exp_index = Index( + [datetime(3000, 1, 1), datetime(5000, 1, 1), datetime(6000, 1, 1)], + dtype=dtype, + ) + exp = Series([3, 2, 1], index=exp_index, name="count") + tm.assert_series_equal(res, exp) + + def test_categorical(self): + s = Series(Categorical(list("aaabbc"))) + result = s.value_counts() + expected = Series( + [3, 2, 1], index=CategoricalIndex(["a", "b", "c"]), name="count" + ) + + tm.assert_series_equal(result, expected, check_index_type=True) + + # preserve order? + s = s.cat.as_ordered() + result = s.value_counts() + expected.index = expected.index.as_ordered() + tm.assert_series_equal(result, expected, check_index_type=True) + + def test_categorical_nans(self): + s = Series(Categorical(list("aaaaabbbcc"))) # 4,3,2,1 (nan) + s.iloc[1] = np.nan + result = s.value_counts() + expected = Series( + [4, 3, 2], + index=CategoricalIndex(["a", "b", "c"], categories=["a", "b", "c"]), + name="count", + ) + tm.assert_series_equal(result, expected, check_index_type=True) + result = s.value_counts(dropna=False) + expected = Series( + [4, 3, 2, 1], index=CategoricalIndex(["a", "b", "c", np.nan]), name="count" + ) + tm.assert_series_equal(result, expected, check_index_type=True) + + # out of order + s = Series( + Categorical(list("aaaaabbbcc"), ordered=True, categories=["b", "a", "c"]) + ) + s.iloc[1] = np.nan + result = s.value_counts() + expected = Series( + [4, 3, 2], + index=CategoricalIndex( + ["a", "b", "c"], + categories=["b", "a", "c"], + ordered=True, + ), + name="count", + ) + tm.assert_series_equal(result, expected, check_index_type=True) + + result = s.value_counts(dropna=False) + expected = Series( + [4, 3, 2, 1], + index=CategoricalIndex( + ["a", "b", "c", np.nan], categories=["b", "a", "c"], ordered=True + ), + name="count", + ) + tm.assert_series_equal(result, expected, check_index_type=True) + + def test_categorical_zeroes(self): + # keep the `d` category with 0 + s = Series(Categorical(list("bbbaac"), categories=list("abcd"), ordered=True)) + result = s.value_counts() + expected = Series( + [3, 2, 1, 0], + index=Categorical( + ["b", "a", "c", "d"], categories=list("abcd"), ordered=True + ), + name="count", + ) + tm.assert_series_equal(result, expected, check_index_type=True) + + def test_value_counts_dropna(self): + # https://github.com/pandas-dev/pandas/issues/9443#issuecomment-73719328 + + tm.assert_series_equal( + Series([True, True, False]).value_counts(dropna=True), + Series([2, 1], index=[True, False], name="count"), + ) + tm.assert_series_equal( + Series([True, True, False]).value_counts(dropna=False), + Series([2, 1], index=[True, False], name="count"), + ) + + tm.assert_series_equal( + Series([True] * 3 + [False] * 2 + [None] * 5).value_counts(dropna=True), + Series([3, 2], index=Index([True, False], dtype=object), name="count"), + ) + tm.assert_series_equal( + Series([True] * 5 + [False] * 3 + [None] * 2).value_counts(dropna=False), + Series([5, 3, 2], index=[True, False, None], name="count"), + ) + tm.assert_series_equal( + Series([10.3, 5.0, 5.0]).value_counts(dropna=True), + Series([2, 1], index=[5.0, 10.3], name="count"), + ) + tm.assert_series_equal( + Series([10.3, 5.0, 5.0]).value_counts(dropna=False), + Series([2, 1], index=[5.0, 10.3], name="count"), + ) + + tm.assert_series_equal( + Series([10.3, 5.0, 5.0, None]).value_counts(dropna=True), + Series([2, 1], index=[5.0, 10.3], name="count"), + ) + + result = Series([10.3, 10.3, 5.0, 5.0, 5.0, None]).value_counts(dropna=False) + expected = Series([3, 2, 1], index=[5.0, 10.3, None], name="count") + tm.assert_series_equal(result, expected) + + @pytest.mark.parametrize("dtype", (np.float64, object, "M8[ns]")) + def test_value_counts_normalized(self, dtype): + # GH12558 + s = Series([1] * 2 + [2] * 3 + [np.nan] * 5) + s_typed = s.astype(dtype) + result = s_typed.value_counts(normalize=True, dropna=False) + expected = Series( + [0.5, 0.3, 0.2], + index=Series([np.nan, 2.0, 1.0], dtype=dtype), + name="proportion", + ) + tm.assert_series_equal(result, expected) + + result = s_typed.value_counts(normalize=True, dropna=True) + expected = Series( + [0.6, 0.4], index=Series([2.0, 1.0], dtype=dtype), name="proportion" + ) + tm.assert_series_equal(result, expected) + + def test_value_counts_uint64(self): + arr = np.array([2**63], dtype=np.uint64) + expected = Series([1], index=[2**63], name="count") + msg = "pandas.value_counts is deprecated" + with tm.assert_produces_warning(FutureWarning, match=msg): + result = algos.value_counts(arr) + + tm.assert_series_equal(result, expected) + + arr = np.array([-1, 2**63], dtype=object) + expected = Series([1, 1], index=[-1, 2**63], name="count") + with tm.assert_produces_warning(FutureWarning, match=msg): + result = algos.value_counts(arr) + + tm.assert_series_equal(result, expected) + + def test_value_counts_series(self): + # GH#54857 + values = np.array([3, 1, 2, 3, 4, np.nan]) + result = Series(values).value_counts(bins=3) + expected = Series( + [2, 2, 1], + index=IntervalIndex.from_tuples( + [(0.996, 2.0), (2.0, 3.0), (3.0, 4.0)], dtype="interval[float64, right]" + ), + name="count", + ) + tm.assert_series_equal(result, expected) + + +class TestDuplicated: + def test_duplicated_with_nas(self): + keys = np.array([0, 1, np.nan, 0, 2, np.nan], dtype=object) + + result = algos.duplicated(keys) + expected = np.array([False, False, False, True, False, True]) + tm.assert_numpy_array_equal(result, expected) + + result = algos.duplicated(keys, keep="first") + expected = np.array([False, False, False, True, False, True]) + tm.assert_numpy_array_equal(result, expected) + + result = algos.duplicated(keys, keep="last") + expected = np.array([True, False, True, False, False, False]) + tm.assert_numpy_array_equal(result, expected) + + result = algos.duplicated(keys, keep=False) + expected = np.array([True, False, True, True, False, True]) + tm.assert_numpy_array_equal(result, expected) + + keys = np.empty(8, dtype=object) + for i, t in enumerate( + zip([0, 0, np.nan, np.nan] * 2, [0, np.nan, 0, np.nan] * 2) + ): + keys[i] = t + + result = algos.duplicated(keys) + falses = [False] * 4 + trues = [True] * 4 + expected = np.array(falses + trues) + tm.assert_numpy_array_equal(result, expected) + + result = algos.duplicated(keys, keep="last") + expected = np.array(trues + falses) + tm.assert_numpy_array_equal(result, expected) + + result = algos.duplicated(keys, keep=False) + expected = np.array(trues + trues) + tm.assert_numpy_array_equal(result, expected) + + @pytest.mark.parametrize( + "case", + [ + np.array([1, 2, 1, 5, 3, 2, 4, 1, 5, 6]), + np.array([1.1, 2.2, 1.1, np.nan, 3.3, 2.2, 4.4, 1.1, np.nan, 6.6]), + np.array( + [ + 1 + 1j, + 2 + 2j, + 1 + 1j, + 5 + 5j, + 3 + 3j, + 2 + 2j, + 4 + 4j, + 1 + 1j, + 5 + 5j, + 6 + 6j, + ] + ), + np.array(["a", "b", "a", "e", "c", "b", "d", "a", "e", "f"], dtype=object), + np.array( + [1, 2**63, 1, 3**5, 10, 2**63, 39, 1, 3**5, 7], dtype=np.uint64 + ), + ], + ) + def test_numeric_object_likes(self, case): + exp_first = np.array( + [False, False, True, False, False, True, False, True, True, False] + ) + exp_last = np.array( + [True, True, True, True, False, False, False, False, False, False] + ) + exp_false = exp_first | exp_last + + res_first = algos.duplicated(case, keep="first") + tm.assert_numpy_array_equal(res_first, exp_first) + + res_last = algos.duplicated(case, keep="last") + tm.assert_numpy_array_equal(res_last, exp_last) + + res_false = algos.duplicated(case, keep=False) + tm.assert_numpy_array_equal(res_false, exp_false) + + # index + for idx in [Index(case), Index(case, dtype="category")]: + res_first = idx.duplicated(keep="first") + tm.assert_numpy_array_equal(res_first, exp_first) + + res_last = idx.duplicated(keep="last") + tm.assert_numpy_array_equal(res_last, exp_last) + + res_false = idx.duplicated(keep=False) + tm.assert_numpy_array_equal(res_false, exp_false) + + # series + for s in [Series(case), Series(case, dtype="category")]: + res_first = s.duplicated(keep="first") + tm.assert_series_equal(res_first, Series(exp_first)) + + res_last = s.duplicated(keep="last") + tm.assert_series_equal(res_last, Series(exp_last)) + + res_false = s.duplicated(keep=False) + tm.assert_series_equal(res_false, Series(exp_false)) + + def test_datetime_likes(self): + dt = [ + "2011-01-01", + "2011-01-02", + "2011-01-01", + "NaT", + "2011-01-03", + "2011-01-02", + "2011-01-04", + "2011-01-01", + "NaT", + "2011-01-06", + ] + td = [ + "1 days", + "2 days", + "1 days", + "NaT", + "3 days", + "2 days", + "4 days", + "1 days", + "NaT", + "6 days", + ] + + cases = [ + np.array([Timestamp(d) for d in dt]), + np.array([Timestamp(d, tz="US/Eastern") for d in dt]), + np.array([Period(d, freq="D") for d in dt]), + np.array([np.datetime64(d) for d in dt]), + np.array([Timedelta(d) for d in td]), + ] + + exp_first = np.array( + [False, False, True, False, False, True, False, True, True, False] + ) + exp_last = np.array( + [True, True, True, True, False, False, False, False, False, False] + ) + exp_false = exp_first | exp_last + + for case in cases: + res_first = algos.duplicated(case, keep="first") + tm.assert_numpy_array_equal(res_first, exp_first) + + res_last = algos.duplicated(case, keep="last") + tm.assert_numpy_array_equal(res_last, exp_last) + + res_false = algos.duplicated(case, keep=False) + tm.assert_numpy_array_equal(res_false, exp_false) + + # index + for idx in [ + Index(case), + Index(case, dtype="category"), + Index(case, dtype=object), + ]: + res_first = idx.duplicated(keep="first") + tm.assert_numpy_array_equal(res_first, exp_first) + + res_last = idx.duplicated(keep="last") + tm.assert_numpy_array_equal(res_last, exp_last) + + res_false = idx.duplicated(keep=False) + tm.assert_numpy_array_equal(res_false, exp_false) + + # series + for s in [ + Series(case), + Series(case, dtype="category"), + Series(case, dtype=object), + ]: + res_first = s.duplicated(keep="first") + tm.assert_series_equal(res_first, Series(exp_first)) + + res_last = s.duplicated(keep="last") + tm.assert_series_equal(res_last, Series(exp_last)) + + res_false = s.duplicated(keep=False) + tm.assert_series_equal(res_false, Series(exp_false)) + + @pytest.mark.parametrize("case", [Index([1, 2, 3]), pd.RangeIndex(0, 3)]) + def test_unique_index(self, case): + assert case.is_unique is True + tm.assert_numpy_array_equal(case.duplicated(), np.array([False, False, False])) + + @pytest.mark.parametrize( + "arr, uniques", + [ + ( + [(0, 0), (0, 1), (1, 0), (1, 1), (0, 0), (0, 1), (1, 0), (1, 1)], + [(0, 0), (0, 1), (1, 0), (1, 1)], + ), + ( + [("b", "c"), ("a", "b"), ("a", "b"), ("b", "c")], + [("b", "c"), ("a", "b")], + ), + ([("a", 1), ("b", 2), ("a", 3), ("a", 1)], [("a", 1), ("b", 2), ("a", 3)]), + ], + ) + def test_unique_tuples(self, arr, uniques): + # https://github.com/pandas-dev/pandas/issues/16519 + expected = np.empty(len(uniques), dtype=object) + expected[:] = uniques + + msg = "unique with argument that is not not a Series" + with tm.assert_produces_warning(FutureWarning, match=msg): + result = pd.unique(arr) + tm.assert_numpy_array_equal(result, expected) + + @pytest.mark.parametrize( + "array,expected", + [ + ( + [1 + 1j, 0, 1, 1j, 1 + 2j, 1 + 2j], + # Should return a complex dtype in the future + np.array([(1 + 1j), 0j, (1 + 0j), 1j, (1 + 2j)], dtype=object), + ) + ], + ) + def test_unique_complex_numbers(self, array, expected): + # GH 17927 + msg = "unique with argument that is not not a Series" + with tm.assert_produces_warning(FutureWarning, match=msg): + result = pd.unique(array) + tm.assert_numpy_array_equal(result, expected) + + +class TestHashTable: + @pytest.mark.parametrize( + "htable, data", + [ + (ht.PyObjectHashTable, [f"foo_{i}" for i in range(1000)]), + (ht.StringHashTable, [f"foo_{i}" for i in range(1000)]), + (ht.Float64HashTable, np.arange(1000, dtype=np.float64)), + (ht.Int64HashTable, np.arange(1000, dtype=np.int64)), + (ht.UInt64HashTable, np.arange(1000, dtype=np.uint64)), + ], + ) + def test_hashtable_unique(self, htable, data, writable): + # output of maker has guaranteed unique elements + s = Series(data) + if htable == ht.Float64HashTable: + # add NaN for float column + s.loc[500] = np.nan + elif htable == ht.PyObjectHashTable: + # use different NaN types for object column + s.loc[500:502] = [np.nan, None, NaT] + + # create duplicated selection + s_duplicated = s.sample(frac=3, replace=True).reset_index(drop=True) + s_duplicated.values.setflags(write=writable) + + # drop_duplicates has own cython code (hash_table_func_helper.pxi) + # and is tested separately; keeps first occurrence like ht.unique() + expected_unique = s_duplicated.drop_duplicates(keep="first").values + result_unique = htable().unique(s_duplicated.values) + tm.assert_numpy_array_equal(result_unique, expected_unique) + + # test return_inverse=True + # reconstruction can only succeed if the inverse is correct + result_unique, result_inverse = htable().unique( + s_duplicated.values, return_inverse=True + ) + tm.assert_numpy_array_equal(result_unique, expected_unique) + reconstr = result_unique[result_inverse] + tm.assert_numpy_array_equal(reconstr, s_duplicated.values) + + @pytest.mark.parametrize( + "htable, data", + [ + (ht.PyObjectHashTable, [f"foo_{i}" for i in range(1000)]), + (ht.StringHashTable, [f"foo_{i}" for i in range(1000)]), + (ht.Float64HashTable, np.arange(1000, dtype=np.float64)), + (ht.Int64HashTable, np.arange(1000, dtype=np.int64)), + (ht.UInt64HashTable, np.arange(1000, dtype=np.uint64)), + ], + ) + def test_hashtable_factorize(self, htable, writable, data): + # output of maker has guaranteed unique elements + s = Series(data) + if htable == ht.Float64HashTable: + # add NaN for float column + s.loc[500] = np.nan + elif htable == ht.PyObjectHashTable: + # use different NaN types for object column + s.loc[500:502] = [np.nan, None, NaT] + + # create duplicated selection + s_duplicated = s.sample(frac=3, replace=True).reset_index(drop=True) + s_duplicated.values.setflags(write=writable) + na_mask = s_duplicated.isna().values + + result_unique, result_inverse = htable().factorize(s_duplicated.values) + + # drop_duplicates has own cython code (hash_table_func_helper.pxi) + # and is tested separately; keeps first occurrence like ht.factorize() + # since factorize removes all NaNs, we do the same here + expected_unique = s_duplicated.dropna().drop_duplicates().values + tm.assert_numpy_array_equal(result_unique, expected_unique) + + # reconstruction can only succeed if the inverse is correct. Since + # factorize removes the NaNs, those have to be excluded here as well + result_reconstruct = result_unique[result_inverse[~na_mask]] + expected_reconstruct = s_duplicated.dropna().values + tm.assert_numpy_array_equal(result_reconstruct, expected_reconstruct) + + +class TestRank: + @pytest.mark.parametrize( + "arr", + [ + [np.nan, np.nan, 5.0, 5.0, 5.0, np.nan, 1, 2, 3, np.nan], + [4.0, np.nan, 5.0, 5.0, 5.0, np.nan, 1, 2, 4.0, np.nan], + ], + ) + def test_scipy_compat(self, arr): + sp_stats = pytest.importorskip("scipy.stats") + + arr = np.array(arr) + + mask = ~np.isfinite(arr) + arr = arr.copy() + result = libalgos.rank_1d(arr) + arr[mask] = np.inf + exp = sp_stats.rankdata(arr) + exp[mask] = np.nan + tm.assert_almost_equal(result, exp) + + @pytest.mark.parametrize("dtype", np.typecodes["AllInteger"]) + def test_basic(self, writable, dtype): + exp = np.array([1, 2], dtype=np.float64) + + data = np.array([1, 100], dtype=dtype) + data.setflags(write=writable) + ser = Series(data) + result = algos.rank(ser) + tm.assert_numpy_array_equal(result, exp) + + @pytest.mark.parametrize("dtype", [np.float64, np.uint64]) + def test_uint64_overflow(self, dtype): + exp = np.array([1, 2], dtype=np.float64) + + s = Series([1, 2**63], dtype=dtype) + tm.assert_numpy_array_equal(algos.rank(s), exp) + + def test_too_many_ndims(self): + arr = np.array([[[1, 2, 3], [4, 5, 6], [7, 8, 9]]]) + msg = "Array with ndim > 2 are not supported" + + with pytest.raises(TypeError, match=msg): + algos.rank(arr) + + @pytest.mark.single_cpu + def test_pct_max_many_rows(self): + # GH 18271 + values = np.arange(2**24 + 1) + result = algos.rank(values, pct=True).max() + assert result == 1 + + values = np.arange(2**25 + 2).reshape(2**24 + 1, 2) + result = algos.rank(values, pct=True).max() + assert result == 1 + + +class TestMode: + def test_no_mode(self): + exp = Series([], dtype=np.float64, index=Index([], dtype=int)) + tm.assert_numpy_array_equal(algos.mode(np.array([])), exp.values) + + @pytest.mark.parametrize("dt", np.typecodes["AllInteger"] + np.typecodes["Float"]) + def test_mode_single(self, dt): + # GH 15714 + exp_single = [1] + data_single = [1] + + exp_multi = [1] + data_multi = [1, 1] + + ser = Series(data_single, dtype=dt) + exp = Series(exp_single, dtype=dt) + tm.assert_numpy_array_equal(algos.mode(ser.values), exp.values) + tm.assert_series_equal(ser.mode(), exp) + + ser = Series(data_multi, dtype=dt) + exp = Series(exp_multi, dtype=dt) + tm.assert_numpy_array_equal(algos.mode(ser.values), exp.values) + tm.assert_series_equal(ser.mode(), exp) + + def test_mode_obj_int(self): + exp = Series([1], dtype=int) + tm.assert_numpy_array_equal(algos.mode(exp.values), exp.values) + + exp = Series(["a", "b", "c"], dtype=object) + tm.assert_numpy_array_equal(algos.mode(exp.values), exp.values) + + @pytest.mark.parametrize("dt", np.typecodes["AllInteger"] + np.typecodes["Float"]) + def test_number_mode(self, dt): + exp_single = [1] + data_single = [1] * 5 + [2] * 3 + + exp_multi = [1, 3] + data_multi = [1] * 5 + [2] * 3 + [3] * 5 + + ser = Series(data_single, dtype=dt) + exp = Series(exp_single, dtype=dt) + tm.assert_numpy_array_equal(algos.mode(ser.values), exp.values) + tm.assert_series_equal(ser.mode(), exp) + + ser = Series(data_multi, dtype=dt) + exp = Series(exp_multi, dtype=dt) + tm.assert_numpy_array_equal(algos.mode(ser.values), exp.values) + tm.assert_series_equal(ser.mode(), exp) + + def test_strobj_mode(self): + exp = ["b"] + data = ["a"] * 2 + ["b"] * 3 + + ser = Series(data, dtype="c") + exp = Series(exp, dtype="c") + tm.assert_numpy_array_equal(algos.mode(ser.values), exp.values) + tm.assert_series_equal(ser.mode(), exp) + + @pytest.mark.parametrize("dt", [str, object]) + def test_strobj_multi_char(self, dt): + exp = ["bar"] + data = ["foo"] * 2 + ["bar"] * 3 + + ser = Series(data, dtype=dt) + exp = Series(exp, dtype=dt) + tm.assert_numpy_array_equal(algos.mode(ser.values), exp.values) + tm.assert_series_equal(ser.mode(), exp) + + def test_datelike_mode(self): + exp = Series(["1900-05-03", "2011-01-03", "2013-01-02"], dtype="M8[ns]") + ser = Series(["2011-01-03", "2013-01-02", "1900-05-03"], dtype="M8[ns]") + tm.assert_extension_array_equal(algos.mode(ser.values), exp._values) + tm.assert_series_equal(ser.mode(), exp) + + exp = Series(["2011-01-03", "2013-01-02"], dtype="M8[ns]") + ser = Series( + ["2011-01-03", "2013-01-02", "1900-05-03", "2011-01-03", "2013-01-02"], + dtype="M8[ns]", + ) + tm.assert_extension_array_equal(algos.mode(ser.values), exp._values) + tm.assert_series_equal(ser.mode(), exp) + + def test_timedelta_mode(self): + exp = Series(["-1 days", "0 days", "1 days"], dtype="timedelta64[ns]") + ser = Series(["1 days", "-1 days", "0 days"], dtype="timedelta64[ns]") + tm.assert_extension_array_equal(algos.mode(ser.values), exp._values) + tm.assert_series_equal(ser.mode(), exp) + + exp = Series(["2 min", "1 day"], dtype="timedelta64[ns]") + ser = Series( + ["1 day", "1 day", "-1 day", "-1 day 2 min", "2 min", "2 min"], + dtype="timedelta64[ns]", + ) + tm.assert_extension_array_equal(algos.mode(ser.values), exp._values) + tm.assert_series_equal(ser.mode(), exp) + + def test_mixed_dtype(self): + exp = Series(["foo"], dtype=object) + ser = Series([1, "foo", "foo"]) + tm.assert_numpy_array_equal(algos.mode(ser.values), exp.values) + tm.assert_series_equal(ser.mode(), exp) + + def test_uint64_overflow(self): + exp = Series([2**63], dtype=np.uint64) + ser = Series([1, 2**63, 2**63], dtype=np.uint64) + tm.assert_numpy_array_equal(algos.mode(ser.values), exp.values) + tm.assert_series_equal(ser.mode(), exp) + + exp = Series([1, 2**63], dtype=np.uint64) + ser = Series([1, 2**63], dtype=np.uint64) + tm.assert_numpy_array_equal(algos.mode(ser.values), exp.values) + tm.assert_series_equal(ser.mode(), exp) + + def test_categorical(self): + c = Categorical([1, 2]) + exp = c + res = Series(c).mode()._values + tm.assert_categorical_equal(res, exp) + + c = Categorical([1, "a", "a"]) + exp = Categorical(["a"], categories=[1, "a"]) + res = Series(c).mode()._values + tm.assert_categorical_equal(res, exp) + + c = Categorical([1, 1, 2, 3, 3]) + exp = Categorical([1, 3], categories=[1, 2, 3]) + res = Series(c).mode()._values + tm.assert_categorical_equal(res, exp) + + def test_index(self): + idx = Index([1, 2, 3]) + exp = Series([1, 2, 3], dtype=np.int64) + tm.assert_numpy_array_equal(algos.mode(idx), exp.values) + + idx = Index([1, "a", "a"]) + exp = Series(["a"], dtype=object) + tm.assert_numpy_array_equal(algos.mode(idx), exp.values) + + idx = Index([1, 1, 2, 3, 3]) + exp = Series([1, 3], dtype=np.int64) + tm.assert_numpy_array_equal(algos.mode(idx), exp.values) + + idx = Index( + ["1 day", "1 day", "-1 day", "-1 day 2 min", "2 min", "2 min"], + dtype="timedelta64[ns]", + ) + with pytest.raises(AttributeError, match="TimedeltaIndex"): + # algos.mode expects Arraylike, does *not* unwrap TimedeltaIndex + algos.mode(idx) + + def test_ser_mode_with_name(self): + # GH 46737 + ser = Series([1, 1, 3], name="foo") + result = ser.mode() + expected = Series([1], name="foo") + tm.assert_series_equal(result, expected) + + +class TestDiff: + @pytest.mark.parametrize("dtype", ["M8[ns]", "m8[ns]"]) + def test_diff_datetimelike_nat(self, dtype): + # NaT - NaT is NaT, not 0 + arr = np.arange(12).astype(np.int64).view(dtype).reshape(3, 4) + arr[:, 2] = arr.dtype.type("NaT", "ns") + result = algos.diff(arr, 1, axis=0) + + expected = np.ones(arr.shape, dtype="timedelta64[ns]") * 4 + expected[:, 2] = np.timedelta64("NaT", "ns") + expected[0, :] = np.timedelta64("NaT", "ns") + + tm.assert_numpy_array_equal(result, expected) + + result = algos.diff(arr.T, 1, axis=1) + tm.assert_numpy_array_equal(result, expected.T) + + def test_diff_ea_axis(self): + dta = date_range("2016-01-01", periods=3, tz="US/Pacific")._data + + msg = "cannot diff DatetimeArray on axis=1" + with pytest.raises(ValueError, match=msg): + algos.diff(dta, 1, axis=1) + + @pytest.mark.parametrize("dtype", ["int8", "int16"]) + def test_diff_low_precision_int(self, dtype): + arr = np.array([0, 1, 1, 0, 0], dtype=dtype) + result = algos.diff(arr, 1) + expected = np.array([np.nan, 1, 0, -1, 0], dtype="float32") + tm.assert_numpy_array_equal(result, expected) + + +@pytest.mark.parametrize("op", [np.array, pd.array]) +def test_union_with_duplicates(op): + # GH#36289 + lvals = op([3, 1, 3, 4]) + rvals = op([2, 3, 1, 1]) + expected = op([3, 3, 1, 1, 4, 2]) + if isinstance(expected, np.ndarray): + result = algos.union_with_duplicates(lvals, rvals) + tm.assert_numpy_array_equal(result, expected) + else: + result = algos.union_with_duplicates(lvals, rvals) + tm.assert_extension_array_equal(result, expected) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/test_common.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/test_common.py new file mode 100644 index 0000000000000000000000000000000000000000..e8a1c961c8cb6e5b1014f6baa193d4593d85d981 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/test_common.py @@ -0,0 +1,267 @@ +import collections +from functools import partial +import string +import subprocess +import sys +import textwrap + +import numpy as np +import pytest + +import pandas as pd +from pandas import Series +import pandas._testing as tm +from pandas.core import ops +import pandas.core.common as com +from pandas.util.version import Version + + +def test_get_callable_name(): + getname = com.get_callable_name + + def fn(x): + return x + + lambda_ = lambda x: x + part1 = partial(fn) + part2 = partial(part1) + + class somecall: + def __call__(self): + # This shouldn't actually get called below; somecall.__init__ + # should. + raise NotImplementedError + + assert getname(fn) == "fn" + assert getname(lambda_) + assert getname(part1) == "fn" + assert getname(part2) == "fn" + assert getname(somecall()) == "somecall" + assert getname(1) is None + + +def test_any_none(): + assert com.any_none(1, 2, 3, None) + assert not com.any_none(1, 2, 3, 4) + + +def test_all_not_none(): + assert com.all_not_none(1, 2, 3, 4) + assert not com.all_not_none(1, 2, 3, None) + assert not com.all_not_none(None, None, None, None) + + +def test_random_state(): + # Check with seed + state = com.random_state(5) + assert state.uniform() == np.random.RandomState(5).uniform() + + # Check with random state object + state2 = np.random.RandomState(10) + assert com.random_state(state2).uniform() == np.random.RandomState(10).uniform() + + # check with no arg random state + assert com.random_state() is np.random + + # check array-like + # GH32503 + state_arr_like = np.random.default_rng(None).integers( + 0, 2**31, size=624, dtype="uint32" + ) + assert ( + com.random_state(state_arr_like).uniform() + == np.random.RandomState(state_arr_like).uniform() + ) + + # Check BitGenerators + # GH32503 + assert ( + com.random_state(np.random.MT19937(3)).uniform() + == np.random.RandomState(np.random.MT19937(3)).uniform() + ) + assert ( + com.random_state(np.random.PCG64(11)).uniform() + == np.random.RandomState(np.random.PCG64(11)).uniform() + ) + + # Error for floats or strings + msg = ( + "random_state must be an integer, array-like, a BitGenerator, Generator, " + "a numpy RandomState, or None" + ) + with pytest.raises(ValueError, match=msg): + com.random_state("test") + + with pytest.raises(ValueError, match=msg): + com.random_state(5.5) + + +@pytest.mark.parametrize( + "left, right, expected", + [ + (Series([1], name="x"), Series([2], name="x"), "x"), + (Series([1], name="x"), Series([2], name="y"), None), + (Series([1]), Series([2], name="x"), None), + (Series([1], name="x"), Series([2]), None), + (Series([1], name="x"), [2], "x"), + ([1], Series([2], name="y"), "y"), + # matching NAs + (Series([1], name=np.nan), pd.Index([], name=np.nan), np.nan), + (Series([1], name=np.nan), pd.Index([], name=pd.NaT), None), + (Series([1], name=pd.NA), pd.Index([], name=pd.NA), pd.NA), + # tuple name GH#39757 + ( + Series([1], name=np.int64(1)), + pd.Index([], name=(np.int64(1), np.int64(2))), + None, + ), + ( + Series([1], name=(np.int64(1), np.int64(2))), + pd.Index([], name=(np.int64(1), np.int64(2))), + (np.int64(1), np.int64(2)), + ), + pytest.param( + Series([1], name=(np.float64("nan"), np.int64(2))), + pd.Index([], name=(np.float64("nan"), np.int64(2))), + (np.float64("nan"), np.int64(2)), + marks=pytest.mark.xfail( + reason="Not checking for matching NAs inside tuples." + ), + ), + ], +) +def test_maybe_match_name(left, right, expected): + res = ops.common._maybe_match_name(left, right) + assert res is expected or res == expected + + +def test_standardize_mapping(): + # No uninitialized defaultdicts + msg = r"to_dict\(\) only accepts initialized defaultdicts" + with pytest.raises(TypeError, match=msg): + com.standardize_mapping(collections.defaultdict) + + # No non-mapping subtypes, instance + msg = "unsupported type: " + with pytest.raises(TypeError, match=msg): + com.standardize_mapping([]) + + # No non-mapping subtypes, class + with pytest.raises(TypeError, match=msg): + com.standardize_mapping(list) + + fill = {"bad": "data"} + assert com.standardize_mapping(fill) == dict + + # Convert instance to type + assert com.standardize_mapping({}) == dict + + dd = collections.defaultdict(list) + assert isinstance(com.standardize_mapping(dd), partial) + + +def test_git_version(): + # GH 21295 + git_version = pd.__git_version__ + assert len(git_version) == 40 + assert all(c in string.hexdigits for c in git_version) + + +def test_version_tag(): + version = Version(pd.__version__) + try: + version > Version("0.0.1") + except TypeError: + raise ValueError( + "No git tags exist, please sync tags between upstream and your repo" + ) + + +@pytest.mark.parametrize( + "obj", [(obj,) for obj in pd.__dict__.values() if callable(obj)] +) +def test_serializable(obj): + # GH 35611 + unpickled = tm.round_trip_pickle(obj) + assert type(obj) == type(unpickled) + + +class TestIsBoolIndexer: + def test_non_bool_array_with_na(self): + # in particular, this should not raise + arr = np.array(["A", "B", np.nan], dtype=object) + assert not com.is_bool_indexer(arr) + + def test_list_subclass(self): + # GH#42433 + + class MyList(list): + pass + + val = MyList(["a"]) + + assert not com.is_bool_indexer(val) + + val = MyList([True]) + assert com.is_bool_indexer(val) + + def test_frozenlist(self): + # GH#42461 + data = {"col1": [1, 2], "col2": [3, 4]} + df = pd.DataFrame(data=data) + + frozen = df.index.names[1:] + assert not com.is_bool_indexer(frozen) + + result = df[frozen] + expected = df[[]] + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("with_exception", [True, False]) +def test_temp_setattr(with_exception): + # GH#45954 + ser = Series(dtype=object) + ser.name = "first" + # Raise a ValueError in either case to satisfy pytest.raises + match = "Inside exception raised" if with_exception else "Outside exception raised" + with pytest.raises(ValueError, match=match): + with com.temp_setattr(ser, "name", "second"): + assert ser.name == "second" + if with_exception: + raise ValueError("Inside exception raised") + raise ValueError("Outside exception raised") + assert ser.name == "first" + + +@pytest.mark.single_cpu +def test_str_size(): + # GH#21758 + a = "a" + expected = sys.getsizeof(a) + pyexe = sys.executable.replace("\\", "/") + call = [ + pyexe, + "-c", + "a='a';import sys;sys.getsizeof(a);import pandas;print(sys.getsizeof(a));", + ] + result = subprocess.check_output(call).decode()[-4:-1].strip("\n") + assert int(result) == int(expected) + + +@pytest.mark.single_cpu +def test_bz2_missing_import(): + # Check whether bz2 missing import is handled correctly (issue #53857) + code = """ + import sys + sys.modules['bz2'] = None + import pytest + import pandas as pd + from pandas.compat import get_bz2_file + msg = 'bz2 module not available.' + with pytest.raises(RuntimeError, match=msg): + get_bz2_file() + """ + code = textwrap.dedent(code) + call = [sys.executable, "-c", code] + subprocess.check_output(call) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/test_downstream.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/test_downstream.py new file mode 100644 index 0000000000000000000000000000000000000000..51ce73ef54300c5d2f9be4e3988de77843a61a8e --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/test_downstream.py @@ -0,0 +1,362 @@ +""" +Testing that we work in the downstream packages +""" +import array +import subprocess +import sys + +import numpy as np +import pytest + +from pandas.errors import IntCastingNaNError +import pandas.util._test_decorators as td + +import pandas as pd +from pandas import ( + DataFrame, + DatetimeIndex, + Series, + TimedeltaIndex, +) +import pandas._testing as tm +from pandas.core.arrays import ( + DatetimeArray, + TimedeltaArray, +) + + +@pytest.fixture +def df(): + return DataFrame({"A": [1, 2, 3]}) + + +def test_dask(df): + # dask sets "compute.use_numexpr" to False, so catch the current value + # and ensure to reset it afterwards to avoid impacting other tests + olduse = pd.get_option("compute.use_numexpr") + + try: + pytest.importorskip("toolz") + dd = pytest.importorskip("dask.dataframe") + + ddf = dd.from_pandas(df, npartitions=3) + assert ddf.A is not None + assert ddf.compute() is not None + finally: + pd.set_option("compute.use_numexpr", olduse) + + +def test_dask_ufunc(): + # dask sets "compute.use_numexpr" to False, so catch the current value + # and ensure to reset it afterwards to avoid impacting other tests + olduse = pd.get_option("compute.use_numexpr") + + try: + da = pytest.importorskip("dask.array") + dd = pytest.importorskip("dask.dataframe") + + s = Series([1.5, 2.3, 3.7, 4.0]) + ds = dd.from_pandas(s, npartitions=2) + + result = da.fix(ds).compute() + expected = np.fix(s) + tm.assert_series_equal(result, expected) + finally: + pd.set_option("compute.use_numexpr", olduse) + + +def test_construct_dask_float_array_int_dtype_match_ndarray(): + # GH#40110 make sure we treat a float-dtype dask array with the same + # rules we would for an ndarray + dd = pytest.importorskip("dask.dataframe") + + arr = np.array([1, 2.5, 3]) + darr = dd.from_array(arr) + + res = Series(darr) + expected = Series(arr) + tm.assert_series_equal(res, expected) + + # GH#49599 in 2.0 we raise instead of silently ignoring the dtype + msg = "Trying to coerce float values to integers" + with pytest.raises(ValueError, match=msg): + Series(darr, dtype="i8") + + msg = r"Cannot convert non-finite values \(NA or inf\) to integer" + arr[2] = np.nan + with pytest.raises(IntCastingNaNError, match=msg): + Series(darr, dtype="i8") + # which is the same as we get with a numpy input + with pytest.raises(IntCastingNaNError, match=msg): + Series(arr, dtype="i8") + + +def test_xarray(df): + pytest.importorskip("xarray") + + assert df.to_xarray() is not None + + +def test_xarray_cftimeindex_nearest(): + # https://github.com/pydata/xarray/issues/3751 + cftime = pytest.importorskip("cftime") + xarray = pytest.importorskip("xarray") + + times = xarray.cftime_range("0001", periods=2) + key = cftime.DatetimeGregorian(2000, 1, 1) + result = times.get_indexer([key], method="nearest") + expected = 1 + assert result == expected + + +@pytest.mark.single_cpu +def test_oo_optimizable(): + # GH 21071 + subprocess.check_call([sys.executable, "-OO", "-c", "import pandas"]) + + +@pytest.mark.single_cpu +def test_oo_optimized_datetime_index_unpickle(): + # GH 42866 + subprocess.check_call( + [ + sys.executable, + "-OO", + "-c", + ( + "import pandas as pd, pickle; " + "pickle.loads(pickle.dumps(pd.date_range('2021-01-01', periods=1)))" + ), + ] + ) + + +def test_statsmodels(): + smf = pytest.importorskip("statsmodels.formula.api") + + df = DataFrame( + {"Lottery": range(5), "Literacy": range(5), "Pop1831": range(100, 105)} + ) + smf.ols("Lottery ~ Literacy + np.log(Pop1831)", data=df).fit() + + +def test_scikit_learn(): + pytest.importorskip("sklearn") + from sklearn import ( + datasets, + svm, + ) + + digits = datasets.load_digits() + clf = svm.SVC(gamma=0.001, C=100.0) + clf.fit(digits.data[:-1], digits.target[:-1]) + clf.predict(digits.data[-1:]) + + +def test_seaborn(): + seaborn = pytest.importorskip("seaborn") + tips = DataFrame( + {"day": pd.date_range("2023", freq="D", periods=5), "total_bill": range(5)} + ) + seaborn.stripplot(x="day", y="total_bill", data=tips) + + +def test_pandas_datareader(): + pytest.importorskip("pandas_datareader") + + +@pytest.mark.filterwarnings("ignore:Passing a BlockManager:DeprecationWarning") +def test_pyarrow(df): + pyarrow = pytest.importorskip("pyarrow") + table = pyarrow.Table.from_pandas(df) + result = table.to_pandas() + tm.assert_frame_equal(result, df) + + +def test_yaml_dump(df): + # GH#42748 + yaml = pytest.importorskip("yaml") + + dumped = yaml.dump(df) + + loaded = yaml.load(dumped, Loader=yaml.Loader) + tm.assert_frame_equal(df, loaded) + + loaded2 = yaml.load(dumped, Loader=yaml.UnsafeLoader) + tm.assert_frame_equal(df, loaded2) + + +@pytest.mark.single_cpu +def test_missing_required_dependency(): + # GH 23868 + # To ensure proper isolation, we pass these flags + # -S : disable site-packages + # -s : disable user site-packages + # -E : disable PYTHON* env vars, especially PYTHONPATH + # https://github.com/MacPython/pandas-wheels/pull/50 + + pyexe = sys.executable.replace("\\", "/") + + # We skip this test if pandas is installed as a site package. We first + # import the package normally and check the path to the module before + # executing the test which imports pandas with site packages disabled. + call = [pyexe, "-c", "import pandas;print(pandas.__file__)"] + output = subprocess.check_output(call).decode() + if "site-packages" in output: + pytest.skip("pandas installed as site package") + + # This test will fail if pandas is installed as a site package. The flags + # prevent pandas being imported and the test will report Failed: DID NOT + # RAISE + call = [pyexe, "-sSE", "-c", "import pandas"] + + msg = ( + rf"Command '\['{pyexe}', '-sSE', '-c', 'import pandas'\]' " + "returned non-zero exit status 1." + ) + + with pytest.raises(subprocess.CalledProcessError, match=msg) as exc: + subprocess.check_output(call, stderr=subprocess.STDOUT) + + output = exc.value.stdout.decode() + for name in ["numpy", "pytz", "dateutil"]: + assert name in output + + +def test_frame_setitem_dask_array_into_new_col(): + # GH#47128 + + # dask sets "compute.use_numexpr" to False, so catch the current value + # and ensure to reset it afterwards to avoid impacting other tests + olduse = pd.get_option("compute.use_numexpr") + + try: + da = pytest.importorskip("dask.array") + + dda = da.array([1, 2]) + df = DataFrame({"a": ["a", "b"]}) + df["b"] = dda + df["c"] = dda + df.loc[[False, True], "b"] = 100 + result = df.loc[[1], :] + expected = DataFrame({"a": ["b"], "b": [100], "c": [2]}, index=[1]) + tm.assert_frame_equal(result, expected) + finally: + pd.set_option("compute.use_numexpr", olduse) + + +def test_pandas_priority(): + # GH#48347 + + class MyClass: + __pandas_priority__ = 5000 + + def __radd__(self, other): + return self + + left = MyClass() + right = Series(range(3)) + + assert right.__add__(left) is NotImplemented + assert right + left is left + + +@pytest.fixture( + params=[ + "memoryview", + "array", + pytest.param("dask", marks=td.skip_if_no("dask.array")), + pytest.param("xarray", marks=td.skip_if_no("xarray")), + ] +) +def array_likes(request): + """ + Fixture giving a numpy array and a parametrized 'data' object, which can + be a memoryview, array, dask or xarray object created from the numpy array. + """ + # GH#24539 recognize e.g xarray, dask, ... + arr = np.array([1, 2, 3], dtype=np.int64) + + name = request.param + if name == "memoryview": + data = memoryview(arr) + elif name == "array": + data = array.array("i", arr) + elif name == "dask": + import dask.array + + data = dask.array.array(arr) + elif name == "xarray": + import xarray as xr + + data = xr.DataArray(arr) + + return arr, data + + +@pytest.mark.parametrize("dtype", ["M8[ns]", "m8[ns]"]) +def test_from_obscure_array(dtype, array_likes): + # GH#24539 recognize e.g xarray, dask, ... + # Note: we dont do this for PeriodArray bc _from_sequence won't accept + # an array of integers + # TODO: could check with arraylike of Period objects + arr, data = array_likes + + cls = {"M8[ns]": DatetimeArray, "m8[ns]": TimedeltaArray}[dtype] + + depr_msg = f"{cls.__name__}.__init__ is deprecated" + with tm.assert_produces_warning(FutureWarning, match=depr_msg): + expected = cls(arr) + result = cls._from_sequence(data, dtype=dtype) + tm.assert_extension_array_equal(result, expected) + + if not isinstance(data, memoryview): + # FIXME(GH#44431) these raise on memoryview and attempted fix + # fails on py3.10 + func = {"M8[ns]": pd.to_datetime, "m8[ns]": pd.to_timedelta}[dtype] + result = func(arr).array + expected = func(data).array + tm.assert_equal(result, expected) + + # Let's check the Indexes while we're here + idx_cls = {"M8[ns]": DatetimeIndex, "m8[ns]": TimedeltaIndex}[dtype] + result = idx_cls(arr) + expected = idx_cls(data) + tm.assert_index_equal(result, expected) + + +def test_dataframe_consortium() -> None: + """ + Test some basic methods of the dataframe consortium standard. + + Full testing is done at https://github.com/data-apis/dataframe-api-compat, + this is just to check that the entry point works as expected. + """ + pytest.importorskip("dataframe_api_compat") + df_pd = DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}) + df = df_pd.__dataframe_consortium_standard__() + result_1 = df.get_column_names() + expected_1 = ["a", "b"] + assert result_1 == expected_1 + + ser = Series([1, 2, 3], name="a") + col = ser.__column_consortium_standard__() + assert col.name == "a" + + +def test_xarray_coerce_unit(): + # GH44053 + xr = pytest.importorskip("xarray") + + arr = xr.DataArray([1, 2, 3]) + result = pd.to_datetime(arr, unit="ns") + expected = DatetimeIndex( + [ + "1970-01-01 00:00:00.000000001", + "1970-01-01 00:00:00.000000002", + "1970-01-01 00:00:00.000000003", + ], + dtype="datetime64[ns]", + freq=None, + ) + tm.assert_index_equal(result, expected) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/test_errors.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/test_errors.py new file mode 100644 index 0000000000000000000000000000000000000000..aeddc08e4b888c0937a3095a46003613e0115876 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/test_errors.py @@ -0,0 +1,112 @@ +import pytest + +from pandas.errors import ( + AbstractMethodError, + UndefinedVariableError, +) + +import pandas as pd + + +@pytest.mark.parametrize( + "exc", + [ + "AttributeConflictWarning", + "CSSWarning", + "CategoricalConversionWarning", + "ClosedFileError", + "DataError", + "DatabaseError", + "DtypeWarning", + "EmptyDataError", + "IncompatibilityWarning", + "IndexingError", + "InvalidColumnName", + "InvalidComparison", + "InvalidVersion", + "LossySetitemError", + "MergeError", + "NoBufferPresent", + "NumExprClobberingError", + "NumbaUtilError", + "OptionError", + "OutOfBoundsDatetime", + "ParserError", + "ParserWarning", + "PerformanceWarning", + "PossibleDataLossError", + "PossiblePrecisionLoss", + "PyperclipException", + "SettingWithCopyError", + "SettingWithCopyWarning", + "SpecificationError", + "UnsortedIndexError", + "UnsupportedFunctionCall", + "ValueLabelTypeMismatch", + ], +) +def test_exception_importable(exc): + from pandas import errors + + err = getattr(errors, exc) + assert err is not None + + # check that we can raise on them + + msg = "^$" + + with pytest.raises(err, match=msg): + raise err() + + +def test_catch_oob(): + from pandas import errors + + msg = "Cannot cast 1500-01-01 00:00:00 to unit='ns' without overflow" + with pytest.raises(errors.OutOfBoundsDatetime, match=msg): + pd.Timestamp("15000101").as_unit("ns") + + +@pytest.mark.parametrize( + "is_local", + [ + True, + False, + ], +) +def test_catch_undefined_variable_error(is_local): + variable_name = "x" + if is_local: + msg = f"local variable '{variable_name}' is not defined" + else: + msg = f"name '{variable_name}' is not defined" + + with pytest.raises(UndefinedVariableError, match=msg): + raise UndefinedVariableError(variable_name, is_local) + + +class Foo: + @classmethod + def classmethod(cls): + raise AbstractMethodError(cls, methodtype="classmethod") + + @property + def property(self): + raise AbstractMethodError(self, methodtype="property") + + def method(self): + raise AbstractMethodError(self) + + +def test_AbstractMethodError_classmethod(): + xpr = "This classmethod must be defined in the concrete class Foo" + with pytest.raises(AbstractMethodError, match=xpr): + Foo.classmethod() + + xpr = "This property must be defined in the concrete class Foo" + with pytest.raises(AbstractMethodError, match=xpr): + Foo().property + + xpr = "This method must be defined in the concrete class Foo" + with pytest.raises(AbstractMethodError, match=xpr): + Foo().method() diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/test_expressions.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/test_expressions.py new file mode 100644 index 0000000000000000000000000000000000000000..dfec99f0786ebf11a44dedfad8aa8e1015356fab --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/test_expressions.py @@ -0,0 +1,466 @@ +import operator +import re + +import numpy as np +import pytest + +from pandas import option_context +import pandas._testing as tm +from pandas.core.api import ( + DataFrame, + Index, + Series, +) +from pandas.core.computation import expressions as expr + + +@pytest.fixture +def _frame(): + return DataFrame( + np.random.default_rng(2).standard_normal((10001, 4)), + columns=list("ABCD"), + dtype="float64", + ) + + +@pytest.fixture +def _frame2(): + return DataFrame( + np.random.default_rng(2).standard_normal((100, 4)), + columns=list("ABCD"), + dtype="float64", + ) + + +@pytest.fixture +def _mixed(_frame): + return DataFrame( + { + "A": _frame["A"].copy(), + "B": _frame["B"].astype("float32"), + "C": _frame["C"].astype("int64"), + "D": _frame["D"].astype("int32"), + } + ) + + +@pytest.fixture +def _mixed2(_frame2): + return DataFrame( + { + "A": _frame2["A"].copy(), + "B": _frame2["B"].astype("float32"), + "C": _frame2["C"].astype("int64"), + "D": _frame2["D"].astype("int32"), + } + ) + + +@pytest.fixture +def _integer(): + return DataFrame( + np.random.default_rng(2).integers(1, 100, size=(10001, 4)), + columns=list("ABCD"), + dtype="int64", + ) + + +@pytest.fixture +def _integer_integers(_integer): + # integers to get a case with zeros + return _integer * np.random.default_rng(2).integers(0, 2, size=np.shape(_integer)) + + +@pytest.fixture +def _integer2(): + return DataFrame( + np.random.default_rng(2).integers(1, 100, size=(101, 4)), + columns=list("ABCD"), + dtype="int64", + ) + + +@pytest.fixture +def _array(_frame): + return _frame["A"].values.copy() + + +@pytest.fixture +def _array2(_frame2): + return _frame2["A"].values.copy() + + +@pytest.fixture +def _array_mixed(_mixed): + return _mixed["D"].values.copy() + + +@pytest.fixture +def _array_mixed2(_mixed2): + return _mixed2["D"].values.copy() + + +@pytest.mark.skipif(not expr.USE_NUMEXPR, reason="not using numexpr") +class TestExpressions: + @staticmethod + def call_op(df, other, flex: bool, opname: str): + if flex: + op = lambda x, y: getattr(x, opname)(y) + op.__name__ = opname + else: + op = getattr(operator, opname) + + with option_context("compute.use_numexpr", False): + expected = op(df, other) + + expr.get_test_result() + + result = op(df, other) + return result, expected + + @pytest.mark.parametrize( + "fixture", + [ + "_integer", + "_integer2", + "_integer_integers", + "_frame", + "_frame2", + "_mixed", + "_mixed2", + ], + ) + @pytest.mark.parametrize("flex", [True, False]) + @pytest.mark.parametrize( + "arith", ["add", "sub", "mul", "mod", "truediv", "floordiv"] + ) + def test_run_arithmetic(self, request, fixture, flex, arith, monkeypatch): + df = request.getfixturevalue(fixture) + with monkeypatch.context() as m: + m.setattr(expr, "_MIN_ELEMENTS", 0) + result, expected = self.call_op(df, df, flex, arith) + + if arith == "truediv": + assert all(x.kind == "f" for x in expected.dtypes.values) + tm.assert_equal(expected, result) + + for i in range(len(df.columns)): + result, expected = self.call_op( + df.iloc[:, i], df.iloc[:, i], flex, arith + ) + if arith == "truediv": + assert expected.dtype.kind == "f" + tm.assert_equal(expected, result) + + @pytest.mark.parametrize( + "fixture", + [ + "_integer", + "_integer2", + "_integer_integers", + "_frame", + "_frame2", + "_mixed", + "_mixed2", + ], + ) + @pytest.mark.parametrize("flex", [True, False]) + def test_run_binary(self, request, fixture, flex, comparison_op, monkeypatch): + """ + tests solely that the result is the same whether or not numexpr is + enabled. Need to test whether the function does the correct thing + elsewhere. + """ + df = request.getfixturevalue(fixture) + arith = comparison_op.__name__ + with option_context("compute.use_numexpr", False): + other = df.copy() + 1 + + with monkeypatch.context() as m: + m.setattr(expr, "_MIN_ELEMENTS", 0) + expr.set_test_mode(True) + + result, expected = self.call_op(df, other, flex, arith) + + used_numexpr = expr.get_test_result() + assert used_numexpr, "Did not use numexpr as expected." + tm.assert_equal(expected, result) + + for i in range(len(df.columns)): + binary_comp = other.iloc[:, i] + 1 + self.call_op(df.iloc[:, i], binary_comp, flex, "add") + + def test_invalid(self): + array = np.random.default_rng(2).standard_normal(1_000_001) + array2 = np.random.default_rng(2).standard_normal(100) + + # no op + result = expr._can_use_numexpr(operator.add, None, array, array, "evaluate") + assert not result + + # min elements + result = expr._can_use_numexpr(operator.add, "+", array2, array2, "evaluate") + assert not result + + # ok, we only check on first part of expression + result = expr._can_use_numexpr(operator.add, "+", array, array2, "evaluate") + assert result + + @pytest.mark.filterwarnings("ignore:invalid value encountered in:RuntimeWarning") + @pytest.mark.parametrize( + "opname,op_str", + [("add", "+"), ("sub", "-"), ("mul", "*"), ("truediv", "/"), ("pow", "**")], + ) + @pytest.mark.parametrize( + "left_fix,right_fix", [("_array", "_array2"), ("_array_mixed", "_array_mixed2")] + ) + def test_binary_ops(self, request, opname, op_str, left_fix, right_fix): + left = request.getfixturevalue(left_fix) + right = request.getfixturevalue(right_fix) + + def testit(left, right, opname, op_str): + if opname == "pow": + left = np.abs(left) + + op = getattr(operator, opname) + + # array has 0s + result = expr.evaluate(op, left, left, use_numexpr=True) + expected = expr.evaluate(op, left, left, use_numexpr=False) + tm.assert_numpy_array_equal(result, expected) + + result = expr._can_use_numexpr(op, op_str, right, right, "evaluate") + assert not result + + with option_context("compute.use_numexpr", False): + testit(left, right, opname, op_str) + + expr.set_numexpr_threads(1) + testit(left, right, opname, op_str) + expr.set_numexpr_threads() + testit(left, right, opname, op_str) + + @pytest.mark.parametrize( + "left_fix,right_fix", [("_array", "_array2"), ("_array_mixed", "_array_mixed2")] + ) + def test_comparison_ops(self, request, comparison_op, left_fix, right_fix): + left = request.getfixturevalue(left_fix) + right = request.getfixturevalue(right_fix) + + def testit(): + f12 = left + 1 + f22 = right + 1 + + op = comparison_op + + result = expr.evaluate(op, left, f12, use_numexpr=True) + expected = expr.evaluate(op, left, f12, use_numexpr=False) + tm.assert_numpy_array_equal(result, expected) + + result = expr._can_use_numexpr(op, op, right, f22, "evaluate") + assert not result + + with option_context("compute.use_numexpr", False): + testit() + + expr.set_numexpr_threads(1) + testit() + expr.set_numexpr_threads() + testit() + + @pytest.mark.parametrize("cond", [True, False]) + @pytest.mark.parametrize("fixture", ["_frame", "_frame2", "_mixed", "_mixed2"]) + def test_where(self, request, cond, fixture): + df = request.getfixturevalue(fixture) + + def testit(): + c = np.empty(df.shape, dtype=np.bool_) + c.fill(cond) + result = expr.where(c, df.values, df.values + 1) + expected = np.where(c, df.values, df.values + 1) + tm.assert_numpy_array_equal(result, expected) + + with option_context("compute.use_numexpr", False): + testit() + + expr.set_numexpr_threads(1) + testit() + expr.set_numexpr_threads() + testit() + + @pytest.mark.parametrize( + "op_str,opname", [("/", "truediv"), ("//", "floordiv"), ("**", "pow")] + ) + def test_bool_ops_raise_on_arithmetic(self, op_str, opname): + df = DataFrame( + { + "a": np.random.default_rng(2).random(10) > 0.5, + "b": np.random.default_rng(2).random(10) > 0.5, + } + ) + + msg = f"operator '{opname}' not implemented for bool dtypes" + f = getattr(operator, opname) + err_msg = re.escape(msg) + + with pytest.raises(NotImplementedError, match=err_msg): + f(df, df) + + with pytest.raises(NotImplementedError, match=err_msg): + f(df.a, df.b) + + with pytest.raises(NotImplementedError, match=err_msg): + f(df.a, True) + + with pytest.raises(NotImplementedError, match=err_msg): + f(False, df.a) + + with pytest.raises(NotImplementedError, match=err_msg): + f(False, df) + + with pytest.raises(NotImplementedError, match=err_msg): + f(df, True) + + @pytest.mark.parametrize( + "op_str,opname", [("+", "add"), ("*", "mul"), ("-", "sub")] + ) + def test_bool_ops_warn_on_arithmetic(self, op_str, opname): + n = 10 + df = DataFrame( + { + "a": np.random.default_rng(2).random(n) > 0.5, + "b": np.random.default_rng(2).random(n) > 0.5, + } + ) + + subs = {"+": "|", "*": "&", "-": "^"} + sub_funcs = {"|": "or_", "&": "and_", "^": "xor"} + + f = getattr(operator, opname) + fe = getattr(operator, sub_funcs[subs[op_str]]) + + if op_str == "-": + # raises TypeError + return + + with tm.use_numexpr(True, min_elements=5): + with tm.assert_produces_warning(): + r = f(df, df) + e = fe(df, df) + tm.assert_frame_equal(r, e) + + with tm.assert_produces_warning(): + r = f(df.a, df.b) + e = fe(df.a, df.b) + tm.assert_series_equal(r, e) + + with tm.assert_produces_warning(): + r = f(df.a, True) + e = fe(df.a, True) + tm.assert_series_equal(r, e) + + with tm.assert_produces_warning(): + r = f(False, df.a) + e = fe(False, df.a) + tm.assert_series_equal(r, e) + + with tm.assert_produces_warning(): + r = f(False, df) + e = fe(False, df) + tm.assert_frame_equal(r, e) + + with tm.assert_produces_warning(): + r = f(df, True) + e = fe(df, True) + tm.assert_frame_equal(r, e) + + @pytest.mark.parametrize( + "test_input,expected", + [ + ( + DataFrame( + [[0, 1, 2, "aa"], [0, 1, 2, "aa"]], columns=["a", "b", "c", "dtype"] + ), + DataFrame([[False, False], [False, False]], columns=["a", "dtype"]), + ), + ( + DataFrame( + [[0, 3, 2, "aa"], [0, 4, 2, "aa"], [0, 1, 1, "bb"]], + columns=["a", "b", "c", "dtype"], + ), + DataFrame( + [[False, False], [False, False], [False, False]], + columns=["a", "dtype"], + ), + ), + ], + ) + def test_bool_ops_column_name_dtype(self, test_input, expected): + # GH 22383 - .ne fails if columns containing column name 'dtype' + result = test_input.loc[:, ["a", "dtype"]].ne(test_input.loc[:, ["a", "dtype"]]) + tm.assert_frame_equal(result, expected) + + @pytest.mark.parametrize( + "arith", ("add", "sub", "mul", "mod", "truediv", "floordiv") + ) + @pytest.mark.parametrize("axis", (0, 1)) + def test_frame_series_axis(self, axis, arith, _frame, monkeypatch): + # GH#26736 Dataframe.floordiv(Series, axis=1) fails + + df = _frame + if axis == 1: + other = df.iloc[0, :] + else: + other = df.iloc[:, 0] + + with monkeypatch.context() as m: + m.setattr(expr, "_MIN_ELEMENTS", 0) + + op_func = getattr(df, arith) + + with option_context("compute.use_numexpr", False): + expected = op_func(other, axis=axis) + + result = op_func(other, axis=axis) + tm.assert_frame_equal(expected, result) + + @pytest.mark.parametrize( + "op", + [ + "__mod__", + "__rmod__", + "__floordiv__", + "__rfloordiv__", + ], + ) + @pytest.mark.parametrize("box", [DataFrame, Series, Index]) + @pytest.mark.parametrize("scalar", [-5, 5]) + def test_python_semantics_with_numexpr_installed( + self, op, box, scalar, monkeypatch + ): + # https://github.com/pandas-dev/pandas/issues/36047 + with monkeypatch.context() as m: + m.setattr(expr, "_MIN_ELEMENTS", 0) + data = np.arange(-50, 50) + obj = box(data) + method = getattr(obj, op) + result = method(scalar) + + # compare result with numpy + with option_context("compute.use_numexpr", False): + expected = method(scalar) + + tm.assert_equal(result, expected) + + # compare result element-wise with Python + for i, elem in enumerate(data): + if box == DataFrame: + scalar_result = result.iloc[i, 0] + else: + scalar_result = result[i] + try: + expected = getattr(int(elem), op)(scalar) + except ZeroDivisionError: + pass + else: + assert scalar_result == expected diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/test_flags.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/test_flags.py new file mode 100644 index 0000000000000000000000000000000000000000..9294b3fc3319b78b59d5637acdf3fd75737cd836 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/test_flags.py @@ -0,0 +1,48 @@ +import pytest + +import pandas as pd + + +class TestFlags: + def test_equality(self): + a = pd.DataFrame().set_flags(allows_duplicate_labels=True).flags + b = pd.DataFrame().set_flags(allows_duplicate_labels=False).flags + + assert a == a + assert b == b + assert a != b + assert a != 2 + + def test_set(self): + df = pd.DataFrame().set_flags(allows_duplicate_labels=True) + a = df.flags + a.allows_duplicate_labels = False + assert a.allows_duplicate_labels is False + a["allows_duplicate_labels"] = True + assert a.allows_duplicate_labels is True + + def test_repr(self): + a = repr(pd.DataFrame({"A"}).set_flags(allows_duplicate_labels=True).flags) + assert a == "" + a = repr(pd.DataFrame({"A"}).set_flags(allows_duplicate_labels=False).flags) + assert a == "" + + def test_obj_ref(self): + df = pd.DataFrame() + flags = df.flags + del df + with pytest.raises(ValueError, match="object has been deleted"): + flags.allows_duplicate_labels = True + + def test_getitem(self): + df = pd.DataFrame() + flags = df.flags + assert flags["allows_duplicate_labels"] is True + flags["allows_duplicate_labels"] = False + assert flags["allows_duplicate_labels"] is False + + with pytest.raises(KeyError, match="a"): + flags["a"] + + with pytest.raises(ValueError, match="a"): + flags["a"] = 10 diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/test_multilevel.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/test_multilevel.py new file mode 100644 index 0000000000000000000000000000000000000000..6644ec82fab17ac9e1c1744b595c38fda17114f5 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/test_multilevel.py @@ -0,0 +1,355 @@ +import datetime + +import numpy as np +import pytest + +import pandas as pd +from pandas import ( + DataFrame, + MultiIndex, + Series, +) +import pandas._testing as tm + + +class TestMultiLevel: + def test_reindex_level(self, multiindex_year_month_day_dataframe_random_data): + # axis=0 + ymd = multiindex_year_month_day_dataframe_random_data + + month_sums = ymd.groupby("month").sum() + result = month_sums.reindex(ymd.index, level=1) + expected = ymd.groupby(level="month").transform("sum") + + tm.assert_frame_equal(result, expected) + + # Series + result = month_sums["A"].reindex(ymd.index, level=1) + expected = ymd["A"].groupby(level="month").transform("sum") + tm.assert_series_equal(result, expected, check_names=False) + + # axis=1 + msg = "DataFrame.groupby with axis=1 is deprecated" + with tm.assert_produces_warning(FutureWarning, match=msg): + gb = ymd.T.groupby("month", axis=1) + + month_sums = gb.sum() + result = month_sums.reindex(columns=ymd.index, level=1) + expected = ymd.groupby(level="month").transform("sum").T + tm.assert_frame_equal(result, expected) + + def test_reindex(self, multiindex_dataframe_random_data): + frame = multiindex_dataframe_random_data + + expected = frame.iloc[[0, 3]] + reindexed = frame.loc[[("foo", "one"), ("bar", "one")]] + tm.assert_frame_equal(reindexed, expected) + + def test_reindex_preserve_levels( + self, multiindex_year_month_day_dataframe_random_data, using_copy_on_write + ): + ymd = multiindex_year_month_day_dataframe_random_data + + new_index = ymd.index[::10] + chunk = ymd.reindex(new_index) + if using_copy_on_write: + assert chunk.index.is_(new_index) + else: + assert chunk.index is new_index + + chunk = ymd.loc[new_index] + assert chunk.index.equals(new_index) + + ymdT = ymd.T + chunk = ymdT.reindex(columns=new_index) + if using_copy_on_write: + assert chunk.columns.is_(new_index) + else: + assert chunk.columns is new_index + + chunk = ymdT.loc[:, new_index] + assert chunk.columns.equals(new_index) + + def test_groupby_transform(self, multiindex_dataframe_random_data): + frame = multiindex_dataframe_random_data + + s = frame["A"] + grouper = s.index.get_level_values(0) + + grouped = s.groupby(grouper, group_keys=False) + + applied = grouped.apply(lambda x: x * 2) + expected = grouped.transform(lambda x: x * 2) + result = applied.reindex(expected.index) + tm.assert_series_equal(result, expected, check_names=False) + + def test_groupby_corner(self): + midx = MultiIndex( + levels=[["foo"], ["bar"], ["baz"]], + codes=[[0], [0], [0]], + names=["one", "two", "three"], + ) + df = DataFrame( + [np.random.default_rng(2).random(4)], + columns=["a", "b", "c", "d"], + index=midx, + ) + # should work + df.groupby(level="three") + + def test_groupby_level_no_obs(self): + # #1697 + midx = MultiIndex.from_tuples( + [ + ("f1", "s1"), + ("f1", "s2"), + ("f2", "s1"), + ("f2", "s2"), + ("f3", "s1"), + ("f3", "s2"), + ] + ) + df = DataFrame([[1, 2, 3, 4, 5, 6], [7, 8, 9, 10, 11, 12]], columns=midx) + df1 = df.loc(axis=1)[df.columns.map(lambda u: u[0] in ["f2", "f3"])] + + msg = "DataFrame.groupby with axis=1 is deprecated" + with tm.assert_produces_warning(FutureWarning, match=msg): + grouped = df1.groupby(axis=1, level=0) + result = grouped.sum() + assert (result.columns == ["f2", "f3"]).all() + + def test_setitem_with_expansion_multiindex_columns( + self, multiindex_year_month_day_dataframe_random_data + ): + ymd = multiindex_year_month_day_dataframe_random_data + + df = ymd[:5].T + df[2000, 1, 10] = df[2000, 1, 7] + assert isinstance(df.columns, MultiIndex) + assert (df[2000, 1, 10] == df[2000, 1, 7]).all() + + def test_alignment(self): + x = Series( + data=[1, 2, 3], index=MultiIndex.from_tuples([("A", 1), ("A", 2), ("B", 3)]) + ) + + y = Series( + data=[4, 5, 6], index=MultiIndex.from_tuples([("Z", 1), ("Z", 2), ("B", 3)]) + ) + + res = x - y + exp_index = x.index.union(y.index) + exp = x.reindex(exp_index) - y.reindex(exp_index) + tm.assert_series_equal(res, exp) + + # hit non-monotonic code path + res = x[::-1] - y[::-1] + exp_index = x.index.union(y.index) + exp = x.reindex(exp_index) - y.reindex(exp_index) + tm.assert_series_equal(res, exp) + + def test_groupby_multilevel(self, multiindex_year_month_day_dataframe_random_data): + ymd = multiindex_year_month_day_dataframe_random_data + + result = ymd.groupby(level=[0, 1]).mean() + + k1 = ymd.index.get_level_values(0) + k2 = ymd.index.get_level_values(1) + + expected = ymd.groupby([k1, k2]).mean() + + # TODO groupby with level_values drops names + tm.assert_frame_equal(result, expected, check_names=False) + assert result.index.names == ymd.index.names[:2] + + result2 = ymd.groupby(level=ymd.index.names[:2]).mean() + tm.assert_frame_equal(result, result2) + + def test_multilevel_consolidate(self): + index = MultiIndex.from_tuples( + [("foo", "one"), ("foo", "two"), ("bar", "one"), ("bar", "two")] + ) + df = DataFrame( + np.random.default_rng(2).standard_normal((4, 4)), index=index, columns=index + ) + df["Totals", ""] = df.sum(1) + df = df._consolidate() + + def test_level_with_tuples(self): + index = MultiIndex( + levels=[[("foo", "bar", 0), ("foo", "baz", 0), ("foo", "qux", 0)], [0, 1]], + codes=[[0, 0, 1, 1, 2, 2], [0, 1, 0, 1, 0, 1]], + ) + + series = Series(np.random.default_rng(2).standard_normal(6), index=index) + frame = DataFrame(np.random.default_rng(2).standard_normal((6, 4)), index=index) + + result = series[("foo", "bar", 0)] + result2 = series.loc[("foo", "bar", 0)] + expected = series[:2] + expected.index = expected.index.droplevel(0) + tm.assert_series_equal(result, expected) + tm.assert_series_equal(result2, expected) + + with pytest.raises(KeyError, match=r"^\(\('foo', 'bar', 0\), 2\)$"): + series[("foo", "bar", 0), 2] + + result = frame.loc[("foo", "bar", 0)] + result2 = frame.xs(("foo", "bar", 0)) + expected = frame[:2] + expected.index = expected.index.droplevel(0) + tm.assert_frame_equal(result, expected) + tm.assert_frame_equal(result2, expected) + + index = MultiIndex( + levels=[[("foo", "bar"), ("foo", "baz"), ("foo", "qux")], [0, 1]], + codes=[[0, 0, 1, 1, 2, 2], [0, 1, 0, 1, 0, 1]], + ) + + series = Series(np.random.default_rng(2).standard_normal(6), index=index) + frame = DataFrame(np.random.default_rng(2).standard_normal((6, 4)), index=index) + + result = series[("foo", "bar")] + result2 = series.loc[("foo", "bar")] + expected = series[:2] + expected.index = expected.index.droplevel(0) + tm.assert_series_equal(result, expected) + tm.assert_series_equal(result2, expected) + + result = frame.loc[("foo", "bar")] + result2 = frame.xs(("foo", "bar")) + expected = frame[:2] + expected.index = expected.index.droplevel(0) + tm.assert_frame_equal(result, expected) + tm.assert_frame_equal(result2, expected) + + def test_reindex_level_partial_selection(self, multiindex_dataframe_random_data): + frame = multiindex_dataframe_random_data + + result = frame.reindex(["foo", "qux"], level=0) + expected = frame.iloc[[0, 1, 2, 7, 8, 9]] + tm.assert_frame_equal(result, expected) + + result = frame.T.reindex(["foo", "qux"], axis=1, level=0) + tm.assert_frame_equal(result, expected.T) + + result = frame.loc[["foo", "qux"]] + tm.assert_frame_equal(result, expected) + + result = frame["A"].loc[["foo", "qux"]] + tm.assert_series_equal(result, expected["A"]) + + result = frame.T.loc[:, ["foo", "qux"]] + tm.assert_frame_equal(result, expected.T) + + @pytest.mark.parametrize("d", [4, "d"]) + def test_empty_frame_groupby_dtypes_consistency(self, d): + # GH 20888 + group_keys = ["a", "b", "c"] + df = DataFrame({"a": [1], "b": [2], "c": [3], "d": [d]}) + + g = df[df.a == 2].groupby(group_keys) + result = g.first().index + expected = MultiIndex( + levels=[[1], [2], [3]], codes=[[], [], []], names=["a", "b", "c"] + ) + + tm.assert_index_equal(result, expected) + + def test_duplicate_groupby_issues(self): + idx_tp = [ + ("600809", "20061231"), + ("600809", "20070331"), + ("600809", "20070630"), + ("600809", "20070331"), + ] + dt = ["demo", "demo", "demo", "demo"] + + idx = MultiIndex.from_tuples(idx_tp, names=["STK_ID", "RPT_Date"]) + s = Series(dt, index=idx) + + result = s.groupby(s.index).first() + assert len(result) == 3 + + def test_subsets_multiindex_dtype(self): + # GH 20757 + data = [["x", 1]] + columns = [("a", "b", np.nan), ("a", "c", 0.0)] + df = DataFrame(data, columns=MultiIndex.from_tuples(columns)) + expected = df.dtypes.a.b + result = df.a.b.dtypes + tm.assert_series_equal(result, expected) + + def test_datetime_object_multiindex(self): + data_dic = { + (0, datetime.date(2018, 3, 3)): {"A": 1, "B": 10}, + (0, datetime.date(2018, 3, 4)): {"A": 2, "B": 11}, + (1, datetime.date(2018, 3, 3)): {"A": 3, "B": 12}, + (1, datetime.date(2018, 3, 4)): {"A": 4, "B": 13}, + } + result = DataFrame.from_dict(data_dic, orient="index") + data = {"A": [1, 2, 3, 4], "B": [10, 11, 12, 13]} + index = [ + [0, 0, 1, 1], + [ + datetime.date(2018, 3, 3), + datetime.date(2018, 3, 4), + datetime.date(2018, 3, 3), + datetime.date(2018, 3, 4), + ], + ] + expected = DataFrame(data=data, index=index) + + tm.assert_frame_equal(result, expected) + + def test_multiindex_with_na(self): + df = DataFrame( + [ + ["A", np.nan, 1.23, 4.56], + ["A", "G", 1.23, 4.56], + ["A", "D", 9.87, 10.54], + ], + columns=["pivot_0", "pivot_1", "col_1", "col_2"], + ).set_index(["pivot_0", "pivot_1"]) + + df.at[("A", "F"), "col_2"] = 0.0 + + expected = DataFrame( + [ + ["A", np.nan, 1.23, 4.56], + ["A", "G", 1.23, 4.56], + ["A", "D", 9.87, 10.54], + ["A", "F", np.nan, 0.0], + ], + columns=["pivot_0", "pivot_1", "col_1", "col_2"], + ).set_index(["pivot_0", "pivot_1"]) + + tm.assert_frame_equal(df, expected) + + +class TestSorted: + """everything you wanted to test about sorting""" + + def test_sort_non_lexsorted(self): + # degenerate case where we sort but don't + # have a satisfying result :< + # GH 15797 + idx = MultiIndex( + [["A", "B", "C"], ["c", "b", "a"]], [[0, 1, 2, 0, 1, 2], [0, 2, 1, 1, 0, 2]] + ) + + df = DataFrame({"col": range(len(idx))}, index=idx, dtype="int64") + assert df.index.is_monotonic_increasing is False + + sorted = df.sort_index() + assert sorted.index.is_monotonic_increasing is True + + expected = DataFrame( + {"col": [1, 4, 5, 2]}, + index=MultiIndex.from_tuples( + [("B", "a"), ("B", "c"), ("C", "a"), ("C", "b")] + ), + dtype="int64", + ) + result = sorted.loc[pd.IndexSlice["B":"C", "a":"c"], :] + tm.assert_frame_equal(result, expected) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/test_nanops.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/test_nanops.py new file mode 100644 index 0000000000000000000000000000000000000000..a50054f33f382ed913261e0cafd944c2fd86aaa3 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/test_nanops.py @@ -0,0 +1,1274 @@ +from functools import partial + +import numpy as np +import pytest + +import pandas.util._test_decorators as td + +from pandas.core.dtypes.common import is_integer_dtype + +import pandas as pd +from pandas import ( + Series, + isna, +) +import pandas._testing as tm +from pandas.core import nanops + +use_bn = nanops._USE_BOTTLENECK + + +@pytest.fixture +def disable_bottleneck(monkeypatch): + with monkeypatch.context() as m: + m.setattr(nanops, "_USE_BOTTLENECK", False) + yield + + +@pytest.fixture +def arr_shape(): + return 11, 7 + + +@pytest.fixture +def arr_float(arr_shape): + return np.random.default_rng(2).standard_normal(arr_shape) + + +@pytest.fixture +def arr_complex(arr_float): + return arr_float + arr_float * 1j + + +@pytest.fixture +def arr_int(arr_shape): + return np.random.default_rng(2).integers(-10, 10, arr_shape) + + +@pytest.fixture +def arr_bool(arr_shape): + return np.random.default_rng(2).integers(0, 2, arr_shape) == 0 + + +@pytest.fixture +def arr_str(arr_float): + return np.abs(arr_float).astype("S") + + +@pytest.fixture +def arr_utf(arr_float): + return np.abs(arr_float).astype("U") + + +@pytest.fixture +def arr_date(arr_shape): + return np.random.default_rng(2).integers(0, 20000, arr_shape).astype("M8[ns]") + + +@pytest.fixture +def arr_tdelta(arr_shape): + return np.random.default_rng(2).integers(0, 20000, arr_shape).astype("m8[ns]") + + +@pytest.fixture +def arr_nan(arr_shape): + return np.tile(np.nan, arr_shape) + + +@pytest.fixture +def arr_float_nan(arr_float, arr_nan): + return np.vstack([arr_float, arr_nan]) + + +@pytest.fixture +def arr_nan_float1(arr_nan, arr_float): + return np.vstack([arr_nan, arr_float]) + + +@pytest.fixture +def arr_nan_nan(arr_nan): + return np.vstack([arr_nan, arr_nan]) + + +@pytest.fixture +def arr_inf(arr_float): + return arr_float * np.inf + + +@pytest.fixture +def arr_float_inf(arr_float, arr_inf): + return np.vstack([arr_float, arr_inf]) + + +@pytest.fixture +def arr_nan_inf(arr_nan, arr_inf): + return np.vstack([arr_nan, arr_inf]) + + +@pytest.fixture +def arr_float_nan_inf(arr_float, arr_nan, arr_inf): + return np.vstack([arr_float, arr_nan, arr_inf]) + + +@pytest.fixture +def arr_nan_nan_inf(arr_nan, arr_inf): + return np.vstack([arr_nan, arr_nan, arr_inf]) + + +@pytest.fixture +def arr_obj( + arr_float, arr_int, arr_bool, arr_complex, arr_str, arr_utf, arr_date, arr_tdelta +): + return np.vstack( + [ + arr_float.astype("O"), + arr_int.astype("O"), + arr_bool.astype("O"), + arr_complex.astype("O"), + arr_str.astype("O"), + arr_utf.astype("O"), + arr_date.astype("O"), + arr_tdelta.astype("O"), + ] + ) + + +@pytest.fixture +def arr_nan_nanj(arr_nan): + with np.errstate(invalid="ignore"): + return arr_nan + arr_nan * 1j + + +@pytest.fixture +def arr_complex_nan(arr_complex, arr_nan_nanj): + with np.errstate(invalid="ignore"): + return np.vstack([arr_complex, arr_nan_nanj]) + + +@pytest.fixture +def arr_nan_infj(arr_inf): + with np.errstate(invalid="ignore"): + return arr_inf * 1j + + +@pytest.fixture +def arr_complex_nan_infj(arr_complex, arr_nan_infj): + with np.errstate(invalid="ignore"): + return np.vstack([arr_complex, arr_nan_infj]) + + +@pytest.fixture +def arr_float_1d(arr_float): + return arr_float[:, 0] + + +@pytest.fixture +def arr_nan_1d(arr_nan): + return arr_nan[:, 0] + + +@pytest.fixture +def arr_float_nan_1d(arr_float_nan): + return arr_float_nan[:, 0] + + +@pytest.fixture +def arr_float1_nan_1d(arr_float1_nan): + return arr_float1_nan[:, 0] + + +@pytest.fixture +def arr_nan_float1_1d(arr_nan_float1): + return arr_nan_float1[:, 0] + + +class TestnanopsDataFrame: + def setup_method(self): + nanops._USE_BOTTLENECK = False + + arr_shape = (11, 7) + + self.arr_float = np.random.default_rng(2).standard_normal(arr_shape) + self.arr_float1 = np.random.default_rng(2).standard_normal(arr_shape) + self.arr_complex = self.arr_float + self.arr_float1 * 1j + self.arr_int = np.random.default_rng(2).integers(-10, 10, arr_shape) + self.arr_bool = np.random.default_rng(2).integers(0, 2, arr_shape) == 0 + self.arr_str = np.abs(self.arr_float).astype("S") + self.arr_utf = np.abs(self.arr_float).astype("U") + self.arr_date = ( + np.random.default_rng(2).integers(0, 20000, arr_shape).astype("M8[ns]") + ) + self.arr_tdelta = ( + np.random.default_rng(2).integers(0, 20000, arr_shape).astype("m8[ns]") + ) + + self.arr_nan = np.tile(np.nan, arr_shape) + self.arr_float_nan = np.vstack([self.arr_float, self.arr_nan]) + self.arr_float1_nan = np.vstack([self.arr_float1, self.arr_nan]) + self.arr_nan_float1 = np.vstack([self.arr_nan, self.arr_float1]) + self.arr_nan_nan = np.vstack([self.arr_nan, self.arr_nan]) + + self.arr_inf = self.arr_float * np.inf + self.arr_float_inf = np.vstack([self.arr_float, self.arr_inf]) + + self.arr_nan_inf = np.vstack([self.arr_nan, self.arr_inf]) + self.arr_float_nan_inf = np.vstack([self.arr_float, self.arr_nan, self.arr_inf]) + self.arr_nan_nan_inf = np.vstack([self.arr_nan, self.arr_nan, self.arr_inf]) + self.arr_obj = np.vstack( + [ + self.arr_float.astype("O"), + self.arr_int.astype("O"), + self.arr_bool.astype("O"), + self.arr_complex.astype("O"), + self.arr_str.astype("O"), + self.arr_utf.astype("O"), + self.arr_date.astype("O"), + self.arr_tdelta.astype("O"), + ] + ) + + with np.errstate(invalid="ignore"): + self.arr_nan_nanj = self.arr_nan + self.arr_nan * 1j + self.arr_complex_nan = np.vstack([self.arr_complex, self.arr_nan_nanj]) + + self.arr_nan_infj = self.arr_inf * 1j + self.arr_complex_nan_infj = np.vstack([self.arr_complex, self.arr_nan_infj]) + + self.arr_float_2d = self.arr_float + self.arr_float1_2d = self.arr_float1 + + self.arr_nan_2d = self.arr_nan + self.arr_float_nan_2d = self.arr_float_nan + self.arr_float1_nan_2d = self.arr_float1_nan + self.arr_nan_float1_2d = self.arr_nan_float1 + + self.arr_float_1d = self.arr_float[:, 0] + self.arr_float1_1d = self.arr_float1[:, 0] + + self.arr_nan_1d = self.arr_nan[:, 0] + self.arr_float_nan_1d = self.arr_float_nan[:, 0] + self.arr_float1_nan_1d = self.arr_float1_nan[:, 0] + self.arr_nan_float1_1d = self.arr_nan_float1[:, 0] + + def teardown_method(self): + nanops._USE_BOTTLENECK = use_bn + + def check_results(self, targ, res, axis, check_dtype=True): + res = getattr(res, "asm8", res) + + if ( + axis != 0 + and hasattr(targ, "shape") + and targ.ndim + and targ.shape != res.shape + ): + res = np.split(res, [targ.shape[0]], axis=0)[0] + + try: + tm.assert_almost_equal(targ, res, check_dtype=check_dtype) + except AssertionError: + # handle timedelta dtypes + if hasattr(targ, "dtype") and targ.dtype == "m8[ns]": + raise + + # There are sometimes rounding errors with + # complex and object dtypes. + # If it isn't one of those, re-raise the error. + if not hasattr(res, "dtype") or res.dtype.kind not in ["c", "O"]: + raise + # convert object dtypes to something that can be split into + # real and imaginary parts + if res.dtype.kind == "O": + if targ.dtype.kind != "O": + res = res.astype(targ.dtype) + else: + cast_dtype = "c16" if hasattr(np, "complex128") else "f8" + res = res.astype(cast_dtype) + targ = targ.astype(cast_dtype) + # there should never be a case where numpy returns an object + # but nanops doesn't, so make that an exception + elif targ.dtype.kind == "O": + raise + tm.assert_almost_equal(np.real(targ), np.real(res), check_dtype=check_dtype) + tm.assert_almost_equal(np.imag(targ), np.imag(res), check_dtype=check_dtype) + + def check_fun_data( + self, + testfunc, + targfunc, + testarval, + targarval, + skipna, + check_dtype=True, + empty_targfunc=None, + **kwargs, + ): + for axis in list(range(targarval.ndim)) + [None]: + targartempval = targarval if skipna else testarval + if skipna and empty_targfunc and isna(targartempval).all(): + targ = empty_targfunc(targartempval, axis=axis, **kwargs) + else: + targ = targfunc(targartempval, axis=axis, **kwargs) + + if targartempval.dtype == object and ( + targfunc is np.any or targfunc is np.all + ): + # GH#12863 the numpy functions will retain e.g. floatiness + if isinstance(targ, np.ndarray): + targ = targ.astype(bool) + else: + targ = bool(targ) + + res = testfunc(testarval, axis=axis, skipna=skipna, **kwargs) + + if ( + isinstance(targ, np.complex128) + and isinstance(res, float) + and np.isnan(targ) + and np.isnan(res) + ): + # GH#18463 + targ = res + + self.check_results(targ, res, axis, check_dtype=check_dtype) + if skipna: + res = testfunc(testarval, axis=axis, **kwargs) + self.check_results(targ, res, axis, check_dtype=check_dtype) + if axis is None: + res = testfunc(testarval, skipna=skipna, **kwargs) + self.check_results(targ, res, axis, check_dtype=check_dtype) + if skipna and axis is None: + res = testfunc(testarval, **kwargs) + self.check_results(targ, res, axis, check_dtype=check_dtype) + + if testarval.ndim <= 1: + return + + # Recurse on lower-dimension + testarval2 = np.take(testarval, 0, axis=-1) + targarval2 = np.take(targarval, 0, axis=-1) + self.check_fun_data( + testfunc, + targfunc, + testarval2, + targarval2, + skipna=skipna, + check_dtype=check_dtype, + empty_targfunc=empty_targfunc, + **kwargs, + ) + + def check_fun( + self, testfunc, targfunc, testar, skipna, empty_targfunc=None, **kwargs + ): + targar = testar + if testar.endswith("_nan") and hasattr(self, testar[:-4]): + targar = testar[:-4] + + testarval = getattr(self, testar) + targarval = getattr(self, targar) + self.check_fun_data( + testfunc, + targfunc, + testarval, + targarval, + skipna=skipna, + empty_targfunc=empty_targfunc, + **kwargs, + ) + + def check_funs( + self, + testfunc, + targfunc, + skipna, + allow_complex=True, + allow_all_nan=True, + allow_date=True, + allow_tdelta=True, + allow_obj=True, + **kwargs, + ): + self.check_fun(testfunc, targfunc, "arr_float", skipna, **kwargs) + self.check_fun(testfunc, targfunc, "arr_float_nan", skipna, **kwargs) + self.check_fun(testfunc, targfunc, "arr_int", skipna, **kwargs) + self.check_fun(testfunc, targfunc, "arr_bool", skipna, **kwargs) + objs = [ + self.arr_float.astype("O"), + self.arr_int.astype("O"), + self.arr_bool.astype("O"), + ] + + if allow_all_nan: + self.check_fun(testfunc, targfunc, "arr_nan", skipna, **kwargs) + + if allow_complex: + self.check_fun(testfunc, targfunc, "arr_complex", skipna, **kwargs) + self.check_fun(testfunc, targfunc, "arr_complex_nan", skipna, **kwargs) + if allow_all_nan: + self.check_fun(testfunc, targfunc, "arr_nan_nanj", skipna, **kwargs) + objs += [self.arr_complex.astype("O")] + + if allow_date: + targfunc(self.arr_date) + self.check_fun(testfunc, targfunc, "arr_date", skipna, **kwargs) + objs += [self.arr_date.astype("O")] + + if allow_tdelta: + try: + targfunc(self.arr_tdelta) + except TypeError: + pass + else: + self.check_fun(testfunc, targfunc, "arr_tdelta", skipna, **kwargs) + objs += [self.arr_tdelta.astype("O")] + + if allow_obj: + self.arr_obj = np.vstack(objs) + # some nanops handle object dtypes better than their numpy + # counterparts, so the numpy functions need to be given something + # else + if allow_obj == "convert": + targfunc = partial( + self._badobj_wrap, func=targfunc, allow_complex=allow_complex + ) + self.check_fun(testfunc, targfunc, "arr_obj", skipna, **kwargs) + + def _badobj_wrap(self, value, func, allow_complex=True, **kwargs): + if value.dtype.kind == "O": + if allow_complex: + value = value.astype("c16") + else: + value = value.astype("f8") + return func(value, **kwargs) + + @pytest.mark.parametrize( + "nan_op,np_op", [(nanops.nanany, np.any), (nanops.nanall, np.all)] + ) + def test_nan_funcs(self, nan_op, np_op, skipna): + self.check_funs(nan_op, np_op, skipna, allow_all_nan=False, allow_date=False) + + def test_nansum(self, skipna): + self.check_funs( + nanops.nansum, + np.sum, + skipna, + allow_date=False, + check_dtype=False, + empty_targfunc=np.nansum, + ) + + def test_nanmean(self, skipna): + self.check_funs( + nanops.nanmean, np.mean, skipna, allow_obj=False, allow_date=False + ) + + @pytest.mark.filterwarnings("ignore::RuntimeWarning") + def test_nanmedian(self, skipna): + self.check_funs( + nanops.nanmedian, + np.median, + skipna, + allow_complex=False, + allow_date=False, + allow_obj="convert", + ) + + @pytest.mark.parametrize("ddof", range(3)) + def test_nanvar(self, ddof, skipna): + self.check_funs( + nanops.nanvar, + np.var, + skipna, + allow_complex=False, + allow_date=False, + allow_obj="convert", + ddof=ddof, + ) + + @pytest.mark.parametrize("ddof", range(3)) + def test_nanstd(self, ddof, skipna): + self.check_funs( + nanops.nanstd, + np.std, + skipna, + allow_complex=False, + allow_date=False, + allow_obj="convert", + ddof=ddof, + ) + + @pytest.mark.parametrize("ddof", range(3)) + def test_nansem(self, ddof, skipna): + sp_stats = pytest.importorskip("scipy.stats") + + with np.errstate(invalid="ignore"): + self.check_funs( + nanops.nansem, + sp_stats.sem, + skipna, + allow_complex=False, + allow_date=False, + allow_tdelta=False, + allow_obj="convert", + ddof=ddof, + ) + + @pytest.mark.filterwarnings("ignore::RuntimeWarning") + @pytest.mark.parametrize( + "nan_op,np_op", [(nanops.nanmin, np.min), (nanops.nanmax, np.max)] + ) + def test_nanops_with_warnings(self, nan_op, np_op, skipna): + self.check_funs(nan_op, np_op, skipna, allow_obj=False) + + def _argminmax_wrap(self, value, axis=None, func=None): + res = func(value, axis) + nans = np.min(value, axis) + nullnan = isna(nans) + if res.ndim: + res[nullnan] = -1 + elif ( + hasattr(nullnan, "all") + and nullnan.all() + or not hasattr(nullnan, "all") + and nullnan + ): + res = -1 + return res + + @pytest.mark.filterwarnings("ignore::RuntimeWarning") + def test_nanargmax(self, skipna): + func = partial(self._argminmax_wrap, func=np.argmax) + self.check_funs(nanops.nanargmax, func, skipna, allow_obj=False) + + @pytest.mark.filterwarnings("ignore::RuntimeWarning") + def test_nanargmin(self, skipna): + func = partial(self._argminmax_wrap, func=np.argmin) + self.check_funs(nanops.nanargmin, func, skipna, allow_obj=False) + + def _skew_kurt_wrap(self, values, axis=None, func=None): + if not isinstance(values.dtype.type, np.floating): + values = values.astype("f8") + result = func(values, axis=axis, bias=False) + # fix for handling cases where all elements in an axis are the same + if isinstance(result, np.ndarray): + result[np.max(values, axis=axis) == np.min(values, axis=axis)] = 0 + return result + elif np.max(values) == np.min(values): + return 0.0 + return result + + def test_nanskew(self, skipna): + sp_stats = pytest.importorskip("scipy.stats") + + func = partial(self._skew_kurt_wrap, func=sp_stats.skew) + with np.errstate(invalid="ignore"): + self.check_funs( + nanops.nanskew, + func, + skipna, + allow_complex=False, + allow_date=False, + allow_tdelta=False, + ) + + def test_nankurt(self, skipna): + sp_stats = pytest.importorskip("scipy.stats") + + func1 = partial(sp_stats.kurtosis, fisher=True) + func = partial(self._skew_kurt_wrap, func=func1) + with np.errstate(invalid="ignore"): + self.check_funs( + nanops.nankurt, + func, + skipna, + allow_complex=False, + allow_date=False, + allow_tdelta=False, + ) + + def test_nanprod(self, skipna): + self.check_funs( + nanops.nanprod, + np.prod, + skipna, + allow_date=False, + allow_tdelta=False, + empty_targfunc=np.nanprod, + ) + + def check_nancorr_nancov_2d(self, checkfun, targ0, targ1, **kwargs): + res00 = checkfun(self.arr_float_2d, self.arr_float1_2d, **kwargs) + res01 = checkfun( + self.arr_float_2d, + self.arr_float1_2d, + min_periods=len(self.arr_float_2d) - 1, + **kwargs, + ) + tm.assert_almost_equal(targ0, res00) + tm.assert_almost_equal(targ0, res01) + + res10 = checkfun(self.arr_float_nan_2d, self.arr_float1_nan_2d, **kwargs) + res11 = checkfun( + self.arr_float_nan_2d, + self.arr_float1_nan_2d, + min_periods=len(self.arr_float_2d) - 1, + **kwargs, + ) + tm.assert_almost_equal(targ1, res10) + tm.assert_almost_equal(targ1, res11) + + targ2 = np.nan + res20 = checkfun(self.arr_nan_2d, self.arr_float1_2d, **kwargs) + res21 = checkfun(self.arr_float_2d, self.arr_nan_2d, **kwargs) + res22 = checkfun(self.arr_nan_2d, self.arr_nan_2d, **kwargs) + res23 = checkfun(self.arr_float_nan_2d, self.arr_nan_float1_2d, **kwargs) + res24 = checkfun( + self.arr_float_nan_2d, + self.arr_nan_float1_2d, + min_periods=len(self.arr_float_2d) - 1, + **kwargs, + ) + res25 = checkfun( + self.arr_float_2d, + self.arr_float1_2d, + min_periods=len(self.arr_float_2d) + 1, + **kwargs, + ) + tm.assert_almost_equal(targ2, res20) + tm.assert_almost_equal(targ2, res21) + tm.assert_almost_equal(targ2, res22) + tm.assert_almost_equal(targ2, res23) + tm.assert_almost_equal(targ2, res24) + tm.assert_almost_equal(targ2, res25) + + def check_nancorr_nancov_1d(self, checkfun, targ0, targ1, **kwargs): + res00 = checkfun(self.arr_float_1d, self.arr_float1_1d, **kwargs) + res01 = checkfun( + self.arr_float_1d, + self.arr_float1_1d, + min_periods=len(self.arr_float_1d) - 1, + **kwargs, + ) + tm.assert_almost_equal(targ0, res00) + tm.assert_almost_equal(targ0, res01) + + res10 = checkfun(self.arr_float_nan_1d, self.arr_float1_nan_1d, **kwargs) + res11 = checkfun( + self.arr_float_nan_1d, + self.arr_float1_nan_1d, + min_periods=len(self.arr_float_1d) - 1, + **kwargs, + ) + tm.assert_almost_equal(targ1, res10) + tm.assert_almost_equal(targ1, res11) + + targ2 = np.nan + res20 = checkfun(self.arr_nan_1d, self.arr_float1_1d, **kwargs) + res21 = checkfun(self.arr_float_1d, self.arr_nan_1d, **kwargs) + res22 = checkfun(self.arr_nan_1d, self.arr_nan_1d, **kwargs) + res23 = checkfun(self.arr_float_nan_1d, self.arr_nan_float1_1d, **kwargs) + res24 = checkfun( + self.arr_float_nan_1d, + self.arr_nan_float1_1d, + min_periods=len(self.arr_float_1d) - 1, + **kwargs, + ) + res25 = checkfun( + self.arr_float_1d, + self.arr_float1_1d, + min_periods=len(self.arr_float_1d) + 1, + **kwargs, + ) + tm.assert_almost_equal(targ2, res20) + tm.assert_almost_equal(targ2, res21) + tm.assert_almost_equal(targ2, res22) + tm.assert_almost_equal(targ2, res23) + tm.assert_almost_equal(targ2, res24) + tm.assert_almost_equal(targ2, res25) + + def test_nancorr(self): + targ0 = np.corrcoef(self.arr_float_2d, self.arr_float1_2d)[0, 1] + targ1 = np.corrcoef(self.arr_float_2d.flat, self.arr_float1_2d.flat)[0, 1] + self.check_nancorr_nancov_2d(nanops.nancorr, targ0, targ1) + targ0 = np.corrcoef(self.arr_float_1d, self.arr_float1_1d)[0, 1] + targ1 = np.corrcoef(self.arr_float_1d.flat, self.arr_float1_1d.flat)[0, 1] + self.check_nancorr_nancov_1d(nanops.nancorr, targ0, targ1, method="pearson") + + def test_nancorr_pearson(self): + targ0 = np.corrcoef(self.arr_float_2d, self.arr_float1_2d)[0, 1] + targ1 = np.corrcoef(self.arr_float_2d.flat, self.arr_float1_2d.flat)[0, 1] + self.check_nancorr_nancov_2d(nanops.nancorr, targ0, targ1, method="pearson") + targ0 = np.corrcoef(self.arr_float_1d, self.arr_float1_1d)[0, 1] + targ1 = np.corrcoef(self.arr_float_1d.flat, self.arr_float1_1d.flat)[0, 1] + self.check_nancorr_nancov_1d(nanops.nancorr, targ0, targ1, method="pearson") + + def test_nancorr_kendall(self): + sp_stats = pytest.importorskip("scipy.stats") + + targ0 = sp_stats.kendalltau(self.arr_float_2d, self.arr_float1_2d)[0] + targ1 = sp_stats.kendalltau(self.arr_float_2d.flat, self.arr_float1_2d.flat)[0] + self.check_nancorr_nancov_2d(nanops.nancorr, targ0, targ1, method="kendall") + targ0 = sp_stats.kendalltau(self.arr_float_1d, self.arr_float1_1d)[0] + targ1 = sp_stats.kendalltau(self.arr_float_1d.flat, self.arr_float1_1d.flat)[0] + self.check_nancorr_nancov_1d(nanops.nancorr, targ0, targ1, method="kendall") + + def test_nancorr_spearman(self): + sp_stats = pytest.importorskip("scipy.stats") + + targ0 = sp_stats.spearmanr(self.arr_float_2d, self.arr_float1_2d)[0] + targ1 = sp_stats.spearmanr(self.arr_float_2d.flat, self.arr_float1_2d.flat)[0] + self.check_nancorr_nancov_2d(nanops.nancorr, targ0, targ1, method="spearman") + targ0 = sp_stats.spearmanr(self.arr_float_1d, self.arr_float1_1d)[0] + targ1 = sp_stats.spearmanr(self.arr_float_1d.flat, self.arr_float1_1d.flat)[0] + self.check_nancorr_nancov_1d(nanops.nancorr, targ0, targ1, method="spearman") + + def test_invalid_method(self): + pytest.importorskip("scipy") + targ0 = np.corrcoef(self.arr_float_2d, self.arr_float1_2d)[0, 1] + targ1 = np.corrcoef(self.arr_float_2d.flat, self.arr_float1_2d.flat)[0, 1] + msg = "Unknown method 'foo', expected one of 'kendall', 'spearman'" + with pytest.raises(ValueError, match=msg): + self.check_nancorr_nancov_1d(nanops.nancorr, targ0, targ1, method="foo") + + def test_nancov(self): + targ0 = np.cov(self.arr_float_2d, self.arr_float1_2d)[0, 1] + targ1 = np.cov(self.arr_float_2d.flat, self.arr_float1_2d.flat)[0, 1] + self.check_nancorr_nancov_2d(nanops.nancov, targ0, targ1) + targ0 = np.cov(self.arr_float_1d, self.arr_float1_1d)[0, 1] + targ1 = np.cov(self.arr_float_1d.flat, self.arr_float1_1d.flat)[0, 1] + self.check_nancorr_nancov_1d(nanops.nancov, targ0, targ1) + + +@pytest.mark.parametrize( + "arr, correct", + [ + ("arr_complex", False), + ("arr_int", False), + ("arr_bool", False), + ("arr_str", False), + ("arr_utf", False), + ("arr_complex", False), + ("arr_complex_nan", False), + ("arr_nan_nanj", False), + ("arr_nan_infj", True), + ("arr_complex_nan_infj", True), + ], +) +def test_has_infs_non_float(request, arr, correct, disable_bottleneck): + val = request.getfixturevalue(arr) + while getattr(val, "ndim", True): + res0 = nanops._has_infs(val) + if correct: + assert res0 + else: + assert not res0 + + if not hasattr(val, "ndim"): + break + + # Reduce dimension for next step in the loop + val = np.take(val, 0, axis=-1) + + +@pytest.mark.parametrize( + "arr, correct", + [ + ("arr_float", False), + ("arr_nan", False), + ("arr_float_nan", False), + ("arr_nan_nan", False), + ("arr_float_inf", True), + ("arr_inf", True), + ("arr_nan_inf", True), + ("arr_float_nan_inf", True), + ("arr_nan_nan_inf", True), + ], +) +@pytest.mark.parametrize("astype", [None, "f4", "f2"]) +def test_has_infs_floats(request, arr, correct, astype, disable_bottleneck): + val = request.getfixturevalue(arr) + if astype is not None: + val = val.astype(astype) + while getattr(val, "ndim", True): + res0 = nanops._has_infs(val) + if correct: + assert res0 + else: + assert not res0 + + if not hasattr(val, "ndim"): + break + + # Reduce dimension for next step in the loop + val = np.take(val, 0, axis=-1) + + +@pytest.mark.parametrize( + "fixture", ["arr_float", "arr_complex", "arr_int", "arr_bool", "arr_str", "arr_utf"] +) +def test_bn_ok_dtype(fixture, request, disable_bottleneck): + obj = request.getfixturevalue(fixture) + assert nanops._bn_ok_dtype(obj.dtype, "test") + + +@pytest.mark.parametrize( + "fixture", + [ + "arr_date", + "arr_tdelta", + "arr_obj", + ], +) +def test_bn_not_ok_dtype(fixture, request, disable_bottleneck): + obj = request.getfixturevalue(fixture) + assert not nanops._bn_ok_dtype(obj.dtype, "test") + + +class TestEnsureNumeric: + def test_numeric_values(self): + # Test integer + assert nanops._ensure_numeric(1) == 1 + + # Test float + assert nanops._ensure_numeric(1.1) == 1.1 + + # Test complex + assert nanops._ensure_numeric(1 + 2j) == 1 + 2j + + def test_ndarray(self): + # Test numeric ndarray + values = np.array([1, 2, 3]) + assert np.allclose(nanops._ensure_numeric(values), values) + + # Test object ndarray + o_values = values.astype(object) + assert np.allclose(nanops._ensure_numeric(o_values), values) + + # Test convertible string ndarray + s_values = np.array(["1", "2", "3"], dtype=object) + msg = r"Could not convert \['1' '2' '3'\] to numeric" + with pytest.raises(TypeError, match=msg): + nanops._ensure_numeric(s_values) + + # Test non-convertible string ndarray + s_values = np.array(["foo", "bar", "baz"], dtype=object) + msg = r"Could not convert .* to numeric" + with pytest.raises(TypeError, match=msg): + nanops._ensure_numeric(s_values) + + def test_convertable_values(self): + with pytest.raises(TypeError, match="Could not convert string '1' to numeric"): + nanops._ensure_numeric("1") + with pytest.raises( + TypeError, match="Could not convert string '1.1' to numeric" + ): + nanops._ensure_numeric("1.1") + with pytest.raises( + TypeError, match=r"Could not convert string '1\+1j' to numeric" + ): + nanops._ensure_numeric("1+1j") + + def test_non_convertable_values(self): + msg = "Could not convert string 'foo' to numeric" + with pytest.raises(TypeError, match=msg): + nanops._ensure_numeric("foo") + + # with the wrong type, python raises TypeError for us + msg = "argument must be a string or a number" + with pytest.raises(TypeError, match=msg): + nanops._ensure_numeric({}) + with pytest.raises(TypeError, match=msg): + nanops._ensure_numeric([]) + + +class TestNanvarFixedValues: + # xref GH10242 + # Samples from a normal distribution. + @pytest.fixture + def variance(self): + return 3.0 + + @pytest.fixture + def samples(self, variance): + return self.prng.normal(scale=variance**0.5, size=100000) + + def test_nanvar_all_finite(self, samples, variance): + actual_variance = nanops.nanvar(samples) + tm.assert_almost_equal(actual_variance, variance, rtol=1e-2) + + def test_nanvar_nans(self, samples, variance): + samples_test = np.nan * np.ones(2 * samples.shape[0]) + samples_test[::2] = samples + + actual_variance = nanops.nanvar(samples_test, skipna=True) + tm.assert_almost_equal(actual_variance, variance, rtol=1e-2) + + actual_variance = nanops.nanvar(samples_test, skipna=False) + tm.assert_almost_equal(actual_variance, np.nan, rtol=1e-2) + + def test_nanstd_nans(self, samples, variance): + samples_test = np.nan * np.ones(2 * samples.shape[0]) + samples_test[::2] = samples + + actual_std = nanops.nanstd(samples_test, skipna=True) + tm.assert_almost_equal(actual_std, variance**0.5, rtol=1e-2) + + actual_std = nanops.nanvar(samples_test, skipna=False) + tm.assert_almost_equal(actual_std, np.nan, rtol=1e-2) + + def test_nanvar_axis(self, samples, variance): + # Generate some sample data. + samples_unif = self.prng.uniform(size=samples.shape[0]) + samples = np.vstack([samples, samples_unif]) + + actual_variance = nanops.nanvar(samples, axis=1) + tm.assert_almost_equal( + actual_variance, np.array([variance, 1.0 / 12]), rtol=1e-2 + ) + + def test_nanvar_ddof(self): + n = 5 + samples = self.prng.uniform(size=(10000, n + 1)) + samples[:, -1] = np.nan # Force use of our own algorithm. + + variance_0 = nanops.nanvar(samples, axis=1, skipna=True, ddof=0).mean() + variance_1 = nanops.nanvar(samples, axis=1, skipna=True, ddof=1).mean() + variance_2 = nanops.nanvar(samples, axis=1, skipna=True, ddof=2).mean() + + # The unbiased estimate. + var = 1.0 / 12 + tm.assert_almost_equal(variance_1, var, rtol=1e-2) + + # The underestimated variance. + tm.assert_almost_equal(variance_0, (n - 1.0) / n * var, rtol=1e-2) + + # The overestimated variance. + tm.assert_almost_equal(variance_2, (n - 1.0) / (n - 2.0) * var, rtol=1e-2) + + @pytest.mark.parametrize("axis", range(2)) + @pytest.mark.parametrize("ddof", range(3)) + def test_ground_truth(self, axis, ddof): + # Test against values that were precomputed with Numpy. + samples = np.empty((4, 4)) + samples[:3, :3] = np.array( + [ + [0.97303362, 0.21869576, 0.55560287], + [0.72980153, 0.03109364, 0.99155171], + [0.09317602, 0.60078248, 0.15871292], + ] + ) + samples[3] = samples[:, 3] = np.nan + + # Actual variances along axis=0, 1 for ddof=0, 1, 2 + variance = np.array( + [ + [ + [0.13762259, 0.05619224, 0.11568816], + [0.20643388, 0.08428837, 0.17353224], + [0.41286776, 0.16857673, 0.34706449], + ], + [ + [0.09519783, 0.16435395, 0.05082054], + [0.14279674, 0.24653093, 0.07623082], + [0.28559348, 0.49306186, 0.15246163], + ], + ] + ) + + # Test nanvar. + var = nanops.nanvar(samples, skipna=True, axis=axis, ddof=ddof) + tm.assert_almost_equal(var[:3], variance[axis, ddof]) + assert np.isnan(var[3]) + + # Test nanstd. + std = nanops.nanstd(samples, skipna=True, axis=axis, ddof=ddof) + tm.assert_almost_equal(std[:3], variance[axis, ddof] ** 0.5) + assert np.isnan(std[3]) + + @pytest.mark.parametrize("ddof", range(3)) + def test_nanstd_roundoff(self, ddof): + # Regression test for GH 10242 (test data taken from GH 10489). Ensure + # that variance is stable. + data = Series(766897346 * np.ones(10)) + result = data.std(ddof=ddof) + assert result == 0.0 + + @property + def prng(self): + return np.random.default_rng(2) + + +class TestNanskewFixedValues: + # xref GH 11974 + # Test data + skewness value (computed with scipy.stats.skew) + @pytest.fixture + def samples(self): + return np.sin(np.linspace(0, 1, 200)) + + @pytest.fixture + def actual_skew(self): + return -0.1875895205961754 + + @pytest.mark.parametrize("val", [3075.2, 3075.3, 3075.5]) + def test_constant_series(self, val): + # xref GH 11974 + data = val * np.ones(300) + skew = nanops.nanskew(data) + assert skew == 0.0 + + def test_all_finite(self): + alpha, beta = 0.3, 0.1 + left_tailed = self.prng.beta(alpha, beta, size=100) + assert nanops.nanskew(left_tailed) < 0 + + alpha, beta = 0.1, 0.3 + right_tailed = self.prng.beta(alpha, beta, size=100) + assert nanops.nanskew(right_tailed) > 0 + + def test_ground_truth(self, samples, actual_skew): + skew = nanops.nanskew(samples) + tm.assert_almost_equal(skew, actual_skew) + + def test_axis(self, samples, actual_skew): + samples = np.vstack([samples, np.nan * np.ones(len(samples))]) + skew = nanops.nanskew(samples, axis=1) + tm.assert_almost_equal(skew, np.array([actual_skew, np.nan])) + + def test_nans(self, samples): + samples = np.hstack([samples, np.nan]) + skew = nanops.nanskew(samples, skipna=False) + assert np.isnan(skew) + + def test_nans_skipna(self, samples, actual_skew): + samples = np.hstack([samples, np.nan]) + skew = nanops.nanskew(samples, skipna=True) + tm.assert_almost_equal(skew, actual_skew) + + @property + def prng(self): + return np.random.default_rng(2) + + +class TestNankurtFixedValues: + # xref GH 11974 + # Test data + kurtosis value (computed with scipy.stats.kurtosis) + @pytest.fixture + def samples(self): + return np.sin(np.linspace(0, 1, 200)) + + @pytest.fixture + def actual_kurt(self): + return -1.2058303433799713 + + @pytest.mark.parametrize("val", [3075.2, 3075.3, 3075.5]) + def test_constant_series(self, val): + # xref GH 11974 + data = val * np.ones(300) + kurt = nanops.nankurt(data) + assert kurt == 0.0 + + def test_all_finite(self): + alpha, beta = 0.3, 0.1 + left_tailed = self.prng.beta(alpha, beta, size=100) + assert nanops.nankurt(left_tailed) < 2 + + alpha, beta = 0.1, 0.3 + right_tailed = self.prng.beta(alpha, beta, size=100) + assert nanops.nankurt(right_tailed) < 0 + + def test_ground_truth(self, samples, actual_kurt): + kurt = nanops.nankurt(samples) + tm.assert_almost_equal(kurt, actual_kurt) + + def test_axis(self, samples, actual_kurt): + samples = np.vstack([samples, np.nan * np.ones(len(samples))]) + kurt = nanops.nankurt(samples, axis=1) + tm.assert_almost_equal(kurt, np.array([actual_kurt, np.nan])) + + def test_nans(self, samples): + samples = np.hstack([samples, np.nan]) + kurt = nanops.nankurt(samples, skipna=False) + assert np.isnan(kurt) + + def test_nans_skipna(self, samples, actual_kurt): + samples = np.hstack([samples, np.nan]) + kurt = nanops.nankurt(samples, skipna=True) + tm.assert_almost_equal(kurt, actual_kurt) + + @property + def prng(self): + return np.random.default_rng(2) + + +class TestDatetime64NaNOps: + @pytest.fixture(params=["s", "ms", "us", "ns"]) + def unit(self, request): + return request.param + + # Enabling mean changes the behavior of DataFrame.mean + # See https://github.com/pandas-dev/pandas/issues/24752 + def test_nanmean(self, unit): + dti = pd.date_range("2016-01-01", periods=3).as_unit(unit) + expected = dti[1] + + for obj in [dti, dti._data]: + result = nanops.nanmean(obj) + assert result == expected + + dti2 = dti.insert(1, pd.NaT) + + for obj in [dti2, dti2._data]: + result = nanops.nanmean(obj) + assert result == expected + + @pytest.mark.parametrize("constructor", ["M8", "m8"]) + def test_nanmean_skipna_false(self, constructor, unit): + dtype = f"{constructor}[{unit}]" + arr = np.arange(12).astype(np.int64).view(dtype).reshape(4, 3) + + arr[-1, -1] = "NaT" + + result = nanops.nanmean(arr, skipna=False) + assert np.isnat(result) + assert result.dtype == dtype + + result = nanops.nanmean(arr, axis=0, skipna=False) + expected = np.array([4, 5, "NaT"], dtype=arr.dtype) + tm.assert_numpy_array_equal(result, expected) + + result = nanops.nanmean(arr, axis=1, skipna=False) + expected = np.array([arr[0, 1], arr[1, 1], arr[2, 1], arr[-1, -1]]) + tm.assert_numpy_array_equal(result, expected) + + +def test_use_bottleneck(): + if nanops._BOTTLENECK_INSTALLED: + with pd.option_context("use_bottleneck", True): + assert pd.get_option("use_bottleneck") + + with pd.option_context("use_bottleneck", False): + assert not pd.get_option("use_bottleneck") + + +@pytest.mark.parametrize( + "numpy_op, expected", + [ + (np.sum, 10), + (np.nansum, 10), + (np.mean, 2.5), + (np.nanmean, 2.5), + (np.median, 2.5), + (np.nanmedian, 2.5), + (np.min, 1), + (np.max, 4), + (np.nanmin, 1), + (np.nanmax, 4), + ], +) +def test_numpy_ops(numpy_op, expected): + # GH8383 + result = numpy_op(Series([1, 2, 3, 4])) + assert result == expected + + +@pytest.mark.parametrize( + "operation", + [ + nanops.nanany, + nanops.nanall, + nanops.nansum, + nanops.nanmean, + nanops.nanmedian, + nanops.nanstd, + nanops.nanvar, + nanops.nansem, + nanops.nanargmax, + nanops.nanargmin, + nanops.nanmax, + nanops.nanmin, + nanops.nanskew, + nanops.nankurt, + nanops.nanprod, + ], +) +def test_nanops_independent_of_mask_param(operation): + # GH22764 + ser = Series([1, 2, np.nan, 3, np.nan, 4]) + mask = ser.isna() + median_expected = operation(ser._values) + median_result = operation(ser._values, mask=mask) + assert median_expected == median_result + + +@pytest.mark.parametrize("min_count", [-1, 0]) +def test_check_below_min_count_negative_or_zero_min_count(min_count): + # GH35227 + result = nanops.check_below_min_count((21, 37), None, min_count) + expected_result = False + assert result == expected_result + + +@pytest.mark.parametrize( + "mask", [None, np.array([False, False, True]), np.array([True] + 9 * [False])] +) +@pytest.mark.parametrize("min_count, expected_result", [(1, False), (101, True)]) +def test_check_below_min_count_positive_min_count(mask, min_count, expected_result): + # GH35227 + shape = (10, 10) + result = nanops.check_below_min_count(shape, mask, min_count) + assert result == expected_result + + +@td.skip_if_windows +@td.skip_if_32bit +@pytest.mark.parametrize("min_count, expected_result", [(1, False), (2812191852, True)]) +def test_check_below_min_count_large_shape(min_count, expected_result): + # GH35227 large shape used to show that the issue is fixed + shape = (2244367, 1253) + result = nanops.check_below_min_count(shape, mask=None, min_count=min_count) + assert result == expected_result + + +@pytest.mark.parametrize("func", ["nanmean", "nansum"]) +def test_check_bottleneck_disallow(any_real_numpy_dtype, func): + # GH 42878 bottleneck sometimes produces unreliable results for mean and sum + assert not nanops._bn_ok_dtype(np.dtype(any_real_numpy_dtype).type, func) + + +@pytest.mark.parametrize("val", [2**55, -(2**55), 20150515061816532]) +def test_nanmean_overflow(disable_bottleneck, val): + # GH 10155 + # In the previous implementation mean can overflow for int dtypes, it + # is now consistent with numpy + + ser = Series(val, index=range(500), dtype=np.int64) + result = ser.mean() + np_result = ser.values.mean() + assert result == val + assert result == np_result + assert result.dtype == np.float64 + + +@pytest.mark.parametrize( + "dtype", + [ + np.int16, + np.int32, + np.int64, + np.float32, + np.float64, + getattr(np, "float128", None), + ], +) +@pytest.mark.parametrize("method", ["mean", "std", "var", "skew", "kurt", "min", "max"]) +def test_returned_dtype(disable_bottleneck, dtype, method): + if dtype is None: + pytest.skip("np.float128 not available") + + ser = Series(range(10), dtype=dtype) + result = getattr(ser, method)() + if is_integer_dtype(dtype) and method not in ["min", "max"]: + assert result.dtype == np.float64 + else: + assert result.dtype == dtype diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/test_optional_dependency.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/test_optional_dependency.py new file mode 100644 index 0000000000000000000000000000000000000000..52b5f636b1254ceddf869704a6378f4ea5012b8c --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/test_optional_dependency.py @@ -0,0 +1,100 @@ +import sys +import types + +import pytest + +from pandas.compat._optional import ( + VERSIONS, + import_optional_dependency, +) + +import pandas._testing as tm + + +def test_import_optional(): + match = "Missing .*notapackage.* pip .* conda .* notapackage" + with pytest.raises(ImportError, match=match) as exc_info: + import_optional_dependency("notapackage") + # The original exception should be there as context: + assert isinstance(exc_info.value.__context__, ImportError) + + result = import_optional_dependency("notapackage", errors="ignore") + assert result is None + + +def test_xlrd_version_fallback(): + pytest.importorskip("xlrd") + import_optional_dependency("xlrd") + + +def test_bad_version(monkeypatch): + name = "fakemodule" + module = types.ModuleType(name) + module.__version__ = "0.9.0" + sys.modules[name] = module + monkeypatch.setitem(VERSIONS, name, "1.0.0") + + match = "Pandas requires .*1.0.0.* of .fakemodule.*'0.9.0'" + with pytest.raises(ImportError, match=match): + import_optional_dependency("fakemodule") + + # Test min_version parameter + result = import_optional_dependency("fakemodule", min_version="0.8") + assert result is module + + with tm.assert_produces_warning(UserWarning): + result = import_optional_dependency("fakemodule", errors="warn") + assert result is None + + module.__version__ = "1.0.0" # exact match is OK + result = import_optional_dependency("fakemodule") + assert result is module + + with pytest.raises(ImportError, match="Pandas requires version '1.1.0'"): + import_optional_dependency("fakemodule", min_version="1.1.0") + + with tm.assert_produces_warning(UserWarning): + result = import_optional_dependency( + "fakemodule", errors="warn", min_version="1.1.0" + ) + assert result is None + + result = import_optional_dependency( + "fakemodule", errors="ignore", min_version="1.1.0" + ) + assert result is None + + +def test_submodule(monkeypatch): + # Create a fake module with a submodule + name = "fakemodule" + module = types.ModuleType(name) + module.__version__ = "0.9.0" + sys.modules[name] = module + sub_name = "submodule" + submodule = types.ModuleType(sub_name) + setattr(module, sub_name, submodule) + sys.modules[f"{name}.{sub_name}"] = submodule + monkeypatch.setitem(VERSIONS, name, "1.0.0") + + match = "Pandas requires .*1.0.0.* of .fakemodule.*'0.9.0'" + with pytest.raises(ImportError, match=match): + import_optional_dependency("fakemodule.submodule") + + with tm.assert_produces_warning(UserWarning): + result = import_optional_dependency("fakemodule.submodule", errors="warn") + assert result is None + + module.__version__ = "1.0.0" # exact match is OK + result = import_optional_dependency("fakemodule.submodule") + assert result is submodule + + +def test_no_version_raises(monkeypatch): + name = "fakemodule" + module = types.ModuleType(name) + sys.modules[name] = module + monkeypatch.setitem(VERSIONS, name, "1.0.0") + + with pytest.raises(ImportError, match="Can't determine .* fakemodule"): + import_optional_dependency(name) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/test_register_accessor.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/test_register_accessor.py new file mode 100644 index 0000000000000000000000000000000000000000..4e569dc40005d5883870b82f46859d5bc36578f9 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/test_register_accessor.py @@ -0,0 +1,103 @@ +from collections.abc import Generator +import contextlib + +import pytest + +import pandas as pd +import pandas._testing as tm +from pandas.core import accessor + + +def test_dirname_mixin() -> None: + # GH37173 + + class X(accessor.DirNamesMixin): + x = 1 + y: int + + def __init__(self) -> None: + self.z = 3 + + result = [attr_name for attr_name in dir(X()) if not attr_name.startswith("_")] + + assert result == ["x", "z"] + + +@contextlib.contextmanager +def ensure_removed(obj, attr) -> Generator[None, None, None]: + """Ensure that an attribute added to 'obj' during the test is + removed when we're done + """ + try: + yield + finally: + try: + delattr(obj, attr) + except AttributeError: + pass + obj._accessors.discard(attr) + + +class MyAccessor: + def __init__(self, obj) -> None: + self.obj = obj + self.item = "item" + + @property + def prop(self): + return self.item + + def method(self): + return self.item + + +@pytest.mark.parametrize( + "obj, registrar", + [ + (pd.Series, pd.api.extensions.register_series_accessor), + (pd.DataFrame, pd.api.extensions.register_dataframe_accessor), + (pd.Index, pd.api.extensions.register_index_accessor), + ], +) +def test_register(obj, registrar): + with ensure_removed(obj, "mine"): + before = set(dir(obj)) + registrar("mine")(MyAccessor) + o = obj([]) if obj is not pd.Series else obj([], dtype=object) + assert o.mine.prop == "item" + after = set(dir(obj)) + assert (before ^ after) == {"mine"} + assert "mine" in obj._accessors + + +def test_accessor_works(): + with ensure_removed(pd.Series, "mine"): + pd.api.extensions.register_series_accessor("mine")(MyAccessor) + + s = pd.Series([1, 2]) + assert s.mine.obj is s + + assert s.mine.prop == "item" + assert s.mine.method() == "item" + + +def test_overwrite_warns(): + match = r".*MyAccessor.*fake.*Series.*" + with tm.assert_produces_warning(UserWarning, match=match): + with ensure_removed(pd.Series, "fake"): + setattr(pd.Series, "fake", 123) + pd.api.extensions.register_series_accessor("fake")(MyAccessor) + s = pd.Series([1, 2]) + assert s.fake.prop == "item" + + +def test_raises_attribute_error(): + with ensure_removed(pd.Series, "bad"): + + @pd.api.extensions.register_series_accessor("bad") + class Bad: + def __init__(self, data) -> None: + raise AttributeError("whoops") + + with pytest.raises(AttributeError, match="whoops"): + pd.Series([], dtype=object).bad diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/test_sorting.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/test_sorting.py new file mode 100644 index 0000000000000000000000000000000000000000..285f240028152072ed52b3657113d30a7fd63fea --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/test_sorting.py @@ -0,0 +1,487 @@ +from collections import defaultdict +from datetime import datetime +from itertools import product + +import numpy as np +import pytest + +from pandas import ( + NA, + DataFrame, + MultiIndex, + Series, + array, + concat, + merge, +) +import pandas._testing as tm +from pandas.core.algorithms import safe_sort +import pandas.core.common as com +from pandas.core.sorting import ( + _decons_group_index, + get_group_index, + is_int64_overflow_possible, + lexsort_indexer, + nargsort, +) + + +@pytest.fixture +def left_right(): + low, high, n = -1 << 10, 1 << 10, 1 << 20 + left = DataFrame( + np.random.default_rng(2).integers(low, high, (n, 7)), columns=list("ABCDEFG") + ) + left["left"] = left.sum(axis=1) + + # one-2-one match + i = np.random.default_rng(2).permutation(len(left)) + right = left.iloc[i].copy() + right.columns = right.columns[:-1].tolist() + ["right"] + right.index = np.arange(len(right)) + right["right"] *= -1 + return left, right + + +class TestSorting: + @pytest.mark.slow + def test_int64_overflow(self): + B = np.concatenate((np.arange(1000), np.arange(1000), np.arange(500))) + A = np.arange(2500) + df = DataFrame( + { + "A": A, + "B": B, + "C": A, + "D": B, + "E": A, + "F": B, + "G": A, + "H": B, + "values": np.random.default_rng(2).standard_normal(2500), + } + ) + + lg = df.groupby(["A", "B", "C", "D", "E", "F", "G", "H"]) + rg = df.groupby(["H", "G", "F", "E", "D", "C", "B", "A"]) + + left = lg.sum()["values"] + right = rg.sum()["values"] + + exp_index, _ = left.index.sortlevel() + tm.assert_index_equal(left.index, exp_index) + + exp_index, _ = right.index.sortlevel(0) + tm.assert_index_equal(right.index, exp_index) + + tups = list(map(tuple, df[["A", "B", "C", "D", "E", "F", "G", "H"]].values)) + tups = com.asarray_tuplesafe(tups) + + expected = df.groupby(tups).sum()["values"] + + for k, v in expected.items(): + assert left[k] == right[k[::-1]] + assert left[k] == v + assert len(left) == len(right) + + def test_int64_overflow_groupby_large_range(self): + # GH9096 + values = range(55109) + data = DataFrame.from_dict({"a": values, "b": values, "c": values, "d": values}) + grouped = data.groupby(["a", "b", "c", "d"]) + assert len(grouped) == len(values) + + @pytest.mark.parametrize("agg", ["mean", "median"]) + def test_int64_overflow_groupby_large_df_shuffled(self, agg): + rs = np.random.default_rng(2) + arr = rs.integers(-1 << 12, 1 << 12, (1 << 15, 5)) + i = rs.choice(len(arr), len(arr) * 4) + arr = np.vstack((arr, arr[i])) # add some duplicate rows + + i = rs.permutation(len(arr)) + arr = arr[i] # shuffle rows + + df = DataFrame(arr, columns=list("abcde")) + df["jim"], df["joe"] = np.zeros((2, len(df))) + gr = df.groupby(list("abcde")) + + # verify this is testing what it is supposed to test! + assert is_int64_overflow_possible(gr._grouper.shape) + + mi = MultiIndex.from_arrays( + [ar.ravel() for ar in np.array_split(np.unique(arr, axis=0), 5, axis=1)], + names=list("abcde"), + ) + + res = DataFrame( + np.zeros((len(mi), 2)), columns=["jim", "joe"], index=mi + ).sort_index() + + tm.assert_frame_equal(getattr(gr, agg)(), res) + + @pytest.mark.parametrize( + "order, na_position, exp", + [ + [ + True, + "last", + list(range(5, 105)) + list(range(5)) + list(range(105, 110)), + ], + [ + True, + "first", + list(range(5)) + list(range(105, 110)) + list(range(5, 105)), + ], + [ + False, + "last", + list(range(104, 4, -1)) + list(range(5)) + list(range(105, 110)), + ], + [ + False, + "first", + list(range(5)) + list(range(105, 110)) + list(range(104, 4, -1)), + ], + ], + ) + def test_lexsort_indexer(self, order, na_position, exp): + keys = [[np.nan] * 5 + list(range(100)) + [np.nan] * 5] + result = lexsort_indexer(keys, orders=order, na_position=na_position) + tm.assert_numpy_array_equal(result, np.array(exp, dtype=np.intp)) + + @pytest.mark.parametrize( + "ascending, na_position, exp", + [ + [ + True, + "last", + list(range(5, 105)) + list(range(5)) + list(range(105, 110)), + ], + [ + True, + "first", + list(range(5)) + list(range(105, 110)) + list(range(5, 105)), + ], + [ + False, + "last", + list(range(104, 4, -1)) + list(range(5)) + list(range(105, 110)), + ], + [ + False, + "first", + list(range(5)) + list(range(105, 110)) + list(range(104, 4, -1)), + ], + ], + ) + def test_nargsort(self, ascending, na_position, exp): + # list places NaNs last, np.array(..., dtype="O") may not place NaNs first + items = np.array([np.nan] * 5 + list(range(100)) + [np.nan] * 5, dtype="O") + + # mergesort is the most difficult to get right because we want it to be + # stable. + + # According to numpy/core/tests/test_multiarray, """The number of + # sorted items must be greater than ~50 to check the actual algorithm + # because quick and merge sort fall over to insertion sort for small + # arrays.""" + + result = nargsort( + items, kind="mergesort", ascending=ascending, na_position=na_position + ) + tm.assert_numpy_array_equal(result, np.array(exp), check_dtype=False) + + +class TestMerge: + def test_int64_overflow_outer_merge(self): + # #2690, combinatorial explosion + df1 = DataFrame( + np.random.default_rng(2).standard_normal((1000, 7)), + columns=list("ABCDEF") + ["G1"], + ) + df2 = DataFrame( + np.random.default_rng(3).standard_normal((1000, 7)), + columns=list("ABCDEF") + ["G2"], + ) + result = merge(df1, df2, how="outer") + assert len(result) == 2000 + + @pytest.mark.slow + def test_int64_overflow_check_sum_col(self, left_right): + left, right = left_right + + out = merge(left, right, how="outer") + assert len(out) == len(left) + tm.assert_series_equal(out["left"], -out["right"], check_names=False) + result = out.iloc[:, :-2].sum(axis=1) + tm.assert_series_equal(out["left"], result, check_names=False) + assert result.name is None + + @pytest.mark.slow + @pytest.mark.parametrize("how", ["left", "right", "outer", "inner"]) + def test_int64_overflow_how_merge(self, left_right, how): + left, right = left_right + + out = merge(left, right, how="outer") + out.sort_values(out.columns.tolist(), inplace=True) + out.index = np.arange(len(out)) + tm.assert_frame_equal(out, merge(left, right, how=how, sort=True)) + + @pytest.mark.slow + def test_int64_overflow_sort_false_order(self, left_right): + left, right = left_right + + # check that left merge w/ sort=False maintains left frame order + out = merge(left, right, how="left", sort=False) + tm.assert_frame_equal(left, out[left.columns.tolist()]) + + out = merge(right, left, how="left", sort=False) + tm.assert_frame_equal(right, out[right.columns.tolist()]) + + @pytest.mark.slow + @pytest.mark.parametrize("how", ["left", "right", "outer", "inner"]) + @pytest.mark.parametrize("sort", [True, False]) + def test_int64_overflow_one_to_many_none_match(self, how, sort): + # one-2-many/none match + low, high, n = -1 << 10, 1 << 10, 1 << 11 + left = DataFrame( + np.random.default_rng(2).integers(low, high, (n, 7)).astype("int64"), + columns=list("ABCDEFG"), + ) + + # confirm that this is checking what it is supposed to check + shape = left.apply(Series.nunique).values + assert is_int64_overflow_possible(shape) + + # add duplicates to left frame + left = concat([left, left], ignore_index=True) + + right = DataFrame( + np.random.default_rng(3).integers(low, high, (n // 2, 7)).astype("int64"), + columns=list("ABCDEFG"), + ) + + # add duplicates & overlap with left to the right frame + i = np.random.default_rng(4).choice(len(left), n) + right = concat([right, right, left.iloc[i]], ignore_index=True) + + left["left"] = np.random.default_rng(2).standard_normal(len(left)) + right["right"] = np.random.default_rng(2).standard_normal(len(right)) + + # shuffle left & right frames + i = np.random.default_rng(5).permutation(len(left)) + left = left.iloc[i].copy() + left.index = np.arange(len(left)) + + i = np.random.default_rng(6).permutation(len(right)) + right = right.iloc[i].copy() + right.index = np.arange(len(right)) + + # manually compute outer merge + ldict, rdict = defaultdict(list), defaultdict(list) + + for idx, row in left.set_index(list("ABCDEFG")).iterrows(): + ldict[idx].append(row["left"]) + + for idx, row in right.set_index(list("ABCDEFG")).iterrows(): + rdict[idx].append(row["right"]) + + vals = [] + for k, lval in ldict.items(): + rval = rdict.get(k, [np.nan]) + for lv, rv in product(lval, rval): + vals.append( + k + + ( + lv, + rv, + ) + ) + + for k, rval in rdict.items(): + if k not in ldict: + vals.extend( + k + + ( + np.nan, + rv, + ) + for rv in rval + ) + + def align(df): + df = df.sort_values(df.columns.tolist()) + df.index = np.arange(len(df)) + return df + + out = DataFrame(vals, columns=list("ABCDEFG") + ["left", "right"]) + out = align(out) + + jmask = { + "left": out["left"].notna(), + "right": out["right"].notna(), + "inner": out["left"].notna() & out["right"].notna(), + "outer": np.ones(len(out), dtype="bool"), + } + + mask = jmask[how] + frame = align(out[mask].copy()) + assert mask.all() ^ mask.any() or how == "outer" + + res = merge(left, right, how=how, sort=sort) + if sort: + kcols = list("ABCDEFG") + tm.assert_frame_equal( + res[kcols].copy(), res[kcols].sort_values(kcols, kind="mergesort") + ) + + # as in GH9092 dtypes break with outer/right join + # 2021-12-18: dtype does not break anymore + tm.assert_frame_equal(frame, align(res)) + + +@pytest.mark.parametrize( + "codes_list, shape", + [ + [ + [ + np.tile([0, 1, 2, 3, 0, 1, 2, 3], 100).astype(np.int64), + np.tile([0, 2, 4, 3, 0, 1, 2, 3], 100).astype(np.int64), + np.tile([5, 1, 0, 2, 3, 0, 5, 4], 100).astype(np.int64), + ], + (4, 5, 6), + ], + [ + [ + np.tile(np.arange(10000, dtype=np.int64), 5), + np.tile(np.arange(10000, dtype=np.int64), 5), + ], + (10000, 10000), + ], + ], +) +def test_decons(codes_list, shape): + group_index = get_group_index(codes_list, shape, sort=True, xnull=True) + codes_list2 = _decons_group_index(group_index, shape) + + for a, b in zip(codes_list, codes_list2): + tm.assert_numpy_array_equal(a, b) + + +class TestSafeSort: + @pytest.mark.parametrize( + "arg, exp", + [ + [[3, 1, 2, 0, 4], [0, 1, 2, 3, 4]], + [ + np.array(list("baaacb"), dtype=object), + np.array(list("aaabbc"), dtype=object), + ], + [[], []], + ], + ) + def test_basic_sort(self, arg, exp): + result = safe_sort(np.array(arg)) + expected = np.array(exp) + tm.assert_numpy_array_equal(result, expected) + + @pytest.mark.parametrize("verify", [True, False]) + @pytest.mark.parametrize( + "codes, exp_codes", + [ + [[0, 1, 1, 2, 3, 0, -1, 4], [3, 1, 1, 2, 0, 3, -1, 4]], + [[], []], + ], + ) + def test_codes(self, verify, codes, exp_codes): + values = np.array([3, 1, 2, 0, 4]) + expected = np.array([0, 1, 2, 3, 4]) + + result, result_codes = safe_sort( + values, codes, use_na_sentinel=True, verify=verify + ) + expected_codes = np.array(exp_codes, dtype=np.intp) + tm.assert_numpy_array_equal(result, expected) + tm.assert_numpy_array_equal(result_codes, expected_codes) + + def test_codes_out_of_bound(self): + values = np.array([3, 1, 2, 0, 4]) + expected = np.array([0, 1, 2, 3, 4]) + + # out of bound indices + codes = [0, 101, 102, 2, 3, 0, 99, 4] + result, result_codes = safe_sort(values, codes, use_na_sentinel=True) + expected_codes = np.array([3, -1, -1, 2, 0, 3, -1, 4], dtype=np.intp) + tm.assert_numpy_array_equal(result, expected) + tm.assert_numpy_array_equal(result_codes, expected_codes) + + def test_mixed_integer(self): + values = np.array(["b", 1, 0, "a", 0, "b"], dtype=object) + result = safe_sort(values) + expected = np.array([0, 0, 1, "a", "b", "b"], dtype=object) + tm.assert_numpy_array_equal(result, expected) + + def test_mixed_integer_with_codes(self): + values = np.array(["b", 1, 0, "a"], dtype=object) + codes = [0, 1, 2, 3, 0, -1, 1] + result, result_codes = safe_sort(values, codes) + expected = np.array([0, 1, "a", "b"], dtype=object) + expected_codes = np.array([3, 1, 0, 2, 3, -1, 1], dtype=np.intp) + tm.assert_numpy_array_equal(result, expected) + tm.assert_numpy_array_equal(result_codes, expected_codes) + + def test_unsortable(self): + # GH 13714 + arr = np.array([1, 2, datetime.now(), 0, 3], dtype=object) + msg = "'[<>]' not supported between instances of .*" + with pytest.raises(TypeError, match=msg): + safe_sort(arr) + + @pytest.mark.parametrize( + "arg, codes, err, msg", + [ + [1, None, TypeError, "Only np.ndarray, ExtensionArray, and Index"], + [np.array([0, 1, 2]), 1, TypeError, "Only list-like objects or None"], + [np.array([0, 1, 2, 1]), [0, 1], ValueError, "values should be unique"], + ], + ) + def test_exceptions(self, arg, codes, err, msg): + with pytest.raises(err, match=msg): + safe_sort(values=arg, codes=codes) + + @pytest.mark.parametrize( + "arg, exp", [[[1, 3, 2], [1, 2, 3]], [[1, 3, np.nan, 2], [1, 2, 3, np.nan]]] + ) + def test_extension_array(self, arg, exp): + a = array(arg, dtype="Int64") + result = safe_sort(a) + expected = array(exp, dtype="Int64") + tm.assert_extension_array_equal(result, expected) + + @pytest.mark.parametrize("verify", [True, False]) + def test_extension_array_codes(self, verify): + a = array([1, 3, 2], dtype="Int64") + result, codes = safe_sort(a, [0, 1, -1, 2], use_na_sentinel=True, verify=verify) + expected_values = array([1, 2, 3], dtype="Int64") + expected_codes = np.array([0, 2, -1, 1], dtype=np.intp) + tm.assert_extension_array_equal(result, expected_values) + tm.assert_numpy_array_equal(codes, expected_codes) + + +def test_mixed_str_null(nulls_fixture): + values = np.array(["b", nulls_fixture, "a", "b"], dtype=object) + result = safe_sort(values) + expected = np.array(["a", "b", "b", nulls_fixture], dtype=object) + tm.assert_numpy_array_equal(result, expected) + + +def test_safe_sort_multiindex(): + # GH#48412 + arr1 = Series([2, 1, NA, NA], dtype="Int64") + arr2 = [2, 1, 3, 3] + midx = MultiIndex.from_arrays([arr1, arr2]) + result = safe_sort(midx) + expected = MultiIndex.from_arrays( + [Series([1, 2, NA, NA], dtype="Int64"), [1, 2, 3, 3]] + ) + tm.assert_index_equal(result, expected) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/test_take.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/test_take.py new file mode 100644 index 0000000000000000000000000000000000000000..4f34ab34c35f0c2446597001f59526d4c8b0900d --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/tests/test_take.py @@ -0,0 +1,307 @@ +from datetime import datetime + +import numpy as np +import pytest + +from pandas._libs import iNaT + +import pandas._testing as tm +import pandas.core.algorithms as algos + + +@pytest.fixture( + params=[ + (np.int8, np.int16(127), np.int8), + (np.int8, np.int16(128), np.int16), + (np.int32, 1, np.int32), + (np.int32, 2.0, np.float64), + (np.int32, 3.0 + 4.0j, np.complex128), + (np.int32, True, np.object_), + (np.int32, "", np.object_), + (np.float64, 1, np.float64), + (np.float64, 2.0, np.float64), + (np.float64, 3.0 + 4.0j, np.complex128), + (np.float64, True, np.object_), + (np.float64, "", np.object_), + (np.complex128, 1, np.complex128), + (np.complex128, 2.0, np.complex128), + (np.complex128, 3.0 + 4.0j, np.complex128), + (np.complex128, True, np.object_), + (np.complex128, "", np.object_), + (np.bool_, 1, np.object_), + (np.bool_, 2.0, np.object_), + (np.bool_, 3.0 + 4.0j, np.object_), + (np.bool_, True, np.bool_), + (np.bool_, "", np.object_), + ] +) +def dtype_fill_out_dtype(request): + return request.param + + +class TestTake: + def test_1d_fill_nonna(self, dtype_fill_out_dtype): + dtype, fill_value, out_dtype = dtype_fill_out_dtype + data = np.random.default_rng(2).integers(0, 2, 4).astype(dtype) + indexer = [2, 1, 0, -1] + + result = algos.take_nd(data, indexer, fill_value=fill_value) + assert (result[[0, 1, 2]] == data[[2, 1, 0]]).all() + assert result[3] == fill_value + assert result.dtype == out_dtype + + indexer = [2, 1, 0, 1] + + result = algos.take_nd(data, indexer, fill_value=fill_value) + assert (result[[0, 1, 2, 3]] == data[indexer]).all() + assert result.dtype == dtype + + def test_2d_fill_nonna(self, dtype_fill_out_dtype): + dtype, fill_value, out_dtype = dtype_fill_out_dtype + data = np.random.default_rng(2).integers(0, 2, (5, 3)).astype(dtype) + indexer = [2, 1, 0, -1] + + result = algos.take_nd(data, indexer, axis=0, fill_value=fill_value) + assert (result[[0, 1, 2], :] == data[[2, 1, 0], :]).all() + assert (result[3, :] == fill_value).all() + assert result.dtype == out_dtype + + result = algos.take_nd(data, indexer, axis=1, fill_value=fill_value) + assert (result[:, [0, 1, 2]] == data[:, [2, 1, 0]]).all() + assert (result[:, 3] == fill_value).all() + assert result.dtype == out_dtype + + indexer = [2, 1, 0, 1] + result = algos.take_nd(data, indexer, axis=0, fill_value=fill_value) + assert (result[[0, 1, 2, 3], :] == data[indexer, :]).all() + assert result.dtype == dtype + + result = algos.take_nd(data, indexer, axis=1, fill_value=fill_value) + assert (result[:, [0, 1, 2, 3]] == data[:, indexer]).all() + assert result.dtype == dtype + + def test_3d_fill_nonna(self, dtype_fill_out_dtype): + dtype, fill_value, out_dtype = dtype_fill_out_dtype + + data = np.random.default_rng(2).integers(0, 2, (5, 4, 3)).astype(dtype) + indexer = [2, 1, 0, -1] + + result = algos.take_nd(data, indexer, axis=0, fill_value=fill_value) + assert (result[[0, 1, 2], :, :] == data[[2, 1, 0], :, :]).all() + assert (result[3, :, :] == fill_value).all() + assert result.dtype == out_dtype + + result = algos.take_nd(data, indexer, axis=1, fill_value=fill_value) + assert (result[:, [0, 1, 2], :] == data[:, [2, 1, 0], :]).all() + assert (result[:, 3, :] == fill_value).all() + assert result.dtype == out_dtype + + result = algos.take_nd(data, indexer, axis=2, fill_value=fill_value) + assert (result[:, :, [0, 1, 2]] == data[:, :, [2, 1, 0]]).all() + assert (result[:, :, 3] == fill_value).all() + assert result.dtype == out_dtype + + indexer = [2, 1, 0, 1] + result = algos.take_nd(data, indexer, axis=0, fill_value=fill_value) + assert (result[[0, 1, 2, 3], :, :] == data[indexer, :, :]).all() + assert result.dtype == dtype + + result = algos.take_nd(data, indexer, axis=1, fill_value=fill_value) + assert (result[:, [0, 1, 2, 3], :] == data[:, indexer, :]).all() + assert result.dtype == dtype + + result = algos.take_nd(data, indexer, axis=2, fill_value=fill_value) + assert (result[:, :, [0, 1, 2, 3]] == data[:, :, indexer]).all() + assert result.dtype == dtype + + def test_1d_other_dtypes(self): + arr = np.random.default_rng(2).standard_normal(10).astype(np.float32) + + indexer = [1, 2, 3, -1] + result = algos.take_nd(arr, indexer) + expected = arr.take(indexer) + expected[-1] = np.nan + tm.assert_almost_equal(result, expected) + + def test_2d_other_dtypes(self): + arr = np.random.default_rng(2).standard_normal((10, 5)).astype(np.float32) + + indexer = [1, 2, 3, -1] + + # axis=0 + result = algos.take_nd(arr, indexer, axis=0) + expected = arr.take(indexer, axis=0) + expected[-1] = np.nan + tm.assert_almost_equal(result, expected) + + # axis=1 + result = algos.take_nd(arr, indexer, axis=1) + expected = arr.take(indexer, axis=1) + expected[:, -1] = np.nan + tm.assert_almost_equal(result, expected) + + def test_1d_bool(self): + arr = np.array([0, 1, 0], dtype=bool) + + result = algos.take_nd(arr, [0, 2, 2, 1]) + expected = arr.take([0, 2, 2, 1]) + tm.assert_numpy_array_equal(result, expected) + + result = algos.take_nd(arr, [0, 2, -1]) + assert result.dtype == np.object_ + + def test_2d_bool(self): + arr = np.array([[0, 1, 0], [1, 0, 1], [0, 1, 1]], dtype=bool) + + result = algos.take_nd(arr, [0, 2, 2, 1]) + expected = arr.take([0, 2, 2, 1], axis=0) + tm.assert_numpy_array_equal(result, expected) + + result = algos.take_nd(arr, [0, 2, 2, 1], axis=1) + expected = arr.take([0, 2, 2, 1], axis=1) + tm.assert_numpy_array_equal(result, expected) + + result = algos.take_nd(arr, [0, 2, -1]) + assert result.dtype == np.object_ + + def test_2d_float32(self): + arr = np.random.default_rng(2).standard_normal((4, 3)).astype(np.float32) + indexer = [0, 2, -1, 1, -1] + + # axis=0 + result = algos.take_nd(arr, indexer, axis=0) + + expected = arr.take(indexer, axis=0) + expected[[2, 4], :] = np.nan + tm.assert_almost_equal(result, expected) + + # axis=1 + result = algos.take_nd(arr, indexer, axis=1) + expected = arr.take(indexer, axis=1) + expected[:, [2, 4]] = np.nan + tm.assert_almost_equal(result, expected) + + def test_2d_datetime64(self): + # 2005/01/01 - 2006/01/01 + arr = ( + np.random.default_rng(2).integers(11_045_376, 11_360_736, (5, 3)) + * 100_000_000_000 + ) + arr = arr.view(dtype="datetime64[ns]") + indexer = [0, 2, -1, 1, -1] + + # axis=0 + result = algos.take_nd(arr, indexer, axis=0) + expected = arr.take(indexer, axis=0) + expected.view(np.int64)[[2, 4], :] = iNaT + tm.assert_almost_equal(result, expected) + + result = algos.take_nd(arr, indexer, axis=0, fill_value=datetime(2007, 1, 1)) + expected = arr.take(indexer, axis=0) + expected[[2, 4], :] = datetime(2007, 1, 1) + tm.assert_almost_equal(result, expected) + + # axis=1 + result = algos.take_nd(arr, indexer, axis=1) + expected = arr.take(indexer, axis=1) + expected.view(np.int64)[:, [2, 4]] = iNaT + tm.assert_almost_equal(result, expected) + + result = algos.take_nd(arr, indexer, axis=1, fill_value=datetime(2007, 1, 1)) + expected = arr.take(indexer, axis=1) + expected[:, [2, 4]] = datetime(2007, 1, 1) + tm.assert_almost_equal(result, expected) + + def test_take_axis_0(self): + arr = np.arange(12).reshape(4, 3) + result = algos.take(arr, [0, -1]) + expected = np.array([[0, 1, 2], [9, 10, 11]]) + tm.assert_numpy_array_equal(result, expected) + + # allow_fill=True + result = algos.take(arr, [0, -1], allow_fill=True, fill_value=0) + expected = np.array([[0, 1, 2], [0, 0, 0]]) + tm.assert_numpy_array_equal(result, expected) + + def test_take_axis_1(self): + arr = np.arange(12).reshape(4, 3) + result = algos.take(arr, [0, -1], axis=1) + expected = np.array([[0, 2], [3, 5], [6, 8], [9, 11]]) + tm.assert_numpy_array_equal(result, expected) + + # allow_fill=True + result = algos.take(arr, [0, -1], axis=1, allow_fill=True, fill_value=0) + expected = np.array([[0, 0], [3, 0], [6, 0], [9, 0]]) + tm.assert_numpy_array_equal(result, expected) + + # GH#26976 make sure we validate along the correct axis + with pytest.raises(IndexError, match="indices are out-of-bounds"): + algos.take(arr, [0, 3], axis=1, allow_fill=True, fill_value=0) + + def test_take_non_hashable_fill_value(self): + arr = np.array([1, 2, 3]) + indexer = np.array([1, -1]) + with pytest.raises(ValueError, match="fill_value must be a scalar"): + algos.take(arr, indexer, allow_fill=True, fill_value=[1]) + + # with object dtype it is allowed + arr = np.array([1, 2, 3], dtype=object) + result = algos.take(arr, indexer, allow_fill=True, fill_value=[1]) + expected = np.array([2, [1]], dtype=object) + tm.assert_numpy_array_equal(result, expected) + + +class TestExtensionTake: + # The take method found in pd.api.extensions + + def test_bounds_check_large(self): + arr = np.array([1, 2]) + + msg = "indices are out-of-bounds" + with pytest.raises(IndexError, match=msg): + algos.take(arr, [2, 3], allow_fill=True) + + msg = "index 2 is out of bounds for( axis 0 with)? size 2" + with pytest.raises(IndexError, match=msg): + algos.take(arr, [2, 3], allow_fill=False) + + def test_bounds_check_small(self): + arr = np.array([1, 2, 3], dtype=np.int64) + indexer = [0, -1, -2] + + msg = r"'indices' contains values less than allowed \(-2 < -1\)" + with pytest.raises(ValueError, match=msg): + algos.take(arr, indexer, allow_fill=True) + + result = algos.take(arr, indexer) + expected = np.array([1, 3, 2], dtype=np.int64) + tm.assert_numpy_array_equal(result, expected) + + @pytest.mark.parametrize("allow_fill", [True, False]) + def test_take_empty(self, allow_fill): + arr = np.array([], dtype=np.int64) + # empty take is ok + result = algos.take(arr, [], allow_fill=allow_fill) + tm.assert_numpy_array_equal(arr, result) + + msg = "|".join( + [ + "cannot do a non-empty take from an empty axes.", + "indices are out-of-bounds", + ] + ) + with pytest.raises(IndexError, match=msg): + algos.take(arr, [0], allow_fill=allow_fill) + + def test_take_na_empty(self): + result = algos.take(np.array([]), [-1, -1], allow_fill=True, fill_value=0.0) + expected = np.array([0.0, 0.0]) + tm.assert_numpy_array_equal(result, expected) + + def test_take_coerces_list(self): + arr = [1, 2, 3] + msg = "take accepting non-standard inputs is deprecated" + with tm.assert_produces_warning(FutureWarning, match=msg): + result = algos.take(arr, [0, 0]) + expected = np.array([1, 1]) + tm.assert_numpy_array_equal(result, expected) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/util/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/util/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..82b3aa56c653cd1241872c67e9d9016df04a6c5a --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/util/__init__.py @@ -0,0 +1,29 @@ +def __getattr__(key: str): + # These imports need to be lazy to avoid circular import errors + if key == "hash_array": + from pandas.core.util.hashing import hash_array + + return hash_array + if key == "hash_pandas_object": + from pandas.core.util.hashing import hash_pandas_object + + return hash_pandas_object + if key == "Appender": + from pandas.util._decorators import Appender + + return Appender + if key == "Substitution": + from pandas.util._decorators import Substitution + + return Substitution + + if key == "cache_readonly": + from pandas.util._decorators import cache_readonly + + return cache_readonly + + raise AttributeError(f"module 'pandas.util' has no attribute '{key}'") + + +def capitalize_first_letter(s): + return s[:1].upper() + s[1:] diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/util/_decorators.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/util/_decorators.py new file mode 100644 index 0000000000000000000000000000000000000000..4e8189e72c427d75da0367532edfe04d0a7581e8 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/util/_decorators.py @@ -0,0 +1,508 @@ +from __future__ import annotations + +from functools import wraps +import inspect +from textwrap import dedent +from typing import ( + TYPE_CHECKING, + Any, + Callable, + cast, +) +import warnings + +from pandas._libs.properties import cache_readonly +from pandas._typing import ( + F, + T, +) +from pandas.util._exceptions import find_stack_level + +if TYPE_CHECKING: + from collections.abc import Mapping + + +def deprecate( + name: str, + alternative: Callable[..., Any], + version: str, + alt_name: str | None = None, + klass: type[Warning] | None = None, + stacklevel: int = 2, + msg: str | None = None, +) -> Callable[[F], F]: + """ + Return a new function that emits a deprecation warning on use. + + To use this method for a deprecated function, another function + `alternative` with the same signature must exist. The deprecated + function will emit a deprecation warning, and in the docstring + it will contain the deprecation directive with the provided version + so it can be detected for future removal. + + Parameters + ---------- + name : str + Name of function to deprecate. + alternative : func + Function to use instead. + version : str + Version of pandas in which the method has been deprecated. + alt_name : str, optional + Name to use in preference of alternative.__name__. + klass : Warning, default FutureWarning + stacklevel : int, default 2 + msg : str + The message to display in the warning. + Default is '{name} is deprecated. Use {alt_name} instead.' + """ + alt_name = alt_name or alternative.__name__ + klass = klass or FutureWarning + warning_msg = msg or f"{name} is deprecated, use {alt_name} instead." + + @wraps(alternative) + def wrapper(*args, **kwargs) -> Callable[..., Any]: + warnings.warn(warning_msg, klass, stacklevel=stacklevel) + return alternative(*args, **kwargs) + + # adding deprecated directive to the docstring + msg = msg or f"Use `{alt_name}` instead." + doc_error_msg = ( + "deprecate needs a correctly formatted docstring in " + "the target function (should have a one liner short " + "summary, and opening quotes should be in their own " + f"line). Found:\n{alternative.__doc__}" + ) + + # when python is running in optimized mode (i.e. `-OO`), docstrings are + # removed, so we check that a docstring with correct formatting is used + # but we allow empty docstrings + if alternative.__doc__: + if alternative.__doc__.count("\n") < 3: + raise AssertionError(doc_error_msg) + empty1, summary, empty2, doc_string = alternative.__doc__.split("\n", 3) + if empty1 or empty2 and not summary: + raise AssertionError(doc_error_msg) + wrapper.__doc__ = dedent( + f""" + {summary.strip()} + + .. deprecated:: {version} + {msg} + + {dedent(doc_string)}""" + ) + # error: Incompatible return value type (got "Callable[[VarArg(Any), KwArg(Any)], + # Callable[...,Any]]", expected "Callable[[F], F]") + return wrapper # type: ignore[return-value] + + +def deprecate_kwarg( + old_arg_name: str, + new_arg_name: str | None, + mapping: Mapping[Any, Any] | Callable[[Any], Any] | None = None, + stacklevel: int = 2, +) -> Callable[[F], F]: + """ + Decorator to deprecate a keyword argument of a function. + + Parameters + ---------- + old_arg_name : str + Name of argument in function to deprecate + new_arg_name : str or None + Name of preferred argument in function. Use None to raise warning that + ``old_arg_name`` keyword is deprecated. + mapping : dict or callable + If mapping is present, use it to translate old arguments to + new arguments. A callable must do its own value checking; + values not found in a dict will be forwarded unchanged. + + Examples + -------- + The following deprecates 'cols', using 'columns' instead + + >>> @deprecate_kwarg(old_arg_name='cols', new_arg_name='columns') + ... def f(columns=''): + ... print(columns) + ... + >>> f(columns='should work ok') + should work ok + + >>> f(cols='should raise warning') # doctest: +SKIP + FutureWarning: cols is deprecated, use columns instead + warnings.warn(msg, FutureWarning) + should raise warning + + >>> f(cols='should error', columns="can\'t pass do both") # doctest: +SKIP + TypeError: Can only specify 'cols' or 'columns', not both + + >>> @deprecate_kwarg('old', 'new', {'yes': True, 'no': False}) + ... def f(new=False): + ... print('yes!' if new else 'no!') + ... + >>> f(old='yes') # doctest: +SKIP + FutureWarning: old='yes' is deprecated, use new=True instead + warnings.warn(msg, FutureWarning) + yes! + + To raise a warning that a keyword will be removed entirely in the future + + >>> @deprecate_kwarg(old_arg_name='cols', new_arg_name=None) + ... def f(cols='', another_param=''): + ... print(cols) + ... + >>> f(cols='should raise warning') # doctest: +SKIP + FutureWarning: the 'cols' keyword is deprecated and will be removed in a + future version please takes steps to stop use of 'cols' + should raise warning + >>> f(another_param='should not raise warning') # doctest: +SKIP + should not raise warning + + >>> f(cols='should raise warning', another_param='') # doctest: +SKIP + FutureWarning: the 'cols' keyword is deprecated and will be removed in a + future version please takes steps to stop use of 'cols' + should raise warning + """ + if mapping is not None and not hasattr(mapping, "get") and not callable(mapping): + raise TypeError( + "mapping from old to new argument values must be dict or callable!" + ) + + def _deprecate_kwarg(func: F) -> F: + @wraps(func) + def wrapper(*args, **kwargs) -> Callable[..., Any]: + old_arg_value = kwargs.pop(old_arg_name, None) + + if old_arg_value is not None: + if new_arg_name is None: + msg = ( + f"the {repr(old_arg_name)} keyword is deprecated and " + "will be removed in a future version. Please take " + f"steps to stop the use of {repr(old_arg_name)}" + ) + warnings.warn(msg, FutureWarning, stacklevel=stacklevel) + kwargs[old_arg_name] = old_arg_value + return func(*args, **kwargs) + + elif mapping is not None: + if callable(mapping): + new_arg_value = mapping(old_arg_value) + else: + new_arg_value = mapping.get(old_arg_value, old_arg_value) + msg = ( + f"the {old_arg_name}={repr(old_arg_value)} keyword is " + "deprecated, use " + f"{new_arg_name}={repr(new_arg_value)} instead." + ) + else: + new_arg_value = old_arg_value + msg = ( + f"the {repr(old_arg_name)} keyword is deprecated, " + f"use {repr(new_arg_name)} instead." + ) + + warnings.warn(msg, FutureWarning, stacklevel=stacklevel) + if kwargs.get(new_arg_name) is not None: + msg = ( + f"Can only specify {repr(old_arg_name)} " + f"or {repr(new_arg_name)}, not both." + ) + raise TypeError(msg) + kwargs[new_arg_name] = new_arg_value + return func(*args, **kwargs) + + return cast(F, wrapper) + + return _deprecate_kwarg + + +def _format_argument_list(allow_args: list[str]) -> str: + """ + Convert the allow_args argument (either string or integer) of + `deprecate_nonkeyword_arguments` function to a string describing + it to be inserted into warning message. + + Parameters + ---------- + allowed_args : list, tuple or int + The `allowed_args` argument for `deprecate_nonkeyword_arguments`, + but None value is not allowed. + + Returns + ------- + str + The substring describing the argument list in best way to be + inserted to the warning message. + + Examples + -------- + `format_argument_list([])` -> '' + `format_argument_list(['a'])` -> "except for the arguments 'a'" + `format_argument_list(['a', 'b'])` -> "except for the arguments 'a' and 'b'" + `format_argument_list(['a', 'b', 'c'])` -> + "except for the arguments 'a', 'b' and 'c'" + """ + if "self" in allow_args: + allow_args.remove("self") + if not allow_args: + return "" + elif len(allow_args) == 1: + return f" except for the argument '{allow_args[0]}'" + else: + last = allow_args[-1] + args = ", ".join(["'" + x + "'" for x in allow_args[:-1]]) + return f" except for the arguments {args} and '{last}'" + + +def future_version_msg(version: str | None) -> str: + """Specify which version of pandas the deprecation will take place in.""" + if version is None: + return "In a future version of pandas" + else: + return f"Starting with pandas version {version}" + + +def deprecate_nonkeyword_arguments( + version: str | None, + allowed_args: list[str] | None = None, + name: str | None = None, +) -> Callable[[F], F]: + """ + Decorator to deprecate a use of non-keyword arguments of a function. + + Parameters + ---------- + version : str, optional + The version in which positional arguments will become + keyword-only. If None, then the warning message won't + specify any particular version. + + allowed_args : list, optional + In case of list, it must be the list of names of some + first arguments of the decorated functions that are + OK to be given as positional arguments. In case of None value, + defaults to list of all arguments not having the + default value. + + name : str, optional + The specific name of the function to show in the warning + message. If None, then the Qualified name of the function + is used. + """ + + def decorate(func): + old_sig = inspect.signature(func) + + if allowed_args is not None: + allow_args = allowed_args + else: + allow_args = [ + p.name + for p in old_sig.parameters.values() + if p.kind in (p.POSITIONAL_ONLY, p.POSITIONAL_OR_KEYWORD) + and p.default is p.empty + ] + + new_params = [ + p.replace(kind=p.KEYWORD_ONLY) + if ( + p.kind in (p.POSITIONAL_ONLY, p.POSITIONAL_OR_KEYWORD) + and p.name not in allow_args + ) + else p + for p in old_sig.parameters.values() + ] + new_params.sort(key=lambda p: p.kind) + new_sig = old_sig.replace(parameters=new_params) + + num_allow_args = len(allow_args) + msg = ( + f"{future_version_msg(version)} all arguments of " + f"{name or func.__qualname__}{{arguments}} will be keyword-only." + ) + + @wraps(func) + def wrapper(*args, **kwargs): + if len(args) > num_allow_args: + warnings.warn( + msg.format(arguments=_format_argument_list(allow_args)), + FutureWarning, + stacklevel=find_stack_level(), + ) + return func(*args, **kwargs) + + # error: "Callable[[VarArg(Any), KwArg(Any)], Any]" has no + # attribute "__signature__" + wrapper.__signature__ = new_sig # type: ignore[attr-defined] + return wrapper + + return decorate + + +def doc(*docstrings: None | str | Callable, **params) -> Callable[[F], F]: + """ + A decorator to take docstring templates, concatenate them and perform string + substitution on them. + + This decorator will add a variable "_docstring_components" to the wrapped + callable to keep track the original docstring template for potential usage. + If it should be consider as a template, it will be saved as a string. + Otherwise, it will be saved as callable, and later user __doc__ and dedent + to get docstring. + + Parameters + ---------- + *docstrings : None, str, or callable + The string / docstring / docstring template to be appended in order + after default docstring under callable. + **params + The string which would be used to format docstring template. + """ + + def decorator(decorated: F) -> F: + # collecting docstring and docstring templates + docstring_components: list[str | Callable] = [] + if decorated.__doc__: + docstring_components.append(dedent(decorated.__doc__)) + + for docstring in docstrings: + if docstring is None: + continue + if hasattr(docstring, "_docstring_components"): + docstring_components.extend( + docstring._docstring_components # pyright: ignore[reportGeneralTypeIssues] + ) + elif isinstance(docstring, str) or docstring.__doc__: + docstring_components.append(docstring) + + params_applied = [ + component.format(**params) + if isinstance(component, str) and len(params) > 0 + else component + for component in docstring_components + ] + + decorated.__doc__ = "".join( + [ + component + if isinstance(component, str) + else dedent(component.__doc__ or "") + for component in params_applied + ] + ) + + # error: "F" has no attribute "_docstring_components" + decorated._docstring_components = ( # type: ignore[attr-defined] + docstring_components + ) + return decorated + + return decorator + + +# Substitution and Appender are derived from matplotlib.docstring (1.1.0) +# module https://matplotlib.org/users/license.html + + +class Substitution: + """ + A decorator to take a function's docstring and perform string + substitution on it. + + This decorator should be robust even if func.__doc__ is None + (for example, if -OO was passed to the interpreter) + + Usage: construct a docstring.Substitution with a sequence or + dictionary suitable for performing substitution; then + decorate a suitable function with the constructed object. e.g. + + sub_author_name = Substitution(author='Jason') + + @sub_author_name + def some_function(x): + "%(author)s wrote this function" + + # note that some_function.__doc__ is now "Jason wrote this function" + + One can also use positional arguments. + + sub_first_last_names = Substitution('Edgar Allen', 'Poe') + + @sub_first_last_names + def some_function(x): + "%s %s wrote the Raven" + """ + + def __init__(self, *args, **kwargs) -> None: + if args and kwargs: + raise AssertionError("Only positional or keyword args are allowed") + + self.params = args or kwargs + + def __call__(self, func: F) -> F: + func.__doc__ = func.__doc__ and func.__doc__ % self.params + return func + + def update(self, *args, **kwargs) -> None: + """ + Update self.params with supplied args. + """ + if isinstance(self.params, dict): + self.params.update(*args, **kwargs) + + +class Appender: + """ + A function decorator that will append an addendum to the docstring + of the target function. + + This decorator should be robust even if func.__doc__ is None + (for example, if -OO was passed to the interpreter). + + Usage: construct a docstring.Appender with a string to be joined to + the original docstring. An optional 'join' parameter may be supplied + which will be used to join the docstring and addendum. e.g. + + add_copyright = Appender("Copyright (c) 2009", join='\n') + + @add_copyright + def my_dog(has='fleas'): + "This docstring will have a copyright below" + pass + """ + + addendum: str | None + + def __init__(self, addendum: str | None, join: str = "", indents: int = 0) -> None: + if indents > 0: + self.addendum = indent(addendum, indents=indents) + else: + self.addendum = addendum + self.join = join + + def __call__(self, func: T) -> T: + func.__doc__ = func.__doc__ if func.__doc__ else "" + self.addendum = self.addendum if self.addendum else "" + docitems = [func.__doc__, self.addendum] + func.__doc__ = dedent(self.join.join(docitems)) + return func + + +def indent(text: str | None, indents: int = 1) -> str: + if not text or not isinstance(text, str): + return "" + jointext = "".join(["\n"] + [" "] * indents) + return jointext.join(text.split("\n")) + + +__all__ = [ + "Appender", + "cache_readonly", + "deprecate", + "deprecate_kwarg", + "deprecate_nonkeyword_arguments", + "doc", + "future_version_msg", + "Substitution", +] diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/util/_doctools.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/util/_doctools.py new file mode 100644 index 0000000000000000000000000000000000000000..12619abf4baaf336dfd3d5ae78a9bc2133f310c0 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/util/_doctools.py @@ -0,0 +1,202 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import numpy as np + +import pandas as pd + +if TYPE_CHECKING: + from collections.abc import Iterable + + +class TablePlotter: + """ + Layout some DataFrames in vertical/horizontal layout for explanation. + Used in merging.rst + """ + + def __init__( + self, + cell_width: float = 0.37, + cell_height: float = 0.25, + font_size: float = 7.5, + ) -> None: + self.cell_width = cell_width + self.cell_height = cell_height + self.font_size = font_size + + def _shape(self, df: pd.DataFrame) -> tuple[int, int]: + """ + Calculate table shape considering index levels. + """ + row, col = df.shape + return row + df.columns.nlevels, col + df.index.nlevels + + def _get_cells(self, left, right, vertical) -> tuple[int, int]: + """ + Calculate appropriate figure size based on left and right data. + """ + if vertical: + # calculate required number of cells + vcells = max(sum(self._shape(df)[0] for df in left), self._shape(right)[0]) + hcells = max(self._shape(df)[1] for df in left) + self._shape(right)[1] + else: + vcells = max([self._shape(df)[0] for df in left] + [self._shape(right)[0]]) + hcells = sum([self._shape(df)[1] for df in left] + [self._shape(right)[1]]) + return hcells, vcells + + def plot(self, left, right, labels: Iterable[str] = (), vertical: bool = True): + """ + Plot left / right DataFrames in specified layout. + + Parameters + ---------- + left : list of DataFrames before operation is applied + right : DataFrame of operation result + labels : list of str to be drawn as titles of left DataFrames + vertical : bool, default True + If True, use vertical layout. If False, use horizontal layout. + """ + from matplotlib import gridspec + import matplotlib.pyplot as plt + + if not isinstance(left, list): + left = [left] + left = [self._conv(df) for df in left] + right = self._conv(right) + + hcells, vcells = self._get_cells(left, right, vertical) + + if vertical: + figsize = self.cell_width * hcells, self.cell_height * vcells + else: + # include margin for titles + figsize = self.cell_width * hcells, self.cell_height * vcells + fig = plt.figure(figsize=figsize) + + if vertical: + gs = gridspec.GridSpec(len(left), hcells) + # left + max_left_cols = max(self._shape(df)[1] for df in left) + max_left_rows = max(self._shape(df)[0] for df in left) + for i, (_left, _label) in enumerate(zip(left, labels)): + ax = fig.add_subplot(gs[i, 0:max_left_cols]) + self._make_table(ax, _left, title=_label, height=1.0 / max_left_rows) + # right + ax = plt.subplot(gs[:, max_left_cols:]) + self._make_table(ax, right, title="Result", height=1.05 / vcells) + fig.subplots_adjust(top=0.9, bottom=0.05, left=0.05, right=0.95) + else: + max_rows = max(self._shape(df)[0] for df in left + [right]) + height = 1.0 / np.max(max_rows) + gs = gridspec.GridSpec(1, hcells) + # left + i = 0 + for df, _label in zip(left, labels): + sp = self._shape(df) + ax = fig.add_subplot(gs[0, i : i + sp[1]]) + self._make_table(ax, df, title=_label, height=height) + i += sp[1] + # right + ax = plt.subplot(gs[0, i:]) + self._make_table(ax, right, title="Result", height=height) + fig.subplots_adjust(top=0.85, bottom=0.05, left=0.05, right=0.95) + + return fig + + def _conv(self, data): + """ + Convert each input to appropriate for table outplot. + """ + if isinstance(data, pd.Series): + if data.name is None: + data = data.to_frame(name="") + else: + data = data.to_frame() + data = data.fillna("NaN") + return data + + def _insert_index(self, data): + # insert is destructive + data = data.copy() + idx_nlevels = data.index.nlevels + if idx_nlevels == 1: + data.insert(0, "Index", data.index) + else: + for i in range(idx_nlevels): + data.insert(i, f"Index{i}", data.index._get_level_values(i)) + + col_nlevels = data.columns.nlevels + if col_nlevels > 1: + col = data.columns._get_level_values(0) + values = [ + data.columns._get_level_values(i)._values for i in range(1, col_nlevels) + ] + col_df = pd.DataFrame(values) + data.columns = col_df.columns + data = pd.concat([col_df, data]) + data.columns = col + return data + + def _make_table(self, ax, df, title: str, height: float | None = None) -> None: + if df is None: + ax.set_visible(False) + return + + from pandas import plotting + + idx_nlevels = df.index.nlevels + col_nlevels = df.columns.nlevels + # must be convert here to get index levels for colorization + df = self._insert_index(df) + tb = plotting.table(ax, df, loc=9) + tb.set_fontsize(self.font_size) + + if height is None: + height = 1.0 / (len(df) + 1) + + props = tb.properties() + for (r, c), cell in props["celld"].items(): + if c == -1: + cell.set_visible(False) + elif r < col_nlevels and c < idx_nlevels: + cell.set_visible(False) + elif r < col_nlevels or c < idx_nlevels: + cell.set_facecolor("#AAAAAA") + cell.set_height(height) + + ax.set_title(title, size=self.font_size) + ax.axis("off") + + +def main() -> None: + import matplotlib.pyplot as plt + + p = TablePlotter() + + df1 = pd.DataFrame({"A": [10, 11, 12], "B": [20, 21, 22], "C": [30, 31, 32]}) + df2 = pd.DataFrame({"A": [10, 12], "C": [30, 32]}) + + p.plot([df1, df2], pd.concat([df1, df2]), labels=["df1", "df2"], vertical=True) + plt.show() + + df3 = pd.DataFrame({"X": [10, 12], "Z": [30, 32]}) + + p.plot( + [df1, df3], pd.concat([df1, df3], axis=1), labels=["df1", "df2"], vertical=False + ) + plt.show() + + idx = pd.MultiIndex.from_tuples( + [(1, "A"), (1, "B"), (1, "C"), (2, "A"), (2, "B"), (2, "C")] + ) + column = pd.MultiIndex.from_tuples([(1, "A"), (1, "B")]) + df3 = pd.DataFrame({"v1": [1, 2, 3, 4, 5, 6], "v2": [5, 6, 7, 8, 9, 10]}, index=idx) + df3.columns = column + p.plot(df3, df3, labels=["df3"]) + plt.show() + + +if __name__ == "__main__": + main() diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/util/_exceptions.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/util/_exceptions.py new file mode 100644 index 0000000000000000000000000000000000000000..5f50838d373154868ff7414775763a1c66853c65 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/util/_exceptions.py @@ -0,0 +1,103 @@ +from __future__ import annotations + +import contextlib +import inspect +import os +import re +from typing import TYPE_CHECKING +import warnings + +if TYPE_CHECKING: + from collections.abc import Generator + from types import FrameType + + +@contextlib.contextmanager +def rewrite_exception(old_name: str, new_name: str) -> Generator[None, None, None]: + """ + Rewrite the message of an exception. + """ + try: + yield + except Exception as err: + if not err.args: + raise + msg = str(err.args[0]) + msg = msg.replace(old_name, new_name) + args: tuple[str, ...] = (msg,) + if len(err.args) > 1: + args = args + err.args[1:] + err.args = args + raise + + +def find_stack_level() -> int: + """ + Find the first place in the stack that is not inside pandas + (tests notwithstanding). + """ + + import pandas as pd + + pkg_dir = os.path.dirname(pd.__file__) + test_dir = os.path.join(pkg_dir, "tests") + + # https://stackoverflow.com/questions/17407119/python-inspect-stack-is-slow + frame: FrameType | None = inspect.currentframe() + try: + n = 0 + while frame: + filename = inspect.getfile(frame) + if filename.startswith(pkg_dir) and not filename.startswith(test_dir): + frame = frame.f_back + n += 1 + else: + break + finally: + # See note in + # https://docs.python.org/3/library/inspect.html#inspect.Traceback + del frame + return n + + +@contextlib.contextmanager +def rewrite_warning( + target_message: str, + target_category: type[Warning], + new_message: str, + new_category: type[Warning] | None = None, +) -> Generator[None, None, None]: + """ + Rewrite the message of a warning. + + Parameters + ---------- + target_message : str + Warning message to match. + target_category : Warning + Warning type to match. + new_message : str + New warning message to emit. + new_category : Warning or None, default None + New warning type to emit. When None, will be the same as target_category. + """ + if new_category is None: + new_category = target_category + with warnings.catch_warnings(record=True) as record: + yield + if len(record) > 0: + match = re.compile(target_message) + for warning in record: + if warning.category is target_category and re.search( + match, str(warning.message) + ): + category = new_category + message: Warning | str = new_message + else: + category, message = warning.category, warning.message + warnings.warn_explicit( + message=message, + category=category, + filename=warning.filename, + lineno=warning.lineno, + ) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/util/_print_versions.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/util/_print_versions.py new file mode 100644 index 0000000000000000000000000000000000000000..4ede5627c28b9a3eaf97f09f6a28642523ce5833 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/util/_print_versions.py @@ -0,0 +1,158 @@ +from __future__ import annotations + +import codecs +import json +import locale +import os +import platform +import struct +import sys +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from pandas._typing import JSONSerializable + +from pandas.compat._optional import ( + VERSIONS, + get_version, + import_optional_dependency, +) + + +def _get_commit_hash() -> str | None: + """ + Use vendored versioneer code to get git hash, which handles + git worktree correctly. + """ + try: + from pandas._version_meson import ( # pyright: ignore [reportMissingImports] + __git_version__, + ) + + return __git_version__ + except ImportError: + from pandas._version import get_versions + + versions = get_versions() + return versions["full-revisionid"] + + +def _get_sys_info() -> dict[str, JSONSerializable]: + """ + Returns system information as a JSON serializable dictionary. + """ + uname_result = platform.uname() + language_code, encoding = locale.getlocale() + return { + "commit": _get_commit_hash(), + "python": platform.python_version(), + "python-bits": struct.calcsize("P") * 8, + "OS": uname_result.system, + "OS-release": uname_result.release, + "Version": uname_result.version, + "machine": uname_result.machine, + "processor": uname_result.processor, + "byteorder": sys.byteorder, + "LC_ALL": os.environ.get("LC_ALL"), + "LANG": os.environ.get("LANG"), + "LOCALE": {"language-code": language_code, "encoding": encoding}, + } + + +def _get_dependency_info() -> dict[str, JSONSerializable]: + """ + Returns dependency information as a JSON serializable dictionary. + """ + deps = [ + "pandas", + # required + "numpy", + "pytz", + "dateutil", + # install / build, + "pip", + "Cython", + # docs + "sphinx", + # Other, not imported. + "IPython", + ] + # Optional dependencies + deps.extend(list(VERSIONS)) + + result: dict[str, JSONSerializable] = {} + for modname in deps: + try: + mod = import_optional_dependency(modname, errors="ignore") + except Exception: + # Dependency conflicts may cause a non ImportError + result[modname] = "N/A" + else: + result[modname] = get_version(mod) if mod else None + return result + + +def show_versions(as_json: str | bool = False) -> None: + """ + Provide useful information, important for bug reports. + + It comprises info about hosting operation system, pandas version, + and versions of other installed relative packages. + + Parameters + ---------- + as_json : str or bool, default False + * If False, outputs info in a human readable form to the console. + * If str, it will be considered as a path to a file. + Info will be written to that file in JSON format. + * If True, outputs info in JSON format to the console. + + Examples + -------- + >>> pd.show_versions() # doctest: +SKIP + Your output may look something like this: + INSTALLED VERSIONS + ------------------ + commit : 37ea63d540fd27274cad6585082c91b1283f963d + python : 3.10.6.final.0 + python-bits : 64 + OS : Linux + OS-release : 5.10.102.1-microsoft-standard-WSL2 + Version : #1 SMP Wed Mar 2 00:30:59 UTC 2022 + machine : x86_64 + processor : x86_64 + byteorder : little + LC_ALL : None + LANG : en_GB.UTF-8 + LOCALE : en_GB.UTF-8 + pandas : 2.0.1 + numpy : 1.24.3 + ... + """ + sys_info = _get_sys_info() + deps = _get_dependency_info() + + if as_json: + j = {"system": sys_info, "dependencies": deps} + + if as_json is True: + sys.stdout.writelines(json.dumps(j, indent=2)) + else: + assert isinstance(as_json, str) # needed for mypy + with codecs.open(as_json, "wb", encoding="utf8") as f: + json.dump(j, f, indent=2) + + else: + assert isinstance(sys_info["LOCALE"], dict) # needed for mypy + language_code = sys_info["LOCALE"]["language-code"] + encoding = sys_info["LOCALE"]["encoding"] + sys_info["LOCALE"] = f"{language_code}.{encoding}" + + maxlen = max(len(x) for x in deps) + print("\nINSTALLED VERSIONS") + print("------------------") + for k, v in sys_info.items(): + print(f"{k:<{maxlen}}: {v}") + print("") + for k, v in deps.items(): + print(f"{k:<{maxlen}}: {v}") diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/util/_test_decorators.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/util/_test_decorators.py new file mode 100644 index 0000000000000000000000000000000000000000..2c1912bce856dd2694447d820ea2c5124be9c1a0 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/util/_test_decorators.py @@ -0,0 +1,173 @@ +""" +This module provides decorator functions which can be applied to test objects +in order to skip those objects when certain conditions occur. A sample use case +is to detect if the platform is missing ``matplotlib``. If so, any test objects +which require ``matplotlib`` and decorated with ``@td.skip_if_no("matplotlib")`` +will be skipped by ``pytest`` during the execution of the test suite. + +To illustrate, after importing this module: + +import pandas.util._test_decorators as td + +The decorators can be applied to classes: + +@td.skip_if_no("package") +class Foo: + ... + +Or individual functions: + +@td.skip_if_no("package") +def test_foo(): + ... + +For more information, refer to the ``pytest`` documentation on ``skipif``. +""" +from __future__ import annotations + +import locale +from typing import ( + TYPE_CHECKING, + Callable, +) + +import pytest + +from pandas._config import get_option + +if TYPE_CHECKING: + from pandas._typing import F + +from pandas._config.config import _get_option + +from pandas.compat import ( + IS64, + is_platform_windows, +) +from pandas.compat._optional import import_optional_dependency + + +def skip_if_installed(package: str) -> pytest.MarkDecorator: + """ + Skip a test if a package is installed. + + Parameters + ---------- + package : str + The name of the package. + + Returns + ------- + pytest.MarkDecorator + a pytest.mark.skipif to use as either a test decorator or a + parametrization mark. + """ + return pytest.mark.skipif( + bool(import_optional_dependency(package, errors="ignore")), + reason=f"Skipping because {package} is installed.", + ) + + +def skip_if_no(package: str, min_version: str | None = None) -> pytest.MarkDecorator: + """ + Generic function to help skip tests when required packages are not + present on the testing system. + + This function returns a pytest mark with a skip condition that will be + evaluated during test collection. An attempt will be made to import the + specified ``package`` and optionally ensure it meets the ``min_version`` + + The mark can be used as either a decorator for a test class or to be + applied to parameters in pytest.mark.parametrize calls or parametrized + fixtures. Use pytest.importorskip if an imported moduled is later needed + or for test functions. + + If the import and version check are unsuccessful, then the test function + (or test case when used in conjunction with parametrization) will be + skipped. + + Parameters + ---------- + package: str + The name of the required package. + min_version: str or None, default None + Optional minimum version of the package. + + Returns + ------- + pytest.MarkDecorator + a pytest.mark.skipif to use as either a test decorator or a + parametrization mark. + """ + msg = f"Could not import '{package}'" + if min_version: + msg += f" satisfying a min_version of {min_version}" + return pytest.mark.skipif( + not bool( + import_optional_dependency( + package, errors="ignore", min_version=min_version + ) + ), + reason=msg, + ) + + +skip_if_32bit = pytest.mark.skipif(not IS64, reason="skipping for 32 bit") +skip_if_windows = pytest.mark.skipif(is_platform_windows(), reason="Running on Windows") +skip_if_not_us_locale = pytest.mark.skipif( + locale.getlocale()[0] != "en_US", + reason=f"Set local {locale.getlocale()[0]} is not en_US", +) + + +def parametrize_fixture_doc(*args) -> Callable[[F], F]: + """ + Intended for use as a decorator for parametrized fixture, + this function will wrap the decorated function with a pytest + ``parametrize_fixture_doc`` mark. That mark will format + initial fixture docstring by replacing placeholders {0}, {1} etc + with parameters passed as arguments. + + Parameters + ---------- + args: iterable + Positional arguments for docstring. + + Returns + ------- + function + The decorated function wrapped within a pytest + ``parametrize_fixture_doc`` mark + """ + + def documented_fixture(fixture): + fixture.__doc__ = fixture.__doc__.format(*args) + return fixture + + return documented_fixture + + +def mark_array_manager_not_yet_implemented(request) -> None: + mark = pytest.mark.xfail(reason="Not yet implemented for ArrayManager") + request.applymarker(mark) + + +skip_array_manager_not_yet_implemented = pytest.mark.xfail( + _get_option("mode.data_manager", silent=True) == "array", + reason="Not yet implemented for ArrayManager", +) + +skip_array_manager_invalid_test = pytest.mark.skipif( + _get_option("mode.data_manager", silent=True) == "array", + reason="Test that relies on BlockManager internals or specific behaviour", +) + +skip_copy_on_write_not_yet_implemented = pytest.mark.xfail( + get_option("mode.copy_on_write") is True, + reason="Not yet implemented/adapted for Copy-on-Write mode", +) + +skip_copy_on_write_invalid_test = pytest.mark.skipif( + get_option("mode.copy_on_write") is True, + reason="Test not valid for Copy-on-Write mode", +) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/util/_tester.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/util/_tester.py new file mode 100644 index 0000000000000000000000000000000000000000..7cfddef7ddff87275ebf31eb7ec10e65d26f8668 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/util/_tester.py @@ -0,0 +1,53 @@ +""" +Entrypoint for testing from the top-level namespace. +""" +from __future__ import annotations + +import os +import sys + +from pandas.compat._optional import import_optional_dependency + +PKG = os.path.dirname(os.path.dirname(__file__)) + + +def test(extra_args: list[str] | None = None, run_doctests: bool = False) -> None: + """ + Run the pandas test suite using pytest. + + By default, runs with the marks -m "not slow and not network and not db" + + Parameters + ---------- + extra_args : list[str], default None + Extra marks to run the tests. + run_doctests : bool, default False + Whether to only run the Python and Cython doctests. If you would like to run + both doctests/regular tests, just append "--doctest-modules"/"--doctest-cython" + to extra_args. + + Examples + -------- + >>> pd.test() # doctest: +SKIP + running: pytest... + """ + pytest = import_optional_dependency("pytest") + import_optional_dependency("hypothesis") + cmd = ["-m not slow and not network and not db"] + if extra_args: + if not isinstance(extra_args, list): + extra_args = [extra_args] + cmd = extra_args + if run_doctests: + cmd = [ + "--doctest-modules", + "--doctest-cython", + f"--ignore={os.path.join(PKG, 'tests')}", + ] + cmd += [PKG] + joined = " ".join(cmd) + print(f"running: pytest {joined}") + sys.exit(pytest.main(cmd)) + + +__all__ = ["test"] diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/util/_validators.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/util/_validators.py new file mode 100644 index 0000000000000000000000000000000000000000..cb0b4d549f49ea972d50c97986c60be64c021c3c --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pandas/util/_validators.py @@ -0,0 +1,456 @@ +""" +Module that contains many useful utilities +for validating data or function arguments +""" +from __future__ import annotations + +from collections.abc import ( + Iterable, + Sequence, +) +from typing import ( + TypeVar, + overload, +) + +import numpy as np + +from pandas._libs import lib + +from pandas.core.dtypes.common import ( + is_bool, + is_integer, +) + +BoolishT = TypeVar("BoolishT", bool, int) +BoolishNoneT = TypeVar("BoolishNoneT", bool, int, None) + + +def _check_arg_length(fname, args, max_fname_arg_count, compat_args) -> None: + """ + Checks whether 'args' has length of at most 'compat_args'. Raises + a TypeError if that is not the case, similar to in Python when a + function is called with too many arguments. + """ + if max_fname_arg_count < 0: + raise ValueError("'max_fname_arg_count' must be non-negative") + + if len(args) > len(compat_args): + max_arg_count = len(compat_args) + max_fname_arg_count + actual_arg_count = len(args) + max_fname_arg_count + argument = "argument" if max_arg_count == 1 else "arguments" + + raise TypeError( + f"{fname}() takes at most {max_arg_count} {argument} " + f"({actual_arg_count} given)" + ) + + +def _check_for_default_values(fname, arg_val_dict, compat_args) -> None: + """ + Check that the keys in `arg_val_dict` are mapped to their + default values as specified in `compat_args`. + + Note that this function is to be called only when it has been + checked that arg_val_dict.keys() is a subset of compat_args + """ + for key in arg_val_dict: + # try checking equality directly with '=' operator, + # as comparison may have been overridden for the left + # hand object + try: + v1 = arg_val_dict[key] + v2 = compat_args[key] + + # check for None-ness otherwise we could end up + # comparing a numpy array vs None + if (v1 is not None and v2 is None) or (v1 is None and v2 is not None): + match = False + else: + match = v1 == v2 + + if not is_bool(match): + raise ValueError("'match' is not a boolean") + + # could not compare them directly, so try comparison + # using the 'is' operator + except ValueError: + match = arg_val_dict[key] is compat_args[key] + + if not match: + raise ValueError( + f"the '{key}' parameter is not supported in " + f"the pandas implementation of {fname}()" + ) + + +def validate_args(fname, args, max_fname_arg_count, compat_args) -> None: + """ + Checks whether the length of the `*args` argument passed into a function + has at most `len(compat_args)` arguments and whether or not all of these + elements in `args` are set to their default values. + + Parameters + ---------- + fname : str + The name of the function being passed the `*args` parameter + args : tuple + The `*args` parameter passed into a function + max_fname_arg_count : int + The maximum number of arguments that the function `fname` + can accept, excluding those in `args`. Used for displaying + appropriate error messages. Must be non-negative. + compat_args : dict + A dictionary of keys and their associated default values. + In order to accommodate buggy behaviour in some versions of `numpy`, + where a signature displayed keyword arguments but then passed those + arguments **positionally** internally when calling downstream + implementations, a dict ensures that the original + order of the keyword arguments is enforced. + + Raises + ------ + TypeError + If `args` contains more values than there are `compat_args` + ValueError + If `args` contains values that do not correspond to those + of the default values specified in `compat_args` + """ + _check_arg_length(fname, args, max_fname_arg_count, compat_args) + + # We do this so that we can provide a more informative + # error message about the parameters that we are not + # supporting in the pandas implementation of 'fname' + kwargs = dict(zip(compat_args, args)) + _check_for_default_values(fname, kwargs, compat_args) + + +def _check_for_invalid_keys(fname, kwargs, compat_args) -> None: + """ + Checks whether 'kwargs' contains any keys that are not + in 'compat_args' and raises a TypeError if there is one. + """ + # set(dict) --> set of the dictionary's keys + diff = set(kwargs) - set(compat_args) + + if diff: + bad_arg = next(iter(diff)) + raise TypeError(f"{fname}() got an unexpected keyword argument '{bad_arg}'") + + +def validate_kwargs(fname, kwargs, compat_args) -> None: + """ + Checks whether parameters passed to the **kwargs argument in a + function `fname` are valid parameters as specified in `*compat_args` + and whether or not they are set to their default values. + + Parameters + ---------- + fname : str + The name of the function being passed the `**kwargs` parameter + kwargs : dict + The `**kwargs` parameter passed into `fname` + compat_args: dict + A dictionary of keys that `kwargs` is allowed to have and their + associated default values + + Raises + ------ + TypeError if `kwargs` contains keys not in `compat_args` + ValueError if `kwargs` contains keys in `compat_args` that do not + map to the default values specified in `compat_args` + """ + kwds = kwargs.copy() + _check_for_invalid_keys(fname, kwargs, compat_args) + _check_for_default_values(fname, kwds, compat_args) + + +def validate_args_and_kwargs( + fname, args, kwargs, max_fname_arg_count, compat_args +) -> None: + """ + Checks whether parameters passed to the *args and **kwargs argument in a + function `fname` are valid parameters as specified in `*compat_args` + and whether or not they are set to their default values. + + Parameters + ---------- + fname: str + The name of the function being passed the `**kwargs` parameter + args: tuple + The `*args` parameter passed into a function + kwargs: dict + The `**kwargs` parameter passed into `fname` + max_fname_arg_count: int + The minimum number of arguments that the function `fname` + requires, excluding those in `args`. Used for displaying + appropriate error messages. Must be non-negative. + compat_args: dict + A dictionary of keys that `kwargs` is allowed to + have and their associated default values. + + Raises + ------ + TypeError if `args` contains more values than there are + `compat_args` OR `kwargs` contains keys not in `compat_args` + ValueError if `args` contains values not at the default value (`None`) + `kwargs` contains keys in `compat_args` that do not map to the default + value as specified in `compat_args` + + See Also + -------- + validate_args : Purely args validation. + validate_kwargs : Purely kwargs validation. + + """ + # Check that the total number of arguments passed in (i.e. + # args and kwargs) does not exceed the length of compat_args + _check_arg_length( + fname, args + tuple(kwargs.values()), max_fname_arg_count, compat_args + ) + + # Check there is no overlap with the positional and keyword + # arguments, similar to what is done in actual Python functions + args_dict = dict(zip(compat_args, args)) + + for key in args_dict: + if key in kwargs: + raise TypeError( + f"{fname}() got multiple values for keyword argument '{key}'" + ) + + kwargs.update(args_dict) + validate_kwargs(fname, kwargs, compat_args) + + +def validate_bool_kwarg( + value: BoolishNoneT, + arg_name: str, + none_allowed: bool = True, + int_allowed: bool = False, +) -> BoolishNoneT: + """ + Ensure that argument passed in arg_name can be interpreted as boolean. + + Parameters + ---------- + value : bool + Value to be validated. + arg_name : str + Name of the argument. To be reflected in the error message. + none_allowed : bool, default True + Whether to consider None to be a valid boolean. + int_allowed : bool, default False + Whether to consider integer value to be a valid boolean. + + Returns + ------- + value + The same value as input. + + Raises + ------ + ValueError + If the value is not a valid boolean. + """ + good_value = is_bool(value) + if none_allowed: + good_value = good_value or (value is None) + + if int_allowed: + good_value = good_value or isinstance(value, int) + + if not good_value: + raise ValueError( + f'For argument "{arg_name}" expected type bool, received ' + f"type {type(value).__name__}." + ) + return value # pyright: ignore[reportGeneralTypeIssues] + + +def validate_fillna_kwargs(value, method, validate_scalar_dict_value: bool = True): + """ + Validate the keyword arguments to 'fillna'. + + This checks that exactly one of 'value' and 'method' is specified. + If 'method' is specified, this validates that it's a valid method. + + Parameters + ---------- + value, method : object + The 'value' and 'method' keyword arguments for 'fillna'. + validate_scalar_dict_value : bool, default True + Whether to validate that 'value' is a scalar or dict. Specifically, + validate that it is not a list or tuple. + + Returns + ------- + value, method : object + """ + from pandas.core.missing import clean_fill_method + + if value is None and method is None: + raise ValueError("Must specify a fill 'value' or 'method'.") + if value is None and method is not None: + method = clean_fill_method(method) + + elif value is not None and method is None: + if validate_scalar_dict_value and isinstance(value, (list, tuple)): + raise TypeError( + '"value" parameter must be a scalar or dict, but ' + f'you passed a "{type(value).__name__}"' + ) + + elif value is not None and method is not None: + raise ValueError("Cannot specify both 'value' and 'method'.") + + return value, method + + +def validate_percentile(q: float | Iterable[float]) -> np.ndarray: + """ + Validate percentiles (used by describe and quantile). + + This function checks if the given float or iterable of floats is a valid percentile + otherwise raises a ValueError. + + Parameters + ---------- + q: float or iterable of floats + A single percentile or an iterable of percentiles. + + Returns + ------- + ndarray + An ndarray of the percentiles if valid. + + Raises + ------ + ValueError if percentiles are not in given interval([0, 1]). + """ + q_arr = np.asarray(q) + # Don't change this to an f-string. The string formatting + # is too expensive for cases where we don't need it. + msg = "percentiles should all be in the interval [0, 1]" + if q_arr.ndim == 0: + if not 0 <= q_arr <= 1: + raise ValueError(msg) + else: + if not all(0 <= qs <= 1 for qs in q_arr): + raise ValueError(msg) + return q_arr + + +@overload +def validate_ascending(ascending: BoolishT) -> BoolishT: + ... + + +@overload +def validate_ascending(ascending: Sequence[BoolishT]) -> list[BoolishT]: + ... + + +def validate_ascending( + ascending: bool | int | Sequence[BoolishT], +) -> bool | int | list[BoolishT]: + """Validate ``ascending`` kwargs for ``sort_index`` method.""" + kwargs = {"none_allowed": False, "int_allowed": True} + if not isinstance(ascending, Sequence): + return validate_bool_kwarg(ascending, "ascending", **kwargs) + + return [validate_bool_kwarg(item, "ascending", **kwargs) for item in ascending] + + +def validate_endpoints(closed: str | None) -> tuple[bool, bool]: + """ + Check that the `closed` argument is among [None, "left", "right"] + + Parameters + ---------- + closed : {None, "left", "right"} + + Returns + ------- + left_closed : bool + right_closed : bool + + Raises + ------ + ValueError : if argument is not among valid values + """ + left_closed = False + right_closed = False + + if closed is None: + left_closed = True + right_closed = True + elif closed == "left": + left_closed = True + elif closed == "right": + right_closed = True + else: + raise ValueError("Closed has to be either 'left', 'right' or None") + + return left_closed, right_closed + + +def validate_inclusive(inclusive: str | None) -> tuple[bool, bool]: + """ + Check that the `inclusive` argument is among {"both", "neither", "left", "right"}. + + Parameters + ---------- + inclusive : {"both", "neither", "left", "right"} + + Returns + ------- + left_right_inclusive : tuple[bool, bool] + + Raises + ------ + ValueError : if argument is not among valid values + """ + left_right_inclusive: tuple[bool, bool] | None = None + + if isinstance(inclusive, str): + left_right_inclusive = { + "both": (True, True), + "left": (True, False), + "right": (False, True), + "neither": (False, False), + }.get(inclusive) + + if left_right_inclusive is None: + raise ValueError( + "Inclusive has to be either 'both', 'neither', 'left' or 'right'" + ) + + return left_right_inclusive + + +def validate_insert_loc(loc: int, length: int) -> int: + """ + Check that we have an integer between -length and length, inclusive. + + Standardize negative loc to within [0, length]. + + The exceptions we raise on failure match np.insert. + """ + if not is_integer(loc): + raise TypeError(f"loc must be an integer between -{length} and {length}") + + if loc < 0: + loc += length + if not 0 <= loc <= length: + raise IndexError(f"loc must be an integer between -{length} and {length}") + return loc # pyright: ignore[reportGeneralTypeIssues] + + +def check_dtype_backend(dtype_backend) -> None: + if dtype_backend is not lib.no_default: + if dtype_backend not in ["numpy_nullable", "pyarrow"]: + raise ValueError( + f"dtype_backend {dtype_backend} is invalid, only 'numpy_nullable' and " + f"'pyarrow' are allowed.", + ) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_decomp/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_decomp/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f2a3a399684aa3ab77fb3b2f7d62ba99a2889182 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_decomp/__pycache__/__init__.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_decomp/__pycache__/decompositions_for_jvp.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_decomp/__pycache__/decompositions_for_jvp.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..42d8e005752111af273171957ba059f8ac8fa430 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_decomp/__pycache__/decompositions_for_jvp.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_decomp/__pycache__/decompositions_for_rng.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_decomp/__pycache__/decompositions_for_rng.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..897acad2dad13af87920284488e4fe9690f0887b Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_decomp/__pycache__/decompositions_for_rng.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..834dbce32f10bfb339fd2182a2455b43914441c9 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/__init__.py @@ -0,0 +1,61 @@ +# mypy: allow-untyped-defs +import dataclasses +import glob +import inspect +from os.path import basename, dirname, isfile, join + +import torch +from torch._export.db.case import ( + _EXAMPLE_CASES, + _EXAMPLE_CONFLICT_CASES, + _EXAMPLE_REWRITE_CASES, + SupportLevel, + export_case, + ExportCase, +) + + +def _collect_examples(): + case_names = glob.glob(join(dirname(__file__), "*.py")) + case_names = [ + basename(f)[:-3] for f in case_names if isfile(f) and not f.endswith("__init__.py") + ] + + case_fields = {f.name for f in dataclasses.fields(ExportCase)} + for case_name in case_names: + case = __import__(case_name, globals(), locals(), [], 1) + variables = [name for name in dir(case) if name in case_fields] + export_case(**{v: getattr(case, v) for v in variables})(case.model) + +_collect_examples() + +def all_examples(): + return _EXAMPLE_CASES + + +if len(_EXAMPLE_CONFLICT_CASES) > 0: + + def get_name(case): + model = case.model + if isinstance(model, torch.nn.Module): + model = type(model) + return model.__name__ + + msg = "Error on conflict export case name.\n" + for case_name, cases in _EXAMPLE_CONFLICT_CASES.items(): + msg += f"Case name {case_name} is associated with multiple cases:\n " + msg += f"[{','.join(map(get_name, cases))}]\n" + + raise RuntimeError(msg) + + +def filter_examples_by_support_level(support_level: SupportLevel): + return { + key: val + for key, val in all_examples().items() + if val.support_level == support_level + } + + +def get_rewrite_cases(case): + return _EXAMPLE_REWRITE_CASES.get(case.name, []) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/assume_constant_result.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/assume_constant_result.py new file mode 100644 index 0000000000000000000000000000000000000000..931ce7f7a50fc5a175101ac57c424c88cf31b54c --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/assume_constant_result.py @@ -0,0 +1,20 @@ +# mypy: allow-untyped-defs +import torch +import torch._dynamo as torchdynamo + + +class AssumeConstantResult(torch.nn.Module): + """ + Applying `assume_constant_result` decorator to burn make non-tracable code as constant. + """ + + @torchdynamo.assume_constant_result + def get_item(self, y): + return y.int().item() + + def forward(self, x, y): + return x[: self.get_item(y)] + +example_args = (torch.randn(3, 2), torch.tensor(4)) +tags = {"torch.escape-hatch"} +model = AssumeConstantResult() diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/autograd_function.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/autograd_function.py new file mode 100644 index 0000000000000000000000000000000000000000..efd645d13a7d5a13dc69d9ab3593772520b726c0 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/autograd_function.py @@ -0,0 +1,25 @@ +# mypy: allow-untyped-defs +import torch + +class MyAutogradFunction(torch.autograd.Function): + @staticmethod + # pyrefly: ignore [bad-override] + def forward(ctx, x): + return x.clone() + + @staticmethod + # pyrefly: ignore [bad-override] + def backward(ctx, grad_output): + return grad_output + 1 + +class AutogradFunction(torch.nn.Module): + """ + TorchDynamo does not keep track of backward() on autograd functions. We recommend to + use `allow_in_graph` to mitigate this problem. + """ + + def forward(self, x): + return MyAutogradFunction.apply(x) + +example_args = (torch.randn(3, 2),) +model = AutogradFunction() diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/class_method.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/class_method.py new file mode 100644 index 0000000000000000000000000000000000000000..f701f54d4f4ea1cb5816292cd60bb4df3d03c5e8 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/class_method.py @@ -0,0 +1,22 @@ +# mypy: allow-untyped-defs +import torch + +class ClassMethod(torch.nn.Module): + """ + Class methods are inlined during tracing. + """ + + @classmethod + def method(cls, x): + return x + 1 + + def __init__(self) -> None: + super().__init__() + self.linear = torch.nn.Linear(4, 2) + + def forward(self, x): + x = self.linear(x) + return self.method(x) * self.__class__.method(x) * type(self).method(x) + +example_args = (torch.randn(3, 4),) +model = ClassMethod() diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/cond_branch_class_method.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/cond_branch_class_method.py new file mode 100644 index 0000000000000000000000000000000000000000..22600cc504348d1d261b0ea2b9ed2d57af76b0a3 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/cond_branch_class_method.py @@ -0,0 +1,44 @@ +# mypy: allow-untyped-defs +import torch + +from functorch.experimental.control_flow import cond + +class MySubModule(torch.nn.Module): + def foo(self, x): + return x.cos() + + def forward(self, x): + return self.foo(x) + +class CondBranchClassMethod(torch.nn.Module): + """ + The branch functions (`true_fn` and `false_fn`) passed to cond() must follow these rules: + - both branches must take the same args, which must also match the branch args passed to cond. + - both branches must return a single tensor + - returned tensor must have the same tensor metadata, e.g. shape and dtype + - branch function can be free function, nested function, lambda, class methods + - branch function can not have closure variables + - no inplace mutations on inputs or global variables + + + This example demonstrates using class method in cond(). + + NOTE: If the `pred` is test on a dim with batch size < 2, it will be specialized. + """ + + def __init__(self) -> None: + super().__init__() + self.subm = MySubModule() + + def bar(self, x): + return x.sin() + + def forward(self, x): + return cond(x.shape[0] <= 2, self.subm.forward, self.bar, [x]) + +example_args = (torch.randn(3),) +tags = { + "torch.cond", + "torch.dynamic-shape", +} +model = CondBranchClassMethod() diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/cond_branch_nested_function.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/cond_branch_nested_function.py new file mode 100644 index 0000000000000000000000000000000000000000..b28ceeddc7956d136a8cf786c283344731d3e7ac --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/cond_branch_nested_function.py @@ -0,0 +1,41 @@ +# mypy: allow-untyped-defs +import torch + +from functorch.experimental.control_flow import cond + +class CondBranchNestedFunction(torch.nn.Module): + """ + The branch functions (`true_fn` and `false_fn`) passed to cond() must follow these rules: + - both branches must take the same args, which must also match the branch args passed to cond. + - both branches must return a single tensor + - returned tensor must have the same tensor metadata, e.g. shape and dtype + - branch function can be free function, nested function, lambda, class methods + - branch function can not have closure variables + - no inplace mutations on inputs or global variables + + This example demonstrates using nested function in cond(). + + NOTE: If the `pred` is test on a dim with batch size < 2, it will be specialized. + """ + + def forward(self, x): + def true_fn(x): + def inner_true_fn(y): + return x + y + + return inner_true_fn(x) + + def false_fn(x): + def inner_false_fn(y): + return x - y + + return inner_false_fn(x) + + return cond(x.shape[0] < 10, true_fn, false_fn, [x]) + +example_args = (torch.randn(3),) +tags = { + "torch.cond", + "torch.dynamic-shape", +} +model = CondBranchNestedFunction() diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/cond_branch_nonlocal_variables.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/cond_branch_nonlocal_variables.py new file mode 100644 index 0000000000000000000000000000000000000000..50d0ec87a690d063cb0e841fc057a6ae95c369fb --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/cond_branch_nonlocal_variables.py @@ -0,0 +1,59 @@ +# mypy: allow-untyped-defs +import torch + +from functorch.experimental.control_flow import cond + +class CondBranchNonlocalVariables(torch.nn.Module): + """ + The branch functions (`true_fn` and `false_fn`) passed to cond() must follow these rules: + - both branches must take the same args, which must also match the branch args passed to cond. + - both branches must return a single tensor + - returned tensor must have the same tensor metadata, e.g. shape and dtype + - branch function can be free function, nested function, lambda, class methods + - branch function can not have closure variables + - no inplace mutations on inputs or global variables + + This example demonstrates how to rewrite code to avoid capturing closure variables in branch functions. + + The code below will not work because capturing closure variables is not supported. + ``` + my_tensor_var = x + 100 + my_primitive_var = 3.14 + + def true_fn(y): + nonlocal my_tensor_var, my_primitive_var + return y + my_tensor_var + my_primitive_var + + def false_fn(y): + nonlocal my_tensor_var, my_primitive_var + return y - my_tensor_var - my_primitive_var + + return cond(x.shape[0] > 5, true_fn, false_fn, [x]) + ``` + + NOTE: If the `pred` is test on a dim with batch size < 2, it will be specialized. + """ + + def forward(self, x): + my_tensor_var = x + 100 + my_primitive_var = 3.14 + + def true_fn(x, y, z): + return x + y + z + + def false_fn(x, y, z): + return x - y - z + + return cond( + x.shape[0] > 5, + true_fn, + false_fn, + [x, my_tensor_var, torch.tensor(my_primitive_var)], + ) + +example_args = (torch.randn(6),) +tags = { + "torch.cond", + "torch.dynamic-shape", +} +model = CondBranchNonlocalVariables() diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/cond_closed_over_variable.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/cond_closed_over_variable.py new file mode 100644 index 0000000000000000000000000000000000000000..183180ab4fc825385170fea2bec6af184374a67e --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/cond_closed_over_variable.py @@ -0,0 +1,22 @@ +# mypy: allow-untyped-defs +import torch + +from functorch.experimental.control_flow import cond + +class CondClosedOverVariable(torch.nn.Module): + """ + torch.cond() supports branches closed over arbitrary variables. + """ + + def forward(self, pred, x): + def true_fn(val): + return x * 2 + + def false_fn(val): + return x - 2 + + return cond(pred, true_fn, false_fn, [x + 1]) + +example_args = (torch.tensor(True), torch.randn(3, 2)) +tags = {"torch.cond", "python.closure"} +model = CondClosedOverVariable() diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/cond_operands.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/cond_operands.py new file mode 100644 index 0000000000000000000000000000000000000000..60a75d24639cdac991298e99acf96b8eadbff442 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/cond_operands.py @@ -0,0 +1,35 @@ +# mypy: allow-untyped-defs +import torch + +from torch.export import Dim + +x = torch.randn(3, 2) +y = torch.randn(2) +dim0_x = Dim("dim0_x") + +class CondOperands(torch.nn.Module): + """ + The operands passed to cond() must be: + - a list of tensors + - match arguments of `true_fn` and `false_fn` + + NOTE: If the `pred` is test on a dim with batch size < 2, it will be specialized. + """ + + def forward(self, x, y): + def true_fn(x, y): + return x + y + + def false_fn(x, y): + return x - y + + return torch.cond(x.shape[0] > 2, true_fn, false_fn, [x, y]) + +example_args = (x, y) +tags = { + "torch.cond", + "torch.dynamic-shape", +} +extra_inputs = (torch.randn(2, 2), torch.randn(2)) +dynamic_shapes = {"x": {0: dim0_x}, "y": None} +model = CondOperands() diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/cond_predicate.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/cond_predicate.py new file mode 100644 index 0000000000000000000000000000000000000000..68bb8850bba909a0c6546c8f12a1a3fa1bdc70d1 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/cond_predicate.py @@ -0,0 +1,25 @@ +# mypy: allow-untyped-defs +import torch + +from functorch.experimental.control_flow import cond + +class CondPredicate(torch.nn.Module): + """ + The conditional statement (aka predicate) passed to cond() must be one of the following: + - torch.Tensor with a single element + - boolean expression + + NOTE: If the `pred` is test on a dim with batch size < 2, it will be specialized. + """ + + def forward(self, x): + pred = x.dim() > 2 and x.shape[2] > 10 + + return cond(pred, lambda x: x.cos(), lambda y: y.sin(), [x]) + +example_args = (torch.randn(6, 4, 3),) +tags = { + "torch.cond", + "torch.dynamic-shape", +} +model = CondPredicate() diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/constrain_as_size_example.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/constrain_as_size_example.py new file mode 100644 index 0000000000000000000000000000000000000000..934746aaf6739de7a37077d8ec3c2776586a5657 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/constrain_as_size_example.py @@ -0,0 +1,23 @@ +# mypy: allow-untyped-defs +import torch + + +class ConstrainAsSizeExample(torch.nn.Module): + """ + If the value is not known at tracing time, you can provide hint so that we + can trace further. Please look at torch._check APIs. + """ + + def forward(self, x): + a = x.item() + torch._check(a >= 0) + torch._check(a <= 5) + return torch.zeros((a, 5)) + + +example_args = (torch.tensor(4),) +tags = { + "torch.dynamic-value", + "torch.escape-hatch", +} +model = ConstrainAsSizeExample() diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/constrain_as_value_example.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/constrain_as_value_example.py new file mode 100644 index 0000000000000000000000000000000000000000..22f791a3e80474257c27d927bad56cf4c2fbce78 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/constrain_as_value_example.py @@ -0,0 +1,26 @@ +# mypy: allow-untyped-defs +import torch + + +class ConstrainAsValueExample(torch.nn.Module): + """ + If the value is not known at tracing time, you can provide hint so that we + can trace further. Please look at torch._check API. + """ + + def forward(self, x, y): + a = x.item() + torch._check(a >= 0) + torch._check(a <= 5) + + if a < 6: + return y.sin() + return y.cos() + + +example_args = (torch.tensor(4), torch.randn(5, 5)) +tags = { + "torch.dynamic-value", + "torch.escape-hatch", +} +model = ConstrainAsValueExample() diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/decorator.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/decorator.py new file mode 100644 index 0000000000000000000000000000000000000000..7d24cc681a6b62adf40bfd9a2e5283afb3515131 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/decorator.py @@ -0,0 +1,23 @@ +# mypy: allow-untyped-defs +import functools + +import torch + +def test_decorator(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + return func(*args, **kwargs) + 1 + + return wrapper + +class Decorator(torch.nn.Module): + """ + Decorators calls are inlined into the exported function during tracing. + """ + + @test_decorator + def forward(self, x, y): + return x + y + +example_args = (torch.randn(3, 2), torch.randn(3, 2)) +model = Decorator() diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/dictionary.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/dictionary.py new file mode 100644 index 0000000000000000000000000000000000000000..49e688bc0ac1f09567e3b877aaca29a1d02b4121 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/dictionary.py @@ -0,0 +1,17 @@ +# mypy: allow-untyped-defs +import torch + +class Dictionary(torch.nn.Module): + """ + Dictionary structures are inlined and flattened along tracing. + """ + + def forward(self, x, y): + elements = {} + elements["x2"] = x * x + y = y * elements["x2"] + return {"y": y} + +example_args = (torch.randn(3, 2), torch.tensor(4)) +tags = {"python.data-structure"} +model = Dictionary() diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/dynamic_shape_assert.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/dynamic_shape_assert.py new file mode 100644 index 0000000000000000000000000000000000000000..cc822e5553e1ab8bd350a26966c22f1a9a1698cf --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/dynamic_shape_assert.py @@ -0,0 +1,18 @@ +# mypy: allow-untyped-defs +import torch + +class DynamicShapeAssert(torch.nn.Module): + """ + A basic usage of python assertion. + """ + + def forward(self, x): + # assertion with error message + assert x.shape[0] > 2, f"{x.shape[0]} is greater than 2" + # assertion without error message + assert x.shape[0] > 1 + return x + +example_args = (torch.randn(3, 2),) +tags = {"python.assert"} +model = DynamicShapeAssert() diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/dynamic_shape_constructor.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/dynamic_shape_constructor.py new file mode 100644 index 0000000000000000000000000000000000000000..157e460274ad58ba71c886b35364ddc0cd4d886a --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/dynamic_shape_constructor.py @@ -0,0 +1,15 @@ +# mypy: allow-untyped-defs +import torch + +class DynamicShapeConstructor(torch.nn.Module): + """ + Tensor constructors should be captured with dynamic shape inputs rather + than being baked in with static shape. + """ + + def forward(self, x): + return torch.zeros(x.shape[0] * 2) + +example_args = (torch.randn(3, 2),) +tags = {"torch.dynamic-shape"} +model = DynamicShapeConstructor() diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/dynamic_shape_if_guard.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/dynamic_shape_if_guard.py new file mode 100644 index 0000000000000000000000000000000000000000..21824ef3a0f66eb25f4d8e8c1ba92f53fdd4c275 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/dynamic_shape_if_guard.py @@ -0,0 +1,19 @@ +# mypy: allow-untyped-defs +import torch + +class DynamicShapeIfGuard(torch.nn.Module): + """ + `if` statement with backed dynamic shape predicate will be specialized into + one particular branch and generate a guard. However, export will fail if the + the dimension is marked as dynamic shape from higher level API. + """ + + def forward(self, x): + if x.shape[0] == 3: + return x.cos() + + return x.sin() + +example_args = (torch.randn(3, 2, 2),) +tags = {"torch.dynamic-shape", "python.control-flow"} +model = DynamicShapeIfGuard() diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/dynamic_shape_map.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/dynamic_shape_map.py new file mode 100644 index 0000000000000000000000000000000000000000..f8066aed556b9ee588b9744d17ba16c35d8fed6c --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/dynamic_shape_map.py @@ -0,0 +1,19 @@ +# mypy: allow-untyped-defs +import torch + +from functorch.experimental.control_flow import map + +class DynamicShapeMap(torch.nn.Module): + """ + functorch map() maps a function over the first tensor dimension. + """ + + def forward(self, xs, y): + def body(x, y): + return x + y + + return map(body, xs, y) + +example_args = (torch.randn(3, 2), torch.randn(2)) +tags = {"torch.dynamic-shape", "torch.map"} +model = DynamicShapeMap() diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/dynamic_shape_round.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/dynamic_shape_round.py new file mode 100644 index 0000000000000000000000000000000000000000..decbf036553cb76544a19e531e5aee98d792ae0b --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/dynamic_shape_round.py @@ -0,0 +1,21 @@ +# mypy: allow-untyped-defs +import torch + +from torch._export.db.case import SupportLevel +from torch.export import Dim + +class DynamicShapeRound(torch.nn.Module): + """ + Calling round on dynamic shapes is not supported. + """ + + def forward(self, x): + return x[: round(x.shape[0] / 2)] + +x = torch.randn(3, 2) +dim0_x = Dim("dim0_x") +example_args = (x,) +tags = {"torch.dynamic-shape", "python.builtin"} +support_level = SupportLevel.NOT_SUPPORTED_YET +dynamic_shapes = {"x": {0: dim0_x}} +model = DynamicShapeRound() diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/dynamic_shape_slicing.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/dynamic_shape_slicing.py new file mode 100644 index 0000000000000000000000000000000000000000..360fe15f6f98d9d735366bfa53371d79e0b00209 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/dynamic_shape_slicing.py @@ -0,0 +1,15 @@ +# mypy: allow-untyped-defs +import torch + +class DynamicShapeSlicing(torch.nn.Module): + """ + Slices with dynamic shape arguments should be captured into the graph + rather than being baked in. + """ + + def forward(self, x): + return x[: x.shape[0] - 2, x.shape[1] - 1 :: 2] + +example_args = (torch.randn(3, 2),) +tags = {"torch.dynamic-shape"} +model = DynamicShapeSlicing() diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/fn_with_kwargs.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/fn_with_kwargs.py new file mode 100644 index 0000000000000000000000000000000000000000..46b2637b398c21bf9399d0a3fa2a964354beea3e --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/fn_with_kwargs.py @@ -0,0 +1,30 @@ +# mypy: allow-untyped-defs +import torch + +class FnWithKwargs(torch.nn.Module): + """ + Keyword arguments are not supported at the moment. + """ + + def forward(self, pos0, tuple0, *myargs, mykw0, **mykwargs): + out = pos0 + for arg in tuple0: + out = out * arg + for arg in myargs: + out = out * arg + out = out * mykw0 + out = out * mykwargs["input0"] * mykwargs["input1"] + return out + +example_args = ( + torch.randn(4), + (torch.randn(4), torch.randn(4)), + *[torch.randn(4), torch.randn(4)] +) +example_kwargs = { + "mykw0": torch.randn(4), + "input0": torch.randn(4), + "input1": torch.randn(4), +} +tags = {"python.data-structure"} +model = FnWithKwargs() diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/list_contains.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/list_contains.py new file mode 100644 index 0000000000000000000000000000000000000000..35a140f4ee2e5d6f42c3509984333db896f1c081 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/list_contains.py @@ -0,0 +1,17 @@ +# mypy: allow-untyped-defs +import torch + +class ListContains(torch.nn.Module): + """ + List containment relation can be checked on a dynamic shape or constants. + """ + + def forward(self, x): + assert x.size(-1) in [6, 2] + assert x.size(0) not in [4, 5, 6] + assert "monkey" not in ["cow", "pig"] + return x + x + +example_args = (torch.randn(3, 2),) +tags = {"torch.dynamic-shape", "python.data-structure", "python.assert"} +model = ListContains() diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/list_unpack.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/list_unpack.py new file mode 100644 index 0000000000000000000000000000000000000000..98533cfab5498934a61fbe693bb2497d5dbc9738 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/list_unpack.py @@ -0,0 +1,21 @@ +# mypy: allow-untyped-defs + +import torch + +class ListUnpack(torch.nn.Module): + """ + Lists are treated as static construct, therefore unpacking should be + erased after tracing. + """ + + def forward(self, args: list[torch.Tensor]): + """ + Lists are treated as static construct, therefore unpacking should be + erased after tracing. + """ + x, *y = args + return x + y[0] + +example_args = ([torch.randn(3, 2), torch.tensor(4), torch.tensor(5)],) +tags = {"python.control-flow", "python.data-structure"} +model = ListUnpack() diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/model_attr_mutation.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/model_attr_mutation.py new file mode 100644 index 0000000000000000000000000000000000000000..122b0ddfc3429fb31415a146e8e1dcfddb2fe031 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/model_attr_mutation.py @@ -0,0 +1,24 @@ +# mypy: allow-untyped-defs +import torch + + +class ModelAttrMutation(torch.nn.Module): + """ + Attribute mutation raises a warning. Covered in the test_export.py test_detect_leak_strict test. + """ + + def __init__(self) -> None: + super().__init__() + self.attr_list = [torch.randn(3, 2), torch.randn(3, 2)] + + def recreate_list(self): + return [torch.zeros(3, 2), torch.zeros(3, 2)] + + def forward(self, x): + self.attr_list = self.recreate_list() + return x.sum() + self.attr_list[0].sum() + + +example_args = (torch.randn(3, 2),) +tags = {"python.object-model"} +model = ModelAttrMutation() diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/nested_function.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/nested_function.py new file mode 100644 index 0000000000000000000000000000000000000000..e4076ac14dada40b4d78812666a9ec6b5e67045b --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/nested_function.py @@ -0,0 +1,23 @@ +# mypy: allow-untyped-defs +import torch + +class NestedFunction(torch.nn.Module): + """ + Nested functions are traced through. Side effects on global captures + are not supported though. + """ + + def forward(self, a, b): + x = a + b + z = a - b + + def closure(y): + nonlocal x + x += 1 + return x * y + z + + return closure(x) + +example_args = (torch.randn(3, 2), torch.randn(2)) +tags = {"python.closure"} +model = NestedFunction() diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/null_context_manager.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/null_context_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..80d09f68097edbe676077be183711dabe5cbc664 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/null_context_manager.py @@ -0,0 +1,21 @@ +# mypy: allow-untyped-defs +import contextlib + +import torch + +class NullContextManager(torch.nn.Module): + """ + Null context manager in Python will be traced out. + """ + + def forward(self, x): + """ + Null context manager in Python will be traced out. + """ + ctx = contextlib.nullcontext() + with ctx: + return x.sin() + x.cos() + +example_args = (torch.randn(3, 2),) +tags = {"python.context-manager"} +model = NullContextManager() diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/optional_input.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/optional_input.py new file mode 100644 index 0000000000000000000000000000000000000000..41e66a7c977a83bf59116864c7f443387277f06e --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/optional_input.py @@ -0,0 +1,20 @@ +# mypy: allow-untyped-defs +import torch +from torch._export.db.case import SupportLevel + + +class OptionalInput(torch.nn.Module): + """ + Tracing through optional input is not supported yet + """ + + def forward(self, x, y=torch.randn(2, 3)): + if y is not None: + return x + y + return x + + +example_args = (torch.randn(2, 3),) +tags = {"python.object-model"} +support_level = SupportLevel.SUPPORTED +model = OptionalInput() diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/pytree_flatten.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/pytree_flatten.py new file mode 100644 index 0000000000000000000000000000000000000000..804e73c5a6d58ad5b5be179bf67a5d5bc38c2e2b --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/pytree_flatten.py @@ -0,0 +1,16 @@ +# mypy: allow-untyped-defs +import torch + +from torch.utils import _pytree as pytree + +class PytreeFlatten(torch.nn.Module): + """ + Pytree from PyTorch can be captured by TorchDynamo. + """ + + def forward(self, x): + y, _spec = pytree.tree_flatten(x) + return y[0] + 1 + +example_args = ({1: torch.randn(3, 2), 2: torch.randn(3, 2)},), +model = PytreeFlatten() diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/specialized_attribute.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/specialized_attribute.py new file mode 100644 index 0000000000000000000000000000000000000000..f17092f9afc681b91a982a8a2479ac1dde4f455d --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/specialized_attribute.py @@ -0,0 +1,26 @@ +# mypy: allow-untyped-defs +from enum import Enum + +import torch + +class Animal(Enum): + COW = "moo" + +class SpecializedAttribute(torch.nn.Module): + """ + Model attributes are specialized. + """ + + def __init__(self) -> None: + super().__init__() + self.a = "moo" + self.b = 4 + + def forward(self, x): + if self.a == Animal.COW.value: + return x * x + self.b + else: + raise ValueError("bad") + +example_args = (torch.randn(3, 2),) +model = SpecializedAttribute() diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/static_for_loop.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/static_for_loop.py new file mode 100644 index 0000000000000000000000000000000000000000..aa62b86d16d9b6a1c539976a891f58bd732ae30d --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/static_for_loop.py @@ -0,0 +1,16 @@ +# mypy: allow-untyped-defs +import torch + +class StaticForLoop(torch.nn.Module): + """ + A for loop with constant number of iterations should be unrolled in the exported graph. + """ + + def forward(self, x): + # constant + ret = [i + x for i in range(10)] + return ret + +example_args = (torch.randn(3, 2),) +tags = {"python.control-flow"} +model = StaticForLoop() diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/static_if.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/static_if.py new file mode 100644 index 0000000000000000000000000000000000000000..f169380159a45489142ce5ae3523b2e4504c6721 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/static_if.py @@ -0,0 +1,18 @@ +# mypy: allow-untyped-defs +import torch + +class StaticIf(torch.nn.Module): + """ + `if` statement with static predicate value should be traced through with the + taken branch. + """ + + def forward(self, x): + if len(x.shape) == 3: + return x + torch.ones(1, 1, 1) + + return x + +example_args = (torch.randn(3, 2, 2),) +tags = {"python.control-flow"} +model = StaticIf() diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/tensor_setattr.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/tensor_setattr.py new file mode 100644 index 0000000000000000000000000000000000000000..8fbc263e7ff2240a3cf8618c56f152e744aa40e8 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/tensor_setattr.py @@ -0,0 +1,15 @@ +# mypy: allow-untyped-defs +import torch + + +class TensorSetattr(torch.nn.Module): + """ + setattr() call onto tensors is not supported. + """ + def forward(self, x, attr): + setattr(x, attr, torch.randn(3, 2)) + return x + 4 + +example_args = (torch.randn(3, 2), "attr") +tags = {"python.builtin"} +model = TensorSetattr() diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/type_reflection_method.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/type_reflection_method.py new file mode 100644 index 0000000000000000000000000000000000000000..99ad42a153c512d65aaae1bcac2377ee1e124f25 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/type_reflection_method.py @@ -0,0 +1,22 @@ +# mypy: allow-untyped-defs +import torch + +class A: + @classmethod + def func(cls, x): + return 1 + x + +class TypeReflectionMethod(torch.nn.Module): + """ + type() calls on custom objects followed by attribute accesses are not allowed + due to its overly dynamic nature. + """ + + def forward(self, x): + a = A() + return type(a).func(x) + + +example_args = (torch.randn(3, 4),) +tags = {"python.builtin"} +model = TypeReflectionMethod() diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/unsupported_operator.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/unsupported_operator.py new file mode 100644 index 0000000000000000000000000000000000000000..f5a52d80b895b3b2c2d85b878ca4efea511e73ea --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/examples/unsupported_operator.py @@ -0,0 +1,18 @@ +# mypy: allow-untyped-defs +import torch +from torch._export.db.case import SupportLevel + + +class TorchSymMin(torch.nn.Module): + """ + torch.sym_min operator is not supported in export. + """ + + def forward(self, x): + return x.sum() + torch.sym_min(x.size(0), 100) + + +example_args = (torch.randn(3, 2),) +tags = {"torch.operator"} +support_level = SupportLevel.NOT_SUPPORTED_YET +model = TorchSymMin() diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/pass_infra/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/pass_infra/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d08f88b0bef092469d0752d3a210d872eb030790 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/pass_infra/__pycache__/__init__.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/pass_infra/__pycache__/node_metadata.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/pass_infra/__pycache__/node_metadata.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c08a9760b64ad27a9da63f6cc05104abc6b77f8e Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/pass_infra/__pycache__/node_metadata.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/pass_infra/__pycache__/proxy_value.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/pass_infra/__pycache__/proxy_value.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7ac24e21f214c828f9b822176a50921f30af7b81 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/pass_infra/__pycache__/proxy_value.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/passes/__pycache__/_node_metadata_hook.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/passes/__pycache__/_node_metadata_hook.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..91f719bda471e4778806f41f2848f081cfc47f13 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/passes/__pycache__/_node_metadata_hook.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/passes/__pycache__/lift_constants_pass.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/passes/__pycache__/lift_constants_pass.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9ea593c2859fe5b84abe09f46157175b5ce0271b Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/passes/__pycache__/lift_constants_pass.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/passes/__pycache__/replace_quantized_ops_with_standard_ops_pass.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/passes/__pycache__/replace_quantized_ops_with_standard_ops_pass.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9bc82dff6880485e70f902b67ee9d3dfe520d976 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/passes/__pycache__/replace_quantized_ops_with_standard_ops_pass.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/__pycache__/cpp_flex_attention_template.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/__pycache__/cpp_flex_attention_template.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..89f006ef372e94098ed725c646a67a5a5152b14c Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/__pycache__/cpp_flex_attention_template.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/__pycache__/cpp_template_kernel.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/__pycache__/cpp_template_kernel.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d5316a5b4bfdf72088a5d0a2aa1b8365170607d6 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/__pycache__/cpp_template_kernel.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/__pycache__/cpp_wrapper_cpu_array_ref.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/__pycache__/cpp_wrapper_cpu_array_ref.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..83e92ffd3be27d11089abd1c6a10f3cf2cc6eb03 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/__pycache__/cpp_wrapper_cpu_array_ref.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/__pycache__/cpu_device_op_overrides.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/__pycache__/cpu_device_op_overrides.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..538f6d1a186b7940b0e6cf1dec9271284a97851b Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/__pycache__/cpu_device_op_overrides.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/__pycache__/cuda_combined_scheduling.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/__pycache__/cuda_combined_scheduling.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fe2db29e2f8f194525942f3586eb47acc42cc0c9 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/__pycache__/cuda_combined_scheduling.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/__pycache__/debug_utils.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/__pycache__/debug_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..80bb8732315bfa541349ad3bf55f8b4ee9894020 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/__pycache__/debug_utils.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/__pycache__/memory_planning.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/__pycache__/memory_planning.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..af801ac6e89770de50a0f9eeb8bf12c6a76b5560 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/__pycache__/memory_planning.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/__pycache__/multi_kernel.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/__pycache__/multi_kernel.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..86456a08b65a092cd7e1a9a63037dc5ea4a68d62 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/__pycache__/multi_kernel.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/__pycache__/pallas.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/__pycache__/pallas.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7a5fc5f0285816b419e887c6aadc6d83b20acccb Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/__pycache__/pallas.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/__pycache__/python_wrapper_mtia.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/__pycache__/python_wrapper_mtia.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e80e4d3a1d09c296e4dd6cff801429592a0688a0 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/__pycache__/python_wrapper_mtia.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/__pycache__/triton_split_scan.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/__pycache__/triton_split_scan.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a8fac8adf467eb4bcaa21993c5d5c391e22be056 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/__pycache__/triton_split_scan.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/__pycache__/wrapper_fxir.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/__pycache__/wrapper_fxir.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..62687e9adbc198547e2d537d37e4fd8964a3ea8d Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/__pycache__/wrapper_fxir.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/aoti_runtime/interface.cpp b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/aoti_runtime/interface.cpp new file mode 100644 index 0000000000000000000000000000000000000000..515ab89d1f2d1187fe2855733ecbbc523fbeae0a --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/aoti_runtime/interface.cpp @@ -0,0 +1,488 @@ +// Definition of AOTI runtime interface functions + +#include +#include + +#include +#include + +#define CONVERT_EXCEPTION_TO_ERROR_CODE(...) \ + try { \ + __VA_ARGS__ \ + } catch (const std::exception& e) { \ + std::cerr << "Error: " << e.what() << '\n'; \ + return AOTI_RUNTIME_FAILURE; \ + } catch (...) { \ + std::cerr << "Unknown exception occurred.\n"; \ + return AOTI_RUNTIME_FAILURE; \ + } \ + return AOTI_RUNTIME_SUCCESS; + +#define AOTI_VECTOR_SIZE_CHECK(actual_size, expected_size, name) \ + do { \ + AOTI_RUNTIME_CHECK( \ + actual_size == expected_size, \ + "expected " + std::string(name) + " vector size to be " + \ + std::to_string(expected_size) + ", but got " + \ + std::to_string(actual_size)); \ + } while (0) + +// AOTInductor uses at::addmm_out, which doesn't supports +// arguments that requires gradient. For this reason, we +// enforce no_grad context for run APIs. +// +// A RAII, thread local (!) guard that enables or disables grad mode upon +// construction, and sets it back to the original value upon destruction. +struct AOTINoGradGuard { + AOTINoGradGuard() { + aoti_torch_grad_mode_set_enabled(false); + } + AOTINoGradGuard(const AOTINoGradGuard&) = delete; + AOTINoGradGuard(AOTINoGradGuard&&) noexcept = delete; + ~AOTINoGradGuard() { + aoti_torch_grad_mode_set_enabled(prev_mode); + } + AOTINoGradGuard& operator=(const AOTINoGradGuard&) = delete; + AOTINoGradGuard& operator=(AOTINoGradGuard&&) noexcept = delete; + bool prev_mode{aoti_torch_grad_mode_is_enabled()}; +}; + +extern "C" { + +AOTIRuntimeError AOTInductorModelContainerCreate( + AOTInductorModelContainerHandle* container_handle, + size_t num_models, + bool is_cpu, + const char* cubin_dir) { + return AOTInductorModelContainerCreateWithDevice( + container_handle, + num_models, + is_cpu ? "cpu" : "cuda", + cubin_dir); +} + +AOTIRuntimeError AOTInductorModelContainerCreateWithDevice( + AOTInductorModelContainerHandle* container_handle, + size_t num_models, + const char* device_str, + const char* cubin_dir) { + + if (num_models == 0) { + std::cerr << "Error: num_models must be positive, but got 0\n"; + return AOTI_RUNTIME_FAILURE; + } + CONVERT_EXCEPTION_TO_ERROR_CODE({ + std::optional cubin_dir_opt; + if (cubin_dir != nullptr) { + cubin_dir_opt.emplace(cubin_dir); + } + auto* container = new torch::aot_inductor::AOTInductorModelContainer( + num_models, std::string(device_str), cubin_dir_opt); + *container_handle = + reinterpret_cast(container); + }) +} + + +AOTIRuntimeError AOTInductorModelContainerDelete( + AOTInductorModelContainerHandle container_handle) { + CONVERT_EXCEPTION_TO_ERROR_CODE({ + auto* container = + reinterpret_cast( + container_handle); + delete container; + }); +} + +AOTIRuntimeError AOTInductorModelContainerRun( + AOTInductorModelContainerHandle container_handle, + AtenTensorHandle* input_handles, // array of input AtenTensorHandle; handles + // are stolen; the array itself is borrowed + size_t num_inputs, + AtenTensorHandle* + output_handles, // array for writing output AtenTensorHandle; handles + // will be stolen by the caller; the array itself is + // borrowed + size_t num_outputs, + AOTInductorStreamHandle stream_handle, + AOTIProxyExecutorHandle proxy_executor_handle) { + auto* container = + reinterpret_cast( + container_handle); + AOTI_VECTOR_SIZE_CHECK(num_inputs, container->num_inputs(), "inputs"); + AOTI_VECTOR_SIZE_CHECK(num_outputs, container->num_outputs(), "outputs"); + + auto stream = + reinterpret_cast(stream_handle); + CONVERT_EXCEPTION_TO_ERROR_CODE({ + AOTINoGradGuard guard; + container->run( + input_handles, output_handles, stream, proxy_executor_handle); + }) +} + +AOTIRuntimeError AOTInductorModelContainerRunSingleThreaded( + AOTInductorModelContainerHandle container_handle, + AtenTensorHandle* input_handles, // array of input AtenTensorHandle; handles + // are stolen; the array itself is borrowed + size_t num_inputs, + AtenTensorHandle* + output_handles, // array for writing output AtenTensorHandle; handles + // will be stolen by the caller; the array itself is + // borrowed + size_t num_outputs, + AOTInductorStreamHandle stream_handle, + AOTIProxyExecutorHandle proxy_executor_handle) { + auto* container = + reinterpret_cast( + container_handle); + AOTI_VECTOR_SIZE_CHECK(num_inputs, container->num_inputs(), "inputs"); + AOTI_VECTOR_SIZE_CHECK(num_outputs, container->num_outputs(), "outputs"); + + auto stream = + reinterpret_cast(stream_handle); + CONVERT_EXCEPTION_TO_ERROR_CODE({ + AOTINoGradGuard guard; + container->run_single_threaded( + input_handles, output_handles, stream, proxy_executor_handle); + }) +} + +AOTIRuntimeError AOTInductorModelContainerGetNumConstants( + AOTInductorModelContainerHandle container_handle, + size_t* num_constants) { + auto* container = + reinterpret_cast( + container_handle); + CONVERT_EXCEPTION_TO_ERROR_CODE( + { *num_constants = container->num_constants(); }) +} + +AOTIRuntimeError AOTInductorModelContainerGetConstantName( + AOTInductorModelContainerHandle container_handle, + size_t idx, + const char** name) { + auto* container = + reinterpret_cast( + container_handle); + CONVERT_EXCEPTION_TO_ERROR_CODE( + { *name = container->constant_name(idx); }) +} + +AOTIRuntimeError AOTInductorModelContainerGetConstantOriginalFQN( + AOTInductorModelContainerHandle container_handle, + size_t idx, + const char** original_fqn) { + auto* container = + reinterpret_cast( + container_handle); + CONVERT_EXCEPTION_TO_ERROR_CODE( + { *original_fqn = container->constant_original_fqn(idx); }) +} + +AOTIRuntimeError AOTInductorModelContainerGetConstantFromFolded( + AOTInductorModelContainerHandle container_handle, + size_t idx, + bool* from_folded) { + auto* container = + reinterpret_cast(container_handle); + CONVERT_EXCEPTION_TO_ERROR_CODE({ *from_folded = container->constant_from_folded(idx); }) +} + +AOTIRuntimeError AOTInductorModelContainerGetConstantType( + AOTInductorModelContainerHandle container_handle, + size_t idx, + int32_t* type) { + auto* container = + reinterpret_cast(container_handle); + CONVERT_EXCEPTION_TO_ERROR_CODE({ *type = container->constant_type(idx); }) +} + +AOTIRuntimeError AOTInductorModelContainerGetConstantDtype( + AOTInductorModelContainerHandle container_handle, + size_t idx, + int32_t* dtype) { + auto* container = + reinterpret_cast( + container_handle); + CONVERT_EXCEPTION_TO_ERROR_CODE( + { *dtype = container->constant_dtype(idx); }) +} + +AOTIRuntimeError AOTInductorModelContainerGetConstantDataSize( + AOTInductorModelContainerHandle container_handle, + size_t idx, + size_t* data_size) { + auto* container = + reinterpret_cast( + container_handle); + CONVERT_EXCEPTION_TO_ERROR_CODE( + { *data_size = container->constant_data_size(idx); }) +} + +AOTIRuntimeError AOTInductorModelContainerExtractConstantsMap( + AOTInductorModelContainerHandle container_handle, + AOTInductorConstantMapHandle constant_map_handle, + bool use_inactive) { + auto* container = + reinterpret_cast( + container_handle); + auto constants_map = reinterpret_cast*>(constant_map_handle); + CONVERT_EXCEPTION_TO_ERROR_CODE( + { const auto ret = container->extract_constants_map(use_inactive); + for (const auto& pair: ret) { + constants_map->emplace(pair.first, pair.second); + } + }) +} + +AOTIRuntimeError AOTInductorModelContainerUpdateUserManagedConstantBuffer( + AOTInductorModelContainerHandle container_handle, + AOTInductorConstantMapHandle constant_map_handle, + bool use_inactive, + bool validate_full_update) { + auto* container = + reinterpret_cast( + container_handle); + auto input_map = reinterpret_cast*>(constant_map_handle); + CONVERT_EXCEPTION_TO_ERROR_CODE({ + container->update_constant_buffer( + *input_map, use_inactive, validate_full_update, /* user_managed = */ true); + }) +} + +AOTIRuntimeError AOTInductorModelContainerUpdateUserManagedConstantBufferPairs( + AOTInductorModelContainerHandle container_handle, + const AOTInductorConstantMapEntry* pairs, + size_t num_pairs, + bool use_inactive, + bool validate_full_update) { + auto* container = + reinterpret_cast(container_handle); + // Build a local unordered_map inside + std::unordered_map input_map; + input_map.reserve(num_pairs); + for (size_t i = 0; i < num_pairs; ++i) { + input_map.emplace(pairs[i].name, pairs[i].handle); + } + CONVERT_EXCEPTION_TO_ERROR_CODE({ + container->update_constant_buffer( + input_map, use_inactive, validate_full_update, /*user_managed=*/true); + }) +} + +AOTIRuntimeError AOTInductorModelContainerUpdateConstantBuffer( + AOTInductorModelContainerHandle container_handle, + AOTInductorConstantMapHandle constant_map_handle, + bool use_inactive, + bool validate_full_update) { + auto* container = + reinterpret_cast( + container_handle); + auto input_map = reinterpret_cast*>(constant_map_handle); + CONVERT_EXCEPTION_TO_ERROR_CODE({ + container->update_constant_buffer( + *input_map, use_inactive, validate_full_update); + }) +} + +AOTIRuntimeError AOTInductorModelContainerUpdateInactiveConstantBuffer( + AOTInductorModelContainerHandle container_handle, + AOTInductorConstantMapHandle constant_map_handle) { + return AOTInductorModelContainerUpdateConstantBuffer(container_handle, + constant_map_handle, + /*use_inactive*/ true, + /*validate_full_update*/ true); +} + +AOTIRuntimeError AOTInductorModelContainerFreeInactiveConstantBuffer( + AOTInductorModelContainerHandle container_handle) { + auto* container = + reinterpret_cast( + container_handle); + CONVERT_EXCEPTION_TO_ERROR_CODE({ + container->free_inactive_constant_buffer(); + }) +} + +AOTIRuntimeError AOTInductorModelContainerRunConstantFolding( + AOTInductorModelContainerHandle container_handle, + bool use_inactive, + AOTInductorStreamHandle stream_handle, + AOTIProxyExecutorHandle proxy_executor_handle) { + auto* container = + reinterpret_cast( + container_handle); + auto stream = + reinterpret_cast(stream_handle); + CONVERT_EXCEPTION_TO_ERROR_CODE({ + AOTINoGradGuard guard; + container->run_const_fold(use_inactive, stream, proxy_executor_handle); + }) +} + +AOTIRuntimeError AOTInductorModelContainerSwapConstantBuffer( + AOTInductorModelContainerHandle container_handle) { + auto* container = + reinterpret_cast( + container_handle); + CONVERT_EXCEPTION_TO_ERROR_CODE({ + container->swap_constant_buffer(); + }) +} + +AOTIRuntimeError AOTInductorModelContainerGetNumInputs( + AOTInductorModelContainerHandle container_handle, + size_t* ret_num_inputs) { + auto* container = + reinterpret_cast( + container_handle); + CONVERT_EXCEPTION_TO_ERROR_CODE( + { *ret_num_inputs = container->num_inputs(); }) +} + +AOTIRuntimeError AOTInductorModelContainerGetInputName( + AOTInductorModelContainerHandle container_handle, + size_t input_idx, + const char** ret_input_names) { + auto* container = + reinterpret_cast( + container_handle); + CONVERT_EXCEPTION_TO_ERROR_CODE( + { *ret_input_names = container->input_name(input_idx); }) +} + +AOTIRuntimeError AOTInductorModelContainerGetNumOutputs( + AOTInductorModelContainerHandle container_handle, + size_t* ret_num_outputs) { + auto* container = + reinterpret_cast( + container_handle); + CONVERT_EXCEPTION_TO_ERROR_CODE( + { *ret_num_outputs = container->num_outputs(); }) +} + +AOTIRuntimeError AOTInductorModelContainerGetOutputName( + AOTInductorModelContainerHandle container_handle, + size_t output_idx, + const char** ret_output_names) { + auto* container = + reinterpret_cast( + container_handle); + CONVERT_EXCEPTION_TO_ERROR_CODE( + { *ret_output_names = container->output_name(output_idx); }) +} + +AOTIRuntimeError AOTInductorModelContainerGetCallSpec( + AOTInductorModelContainerHandle container_handle, + const char** in_spec, + const char** out_spec) { + auto* container = + reinterpret_cast( + container_handle); + CONVERT_EXCEPTION_TO_ERROR_CODE({ + *in_spec = container->get_in_spec(); + *out_spec = container->get_out_spec(); + }) +} + +AOTIRuntimeError AOTInductorModelCreate( + AOTInductorModelHandle* model_handle, + AOTInductorConstantMapHandle constant_map_handle){ + CONVERT_EXCEPTION_TO_ERROR_CODE({ + auto constant_map = std::make_shared(); + auto constant_array = std::make_shared>(); + auto input_map = reinterpret_cast*>(constant_map_handle); + + auto model = new torch::aot_inductor::AOTInductorModel( + constant_map, + constant_array, + "cpu", // device_str is hardcoded, as AOTInductorModelCreate is only use for CPU models + "" + ); + + if (input_map) { + for (auto const& kv : *input_map) { + constant_map->emplace(kv.first, kv.second); + } + } else { + model->load_constants(); + } + + *model_handle = reinterpret_cast(model); + })} + +AOTIRuntimeError AOTInductorModelRun( + AOTInductorModelHandle model_handle, + AtenTensorHandle* input_handles, + AtenTensorHandle* output_handles) { + auto model = + reinterpret_cast(model_handle); + CONVERT_EXCEPTION_TO_ERROR_CODE({ + AOTINoGradGuard guard; + model->run_impl( + input_handles, + output_handles, + (torch::aot_inductor::DeviceStreamType) nullptr, + nullptr); + }) +} + +AOTIRuntimeError AOTInductorModelDelete(AOTInductorModelHandle model_handle){ + CONVERT_EXCEPTION_TO_ERROR_CODE({ + auto model = reinterpret_cast( + model_handle); + delete model; + })} + +AOTIRuntimeError AOTInductorModelGetNumOutputs( + AOTInductorModelHandle model_handle, + size_t* ret_num_outputs) { + CONVERT_EXCEPTION_TO_ERROR_CODE({ + auto model = reinterpret_cast(model_handle); + *ret_num_outputs = model->num_outputs(); + }) +} + +AOTIRuntimeError AOTInductorModelUpdateConstantsMap( + AOTInductorModelHandle model_handle, + AOTInductorConstantMapHandle constant_map_handle) { + auto model = + reinterpret_cast(model_handle); + CONVERT_EXCEPTION_TO_ERROR_CODE({ + auto constant_map = std::make_shared(); + auto input_map = + reinterpret_cast*>( + constant_map_handle); + + for (auto const& kv : *input_map) { + constant_map->emplace(kv.first, kv.second); + } + model->update_constants_map(std::move(constant_map)); + }) +} + +AOTIRuntimeError AOTInductorModelContainerGetConstantsBlobSize( + AOTInductorModelContainerHandle container_handle, + uint64_t* ret_size) { + auto* container = + reinterpret_cast( + container_handle); + CONVERT_EXCEPTION_TO_ERROR_CODE( + { *ret_size = container->constant_blob_size(); }) +} + + +// Load weights from a single blob in weight_blob_ptr +AOTIRuntimeError AOTInductorModelUpdateConstantsFromBlob( + AOTInductorModelContainerHandle container_handle, + const uint8_t* weight_blob_ptr){ + auto* container = + reinterpret_cast( + container_handle); + CONVERT_EXCEPTION_TO_ERROR_CODE( + {container->update_constants_from_blob(weight_blob_ptr); }) + } + + +} // extern "C" diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cuda_cpp_scheduling.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cuda_cpp_scheduling.py new file mode 100644 index 0000000000000000000000000000000000000000..2496860ca1f7c72eadd86f908384e2f81983af4f --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cuda_cpp_scheduling.py @@ -0,0 +1,296 @@ +# mypy: allow-untyped-defs +import hashlib +import logging +from collections.abc import Sequence +from typing import cast + +from torch._inductor.codegen.cuda.cutlass_python_evt import ( + CutlassEVTCodegen, + MockCutlassHandler, +) +from torch._inductor.utils import Placeholder +from torch.utils._ordered_set import OrderedSet + +from ...._dynamo.utils import counters +from ... import config +from ...codecache import code_hash, get_path +from ...ir import Buffer, ComputedBuffer, CUDATemplateBuffer, Pointwise +from ...scheduler import ( + BaseSchedulerNode, + BaseScheduling, + FusedSchedulerNode, + SchedulerNode, + WhyNoFuse, +) +from ...utils import get_fused_kernel_name, get_kernel_metadata, sympy_product +from ...virtualized import V +from ..common import BackendFeature, IndentedBuffer + + +log = logging.getLogger(__name__) + + +class WhyNoFuseNames(WhyNoFuse): + def __init__(self, name1: str, name2: str) -> None: + self.name1 = name1 + self.name2 = name2 + + +class CUDACPPScheduling(BaseScheduling): + """ + Partial Scheduling implementation for CUDA C++ Kernels. + This class is intended to be used in combination with TritonScheduling, + and delegated to by CUDACombinedScheduling. + + It handles fusion decisions and CUDA C++ specific template code generation. + """ + + @classmethod + def get_backend_features(cls, device) -> OrderedSet[BackendFeature]: + return OrderedSet() + + def group_fn(self, sizes): + return tuple(V.graph.sizevars.simplify(sympy_product(s)) for s in sizes) + + @staticmethod + def is_cuda_cpp_template(node: BaseSchedulerNode) -> bool: + return isinstance(node, SchedulerNode) and isinstance( + node.node, CUDATemplateBuffer + ) + + def is_cuda_cpp_fused_template(self, node: BaseSchedulerNode) -> bool: + return isinstance(node, FusedSchedulerNode) and self.is_cuda_cpp_template(node) + + def can_fuse_vertical( + self, node1: BaseSchedulerNode, node2: BaseSchedulerNode + ) -> bool: + if self.is_cuda_cpp_template(node1) and isinstance(node2, BaseSchedulerNode): + assert node1.node, "node1.node should not be None" + return self._can_fuse_epilogue_impl( + cast(CUDATemplateBuffer, node1.node), + [], + node2, # type: ignore[arg-type] + ) + elif self.is_cuda_cpp_fused_template(node1) and isinstance( + node2, BaseSchedulerNode + ): + assert node1.node, "node1.node should not be None" + assert node2.node, "node2.node should not be None" + fnode1 = cast(FusedSchedulerNode, node1) + return self._can_fuse_epilogue_impl( + fnode1.get_template_node(), # type: ignore[arg-type] + self._unwrap_epilogue_nodes(fnode1), + node2, # type: ignore[arg-type] + ) + + return False + + def define_kernel(self, src_code: str, node_schedule) -> str: + wrapper = V.graph.wrapper_code + if src_code in wrapper.src_to_kernel: + kernel_name = wrapper.src_to_kernel[src_code] + else: + fused_name = ( + get_fused_kernel_name(node_schedule, config.triton.descriptive_names) + if config.triton.descriptive_names + else "" + ) + + # use the original src_code as the key + kernel_hash = hashlib.sha256(src_code.encode("utf-8")).hexdigest()[:8] + if fused_name == "fused": + # no EVT kernel, use the original kernel name + kernel_name = f"cutlass_{kernel_hash}" + else: + kernel_name = f"cutlass_{fused_name}_{kernel_hash}" + wrapper.src_to_kernel[src_code] = kernel_name + src_code = src_code.replace(str(Placeholder.KERNEL_NAME), kernel_name) + + _, _, kernel_path = get_path(code_hash(src_code), "py") + + compile_wrapper = IndentedBuffer() + compile_wrapper.writeline("async_compile.cuda(r'''") + compile_wrapper.splice(src_code, strip=True) + compile_wrapper.writeline( + f"''', 'so', aot_compile={str(V.graph.aot_mode)})" + ) + + metadata_comment = f"# kernel path: {kernel_path}" + origins, detailed_origins = get_kernel_metadata(node_schedule, wrapper) + metadata_comment += "\n" + origins + "\n" + detailed_origins + wrapper.define_kernel( + kernel_name, compile_wrapper.getvalue(), metadata_comment + ) + return kernel_name + + def codegen_template( + self, + template_node: BaseSchedulerNode, + epilogue_nodes: Sequence[BaseSchedulerNode], + prologue_nodes: Sequence[BaseSchedulerNode], + ): + """ + Codegen a CUDA template, possibly with fused epilogues + """ + counters["inductor"]["cuda_epilogue_fusion_counter"] += len(epilogue_nodes) + assert self.is_cuda_cpp_template(template_node), ( + "Template node passed to CUDAScheduler.codegen_template must be a SchedulerNode that wraps a CUDATemplateBuffer" + ) + template_node = cast(SchedulerNode, template_node) + _, (_numel, rnumel) = template_node.group + assert rnumel == 1 + ctb: CUDATemplateBuffer = cast(CUDATemplateBuffer, template_node.node) + epilogue_ir_nodes: list[Buffer] = [n.node for n in epilogue_nodes] # type: ignore[misc] + assert all(isinstance(n, ComputedBuffer) for n in epilogue_ir_nodes), ( + "Epilogue nodes must all be instances of ir.ComputedBuffer" + ) + kernel, render = ctb.make_kernel_render( # type: ignore[misc] + ctb, epilogue_nodes=epilogue_nodes + ) + with kernel: + for node in [template_node, *epilogue_nodes]: + node.mark_run() + + # typically there is a codegen pass which runs after mark_run + # for this kernel we've already generated the C++ code, but we still + # need to let the kernel know about loads/stores that occur in the fused + # kernel for memory planning to properly optimize allocations + ctb.emulate_store_fn() + for node in epilogue_ir_nodes: + with V.set_ops_handler(MockCutlassHandler(V.get_ops_handler())): + assert isinstance( + node, ComputedBuffer + ) # Not sure why we need to do this again + node.get_store_function()(CutlassEVTCodegen.get_index_vars(node)) + + with V.set_kernel_handler(kernel): + src_code = render() + node_schedule = [template_node, *epilogue_nodes] + kernel_name = self.define_kernel(src_code, node_schedule) + + # debug printing values of intermediate tensors + _, call_args, arg_signatures, _ = kernel.args.python_argdefs() + debug_printer_manager = V.graph.wrapper_code.debug_printer + debug_printer_manager.set_printer_args( + call_args, kernel_name, arg_signatures, kernel + ) + with debug_printer_manager: + self.codegen_comment(node_schedule, kernel_name) + kernel.call_kernel(kernel_name, ctb) + + V.graph.removed_buffers |= kernel.removed_buffers + self.free_buffers_in_scheduler() + + @staticmethod + def _unwrap_epilogue_nodes( + fused_node: FusedSchedulerNode, + ) -> list[BaseSchedulerNode]: + nodes = fused_node.get_nodes() + template_node = fused_node.get_template_node() + assert all(n.node is not None for n in nodes), ( + "All epilogue nodes should have an IRNode" + ) + # pyrefly: ignore [redundant-cast] + return cast( + list[BaseSchedulerNode], [n for n in nodes if n.node is not template_node] + ) + + def _can_fuse_epilogue_impl( + self, + cuda_template_buffer: CUDATemplateBuffer, + existing_epilogue_nodes: list[BaseSchedulerNode], + node_to_fuse: BaseSchedulerNode, + ) -> bool: + """ + Check if the given node can be fused with the epilogue. At the moment, Kernels + support fusion with Pointwise operations, wrapped in (named) ComputedBuffer nodes. + + Args: + cuda_template_buffer : A CUDATemplateBuffer object representing the CUDA template and it's result buffer + existing_epilogue_nodes : List[SchedulerNode]: The list of already fused epilogue nodes. + node_to_fuse: The SchedulerNode node to be checked if it can be fused with the epilogue. + Returns: + - bool: True if the given node can be fused with the epilogue, False otherwise. + + """ + why = WhyNoFuseNames(cuda_template_buffer.get_name(), node_to_fuse.get_name()) + + scheduler_nodes_to_fuse = node_to_fuse.get_nodes() + + assert isinstance(cuda_template_buffer, CUDATemplateBuffer) + + # Checks on constituent nodes + for s_node in scheduler_nodes_to_fuse: + node = s_node.node + + if not isinstance(node, ComputedBuffer): + why(f"{node} is not a ComputedBuffer") + return False + elif not isinstance(node.data, Pointwise): + why(f"{node} is not a Pointwise op") + return False + elif not node.get_computed_buffer_name(): # type: ignore[attr-defined] + why(f"{node} does not have a computed buffer name") + return False + + name = node.get_computed_buffer_name() # type: ignore[attr-defined] + # dtype can differ, and strides can differ as long as they are broadcastable + if node.get_size() != cuda_template_buffer.get_size(): + why( + f"{name}'s size: {node.get_size()} differs from {cuda_template_buffer.get_name()}'s \ +size: {cuda_template_buffer.get_size()}" + ) + return False + + assert len( + existing_epilogue_nodes + ) or cuda_template_buffer.get_name() in OrderedSet( + [rd.name for rd in node_to_fuse.read_writes.reads] + ), "First epilogue node must read from cuda template buffer" + + if node_to_fuse.has_aliasing_or_mutation(): + why(f"{node_to_fuse.get_name()} has aliasing or mutation") + return False + elif node_to_fuse.is_reduction(): + why( + f"{node_to_fuse.get_name()} is a reduction which is not yet supported by EVT" + ) + return False + elif ( + not config.cuda.cutlass_epilogue_fusion_enabled + or not config.epilogue_fusion + ): + why("cutlass epilogue fusion is not enabled") + return False + elif not cuda_template_buffer.supports_epilogue_fusion: + why("epilogue fusion is only supported for TMA-enabled gemm ops") + return False + + try: + from torch._inductor.codegen.cuda.cutlass_python_evt import ( + CutlassEVTCodegen, + ) + + CutlassEVTCodegen.ir_to_evt_python_code( + cuda_template_buffer.get_name(), + existing_epilogue_nodes + list(node_to_fuse.get_nodes()), + OrderedSet(), + ) + + except NotImplementedError as e: + not_implemented_op = str(e) + if not_implemented_op.startswith("_op_"): + not_implemented_op = not_implemented_op[4:] + why( + f"Cannot fuse epilogue node {node_to_fuse} into {cuda_template_buffer.name}, \ +likely due to unsupported operation: {not_implemented_op}" # noqa: G004, B950 + ) + return False + else: # Likely due to unsupported dtype. + why( + f"Cannot fuse epilogue node {node_to_fuse} into {cuda_template_buffer.name}. \ +Reason: {not_implemented_op}" # noqa: G004, B950 + ) + return False + + return True diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cuda_env.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cuda_env.py new file mode 100644 index 0000000000000000000000000000000000000000..9ca3afbd9ca57a7aed17ffe69d074c667dd2c09f --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cuda_env.py @@ -0,0 +1,55 @@ +import functools +import logging +import shutil +from typing import Optional + +import torch +from torch._inductor.utils import clear_on_fresh_cache + +from ... import config + + +log = logging.getLogger(__name__) + + +@clear_on_fresh_cache +@functools.lru_cache(1) +def get_cuda_arch() -> Optional[str]: + try: + cuda_arch = config.cuda.arch + if cuda_arch is None: + # Get Compute Capability of the first Visible device + major, minor = torch.cuda.get_device_capability(0) + return str(major * 10 + minor) + return str(cuda_arch) + except Exception: + log.exception("Error getting cuda arch") + return None + + +@clear_on_fresh_cache +@functools.lru_cache(1) +def is_datacenter_blackwell_arch() -> bool: + arch = get_cuda_arch() + if arch is None: + return False + arch_number = int(arch) + return arch_number >= 100 and arch_number < 110 + + +@clear_on_fresh_cache +@functools.lru_cache(1) +def get_cuda_version() -> Optional[str]: + try: + cuda_version = config.cuda.version + if cuda_version is None: + cuda_version = torch.version.cuda + return cuda_version + except Exception: + log.exception("Error getting cuda version") + return None + + +@functools.cache +def nvcc_exist(nvcc_path: Optional[str] = "nvcc") -> bool: + return nvcc_path is not None and shutil.which(nvcc_path) is not None diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cuda_kernel.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cuda_kernel.py new file mode 100644 index 0000000000000000000000000000000000000000..97643ef00a7bd63aa6887c5ee6645f1c788e45fd --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cuda_kernel.py @@ -0,0 +1,687 @@ +# mypy: allow-untyped-defs +import functools +import itertools +import logging +from collections import defaultdict +from collections.abc import Callable +from dataclasses import dataclass +from typing import Any, Literal, Optional, TYPE_CHECKING, Union + +from sympy import Expr, symbols + +import torch._inductor.config as config +from torch import dtype as torch_dtype +from torch._inductor.codegen.cpp_wrapper_cpu import CppWrapperCpu +from torch._inductor.scheduler import BaseSchedulerNode +from torch._inductor.utils import do_bench_using_profiling, OrderedSet, Placeholder +from torch.utils._sympy.value_ranges import ValueRanges + +from .cutlass_utils import DTYPE_TO_CUTLASS_TYPE + + +if TYPE_CHECKING: + from .cuda_template import ArgInfo + +from ...autotune_process import CUDABenchmarkRequest +from ...ir import ( + Buffer, + ChoiceCaller, + CUDATemplateBuffer, + IRNode, + Layout, + PrimitiveInfoType, + ShapeAsConstantBuffer, + TensorBox, +) +from ...utils import sympy_product +from ...virtualized import V +from ..common import ( + CSEVariable, + IndentedBuffer, + Kernel, + OpOverrides, + WorkspaceArg, + WorkspaceZeroMode, +) +from ..cpp_utils import CppPrinter, DTYPE_TO_CPP + + +if TYPE_CHECKING: + from torch._inductor.codegen.cuda.cuda_template import CUDATemplate + +log = logging.getLogger(__name__) + +cexpr = CppPrinter().doprint + + +def _normalize_idx(index: int, total_length: int) -> int: + return index if index >= 0 else index + total_length + + +ValidLayoutSymbols = Literal["M", "N", "K", "B", "lda", "ldb", "ldc", "ldd"] +ValidLayoutAttrs = Literal["size", "stride"] + + +@dataclass(frozen=True) +class LayoutArg: + node: IRNode + symbol: ValidLayoutSymbols + attr: ValidLayoutAttrs + dim: int + + def matches(self, node, attr, dim) -> bool: + return self.node == node and self.attr == attr and self.dim == dim + + +class CUDAKernel(Kernel): + """ + Baseclass for CUDA / Cutlass based Kernels + """ + + overrides = OpOverrides # type: ignore[assignment] + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.layout_args: dict[str, list[LayoutArg]] = defaultdict(list) + self.size_args: list[Union[Expr, int]] = [] + # Mapping from arg name to IRNode. + self.named_nodes: dict[str, IRNode] = {} + + def find_symbol( + self, node: IRNode, attr: ValidLayoutAttrs, dim: int + ) -> Optional[str]: + arg = self.find_layout_arg(node, attr, dim) + return arg.symbol if arg else None + + def find_layout_arg( + self, node: IRNode, attr: ValidLayoutAttrs, dim: int + ) -> Optional[LayoutArg]: + matches = [ + arg + for arg in itertools.chain.from_iterable(self.layout_args.values()) + if arg.matches(node, attr, dim) + ] + if len(matches) >= 1: + # Verify all matches have the same node, attribute, and dimension + # And if they come from the same node, whichever symbol we use is fine. + # if in runtime the logic changes, this would trigger guard + first_match = matches[0] + if not all( + match.node == first_match.node + and match.attr == first_match.attr + and match.dim == first_match.dim + for match in matches + ): + raise AssertionError("All matching layout args should be identical") + return first_match + return None + + def add_layout_arg( + self, symbol: ValidLayoutSymbols, node: IRNode, attr: ValidLayoutAttrs, dim: int + ): + arg = LayoutArg(node, symbol, attr, dim) + self.layout_args[symbol].append(arg) + + def init_layout_args(self) -> None: + X = self.named_nodes["X"] + W = self.named_nodes["W"] + Y = self.named_nodes["Y"] + Bias = self.named_nodes.get("Bias", None) + x_mdim = _normalize_idx(-2, len(X.get_size())) + x_kdim = _normalize_idx(-1, len(X.get_size())) + w_kdim = _normalize_idx(-2, len(W.get_size())) + w_ndim = _normalize_idx(-1, len(W.get_size())) + y_mdim = _normalize_idx(-2, len(Y.get_size())) + y_ndim = _normalize_idx(-1, len(Y.get_size())) + self.add_layout_arg("M", X, "size", x_mdim) + self.add_layout_arg("K", X, "size", x_kdim) + self.add_layout_arg("K", W, "size", w_kdim) + self.add_layout_arg("N", W, "size", w_ndim) + self.add_layout_arg("M", Y, "size", y_mdim) + self.add_layout_arg("N", Y, "size", y_ndim) + if len(X.get_size()) > 2: + self.add_layout_arg("B", X, "size", 0) + + lda_dim = self.find_ld_idx(X) + ldb_dim = self.find_ld_idx(W) + ldc_dim = self.find_ld_idx(Bias) if Bias else None + ldd_dim = self.find_ld_idx(Y) + self.add_layout_arg("lda", X, "stride", lda_dim) + self.add_layout_arg("ldb", W, "stride", ldb_dim) + if Bias is not None and ldc_dim is not None: + self.add_layout_arg("ldc", Bias, "stride", ldc_dim) + self.add_layout_arg("ldd", Y, "stride", ldd_dim) + + def get_layout_args(self) -> tuple[Union[Expr, int], ...]: + X = self.named_nodes["X"] + W = self.named_nodes["W"] + Y = self.named_nodes["Y"] + Bias = self.named_nodes.get("Bias", None) + mdim = _normalize_idx(-2, len(X.get_size())) + ndim = _normalize_idx(-1, len(W.get_size())) + kdim = _normalize_idx(-1, len(X.get_size())) + + def get_ld(node) -> Union[Expr, int]: + dim = self.find_ld_idx(node) + return node.get_stride()[dim] + + M = X.get_size()[mdim] + N = W.get_size()[ndim] + K = X.get_size()[kdim] + B = X.get_size()[0] if len(X.get_size()) > 2 else 1 + LDA = get_ld(X) + LDB = get_ld(W) + LDC = get_ld(Bias) if Bias else 0 + LDD = get_ld(Y) + return (M, N, K, B, LDA, LDB, LDC, LDD) + + def get_dynamic_shape_args(self) -> list[Union[Expr, int]]: + return [*self.get_layout_args(), *self.size_args] + + def get_offset_args(self) -> list[Expr]: + return [node.get_layout().offset for node in self.named_nodes.values()] + + @staticmethod + def find_ld_idx(node: IRNode) -> int: + strides = node.get_stride() + # Handle 1D tensor case + if V.graph.sizevars.statically_known_equals(strides[-1], 1): + return _normalize_idx(-2, len(strides)) + + assert V.graph.sizevars.statically_known_equals(strides[-2], 1), strides[-2] + return _normalize_idx(-1, len(strides)) + + +class CUDATemplateKernel(CUDAKernel): + """ + Template kernels defined by CUDA / Cutlass in C++. + """ + + _EXTRA_CPP_ARGS = "size_t* workspace_size, uint8_t* workspace, cudaStream_t stream" + + def __init__( + self, + kernel_name: str, + runtime_arg_info: list["ArgInfo"], + runtime_arg_values: list[Any], + ) -> None: + """ + Initializes a new instance of the CUDATemplateKernel class. + + Args: + kernel_name (str): The name of the kernel. + """ + super().__init__() + self.kernel_name = kernel_name + self.runtime_arg_info = runtime_arg_info + self.runtime_arg_values = runtime_arg_values + + def check_not_null(self, node: IRNode) -> str: + """ + Generates code to check that a node is not null. + """ + if node is None: + return "" + + size_str = self.size(node, 0, -1) + name_str = self.arg_name(node) + if name_str is None: + return "" + + res = IndentedBuffer(initial_indent=2) + res.tabwidth = 1 + res.splice( + f""" + {{ + if (!{name_str}) {{ + int64_t {name_str}_size = {size_str}; + if ({name_str}_size > 0) {{ + throw std::runtime_error("input {name_str} is null but size is not 0!"); + }} + }} + }} + """ + ) + return res.getvalue() + + def get_signature(self) -> str: + return self.signature + + def def_kernel( + self, + inputs: list[IRNode], + outputs: list[IRNode], + names_str: str = "", + input_reorder: Optional[list[int]] = None, + ) -> str: + """ + Hook called from template code to generate function definition and + needed args. + + Args: + inputs: List of input IRNodes + outputs: List of output IRNodes + names_str: Comma separated list of input + output argument names. + input_reorder: The actual order of input nodes. + e.g. The template might have input argument defined as [X, W, Bias], + and the actual input passed into this template could be [Bias, X, W]. + In this case, the `input_reorder` would be [2, 0, 1]. + additional_size_args: Additional size arguments for epilogue inputs + """ + # NB: name order matters here, it's used to match up offsets + names = [x.strip() for x in names_str.strip().split(",")] + if len(inputs) + len(outputs) != len(names): + raise RuntimeError( + f"{len(inputs) + len(outputs)=} != {len(names)=}, {inputs=}, {outputs=}, {names=}" + ) + + if input_reorder is not None: + assert len(inputs) == len(input_reorder) + else: + input_reorder = list(range(len(inputs))) + + for idx in input_reorder: + name = names[idx] + node = inputs[idx] + if node is not None: + self.named_nodes[name] = node + self.args.input_buffers[node.get_name()] = name + + free_symbols: OrderedSet[Expr] = OrderedSet() + for name, node in zip(names[len(inputs) : len(inputs) + len(outputs)], outputs): + if node is not None: + # NB: named nodes must be populated in the order of names + self.named_nodes[name] = node + self.args.output_buffers[node.get_name()] = name + + if name not in ( + "X", + "W", + "Bias", + "Y", + ): # we handle these symbolic shapes explicitly + for expr in itertools.chain(node.get_size(), node.get_stride()): + if isinstance(expr, Expr): + for s in expr.free_symbols: + free_symbols.add(s) # type: ignore[arg-type] + + arg_defs, *_ = self.args.cpp_argdefs(DTYPE_TO_CUTLASS_TYPE) + + self.init_layout_args() + size_vars = ["M", "N", "K", "B", "lda", "ldb", "ldc", "ldd"] + size_vars.extend(str(s) for s in free_symbols) + self.size_args.extend(free_symbols) + size_args = [f"const int {s}" for s in size_vars] + offset_args = [f"const int {name}_offset" for name in self.named_nodes] + runtime_arg_decls = ",".join( + [f"{arg.ty} {arg.name}" for arg in self.runtime_arg_info] + ) + if runtime_arg_decls: + runtime_arg_decls += ", " + + signature = ( + f"int {self.kernel_name}({', '.join(arg_defs + size_args + offset_args)},\ + {runtime_arg_decls}{self._EXTRA_CPP_ARGS})" + ) + self.signature = signature + return signature + + def call_kernel( + self, + name: str, + node: "CUDATemplateBuffer", # type: ignore[name-defined] + ) -> None: + """ + Generates code to call the kernel through V.graph.wrapper_code. + used from within torch._inductor.wrapper.PythonWrapperCodegen + + name: Name of kernel function. + node: The CUDATemplateBuffer node which contains information about the kernel, it's fused epilogue nodes + as well as all required inputs and outputs. + """ + wrapper = V.graph.wrapper_code + + arg_types: list[Any] + if V.graph.cpp_wrapper: + # Make sure we initialize these kernels since they're exported as + # C-style symbol names. + assert isinstance(wrapper, CppWrapperCpu) + wrapper.initialized_kernels[name] = self + # We always originally initialize name with "KERNEL_NAME". So, we + # we replace with the real kernel name passed as an arg to this function. + self.signature = self.signature.replace(str(Placeholder.KERNEL_NAME), name) + _, call_args, arg_types = self.args.cpp_argdefs(DTYPE_TO_CUTLASS_TYPE) + else: + _, call_args, _, arg_types = self.args.python_argdefs() + + dynamic_shape_args = self.get_dynamic_shape_args() + offset_args = self.get_offset_args() + call_args.extend(dynamic_shape_args) # type: ignore[arg-type] + call_args.extend(offset_args) # type: ignore[arg-type] + for arg in self.runtime_arg_values: + call_args.append(str(arg)) + arg_types.extend("const int" for _ in dynamic_shape_args) + arg_types.extend("const int" for _ in offset_args) + for arg in self.runtime_arg_info: + arg_types.append(arg.ty) + # dynamo wraps unspec variable as 0d CPU tensor, need convert to scalar + for i in range(len(call_args)): + if V.graph.is_unspec_arg(call_args[i]): + call_args[i] = call_args[i] + ".item()" + elif isinstance(arg_types[i], torch_dtype): + call_args[i] = ( + call_args[i] + if V.graph.cpp_wrapper + else f"c_void_p({call_args[i]}.data_ptr())" + ) + + # workspace_size ptr is NULL to mark this call is not intended for retrieving workspace_size. + # workspace_size should have already been retrieved prior to this call. + # workspace_size is here. + call_args.append("nullptr" if V.graph.cpp_wrapper else "None") + if V.graph.cpp_wrapper: + arg_types.append("size_t*") + + if node.get_workspace_size() > 0: + ws = WorkspaceArg( + count=node.get_workspace_size(), + device=V.graph.get_current_device_or_throw(), + zero_mode=WorkspaceZeroMode.UNINITIALIZED, + outer_name=WorkspaceArg.unique_name(), + ) + wrapper.generate_workspace_allocation(ws) + workspace = str(ws.outer_name) + call_args.append( + workspace + if V.graph.cpp_wrapper + else f"c_void_p({workspace}.data_ptr())" + ) + else: + ws = None + call_args.append("nullptr" if V.graph.cpp_wrapper else "None") + if V.graph.cpp_wrapper: + arg_types.append("uint8_t*") + + wrapper.generate_kernel_call( + name, + call_args, + triton=False, + arg_types=arg_types, + ) + if ws: + wrapper.generate_workspace_deallocation(ws) + + def dtype(self, node: IRNode) -> Optional[str]: + """ + Generates code which represents dtype of a given node. + """ + + if node is None: + return "void" + return DTYPE_TO_CPP.get(node.get_layout().dtype) + + def cutlass_dtype(self, node: IRNode, default_dtype="void") -> Optional[str]: + # Helper method, called into from CUTLASSGemmTemplate + if node is None: + return default_dtype + from torch._inductor.codegen.cuda.cuda_template import CUTLASSTemplate + + return CUTLASSTemplate._DTYPE_TO_CUTLASS[node.get_layout().dtype] + + def max_valid_index(self, node: IRNode, default=-1): + # Helper method, called into from CUTLASSGemmTemplate + if node is None: + return default + max_valid_offset = 0 + for i in range(len(node.get_size())): + max_valid_offset += (node.get_size()[i] - 1) * node.get_stride()[i] + return max_valid_offset + + def ptr(self, node: IRNode) -> str: + """ + Generates code which represents pointer of a given node. + """ + + if node is None: + return "nullptr" + arg_name = self.arg_name(node) + if arg_name is None: + return "nullptr" + return f"{arg_name} + {arg_name}_offset" + + def size( + self, + node: IRNode, + start_index: int, + end_index: Optional[int] = None, + default_value: int = 0, + ) -> str: + """ + Hook called from template code to get the size of an arg. + Generates code which represents size of a given node in [start_index, end_index). + If node is None, returns default_value. + + TODO: Will add needed args to pass it in if it is dynamic. + """ + + if node is None: + return str(default_value) + + start_index = _normalize_idx(start_index, len(node.get_size())) + if end_index is None: + end_index = start_index + end_index = _normalize_idx(end_index, len(node.get_size())) + sizes = [ + self.find_symbol(node, "size", dim=i) or node.get_size()[i] + for i in range(start_index, end_index + 1) + ] + if len(sizes) == 0: + return str(default_value) + + sizes = [symbols(v) if isinstance(v, str) else v for v in sizes] + val = sympy_product(sizes) + return val + + def stride(self, node: IRNode, index: int, default_value: int = 0) -> str: + """ + Hook called from template code to get the stride of an arg. + Generates code which represents stride of a given node at index. + If node is None, returns default_value. + + TODO: Will add needed args to pass it in if it is dynamic. + """ + + if node is None: + return str(default_value) + + index = _normalize_idx(index, len(node.get_size())) + if index < 0: + return str(default_value) + + stride = node.get_stride()[index] + if V.graph.sizevars.statically_known_leq(stride, 1): + return str(stride) + return self.find_symbol(node, "stride", dim=index) or str(stride) + + def batch_stride(self, node: IRNode, default_value: int = 0) -> str: + """ + Hook called from template code to get the batch stride of an arg. + Returns 0 if batch dim is not present. + + This method assumes that batch stride is the largest stride. + """ + + if node is None: + return str(default_value) + + if len(node.get_size()) < 3: + return str(default_value) + + batch_stride = node.get_stride()[0] + if V.graph.sizevars.statically_known_leq(batch_stride, 1): + return str(batch_stride) + + return "{}*{}".format( + self.find_symbol(node, "size", dim=1) or node.get_size()[1], + self.find_symbol(node, "size", dim=2) or node.get_size()[2], + ) + + def row_or_column_stride(self, node: IRNode, default_value: int = 0) -> str: + """ + Hook called from template code to get the row or column stride of an arg. + This is required by some CUTLASS 2.X APIs. + If the node is in row_major, it returns stride[-2]. + If the node is in column_major, it returns stride[-1]. + + TODO: Will add needed args to pass it in if it is dynamic. + """ + + if node is None or len(node.get_stride()) < 2: + return str(default_value) + + stride0 = node.get_stride()[-1] + stride1 = node.get_stride()[-2] + if stride0 == 1: + return cexpr(self.rename_indexing(stride1)) + elif stride1 == 1: + return cexpr(self.rename_indexing(stride0)) + else: + raise RuntimeError( + f"At least 1 stride should be 1. Strides: {node.get_stride()=}" + ) + + def load(self, name: str, index: Expr, mode: Any = None) -> CSEVariable: + """ + Mock load function for memory planning to optimize allocations properly. + """ + return self.create_cse_var(name, bounds=ValueRanges.unknown()) + + def store(self, name: str, index: Expr, value: Any, mode: Any = None) -> None: + """ + Mock store function for memory planning to optimize allocations properly. + """ + self.store_buffer_names.add(name) + + +class CUDATemplateCaller(ChoiceCaller): + """ + CUDATemplateCaller + + This class represents a caller for CUDA template kernels. It is a subclass of ChoiceCaller. + Attributes: + name (str): The name of the caller. + category (str): The category of the caller. + bmreq (CUDABenchmarkRequest): The benchmark request for the caller. + template_buffer (CUDATemplateBuffer): The template buffer for the caller. + """ + + def __init__( + self, + name: str, + category: str, + input_nodes: list[Buffer], + layout: Layout, + make_kernel_render: Callable[ + [CUDATemplateBuffer, Optional[list[BaseSchedulerNode]]], + tuple[CUDATemplateKernel, functools.partial[str]], + ], + bmreq: CUDABenchmarkRequest, + supports_epilogue_fusion: bool, + template: "CUDATemplate", # type: ignore[name-defined] + info_kwargs: Optional[ + dict[str, Union[PrimitiveInfoType, list[PrimitiveInfoType]]] + ], # type: ignore[type-arg] + description: str, + ) -> None: + super().__init__(name, input_nodes, layout, description) + self.category = category + self.make_kernel_render = make_kernel_render + self.bmreq = bmreq + self.supports_epilogue_fusion = supports_epilogue_fusion + self.template = template + self.info_kwargs = info_kwargs + + def precompile(self) -> None: + assert self.bmreq is not None + self.bmreq.precompile() + + def benchmark(self, *args, out) -> float: + assert self.bmreq is not None + if config.profile_bandwidth_with_do_bench_using_profiling: + algo = self.bmreq.make_run_fn(*args, out=out) + return do_bench_using_profiling(algo) + return self.bmreq.benchmark(*args, out=out) + + def __str__(self) -> str: + return f"CUDATemplateCaller(source_file={self.bmreq.source_file})" + + def call_name(self) -> str: + return f"cuda_template_kernels.{self.name}" + + def kernel_hash_key(self) -> str: + """ + Return kernel hash key that does not depend on swizzle. + """ + return "-".join( + [ + self.category, + self.bmreq.hash_key, + ] + ) + + def hash_key(self) -> str: + """ + Return kernel hash key that does not depend on swizzle. + """ + swizzle_str: str = ( + str(self.info_kwargs.get("swizzle")) + if isinstance(self.info_kwargs, dict) + else "None" + ) + return "-".join( + [ + self.category, + self.bmreq.hash_key, + swizzle_str, + ] + ) + + def info_dict(self) -> dict[str, Union[PrimitiveInfoType, list[PrimitiveInfoType]]]: + """ + Information returned here is logged to the autotune log file when that is enabled. + + In general, we should avoid calling this function as it is expensive to compute, + and can add up very fast. + """ + if self.info_kwargs is not None and "op" in self.info_kwargs: + op: Any = self.info_kwargs["op"] + return { + "backend": "CUDA", + "op_type": type(op).__name__, + "op_conf_name": str(op.configuration_name()), + "op_arch": str(op.arch), + "tile_shape": str(op.tile_description.tile_shape), + "epilogue_schedule": str(op.epilogue_schedule), + "kernel_schedule": str(op.kernel_schedule), + "element_accumulator": str(op.accumulator_type()), + "op_name": str(op.procedural_name()), + "instruction_shape": str( + op.tile_description.math_instruction.instruction_shape + ), + "swizzle": str(self.info_kwargs["swizzle"]), + } + else: + return {"backend": "CUDA", "op_type": "unknown"} + + def output_node(self) -> Union[TensorBox, ShapeAsConstantBuffer]: + self.bmreq.update_workspace_size() + return TensorBox.create( + CUDATemplateBuffer( + layout=self.layout, + inputs=self.input_nodes, + make_kernel_render=self.make_kernel_render, + workspace_size=self.bmreq.workspace_size, + supports_epilogue_fusion=self.supports_epilogue_fusion, + template=self.template, + ) + ) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cuda_template.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cuda_template.py new file mode 100644 index 0000000000000000000000000000000000000000..79dfa9c6c391fe10ce2c4a657aea83b1639f4f5d --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cuda_template.py @@ -0,0 +1,394 @@ +# mypy: allow-untyped-defs +import functools +import hashlib +import itertools +from dataclasses import dataclass +from typing import Any, Optional, TYPE_CHECKING, Union +from typing_extensions import override +from unittest.mock import patch + +import sympy + +import torch +from torch._inductor import config +from torch._inductor.utils import clear_on_fresh_cache, Placeholder +from torch._logging import getArtifactLogger + +from ...autotune_process import CUDABenchmarkRequest, TensorMeta +from ...ir import Buffer, CUDATemplateBuffer, IRNode, Layout +from ...utils import IndentedBuffer, unique +from ...virtualized import V +from ..common import KernelTemplate +from .cuda_kernel import CUDATemplateCaller, CUDATemplateKernel +from .cutlass_utils import DTYPE_TO_CUTLASS_TYPE + + +if TYPE_CHECKING: + from ...scheduler import BaseSchedulerNode # noqa: TC004 +else: + BaseSchedulerNode = Any + +GemmOperation = Any + +autotuning_log = getArtifactLogger(__name__, "autotuning") + + +@dataclass(frozen=True) +class ArgInfo: + name: str + ty: str + + +@clear_on_fresh_cache +class CUDATemplate(KernelTemplate): + index_counter = itertools.count() + # dict of cache key to (code, size_args) + code_cache: dict[str, tuple[str, tuple[int, ...], tuple[int, ...]]] = {} + cache_clear = staticmethod(code_cache.clear) + + def __init__( + self, + name: str, + input_nodes: list[Buffer], + layout: Layout, + input_reorder: Optional[list[int]] = None, + ) -> None: + """ + Baseclass for CUDA C++ Templates, derived from KernelTemplate. + Not to be instantiated directly. + + Args: + name (str): The name of the CUDATemplate object. + input_nodes (List[IRNode]): A list of input IRNodes. + layout (Layout): The layout of the output buffer / tensor. + input_reorder (Optional[List[int]]): An optional list that specifies + the order of the input nodes. + """ + super().__init__(name) + self.input_nodes = input_nodes + self.output_node: Buffer = Buffer(name="buf_out", layout=layout) + self.input_reorder = input_reorder + self.layout = layout + + @classmethod + @functools.lru_cache(None) + # pyrefly: ignore [bad-override] + def _template_from_string(cls, source: str) -> Any: + return KernelTemplate._template_from_string(source) + + @staticmethod + def supports_epilogue_fusion(op: GemmOperation) -> bool: + return False + + def make_key(self, name: str, input_key: str, layout_repr: str) -> str: + """ + Make a key for the code cache. The idea of the method is to cache + everything that matters but doesn't include runtime param values, i.e., + self.get_runtime_arg_values(). + + Args: + kwargs: Additional keyword arguments. Including op (GemmOperation). + """ + return hashlib.sha256( + str( + ( + input_key, + self.input_reorder, + # output layout, same as self.output_node.get_layout() + layout_repr, + self.get_runtime_arg_info(), + name, + ) + ).encode("utf-8") + ).hexdigest() + + def generate_code_and_args( + self, name: str, input_key: str, layout_repr: str, **kwargs + ) -> tuple[str, tuple[int, ...]]: + """ + Generate code and args with caching. We cache the code even if runtime + args are different. + """ + key: Optional[str] = None + if config.cuda.enable_caching_codegen: + key = self.make_key(name=name, input_key=input_key, layout_repr=layout_repr) + + if key is not None and key in self.code_cache: + code, size_args, offset_args = self.code_cache[key] + extra_args = tuple( + list(size_args) + + list(offset_args) + + list(self.get_runtime_arg_values(**kwargs)) + ) + return code, extra_args + + kernel_name = str(Placeholder.KERNEL_NAME) + kernel = CUDATemplateKernel( + kernel_name=kernel_name, + runtime_arg_info=self.get_runtime_arg_info(), + runtime_arg_values=self.get_runtime_arg_values(**kwargs), + ) + with patch.object(V.graph, "get_dtype", self._fake_get_dtype(self.output_node)): + code = self.render(kernel=kernel, **kwargs) + _, call_args, _, _ = kernel.args.python_argdefs() + autotuning_log.debug("Generated Code:\n%s", code) + autotuning_log.debug( + "Args: cpp_argdefs: %s, python_argdefs: %s", + kernel.args.cpp_argdefs(DTYPE_TO_CUTLASS_TYPE), + kernel.args.python_argdefs(), + ) + + input_reorder = ( + self.input_reorder + if self.input_reorder is not None + else list(range(len(self.input_nodes))) + ) + expected_args = list( + unique(self.input_nodes[idx].get_name() for idx in input_reorder) + ) + expected_args.extend([self.output_node.get_name()]) + assert list(call_args)[: len(expected_args)] == expected_args, ( + call_args, + expected_args, + ) + V.graph.sizevars.size_hints(map(sympy.expand, call_args[len(expected_args) :])) + size_args = V.graph.sizevars.size_hints(kernel.get_dynamic_shape_args()) + offset_args = V.graph.sizevars.size_hints(kernel.get_offset_args()) + + if key is not None: + self.code_cache[key] = code, size_args, offset_args + + # extra args has runtime params, which shouldn't be cached + extra_args = tuple( + list(size_args) + list(offset_args) + self.get_runtime_arg_values(**kwargs) + ) + + return code, extra_args + + def generate( # type: ignore[override] + self, + name: str, + description: str, + input_key: str, + layout_repr: str, + input_tensor_meta: Union[TensorMeta, list[TensorMeta]], + output_tensor_meta: Union[TensorMeta, list[TensorMeta]], + **kwargs, + ) -> CUDATemplateCaller: + """ + Generates the CUDA template caller object for the given GEMM template and operation. + This CUDATemplateCaller may be used to call and benchmark the generated CUDA kernel + in a standalone manner to enable Autotuning. + + Args: + description: op name followed by swizzle. + kwargs: Additional keyword arguments. + + Returns: + A CUDATemplateCaller object representing the generated CUDA template caller. + """ + code, extra_args = self.generate_code_and_args( + name=name, + input_key=input_key, + layout_repr=layout_repr, + **kwargs, + ) + + # not caching since kernel name is needed below + kernel_hash = hashlib.sha256(code.encode("utf-8")).hexdigest()[:8] + kernel_name = f"cutlass_{kernel_hash}" + code = code.replace(self.name, kernel_name) + + # create the BenchmarkRequest + bmreq = CUDABenchmarkRequest( + kernel_name=kernel_name, + input_tensor_meta=input_tensor_meta, + output_tensor_meta=output_tensor_meta, + extra_args=extra_args, + source_code=code, + ) + + # kwargs has "op" argument in case of CUTLASSGemmTemplate + op = kwargs["op"] + if not op: + supports_epilogue_fusion = False + else: + # epilogue fusion is only supported for TMA kernels + supports_epilogue_fusion = self.supports_epilogue_fusion(op) + + def make_kernel_render( + template_node: CUDATemplateBuffer, + epilogue_nodes: Optional[list[BaseSchedulerNode]] = None, + ) -> tuple[CUDATemplateKernel, functools.partial[str]]: + assert supports_epilogue_fusion or not epilogue_nodes, ( + "epilogue fusion is not supported for this kernel" + ) + kernel = CUDATemplateKernel( + kernel_name=str(Placeholder.KERNEL_NAME), + runtime_arg_info=self.get_runtime_arg_info(), + runtime_arg_values=self.get_runtime_arg_values(**kwargs), + ) + render = functools.partial( + self.render, + kernel=kernel, + template_buffer_node=template_node, + epilogue_nodes=epilogue_nodes, + **kwargs, # includes "op" argument in case of CUTLASSGemmTemplate + ) + return kernel, render + + return CUDATemplateCaller( + kernel_name, + "cutlass_gemm", + self.input_nodes, + self.output_node.get_layout(), + make_kernel_render, + bmreq, + supports_epilogue_fusion, + self, + kwargs, + description, + ) + + def header(self) -> IndentedBuffer: + res = IndentedBuffer() + res.splice( + """ + #include + #include + #include + #include + #include + """ + ) + return res + + def globals(self) -> IndentedBuffer: + res = IndentedBuffer() + res.splice( + """ + // We compile all models with -fvisibility=hidden. Any symbols that need to be + // exposed in the final shared library must be declared with PT_EXPORT to make + // them visible. + #ifdef __GNUC__ // Applies to any compiler with GNU extensions (clang and g++) + #define PT_EXPORT __attribute__((__visibility__("default"))) + #else + #ifdef _WIN32 + #define PT_EXPORT __declspec(dllexport) + #else + #define PT_EXPORT + #endif + #endif + """ + ) + return res + + def render(self, **kwargs) -> str: + raise NotImplementedError + + def get_runtime_arg_info(self) -> list[ArgInfo]: + return [] + + def get_runtime_arg_values(self, **kwargs) -> list[Any]: + return [] + + +class CUTLASSTemplate(CUDATemplate): + """ + CUTLASSTemplate is a class that provides a template for generating CUTLASS Templates. Used as a baseclass for the + CUTLASSGemmTemplate, providing functionality that might also be relevant for non-GEMM CUTLASS Kernels. + """ + + def header(self) -> IndentedBuffer: + res = super().header() + res.splice( + """ + #include "cute/tensor.hpp" + #include "cutlass/cutlass.h" + #include "cutlass/numeric_types.h" + #include "cutlass/tensor_ref.h" + #include "cutlass/util/host_tensor.h" + #include "cutlass/util/reference/host/tensor_fill.h" + #include "cutlass/util/reference/device/tensor_fill.h" + #include "cutlass/util/device_memory.h" + """ + ) + return res + + def globals(self) -> IndentedBuffer: + res = super().globals() + res.splice( + """ + using namespace cute; + #define CUTLASS_CHECK(status) \\ + { \\ + cutlass::Status error = status; \\ + if (error != cutlass::Status::kSuccess) { \\ + auto msg = std::string("[") + __FILE__ + "] Got cutlass error: " + \\ + cutlassGetStatusString(error) + " at: " + std::to_string(__LINE__); \\ + throw std::runtime_error(msg); \\ + } \\ + } + + // Used as pass-through functor in EVT just for type casting / rounding + template + struct identity_op { + CUTLASS_HOST_DEVICE + T operator()(T val) const { return val; } + }; + + """ + ) + return res + + def cute_int(self, int_str: str, var_name: str) -> str: + res = "" + if int_str in ("1", "1L"): + res = "cute::Int<1>{}" + else: + res = int_str + + return f"{res} /* {var_name} */" + + _DTYPE_TO_CUTLASS = { + torch.float32: "float", + torch.float64: "double", + torch.float16: "cutlass::half_t", + torch.int32: "int32_t", + torch.int16: "int16_t", + torch.int8: "int8_t", + torch.uint8: "uint8_t", + torch.bool: "bool", + torch.bfloat16: "cutlass::bfloat16_t", + torch.float8_e4m3fn: "cutlass::float_e4m3_t", + } + + _DTYPE_TO_CUTLASS_SPARSE_META = { + torch.int32: "uint32_t", + torch.int16: "uint16_t", + } + + def cutlass_type_cast(self, node: IRNode, ptr: str) -> str: + if node is None: + return ptr + else: + return f"({self._DTYPE_TO_CUTLASS.get(node.get_dtype())}*)({ptr})" + + def cutlass_sparse_meta_type_cast(self, node: IRNode, ptr: str) -> str: + if node is None: + return ptr + else: + return ( + f"({self._DTYPE_TO_CUTLASS_SPARSE_META.get(node.get_dtype())}*)({ptr})" + ) + + @override + def get_runtime_arg_info(self) -> list[ArgInfo]: + return [ArgInfo("swizzle", "const uint8_t")] + + @override + def get_runtime_arg_values(self, **kwargs) -> list[Any]: + """ + Helper method to retrieve runtime args from generate kwargs + """ + return [kwargs[arg.name] for arg in self.get_runtime_arg_info()] diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cutlass_cache.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cutlass_cache.py new file mode 100644 index 0000000000000000000000000000000000000000..66db98867b4131631540d262b4e7eb4c932cc02a --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cutlass_cache.py @@ -0,0 +1,119 @@ +# mypy: allow-untyped-defs +import functools +import hashlib +import inspect +import json +import logging +import os +import time +from typing import Any, Optional + +import torch._inductor.config as config +from torch._inductor.codecache import cutlass_key +from torch._inductor.codegen.cuda import cutlass_utils, serialization +from torch._inductor.codegen.cuda.cuda_env import get_cuda_arch, get_cuda_version +from torch._inductor.codegen.cuda.serialization import get_cutlass_operation_serializer +from torch._inductor.runtime.cache_dir_utils import cache_dir +from torch._inductor.utils import clear_on_fresh_cache + + +log = logging.getLogger(__name__) + + +CONFIG_PREFIX: str = "configs" + + +def get_config_request_key( + arch: str, + cuda_version: str, + instantiation_level: str, +) -> str: + """ + Return a key for the full ops, based on cutlass key, arch, cuda version, instantiation level, and serialization.py file hash. + """ + + # Get hash of serialization.py and cutlass_utils.py files using their module file paths + def get_file_hash(file_module): + file_path = inspect.getfile(file_module) + with open(file_path, "rb") as f: + return hashlib.sha256(f.read()).hexdigest() + + serialization_hash = get_file_hash(serialization) + cutlass_utils_hash = get_file_hash(cutlass_utils) + + hash_target = "-".join( + [ + cutlass_key().hex(), + arch, + cuda_version, + instantiation_level, + serialization_hash, + cutlass_utils_hash, + ] + ) + return hashlib.sha256(hash_target.encode("utf-8")).hexdigest()[0:8] + + +def _generate_config_filename(request_key: str) -> str: + """ + Generate a filename for the full ops. + """ + return f"{CONFIG_PREFIX}_{request_key}.json" + + +@clear_on_fresh_cache +@functools.cache +def maybe_fetch_ops() -> Optional[list[Any]]: + """ + Fetch ops from databases. + """ + if config.force_disable_caches: + return None + + # setup + arch: str = get_cuda_arch() + # get_cuda_version might return "12.4.0" or "12.4" + # but we want to use "12.4" + version: str = ".".join(get_cuda_version().split(".")[:2]) + instantiation_level: str = config.cuda.cutlass_instantiation_level + + # filename and filepath + request_key: str = get_config_request_key(arch, version, instantiation_level) + filename: str = _generate_config_filename(request_key) + filepath: str = os.path.join(cache_dir(), filename) + + # try fetch + serialized_ops: Optional[list[str]] = None + start_time = time.time() + if os.path.isfile(filepath): + # locally + try: + with open(filepath) as f: + serialized_ops = json.load(f) + + assert isinstance(serialized_ops, list), ( + f"Expected serialized ops is a list, got {type(serialized_ops)}" + ) + except Exception: + log.warning( + "Failed to load CUTLASS config %s from local cache", + filename, + exc_info=True, + ) + serialized_ops = None + elif config.is_fbcode(): + from torch._inductor.fb.cutlass_remote_cache import ( + maybe_fetch_cutlass_configs_from_remote, + ) + + # from remote + serialized_ops = maybe_fetch_cutlass_configs_from_remote(filepath) + + if serialized_ops is None: + return None + + # deserialize + serializer = get_cutlass_operation_serializer() + full_ops = [serializer.deserialize(x) for x in serialized_ops] # type: ignore[union-attr] + log.info("Loaded ops from %s cache in %.3fs", filename, time.time() - start_time) + return full_ops diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cutlass_python_evt.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cutlass_python_evt.py new file mode 100644 index 0000000000000000000000000000000000000000..e6b7d2afe6c39e27e81c3b78d2c411f3cdf7193e --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cutlass_python_evt.py @@ -0,0 +1,326 @@ +import itertools +from collections.abc import Generator, Iterable, Iterator, Sequence +from contextlib import contextmanager +from os import linesep +from typing import Any, Optional + +import sympy + +import torch +import torch._inductor.virtualized as virtualized +from torch._inductor.ir import ComputedBuffer, Pointwise +from torch._inductor.ops_handler import DefaultHandler, WrapperHandler +from torch._inductor.scheduler import BaseSchedulerNode +from torch._inductor.utils import DelayReplaceLine, IndentedBuffer, OrderedSet +from torch._inductor.virtualized import OpsValue + +from ...virtualized import V + + +_ACCUMULATOR_ARG_NAME = "accum" + + +def scaled_mm_evt( + scale_A_name: str, scale_B_name: str, bias_name: Optional[str], output_name: str +) -> tuple[list[str], dict[str, Any], str]: + evt_read_names = [scale_A_name, scale_B_name] + var_name_to_buffer_name = {n: n for n in [scale_A_name, scale_B_name]} + var_name_to_buffer_name["D"] = output_name + var_name_to_buffer_name[_ACCUMULATOR_ARG_NAME] = output_name + expr = f"accum * {scale_A_name} * {scale_B_name}{linesep}" + if bias_name: + expr = f"({expr}) + {bias_name}" + evt_read_names.append(bias_name) + var_name_to_buffer_name[bias_name] = bias_name + + evt_py_code = f"def fn(accum, {','.join(evt_read_names)}):{linesep}\ + D = {expr}{linesep}\ + return D{linesep}" + + return evt_read_names, var_name_to_buffer_name, evt_py_code + + +class CutlassEVTOpsMixIn: + @staticmethod + def _infix_bin_op(op: str, a: str, b: str) -> str: + return f"{a} {op} {b}" + + @staticmethod + def _prefix_bin_op(op: str, a: str, b: str) -> str: + return f"{op}({a}, {b})" + + @staticmethod + def _prefix_un_op(op: str, a: str) -> str: + return f"{op}({a})" + + @staticmethod + def to_dtype( + x: str, + dtype: Any, + src_dtype: Optional[torch.dtype] = None, + use_compute_types: bool = False, + ) -> str: + return x + + @staticmethod + def constant(value: Any, dtype: Any) -> str: + raise NotImplementedError + + @staticmethod + def mul(x0: str, x1: str) -> str: + return CutlassEVTOpsMixIn._infix_bin_op("*", x0, x1) + + @staticmethod + def truediv(x0: str, x1: str) -> str: + return CutlassEVTOpsMixIn._infix_bin_op("/", x0, x1) + + @staticmethod + def ge(x0: str, x1: str) -> str: + raise NotImplementedError + + @staticmethod + def add(x0: str, x1: str) -> str: + return CutlassEVTOpsMixIn._infix_bin_op("+", x0, x1) + + @staticmethod + def relu(x0: str) -> str: + return CutlassEVTOpsMixIn._prefix_un_op("relu", x0) + + @staticmethod + def sigmoid(x0: str) -> str: + return CutlassEVTOpsMixIn._prefix_un_op("sigmoid", x0) + + @staticmethod + def sub(x0: str, x1: str) -> str: + return CutlassEVTOpsMixIn._infix_bin_op("-", x0, x1) + + @staticmethod + def tanh(x0: str) -> str: + return CutlassEVTOpsMixIn._prefix_un_op("tanh", x0) + + @staticmethod + def exp(x0: str) -> str: + return CutlassEVTOpsMixIn._prefix_un_op("exp", x0) + + +class MockCutlassHandler(CutlassEVTOpsMixIn, WrapperHandler): + """Passthrough handler for cutlass ops, used for running epilogue nodes for memory planning""" + + +class _AssignmentFormatter(DefaultHandler): + def __init__(self, parent_handler: "CutlassEVTCodegen"): + self.parent_handler = parent_handler + + def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any: + # Handle op dispatch here + if hasattr(self.parent_handler, name): + fn = getattr(self.parent_handler, name) + line = fn(*args, **kwargs) + if name in ("load", "store"): + return OpsValue(line) + else: + var = self.parent_handler._tmp_var() + line = DelayReplaceLine( + var, + lambda: "D" + if var == self.parent_handler.last_stored_var_name + else var, + f"{var} = {line}", + ) + self.parent_handler.body.writeline(line) + return OpsValue(var) + else: + raise NotImplementedError(name) + + +class CutlassEVTCodegen(CutlassEVTOpsMixIn): + """ + Notes: + * Used by CUTLASSGemmTemplate. + * This class should not be instantiated by users, it is intended to be used + by calling CutlassEVTCodegen.ir_to_evt_python_code(...) + which instantiates this class as an ops handler for virtualized.V.ops.[op-name] + * Extend this with more _op_ nodes to add support for new pointwise operations. + """ + + def __init__(self, accumulator_node_name: str, removed_buffers: OrderedSet[str]): + """ + + Initializes a CutlassEVTEpilogueArgumentFormatter object. Do not instantiate directly. + Use the CutlassEVTCodegen.ir_to_evt_python_code static method. + + Args: + accumulator_node_name: The name of the accumulator node which should contain + the Matmul result before fusion according to the IR graph. + epilogue_nodes: The list of scheduler nodes to be fused into the epilogue + """ + self.accumulator_node_name: str = accumulator_node_name # + self.body: IndentedBuffer = IndentedBuffer(1) # The body buffer for codegen + self.var_counter: Iterator[int] = itertools.count() + self.store_name_to_value: dict[str, OpsValue] = ( + dict() + ) # Aliases for subexpression functors + self.reads: OrderedSet[str] = OrderedSet([]) + # Used for creating example tensors + self.var_name_to_buffer_name: dict[str, str] = { + _ACCUMULATOR_ARG_NAME: accumulator_node_name + } + self.removed_buffers: OrderedSet[str] = removed_buffers + self.cur_node: Optional[ComputedBuffer] = None + self.name_to_buffer = V.graph.name_to_buffer | V.graph.graph_inputs + for name in V.graph.constants: + self.name_to_buffer[name] = V.graph.add_tensor_constant( + V.graph.constants[name], name + ) + self.is_D_assigned = False + self.D_var_name = None + + if accumulator_node_name not in removed_buffers: + # cannot return accumulator directly, so alias it + var = self._tmp_var() + self.body.writeline(f"{var} = {_ACCUMULATOR_ARG_NAME}") + self.store(accumulator_node_name, value=OpsValue(var)) + + @staticmethod + def ir_to_evt_python_code( + cuda_template_node_name: str, + epilogue_nodes: list[BaseSchedulerNode], + removed_buffers: OrderedSet[str], + ) -> tuple[list[str], list[str], dict[str, Any], str]: + codegen = CutlassEVTCodegen(cuda_template_node_name, removed_buffers) + handler = _AssignmentFormatter(codegen) + + with virtualized.V.set_ops_handler(handler): + for s_node in epilogue_nodes: + node = s_node.node + assert isinstance(node, ComputedBuffer) + with codegen.set_cur_node(node): + index_vars = CutlassEVTCodegen.get_index_vars(node) + node.get_store_function()(index_vars) + + codegen.finalize() + + return ( + codegen.get_reads(), + codegen.get_writes(), + codegen.get_renames(), + codegen.get_value(), + ) + + def get_value(self) -> str: + return linesep.join( + [ + self._render_input_signature(), + self.body.getvalue(), + self._render_return_statement(), + ] + ) + + def finalize(self) -> None: + # Rename the last store to D + # no other code references this store + # to workaround https://github.com/NVIDIA/cutlass/issues/2288 + # Note: the delayed line will automatically rewrite the last assignment to + # be to D + buffer_name = self.var_name_to_buffer_name[self.last_stored_var_name] + self.var_name_to_buffer_name.pop(self.last_stored_var_name) + self.var_name_to_buffer_name["D"] = buffer_name + self.store_name_to_value[buffer_name] = OpsValue("D") + + @contextmanager + def set_cur_node(self, node: ComputedBuffer) -> Generator[None, Any, Any]: + prev_node = self.cur_node + try: + self.cur_node = node + yield + finally: + self.cur_node = prev_node + + def get_renames(self) -> dict[str, str]: + return dict(self.var_name_to_buffer_name) + + def get_reads(self) -> list[str]: + return list(self.reads.difference(self.store_name_to_value.keys())) + + def get_writes(self) -> list[str]: + return list(self.store_name_to_value.keys()) + + def load(self, name: str, index: Any) -> str: + self._check_indexing(name, index) + if name in self.store_name_to_value: + return self.store_name_to_value[name].value + elif name == self.accumulator_node_name: + return _ACCUMULATOR_ARG_NAME + else: + self.reads.add(name) + self.var_name_to_buffer_name[name] = name + return name + + def store( + self, name: Any, index: Any = None, value: Any = None, mode: Any = None + ) -> None: + if name not in self.removed_buffers: + if index: + self._check_indexing(name, index) + assert value.value != _ACCUMULATOR_ARG_NAME, ( + "Cannot store accumulator arg name" + ) + self.var_name_to_buffer_name[value.value] = name + self.store_name_to_value[name] = value + self.last_stored_var_name = value.value + return None + + def _get_cur_node(self) -> ComputedBuffer: + assert self.cur_node + return self.cur_node + + @staticmethod + def get_index_vars(node: ComputedBuffer) -> Sequence[sympy.Expr]: + data = node.data + # TODO mlazos: relax this, cutlass supports reductions and other ops + assert isinstance(data, Pointwise) + return data._index(data.ranges) + + def _get_current_index_vars(self) -> Sequence[sympy.Expr]: + return self.get_index_vars(self._get_cur_node()) + + def _check_indexing(self, name: str, index: sympy.Expr) -> None: + # We only support indexing that matches the layout today because + # CUTLASS doesn't support arbitrary indexing + buffer_name = ( + self.accumulator_node_name if name == _ACCUMULATOR_ARG_NAME else name + ) + buffer = self.name_to_buffer[buffer_name] + index_strides = V.graph.sizevars.stride_vars( + index, self._get_current_index_vars() + ) + stride = buffer.get_layout().stride + if not self._stride_compatible(stride, index_strides): + raise NotImplementedError( + f"Unsupported indexing for {name} with index {index}, index strides {index_strides}, and layout stride {stride}" + ) + + def _stride_compatible( + self, left: Iterable[sympy.Expr], right: Iterable[sympy.Expr] + ) -> bool: + return all( + sympy.Eq(l, r) or sympy.Eq(l, 0) or sympy.Eq(r, 0) + for l, r in (zip(left, right)) + ) + + def _render_input_signature(self) -> str: + arguments = ", ".join( + [_ACCUMULATOR_ARG_NAME] + + [name for name in self.reads if name != self.accumulator_node_name] + ) + return f"def fn({arguments}):" + + def _render_return_statement(self) -> str: + return_vars = OrderedSet( + op_v.value for op_v in self.store_name_to_value.values() + ) + assert "D" in return_vars + return f"return {', '.join(return_vars)}" + + def _tmp_var(self) -> str: + return f"tmp_{next(self.var_counter)}" diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cutlass_utils.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cutlass_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..fa46e8766cd5819b41af4c5269945119722d2251 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cutlass_utils.py @@ -0,0 +1,493 @@ +# mypy: allow-untyped-defs +import atexit +import functools +import logging +import os +import shutil +import sys +import time +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Optional +from typing_extensions import TypeIs + +import sympy + +import torch +from torch._inductor.runtime.runtime_utils import dynamo_timed +from torch._inductor.utils import clear_on_fresh_cache +from torch.utils._ordered_set import OrderedSet + +from ... import config +from ...ir import Layout +from ...runtime.runtime_utils import cache_dir +from ...virtualized import V +from ..cpp_utils import DTYPE_TO_CPP +from .cuda_env import get_cuda_arch, get_cuda_version + + +log = logging.getLogger(__name__) + +CUTLASS_OPERATION_KIND: str = "gemm" +ACCUMULATOR_DTYPES: OrderedSet[torch.dtype] = OrderedSet([torch.float, torch.int32]) +XW_DTYPES: OrderedSet[torch.dtype] = OrderedSet( + [torch.half, torch.bfloat16, torch.float8_e4m3fn, torch.int8] +) + + +@atexit.register +def move_cutlass_compiled_cache() -> None: + """Move CUTLASS compiled cache file to the cache directory if it exists.""" + if not try_import_cutlass.cache_info().currsize > 0: + return + + import cutlass_cppgen # type: ignore[import-not-found] + + # Check if the CACHE_FILE attribute exists in cutlass_cppgen and if the file exists + if not hasattr(cutlass_cppgen, "CACHE_FILE") or not os.path.exists( + cutlass_cppgen.CACHE_FILE + ): + return + + try: + filename = os.path.basename(cutlass_cppgen.CACHE_FILE) + shutil.move(cutlass_cppgen.CACHE_FILE, os.path.join(cache_dir(), filename)) + log.debug("Moved CUTLASS compiled cache file to %s", cache_dir()) + except OSError: + log.warning("Failed to move CUTLASS compiled cache file", exc_info=True) + + +def _rename_cutlass_import(content: str, cutlass_modules: list[str]) -> str: + for cutlass_module in cutlass_modules: + content = content.replace( + f"from {cutlass_module} import ", + f"from cutlass_library.{cutlass_module} import ", + ) + return content + + +@functools.cache +def try_import_cutlass() -> bool: + """ + We want to support three ways of passing in CUTLASS: + 1. fbcode, handled by the internal build system. + 2. User specifies cutlass_dir. The default is ../third_party/cutlass/, + which is the directory when developers build from source. + """ + if config.is_fbcode(): + try: + import cutlass_cppgen # type: ignore[import-not-found] # noqa: F401 + import cutlass_library # type: ignore[import-not-found] + except ImportError as e: + log.warning( # noqa: G200 + "Failed to import CUTLASS packages in fbcode: %s, ignoring the CUTLASS backend.", + str(e), + ) + return False + + return True + + # Copy CUTLASS python scripts to a temp dir and add the temp dir to Python search path. + # This is a temporary hack to avoid CUTLASS module naming conflicts. + # TODO(ipiszy): remove this hack when CUTLASS solves Python scripts packaging structure issues. + + # TODO(mlazos): epilogue visitor tree currently lives in python/cutlass, + # but will be moved to python/cutlass_library in the future (later 2025) + def path_join(path0, path1): + return os.path.abspath(os.path.join(path0, path1)) + + # contains both cutlass and cutlass_library + # we need cutlass for eVT + cutlass_python_path = path_join(config.cuda.cutlass_dir, "python") + torch_root = os.path.abspath(os.path.dirname(torch.__file__)) + mock_src_path = os.path.join( + torch_root, + "_inductor", + "codegen", + "cuda", + "cutlass_lib_extensions", + "cutlass_mock_imports", + ) + + cutlass_library_src_path = path_join(cutlass_python_path, "cutlass_library") + cutlass_cppgen_src_path = path_join(cutlass_python_path, "cutlass_cppgen") + pycute_src_path = path_join(cutlass_python_path, "pycute") + + tmp_cutlass_full_path = os.path.abspath(os.path.join(cache_dir(), "torch_cutlass")) + + dst_link_library = path_join(tmp_cutlass_full_path, "cutlass_library") + dst_link_cutlass_cppgen = path_join(tmp_cutlass_full_path, "cutlass_cppgen") + dst_link_pycute = path_join(tmp_cutlass_full_path, "pycute") + + # mock modules to import cutlass + mock_modules = ["cuda", "scipy", "pydot"] + + if os.path.isdir(cutlass_python_path): + if tmp_cutlass_full_path not in sys.path: + + def link_and_append(dst_link, src_path, parent_dir): + if os.path.lexists(dst_link): + assert os.path.islink(dst_link), ( + f"{dst_link} is not a symlink. Try to remove {dst_link} manually and try again." + ) + assert os.path.realpath(os.readlink(dst_link)) == os.path.realpath( + src_path, + ), f"Symlink at {dst_link} does not point to {src_path}" + else: + os.makedirs(parent_dir, exist_ok=True) + os.symlink(src_path, dst_link) + + if parent_dir not in sys.path: + sys.path.append(parent_dir) + + link_and_append( + dst_link_library, cutlass_library_src_path, tmp_cutlass_full_path + ) + link_and_append( + dst_link_cutlass_cppgen, cutlass_cppgen_src_path, tmp_cutlass_full_path + ) + link_and_append(dst_link_pycute, pycute_src_path, tmp_cutlass_full_path) + + for module in mock_modules: + link_and_append( + path_join(tmp_cutlass_full_path, module), # dst_link + path_join(mock_src_path, module), # src_path + tmp_cutlass_full_path, # parent + ) + + try: + import cutlass_cppgen # type: ignore[import-not-found] # noqa: F401, F811 + import cutlass_library.generator # noqa: F401 + import cutlass_library.library # noqa: F401 + import cutlass_library.manifest # noqa: F401 + import pycute # type: ignore[import-not-found] # noqa: F401 + + return True + except ImportError as e: + log.debug( # noqa: G200 + "Failed to import CUTLASS packages: %s, ignoring the CUTLASS backend.", + str(e), + ) + else: + log.debug( + "Failed to import CUTLASS packages: CUTLASS repo does not exist: %s", + cutlass_python_path, + ) + return False + + +@functools.lru_cache(8) +def _normalize_cuda_arch(arch: str) -> str: + if int(arch) >= 100: + log.warning( + "Detected CUDA architecture >= 100: %s. We will generate operations with " + "GenerateSM100 (if available) and GenerateSM90. Please file an " + "issue for any problems and feedback. ", + arch, + ) + + if int(arch) >= 100: + return "100" + elif int(arch) >= 90: + return "90" + elif int(arch) >= 80: + return "80" + elif int(arch) >= 75: + return "75" + elif int(arch) >= 70: + return "70" + else: + raise NotImplementedError(f"Unsupported cuda arch: {arch}") + + +@dataclass +class CUTLASSArgs: + """ + CUTLASS args used to initialize a CUTLASS Manifest. + """ + + architectures: Optional[str] = None + cuda_version: Optional[str] = None + instantiation_level: Optional[str] = None + operations: Optional[str] = None + + build_dir = "" + curr_build_dir = "" + generator_target = "" + kernels = "all" + ignore_kernels = "" + exclude_kernels = "" + # TODO: these three look dead? + kernel_filter_file: None = None + selected_kernel_list: None = None + interface_dir: None = None + filter_by_cc = True + disable_full_archs_compilation = False + + def __post_init__(self): + if self.architectures is None or self.cuda_version is None: + raise RuntimeError( + f"{self.architectures=} or {self.cuda_version=} is None!" + ) + self.architectures = _normalize_cuda_arch(self.architectures) + + +@clear_on_fresh_cache +@functools.cache +def _gen_ops_cached(arch, version) -> dict[Any, Any]: + # Note: Cache needs to be specific for cuda architecture and version + + # Import cutlass python scripts. + assert try_import_cutlass() + import cutlass_library.generator as cutlass_generator + import cutlass_library.manifest as cutlass_manifest + + if arch is None or version is None: + log.error( + "Cannot detect cuda arch %s or cuda version %s. " + "Will discard all cutlass ops. " + "Please consider setting _inductor.cuda.arch and _inductor.cuda.version configs.", + arch, + version, + ) + return {} + arch = _normalize_cuda_arch(arch) + instantiation_level: str = config.cuda.cutlass_instantiation_level + args = CUTLASSArgs( + architectures=arch, + cuda_version=version, + instantiation_level=instantiation_level, + operations=CUTLASS_OPERATION_KIND, + ) + manifest = cutlass_manifest.Manifest(args) + + start_time = time.time() + if arch == "100": + if hasattr(cutlass_generator, "GenerateSM100"): + cutlass_generator.GenerateSM100(manifest, args.cuda_version) + cutlass_generator.GenerateSM90(manifest, args.cuda_version) + else: + try: + func = getattr(cutlass_generator, "GenerateSM" + arch) + func(manifest, args.cuda_version) + except AttributeError as e: + raise NotImplementedError( + "Arch " + arch + " is not supported by current cutlass lib." + ) from e + + log.info( + "CUTLASS library generated a dict of %d operation kinds in %.2f seconds", + len(manifest.operations), + time.time() - start_time, + ) + return manifest.operations + + +def gen_ops() -> dict[Any, Any]: + """ + Generates all supported CUTLASS operations. + """ + with dynamo_timed("cutlass_utils.gen_ops"): + arch = get_cuda_arch() + version = get_cuda_version() + return _gen_ops_cached(arch, version) + + +DTYPE_TO_CUTLASS_TYPE = { + **DTYPE_TO_CPP, + torch.float16: "__half", + torch.bfloat16: "__nv_bfloat16", + torch.float8_e4m3fn: "__nv_fp8_e4m3", +} + + +@functools.lru_cache(32) +def torch_dtype_to_cutlass_type( + torch_dtype: torch.dtype, +) -> "cutlass_library.library.DataType": # type: ignore[name-defined] # noqa: F821 + # Import cutlass python scripts. + assert try_import_cutlass() + import cutlass_library # type: ignore[import] + + if torch_dtype == torch.float: + return cutlass_library.library.DataType.f32 + elif torch_dtype == torch.half: + return cutlass_library.library.DataType.f16 + elif torch_dtype == torch.bfloat16: + return cutlass_library.library.DataType.bf16 + else: + raise NotImplementedError(f"Unsupported data type: {torch_dtype=}") + + +@functools.lru_cache(32) +def dtype_match( + torch_dtype: Optional[torch.dtype], + cutlass_dtype: "cutlass_library.library.DataType", # type: ignore[name-defined] # noqa: F821 +) -> bool: + # Import cutlass python scripts. + assert try_import_cutlass() + import cutlass_library + + if torch_dtype == torch.float: + return ( + cutlass_dtype == cutlass_library.library.DataType.f32 + or cutlass_dtype == cutlass_library.library.DataType.tf32 + ) + elif torch_dtype == torch.half: + return cutlass_dtype == cutlass_library.library.DataType.f16 + elif torch_dtype == torch.bfloat16: + return cutlass_dtype == cutlass_library.library.DataType.bf16 + elif torch_dtype == torch.int8: + return cutlass_dtype == cutlass_library.library.DataType.s8 + elif torch_dtype == torch.uint8: + return cutlass_dtype == cutlass_library.library.DataType.u8 + elif torch_dtype == torch.int32: + return cutlass_dtype == cutlass_library.library.DataType.s32 + elif torch_dtype == torch.float8_e4m3fn: + return cutlass_dtype == cutlass_library.library.DataType.e4m3 + else: + return False + + +def get_accumulator_dtype( + input_torch_dtypes: list[torch.dtype], +) -> Optional[torch.dtype]: + """ + Given a pair of input torch dtypes, returns the inferred accumulator torch dtype. + """ + + assert OrderedSet(input_torch_dtypes) <= XW_DTYPES, ( + f"{input_torch_dtypes=} is not supported" + ) + + if len(input_torch_dtypes) != 2: + return None + + torch_dtype = None + if input_torch_dtypes[0] == input_torch_dtypes[1]: + torch_dtype = input_torch_dtypes[0] + else: + size0 = torch.tensor([], dtype=input_torch_dtypes[0]).element_size() + size1 = torch.tensor([], dtype=input_torch_dtypes[1]).element_size() + if size0 > size1: + dtype0, dtype1 = input_torch_dtypes + else: + dtype1, dtype0 = input_torch_dtypes + if dtype0 in [torch.half, torch.bfloat16] and dtype1 in [ + torch.int8, + torch.uint8, + ]: + torch_dtype = dtype0 + + if torch_dtype in (torch.float16, torch.bfloat16, torch.float, torch.float8_e4m3fn): + accumulator_dtype = torch.float + elif torch_dtype == torch.int8: + accumulator_dtype = torch.int32 + else: + raise NotImplementedError(f"Unsupported data types: {input_torch_dtypes=}") + + assert accumulator_dtype in ACCUMULATOR_DTYPES, ( + f"{accumulator_dtype=} is not supported" + ) + return accumulator_dtype + + +@functools.lru_cache(32) +def get_alignments(torch_dtype: torch.dtype) -> list[int]: + """ + Returns all possible valid CUTLASS alignments in terms of the number of elements for a given dtype. + CUTLASS gemm / conv SM80 APIs support 16 bytes max alignment, and 2 bytes min alignment. + """ + + if torch_dtype in (torch.half, torch.bfloat16): + return [8, 4, 2, 1] + elif torch_dtype == torch.float: + return [4, 2, 1] + elif torch_dtype in (torch.uint8, torch.int8, torch.float8_e4m3fn): + return [16, 8, 4, 2] + elif torch_dtype == torch.int32: + return [4, 2, 1] + else: + raise NotImplementedError(f"unsupported {torch_dtype=} for alignments") + + +def get_max_alignment(inductor_layout: Layout) -> int: + """ + Returns the max alignment (in terms of number of elements) for a given Inductor Layout. + """ + + dtype = inductor_layout.dtype + size = inductor_layout.size + offset = inductor_layout.offset + + def is_static_int(number: object) -> TypeIs[int | sympy.Integer]: + return isinstance(number, (int | sympy.Integer)) + + def a_factor_of(x, alignment): + if is_static_int(x) and is_static_int(alignment): + return x % alignment == 0 + rem = sympy.Mod(x, alignment) + return V.graph.sizevars.evaluate_expr(sympy.Eq(rem, 0)) + + try: + contiguous_dim = inductor_layout.stride.index(1) + except ValueError: + # No dim with stride 1 found, return 1 + return 1 + alignments = get_alignments(dtype) + for alignment in alignments: + if not a_factor_of(size[contiguous_dim], alignment) or not a_factor_of( + offset, alignment + ): + continue + if all( + (dim == contiguous_dim) + or a_factor_of(inductor_layout.stride[dim], alignment) + for dim in range(len(size)) + ): + return alignment + return 1 + + +class CUDACompileSourceCapturingContext: + # Helper class for Benchmarking and Testing CUTLASS Kernels in isolation. + # Can be used to capture the sourcecode passed to CUDACodeCache.compile + + def __init__(self): + self.sources = [] + self._compile_patch = None + + def __enter__(self, *args, **kwargs): + import unittest.mock as mock + + import torch._inductor.codecache + + _compile_method_orig = torch._inductor.codecache.CUDACodeCache.compile + + def my_compile( + source_code, dst_file_ext, extra_args: Optional[list[str]] = None + ): + self.sources.append(source_code) + return _compile_method_orig(source_code, dst_file_ext) + + # pyrefly: ignore [bad-assignment] + self._compile_patch = mock.patch( + "torch._inductor.codecache.CUDACodeCache.compile", my_compile + ) + self._compile_patch.__enter__(*args, **kwargs) # type: ignore[union-attr] + return self + + def __exit__(self, *args, **kwargs): + self._compile_patch.__exit__(*args, **kwargs) # type: ignore[union-attr] + + +def cuda_standalone_runner_compile_command(srcpath: Path, exepath: Path): + # returns command string to compile a (captured) CUDA GEMM Kernel source to a standalone executable that's ready to run + # Passes the correct preprocessor define to nvcc to ensure the standalone runner is enabled. + from torch._inductor.codecache import cuda_compile_command + + extra_args = ["-DGENERATE_STANDALONE_RUNNER=1", "-DCUTLASS_DEBUG_TRACE_LEVEL=1"] + compile_command = cuda_compile_command( + [str(srcpath)], str(exepath), "exe", extra_args=extra_args + ) + return compile_command diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/device_op_overrides.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/device_op_overrides.py new file mode 100644 index 0000000000000000000000000000000000000000..147515e0decfe8f14853e18193fa4ca45501cac8 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/device_op_overrides.py @@ -0,0 +1,364 @@ +from __future__ import annotations + +from typing import Optional + +import torch + +from ..common import ( + DeviceOpOverrides, + register_device_op_overrides, + TritonScratchWorkspace, +) + + +class CUDADeviceOpOverrides(DeviceOpOverrides): + """ + CUDA-specific codegen functions, see DeviceOpOverrides for details + """ + + def import_get_raw_stream_as(self, name: str) -> str: + return f"from torch._C import _cuda_getCurrentRawStream as {name}" + + def set_device(self, device_idx: int) -> str: + return f"torch.cuda.set_device({device_idx})" + + def synchronize(self) -> str: + return "torch.cuda.synchronize()" + + def device_guard(self, device_idx: int) -> str: + return f"torch.cuda._DeviceGuard({device_idx})" + + def cpp_device_guard(self) -> str: + return "at::cuda::CUDAGuard" + + def cpp_aoti_device_guard(self) -> str: + return "AOTICudaGuard" + + def cpp_stream_guard(self) -> str: + return "at::cuda::CUDAStreamGuard" + + def cpp_aoti_stream_guard(self) -> str: + return "AOTICudaStreamGuard" + + def cpp_getStreamFromExternal(self) -> str: + return "at::cuda::getStreamFromExternal" + + def kernel_header(self) -> str: + source_codes = """ + #include + #include + #include + """ + return source_codes + + def kernel_driver(self) -> str: + source_codes = """ + #define CUDA_DRIVER_CHECK(EXPR) \\ + do { \\ + CUresult code = EXPR; \\ + const char *msg; \\ + CUresult code_get_error = cuGetErrorString(code, &msg); \\ + if (code_get_error != CUDA_SUCCESS) { \\ + throw std::runtime_error( \\ + std::string("CUDA driver error: ") + \\ + std::string("invalid error code!")); \\ + } \\ + if (code != CUDA_SUCCESS) { \\ + throw std::runtime_error( \\ + std::string("CUDA driver error: ") + \\ + std::string(msg)); \\ + } \\ + } while (0); + + static inline CUfunction loadKernel( + std::string filePath, + const std::string &funcName, + uint32_t sharedMemBytes, + const std::optional &cubinDir = std::nullopt) { + if (cubinDir) { + std::filesystem::path p1{*cubinDir}; + std::filesystem::path p2{filePath}; + filePath = (p1 / p2.filename()).string(); + } + + CUmodule mod; + CUfunction func; + CUDA_DRIVER_CHECK(cuModuleLoad(&mod, filePath.c_str())); + CUDA_DRIVER_CHECK(cuModuleGetFunction(&func, mod, funcName.c_str())); + if (sharedMemBytes > 0) { + CUDA_DRIVER_CHECK(cuFuncSetAttribute( + func, + CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, + sharedMemBytes + )) + } + return func; + } + + static inline CUfunction loadKernel(const void* start, const std::string &funcName, uint32_t sharedMemBytes) { + CUmodule mod; + CUfunction func; + CUDA_DRIVER_CHECK(cuModuleLoadData(&mod, start)); + CUDA_DRIVER_CHECK(cuModuleGetFunction(&func, mod, funcName.c_str())); + if (sharedMemBytes > 0) { + CUDA_DRIVER_CHECK(cuFuncSetAttribute( + func, + CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, + sharedMemBytes + )) + } + return func; + } + + static inline void launchKernel( + CUfunction func, + uint32_t gridX, + uint32_t gridY, + uint32_t gridZ, + uint32_t numWarps, + uint32_t sharedMemBytes, + void* args[], + cudaStream_t stream) { + CUDA_DRIVER_CHECK(cuLaunchKernel( + func, gridX, gridY, gridZ, 32*numWarps, 1, 1, sharedMemBytes, stream, args, nullptr + )); + } + """ + if torch.version.hip is not None: + # Adjusting the warp size to GPU supported wavefront size on AMD GPU + prop = torch.cuda.get_device_properties(torch.cuda.current_device()) + source_codes = source_codes.replace( + "32*numWarps", str(prop.warp_size) + "*numWarps" + ) + return source_codes + + def tma_descriptor_helpers(self) -> str: + """ + CUDA helper functions for initializing TMA Descriptors on host side + """ + if torch.version.hip is not None: + raise RuntimeError("Host-side TMA descriptors not supported on HIP.") + + # helper functions for initializing 1D and 2D TMA descriptors in C++. borrowed from the Triton code here: + # Old APIs (fill(1|2)DTMADescriptor): + # https://github.com/triton-lang/triton/blob/6af4f88591c85de079d8a36a4d7dba67918e2b39/third_party/nvidia/backend/driver.c#L283 + # New APIs (fillTMADescriptor): + # https://github.com/triton-lang/triton/blob/main/third_party/nvidia/backend/driver.c#L283 + return """ + #if !defined(USE_ROCM) && defined(CUDA_VERSION) && CUDA_VERSION >= 12000 + [[maybe_unused]] static void init1DTMADescriptor( + CUtensorMap* m, + void* globalAddress, + uint64_t dim, + uint32_t blockDim, + uint32_t elementSize) { + uint64_t dims[1] = {dim}; + uint64_t globalStrides[1] = {dim * elementSize}; + uint32_t tensorDims[1] = {blockDim}; + uint32_t elementStrides[1] = {1}; + + CUtensorMapDataType type; + switch (elementSize) { + case 1: + type = CU_TENSOR_MAP_DATA_TYPE_UINT8; + break; + case 2: + type = CU_TENSOR_MAP_DATA_TYPE_UINT16; + break; + case 4: + type = CU_TENSOR_MAP_DATA_TYPE_UINT32; + break; + default: + throw std::runtime_error("elementSize must be 1, 2, or 4"); + } + + if (elementSize * blockDim < 32) { + throw std::runtime_error("block size too small"); + } + + int rank = 1; + + CUDA_DRIVER_CHECK(cuTensorMapEncodeTiled( + m, type, rank, globalAddress, dims, + globalStrides, tensorDims, elementStrides, CU_TENSOR_MAP_INTERLEAVE_NONE, + CU_TENSOR_MAP_SWIZZLE_NONE, CU_TENSOR_MAP_L2_PROMOTION_NONE, + CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE)); + } + + [[maybe_unused]] static void init2DTMADescriptor( + CUtensorMap* m, + void* globalAddress, + uint64_t dim1, + uint64_t dim0, + uint32_t blockDim1, + uint32_t blockDim0, + uint32_t elementSize) { + uint64_t dims[2] = {dim0, dim1}; + uint32_t tensorDims[2] = {blockDim0, blockDim1}; + uint64_t globalStrides[2] = {dims[0] * elementSize, + dims[0] * dims[1] * elementSize}; + uint32_t elementStrides[2] = {1, 1}; + + CUtensorMapDataType type; + switch (elementSize) { + case 1: + type = CU_TENSOR_MAP_DATA_TYPE_UINT8; + break; + case 2: + type = CU_TENSOR_MAP_DATA_TYPE_UINT16; + break; + case 4: + type = CU_TENSOR_MAP_DATA_TYPE_UINT32; + break; + default: + throw std::runtime_error("elementSize must be 1, 2, or 4"); + } + + int rank = 2; + + CUtensorMapSwizzle swizzle = CU_TENSOR_MAP_SWIZZLE_128B; + uint32_t contigDimSizeInByte = elementSize * tensorDims[0]; + if (contigDimSizeInByte >= 128) { + swizzle = CU_TENSOR_MAP_SWIZZLE_128B; + } else if (contigDimSizeInByte >= 64) { + swizzle = CU_TENSOR_MAP_SWIZZLE_64B; + } else if (contigDimSizeInByte >= 32) { + swizzle = CU_TENSOR_MAP_SWIZZLE_32B; + } else { + throw std::runtime_error("block size too small"); + } + + if (contigDimSizeInByte > 128) { + tensorDims[0] = 128 / elementSize; + } + + CUDA_DRIVER_CHECK(cuTensorMapEncodeTiled( + m, type, rank, globalAddress, dims, + globalStrides, tensorDims, elementStrides, CU_TENSOR_MAP_INTERLEAVE_NONE, + swizzle, CU_TENSOR_MAP_L2_PROMOTION_L2_128B, + CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE)); + } + + [[maybe_unused]] static void initTMADescriptor( + CUtensorMap* m, + void* globalAddress, + int elemSize, + int rank, + uint32_t* blockSize, + uint64_t* shape, + uint64_t* stride + ) { + uint32_t elementStrides[5] = {1, 1, 1, 1, 1}; + uint32_t blockSizeInt[5]; + uint64_t shapeInt[5]; + uint64_t stridesLL[5]; + + // Reorder blockSize (reverse the order) + for (int i = 0; i < rank; ++i) { + blockSizeInt[rank - i - 1] = blockSize[i]; + } + + // Reorder shape (reverse the order) + for (int i = 0; i < rank; ++i) { + shapeInt[rank - i - 1] = shape[i]; + } + + // Reorder and calculate strides + for (int i = 0; i + 1 < rank; ++i) { + stridesLL[rank - i - 2] = elemSize * stride[i]; + } + stridesLL[rank - 1] = + shapeInt[rank - 1] * (rank == 1 ? elemSize : stridesLL[rank - 2]); + + CUtensorMapDataType type; + // In Triton this is computed ahead of time; but for simplicity + // in the PyTorch version we copied this code from the old + // TMA API handling (i.e. init2DTMADescriptor) + switch (elemSize) { + case 1: + type = CU_TENSOR_MAP_DATA_TYPE_UINT8; + break; + case 2: + type = CU_TENSOR_MAP_DATA_TYPE_UINT16; + break; + case 4: + type = CU_TENSOR_MAP_DATA_TYPE_UINT32; + break; + default: + throw std::runtime_error("elemSize must be 1, 2, or 4"); + } + + // Calculate the size of the most contiguous dimension in bytes + CUtensorMapSwizzle swizzle = CU_TENSOR_MAP_SWIZZLE_128B; + uint32_t contigDimSizeInByte = elemSize * blockSizeInt[0]; + if (rank == 1) { + // rank 1 should not be swizzled + swizzle = CU_TENSOR_MAP_SWIZZLE_NONE; + } else if (contigDimSizeInByte >= 128) { + swizzle = CU_TENSOR_MAP_SWIZZLE_128B; + } else if (contigDimSizeInByte >= 64) { + swizzle = CU_TENSOR_MAP_SWIZZLE_64B; + } else if (contigDimSizeInByte >= 32) { + swizzle = CU_TENSOR_MAP_SWIZZLE_32B; + } else { + throw std::runtime_error("block size too small"); + } + + CUDA_DRIVER_CHECK(cuTensorMapEncodeTiled( + m, type, rank, globalAddress, + shapeInt, stridesLL, blockSizeInt, elementStrides, + CU_TENSOR_MAP_INTERLEAVE_NONE, (CUtensorMapSwizzle)swizzle, + CU_TENSOR_MAP_L2_PROMOTION_L2_128B, CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE)); + } + + struct StableTMADescriptor { + CUtensorMap m; + uint32_t block_shape[5]; + uint64_t global_shape[5]; + uint64_t strides[5]; + }; + #endif + """ + + def cpp_stream_type(self) -> str: + return "cudaStream_t" + + def aoti_get_stream(self) -> str: + return "aoti_torch_get_current_cuda_stream" + + def cpp_kernel_type(self) -> str: + return "CUfunction" + + def cpp_device_ptr(self) -> str: + return "CUdeviceptr" + + def cpp_scratch( + self, idx: int, workspace: TritonScratchWorkspace, prefix: Optional[str] = None + ) -> Optional[tuple[list[str], str]]: + prefix = f"{prefix}_" if prefix else "" + var_name = f"{prefix}scratch_{idx}" + if workspace.size > 0: + size_array = f"int64_t {var_name}_size[] = {{{workspace.size}}};" + stride_array = f"int64_t {var_name}_stride[] = {{1}};" + device_type = "cached_torch_device_type_cuda" + device_idx = "device_idx_" + + return ( + [ + f"{size_array}", + f"{stride_array}", + f"AtenTensorHandle {var_name}_handle;", + ( + f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_empty_strided(1, {var_name}_size, {var_name}_stride, " + f"{workspace.generate_dtype_str()}, {device_type}, {device_idx}, &{var_name}_handle));" + ), + f"RAIIAtenTensorHandle {var_name}_tensor({var_name}_handle);", + f"CUdeviceptr {var_name} = reinterpret_cast({var_name}_tensor.data_ptr());", + ], + var_name, + ) + else: + return [f"CUdeviceptr {var_name} = 0;"], var_name + + +register_device_op_overrides("cuda", CUDADeviceOpOverrides()) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/gemm_template.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/gemm_template.py new file mode 100644 index 0000000000000000000000000000000000000000..c4b7188bd9e621eb4a2bed773d7a5a116bca9b3e --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/gemm_template.py @@ -0,0 +1,1966 @@ +# mypy: allow-untyped-defs +import copy +import enum +import functools +import logging +import re +import time +from abc import ABC, abstractmethod +from typing import Any, Optional, Union + +import torch +import torch.utils._pytree as pytree +from torch._inductor.autotune_process import TensorMeta +from torch._inductor.codegen.cuda.cutlass_cache import maybe_fetch_ops +from torch._inductor.codegen.wrapper import PythonWrapperCodegen +from torch._inductor.runtime.runtime_utils import dynamo_timed +from torch._inductor.scheduler import BaseSchedulerNode +from torch._inductor.select_algorithm import create_inputs_key +from torch._inductor.utils import clear_on_fresh_cache + +from ... import ir +from ...config import cuda as inductor_cuda_config +from ...ir import ( + Buffer, + ChoiceCaller, + CUDATemplateBuffer, + FixedLayout, + IRNode, + Layout, + ReinterpretView, +) +from ...utils import is_dynamic, Placeholder +from ...virtualized import V +from ..common import IndentedBuffer +from . import cutlass_utils +from .cuda_kernel import CUDATemplateKernel +from .cuda_template import CUTLASSTemplate +from .cutlass_python_evt import CutlassEVTCodegen, scaled_mm_evt +from .cutlass_utils import ( + ACCUMULATOR_DTYPES, + dtype_match, + torch_dtype_to_cutlass_type, + XW_DTYPES, +) + + +GemmOperation = Any +EVTArgRenames = Any + +log = logging.getLogger(__name__) + +# Jinja template for GEMM Kernel, used by the CUTLASSGemm3xTemplate class below. +GEMM_TEMPLATE_CUTLASS_3X = r""" +{{template.header().getvalue()}} +{{template.globals().getvalue()}} +{{epilogue_visitor_tree}} +{{instance_definition}} +// When workspace_size is not a nullptr, populates requested workspace_size and returns. +// Otherwise, computes the Gemm kernel using the given workspace ptr. +extern "C" { +PT_EXPORT {{kernel_call_signature}} { + try { + using ElementComputeEpilogue = {{instance_type}}::ElementAccumulator; + using coord_t = cutlass::gemm::GemmCoord::Index; + static cutlass::KernelHardwareInfo hw_info; + if (hw_info.sm_count == 0) { + hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(0); + CUTLASS_TRACE_HOST("Query result for SM count per device: " << hw_info.sm_count); + } + {{instance_type}}::Arguments arguments; + {{template.render_gemm_arguments(argument_template, epilogue_template, should_swap_xw, + X, W, Bias, Y, alpha, beta, kernel, epilogue_args)}} + {{instance_type}} gemm_op; + if (workspace_size) { + *workspace_size = gemm_op.get_workspace_size(arguments); + return 0; + } + // check for null pointers after workspace size, since querying workspace size doesn't require valid data pointers +#ifndef CUTLASS_BACKEND_DISABLE_CHECKS + { + auto status = gemm_op.can_implement(arguments); + CUTLASS_CHECK(status); + } +#endif +#ifdef CUTLASS_DEBUG_TRACE_LEVEL +#if CUTLASS_DEBUG_TRACE_LEVEL == 1 + { + // Print the maximum number of active blocks per SM for the kernel if CUTLASS_DEBUG_TRACE_LEVEL == 1 + // we don't need a print statement, it's happening inside the function. + gemm_op.maximum_active_blocks(); + } +#endif +#endif + { + auto status = gemm_op.initialize(arguments, workspace, stream); + CUTLASS_CHECK(status); + } + { + auto status = gemm_op(stream); + CUTLASS_CHECK(status); + } + } + catch (std::exception& e) { + std::cerr << "Runtime error: " << e.what() << std::endl; + return -1; + } + catch (...) { + return -1; + } + return 0; +} +} + +// configuration name: {{op_conf_name}} +""" + +# Jinja template for Cutlass 3.x GEMM Kernel arguments, used by the CUTLASSGemmTemplate class below. +GEMM_ARGS_CUTLASS_3X = r""" + // Initialize GemmUniversal3xInstance arguments. + arguments = { + {{template.gemm_mode()}}, // GemmUniversalMode mode + { + static_cast({{M}}), + static_cast({{N}}), + static_cast(K), + static_cast(B) + }, // ProblemShape problem_shape + { + {{template.cutlass_type_cast(X, kernel.ptr(X))}}, // ElementA const* ptr_A + { + {{template.cute_int(kernel.stride(X, -2), "stride_x0")}}, + {{template.cute_int(kernel.stride(X, -1), "stride_x1")}}, + {{template.cute_int(kernel.batch_stride(X), "batch_stride_x")}} + }, // StrideA dA + {{template.cutlass_type_cast(W, kernel.ptr(W))}}, // ElementB const* ptr_B + { + {{template.cute_int(kernel.stride(W, -1), "stride_w1")}}, + {{template.cute_int(kernel.stride(W, -2), "stride_w0")}}, + {{template.cute_int(kernel.batch_stride(W), "batch_stride_w")}} + }, // StrideB dB + }, // MainloopArguments mainloop + {{epilogue_arguments}}, + hw_info + }; + arguments.scheduler.max_swizzle_size = swizzle; +""" + +# Jinja template for Cutlass 3.x GEMM Kernel arguments if epilogue fusion is applied, +# used by the CUTLASSGemmTemplate class below. +GEMM_ARGS_CUTLASS_3X_EPILOGUE = r""" + // see https://tinyurl.com/4rk89z48 + { + {{epilogue_args}}, // thread, typename FusionCallbacks::Arguments ( EVT ) or ThreadEpilogueOp::Params (non-EVT ) + {{template.cutlass_type_cast(Bias, kernel.ptr(Bias))}}, // ElementC const* ptr_C + { + {{template.cute_int(kernel.stride(Bias, -2, 1), "stride_bias0")}}, + {{template.cute_int(kernel.stride(Bias, -1, 1), "stride_bias1")}}, + {{template.cute_int(kernel.batch_stride(Bias), "batch_stride_bias")}} + }, // StrideC dC + {{template.cutlass_type_cast(Y, kernel.ptr(Y))}}, // ElementD const* ptr_D + { + {{template.cute_int(kernel.stride(Y, -2), "stride_y0")}}, + {{template.cute_int(kernel.stride(Y, -1), "stride_y1")}}, + {{template.cute_int(kernel.batch_stride(Y), "batch_stride_y")}} + }, // StrideD dD + }, // EpilogueArguments epilogue +""" + +# Jinja template for GEMM Kernel, used by the CUTLASS2xGemmTemplate class below. +GEMM_TEMPLATE_CUTLASS_2X = r""" +{{template.header().getvalue()}} +{{template.globals().getvalue()}} +{{instance_definition}} +// When workspace_size is not a nullptr, populates requested workspace_size and returns. +// Otherwise, computes the Gemm kernel using the given workspace ptr. +extern "C" { +PT_EXPORT {{kernel_call_signature}} { + try { + int B = {{kernel.size(Y, 0, -3, default_value=1)}}; + using ElementComputeEpilogue = {{instance_type}}::ElementAccumulator; + using coord_t = cutlass::gemm::GemmCoord::Index; + static cutlass::KernelHardwareInfo hw_info; + if (hw_info.sm_count == 0) { + hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(0); + CUTLASS_TRACE_HOST("Query result for SM count per device: " << hw_info.sm_count); + } + {{instance_type}}::Arguments arguments; + {{template.render_gemm_arguments(instance_type, argument_template, epilogue_template, should_swap_xw, + X, W, Bias, Meta, Y, alpha, beta, kernel, epilogue_args)}} + {{instance_type}} gemm_op; + if (workspace_size) { + *workspace_size = gemm_op.get_workspace_size(arguments); + return 0; + } + + // check for null pointers after workspace size, since querying workspace size doesn't require valid data pointers +#ifndef CUTLASS_BACKEND_DISABLE_CHECKS + { + auto status = gemm_op.can_implement(arguments); + CUTLASS_CHECK(status); + } +#endif +#ifdef CUTLASS_DEBUG_TRACE_LEVEL +#if CUTLASS_DEBUG_TRACE_LEVEL == 1 + { + // Print the maximum number of active blocks per SM for the kernel if CUTLASS_DEBUG_TRACE_LEVEL == 1 + // we don't need a print statement, it's happening inside the function. + gemm_op.maximum_active_blocks(); + } +#endif +#endif + + { + auto status = gemm_op.initialize(arguments, workspace, stream); + CUTLASS_CHECK(status); + } + { + auto status = gemm_op(stream); + CUTLASS_CHECK(status); + } + } + catch (std::exception& e) { + std::cerr << "Runtime error: " << e.what() << std::endl; + return -1; + } + catch (...) { + return -1; + } + return 0; +} +} +""" + +# Jinja template for Cutlass 2.x GEMM Kernel arguments, used by the CUTLASS2xGemmTemplate class below. +GEMM_ARGS_CUTLASS_2X = r""" + int64_t batch_stride_x = {{kernel.stride(X, -3)}}; + int64_t row_stride_x = {{kernel.row_or_column_stride(X)}}; + int64_t batch_stride_w = {{kernel.stride(W, -3)}}; + int64_t row_stride_w = {{kernel.row_or_column_stride(W)}}; + int64_t batch_stride_bias = {{kernel.stride(Bias, -3)}}; + int64_t row_stride_bias = {{kernel.row_or_column_stride(Bias)}}; + int64_t batch_stride_y = {{kernel.stride(Y, -3)}}; + int64_t row_stride_y = {{kernel.row_or_column_stride(Y)}}; + // Initialize GemmUniversalInstance arguments. + arguments = { + {{template.gemm_mode()}}, // GemmUniversalMode mode + { + static_cast(M), + static_cast(N), + static_cast(K) + }, // GemmCoord problem_size + {{split_k if split_k > 1 else 'B'}}, // int batch_count + {ElementComputeEpilogue({{alpha}}), ElementComputeEpilogue({{beta}})}, // typename EpilogueOutputOp::Params epilogue + {{template.cutlass_type_cast(X, kernel.ptr(X))}}, // void const * ptr_A + {{template.cutlass_type_cast(W, kernel.ptr(W))}}, // void const * ptr_B + {{template.cutlass_type_cast(Bias, kernel.ptr(Bias))}}, // void const * ptr_C + {{template.cutlass_type_cast(Y, kernel.ptr(Y))}}, // void * ptr_D + batch_stride_x, // int64_t batch_stride_A + batch_stride_w, // int64_t batch_stride_B + batch_stride_bias, // int64_t batch_stride_C + batch_stride_y, // int64_t batch_stride_D + row_stride_x, // typename LayoutA::Stride::LongIndex lda + row_stride_w, // typename LayoutB::Stride::LongIndex ldb + row_stride_bias, // typename LayoutC::Stride::LongIndex ldc + row_stride_y, // typename LayoutC::Stride::LongIndex ldd + }; +""" + +GEMM_ARGS_SPARSE_CUTLASS_2X = r""" + using TensorRefA = cutlass::TensorRef<{{instance_type}}::ElementA, + {{instance_type}}::LayoutA>; + using TensorRefB = cutlass::TensorRef<{{instance_type}}::ElementB, + {{instance_type}}::LayoutB>; + using TensorRefC = cutlass::TensorRef<{{instance_type}}::ElementC, + {{instance_type}}::LayoutC>; + using TensorRefE = cutlass::TensorRef<{{instance_type}}::ElementE, + {{instance_type}}::LayoutE>; + // Note that "X" and "W" names may be misleading here. Namely, for + // sparse GEMM, the first argument is always sparse, while typically + // weight matrix, implied by name "W" will be sparse in + // applications. Thus, just remember that here: "X" refers to first + // argument, that is sparse, and "W" to second, that is dense. + TensorRefA X_ref({{template.cutlass_type_cast(X, kernel.ptr(X))}}, {{kernel.row_or_column_stride(X)}}); + TensorRefB W_ref({{template.cutlass_type_cast(W, kernel.ptr(W))}}, {{kernel.row_or_column_stride(W)}}); + TensorRefC Y_ref({{template.cutlass_type_cast(Y, kernel.ptr(Y))}}, {{kernel.row_or_column_stride(Y)}}); + TensorRefE Meta_ref({{template.cutlass_sparse_meta_type_cast(Meta, kernel.ptr(Meta))}}, + TensorRefE::Layout::packed({ {{kernel.size(Meta, 0)}}, {{kernel.size(Meta, 1)}} })); + // Initialize GemmSparse arguments. + arguments = { + { + static_cast(M), + static_cast(N), + static_cast(2 * K), + }, // GemmCoord problem_size + X_ref, // TensorRef ref_A + W_ref, // TensorRef ref_B + Y_ref, // TensorRef ref_C + Y_ref, // TensorRef ref_D + Meta_ref, // TensorRef ref_E + {ElementComputeEpilogue({{alpha}}), ElementComputeEpilogue({{beta}})}, // typename EpilogueOutputOp::Params epilogue, + }; +""" + +# Additional includes which are necessary if the standalone test / debug runner is generated as well +GEMM_STANDALONE_RUNNER_ADDITIONAL_INCLUDES = r""" +#ifdef GENERATE_STANDALONE_RUNNER +#include "cutlass/util/distribution.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/reference/device/gemm_complex.h" +#include "cutlass/util/reference/device/tensor_compare.h" +#include "cutlass/util/reference/device/tensor_fill.h" +#include +#endif +""" + +# Jinja template for the standalone runner that may be generated as part of the code. +GEMM_STANDALONE_RUNNER_TEMPLATE = r""" +#ifdef GENERATE_STANDALONE_RUNNER +/// Helper to initialize a block of device data +template +bool initialize_block( + cutlass::DeviceAllocation& block, + uint64_t seed, float max=1.0, float min=-1.0) { + if (block.size()<=0) return false; + Element scope_max(static_cast(max)), scope_min(static_cast(min)); + cutlass::reference::device::BlockFillRandomUniform( + (Element*)block.get(), block.size(), seed, scope_max, scope_min); + + return true; +} + +{% if Meta is defined and Meta is not none %} +template +bool initialize_block_meta( + cutlass::DeviceAllocation& block, + uint64_t seed) { + if (block.size()<=0) return false; + cutlass::reference::device::BlockFillRandomSparseMeta( + (Element*)block.get(), block.size(), seed, {{instance_type}}::kMetaSizeInBits); + return true; +} +{% endif %} + +extern "C" int run_standalone(uint64_t seed, int repetitions) { + std::cout << "Starting GEMM Standalone test run with seed " << seed << std::endl; + size_t workspace_size = 0; + size_t* workspace_size_ptr = &workspace_size; + + int M = {{kernel.get_layout_args()[0]}}; + int N = {{kernel.get_layout_args()[1]}}; + int K = {{kernel.get_layout_args()[2]}}; + int B = {{kernel.get_layout_args()[3]}}; + int lda = {{kernel.get_layout_args()[4]}}; + int ldb = {{kernel.get_layout_args()[5]}}; + int ldc = {{kernel.get_layout_args()[6]}}; + int ldd = {{kernel.get_layout_args()[7]}}; + uint8_t swizzle = {{kernel.runtime_arg_values[0]}}; + + using ElementA = {{kernel.cutlass_dtype(X)}}; + using ElementB = {{kernel.cutlass_dtype(W)}}; + using ElementC = {{kernel.cutlass_dtype(Bias, default_dtype='uint8_t')}}; // may not be void + using ElementD = {{kernel.cutlass_dtype(Y)}}; + {% if Meta is defined and Meta is not none %} + using ElementE = {{kernel.cutlass_dtype(Meta)}}; + {% endif %} + + cutlass::DeviceAllocation X_data({{kernel.max_valid_index(X)+1}}); + initialize_block(X_data, seed++); + cutlass::DeviceAllocation W_data({{kernel.max_valid_index(W)+1}}); + initialize_block(W_data, seed++); + cutlass::DeviceAllocation Bias_data({{kernel.max_valid_index(Bias)+1}}); + initialize_block(Bias_data, seed++); + cutlass::DeviceAllocation Y_data({{kernel.max_valid_index(Y)+1}}); + {% if Meta is defined and Meta is not none %} + cutlass::DeviceAllocation Meta_data({{kernel.max_valid_index(Meta)+1}}); + initialize_block_meta(Meta_data, seed++); + {% endif %} + + cutlass::DeviceAllocation workspace_data; + // Call once with workspace_size_ptr set to get workspace size + + std::cout << "Calling once to get workspace size" << std::endl; + {{test_call_statement}}; + // Allocate workspace if necessary + if (workspace_size > 0) { + workspace_data.reset(workspace_size); + std::cout << "Allocated workspace size of " << workspace_size << " bytes" << std::endl; + } + std::cout << "Calling Kernel as {{test_call_statement}};" << std::endl; + workspace_size_ptr = nullptr; + for (int i=0; i None: + """ + Args: + input_nodes (List[Buffer]): List of input nodes of the GEMM kernel. + layout (Layout): Layout type of the resulting output node. + alpha (float): The scaling factor for the product of the inputs in the GEMM operation. + beta (float): The scaling factor applied to the output matrix. + input_reorder (Optional[List[int]]): Specifies the reordering of the input nodes. If not provided, + no reordering is performed. Defaults to None. + """ + super().__init__( + str(Placeholder.KERNEL_NAME), input_nodes, layout, input_reorder + ) + self.alpha = alpha + self.beta = beta + self.use_fast_accum = use_fast_accum + assert 2 <= len(input_nodes) <= 5 + assert self._are_inputs_layout_compatible( + [node.get_layout() for node in input_nodes] + ) + + self.cache_key: str = create_inputs_key(self.input_nodes) + + @staticmethod + @abstractmethod + def add_cutlass_gemm_choices( + choices: list[ChoiceCaller], + layout: ir.Layout, + input_nodes: list[Buffer], + alpha: Union[float, int] = 1, + beta: Union[float, int] = 0, + input_reorder: Optional[list[int]] = None, + use_fast_accum: Optional[bool] = None, + **extra_kwargs, + ) -> None: + raise NotImplementedError + + @staticmethod + @abstractmethod + def _get_supported_ops() -> "list[cutlass_library.gemm_operation.GemmOperation]": # type: ignore[name-defined] # noqa: F821 + raise NotImplementedError + + @staticmethod + @abstractmethod + def _has_tma_epilogue(self) -> bool: + raise NotImplementedError + + @abstractmethod + def _get_template(self) -> str: + raise NotImplementedError + + @abstractmethod + def _get_template_args( + self, + op: "cutlass_library.gemm_op.GemmOperation", # type: ignore[name-defined] # noqa: F821 + ) -> tuple[str, Optional[str]]: + raise NotImplementedError + + @abstractmethod + def _are_inputs_layout_compatible(self, layouts: list[Layout]) -> bool: + raise NotImplementedError + + @abstractmethod + def _shape_match( + self, + op: "cutlass_library.gemm_op.GemmOperation", # type: ignore[name-defined] # noqa: F821 + ) -> bool: + raise NotImplementedError + + @abstractmethod + def _alignment_match( + self, + op: "cutlass_library.gemm_op.GemmOperation", # type: ignore[name-defined] # noqa: F821 + ) -> bool: + raise NotImplementedError + + @abstractmethod + def _set_bias_layout_and_alignment( + self, + op: "cutlass_library.gemm_op.GemmOperation", # type: ignore[name-defined] # noqa: F821 + ) -> bool: + raise NotImplementedError + + @abstractmethod + def _define_gemm_instance( + self, + op: GemmOperation, + evt_name: Optional[str] = None, + ) -> tuple[str, str]: + raise NotImplementedError + + @abstractmethod + def _get_extra_inputs_and_names( + self, + op: "cutlass_gemm_op.GemmOperation" = None, # type: ignore[name-defined] # noqa: F821 + ) -> tuple[Optional[Buffer], list[Optional[Buffer]], list[str]]: + raise NotImplementedError + + @abstractmethod + def _update_arg_names_for_test_call_statement( + self, + arg_names: list[str], + input_nodes: list[Buffer], + ) -> list[str]: + raise NotImplementedError + + def _add_cutlass_gemm_choices( + self, + choices: list[ChoiceCaller], + layout: ir.Layout, + input_nodes: list[Buffer], + alpha: Union[float, int] = 1, + beta: Union[float, int] = 0, + input_reorder: Optional[list[int]] = None, + **extra_kwargs, + ) -> None: + """ + Adds Cutlass GEMM configurations choices to the auto-tuning list. + + This function mutates the passed list of choices by appending the choices for Cutlass GEMM configs to it. + + Args: + choices (list): The list to which choices are appended. + layout (ir.Layout): The layout configuration. + input_nodes (list): The list of input nodes. + alpha (float,int): Scaling factor, defaults to 1. + beta (float,int): Offset, defaults to 0. + input_reorder (list, optional): Order of the inputs, defaults to None. + **extra_kwargs: Additional keyword arguments. + + """ + + ops = self.gen_ops() + + # pre-computation + layout_repr: str = str(layout) + input_tensor_meta: Union[TensorMeta, list[TensorMeta]] = ( + TensorMeta.from_irnodes(self.input_nodes) + ) + output_tensor_meta: Union[TensorMeta, list[TensorMeta]] = ( + TensorMeta.from_irnodes(self.output_node) + ) + + with dynamo_timed("CUTLASSGemmTemplate.maybe_append_choice"): + for name, op in ops: + for ( + swizzle + ) in inductor_cuda_config.cutlass_max_profiling_swizzle_options: + description = f"{name} swizzle={swizzle}" + self.maybe_append_choice( + choices, + op=op, + name=name, + description=description, + input_key=self.cache_key, + layout_repr=layout_repr, + input_tensor_meta=input_tensor_meta, + output_tensor_meta=output_tensor_meta, + swizzle=swizzle, + ) + + if len(ops) == 0: + log.info( + "No suitable Cutlass GEMM configs found, fallbacks used " + "( len(ops)=%d, output_layout=%s, input_layouts=%s, input_strides=%s )", + len(ops), + layout, + [node.get_layout() for node in input_nodes], + [node.get_stride() for node in input_nodes], + ) + log.debug( + "Added %d Cutlass gemm configs.", + len(ops), + ) + + def header(self) -> IndentedBuffer: + """ + Returns a buffer containing CUDA C++ code for the header section of the CUTLASS GEMM template. + This section primarily includes the necessary header files. + + Returns: + IndentedBuffer: An instance of IndentedBuffer that contains the generated CUDA C++ header code. + """ + res = super().header() + res.splice( + """ + #include "cutlass/gemm/gemm.h" + #include "cutlass/gemm/device/gemm_universal.h" + #include "cutlass/gemm/device/gemm_universal_adapter.h" + #include "cutlass/gemm/kernel/gemm_universal.hpp" + #include "cutlass/gemm/device/gemm_sparse.h" + #include "cutlass/gemm/collective/collective_builder.hpp" + #include "cutlass/epilogue/collective/collective_builder.hpp" + #include "cutlass/epilogue/collective/default_epilogue.hpp" + #include "cutlass/epilogue/thread/linear_combination.h" + #include "cutlass/epilogue/thread/activation.h" + #include "cutlass/gemm/dispatch_policy.hpp" + #include "cutlass/gemm/kernel/tile_scheduler.hpp" + #include "cutlass/tensor_ref.h" + #include "cutlass/util/distribution.h" + #include "cutlass/util/packed_stride.hpp" + #include "cutlass/util/tensor_view_io.h" + """ + ) + if inductor_cuda_config.generate_test_runner and not is_dynamic( + *self.input_nodes, self.output_node + ): + res.splice(GEMM_STANDALONE_RUNNER_ADDITIONAL_INCLUDES) + return res + + @staticmethod + def cutlass_layout(torch_layout: ir.Layout) -> "Optional[cutlass_lib.LayoutType]": # type: ignore[name-defined] # noqa: F821 + """ + Converts an ir.Layout instance into the corresponding cutlass_library.LayoutType enum value + (RowMajor, ColumnMajor, or None if no matching value is found ). + + Args: + torch_layout (ir.Layout): The layout that needs to be looked up. + + Returns: + cutlass_lib.LayoutType: The converted layout corresponding to the `torch_layout` or None if no matching + value is found. + """ + assert cutlass_utils.try_import_cutlass() + import cutlass_library.library as cutlass_lib + + if V.graph.sizevars.statically_known_equals(torch_layout.stride[-1], 1): + return cutlass_lib.LayoutType.RowMajor + elif V.graph.sizevars.statically_known_equals(torch_layout.stride[-2], 1): + return cutlass_lib.LayoutType.ColumnMajor + else: + return None + + @staticmethod + def flip_cutlass_layout( + cutlass_layout: "cutlass_lib.LayoutType", # type: ignore[name-defined] # noqa: F821 + ) -> "cutlass_lib.LayoutType": # type: ignore[name-defined] # noqa: F821 + """Helper method: Flips a given cutlass layout (cutlass_lib.LayoutType) from RowMajor + to ColumnMajor or vice versa""" + assert cutlass_utils.try_import_cutlass() + import cutlass_library.library as cutlass_lib + + if cutlass_layout == cutlass_lib.LayoutType.RowMajor: + return cutlass_lib.LayoutType.ColumnMajor + else: + return cutlass_lib.LayoutType.RowMajor + + @staticmethod + @functools.lru_cache(32) + def layout_match( + torch_layout: ir.Layout, + cutlass_layout: "cutlass_lib.LayoutType", # type: ignore[name-defined] # noqa: F821 + ) -> bool: + """Helper Method: Determines whether a given torch layout matches a given Cutlass layout""" + return CUTLASSGemmTemplate.cutlass_layout(torch_layout) == cutlass_layout + + @staticmethod + def set_layout(tensor_desc: "TensorDescription", torch_layout: ir.Layout) -> None: # type: ignore[name-defined] # noqa: F821 + """ + Helper method: Sets the layout of a given tensor description to match the given torch layout + """ + if CUTLASSGemmTemplate.layout_match(torch_layout, tensor_desc.layout): + return + tensor_desc.layout = CUTLASSGemmTemplate.cutlass_layout(torch_layout) + + @staticmethod + def set_alignment(torch_layout, op_element) -> bool: + """ + Helper method to update the alignment of a given CUTLASS GEMM op operand's element. + + This method modifies the alignment of the given Cutlass GEMM op operand's element to match the + layout of the corresponding ir.Buffer node. + + Args: + torch_layout: The layout of the corresponding ir.Buffer node. + op_element: The Cutlass GEMM op operand's element whose alignment is to be updated. + + Returns: + bool: True if the alignment was successfully updated, False otherwise. + """ + alignment = cutlass_utils.get_max_alignment(torch_layout) + cuda_arch = cutlass_utils.get_cuda_arch() + if cuda_arch and int(cuda_arch) >= 90 and alignment < op_element.alignment: + return False + else: + op_element.alignment = alignment + return True + + @staticmethod + def should_swap_XW( + bias: IRNode, + ) -> bool: + """ + Helper method to determine whether we should do an explicit transpose by switching the order of the + matmul operands. This might be necessary when we can't otherwise arrive at the right memory + layout for the given Bias operand. + + Note: This method is a workaround for CUDA Errors that seemingly non-deterministically + occurred in practice in some CUTLASS GEMM Kernels with Linear epilogues that have a bias term. + it might make sense to check on newer Cutlass releases whether it makes sense to keep + returning True in certain cases or whether it becomes unnecessary. + """ + # If bias is row major, swap all M and N dimensions + if ( + bias is not None + and len(bias.get_stride()) >= 2 + and bias.get_stride()[-1] in (0, 1) + ): + log.debug("GEMM Layout swapped X and W -> explicit transpose") + return True + return False + + @staticmethod + def swap_XW( + op: "cutlass_library.gemm_op.GemmOperation", # type: ignore[name-defined] # noqa: F821 + ) -> "cutlass_library.gemm_op.GemmOperation": # type: ignore[name-defined] # noqa: F821 + """ + Swap operands X and W (aka operans A and B) of the GEMM operation. This + requires transposing the operands, which is done by swapping the strides. + Note that we don't change the apparent external layout, just the operand layout. + this is intentional. + """ + new_op = copy.deepcopy(op) + new_op.A.layout = CUTLASSGemmTemplate.flip_cutlass_layout(new_op.A.layout) + new_op.B.layout = CUTLASSGemmTemplate.flip_cutlass_layout(new_op.B.layout) + new_op.A, new_op.B = new_op.B, new_op.A + new_op.C.layout = CUTLASSGemmTemplate.flip_cutlass_layout(new_op.C.layout) + new_op.D.layout = CUTLASSGemmTemplate.flip_cutlass_layout(new_op.D.layout) + return new_op + + def fix_op_layout( + self, + op: "cutlass_library.gemm_op.GemmOperation", # type: ignore[name-defined] # noqa: F821 + X: Buffer, + W: Buffer, + Bias: Optional[Buffer], + Y: Union[Buffer, ReinterpretView], + ) -> "cutlass_library.gemm_op.GemmOperation": # type: ignore[name-defined] # noqa: F821 + # This is a workaround to deal with cases where the input layouts have changed + # between autotuning and rendering. This happens if the inputs layout + # are FlexibleLayout instances. In this case, we need to update the + # op's input layouts. It is a hack, because now the op + # we benchmarked is not the same as the op we render, + # but there is no simple way to fix this in the autotuner, since that would + # potentially disable other optimizations. + a_layout = X.get_layout() + b_layout = W.get_layout() + c_layout = Bias.get_layout() if Bias is not None else None + + d_layout = copy.deepcopy(Y.get_layout()) + match_list = [ + CUTLASSGemmTemplate.layout_match(buf.get_layout(), op_layout) + for buf, op_layout in zip( + (X, W, Bias, Y), + (op.A.layout, op.B.layout, op.C.layout, op.D.layout), + ) + if buf is not None + ] + all_match = all(match_list) + if all_match: + return op + log.warning( + f"Cutlass GEMM Layout change: Input and/or output layouts have changed between autotuning/retuning and call to render on {self}. Applying workaround. This can lead to suboptimal performance. Match List: {match_list}" # noqa: G004, B950 + ) + new_op = copy.deepcopy(op) + + if a_layout is not None: + new_op.A.layout = CUTLASSGemmTemplate.cutlass_layout(a_layout) + if b_layout is not None: + new_op.B.layout = CUTLASSGemmTemplate.cutlass_layout(b_layout) + if c_layout is not None: + new_op.C.layout = CUTLASSGemmTemplate.cutlass_layout(c_layout) + new_op.C.element = cutlass_utils.torch_dtype_to_cutlass_type(c_layout.dtype) + if d_layout is not None: + new_op.D.layout = CUTLASSGemmTemplate.cutlass_layout(d_layout) + return new_op + + def _dtype_match( + self, + op: "cutlass_library.gemm_op.GemmOperation", # type: ignore[name-defined] # noqa: F821 + ) -> bool: + """ + Checking dtypes of A, B, acc, D here. + + Empirically speaking, CUTLASS2x ops have same dtype for C and D. + """ + X = self.input_nodes[0] + W = self.input_nodes[1] + + accumulator_torch_dtype = cutlass_utils.get_accumulator_dtype( + [X.get_dtype(), W.get_dtype()], + ) + if not ( + cutlass_utils.dtype_match(X.get_dtype(), op.A.element) + and cutlass_utils.dtype_match(W.get_dtype(), op.B.element) + and cutlass_utils.dtype_match( + self.output_node.get_layout().dtype, op.D.element + ) + and cutlass_utils.dtype_match( + accumulator_torch_dtype, op.accumulator_type() + ) + ): + return False + + return True + + @classmethod + def global_filter_ops( + cls, + ops: list["cutlass_library.gemm_op.GemmOperation"], # type: ignore[name-defined] # noqa: F821 + ) -> list["cutlass_library.gemm_op.GemmOperation"]: # type: ignore[name-defined] # noqa: F821 + """ + Filter ops without using information about the torch op, input nodes and output node. + """ + assert cutlass_utils.try_import_cutlass() + import cutlass_library.library as cutlass_lib # type: ignore[import] + + # Skip simt kernels + ops = [ + op + for op in ops + if op.tile_description.math_instruction.opcode_class + != cutlass_lib.OpcodeClass.Simt + ] + + # only keep the set of row x column ops + # for other layout, we modify in place in filter_op, after deepcopy + ops = [ + op + for op in ops + if op.A.layout.name == "RowMajor" and op.B.layout.name == "ColumnMajor" + ] + + # filter by supported accumulator types + ops = [ + op + for op in ops + if any( + dtype_match(torch_dtype, op.accumulator_type()) + for torch_dtype in ACCUMULATOR_DTYPES + ) + ] + + # check if dtypes of A and B are supported + ops = [ + op + for op in ops + if any(dtype_match(torch_dtype, op.A.element) for torch_dtype in XW_DTYPES) + and any(dtype_match(torch_dtype, op.B.element) for torch_dtype in XW_DTYPES) + ] + + return ops + + def filter_op( + self, + op: "cutlass_library.gemm_op.GemmOperation", # type: ignore[name-defined] # noqa: F821 + ) -> "cutlass_library.gemm_op.GemmOperation": # type: ignore[name-defined] # noqa: F821 + """ + Helper method: + + Determines whether a given Cutlass GEMM op definition is suitable for the current + input / output of the operation that this template is supposed to implement. + + Takes memory layout, dtype and support for EVT operations into account, + and filters potentially problematic ops. + + Returns None if the op is not suitable, otherwise returns the op to be used, which might + have been mutated. + """ + + if op.gemm_kind not in self._get_supported_ops(): + return None + + X = self.input_nodes[0] + W = self.input_nodes[1] + + # Filter ops according to the shape match. + if not self._shape_match(op): + return None + + # Filter ops by dtypes. + if not self._dtype_match(op): + return None + + # Filter ops by alignment. + if not self._alignment_match(op): + log.debug( + "Skipping due to alignment mismatch. op: %s", op.configuration_name() + ) + return None + + # only use stream k for static shape + if op.tile_scheduler.name == "StreamK": + static_shape = PythonWrapperCodegen.statically_known_list_of_ints_or_none( + tuple(X.get_size()) + tuple(W.get_size()) + ) + if not static_shape: + return None + + # Update op. + op = copy.deepcopy(op) + + # set layouts for X and W + self.set_layout(op.A, X.get_layout()) + self.set_layout(op.B, W.get_layout()) + + # Set output layout. + op.D.layout = CUTLASSGemmTemplate.cutlass_layout(self.output_node.get_layout()) + + # Filter ops by alignments and set alignments. + status = ( + self.set_alignment(X.get_layout(), op.A) + and self.set_alignment(W.get_layout(), op.B) + and self.set_alignment(self.output_node.get_layout(), op.D) + ) + if not status: + log.debug( + "Skipping due to alignment setting failure. op: %s", + op.configuration_name(), + ) + return None + + if inductor_cuda_config.cutlass_tma_only and not self._has_tma_epilogue(op): + return None + + # Set epilogue. + # TODO: update epilogue functor according to epilogues. + op.element_epilogue = op.accumulator_type() + + if self.use_fast_accum is not None: + is_op_fast_accum = "fastaccum" in op.configuration_name() + if self.use_fast_accum ^ is_op_fast_accum: + return None + + # Set bias layout and alignment. + status = self._set_bias_layout_and_alignment(op) + if not status: + log.debug( + "Skipping due to bias layout and alignment setting failure. op: %s", + op.configuration_name(), + ) + return None + + # Apply regex filters at the end when configuration name doesn't change anymore + if inductor_cuda_config.cutlass_op_allowlist_regex: + if not re.search( + inductor_cuda_config.cutlass_op_allowlist_regex, op.configuration_name() + ): + return None + if inductor_cuda_config.cutlass_op_denylist_regex is not None: + if re.search( + inductor_cuda_config.cutlass_op_denylist_regex, op.configuration_name() + ): + return None + + return op + + def gen_ops(self) -> "list[tuple[str, cutlass_gemm_op.GemmOperation]]": # type: ignore[name-defined] # noqa: F821 + """ + Creates a list of Cutlass GemmOperation instances that match the operation this template is designed to represent. + The matching is carried out with respect to the input and output specifications of the operation. + + No function arguments. + + Returns: + List[tuple[str, cutlass_gemm_op.GemmOperation]]: A list of (cutlass_name, GemmOperation) + tuples that are compatible with the operation requirements of this template. + """ + assert cutlass_utils.try_import_cutlass() + import cutlass_library.gemm_operation as cutlass_gemm_op + + if self.cache_key in self.filtered_ops_cache: + log.debug("Using cached ops for %s", self.cache_key) + return self.filtered_ops_cache[self.cache_key] + + with dynamo_timed("CUTLASSGemmTemplate.maybe_fetch_ops"): + maybe_ops = maybe_fetch_ops() + if maybe_ops is None: + log.debug("Cannot fetch ops from cache, generating ops from scratch") + full_ops = cutlass_utils.gen_ops() + ops = pytree.tree_flatten(full_ops)[0] + else: + log.debug("Using cached ops from cache") + ops = maybe_ops + + ops = self.global_filter_ops(ops) + + res: dict[str, cutlass_gemm_op.GemmOperation] = {} + start_time = time.time() + for op in ops: + # if changed, need to also change CUTLASS_OPERATION_KIND + assert isinstance(op, cutlass_gemm_op.GemmOperation) + filter_res = self.filter_op(op) + if ( + filter_res is not None + and res.get(filter_res.configuration_name(), None) is None + ): + res[filter_res.configuration_name()] = filter_res + log.info( + "Got cutlass configs: total number of ops: %d. Filtering took %.2f seconds", + len(res), + time.time() - start_time, + ) + sorted_res = sorted(res.items()) + ret_res = sorted_res[: inductor_cuda_config.cutlass_max_profiling_configs] + if len(self.filtered_ops_cache) < 50: + self.filtered_ops_cache[self.cache_key] = ret_res + else: + log.debug("Not caching ops since filtered_ops_cache has reached size 50.") + return ret_res + + def gemm_mode(self) -> str: + """ + Returns a Cutlass GEMM mode string for the current operation, dependent on whether this op implements + a batched GEMM or a simple GEMM without batch dimension. + + Returns: + str: A string indicating the Cutlass GEMM mode. If the output node has more than two dimensions, + "cutlass::gemm::GemmUniversalMode::kBatched" is returned, otherwise + "cutlass::gemm::GemmUniversalMode::kGemm" is returned. + """ + sizes = self.output_node.get_size() + if len(sizes) > 2: + return "cutlass::gemm::GemmUniversalMode::kBatched" + else: + return "cutlass::gemm::GemmUniversalMode::kGemm" + + def render( # type: ignore[override] + self, + kernel: CUDATemplateKernel, + op: "cutlass_gemm_op.GemmOperation" = None, # type: ignore[name-defined] # noqa: F821 + template_buffer_node: Optional[CUDATemplateBuffer] = None, + epilogue_nodes: Optional[list[BaseSchedulerNode]] = None, + **kwargs, + ) -> str: + """ + The primary entry point for the code rendering process used in this template. + Renders the Cutlass based CUDA C++ code for the GEMM Kernel that this template is designed to implement, + including potentially fused epilogues. + + Args: + kernel (CUDATemplateKernel): The kernel to be rendered. + op (cutlass_gemm_op.GemmOperation, optional): A GEMM operation that is required to be compatible with the + input and output definitions as well as a possible epilogue. Defaults to None. + **kwargs: Additional keyword arguments. Currently unused. + + Returns: + str: Cutlass based CUDA C++ code fragment as a string, to be used by the current + CUDATemplateKernel or autotuning code. + + Note: + All inputs and their corresponding buffer addresses and names take precedence over previously + passed inputs to the template at construction time. However, they should be layout compatible. + """ + assert cutlass_utils.try_import_cutlass() + import cutlass_library.gemm_operation as cutlass_gemm_op + import cutlass_library.library as cutlass_lib + + assert isinstance(op, cutlass_gemm_op.GemmOperation), ( + "op argument is required and has to be an instance of GemmOperation" + ) + + if epilogue_nodes and not self._has_tma_epilogue(op): + raise NotImplementedError( + "Non-TMA epilogue visitor tree is not supported in Cutlass." + ) + + assert len(self.input_nodes) >= 2 and self.output_node is not None + X, W = self.input_nodes[0], self.input_nodes[1] + for input_node in self.input_nodes: + if not isinstance(X.layout, FixedLayout): + input_node.freeze_layout() + + Y = self.output_node + if template_buffer_node is not None: + Y = template_buffer_node + + Bias, extra_inputs, extra_names = self._get_extra_inputs_and_names(op) + + # Define Kernel call signature + # Important: This step also populates Kernel name to node mapping data structures, + # which are required further below ( for example by the template renderer ) + inputs = [X, W, Bias, *extra_inputs] + names = ["X", "W", "Bias", *extra_names] + ["Y"] + names_str = ",".join(names) + if self.input_reorder is not None: + input_reorder = self.input_reorder + else: + input_reorder = None + + # The layouts might have changed between autotuning and this call if they were FlexibleLayout + # we need to adapt, which might lead to suboptimal performance. + op = self.fix_op_layout(op, X, W, Bias, Y) + + # to make op mutable without affecting others + op = copy.deepcopy(op) + is_scaled_mm = len(self.input_nodes) in (4, 5) + if Bias is not None and not is_scaled_mm: + assert Bias.get_dtype() == X.get_dtype() + # This might have been set to void during filtering, when the assumption was still that there's no C + # operand + op.C.element = op.A.element + + assert op.C.element == op.D.element, ( + f"Expect C and D to have the same dtype, found {op.C.element} and {op.D.element}" + ) + + argument_template, epilogue_template = self._get_template_args(op) + should_swap_xw: bool = False + if Bias is not None and self._has_tma_epilogue(op): + if ( + op.epilogue_schedule + != cutlass_lib.EpilogueScheduleType.EpilogueTransposed + and self.should_swap_XW(Bias) + ): + # TMA epilogue requires bias vector in column major to get best perf. + op = self.swap_XW(op) + should_swap_xw = True + + name_to_buffer = {node.get_name(): node for node in self.input_nodes} + # handle the fake output buffer during lowering + name_to_buffer[Y.get_name()] = Y # type: ignore[assignment] + + if epilogue_nodes or is_scaled_mm: + if epilogue_nodes: + ( + input_names, + output_names, + var_name_to_buffer_name, + evt_py_code, + ) = CutlassEVTCodegen.ir_to_evt_python_code( + Y.get_name(), epilogue_nodes, V.kernel.removed_buffers + ) + + # TODO: mlazos remove this by returning buffer metadata from + # ir_to_evt_python code + for name, buf in ( + V.graph.name_to_buffer | V.graph.graph_inputs + ).items(): + if name not in name_to_buffer: + name_to_buffer[name] = buf # type: ignore[assignment] + + D_output_name = var_name_to_buffer_name["D"] + D_output_buffer = name_to_buffer[D_output_name] + Y = D_output_buffer # type: ignore[assignment] + # Interestingly, I don't think the rest of the layout matters here since we + # use the properties of the Y buffer to fill in D's properties in the epilogue + # args. This is needed though because it defines types expected in the epilogue args. + op.D.element = cutlass_utils.torch_dtype_to_cutlass_type( + D_output_buffer.get_dtype() + ) + + assert output_names, "There should be at least one write" + + epilogue_inputs = [name_to_buffer[name] for name in input_names] + outputs = [name_to_buffer[name] for name in output_names] + else: # Scaled MM, we read the two scale matrices (and optional bias) and write a single output + bias = None if len(self.input_nodes) < 5 else self.input_nodes[4] + bias_name = bias.get_name() if bias else None + + ( + evt_read_names, + var_name_to_buffer_name, + evt_py_code, + ) = scaled_mm_evt( + self.input_nodes[2].get_name(), # scale_A + self.input_nodes[3].get_name(), # scale_B + bias_name, + Y.get_name(), + ) + + input_names = list(evt_read_names) + output_names = [] # We only need Y + epilogue_inputs = [self.input_nodes[2], self.input_nodes[3]] + if bias: + epilogue_inputs.append(bias) + outputs = [] + + acc_dtype = cutlass_utils.get_accumulator_dtype( + [X.get_dtype(), W.get_dtype()] + ) + assert acc_dtype, "Could not determine accumulator dtype" + + evt_name, evt_args, evt_code, evt_arg_renames = self._render_evt( + op, + evt_py_code, + var_name_to_buffer_name, + name_to_buffer, + Y.get_dtype(), + acc_dtype, + ) + + inputs = [ + X, + W, + Bias, + *epilogue_inputs, # type: ignore[list-item] + Y, + *extra_inputs, + ] + input_names = [evt_arg_renames.get(name) for name in input_names] + output_names = [evt_arg_renames.get(name) for name in output_names] + + names_str = ",".join( + ["X", "W", "Bias", *input_names, "Y", *output_names, *extra_names] + ) + else: + evt_name = None + outputs = [Y] + evt_args = f"{{ElementComputeEpilogue({self.alpha}), ElementComputeEpilogue({self.beta})}}" + evt_code = "" + + kernel_call_signature = kernel.def_kernel( + inputs=inputs, # type: ignore[arg-type] + outputs=outputs, # type: ignore[arg-type] + names_str=names_str, + input_reorder=input_reorder, + ) + + test_call_statement = self.test_call_statement(kernel, inputs, names_str) + + instance_definition, instance_type = self._define_gemm_instance(op, evt_name) + + options = { + "alpha": self.alpha, + "beta": self.beta, + "X": X, + "W": W, + "Y": Y, + "kernel_call_signature": kernel_call_signature, + "Bias": Bias, + "epilogue_template": epilogue_template, + "argument_template": argument_template, + "should_swap_xw": should_swap_xw, + "template": self, + "kernel": kernel, + "instance_definition": instance_definition, + "instance_type": instance_type, + "input_reorder": self.input_reorder, + "epilogue_args": evt_args, + "test_call_statement": test_call_statement, + "op_conf_name": op.configuration_name(), + "epilogue_visitor_tree": evt_code, + } + options.update(dict(zip(extra_names, extra_inputs))) + res = self._template_from_string(self._get_template()).render(**options) + if inductor_cuda_config.generate_test_runner and not is_dynamic(X, W, Y, Bias): + test_runner_code = self._template_from_string( + GEMM_STANDALONE_RUNNER_TEMPLATE + ).render(**options) + res += "\n\n" + test_runner_code + + # splice to remove trailing spaces in each line + buf = IndentedBuffer() + buf.splice(res) + return buf.getvalue() + + def test_call_statement( + self, + kernel, + input_nodes, + names_str: str = "", + ) -> str: + """ + Helper method to render the Cutlass CUDA C++ code required for calling the GEMM operation in the standalone + test runner that might also be generated along with the rest of the code, if the corresponding config is + enabled. + + Returns a C++ statement that calls the GEMM operation with the correct arguments. + """ + _, __, arg_types = kernel.args.cpp_argdefs(cutlass_utils.DTYPE_TO_CUTLASS_TYPE) + arg_names = [name.strip() for name in names_str.strip().split(",")] + arg_names = self._update_arg_names_for_test_call_statement( + arg_names, input_nodes + ) + arguments = [ + f"(({arg_type}){arg_name}_data.get())" + for arg_type, arg_name in zip(arg_types, arg_names) + ] + return f"{kernel.kernel_name}({', '.join(arguments)}, M, N, K, B, lda, ldb, ldc, ldd, 0, 0, 0, swizzle, workspace_size_ptr, (uint8_t*)workspace_data.get(), 0);" # noqa: B950 + + def _render_evt( + self, + op: GemmOperation, + evt_py_code: str, + buffer_renames: dict[str, str], + name_to_buffer: dict[str, Buffer], + output_dtype: torch.dtype, + accumulator_dtype: torch.dtype, + ) -> tuple[str, str, str, EVTArgRenames]: # type: ignore[name-defined] # noqa: F821 + raise NotImplementedError("_render_evt in CUTLASSGemmTemplate not implemented") + + +class CUTLASS3xGemmTemplate(CUTLASSGemmTemplate): + """ + CUTLASS 3x GEMM Template, which is used to generate CUTLASS GEMM kernels + including those which allow flexible fusions with epilogues. + """ + + @staticmethod + def add_cutlass_gemm_choices( + choices: list[ChoiceCaller], + layout: ir.Layout, + input_nodes: list[Buffer], + alpha: Union[float, int] = 1, + beta: Union[float, int] = 0, + input_reorder: Optional[list[int]] = None, + use_fast_accum: Optional[bool] = None, + **extra_kwargs, + ) -> None: + template = CUTLASS3xGemmTemplate( + input_nodes, + layout, + alpha, + beta, + input_reorder, + use_fast_accum, + ) + template._add_cutlass_gemm_choices( + choices, layout, input_nodes, alpha, beta, input_reorder, **extra_kwargs + ) + + @staticmethod + @functools.lru_cache(1) + def _get_supported_ops() -> "list[cutlass_library.gemm_operation.GemmOperation]": # type: ignore[name-defined] # noqa: F821 + import cutlass_library.library as cutlass_lib + + return [cutlass_lib.GemmKind.Universal3x] + + def _get_template(self) -> str: + return GEMM_TEMPLATE_CUTLASS_3X + + def _get_template_args( + self, + op: "cutlass_library.gemm_op.GemmOperation", # type: ignore[name-defined] # noqa: F821 + ) -> tuple[str, Optional[str]]: + return (GEMM_ARGS_CUTLASS_3X, GEMM_ARGS_CUTLASS_3X_EPILOGUE) + + @staticmethod + def _has_tma_epilogue( # noqa: F821 # type: ignore[arg-type,name-defined] + op: "cutlass_library.gemm_op.GemmOperation", # type: ignore[name-defined,arg-type] # noqa: F821 + ) -> bool: # type: ignore[name-defined] + """Helper method: Determine whether a given Cutlass GEMM op has a TMA Epilogue""" + assert cutlass_utils.try_import_cutlass() + import cutlass_library.library as cutlass_lib + + result = False + if op.gemm_kind == cutlass_lib.GemmKind.Universal3x: + epilogue_schedule_str = str(op.epilogue_schedule).split(".")[-1] + result = epilogue_schedule_str.lower().startswith("tma") + return result + + @staticmethod + def supports_epilogue_fusion(op: GemmOperation) -> bool: + return CUTLASS3xGemmTemplate._has_tma_epilogue(op) + + def _are_inputs_layout_compatible(self, layouts: list[Layout]) -> bool: + """ + Evaluates whether input layouts are compatible for General Matrix Multiply (GEMM). + + This function checks compatibility of A, B, and possibly C operand layouts for + a General Matrix Multiply (GEMM) operation, expressed as 'alpha * matmul(A, B) + beta * C'. + It verifies requirements such as matching data types, minimum rank, and suitability + for broadcasting, as defined by PyTorch operations like `torch.matmul`, `torch.aten.mm`, + `addmm`, `bmm`, `baddbmm`, etc. + + Args: + layouts (List[Layout]): List containing 2 or 3 Layout objects representing + the input matrices A, B, and possibly C. + + Returns: + bool: True if layouts are GEMM compatible, otherwise False. + """ + assert 2 <= len(layouts) <= 5 + # Check if A and B are compatible + A_layout, B_layout = layouts[:2] + if len(A_layout.size) < 1: + return False + if len(B_layout.size) < 1: + return False + A_size = list(V.graph.sizevars.size_hints(A_layout.size)) + B_size = list(V.graph.sizevars.size_hints(B_layout.size)) + if len(A_size) < 2: + A_size.insert(0, 1) + if len(B_size) < 2: + A_size.insert(1, 1) + # Are batch dims broadcastable? + while len(A_size) < len(B_size): + A_size.insert(0, 1) + while len(B_size) < len(A_size): + B_size.insert(0, 1) + K = max(A_size[-1], B_size[-2]) + M = A_size[-2] + N = B_size[-1] + if K != A_size[-1] and A_size[-1] != 1: + return False + if K != B_size[-2] and B_size[-1] != 1: + return False + # check batch dim broadcastable + for i in range(len(A_size) - 2): + if A_size[i] != B_size[i] and A_size[i] != 1 and B_size[i] != 1: + return False + if len(layouts) == 3: + C_layout = layouts[2] + C_size = [V.graph.sizevars.size_hint(i) for i in C_layout.size] + while len(C_size) < len(A_size): + C_size.insert(0, 1) + # check batch dims + for i in range(len(A_size) - 2): + bd = max(A_size[i], B_size[i]) + if bd != C_size[i] and C_size[i] != 1: + return False + if len(C_size) > len(A_size): + # This may happen if the last elements of C are contiguous and + # their multiplied size equals the last dim size of B + if M != C_size[len(A_size) - 2] and C_size[len(A_size) - 2] != 1: + return False + remaining_size = 1 + for i in range(len(A_size) - 1, len(C_size)): + remaining_size *= C_size[i] + if N != remaining_size and remaining_size != 1: + return False + return True + assert len(C_size) == len(A_size) + if M != C_size[-2] and C_size[-2] != 1: + return False + if N != C_size[-1] and C_size[-1] != 1: + return False + return True + + def _render_evt( + self, + op: GemmOperation, + evt_py_code: str, + var_name_to_buffer_name: dict[str, str], + name_to_buffer: dict[str, Buffer], + output_dtype: torch.dtype, + accumulator_dtype: torch.dtype, + ) -> tuple[str, str, str, EVTArgRenames]: + from .cutlass_lib_extensions.evt_extensions import create_example_tensors, trace + + acc_dtype = torch_dtype_to_cutlass_type(accumulator_dtype) + output_dtype = torch_dtype_to_cutlass_type(output_dtype) + + examples = create_example_tensors( + var_name_to_buffer_name, + name_to_buffer, # type: ignore[arg-type] + V.graph.sizevars.size_hint, + ) + evt_name, evt_args, evt_code, arg_renames = trace( + evt_py_code, + examples, + acc_dtype, + output_dtype, + op.tile_description, # type: ignore[attr-defined] + op.epilogue_schedule, # type: ignore[attr-defined] + {k: name_to_buffer[v] for k, v in var_name_to_buffer_name.items()}, # type: ignore[arg-type,misc] + V.graph.sizevars.size_hint, + ) + + return ( + evt_name, + evt_args, + evt_code, + arg_renames, + ) + + def _shape_match( + self, + op: "cutlass_library.gemm_op.GemmOperation", # type: ignore[name-defined] # noqa: F821 + ) -> bool: + return True + + def _alignment_match( + self, + op: "cutlass_library.gemm_op.GemmOperation", # type: ignore[name-defined] # noqa: F821 + ) -> bool: + return True + + def _set_bias_layout_and_alignment( + self, + op: "cutlass_library.gemm_op.GemmOperation", # type: ignore[name-defined] # noqa: F821 + ) -> bool: + import cutlass_library.library as cutlass_lib + + has_bias = len(self.input_nodes) == 3 and self.input_nodes[2] is not None + if has_bias: + Bias = self.input_nodes[2] + # bias dtype + op.C.element = cutlass_utils.torch_dtype_to_cutlass_type( + Bias.get_layout().dtype + ) + + # Bias layout + bias_layout = CUTLASSGemmTemplate.cutlass_layout(Bias.get_layout()) + op.C.layout = bias_layout + + # Bias alignment + status = self.set_alignment(Bias.get_layout(), op.C) + if not status: + return False + else: + op.C.element = cutlass_lib.DataType.void + return True + + def _define_gemm_instance( + self, + op: GemmOperation, + evt_name: Optional[str] = None, + ) -> tuple[str, str]: + """Defines and renders the Cutlass / CUDA C++ code for a given GEMM operation instance. + + This function uses the Cutlass library to generate key parts of the codegen process. General Matrix Multiply + forms a core part of a number of scientific applications, so this efficient and adaptable implementation is + crucial. + + Args: + op (cutlass_library.gemm_op.GemmOperation): This is the core GEMM operation that we are defining and rendering. + + Returns: + tuple[str, str]: A tuple where the first part is a string that constitutes the defined GEMM operation in C++ + code (render) and the second part is the string that specifies the operation type. + """ + assert cutlass_utils.try_import_cutlass() + import cutlass_library.library as cutlass_lib + + from .cutlass_lib_extensions import gemm_operation_extensions as gemm_extensions + + emitter = gemm_extensions.EmitGemmUniversal3xInstanceWithEVT(evt_name=evt_name) # type: ignore[call-arg] + + if not hasattr(op, "epilogue_functor") or not isinstance( + op.epilogue_functor, enum.Enum + ): + op = copy.deepcopy(op) + op.epilogue_functor = cutlass_lib.EpilogueFunctor.LinearCombination + + op_def = emitter.emit(op) + pattern = re.compile(r"\s*struct\s(.*?)\s:") + decl = [line for line in op_def.split("\n") if "struct " in line][-1] + + match = pattern.match(decl) + if match is None: + raise RuntimeError("Invalid Gemm config: \n" + op_def) + op_type = match.groups()[0] + if op.gemm_kind == cutlass_lib.GemmKind.Universal3x: + op_def += f"\n using {op_type}_device_type = cutlass::gemm::device::GemmUniversalAdapter<{op_type}>;\n" + op_type = f"{op_type}_device_type" + + return op_def, op_type + + def _get_extra_inputs_and_names( + self, + op: "cutlass_gemm_op.GemmOperation" = None, # type: ignore[name-defined] # noqa: F821 + ) -> tuple[Optional[Buffer], list[Optional[Buffer]], list[str]]: + Bias = self.input_nodes[2] if len(self.input_nodes) == 3 else None + inputs: list[Optional[Buffer]] = [] + names: list[str] = [] + return (Bias, inputs, names) + + def _update_arg_names_for_test_call_statement( + self, + arg_names: list[str], + input_nodes: list[Buffer], + ) -> list[str]: + if input_nodes[2] is None: + del arg_names[2] + else: + # Reorder them as Bias, A, B + if self.input_reorder is not None: + arg_names[0 : len(self.input_reorder)] = [ + arg_names[i] for i in self.input_reorder + ] + return arg_names + + def render_gemm_arguments( + self, + argument_template: str, + epilogue_template: str, + should_swap_xw: bool, + X: IRNode, + W: IRNode, + Bias: IRNode, + Y: IRNode, + alpha: float, + beta: float, + kernel: CUDATemplateKernel, + epilogue_args, + ) -> str: + """ + Render the Cutlass CUDA C++ code required for passing arguments to the GEMM operation. + + Args: + argument_template (str): Template for the GEMM operation arguments. + epilogue_template (str): Template for the epilogue arguments. + should_swap_xw (bool): Determines whether X, W operands should be swapped. If True, applies an explicit + transpose operation to X and W. + X (IRNode): The X input tensor. + W (IRNode): The W input tensor. + Bias (IRNode): The bias tensor. + Y (IRNode): The output tensor. + alpha (float): Scaling factor for the product of the inputs. + beta (float): Scaling factor for the output tensor. + kernel (CUDATemplateKernel): CUDA Template kernel for the operation. + epilogue_args (any): Additional arguments for the epilogue state. + + Returns: + str: A block of CUDA C++ code as a string, ready to be used as arguments for the GEMM operation. + + Note: If `should_swap_xw` is True, a transpose operation will be applied to the X, W, Bias, and Y + tensors. This operation also implies the M and N dimensions of Bias and GEMM output to be swapped + before the function call. + """ + options = { + "alpha": alpha, + "beta": beta, + "X": X, + "W": W, + "Y": Y, + "Bias": Bias, + "template": self, + "kernel": kernel, + "M": "M", + "N": "N", + "epilogue_args": epilogue_args, + } + assert epilogue_template is not None + + if should_swap_xw: + # Swap + def clone_with_transposed_stride(node: IRNode) -> IRNode: + old_layout = node.get_layout() + new_stride = list(old_layout.stride) # type: ignore[union-attr] + new_stride[-2], new_stride[-1] = new_stride[-1], new_stride[-2] + assert old_layout.device is not None + new_layout = FixedLayout( + old_layout.device, + old_layout.dtype, + list(old_layout.size), # type: ignore[union-attr] + new_stride, + old_layout.offset, # type: ignore[union-attr] + ) + return Buffer(name=node.get_name(), layout=new_layout) + + new_X = clone_with_transposed_stride(X) + new_W = clone_with_transposed_stride(W) + new_Bias = clone_with_transposed_stride(Bias) + new_Y = clone_with_transposed_stride(Y) + options["X"], options["W"], options["Bias"], options["Y"] = ( + new_W, + new_X, + new_Bias, + new_Y, + ) + options["M"], options["N"] = "N", "M" + + epilogue_arguments = self._template_from_string(epilogue_template).render( + **options + ) + arguments = self._template_from_string(argument_template).render( + epilogue_arguments=epilogue_arguments, **options + ) + + return arguments + + +class CUTLASS2xGemmTemplate(CUTLASSGemmTemplate): + def __init__( + self, + input_nodes: list[Buffer], + layout: Layout, + alpha: float, + beta: float, + input_reorder: Optional[list[int]] = None, + ): + super().__init__(input_nodes, layout, alpha, beta, input_reorder) + + @staticmethod + def add_cutlass_gemm_choices( + choices: list[ChoiceCaller], + layout: ir.Layout, + input_nodes: list[Buffer], + alpha: Union[float, int] = 1, + beta: Union[float, int] = 0, + input_reorder: Optional[list[int]] = None, + use_fast_accum: Optional[bool] = False, + **extra_kwargs, + ) -> None: + template = CUTLASS2xGemmTemplate( + input_nodes, layout, alpha, beta, input_reorder + ) + template._add_cutlass_gemm_choices( + choices, layout, input_nodes, alpha, beta, input_reorder, **extra_kwargs + ) + + @staticmethod + def _get_supported_ops() -> "list[cutlass_library.gemm_operation.GemmOperation]": # type: ignore[name-defined] # noqa: F821 + import cutlass_library.library as cutlass_lib + + return [cutlass_lib.GemmKind.Universal, cutlass_lib.GemmKind.Sparse] + + @staticmethod + def _has_tma_epilogue(self) -> bool: + return False + + def _get_template(self) -> str: + return GEMM_TEMPLATE_CUTLASS_2X + + def _get_template_args( + self, + op: "cutlass_library.gemm_op.GemmOperation", # type: ignore[name-defined] # noqa: F821 + ) -> tuple[str, Optional[str]]: + import cutlass_library.library as cutlass_lib + + if op.gemm_kind == cutlass_lib.GemmKind.Sparse: + return (GEMM_ARGS_SPARSE_CUTLASS_2X, None) + + return (GEMM_ARGS_CUTLASS_2X, None) + + def _are_inputs_layout_compatible(self, layouts: list[Layout]) -> bool: + """ + Evaluates whether input layouts are compatible for set of operations supported by this class. + + Args: + layouts (List[Layout]): List containing Layout objects representing + the input matrices. + + Returns: + bool: True if layouts are GEMM compatible, otherwise False. + """ + assert len(layouts) == 2 or len(layouts) == 3 + # Check if A and B are compatible + A_layout, B_layout = layouts[:2] + if len(A_layout.size) != 2: + return False + if len(B_layout.size) != 2: + return False + A_size = [int(i) for i in A_layout.size] + B_size = [int(i) for i in B_layout.size] + K = max(A_size[1], B_size[0]) + return (K == A_size[1] or K == 2 * A_size[1]) and K == B_size[0] + + def _shape_match( + self, + op: "cutlass_library.gemm_op.GemmOperation", # type: ignore[name-defined] # noqa: F821 + ) -> bool: + import cutlass_library.library as cutlass_lib + + X, W = self.input_nodes[0], self.input_nodes[1] + + if op.gemm_kind == cutlass_lib.GemmKind.Sparse: + return X.get_size()[1] * 2 == W.get_size()[0] + + return X.get_size()[1] == W.get_size()[0] + + def _alignment_match( + self, + op: "cutlass_library.gemm_op.GemmOperation", # type: ignore[name-defined] # noqa: F821 + ) -> bool: + import cutlass_library.library as cutlass_lib + + if op.gemm_kind != cutlass_lib.GemmKind.Sparse: + return True + + # SparseGemm in CUTLASS has specific alignment check that for + # small k could make some of the choices throw kMisalignedOperand + # CUTLASS error when run, see: + # https://github.com/NVIDIA/cutlass/blob/e01b9b5029b7caca5a43c29f7d2714d7cf1dcae8/include/cutlass/gemm/kernel/sparse_gemm.h#L198-L200 # noqa: B950 + # So, let's skip these choices if that would be the case. + X = self.input_nodes[0] + return (X.get_size()[1] * 2) % op.tile_description.tile_shape[2] == 0 + + def _set_bias_layout_and_alignment( + self, + op: "cutlass_library.gemm_op.GemmOperation", # type: ignore[name-defined] # noqa: F821 + ) -> bool: + import cutlass_library.library as cutlass_lib + + if op.gemm_kind == cutlass_lib.GemmKind.Sparse: + op.C.layout = op.D.layout + return True + + if len(self.input_nodes) >= 3 and self.input_nodes[2] is not None: + Bias = self.input_nodes[2] + bias_layout = CUTLASSGemmTemplate.cutlass_layout(Bias.get_layout()) + if bias_layout != op.D.layout: + # For cutlass2, bias and output layout must match + return False + if not self.set_alignment(Bias.get_layout(), op.C): + return False + else: + op.C.layout = op.D.layout + return True + + def _define_gemm_instance( + self, + op: GemmOperation, + evt_name: Optional[str] = None, + ) -> tuple[str, str]: + """Defines and renders the Cutlass / CUDA C++ code for a given GEMM operation instance. + + This function uses the Cutlass library to generate key parts of the codegen process. General Matrix Multiply + forms a core part of a number of scientific applications, so this efficient and adaptable implementation is + crucial. + + Args: + op (cutlass_library.gemm_op.GemmOperation): This is the core GEMM operation that we are defining and rendering. + + Returns: + tuple[str, str]: A tuple where the first part is a string that constitutes the defined GEMM operation in C++ + code (render) and the second part is the string that specifies the operation type. + """ + assert cutlass_utils.try_import_cutlass() + import cutlass_library.gemm_operation as cutlass_gemm_op + import cutlass_library.library as cutlass_lib + + if op.gemm_kind == cutlass_lib.GemmKind.Sparse: + emitter = cutlass_gemm_op.EmitSparseGemmInstance() + else: + emitter = cutlass_gemm_op.EmitGemmInstance() + op_def = emitter.emit(op) + op_def = op_def.replace( + "cutlass::gemm::device::Gemm", "cutlass::gemm::device::GemmUniversal" + ) + if op.gemm_kind != cutlass_lib.GemmKind.Sparse: + op_def = op_def.replace("false,", "") + pattern = re.compile(r"\s*using\s(.*?)\s=") + decl = op_def.split("\n")[2] + + match = pattern.match(decl) + if match is None: + raise RuntimeError("Invalid Gemm config: \n" + op_def) + op_type = match.groups()[0] + return op_def, op_type + + def _get_extra_inputs_and_names( + self, + op: "cutlass_gemm_op.GemmOperation" = None, # type: ignore[name-defined] # noqa: F821 + ) -> tuple[Optional[Buffer], list[Optional[Buffer]], list[str]]: + import cutlass_library.library as cutlass_lib + + if op.gemm_kind == cutlass_lib.GemmKind.Sparse: + Bias = None + Meta = self.input_nodes[2] + else: + Bias = None if len(self.input_nodes) == 2 else self.input_nodes[2] + Meta = None + inputs = [Meta] + names = ["Meta"] + return (Bias, inputs, names) + + def _update_arg_names_for_test_call_statement( + self, + arg_names: list[str], + input_nodes: list[Buffer], + ) -> list[str]: + if input_nodes[3] is None: + del arg_names[3] + if input_nodes[2] is None: + del arg_names[2] + return arg_names + + def render_gemm_arguments( + self, + instance_type: str, + argument_template: str, + epilogue_template: str, + should_swap_xw: bool, + X: IRNode, + W: IRNode, + Bias: IRNode, + Meta: IRNode, + Y: IRNode, + alpha: float, + beta: float, + kernel: CUDATemplateKernel, + epilogue_args, + ) -> str: + """ + Render the Cutlass CUDA C++ code required for passing arguments to the GEMM operation. + + Args: + instance_type (str): GEMM instance type. + argument_template (str): Template for the GEMM operation arguments. + epilogue_template (str): Template for the epilogue arguments. + should_swap_xw (bool): Determines whether X, W operands should be swapped. If True, applies an explicit + transpose operation to X and W. + X (IRNode): The X input tensor. + W (IRNode): The W input tensor. + Bias (IRNode): The bias tensor. + Meta (IRNode): The meta tensor. + Y (IRNode): The output tensor. + alpha (float): Scaling factor for the product of the inputs. + beta (float): Scaling factor for the output tensor. + kernel (CUDATemplateKernel): CUDA Template kernel for the operation. + epilogue_args (any): Additional arguments for the epilogue state. + + Returns: + str: A block of CUDA C++ code as a string, ready to be used as arguments for the GEMM operation. + + Note: If `should_swap_xw` is True, a transpose operation will be applied to the X, W, Bias, and Y + tensors. This operation also implies the M and N dimensions of Bias and GEMM output to be swapped + before the function call. + """ + options = { + "instance_type": instance_type, + "alpha": alpha, + "beta": beta, + "X": X, + "W": W, + "Y": Y, + "Bias": Bias, + "Meta": Meta, + "template": self, + "kernel": kernel, + "M": "M", + "N": "N", + "epilogue_args": epilogue_args, + } + + if epilogue_template is None: + arguments = self._template_from_string(argument_template).render( + split_k=1, **options + ) + return arguments + + epilogue_arguments = self._template_from_string(epilogue_template).render( + **options + ) + arguments = self._template_from_string(argument_template).render( + epilogue_arguments=epilogue_arguments, **options + ) + + return arguments diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/serialization.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/serialization.py new file mode 100644 index 0000000000000000000000000000000000000000..a17f04b0a1b5a25ee623880eac8daf56a63e8ef4 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/serialization.py @@ -0,0 +1,507 @@ +# mypy: allow-untyped-defs +import functools +import json +from enum import Enum +from typing import Any, Optional + +from torch._inductor.codegen.cuda.cutlass_utils import try_import_cutlass + + +class CUTLASSOperationSerializer: + """Serializes and deserializes CUTLASS GEMM operations to/from JSON. + + Handles GemmOperation objects and their nested components (TileDescription, TensorDescription). + """ + + # not used, but keeping in case we want to generalize the serializer + _SUPPORTED_CLASSES: list[str] = [ + "GemmOperation", + "GemmKind", + "TileDescription", + "TensorDescription", + "DataType", + "EpilogueFunctor", + "EpilogueFunctor3x", + "SwizzlingFunctor", + "KernelScheduleType", + "EpilogueScheduleType", + "TileSchedulerType", + ] + + @classmethod + def serialize(cls, operation: "GemmOperation") -> str: # type: ignore[name-defined] # noqa: F821 + """Serialize a GEMM operation to JSON string. + + Args: + operation: GemmOperation object + + Returns: + str: JSON string representation of the operation + """ + assert operation.__class__.__qualname__ == "GemmOperation", ( + "Only GemmOperation objects are supported via the main API" + ) + return json.dumps(cls._gemm_operation_to_json(operation)) + + @classmethod + def deserialize(cls, json_str: str) -> "GemmOperation": # type: ignore[name-defined] # noqa: F821 + """Deserialize JSON string to a GEMM operation. + + Args: + json_str: JSON string of a GEMM operation + + Returns: + GemmOperation: Reconstructed operation + """ + json_dict = json.loads(json_str) + return cls._json_to_gemm_operation(json_dict) + + @classmethod + def _gemm_operation_to_json(cls, operation: "GemmOperation") -> dict[str, Any]: # type: ignore[name-defined] # noqa: F821 + """Convert GemmOperation to JSON-serializable dict. + + Args: + operation: GemmOperation object + + Returns: + dict: Dictionary representation + """ + from cutlass_library.library import TensorDescription + + # Create the main dictionary with required and optional parameters + result = { + # Required parameters + "gemm_kind": cls._enum_to_json(operation.gemm_kind), + "arch": operation.arch, + "tile_description": cls._tile_description_to_json( + operation.tile_description + ), + "A": cls._tensor_description_to_json(operation.A), + "B": cls._tensor_description_to_json(operation.B), + "C": cls._tensor_description_to_json(operation.C), + "element_epilogue": cls._enum_to_json(operation.element_epilogue), + # Optional parameters + "epilogue_functor": cls._enum_to_json(operation.epilogue_functor), + "swizzling_functor": cls._enum_to_json(operation.swizzling_functor), + "D": cls._tensor_description_to_json(operation.D) if operation.D else None, + "kernel_schedule": cls._enum_to_json(operation.kernel_schedule), + "epilogue_schedule": cls._enum_to_json(operation.epilogue_schedule), + "tile_scheduler": cls._enum_to_json(operation.tile_scheduler), + } + + # Process optional attributes + optional_attrs = [ + "mixed_input_mode", + "mixed_input_shuffle", + "ScaleFactorA", + "ScaleFactorB", + "ScaleFactorD", + "ScaleFactorMVecSize", + "ScaleFactorNVecSize", + "ScaleFactorKVecSize", + "ScaleFactorVectorSize", + "is_3x", + ] + + for attr in optional_attrs: + if not hasattr(operation, attr): + continue + + value = getattr(operation, attr) + + if isinstance(value, TensorDescription): + result[attr] = cls._tensor_description_to_json(value) + elif isinstance(value, Enum): + result[attr] = cls._enum_to_json(value) + else: + result[attr] = value + + return result + + @classmethod + def _json_to_gemm_operation(cls, json_dict: dict[str, Any]) -> "GemmOperation": # type: ignore[name-defined] # noqa: F821 + """Convert JSON dict to GemmOperation object. + + Args: + json_dict: Dictionary representation + + Returns: + GemmOperation: Reconstructed object + """ + from cutlass_library import DataType + from cutlass_library.gemm_operation import GemmKind, GemmOperation + from cutlass_library.library import ( + EpilogueFunctor, + EpilogueFunctor3x, + EpilogueScheduleType, + KernelScheduleType, + MixedInputMode, + SwizzlingFunctor, + TileSchedulerType, + ) + + # Extract constructor parameters from the JSON dictionary + gemm_kind = cls._json_to_enum(json_dict["gemm_kind"], GemmKind) + arch = json_dict["arch"] + tile_description = cls._json_to_tile_description(json_dict["tile_description"]) + A = cls._json_to_tensor_description(json_dict.get("A"), "A") + B = cls._json_to_tensor_description(json_dict.get("B"), "B") + C = cls._json_to_tensor_description(json_dict.get("C"), "C") + element_epilogue = cls._json_to_enum(json_dict["element_epilogue"], DataType) + + # Get optional parameters with defaults + epilogue_functor = cls._json_to_enum( + json_dict.get("epilogue_functor"), + EpilogueFunctor3x if json_dict.get("is_3x") else EpilogueFunctor, + ) + swizzling_functor = cls._json_to_enum( + json_dict.get("swizzling_functor"), SwizzlingFunctor + ) + D = cls._json_to_tensor_description(json_dict.get("D"), "D") + kernel_schedule = cls._json_to_enum( + json_dict.get("kernel_schedule"), KernelScheduleType + ) + epilogue_schedule = cls._json_to_enum( + json_dict.get("epilogue_schedule"), EpilogueScheduleType + ) + tile_scheduler = cls._json_to_enum( + json_dict.get("tile_scheduler"), TileSchedulerType + ) + + mixed_input_mode = cls._json_to_enum( + json_dict.get("mixed_input_mode"), MixedInputMode + ) + mixed_input_shuffle = json_dict.get("mixed_input_shuffle", False) + + # Scale factors + ScaleFactorA = cls._json_to_enum(json_dict.get("ScaleFactorA"), DataType) + ScaleFactorB = cls._json_to_enum(json_dict.get("ScaleFactorB"), DataType) + + ScaleFactorD = None + if "ScaleFactorD" in json_dict and "ScaleFactorVectorSize" in json_dict: + ScaleFactorD = { + "tensor": cls._json_to_tensor_description( + json_dict.get("ScaleFactorD"), "ScaleFactorD" + ), + "vector_size": json_dict.get("ScaleFactorVectorSize"), + } + + ScaleFactorMVecSize = json_dict.get("ScaleFactorMVecSize") + ScaleFactorNVecSize = json_dict.get("ScaleFactorNVecSize") + ScaleFactorKVecSize = json_dict.get("ScaleFactorKVecSize") + + # Create the GemmOperation with the extracted parameters + operation = GemmOperation( + gemm_kind=gemm_kind, + arch=arch, + tile_description=tile_description, + A=A, + B=B, + C=C, + element_epilogue=element_epilogue, + epilogue_functor=epilogue_functor, + swizzling_functor=swizzling_functor, + D=D, + kernel_schedule=kernel_schedule, + epilogue_schedule=epilogue_schedule, + tile_scheduler=tile_scheduler, + mixed_input_mode=mixed_input_mode, + mixed_input_shuffle=mixed_input_shuffle, + ScaleFactorA=ScaleFactorA, + ScaleFactorB=ScaleFactorB, + ScaleFactorD=ScaleFactorD, + ScaleFactorMVecSize=ScaleFactorMVecSize, + ScaleFactorNVecSize=ScaleFactorNVecSize, + ScaleFactorKVecSize=ScaleFactorKVecSize, + ) + + return operation + + @classmethod + @functools.lru_cache(None) + def _tile_description_to_json(cls, tile_desc: "TileDescription") -> str: # type: ignore[name-defined] # noqa: F821 + """ + Convert TileDescription to JSON string. + + Args: + tile_desc: TileDescription object + + Returns: + str: JSON string representation + """ + + # Create the main dictionary with field names matching TileDescription constructor parameters + result = { + "threadblock_shape": tile_desc.threadblock_shape, + "stages": tile_desc.stages, + "warp_count": tile_desc.warp_count, + "math_instruction": cls._math_instruction_to_json( + tile_desc.math_instruction + ), + "min_compute": tile_desc.minimum_compute_capability, # Store as min_compute for constructor + "max_compute": tile_desc.maximum_compute_capability, # Store as max_compute for constructor + "cluster_shape": tile_desc.cluster_shape, + "explicit_vector_sizes": tile_desc.explicit_vector_sizes, + } + + # Add tile_shape if it exists and differs from threadblock_shape + if ( + hasattr(tile_desc, "tile_shape") + and tile_desc.tile_shape != tile_desc.threadblock_shape + ): + result["tile_shape"] = tile_desc.tile_shape + + return json.dumps(result) + + @classmethod + @functools.lru_cache(None) + def _json_to_tile_description( + cls, json_dict: Optional[str] + ) -> Optional["TileDescription"]: # type: ignore[name-defined] # noqa: F821 + """ + Convert JSON dict to TileDescription object. + + Args: + json_dict: Dictionary representation + + Returns: + TileDescription: Reconstructed object + """ + if json_dict is None: + return None + + tile_dict = json.loads(json_dict) + + from cutlass_library.library import TileDescription + + math_instruction = cls._json_to_math_instruction(tile_dict["math_instruction"]) + + # Get compute capability values, checking both naming conventions + min_compute = tile_dict.get( + "min_compute", tile_dict.get("minimum_compute_capability") + ) + max_compute = tile_dict.get( + "max_compute", tile_dict.get("maximum_compute_capability") + ) + + # Get cluster shape with default value + cluster_shape = tile_dict.get("cluster_shape", [1, 1, 1]) + + # Create the TileDescription object + tile_desc = TileDescription( + threadblock_shape=tile_dict["threadblock_shape"], + stages=tile_dict["stages"], + warp_count=tile_dict["warp_count"], + math_instruction=math_instruction, + min_compute=min_compute, + max_compute=max_compute, + cluster_shape=cluster_shape, + explicit_vector_sizes=tile_dict.get("explicit_vector_sizes"), + ) + + # Set tile_shape if it exists and differs from threadblock_shape + if ( + "tile_shape" in tile_dict + and tile_dict["tile_shape"] != tile_dict["threadblock_shape"] + ): + tile_desc.tile_shape = tile_dict["tile_shape"] + + return tile_desc + + @classmethod + @functools.lru_cache(None) + def _math_instruction_to_json( + cls, + math_instruction: Optional["MathInstruction"], # type: ignore[name-defined] # noqa: F821 + ) -> Optional[str]: + """Convert MathInstruction to JSON string. + + Args: + math_instruction: MathInstruction object + + Returns: + Optional[str]: JSON string representation or None + """ + if math_instruction is None: + return None + + result = { + "instruction_shape": math_instruction.instruction_shape, + "element_a": cls._enum_to_json(math_instruction.element_a), + "element_b": cls._enum_to_json(math_instruction.element_b), + "element_accumulator": cls._enum_to_json( + math_instruction.element_accumulator + ), + "opcode_class": cls._enum_to_json(math_instruction.opcode_class), + "math_operation": cls._enum_to_json(math_instruction.math_operation), + "element_scale_factor": cls._enum_to_json( + math_instruction.element_scale_factor + ), + } + + return json.dumps(result) + + @classmethod + @functools.lru_cache(None) + def _json_to_math_instruction( + cls, json_dict: Optional[str] + ) -> Optional["MathInstruction"]: # type: ignore[name-defined] # noqa: F821 + """Convert JSON string to MathInstruction object. + + Args: + json_dict: JSON string representation + + Returns: + Optional[MathInstruction]: Reconstructed object or None + """ + if json_dict is None: + return None + + from cutlass_library import DataType + from cutlass_library.library import MathInstruction, MathOperation, OpcodeClass + + mi_dict = json.loads(json_dict) + + # Convert string enum names back to enum values + element_a = cls._json_to_enum(mi_dict["element_a"], DataType) + element_b = cls._json_to_enum(mi_dict["element_b"], DataType) + element_acc = cls._json_to_enum(mi_dict["element_accumulator"], DataType) + + # Get the opcode_class enum + opcode_class = cls._json_to_enum(mi_dict["opcode_class"], OpcodeClass) + + # Get the math_operation enum + math_op = cls._json_to_enum(mi_dict["math_operation"], MathOperation) + + # Create the MathInstruction object + math_instruction_obj = MathInstruction( + instruction_shape=mi_dict["instruction_shape"], + element_a=element_a, + element_b=element_b, + element_accumulator=element_acc, + opcode_class=opcode_class, + math_operation=math_op, + ) + + # Add element_scale_factor if it exists + if ( + "element_scale_factor" in mi_dict + and mi_dict["element_scale_factor"] is not None + ): + math_instruction_obj.element_scale_factor = cls._json_to_enum( + mi_dict["element_scale_factor"], DataType + ) + + return math_instruction_obj + + @classmethod + @functools.lru_cache(None) + def _tensor_description_to_json( + cls, + tensor_desc: Optional["TensorDescription"], # type: ignore[name-defined] # noqa: F821 + ) -> Optional[str]: + """Convert TensorDescription to JSON string. + + Args: + tensor_desc: TensorDescription object + + Returns: + Optional[str]: JSON string representation or None + """ + if tensor_desc is None: + return None + + result = { + "element": cls._enum_to_json(tensor_desc.element), + "layout": cls._enum_to_json(tensor_desc.layout), + "alignment": tensor_desc.alignment, + "complex_transform": cls._enum_to_json(tensor_desc.complex_transform), + } + + return json.dumps(result) + + @classmethod + @functools.lru_cache(None) + def _json_to_tensor_description( + cls, + json_dict: Optional[str], + tensor_name: Optional[str] = None, + ) -> Optional["TensorDescription"]: # type: ignore[name-defined] # noqa: F821 + """Convert JSON string to TensorDescription object. + + Args: + json_dict: JSON string representation + tensor_name: Name of the tensor to avoid cache in the same op + + Returns: + Optional[TensorDescription]: Reconstructed object or None + """ + if json_dict is None: + return None + + tensor_dict = json.loads(json_dict) + + from cutlass_library import DataType + from cutlass_library.library import ( + ComplexTransform, + LayoutType, + TensorDescription, + ) + + element = cls._json_to_enum(tensor_dict["element"], DataType) + layout = cls._json_to_enum(tensor_dict["layout"], LayoutType) + alignment = tensor_dict["alignment"] + complex_transform = cls._json_to_enum( + tensor_dict["complex_transform"], ComplexTransform + ) + + return TensorDescription(element, layout, alignment, complex_transform) + + @classmethod + @functools.lru_cache(None) + def _enum_to_json(cls, enum_value: Optional[Enum]) -> Optional[str]: + """Convert enum value to JSON string. + + Args: + enum_value: Enum value + + Returns: + Optional[str]: JSON string representation or None + """ + if enum_value is None: + return None + + result = { + "type": enum_value.__class__.__name__, + "name": enum_value.name, + } + + return json.dumps(result) + + @classmethod + @functools.lru_cache(None) + def _json_to_enum(cls, json_dict: Optional[str], enum_class: Any) -> Optional[Enum]: + """Convert JSON string to enum value. + + Format: {name: "EnumName", value: 1} + + Args: + json_dict: JSON string representation + enum_class: Target enum class + + Returns: + Optional[Enum]: Reconstructed enum value or None + """ + if json_dict is None: + return None + + enum_dict = json.loads(json_dict) + + return enum_class[enum_dict["name"]] + + +@functools.lru_cache(1) +def get_cutlass_operation_serializer() -> Optional[CUTLASSOperationSerializer]: + if not try_import_cutlass(): + return None + return CUTLASSOperationSerializer() diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cutedsl/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cutedsl/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f12fa963fd60c00deb9f36f9515e3e794c9529ef --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cutedsl/__init__.py @@ -0,0 +1,8 @@ +# mypy: allow-untyped-defs +from .cutedsl_template import CuteDSLTemplate, CuteDSLTemplateCaller + + +__all__ = [ + "CuteDSLTemplate", + "CuteDSLTemplateCaller", +] diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cutedsl/_cutedsl_utils.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cutedsl/_cutedsl_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..17f850c8078c8d058bad8007e9cf14b69599003b --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cutedsl/_cutedsl_utils.py @@ -0,0 +1,29 @@ +# mypy: disable-error-code=import-not-found +# pyrefly: ignore [import-error] +import cutlass.cute as cute + + +@cute.jit # type: ignore[misc] +def ssa_to_indexable(ssa_value: cute.TensorSSA, dtype: str) -> cute.Numeric: + """ + Convert SSA form to indexable non-SSA form. + + Workaround for lack of gather support: SSA values cannot be used directly + as indices in tensor loads. This converts SSA → fragment → scalar for indexing. + """ + frag = cute.make_rmem_tensor(1, dtype) + frag.store(ssa_value) + return frag[0] + + +@cute.jit # type: ignore[misc] +def result_to_ssa(value: cute.Numeric, dtype: str) -> cute.TensorSSA: + """ + Convert non-SSA result back to SSA form. + + After performing operations with non-SSA values (like indexed loads), + convert the result back to SSA form for further computation. + """ + frag = cute.make_rmem_tensor(1, dtype) + frag[0] = value + return frag.load() diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cutedsl/cutedsl_kernel.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cutedsl/cutedsl_kernel.py new file mode 100644 index 0000000000000000000000000000000000000000..4772ee1541726ec6b016a39f8974d15e676da6c8 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cutedsl/cutedsl_kernel.py @@ -0,0 +1,599 @@ +# mypy: allow-untyped-defs +import contextlib +import dataclasses +import logging +import textwrap +from collections.abc import Callable +from typing import Any, Optional + +import sympy + +import torch +from torch._inductor import config +from torch._inductor.codegen.common import ( + CSE, + CSEVariable, + IndentedBuffer, + Kernel, + ValueRanges, +) +from torch._inductor.ir import ( + BaseView, + Buffer, + ComputedBuffer, + ExternKernel, + InputBuffer, + MutableBox, + ReinterpretView, +) +from torch._inductor.ops_handler import StoreMode +from torch._inductor.utils import OrderedSet +from torch._inductor.virtualized import V + +from ...utils import sympy_index_symbol +from .cutedsl_op_overrides import CuteDSLOpOverrides + + +# TODO setting the 'main' kernel w/ this suffix. We have 3 should probably just auto generate this +MAIN_SUFFIX = "main" + + +log = logging.getLogger(__name__) +kernel_code_log = torch._logging.getArtifactLogger(__name__, "kernel_code") + + +class CuteDSLKernelWrapper: + """Wrapper to provide .run() interface for CuteDSL kernels""" + + def __init__( + self, kernel_fn: Callable[..., Any], kernel_path: Optional[str] = None + ): + self.kernel_fn = kernel_fn + self.kernel_path = kernel_path + kernel_code_log.info("CuteDSL kernel path: %s", kernel_path) + + def run(self, *args, stream=None, **kwargs): + """ + Execute the CuteDSL kernel. + + Args: + *args: Arguments to pass to the kernel function + stream: CUDA stream to pass to the kernel function + **kwargs: Additional keyword arguments for the kernel + + Returns: + Result of the kernel execution + """ + return self.kernel_fn(*args, stream=stream, **kwargs) + + +@dataclasses.dataclass +class CuteDSLSubgraphInfo: + """Minimal subgraph info for CuteDSL kernels.""" + + body: IndentedBuffer + template_mask: Optional[str] = None + template_out: Optional[str] = None + cse: Optional[CSE[Any]] = None + + def __post_init__(self): + self.only_copy_if_non_none_fields = ("cse",) + + def to_dict(self): + return { + field.name: getattr(self, field.name) for field in dataclasses.fields(self) + } + + +class CuteDSLTemplateKernel(Kernel): + """ + Template kernel implementation for CuteDSL (CUTLASS Python DSL). + Handles code generation and argument management for CuteDSL CUDA kernels. + Provides CuteDSL-specific functionality for tensor conversion and kernel configuration. + """ + + def __init__( + self, + kernel_name: str, + input_nodes: list[Buffer], + output_node: Buffer, + subgraphs: Optional[list[Buffer]] = None, + ) -> None: + # Call parent Kernel constructor + super().__init__() + self.kernel_name = kernel_name + self.input_nodes = input_nodes + self.output_node = output_node + self.subgraphs = subgraphs + self.subgraph_bodies: dict[str, CuteDSLSubgraphInfo] = {} + + # Template attributes + self.body: IndentedBuffer = IndentedBuffer() + self.template_mask: Optional[str] = None + self.template_out: Optional[str] = None + self.template_indices: Optional[list[Any]] = None + self.render_hooks: dict[str, Any] = {} + + # TODO Additional attributes needed by template system + self.prologue_fused_inputs: OrderedSet[str] = OrderedSet() + self.prologue_fused_inputs_preserve_zero: OrderedSet[str] = OrderedSet() + self.named_input_nodes: dict[str, Buffer] = {} + + # Create named input nodes mapping + for i, input_node in enumerate(input_nodes): + node_name = getattr(input_node, "name", f"input_{i}") + self.named_input_nodes[node_name] = input_node + + self.cse = CSE(name_prefix="tmp") + + # Track all tensor buffers added during modification processing + self.collected_tensor_buffers: list[str] = [] + + def kexpr(self, expr: sympy.Expr) -> str: + """Convert sympy expression to CuteDSL string representation.""" + return str(expr) + + def gen_imports(self) -> str: + """Generate common imports for CuteDSL templates.""" + imports = IndentedBuffer() + imports.splice( + """ + import torch + import cutlass + import cutlass.cute as cute + from cutlass.cute.runtime import from_dlpack + import cuda.bindings.driver as cuda + from cutlass._mlir.dialects import math as mlir_math + import operator + from torch._inductor.codegen.cutedsl._cutedsl_utils import ssa_to_indexable, result_to_ssa + """ + ) + return imports.getvalue() + + def gen_defines(self, **kwargs) -> str: + """Generate CuteDSL parameter definitions from kwargs, similar to Triton's gen_defines.""" + params = IndentedBuffer() + for name, val in kwargs.items(): + params.writeline(f"{name}: cutlass.Constexpr = {val}") + return params.getvalue() + + def render(self, template, **kwargs): + from torch._inductor.select_algorithm import PartialRender + + """Render the kernel using the template, returning PartialRender object with hooks.""" + # Available {{}} hooks for jinja rendering + template_env = { + "def_kernel": self.def_kernel, + "gen_defines": lambda: self.gen_defines(**kwargs), + "get_output": self.get_output, + "get_tensor_buffers": self.get_tensor_buffers, + "unpack_buffers": self.unpack_buffers, + "modification": self.modification, + "set_cute_hash": self.set_cute_hash, + } + + # Render the template with the environment and provided kwargs + rendered_code = template.render( + kernel_name=self.kernel_name, + input_nodes=self.input_nodes, + output_node=self.output_node, + **template_env, + **kwargs, + ) + + # Always prepend the common imports + imports = self.gen_imports() + full_code = imports + rendered_code + + return PartialRender(full_code, self.render_hooks) + + @contextlib.contextmanager + def set_subgraph_body(self, body_name: str): + """Set the active subgraph body for template processing.""" + assert all( + hasattr(self, field.name) + for field in dataclasses.fields(CuteDSLSubgraphInfo) + ) + old_state = { + key.name: getattr(self, key.name) + for key in dataclasses.fields(CuteDSLSubgraphInfo) + } + + if body_name not in self.subgraph_bodies: + self.subgraph_bodies[body_name] = CuteDSLSubgraphInfo( + body=IndentedBuffer(), + template_mask=None, + template_out=None, + cse=None, + ) + + subgraph = self.subgraph_bodies[body_name] + for key, value in subgraph.to_dict().items(): + if value is None and key in getattr( + subgraph, "only_copy_if_non_none_fields", () + ): + continue + setattr(self, key, value) + + try: + yield + finally: + # Save current state back to subgraph + self.subgraph_bodies[body_name] = CuteDSLSubgraphInfo( + **{ + key.name: getattr(self, key.name) + for key in dataclasses.fields(CuteDSLSubgraphInfo) + } + ) + # Restore old state + for key, value in old_state.items(): + setattr(self, key, value) + + @contextlib.contextmanager + def create_subgraph_body(self, body_name: str, *, clear_cse: bool = False): + """Create a new subgraph body for template processing.""" + assert body_name not in self.subgraph_bodies, ( + f"Subgraph body '{body_name}' already exists" + ) + new_cse = self.cse.clone() if clear_cse else None + self.subgraph_bodies[body_name] = CuteDSLSubgraphInfo( + body=IndentedBuffer(), + template_mask=None, + template_out=None, + cse=new_cse, + ) + with self.set_subgraph_body(body_name): + yield + + def _get_reinterpret_view(self, node) -> ReinterpretView | None: + """Extract or convert to ReinterpretView from a node, handling all views.""" + while isinstance(node, MutableBox): + node = node.data + if isinstance(node, BaseView): + return ExternKernel.convert_to_reinterpret_view(node) + return None + + def def_kernel(self, *argnames): + """Define kernel function signature for CuteDSL templates. + + When inputs are ReinterpretViews of the same underlying buffer (e.g., Q/K/V + from fused QKV projection), we generate separate arguments for each input + even though they share the same underlying buffer. + """ + renames = IndentedBuffer(initial_indent=1) + + # Track template input args - each input gets its own arg even if buffers are shared + self._template_input_args: list[tuple[str, Buffer]] = [] + self._seen_input_args: OrderedSet[str] = OrderedSet() + + for i, input_node in enumerate(self.input_nodes): + buf_name = input_node.get_name() + # Register with args system (may deduplicate, but we track separately) + self.args.input(buf_name) + + if i < len(argnames): + template_name = argnames[i] + arg_name = f"arg_{template_name}" + self.args.input_buffers[buf_name] = arg_name + renames.writeline(f"{template_name} = {arg_name}") + self._template_input_args.append((arg_name, input_node)) + self._seen_input_args.add(arg_name) + + if self.output_node: + self.args.output(self.output_node.get_name()) + + def hook(): + # Generate signature with template input args plus additional args (output, sizevars) + code = IndentedBuffer() + code.writeline(f"# Kernel function signature: {self.kernel_name}") + + # Start with template input args + params = [arg_name for arg_name, _ in self._template_input_args] + + # Get additional args from python_argdefs (output, sizevars, etc.) + arg_defs, _, _, _ = self.args.python_argdefs() + for arg_def in arg_defs: + if arg_def.full_name() not in self._seen_input_args: + params.append(arg_def.full_name()) + + params.append("stream") + code.writeline( + f"def {self.kernel_name}_{MAIN_SUFFIX}({', '.join(params)}):" + ) + with code.indent(): + code.splice(renames.getvalue()) + return code.getvalue() + + assert "" not in self.render_hooks + # Placeholder-based rendering: hook will be called when template encounters "" + self.render_hooks[""] = hook + return "" + + def get_output(self): + """Get the actual argument name for the output buffer.""" + assert self.output_node, "Output node must exist to get output buffer name" + buf_name = self.output_node.get_name() + output = self.args.output_buffers.get(buf_name, None) + if output is None: + raise ValueError(f"Output buffer '{buf_name}' not found in args") + return output + + def set_cute_hash(self, func_name: str, suffix: str = ""): + """Generate code to set __cute_hash__ on a codegen function. + + This allows hash_callable in flash_attn to skip expensive runtime hashing + for Inductor-generated functions. The hash is based on the kernel name + which already contains a unique hash suffix. + """ + hash_value = f"{self.kernel_name}_{suffix}" if suffix else self.kernel_name + return f'{func_name}.__cute_hash__ = "{hash_value}"' + + def get_tensor_buffers(self): + """Get list of tensor buffer names that were collected during modifications.""" + return self.collected_tensor_buffers + + def unpack_buffers(self, buffer_list_name: str, *, indent_width: int = 4): + """Generate buffer unpacking code via render hook.""" + + def hook(): + tensor_buffers = self.get_tensor_buffers() + if not tensor_buffers: + return "" + + # Generate unpacking assignments: in_ptr4 = buffers[0], etc. + unpacking_lines = [] + for i, buffer_name in enumerate(tensor_buffers): + # pyrefly: ignore [bad-argument-type] + unpacking_lines.append(f"{buffer_name} = {buffer_list_name}[{i}]") + + indent = " " * indent_width + return "\n" + indent + ("\n" + indent).join(unpacking_lines) + + # Register the hook and return placeholder + placeholder = "" + # TODO: I think double invoking is fine for this specific hook + # assert placeholder not in self.render_hooks + self.render_hooks[placeholder] = hook + return placeholder + + def call_kernel(self, name: str, node=None): + """Call the kernel function. Simplified version of TritonTemplateKernel.call_kernel. + + For inputs that are ReinterpretViews (e.g., Q/K/V slices from fused QKV), + we generate reinterpret_tensor() calls to properly handle the views. + """ + wrapper = V.graph.wrapper_code + + # Build call args matching the signature generated in `def_kernel` + call_args = [] + arg_types = [] + + for _, input_node in self._template_input_args: + reinterpret_view = self._get_reinterpret_view(input_node) + if reinterpret_view is not None: + call_args.append(reinterpret_view.codegen_reference()) + else: + call_args.append(input_node.get_name()) + arg_types.append(V.graph.get_dtype(input_node.get_name())) + + # Add additional args from python_argdefs (output, sizevars, ..) + orig_arg_defs, orig_call_args, _, orig_arg_types = self.args.python_argdefs() + for arg_def, call_arg, arg_type in zip( + orig_arg_defs, orig_call_args, orig_arg_types + ): + # dedupe + if arg_def.full_name() not in self._seen_input_args: + call_args.append(call_arg) + arg_types.append(arg_type) + + # TODO this karg really should not be called `triton` + wrapper.generate_kernel_call(name, call_args, triton=True, arg_types=arg_types) + + def _get_subgraph(self, subgraph_number: int): + """Get subgraph by number for modification processing.""" + assert isinstance(subgraph_number, int) + assert isinstance(self.subgraphs, list) + assert subgraph_number < len(self.subgraphs), ( + f"Invalid subgraph number provided to create_modification, {subgraph_number} must be < {len(self.subgraphs)}" + ) + assert self.body.getvalue() == "", ( + "Body should be clear before adding a modification" + ) + return self.subgraphs[subgraph_number] + + def modification( + self, + subgraph_number: int, + output_name: Optional[str], + mask: Optional[str] = None, + **fixed_inputs, + ) -> str: + """Generate CuteDSL code for a subgraph modification.""" + # Find unique name to avoid collisions between multiple modifications of same subgraph + num = 0 + while f"mod_{subgraph_number}_{num}" in self.subgraph_bodies: + num += 1 + + with self.create_subgraph_body(f"mod_{subgraph_number}_{num}", clear_cse=True): + subgraph = self._get_subgraph(subgraph_number) + modification_handler = ModificationWrapperCuteDSL( + self, subgraph_number, fixed_inputs, mask + ) + with V.set_kernel_handler(self), V.set_ops_handler(modification_handler): + assert isinstance(subgraph, (ComputedBuffer, list)), ( + f"Expected ComputedBuffer or List[ComputedBuffer], got {type(subgraph)}" + ) + + if isinstance(subgraph, list): + raise NotImplementedError( + "Scatter graphs are not supported for CuteDSL" + ) + + if isinstance(subgraph.data, InputBuffer): + # grad_score_mod can be InputBuffers + out = subgraph.data.make_loader()(()) + else: + # Inline a pointwise lowering into the template + out = subgraph.data.inner_fn(()) + + if output_name is not None: + assert out is not None, ( + f"Expected computation result for named output {output_name}" + ) + self.body.writeline(f"{output_name} = {out.value}") + else: + # Side-effect only: no output assignment (currently only for scatter operations) + raise NotImplementedError( + "Side-effect only modifications not yet supported for CuteDSL" + ) + + # Add Buffers that were added during modification + self.collected_tensor_buffers.extend(modification_handler.tensor_buffers) + + return self.body.getvalue() + + +class ModificationWrapperCuteDSL(V.WrapperHandler): # type: ignore[name-defined] + """ + Wrapper handler that enables CuteDSL code generation during subgraph modifications. + + This class sits between the PyTorch IR and CuteDSL code generation, providing: + 1. Operation substitution: converts PyTorch ops to CuteDSL equivalents via CuteDSLOpOverrides + 2. Placeholder handling: resolves fixed_inputs during template processing + 3. Limited operation support: currently restricted to pointwise operations + + """ + + def __init__( + self, + kernel, + subgraph_number: int, + fixed_inputs: dict[str, Any], + mask: Optional[str], + ): + cutedsl_ops = CuteDSLOpOverrides() + super().__init__(cutedsl_ops) + self.name = f"CuteDSLPlaceholderSubstitution_{subgraph_number}" + self.kernel = kernel + self.fixed_inputs = fixed_inputs + self.mask = mask + # Track tensor buffers that get added during modification processing + self.tensor_buffers: list[str] = [] + + def _get_input_dtype(self, name: str) -> torch.dtype: + """Get the dtype for an input from the kernel's named_input_nodes.""" + if name in self.kernel.named_input_nodes: + return self.kernel.named_input_nodes[name].dtype + # TODO: Fallback for common dimension names - should be replaced with proper dtype tracking + return torch.float32 if name not in ("b", "h", "m", "n") else torch.int32 + + def load(self, name: str, index: sympy.Expr): + """Handle loading from tensor or fixed(template args) input for CuteDSL.""" + if name not in self.fixed_inputs: + var = self._add_kernel_input(name) + buffer = V.graph.get_buffer(name) + var_dtype = buffer.dtype + + cute_dtype = CuteDSLOpOverrides.TORCH_TO_CUTE_DTYPE.get( + var_dtype, "cutlass.Float32" + ) + renamed_index = self.kernel.rename_indexing(index) + + idx_var = self._emit_scalar_fragment( + self.kernel.kexpr(renamed_index), "cutlass.Int32", torch.int32 + ) + + val_frag = self.kernel.cse.newvar(dtype=var_dtype) + self.kernel.body.writeline( + f"{val_frag} = cute.make_rmem_tensor(1, {cute_dtype})" + ) + + self.kernel.body.writeline(f"{val_frag}[0] = ({var}[{idx_var}])") + + final_expr = f"{val_frag}.load()" + + if ( + var_dtype in (torch.float16, torch.bfloat16) + and config.triton.codegen_upcast_to_fp32 + ): + final_expr = f"({final_expr}).to(cutlass.Float32)" + var_dtype = torch.float32 + + out = self.kernel.cse.generate( + self.kernel.body, + final_expr, + dtype=var_dtype, + bounds=ValueRanges.unknown(), + ) + return out + + value = self.fixed_inputs[name] + dtype = self._get_input_dtype(name) + + return self.kernel.cse.generate( + self.kernel.body, value, bounds=ValueRanges.unknown(), dtype=dtype + ) + + def _emit_scalar_fragment( + self, expr_str: str, cute_dtype: str, torch_dtype: torch.dtype + ) -> str: + """ + Convert SSA expression to indexable scalar for tensor loads. + + Workaround for lack of gather support: SSA values cannot be used directly + as indices. This generates code to convert SSA → indexable scalar. + """ + result = self.kernel.cse.newvar(dtype=torch_dtype) + self.kernel.body.writeline( + f"{result} = ssa_to_indexable({expr_str}, {cute_dtype})" + ) + return str(result) + + def indirect_indexing(self, index_var: str, size, check, wrap_neg=True): + """Convert index variable to symbolic form.""" + return sympy_index_symbol(str(index_var)) + + # pyrefly: ignore [bad-override] + def store( + self, name: str, index: sympy.Expr, value: CSEVariable, mode: StoreMode = None + ) -> str: + raise NotImplementedError( + "Store operations not supported - CuteDSL limited to read-only operations" + ) + + def _add_kernel_input(self, name: str): + """Add name as input to kernel and return input ref.""" + # Get the remapped name that will be used in the kernel + remapped_name = self.kernel.args.input(name) + # Track the remapped name for later collection + if remapped_name not in self.tensor_buffers: + self.tensor_buffers.append(remapped_name) + return remapped_name + + def _process_indexing(self, index): + """Process and rename indexing, adding symbols as kernel inputs.""" + renamed = self.kernel.rename_indexing(index) + return self.kernel.kexpr(renamed) + + def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any: + try: + return getattr(self._inner, name)(*args, **kwargs) + except NotImplementedError as e: + bar = "=" * 80 + msg = textwrap.dedent(f""" + {bar} + UNSUPPORTED CUTEDSL OPERATION: '{name}' + {bar} + This operation is not yet implemented in Inductor. + + Please open an issue at: https://github.com/pytorch/pytorch/issues + with the following information: + + Operation: {name} + Args: {args!r} + Kwargs: {kwargs!r} + + Title your issue: [CuteDSL] Missing operation: {name} + {bar} + """).strip() + raise NotImplementedError(msg) from e diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cutedsl/cutedsl_op_overrides.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cutedsl/cutedsl_op_overrides.py new file mode 100644 index 0000000000000000000000000000000000000000..2d3ca75c52adcf96bd1f8e4270eff933b953c1c5 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cutedsl/cutedsl_op_overrides.py @@ -0,0 +1,360 @@ +# mypy: allow-untyped-defs +""" +CuteDSL-specific operation overrides for pointwise operations. + +This module provides CuteDSL implementations of common operations used in +template kernels, particularly for flex attention modifications. +""" + +import math +from typing import Optional, Union + +import sympy + +import torch +from torch._inductor.codegen.common import CSEVariable, OpOverrides +from torch._inductor.virtualized import OpsValue, V +from torch.utils._sympy.value_ranges import ValueRanges + + +CuteDSLArg = Union[CSEVariable, str] + + +def upcast_compute_type(dtype: torch.dtype) -> torch.dtype: + """Maybe upcast [b]float16 to float32""" + if dtype in (torch.float16, torch.bfloat16): + return torch.float32 + return dtype + + +class CuteDSLOpOverrides(OpOverrides): + """ + CuteDSL-specific operation overrides that generate code using CuteDSL syntax. + + CuteDSL TensorSSA objects have built-in operator overloads (__add__, __mul__, etc.) + and math functions (cute.math.exp, cute.math.sqrt, etc.) + """ + + TORCH_TO_CUTE_DTYPE = { + torch.float16: "cutlass.Float16", + torch.bfloat16: "cutlass.BFloat16", + torch.float32: "cutlass.Float32", + torch.float64: "cutlass.Float64", + torch.int8: "cutlass.Int8", + torch.int16: "cutlass.Int16", + torch.int32: "cutlass.Int32", + torch.int64: "cutlass.Int64", + torch.bool: "cutlass.Boolean", + torch.float8_e4m3fn: "cutlass.Float8E4M3FN", + torch.float8_e5m2: "cutlass.Float8E5M2", + } + + # Math constants + LOG2_E = 1.4426950408889634 # 1/ln(2) for converting natural exp to base-2 exp + + @staticmethod + def _ensure_tensor_ssa(arg: CuteDSLArg, template_tensor: CuteDSLArg) -> str: + """ + Convert scalar arguments to TensorSSA using cute.full_like if needed. + + Args: + arg: The argument to check (CSEVariable for tensors, str for scalars, or OpsValue wrapper) + template_tensor: A tensor argument to use as template for full_like + + Returns: + String representation suitable for CuteDSL operations + """ + if isinstance(arg, CSEVariable): + return str(arg) + + if isinstance(arg, OpsValue) and isinstance(arg.value, CSEVariable): + return str(arg.value) + + if isinstance(template_tensor, CSEVariable): + return f"cute.full_like({template_tensor}, {arg})" + + return str(arg) + + @staticmethod + def _extract_dtype_and_bounds( + *args: CuteDSLArg, + ) -> tuple[Optional[torch.dtype], ValueRanges[sympy.Expr]]: + """Extract dtype and bounds from CSEVariable arguments.""" + for arg in args: + if isinstance(arg, CSEVariable): + return arg.dtype, arg.bounds + return None, ValueRanges.unknown() + + @staticmethod + def _apply_binary_op(a: CuteDSLArg, b: CuteDSLArg, op_format: str) -> CuteDSLArg: + """ + Apply a binary operation with automatic scalar-to-tensor conversion. + + CuteDSL requires both operands to be TensorSSA objects for tensor operations. + This helper automatically converts scalar arguments to TensorSSA using + cute.full_like when at least one argument is a tensor (CSEVariable). + + Args: + a: First operand (CSEVariable for tensors, str for scalars) + b: Second operand (CSEVariable for tensors, str for scalars) + op_format: Format string with {a} and {b} placeholders for the operation + + Returns: + CSEVariable if at least one operand is a CSEVariable, otherwise string + """ + tensor_arg = ( + a + if isinstance(a, CSEVariable) + else b + if isinstance(b, CSEVariable) + else None + ) + if tensor_arg is not None: + a_ssa = CuteDSLOpOverrides._ensure_tensor_ssa(a, tensor_arg) + b_ssa = CuteDSLOpOverrides._ensure_tensor_ssa(b, tensor_arg) + result_expr = op_format.format(a=a_ssa, b=b_ssa) + + dtype, bounds = CuteDSLOpOverrides._extract_dtype_and_bounds(a, b) + + # Create and return CSEVariable using CSE generation for caching + return V.kernel.cse.generate( + V.kernel.body, result_expr, bounds=bounds, dtype=dtype + ) + + return op_format.format(a=a, b=b) + + @staticmethod + def _apply_unary_op(x: CuteDSLArg, op_format: str) -> CuteDSLArg: + """ + Apply a unary operation, returning CSEVariable if input is CSEVariable. + + Args: + x: Input operand (CSEVariable for tensors, str for scalars) + op_format: Format string with {x} placeholder for the operation + + Returns: + CSEVariable if input is a CSEVariable, otherwise string + """ + if isinstance(x, CSEVariable): + result_expr = op_format.format(x=str(x)) + return V.kernel.cse.generate( + V.kernel.body, result_expr, bounds=x.bounds, dtype=x.dtype + ) + + return op_format.format(x=x) + + @staticmethod + def constant(value: Union[bool, float, int], dtype: torch.dtype) -> str: + """Generate CuteDSL constant representation.""" + if value == float("-inf"): + return "float('-inf')" + elif value == float("inf"): + return "float('inf')" + elif math.isnan(value): + return "float('nan')" + return repr(value) + + @staticmethod + def add(a: CuteDSLArg, b: CuteDSLArg) -> CuteDSLArg: + return CuteDSLOpOverrides._apply_binary_op(a, b, "({a} + {b})") + + @staticmethod + def mul(a: CuteDSLArg, b: CuteDSLArg) -> CuteDSLArg: + return CuteDSLOpOverrides._apply_binary_op(a, b, "({a} * {b})") + + @staticmethod + def sub(a: CuteDSLArg, b: CuteDSLArg) -> CuteDSLArg: + return CuteDSLOpOverrides._apply_binary_op(a, b, "({a} - {b})") + + @staticmethod + def truediv(a: CuteDSLArg, b: CuteDSLArg) -> CuteDSLArg: + return CuteDSLOpOverrides._apply_binary_op(a, b, "({a} / {b})") + + @staticmethod + def mod(a: CuteDSLArg, b: CuteDSLArg) -> CuteDSLArg: + return CuteDSLOpOverrides._apply_binary_op(a, b, "({a} % {b})") + + @staticmethod + def remainder(a, b): + return CuteDSLOpOverrides._apply_binary_op(a, b, "({a} % {b})") + + @staticmethod + def exp(x: CuteDSLArg) -> CuteDSLArg: + """Exponential using CuteDSL cute.math.exp function.""" + return CuteDSLOpOverrides._apply_unary_op( + x, f"cute.math.exp2({{x}} * {CuteDSLOpOverrides.LOG2_E})" + ) + + @staticmethod + def sqrt(x: CuteDSLArg) -> CuteDSLArg: + """Square root using CuteDSL cute.math.sqrt function.""" + return CuteDSLOpOverrides._apply_unary_op(x, "cute.math.sqrt({x})") + + @staticmethod + def log(x: CuteDSLArg) -> CuteDSLArg: + """Natural logarithm using CuteDSL cute.math.log function.""" + return CuteDSLOpOverrides._apply_unary_op(x, "cute.math.log({x})") + + @staticmethod + def cos(x: CuteDSLArg) -> CuteDSLArg: + """Cosine using CuteDSL cute.math.cos function.""" + return CuteDSLOpOverrides._apply_unary_op(x, "cute.math.cos({x})") + + @staticmethod + def sin(x: CuteDSLArg) -> CuteDSLArg: + """Sine using CuteDSL cute.math.sin function.""" + return CuteDSLOpOverrides._apply_unary_op(x, "cute.math.sin({x})") + + @staticmethod + def erf(x: CuteDSLArg) -> CuteDSLArg: + """Error function using CuteDSL cute.math.erf function.""" + return CuteDSLOpOverrides._apply_unary_op(x, "cute.math.erf({x})") + + @staticmethod + def maximum(a: CuteDSLArg, b: CuteDSLArg) -> CuteDSLArg: + raise NotImplementedError("TODO: maximum is not supported yet for TensorSSA") + + @staticmethod + def minimum(a: CuteDSLArg, b: CuteDSLArg) -> CuteDSLArg: + raise NotImplementedError("TODO: minimum is not supported yet for TensorSSA") + + @staticmethod + def where( + condition: CuteDSLArg, + a: CuteDSLArg, + b: CuteDSLArg, + ) -> CuteDSLArg: + """Conditional selection - handles both CSEVariable and string inputs.""" + # Find a tensor argument to use as template for full_like + # Priority: use 'a' if it's a tensor, else use 'b', else condition + tensor_arg = ( + a + if isinstance(a, CSEVariable) + else ( + b + if isinstance(b, CSEVariable) + else condition + if isinstance(condition, CSEVariable) + else None + ) + ) + + if tensor_arg is not None: + a_ssa = CuteDSLOpOverrides._ensure_tensor_ssa(a, tensor_arg) + b_ssa = CuteDSLOpOverrides._ensure_tensor_ssa(b, tensor_arg) + result_expr = f"cute.where({condition}, {a_ssa}, {b_ssa})" + + dtype, bounds = CuteDSLOpOverrides._extract_dtype_and_bounds( + a, b, condition + ) + + return V.kernel.cse.generate( + V.kernel.body, result_expr, bounds=bounds, dtype=dtype + ) + + return f"cute.where({condition}, {a}, {b})" + + @staticmethod + def pow(a: CuteDSLArg, b: CuteDSLArg): + return CuteDSLOpOverrides._apply_binary_op(a, b, "({a} ** {b})") + + @staticmethod + def abs(x: CuteDSLArg) -> CuteDSLArg: + """Absolute value using CuteDSL cute.math.abs function.""" + if isinstance(x, CSEVariable): + x_dtype = x.dtype + elif isinstance(x, OpsValue) and isinstance(x.value, CSEVariable): + x_dtype = x.value.dtype + else: + x_dtype = torch.float32 + + abs_op = ( + "mlir_math.absf" + if x_dtype in (torch.float16, torch.bfloat16, torch.float32) + else "mlir_math.absi" + ) + return CuteDSLOpOverrides._apply_unary_op( + # pyrefly: ignore [bad-argument-type] + x, + f"cute.TensorSSA({abs_op}({{x}}), {{x}}.shape, {{x}}.dtype)", + ) + + @staticmethod + def neg(x: CuteDSLArg) -> CuteDSLArg: + """Negation using CuteDSL TensorSSA __neg__ operator.""" + # TODO: See https://github.com/NVIDIA/cutlass/issues/2584 + return CuteDSLOpOverrides._apply_unary_op( + x, "cute.TensorSSA(-{x}, {x}.shape, {x}.dtype)" + ) + + @staticmethod + def to_dtype( + x: CuteDSLArg, dtype: torch.dtype, src_dtype=None, use_compute_types=True + ) -> CuteDSLArg: + """Type conversion using CuteDSL TensorSSA.to(Type[Numeric]). + + Maps torch dtypes to cutlass.cute.typing numeric types and emits + `{x}.to(cute.typing.)`. + + Raises NotImplementedError for unsigned integer and unsupported dtypes. + """ + # Always convert up from bf16 and fp16 TODO on configuring + dtype = upcast_compute_type(dtype) + + cute_type = CuteDSLOpOverrides.TORCH_TO_CUTE_DTYPE.get(dtype) + if cute_type is None: + raise NotImplementedError( + f"CuteDSL dtype cast not implemented for torch dtype: {dtype}" + ) + + if isinstance(x, CSEVariable): + result_expr = f"{str(x)}.to({cute_type})" + return V.kernel.cse.generate( + V.kernel.body, result_expr, bounds=x.bounds, dtype=dtype + ) + + return f"{x}.to({cute_type})" + + @staticmethod + def tanh(x0: CuteDSLArg) -> CuteDSLArg: + """Hyperbolic tangent using CuteDSL cute.math.tanh function.""" + return CuteDSLOpOverrides._apply_unary_op(x0, "cute.math.tanh({x})") + + # Logical operations + @staticmethod + def logical_and(x0: CuteDSLArg, x1: CuteDSLArg) -> CuteDSLArg: + return CuteDSLOpOverrides._apply_binary_op(x0, x1, "({a} and {b})") + + @staticmethod + def logical_or(x0: CuteDSLArg, x1: CuteDSLArg) -> CuteDSLArg: + return CuteDSLOpOverrides._apply_binary_op(x0, x1, "({a} or {b})") + + @staticmethod + def logical_not(a): + """Logical NOT.""" + return CuteDSLOpOverrides._apply_unary_op(a, "({x} == 0)") + + # Comparison operations + @staticmethod + def eq(a: CuteDSLArg, b: CuteDSLArg) -> CuteDSLArg: + return CuteDSLOpOverrides._apply_binary_op(a, b, "operator.eq({a}, {b})") + + @staticmethod + def ne(a: CuteDSLArg, b: CuteDSLArg) -> CuteDSLArg: + return CuteDSLOpOverrides._apply_binary_op(a, b, "operator.ne({a}, {b})") + + @staticmethod + def lt(a: CuteDSLArg, b: CuteDSLArg) -> CuteDSLArg: + return CuteDSLOpOverrides._apply_binary_op(a, b, "operator.lt({a}, {b})") + + @staticmethod + def le(a: CuteDSLArg, b: CuteDSLArg) -> CuteDSLArg: + return CuteDSLOpOverrides._apply_binary_op(a, b, "operator.le({a}, {b})") + + @staticmethod + def gt(a: CuteDSLArg, b: CuteDSLArg) -> CuteDSLArg: + return CuteDSLOpOverrides._apply_binary_op(a, b, "operator.gt({a}, {b})") + + @staticmethod + def ge(a: CuteDSLArg, b: CuteDSLArg) -> CuteDSLArg: + return CuteDSLOpOverrides._apply_binary_op(a, b, "operator.ge({a}, {b})") diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cutedsl/cutedsl_scheduling.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cutedsl/cutedsl_scheduling.py new file mode 100644 index 0000000000000000000000000000000000000000..4fc1089a4082acc02f4b039f2fda9c0a726648d1 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cutedsl/cutedsl_scheduling.py @@ -0,0 +1,141 @@ +# mypy: allow-untyped-defs +import hashlib +import logging +from collections.abc import Sequence +from typing import cast + +from torch._inductor.utils import Placeholder +from torch.utils._ordered_set import OrderedSet + +from ... import config +from ...codecache import code_hash, get_path +from ...ir import CuteDSLTemplateBuffer +from ...scheduler import ( + BaseSchedulerNode, + BaseScheduling, + FusedSchedulerNode, + SchedulerNode, +) +from ...select_algorithm import PartialRender +from ...utils import get_fused_kernel_name, get_kernel_metadata +from ...virtualized import V +from ..common import BackendFeature, IndentedBuffer + + +log = logging.getLogger(__name__) + + +class CuteDSLScheduling(BaseScheduling): + """ + Scheduling implementation for CuteDSL (CUTLASS Python DSL) kernels. + This class is intended to be used in combination with other schedulers, + and delegated to by CUDACombinedScheduling. + """ + + @classmethod + def get_backend_features(cls, device) -> OrderedSet[BackendFeature]: + return OrderedSet() + + @staticmethod + def is_cutedsl_template(node: BaseSchedulerNode) -> bool: + """Check if a node is a CuteDSL template.""" + return isinstance(node, SchedulerNode) and isinstance( + node.node, CuteDSLTemplateBuffer + ) + + def is_cutedsl_fused_template(self, node: BaseSchedulerNode) -> bool: + """Check if a node is a fused CuteDSL template.""" + return isinstance(node, FusedSchedulerNode) and self.is_cutedsl_template(node) + + def can_fuse_vertical( + self, node1: BaseSchedulerNode, node2: BaseSchedulerNode + ) -> bool: + """ + TODO CuteDSL doesn't support vertical fusion yet. + This could be extended in the future for epilogue fusion. + """ + return False + + def define_kernel(self, src_code_str: str, node_schedule) -> str: + """Produce the kernel string + Args: + src_code_str: The finalized kernel code string + node_schedule: List of nodes in the schedule + + Note: + This is a little weird since async_compile.cutedsl() has to write the string to + a file in order to cute compile it. Feels bad to have two... + """ + wrapper = V.graph.wrapper_code + + # Use the string as the key for caching + if src_code_str in wrapper.src_to_kernel: + kernel_name = wrapper.src_to_kernel[src_code_str] + else: + fused_name = ( + get_fused_kernel_name(node_schedule, config.triton.descriptive_names) + if config.triton.descriptive_names + else "" + ) + + kernel_hash = hashlib.sha256(src_code_str.encode("utf-8")).hexdigest()[:8] + if fused_name == "fused": + kernel_name = f"cutedsl_{kernel_hash}" + else: + kernel_name = f"cutedsl_{fused_name}_{kernel_hash}" + wrapper.src_to_kernel[src_code_str] = kernel_name + src_code_str = src_code_str.replace( + str(Placeholder.KERNEL_NAME), kernel_name + ) + + _, _, kernel_path = get_path(code_hash(src_code_str), "py") + + compile_wrapper = IndentedBuffer() + compile_wrapper.writeline(f"async_compile.cutedsl({kernel_name!r}, r'''") + compile_wrapper.splice(src_code_str, strip=True) + compile_wrapper.writeline("''')") + + metadata_comment = f"# kernel path: {kernel_path}" + origins, detailed_origins = get_kernel_metadata(node_schedule, wrapper) + metadata_comment += "\n" + origins + "\n" + detailed_origins + wrapper.define_kernel( + kernel_name, compile_wrapper.getvalue(), metadata_comment + ) + return kernel_name + + def codegen_template( + self, + template_node: BaseSchedulerNode, + epilogue_nodes: Sequence[BaseSchedulerNode], + prologue_nodes: Sequence[BaseSchedulerNode], + ): + """ + Codegen a CuteDSL template. Currently doesn't support fusion. + """ + assert self.is_cutedsl_template(template_node), ( + "Template node passed to CuteDSLScheduling.codegen_template must be a " + "SchedulerNode that wraps a CuteDSLTemplateBuffer" + ) + # TODO remove when supported + assert not epilogue_nodes, "CuteDSL doesn't support epilogue fusion yet" + assert not prologue_nodes, "CuteDSL doesn't support prologue fusion yet" + + template_node = cast(SchedulerNode, template_node) + ctb: CuteDSLTemplateBuffer = cast(CuteDSLTemplateBuffer, template_node.node) + + kernel, render = ctb.make_kernel_render(ctb) # type: ignore[misc] + template_node.mark_run() + src_code = render() + # Finalize PartialRender if needed + if isinstance(src_code, PartialRender): + src_code_str = src_code.finalize_all() + else: + src_code_str = src_code + + with V.set_kernel_handler(kernel): + node_schedule = [template_node] + kernel_name = self.define_kernel(src_code_str, node_schedule) + self.codegen_comment(node_schedule, kernel_name) + kernel.call_kernel(kernel_name, ctb) + V.graph.removed_buffers |= kernel.removed_buffers + self.free_buffers_in_scheduler() diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cutedsl/cutedsl_template.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cutedsl/cutedsl_template.py new file mode 100644 index 0000000000000000000000000000000000000000..bf30480981378daca74cf4ab4b1e4c01e8065e79 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cutedsl/cutedsl_template.py @@ -0,0 +1,199 @@ +# mypy: allow-untyped-defs +import functools +import itertools +from collections.abc import Iterable +from typing import Any, Optional, Union +from unittest.mock import patch + +from torch._inductor.ir import ShapeAsConstantBuffer +from torch._inductor.utils import Placeholder +from torch._inductor.virtualized import V +from torch._logging import getArtifactLogger + +from ...autotune_process import CuteDSLBenchmarkRequest, TensorMeta +from ...ir import Buffer, ChoiceCaller, CuteDSLTemplateBuffer, IRNode, Layout, TensorBox +from ..common import KernelTemplate +from .cutedsl_kernel import CuteDSLTemplateKernel + + +log = getArtifactLogger(__name__, "output_code") + + +class CuteDSLTemplate(KernelTemplate): + """Template for generating CuteDSL (CUTLASS Python DSL) kernels.""" + + kernel_type: type[Any] = CuteDSLTemplateKernel + index_counter = itertools.count() + all_templates: dict[str, "CuteDSLTemplate"] = {} + + def __init__( + self, + name: str, + source: str, + subgraph_fn: Optional[Any] = None, + mask_fn: Optional[Any] = None, + ) -> None: + super().__init__(name) + self.source = source + self.subgraph_fn = subgraph_fn + self.mask_fn = mask_fn + self.template = CuteDSLTemplate._template_from_string(source) + assert name not in self.all_templates, f"duplicate template name, {name}" + CuteDSLTemplate.all_templates[name] = self + + @staticmethod + @functools.lru_cache(None) + # pyrefly: ignore [bad-override] + def _template_from_string(source: str) -> Any: + return KernelTemplate._template_from_string(source) + + def maybe_append_choice( + self, choices: list[Any], **kwargs: Any + ) -> Optional[NotImplementedError]: + """ + Maybe generates a new ChoiceCaller and appends it into existing choices. + Returns None if success, otherwise returns the error. + """ + try: + choices.append(self.generate(**kwargs)) + return None + except NotImplementedError as e: + log.debug("CuteDSL template choice generation failed: %s", e) # noqa: G200 + return e + except Exception as e: + log.debug("CuteDSL template choice generation error: %s", e) # noqa: G200 + return NotImplementedError(f"CuteDSL template failed: {e}") + + def generate(self, **kwargs: Any) -> ChoiceCaller: + """Generate the CuteDSL kernel caller.""" + input_nodes = kwargs.pop("input_nodes") + layout = kwargs.pop("layout") + mutated_inputs = kwargs.pop("mutated_inputs", None) + subgraphs = kwargs.pop("subgraphs", None) + + kernel_name = f"cutedsl_{self.name}_{next(self.index_counter)}" + + if self.template is None: + raise RuntimeError("Template compilation failed (Jinja2 required)") + + self.output_node: Buffer = Buffer(name="buf_out", layout=layout) + # Patch V.graph.get_dtype to handle the fake buf_out buffer + with patch.object( + V.graph, "get_dtype", KernelTemplate._fake_get_dtype(self.output_node) + ): + kernel = self.kernel_type( + kernel_name=kernel_name, + input_nodes=input_nodes, + output_node=self.output_node, + subgraphs=subgraphs, + ) + code = kernel.render(self.template, **kwargs) + + log.debug("Generated CuteDSL Code:\n%s", code) + + bmreq = CuteDSLBenchmarkRequest( + kernel_name=kernel_name, + input_tensor_meta=TensorMeta.from_irnodes(input_nodes), + output_tensor_meta=TensorMeta.from_irnodes(self.output_node), + extra_args=tuple(), + source_code=code, + ) + + def make_kernel_render(out_node, hint_override: Optional[int] = None): + """ + Factory function that creates a kernel renderer for the final output. + + This closure captures the current template and parameters, but allows + the output node to be specified later. This is used during the final + kernel selection phase when the actual output buffer is available. + """ + render_kernel = self.kernel_type( + kernel_name=str(Placeholder.KERNEL_NAME), + input_nodes=input_nodes, + output_node=out_node, + subgraphs=subgraphs, + ) + + def render(): + return render_kernel.render(self.template, **kwargs) + + return render_kernel, render + + return CuteDSLTemplateCaller( + name=kernel_name, + input_nodes=input_nodes, + layout=layout, + make_kernel_render=make_kernel_render, + bmreq=bmreq, + template=self, + mutated_inputs=mutated_inputs, + ) + + +class CuteDSLTemplateCaller(ChoiceCaller): + """Caller for CuteDSL templates that integrates with the autotuning system.""" + + def __init__( + self, + name: str, + input_nodes: list[Buffer], + layout: Layout, + make_kernel_render: Any, + bmreq: CuteDSLBenchmarkRequest, + template: "CuteDSLTemplate", + mutated_inputs: Optional[Iterable[IRNode]] = None, + ): + super().__init__( + name=name, + input_nodes=input_nodes, + layout=layout, + description=f"CuteDSL template {name}", + ) + self.make_kernel_render = make_kernel_render + self.bmreq = bmreq + self.template = template + self.mutated_inputs = mutated_inputs + + def __str__(self) -> str: + return f"CuteDSLTemplateCaller({self.name})" + + def benchmark(self, *args, out) -> float: + """Benchmark the kernel execution.""" + return self.bmreq.benchmark(*args, out=out) + + def output_node(self) -> Union[TensorBox, ShapeAsConstantBuffer]: + """Create the output node for this template choice.""" + return TensorBox.create( + CuteDSLTemplateBuffer( + layout=self.layout, + inputs=self.input_nodes, + make_kernel_render=self.make_kernel_render, + template=self.template, + mutated_inputs=self.mutated_inputs, + ) + ) + + def call_name(self) -> str: + """Return the kernel call name.""" + return self.name + + def to_callable(self) -> Any: + """Return callable that can execute this kernel.""" + return self.make_kernel_render + + def hash_key(self) -> str: + """Return unique hash key for this choice.""" + return "-".join( + [ + self.name.rsplit("_", 1)[0], + self.bmreq.module_cache_key, + ] + ) + + def info_dict(self) -> dict[str, Any]: + """Return information about this kernel.""" + return { + "name": self.name, + "backend": "CuteDSL", + "template": self.template.name, + } diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/mtia/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/mtia/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/mtia/device_op_overrides.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/mtia/device_op_overrides.py new file mode 100644 index 0000000000000000000000000000000000000000..135bee2b8fe9226d5b69077201c0b08bfc8460a4 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/mtia/device_op_overrides.py @@ -0,0 +1,20 @@ +from __future__ import annotations + +from ..common import DeviceOpOverrides, register_device_op_overrides + + +class MTIADeviceOpOverrides(DeviceOpOverrides): + def import_get_raw_stream_as(self, name: str) -> str: + return f"from torch._C import _mtia_getCurrentRawStream as {name}" + + def set_device(self, device_idx: int) -> str: + return f"torch.mtia.set_device({device_idx})" + + def synchronize(self) -> str: + return "torch.mtia.synchronize()" + + def device_guard(self, device_idx: int) -> str: + return f"torch.mtia.device({device_idx})" + + +register_device_op_overrides("mtia", MTIADeviceOpOverrides()) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/ck_conv_template.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/ck_conv_template.py new file mode 100644 index 0000000000000000000000000000000000000000..277b6ed3749486074a583d61f6f2909886eb60c9 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/ck_conv_template.py @@ -0,0 +1,627 @@ +# mypy: allow-untyped-defs +import copy +import logging +import random +from typing import Any +from typing_extensions import override + +from torch._inductor.virtualized import V + +from .rocm_template import ArgInfo + + +try: + import ck4inductor # type: ignore[import] +except ImportError: + ck4inductor = None + +if ck4inductor is not None: + from ck4inductor.grouped_conv_fwd.gen_instances import ( # type: ignore[import] + gen_conv_ops_library, + ) + from ck4inductor.grouped_conv_fwd.op import ( # type: ignore[import] # noqa: TCH002 + CKGroupedConvFwdOp, + ) +else: + + def gen_conv_ops_library(): + return [] + + +from torch._inductor import config +from torch._inductor.codegen.rocm.ck_template import CKTemplate +from torch._inductor.codegen.rocm.rocm_kernel import ROCmTemplateKernel +from torch._inductor.utils import IndentedBuffer + + +log = logging.getLogger(__name__) + + +def torch_layout_to_ck_layouts(torch_layout): + # logically, torch tensors are always NCHW, + # and channels-last memory layout is visible in the strides + if V.graph.sizevars.statically_known_equals(torch_layout.stride[-1], 1): + # when input or output is NCHW + # NB: torch.conv2d result is always NCHW + return ["NGCHW", "GKCYX", "NGKHW"] + elif V.graph.sizevars.statically_known_equals(torch_layout.stride[-3], 1): + # when input or output or weight is channels-last + return ["NHWGC", "GKYXC", "NHWGK"] + else: + return None + + +def torch_layout_to_ck_input_layout(torch_layout): + if V.graph.sizevars.statically_known_equals(torch_layout.stride[-1], 1): + return "NGCHW" + elif V.graph.sizevars.statically_known_equals(torch_layout.stride[-3], 1): + return "NHWGC" + else: + return None + + +def torch_layout_to_ck_weight_layout(torch_layout): + if V.graph.sizevars.statically_known_equals(torch_layout.stride[-1], 1): + return "GKCYX" + elif V.graph.sizevars.statically_known_equals(torch_layout.stride[-3], 1): + return "GKYXC" + else: + return None + + +def torch_layout_to_ck_output_layout(torch_layout): + if V.graph.sizevars.statically_known_equals(torch_layout.stride[-1], 1): + return "NGKHW" + elif V.graph.sizevars.statically_known_equals(torch_layout.stride[-3], 1): + return "NHWGK" + else: + return None + + +class CKGroupedConvFwdTemplate(CKTemplate): + conv_template = r""" + {{headers}} + {{globals}} + {{instance_definition}} + extern "C" { + PT_EXPORT {{kernel_definition}} { + auto conv = {{instance_type}} {}; + auto invoker = conv.MakeInvoker(); + + using ck::index_t; + + constexpr index_t NumDTensor = {{n_d_tensors}}; + constexpr index_t NDimSpatial = {{n_dim_spatial}}; + const std::vector FilterSize = { FilterSize_0, FilterSize_1 }; + const std::vector InputSize = { InputSize_0, InputSize_1 }; + const std::vector ConvolutionStrides = { ConvolutionStrides_0, ConvolutionStrides_1 }; + const std::vector Dilations = { Dilations_0, Dilations_1 }; + const std::vector LeftPads = { LeftPads_0, LeftPads_1 }; + const std::vector RightPads = { RightPads_0, RightPads_1 }; + + + auto conv_param = ck::utils::conv::ConvParam { + NDimSpatial, + GroupCount, + NBatch, + NOutChannels, + NInChannels, + FilterSize, + InputSize, + ConvolutionStrides, + Dilations, + LeftPads, + RightPads, + }; + + using InLayout = ck::tensor_layout::convolution::{{input_layout}}; + using WeiLayout = ck::tensor_layout::convolution::{{weight_layout}}; + using OutLayout = ck::tensor_layout::convolution::{{output_layout}}; + + const auto in_g_n_c_wis_desc = + ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed(conv_param); + const auto wei_g_k_c_xs_desc = + ck::utils::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed(conv_param); + const auto out_g_n_k_wos_desc = + ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed(conv_param); + + const void* p_a = input; + const void* p_b = weight; + const std::array p_ds; + void* p_e = output; + std::array a_g_n_c_wis_lengths; + std::array a_g_n_c_wis_strides; + std::array b_g_k_c_xs_lengths; + std::array b_g_k_c_xs_strides; + std::array, NumDTensor> ds_g_n_k_wos_lengths; + std::array, NumDTensor> ds_g_n_k_wos_strides; + std::array e_g_n_k_wos_lengths; + std::array e_g_n_k_wos_strides; + std::array conv_filter_strides; + std::array conv_filter_dilations; + std::array input_left_pads; + std::array input_right_pads; + const auto a_element_op = PassThrough {}; + const auto b_element_op = PassThrough {}; + const auto cde_element_op = PassThrough {}; + + auto copy = [](auto& x, auto& y) { ck::ranges::copy(x, y.begin()); }; + + copy(in_g_n_c_wis_desc.GetLengths(), a_g_n_c_wis_lengths); + copy(in_g_n_c_wis_desc.GetStrides(), a_g_n_c_wis_strides); + copy(wei_g_k_c_xs_desc.GetLengths(), b_g_k_c_xs_lengths); + copy(wei_g_k_c_xs_desc.GetStrides(), b_g_k_c_xs_strides); + copy(out_g_n_k_wos_desc.GetLengths(), e_g_n_k_wos_lengths); + copy(out_g_n_k_wos_desc.GetStrides(), e_g_n_k_wos_strides); + copy(conv_param.conv_filter_strides_, conv_filter_strides); + copy(conv_param.conv_filter_dilations_, conv_filter_dilations); + copy(conv_param.input_left_pads_, input_left_pads); + copy(conv_param.input_right_pads_, input_right_pads); + + auto argument = conv.MakeArgument( + p_a, + p_b, + p_ds, + p_e, + a_g_n_c_wis_lengths, + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + ds_g_n_k_wos_lengths, + ds_g_n_k_wos_strides, + e_g_n_k_wos_lengths, + e_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + a_element_op, + b_element_op, + cde_element_op + ); + if (!conv.IsSupportedArgument(argument)) { + // we do our best to statically avoid this case in `filter_op` + std::cerr << "invalid argument for conv instance " << conv.GetTypeString() << std::endl; + argument.Print(); + return -23; + } + if (workspace_size) { + *workspace_size = conv.GetWorkSpaceSize(&argument); + return 0; + } + + if (p_a == nullptr) { + std::cerr << "p_a is nullptr" << std::endl; + return -1; + } + if (p_b == nullptr) { + std::cerr << "p_b is nullptr" << std::endl; + return -1; + } + if (p_e == nullptr) { + std::cerr << "p_e is nullptr" << std::endl; + return -1; + } + + // when debugging, do time kernel to serialize launches + auto stream_config = StreamConfig{stream, /* time kernel */ false, /* log level */ 0}; + + if (workspace != nullptr) { + conv.SetWorkSpacePointer(&argument, workspace, stream_config); + } + + // run the kernel + float elapsed_time = invoker.Run(argument, stream_config); + return 0; + } // kernel definition + } // extern C + + #ifdef GENERATE_CK_STANDALONE_RUNNER + int main(int argc, char** argv) { + (void) argc; + (void) argv; + return 0; + } + #endif // GENERATE_CK_STANDALONE_RUNNER +""" + + def globals(self) -> IndentedBuffer: + res = super().globals() + res.splice( + """ + // CK conv globals + + using NWC = ck::tensor_layout::convolution::NWC; + using NHWC = ck::tensor_layout::convolution::NHWC; + using NDHWC = ck::tensor_layout::convolution::NDHWC; + + using KXC = ck::tensor_layout::convolution::KXC; + using KYXC = ck::tensor_layout::convolution::KYXC; + using KZYXC = ck::tensor_layout::convolution::KZYXC; + + using NWK = ck::tensor_layout::convolution::NWK; + using NHWK = ck::tensor_layout::convolution::NHWK; + using NDHWK = ck::tensor_layout::convolution::NDHWK; + + using GNWC = ck::tensor_layout::convolution::GNWC; + using GNHWC = ck::tensor_layout::convolution::GNHWC; + using GNDHWC = ck::tensor_layout::convolution::GNDHWC; + + using GKXC = ck::tensor_layout::convolution::GKXC; + using GKYXC = ck::tensor_layout::convolution::GKYXC; + using GKZYXC = ck::tensor_layout::convolution::GKZYXC; + + using GKCX = ck::tensor_layout::convolution::GKCX; + using GKCYX = ck::tensor_layout::convolution::GKCYX; + using GKCZYX = ck::tensor_layout::convolution::GKCZYX; + + using GNWK = ck::tensor_layout::convolution::GNWK; + using GNHWK = ck::tensor_layout::convolution::GNHWK; + using GNDHWK = ck::tensor_layout::convolution::GNDHWK; + + using NGKW = ck::tensor_layout::convolution::NGKW; + using NGKHW = ck::tensor_layout::convolution::NGKHW; + using NGKDHW = ck::tensor_layout::convolution::NGKDHW; + + using NWGC = ck::tensor_layout::convolution::NWGC; + using NHWGC = ck::tensor_layout::convolution::NHWGC; + using NDHWGC = ck::tensor_layout::convolution::NDHWGC; + + using KXGC = ck::tensor_layout::convolution::KXGC; + using KYXGC = ck::tensor_layout::convolution::KYXGC; + using KZYXGC = ck::tensor_layout::convolution::KZYXGC; + + using NWGK = ck::tensor_layout::convolution::NWGK; + using NHWGK = ck::tensor_layout::convolution::NHWGK; + using NDHWGK = ck::tensor_layout::convolution::NDHWGK; + + using NGCW = ck::tensor_layout::convolution::NGCW; + using NGCHW = ck::tensor_layout::convolution::NGCHW; + using NGCDHW = ck::tensor_layout::convolution::NGCDHW; + + using G_K = ck::tensor_layout::convolution::G_K; + + using BlockGemmPipelineScheduler = ck::BlockGemmPipelineScheduler; + using GemmSpecialization = ck::tensor_operation::device::GemmSpecialization; + using BlockGemmPipelineVersion = ck::BlockGemmPipelineVersion; + + using ConvolutionForwardSpecialization = ck::tensor_operation::device::ConvolutionForwardSpecialization; + + using OutElementOp = PassThrough; + + namespace ck { + namespace utils { + namespace conv { + + ConvParam::ConvParam(ck::index_t n_dim, + ck::index_t group_count, + ck::index_t n_batch, + ck::index_t n_out_channels, + ck::index_t n_in_channels, + const std::vector& filters_len, + const std::vector& input_len, + const std::vector& strides, + const std::vector& dilations, + const std::vector& left_pads, + const std::vector& right_pads) + : num_dim_spatial_(static_cast(n_dim)), + G_(static_cast(group_count)), + N_(static_cast(n_batch)), + K_(static_cast(n_out_channels)), + C_(static_cast(n_in_channels)), + filter_spatial_lengths_(num_dim_spatial_), + input_spatial_lengths_(num_dim_spatial_), + output_spatial_lengths_(num_dim_spatial_), + conv_filter_strides_(num_dim_spatial_), + conv_filter_dilations_(num_dim_spatial_), + input_left_pads_(num_dim_spatial_), + input_right_pads_(num_dim_spatial_) + { + if(static_cast(filter_spatial_lengths_.size()) != num_dim_spatial_ || + static_cast(input_spatial_lengths_.size()) != num_dim_spatial_ || + static_cast(conv_filter_strides_.size()) != num_dim_spatial_ || + static_cast(conv_filter_dilations_.size()) != num_dim_spatial_ || + static_cast(input_left_pads_.size()) != num_dim_spatial_ || + static_cast(input_right_pads_.size()) != num_dim_spatial_) + { + throw( + std::runtime_error("ConvParam::ConvParam: " + "parameter size is different from number of declared dimensions!")); + } + + for(ck::index_t i = 0; i < num_dim_spatial_; ++i) + { + filter_spatial_lengths_[i] = static_cast(filters_len[i]); + input_spatial_lengths_[i] = static_cast(input_len[i]); + conv_filter_strides_[i] = static_cast(strides[i]); + conv_filter_dilations_[i] = static_cast(dilations[i]); + input_left_pads_[i] = static_cast(left_pads[i]); + input_right_pads_[i] = static_cast(right_pads[i]); + + // XEff = (X - 1) * conv_dilation_w + 1; + // Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1; + const ck::long_index_t x_eff = + (filter_spatial_lengths_[i] - 1) * conv_filter_dilations_[i] + 1; + + output_spatial_lengths_[i] = + (input_spatial_lengths_[i] + input_left_pads_[i] + input_right_pads_[i] - x_eff) / + conv_filter_strides_[i] + + 1; + } + } + + } // namespace conv + } // namespace utils + } // namespace ck + + const std::vector& HostTensorDescriptor::GetLengths() const { return mLens; } + const std::vector& HostTensorDescriptor::GetStrides() const { return mStrides; } + std::size_t HostTensorDescriptor::GetNumOfDimension() const { return mLens.size(); } + void HostTensorDescriptor::CalculateStrides() { + mStrides.clear(); + mStrides.resize(mLens.size(), 0); + if(mStrides.empty()) + return; + + mStrides.back() = 1; + std::partial_sum( + mLens.rbegin(), mLens.rend() - 1, mStrides.rbegin() + 1, std::multiplies()); + } + """ + ) + return res + + def header(self) -> IndentedBuffer: + res = super().header() + res.splice( + """ + // CK conv headers + + #include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp" + #include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" + #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" + + #include "ck/library/utility/convolution_parameter.hpp" + #include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" + """ + ) + return res + + @staticmethod + def add_ck_conv_choices( + choices, + layout, + input_nodes, + *, + stride, + padding, + dilation, + groups, + n_spatial_dimensions, + ): + template = CKGroupedConvFwdTemplate( + input_nodes, + layout, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + n_spatial_dimensions=n_spatial_dimensions, + ) + ops = template.gen_ops() + for op in ops: + template.maybe_append_choice( + choices, + op=op, + ) + + def __init__( + self, + input_nodes, + layout, + *, + stride, + padding, + dilation, + groups, + n_spatial_dimensions, + ): + super().__init__( + "ck_conv_template", + input_nodes, + layout, + ) + self.stride = stride + self.padding = padding + self.dilation = dilation + self.groups = groups + self.n_spatial_dimensions = n_spatial_dimensions + + def filter_op(self, op: "CKGroupedConvFwdOp"): # type: ignore[name-defined] + metas = [ + T.get_layout() + for T in [*self.input_nodes, self.output_node] + if T is not None + ] + X_meta = metas[0] + W_meta = metas[1] + Y_meta = metas[-1] + # disable the instance if dtypes don't match + if op.a_element_dtype != self._TORCH_DTYPE_TO_CK[X_meta.dtype]: + return None + if op.b_element_dtype != self._TORCH_DTYPE_TO_CK[W_meta.dtype]: + return None + if op.e_element_dtype != self._TORCH_DTYPE_TO_CK[Y_meta.dtype]: + return None + # disable the instance if layouts don't match + if op.a_layout != torch_layout_to_ck_input_layout(X_meta): + return None + if op.b_layout != torch_layout_to_ck_weight_layout(W_meta): + return None + if op.e_layout != torch_layout_to_ck_output_layout(Y_meta): + return None + # disable the instance if number of spatial dimensions doesn't match + if op.n_dim_spatial != self.n_spatial_dimensions: + return None + # disable 1x1 and odd-channels conv specializations for now + if "Default" not in op.conv_forward_specialization: + return None + return op + + def gen_ops(self): + unfiltered_instances = gen_conv_ops_library() + + filtered_instances = list( + filter(lambda op: self.filter_op(op), unfiltered_instances) + ) + # NB: when using a fixed list order, most likely we will pick the subset of instances + # which are very similar to each other. Randomizing the choice seems to solve this. + random.seed(-11) + chosen_instances = ( + random.sample( + filtered_instances, + min(len(filtered_instances), config.rocm.ck_max_profiling_configs), + ) + if config.rocm.ck_max_profiling_configs + else filtered_instances + ) + log.debug( + "generated %d ck instances after filter: %s", + len(chosen_instances), + chosen_instances, + ) + return chosen_instances + + def emit_ck_instance(self, op: "CKGroupedConvFwdOp") -> tuple[str, str]: # type: ignore[name-defined] + # The Jinja template for generating a C++ type alias *definition* for a Universal GEMM instance + template_definition = r""" + // Gemm operator {{operation_name}} + using Operation_{{operation_name}} = + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3< + {{template_params}}>; + +""" + # The Jinja template for generating a C++ type alias *usage* for a Universal GEMM instance + template_type = r""" + Operation_{{operation_name}} +""" + template_params = [] + for field_name, field_value in op.dict_items(): + if isinstance(field_value, tuple): + tuple_elements = ", ".join(map(str, iter(field_value))) + if "ds" in field_name: # element type and layout for bias + arg = f"/* {field_name} */ Tuple<{tuple_elements}>" + else: # tile shape + arg = f"/* {field_name} */ S<{tuple_elements}>" + # pyrefly: ignore [bad-argument-type] + template_params.append(arg) + else: + if field_value is not None: + # pyrefly: ignore [bad-argument-type] + template_params.append(f"/* {field_name} */ {field_value}") + return self._template_from_string(template_definition).render( + operation_name=op.name(), + template_params=(",\n" + 12 * " ").join(template_params), + ), self._template_from_string(template_type).render(operation_name=op.name()) + + def render( # type: ignore[override] + self, + kernel: ROCmTemplateKernel, + op: "CKGroupedConvFwdOp", # type: ignore[name-defined] + **kwargs, + ) -> str: + template_buffer_node = kwargs.get("template_buffer_node") + if template_buffer_node is not None: + self.output_node = template_buffer_node + X, W = self.input_nodes[0], self.input_nodes[1] + Y = self.output_node + Bias = self.input_nodes[2] if 3 == len(self.input_nodes) else None + + op = copy.deepcopy(op) + + instance_definition, instance_type = self.emit_ck_instance(op) + + size_arg_strs = [ + "GroupCount", + "NBatch", + "NOutChannels", + "NInChannels", + "FilterSize_0", + "FilterSize_1", + "InputSize_0", + "InputSize_1", + "ConvolutionStrides_0", + "ConvolutionStrides_1", + "Dilations_0", + "Dilations_1", + "LeftPads_0", + "LeftPads_1", + "RightPads_0", + "RightPads_1", + ] + + return self._template_from_string(self.conv_template).render( + headers=self.header().getvalue(), + globals=self.globals().getvalue(), + instance_definition=instance_definition, + instance_type=instance_type, + kernel_definition=kernel.def_kernel( + inputs=[X, W, Bias] if Bias is not None else [X, W], + outputs=[Y], + names_str="input, weight, bias, output" + if Bias is not None + else "input, weight, output", + size_args=[f"int32_t {arg}" for arg in size_arg_strs], + ), + n_d_tensors=1 if Bias is not None else 0, + n_dim_spatial=self.n_spatial_dimensions, + input_layout=op.a_layout, + weight_layout=op.b_layout, + output_layout=op.e_layout, + ) + + def size_args(self): + x, w = self.input_nodes[0], self.input_nodes[1] + y = self.output_node + + group_count = self.groups + n_batch = x.shape[0] # type: ignore[index] + n_out_channels = y.shape[1] # type: ignore[index] + n_in_channels = x.shape[1] # type: ignore[index] + + filter_size_0, filter_size_1 = w.shape[2:4] # type: ignore[index] + input_size_0, input_size_1 = x.shape[2:4] # type: ignore[index] + convolution_strides_0, convolution_strides_1 = self.stride + dilations_0, dilations_1 = self.dilation + left_pads_0, left_pads_1 = self.padding + right_pads_0, right_pads_1 = self.padding + + return ( + group_count, + n_batch, + n_out_channels, + n_in_channels, + filter_size_0, + filter_size_1, + input_size_0, + input_size_1, + convolution_strides_0, + convolution_strides_1, + dilations_0, + dilations_1, + left_pads_0, + left_pads_1, + right_pads_0, + right_pads_1, + ) + + @override + def get_runtime_arg_info(self) -> list[ArgInfo]: + return [] + + @override + def get_runtime_arg_values(self, **kwargs: Any) -> list[Any]: + """ + Helper method to retrieve runtime args from generate kwargs + """ + return [] diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/ck_template.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/ck_template.py new file mode 100644 index 0000000000000000000000000000000000000000..b1eaf5c228eed80b5b9e40e3bbbd4e2de07b7c45 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/ck_template.py @@ -0,0 +1,110 @@ +from typing import Any +from typing_extensions import override + +import torch +from torch._inductor.codegen.rocm.rocm_template import ROCmTemplate +from torch._inductor.ir import IRNode +from torch._inductor.utils import IndentedBuffer + +from .rocm_template import ArgInfo + + +class CKTemplate(ROCmTemplate): + """ + Base class for generating CK templates, has common, i.e. non-gemm-specific, code generation logic + """ + + _TORCH_DTYPE_TO_CK = { + torch.float32: "F32", + torch.float64: "F64", + torch.float16: "F16", + torch.bfloat16: "BF16", + torch.int32: "I32", + torch.int8: "I8", + torch.float8_e4m3fnuz: "F8", # gfx94 + torch.float8_e4m3fn: "F8", # gfx95 + torch.float8_e5m2fnuz: "BF8", # gfx94 + torch.float8_e5m2: "BF8", # gfx95 + } + + def header(self) -> IndentedBuffer: + res = super().header() + res.splice( + """ + // CK headers + + #ifdef DEBUG_LOG + #define DEBUG_LOG_TMP DEBUG_LOG + #undef DEBUG_LOG + #else + #define DEBUG_LOG_TMP 0 + #endif + #include "ck/ck.hpp" + #undef DEBUG_LOG + #define DEBUG_LOG DEBUG_LOG_TMP + + #include "ck/utility/data_type.hpp" + #include "ck/library/utility/check_err.hpp" + #include "ck/library/utility/device_memory.hpp" + #include "ck/library/utility/fill.hpp" + #include "ck/library/utility/host_tensor.hpp" + #include "ck/library/utility/host_tensor_generator.hpp" + #include "ck/library/utility/literals.hpp" + """ + ) + return res + + def globals(self) -> IndentedBuffer: + res = super().globals() + res.splice( + """ + // CK globals + + template + using S = ck::Sequence; + + template + using Tuple = ck::Tuple; + + using PassThrough = ck::tensor_operation::element_wise::PassThrough; + using Bilinear = ck::tensor_operation::element_wise::Bilinear; + using Scale = ck::tensor_operation::element_wise::Scale; + using ScaleAdd = ck::tensor_operation::element_wise::ScaleAdd; + using MultiplyMultiply = ck::tensor_operation::element_wise::MultiplyMultiply; + + // see "composable_kernel/include/ck/utility/data_type.hpp" + using F8 = ck::f8_t; + using BF8 = ck::bf8_t; + using F16 = ck::half_t; + using F32 = float; + // using F64 = double; + using BF16 = ck::bhalf_t; + // using I32 = int32_t; + // using I8 = int8_t; + // using I4 = ck::int4_t; + + #if DEBUG_LOG + static constexpr auto kDEBUG_LOG = 1; + #else + static constexpr auto kDEBUG_LOG = 0; + #endif + """ + ) + return res + + def torch_type_to_ck(self, node: IRNode, ptr: str) -> str: + if node is None: + return ptr + else: + return f"({self._TORCH_DTYPE_TO_CK.get(node.get_dtype())}*)({ptr})" + + @override + def get_runtime_arg_info(self) -> list[ArgInfo]: + return [ArgInfo("kBatch", "int32_t")] + + @override + def get_runtime_arg_values(self, **kwargs: Any) -> list[Any]: + """ + Helper method to retrieve runtime args from generate kwargs + """ + return [kwargs[arg.name] for arg in self.get_runtime_arg_info()] diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/ck_tile_template.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/ck_tile_template.py new file mode 100644 index 0000000000000000000000000000000000000000..70d31d635cc36dca295b1d82066376a1185c4da9 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/ck_tile_template.py @@ -0,0 +1,58 @@ +import torch +from torch._inductor.codegen.rocm.rocm_template import ROCmTemplate +from torch._inductor.ir import IRNode +from torch._inductor.utils import IndentedBuffer + + +class CKTileTemplate(ROCmTemplate): + """ + Base class for generating CK templates, has common, i.e. non-gemm-specific, code generation logic + """ + + _TORCH_DTYPE_TO_CK = { + torch.float32: "F32", + torch.float64: "F64", + torch.float16: "F16", + torch.bfloat16: "BF16", + torch.int32: "I32", + torch.int8: "I8", + torch.float8_e4m3fnuz: "F8", # gfx94 + torch.float8_e4m3fn: "F8", # gfx95 + torch.float8_e5m2fnuz: "BF8", # gfx94 + torch.float8_e5m2: "BF8", # gfx95 + } + + ck_dtype_to_size = { + "FP16": 2, + "BF16": 2, + } + + def header(self) -> IndentedBuffer: + res = super().header() + res.splice( + """ + // CK headers + #include "ck_tile/core.hpp" + + """ + ) + return res + + def globals(self) -> IndentedBuffer: + res = super().globals() + res.splice( + """ + using F8 = ck_tile::fp8_t; + using BF8 = ck_tile::bf8_t; + using F16 = ck_tile::half_t; + using F32 = float; + using BF16 = ck_tile::bfloat16_t; + """ + ) + return res + + def torch_type_to_ck(self, node: IRNode, ptr: str) -> str: + if node is None: + return ptr + else: + return f"({self._TORCH_DTYPE_TO_CK.get(node.get_dtype())}*)({ptr})" diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/ck_tile_universal_gemm_template.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/ck_tile_universal_gemm_template.py new file mode 100644 index 0000000000000000000000000000000000000000..94a79297ef5e47e16f98a8968c815262a9d24d75 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/ck_tile_universal_gemm_template.py @@ -0,0 +1,979 @@ +# mypy: allow-untyped-defs, disable-error-code="attr-defined, valid-type" +import functools +import logging +import random +from dataclasses import asdict, dataclass +from typing import Any + +import torch +from torch._inductor import config +from torch._inductor.codegen.rocm.ck_tile_template import CKTileTemplate +from torch._inductor.codegen.rocm.rocm_kernel import ROCmTemplateKernel +from torch._inductor.codegen.rocm.rocm_template import ArgInfo +from torch._inductor.ir import Buffer, Layout +from torch.utils._ordered_set import OrderedSet + +from ...utils import IndentedBuffer + + +log = logging.getLogger(__name__) + + +def is_static_int(number): + import sympy + + return isinstance(number, (int, sympy.Integer)) + + +def torch_layout_to_ck_layout(torch_layout): + if torch_layout.stride[-1] == 1: + return "Row" + elif torch_layout.stride[-2] == 1: + return "Col" + else: + return None + + +@dataclass +class CKTileGemmOperation: + layout_a: str + layout_b: str + layout_c: str + + datatype_a: str + datatype_b: str + datatype_c: str + + tile_m: int + tile_n: int + tile_k: int + + warp_m: int + warp_n: int + warp_k: int + + warp_tile_m: int + warp_tile_n: int + warp_tile_k: int + + m_is_padded: str + n_is_padded: str + k_is_padded: str + + pipeline: str + scheduler: str + epilogue: str + + def layout_repr(self): + return f"{self.layout_a[0]}{self.layout_b[0]}{self.layout_c[0]}" + + def dtype_repr(self): + return f"{self.datatype_a}{self.datatype_b}{self.datatype_c}" + + def tile_sizes(self): + return "_".join( + [ + f"{self.tile_m}{self.tile_n}{self.tile_k}", + f"{self.warp_m}{self.warp_n}{self.warp_k}", + f"{self.warp_tile_m}{self.warp_tile_n}{self.warp_tile_k}", + ] + ) + + def name(self): + return "ck_tile_gemm_universal_" + "_".join( + [ + f"{self.layout_repr()}", + f"{self.dtype_repr()}", + f"{self.tile_sizes()}", + f"{self.pipeline}", + f"{self.scheduler}", + f"{self.epilogue}", + ] + ) + + def dict_items(self): + return asdict(self).items() + + +@functools.cache +def ops(): + """ + Generate the supported instance dataclasses + """ + import itertools + + compute_v3_instances = [ + CKTileGemmOperation( + layout_a=layout_a, + layout_b=layout_b, + layout_c=layout_c, + datatype_a=datatype_a, + datatype_b=datatype_b, + datatype_c=datatype_c, + tile_m=tile_m, + tile_n=tile_n, + tile_k=tile_k, + warp_m=warp_m, + warp_n=warp_n, + warp_k=warp_k, + warp_tile_m=warp_tile_m, + warp_tile_n=warp_tile_n, + warp_tile_k=warp_tile_k, + m_is_padded=m_is_padded, + n_is_padded=n_is_padded, + k_is_padded=k_is_padded, + pipeline="CompV3", + scheduler="Intrawave", + epilogue=epilogue, + ) + for (layout_a, layout_b, layout_c) in [ + ("Row", "Row", "Row"), + ("Row", "Col", "Row"), + ] + for (datatype_a, datatype_b, datatype_c) in [("FP16",) * 3, ("BF16",) * 3] + for (tile_m, tile_n, tile_k) in [(256, 256, 32), (256, 256, 64)] + for (warp_m, warp_n, warp_k) in [(2, 2, 1)] + for (warp_tile_m, warp_tile_n, warp_tile_k) in [(32, 32, 16)] + for m_is_padded in ["true", "false"] + for n_is_padded in ["true", "false"] + for k_is_padded in ["true", "false"] + for epilogue in ["Default", "CShuffle"] + ] + + compute_v4_instances = [ + CKTileGemmOperation( + layout_a=layout_a, + layout_b=layout_b, + layout_c=layout_c, + datatype_a=datatype_a, + datatype_b=datatype_b, + datatype_c=datatype_c, + tile_m=tile_m, + tile_n=tile_n, + tile_k=tile_k, + warp_m=warp_m, + warp_n=warp_n, + warp_k=warp_k, + warp_tile_m=warp_tile_m, + warp_tile_n=warp_tile_n, + warp_tile_k=warp_tile_k, + m_is_padded=m_is_padded, + n_is_padded=n_is_padded, + k_is_padded=k_is_padded, + pipeline="CompV4", + scheduler="Intrawave", + epilogue=epilogue, + ) + for (layout_a, layout_b, layout_c) in [ + ("Row", "Row", "Row"), + ("Row", "Col", "Row"), + ] + for (datatype_a, datatype_b, datatype_c) in [("FP16",) * 3, ("BF16",) * 3] + for (tile_m, tile_n, tile_k) in [ + (256, 256, 32) + ] # half the tile size since it has double buffering + for (warp_m, warp_n, warp_k) in [(2, 2, 1)] + for (warp_tile_m, warp_tile_n, warp_tile_k) in [(32, 32, 16)] + for m_is_padded in ["true", "false"] + for n_is_padded in ["true", "false"] + for k_is_padded in ["true", "false"] + for epilogue in ["Default", "CShuffle"] + ] + + mem_instances = [ + CKTileGemmOperation( + layout_a=layout_a, + layout_b=layout_b, + layout_c=layout_c, + datatype_a=datatype_a, + datatype_b=datatype_b, + datatype_c=datatype_c, + tile_m=tile_m, + tile_n=tile_n, + tile_k=tile_k, + warp_m=warp_m, + warp_n=warp_n, + warp_k=warp_k, + warp_tile_m=warp_tile_m, + warp_tile_n=warp_tile_n, + warp_tile_k=warp_tile_k, + m_is_padded=m_is_padded, + n_is_padded=n_is_padded, + k_is_padded=k_is_padded, + pipeline="Mem", + scheduler=scheduler, + epilogue=epilogue, + ) + for (layout_a, layout_b, layout_c) in [ + ("Row", "Row", "Row"), + ("Row", "Col", "Row"), + ] + for (datatype_a, datatype_b, datatype_c) in [("FP16",) * 3, ("BF16",) * 3] + for (tile_m, tile_n, tile_k) in [(256, 256, 32), (256, 256, 64)] + for (warp_m, warp_n, warp_k) in [(2, 2, 1)] + for (warp_tile_m, warp_tile_n, warp_tile_k) in [(32, 32, 16)] + for m_is_padded in ["true", "false"] + for n_is_padded in ["true", "false"] + for k_is_padded in ["true", "false"] + for scheduler in ["Intrawave", "Interwave"] + for epilogue in ["Default", "CShuffle"] + ] + + return list( + itertools.chain(compute_v3_instances, compute_v4_instances, mem_instances) + ) + + +class CKTileGemmTemplate(CKTileTemplate): + """ + This class is used for rendering CK-Tile Universal GEMM kernels + """ + + gemm_template = r"""{{version_comment}} + {{headers}} + {{globals}} + {{instance_definition}} + extern "C" { + PT_EXPORT {{kernel_definition}} { + + using {{instance_namespace}}::BaseGemmPipeline; + using {{instance_namespace}}::TilePartitioner; + + constexpr auto TileK = {{instance_namespace}}::TileK; + constexpr auto kPrefetchStages = BaseGemmPipeline::PrefetchStages; + + const auto BiasTerms = std::array (); + const auto BiasStrides = std::array (); + + auto kargs = ck_tile::UniversalGemmKernelArgs<> { + {X}, + {W}, + BiasTerms, + Y, + M, + N, + K, + {LDA}, + {LDB}, + BiasStrides, + LDC, + kBatch + }; + + if (workspace_size) { + *workspace_size = 0; + return 0; + } + + // run the kernel + const auto dispatch = [&](const auto has_hot_loop_, const auto tail_number_) constexpr { + using Kernel = {{instance_namespace}}::Kernel; + + if (!Kernel::IsSupportedArgument(kargs)) { + // we do our best to statically avoid this case in `filter_op` + throw std::runtime_error("invalid argument"); + } + auto stream_config = ck_tile::stream_config{stream}; + auto grid_size = Kernel::GridSize(M, N, kBatch); + constexpr auto block_size = Kernel::BlockSize(); + constexpr auto lds_bytes = 0; + constexpr auto kBlockPerCU = 1; + auto gemm = ck_tile::make_kernel(Kernel{}, grid_size, block_size, lds_bytes, kargs); + float elapsed_time = ck_tile::launch_kernel(stream_config, gemm); + }; + + const ck_tile::index_t k_grain = kBatch * TileK; + const ck_tile::index_t K_split = (K + k_grain - 1) / k_grain * TileK; + const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); + const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); + const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); + + {{rendered_dispatch}} + + return 0; + } // kernel definition + } // extern C + """ + + def __init__( + self, + input_nodes: list[Buffer], + layout: Layout, + ) -> None: + super().__init__( + "ck_tile_gemm_template", + input_nodes=input_nodes, + layout=layout, + ) + + def header(self) -> IndentedBuffer: + res = super().header() + res.splice( + """ + // CK GEMM header(s) + + #include "ck_tile/ops/gemm.hpp" + #include "ck_tile/ops/epilogue.hpp" + """ + ) + return res + + def globals(self) -> IndentedBuffer: + res = super().globals() + res.splice( + """ + // CK GEMM globals + + using Row = ck_tile::tensor_layout::gemm::RowMajor; + using Col = ck_tile::tensor_layout::gemm::ColumnMajor; + + template + void dispatch_memory_pipeline_hot_loop(const ck_tile::TailNumber tail_num, Dispatcher dispatch) + { + if(tail_num == ck_tile::TailNumber::One) + { + dispatch(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + else if(tail_num == ck_tile::TailNumber::Full) + { + dispatch(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + + if constexpr(PrefetchStages > 2) + { + if(tail_num == ck_tile::TailNumber::Two) + { + dispatch(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + } + if constexpr(PrefetchStages > 3) + { + if(tail_num == ck_tile::TailNumber::Three) + { + dispatch(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + } + if constexpr(PrefetchStages > 4) + { + if(tail_num == ck_tile::TailNumber::Four) + { + dispatch(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + } + if constexpr(PrefetchStages > 5) + { + if(tail_num == ck_tile::TailNumber::Five) + { + dispatch(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + } + if constexpr(PrefetchStages > 6) + { + if(tail_num == ck_tile::TailNumber::Six) + { + dispatch(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + } + if constexpr(PrefetchStages > 7) + { + if(tail_num == ck_tile::TailNumber::Seven) + { + dispatch(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + } + } + """ + ) + return res + + def check_dtypes(self, op: "CKTileGemmOperation"): + X_dtype, W_dtype, out_dtype = [ + T.get_layout().dtype for T in [*self.input_nodes, self.output_node] + ] + if op.datatype_a != self._TORCH_DTYPE_TO_CK[X_dtype]: + return False + if op.datatype_b != self._TORCH_DTYPE_TO_CK[W_dtype]: + return False + if op.datatype_c != self._TORCH_DTYPE_TO_CK[out_dtype]: + return False + return True + + def check_layouts(self, op: "CKTileGemmOperation"): + X_layout, W_layout, out_layout = [ + torch_layout_to_ck_layout(T.get_layout()) + for T in [*self.input_nodes, self.output_node] + ] + if op.layout_a != X_layout: + return False + if op.layout_b != W_layout: + return False + if op.layout_c != out_layout: + return False + return True + + def get_gemm_problem_size(self): + X_size, W_size = [T.get_layout().size for T in [*self.input_nodes]] + + M, K = X_size + _, N = W_size + + return M, N, K + + def check_block_tiles(self, op: "CKTileGemmOperation"): + """ + The contiguous dimension of a tensor must be divisible by the block tile size + This helper function enforces it for the inputs and the output. + """ + M, N, K = self.get_gemm_problem_size() + + def check(dim_size, tile_size, is_padded): + if ( + is_static_int(dim_size) + and dim_size % tile_size != 0 + and is_padded == "false" + ): + return False + return True + + if op.layout_a == "Row": + # handle in kBatch check + return True + elif op.layout_a == "Col": + if not check(M, op.tile_m, op.m_is_padded): + return False + else: + raise AssertionError(f"Invalid layout {op.layout_a=}") + + if op.layout_b == "Row": + if not check(N, op.tile_n, op.n_is_padded): + return False + elif op.layout_b == "Col": + # handle in kBatch check + return True + else: + raise AssertionError(f"Invalid {op.layout_b=}") + + if op.layout_c == "Row": + if not check(N, op.tile_n, op.n_is_padded): + return False + elif op.layout_c == "Col": + if not check(M, op.tile_m, op.m_is_padded): + return False + else: + raise AssertionError(f"Invalid layout {op.layout_c=}") + + return True + + def check_alignments(self, op: "CKTileGemmOperation"): + """ + The contiguous dimension of a tensor must be divisible by the vector load size. + """ + M, N, K = self.get_gemm_problem_size() + + def max_alignment(contiguous_elements_per_tile, elements_per_thread, ck_dtype): + for vector_load_bytes in (16, 8, 4, 2, 1): + alignment = vector_load_bytes // self.ck_dtype_to_size[ck_dtype] + if ( + alignment > 0 + and contiguous_elements_per_tile % alignment == 0 + and elements_per_thread % alignment == 0 + ): + return alignment + + threads_per_block = ( + op.warp_m * op.warp_n * op.warp_k * self.gfx9_threads_per_warp + ) + a_elements_per_thread = op.tile_m * op.tile_k / threads_per_block + b_elements_per_thread = op.tile_n * op.tile_k / threads_per_block + + if op.layout_a == "Row": + # K is contiguous tensor dimension + a_max_vector_size = max_alignment( + op.tile_k, a_elements_per_thread, op.datatype_a + ) + if is_static_int(K) and K % a_max_vector_size != 0: + return False + elif op.layout_a == "Col": + # M is contiguous tensor dimension + a_max_vector_size = max_alignment( + op.tile_m, a_elements_per_thread, op.datatype_a + ) + if is_static_int(M) and M % a_max_vector_size != 0: + return False + else: + raise AssertionError(f"Invalid layout {op.layout_a=}") + + if op.layout_b == "Row": + # N is contiguous tensor dimension + b_max_vector_size = max_alignment( + op.tile_n, b_elements_per_thread, op.datatype_b + ) + if is_static_int(N) and N % b_max_vector_size != 0: + return False + elif op.layout_b == "Col": + # K is contiguous tensor dimension + b_max_vector_size = max_alignment( + op.tile_k, b_elements_per_thread, op.datatype_b + ) + if is_static_int(K) and K % b_max_vector_size != 0: + return False + else: + raise AssertionError(f"Invalid layout {op.layout_b=}") + + # the `default` epilogue writes C to memory by 1 tensor element + # (divisibility check not necessary) + # the `cshuffle` epilogue writes C to memory by 16 bytes + # (so the contiguous C dimension size must be divisible by the number of tensor elements in 16 bytes) + if op.epilogue == "CShuffle": + if ( + op.layout_c == "Row" + and is_static_int(N) + and N % (16 / self.ck_dtype_to_size[op.datatype_c]) != 0 + ): + return False + + return True + + def check_warp_tiles(self, op: "CKTileGemmOperation"): + if op.tile_m % (op.warp_m * op.warp_tile_m) != 0: + return False + if op.tile_n % (op.warp_n * op.warp_tile_n) != 0: + return False + if op.tile_k % (op.warp_k * op.warp_tile_k) != 0: + return False + return True + + def check_block_tile_size(self, op: "CKTileGemmOperation"): + # assuming LDS size is 64KB + if op.pipeline == "CompV4": + max_block_tile_size = 2**15 + else: + max_block_tile_size = 2**16 + + block_tile_size = ( + self.ck_dtype_to_size[op.datatype_a] * op.tile_m * op.tile_k + + self.ck_dtype_to_size[op.datatype_b] * op.tile_n * op.tile_k + ) + if block_tile_size > max_block_tile_size: + return False + return True + + def filter_op(self, op: "CKTileGemmOperation"): + """ + Determines whether a given op definition is suitable for the current + input / output of the operation that this template implements. + + Filter is based on inputs' dtype, layout and statically inferred size. + + Returns None if the op is not suitable, otherwise returns the op to be used. + """ + if not self.check_dtypes(op): + return None + if not self.check_layouts(op): + return None + if not self.check_block_tiles(op): + return None + if not self.check_alignments(op): + return None + + return op + + def emit_ck_instance(self, op: "CKTileGemmOperation"): + """ + This method is used to generate code which defines the type alias for the generated kernel class + """ + template_definition = r""" + // Gemm operator {{operation_name}} + + namespace {{operation_name}} { + // block tile + constexpr int32_t TileM = {{tile_m}}; + constexpr int32_t TileN = {{tile_n}}; + constexpr int32_t TileK = {{tile_k}}; + // warps per block + constexpr int32_t WarpM = {{warp_m}}; + constexpr int32_t WarpN = {{warp_n}}; + constexpr int32_t WarpK = {{warp_k}}; + // xdl tile + constexpr int32_t WarpTileM = {{warp_tile_m}}; + constexpr int32_t WarpTileN = {{warp_tile_n}}; + constexpr int32_t WarpTileK = {{warp_tile_k}}; + + constexpr bool kPadM = {{m_is_padded}}; + constexpr bool kPadN = {{n_is_padded}}; + constexpr bool kPadK = {{k_is_padded}}; + + using ALayout = {{layout_a}}; + using BLayout = {{layout_b}}; + using CLayout = {{layout_c}}; + + using ADataType = {{datatype_a}}; + using BDataType = {{datatype_b}}; + using CDataType = {{datatype_c}}; + using AccDataType = F32; + + constexpr bool permuteA = false; + constexpr bool permuteB = false; + constexpr bool DoubleSmemBuffer = {{has_double_smem_buffer}}; + constexpr bool TransposeC = false; + + constexpr int kBlockPerCu = 1; + constexpr ck_tile::index_t TilePartitionerGroupNum = 8; + constexpr ck_tile::index_t TilePartitionerM01 = 4; + + using GemmShape = + ck_tile::TileGemmShape, + ck_tile::sequence, + ck_tile::sequence, + permuteA, + permuteB>; + + using TilePartitioner = + ck_tile::GemmSpatiallyLocalTilePartitioner; + + using Traits = + ck_tile::TileGemmTraits; + + using GemmUniversalTraits = + ck_tile::TileGemmUniversalTraits; + + using GemmPipelineProblem = + ck_tile::GemmPipelineProblem; + + {{rendered_scheduler}} + + template + using UniversalGemmProblem = + ck_tile::UniversalGemmPipelineProblem; + + {{rendered_pipeline}} + + {{rendered_epilogue}} + + template + using Kernel = ck_tile::GemmKernel, GemmEpilogue>; + } + +""" + + def render_epilogue(epilogue_type): + if epilogue_type == "Default": + return r""" + using EpilogueProblem = ck_tile::DefaultGemm2DEpilogueProblem; + using GemmEpilogue = ck_tile::DefaultGemm2DEpilogue; + """ + elif epilogue_type == "CShuffle": + return r""" + constexpr auto kMemoryOperation = ck_tile::memory_operation_enum::set; + using DsDataType = ck_tile::tuple<>; // no bias terms for vanilla GEMM + using DsLayout = ck_tile::tuple<>; + constexpr auto ELayout = CLayout; + using CDEElementWise = ck_tile::element_wise::PassThrough; // no-op + using EpilogueProblem = ck_tile::CShuffleEpilogueProblem; + + using GemmEpilogue = ck_tile::CShuffleEpilogue; + """ + else: + raise AssertionError("Epilogue must be set") + + def render_pipeline(pipeline_type): + return rf""" + using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCr{pipeline_type}; + + template + using GemmPipeline = ck_tile::GemmPipelineAgBgCr{pipeline_type}>; + """ + + def render_scheduler(scheduler_type): + return rf""" + constexpr auto scheduler = ck_tile::GemmPipelineScheduler::{scheduler_type}; + """ + + rendered_definition = self._template_from_string(template_definition).render( + operation_name=op.name(), + **asdict(op), + rendered_scheduler=render_scheduler(op.scheduler), + rendered_pipeline=render_pipeline(op.pipeline), + rendered_epilogue=render_epilogue(op.epilogue), + has_double_smem_buffer=("true" if op.pipeline == "CompV4" else "false"), + ) + return rendered_definition + + def render( # type: ignore[override] + self, kernel: ROCmTemplateKernel, op: "CKTileGemmOperation", **kwargs + ) -> str: + """ + The primary entry point for the code rendering process used in this template. + """ + epilogue_nodes = kwargs.get("epilogue_nodes") + assert epilogue_nodes is None or 0 == len(epilogue_nodes) + template_buffer_node = kwargs.get("template_buffer_node") + if template_buffer_node is not None: + self.output_node = template_buffer_node + assert 2 == len(self.input_nodes) + X, W = self.input_nodes + Y = self.output_node + + instance_definition = self.emit_ck_instance(op) + + version_comment = rf"""/** +* Generated code for CK inductor backend +* See {type(self).__module__}.{type(self).__qualname__} +* +* Template instance {op} +* +* {torch.__version__=} +* torch.version.git_version={getattr(torch.version, "git_version", "None")} +*/ +""" + + def render_dispatch(pipeline_type, op_name): + switch_tailnum_template = r""" + switch (tail_num) { + {% for tail_num in valid_tailnums %} + case ck_tile::TailNumber::{{tail_num}}: + dispatch({{has_hot_loop}}, + ck_tile::integral_constant{}); + break; + {% endfor %} + default: + std::ostringstream err; + err << "Unsupported dispatch: " + << "Pipeline: " << "{{pipeline}}" + << "Prefetch stages: " << kPrefetchStages + << "Tail num: " << tail_num; + throw std::runtime_error(err.str()); + } // switch tail_num + """ + dispatch_template = r""" + if (has_hot_loop) { + {{rendered_with_hot_loop}} + } + else { // has_hot_loop == false + {{rendered_without_hot_loop}} + } // if has_hot_loop + """ + if pipeline_type == "CompV3": + return self._template_from_string(dispatch_template).render( + rendered_with_hot_loop=self._template_from_string( + switch_tailnum_template + ).render( + has_hot_loop="ck_tile::integral_constant{}", + valid_tailnums=("Full", "Odd", "Even"), + pipeline=pipeline_type, + ), + rendered_without_hot_loop=self._template_from_string( + switch_tailnum_template + ).render( + has_hot_loop="ck_tile::integral_constant{}", + valid_tailnums=("Full", "Odd", "Even"), + pipeline=pipeline_type, + ), + ) + elif pipeline_type == "Mem": + return self._template_from_string(dispatch_template).render( + rendered_with_hot_loop="dispatch_memory_pipeline_hot_loop(tail_num, dispatch);", + rendered_without_hot_loop=self._template_from_string( + switch_tailnum_template + ).render( + has_hot_loop="ck_tile::integral_constant{}", + valid_tailnums=("Full", "Odd", "Even"), + pipeline=pipeline_type, + ), + ) + elif pipeline_type == "CompV4": + return self._template_from_string(dispatch_template).render( + rendered_with_hot_loop=self._template_from_string( + switch_tailnum_template + ).render( + has_hot_loop="ck_tile::integral_constant{}", + valid_tailnums=("Two", "Three"), + pipeline=pipeline_type, + ), + rendered_without_hot_loop=self._template_from_string( + switch_tailnum_template + ).render( + has_hot_loop="ck_tile::integral_constant{}", + valid_tailnums=("Full", "Odd", "Even"), + pipeline=pipeline_type, + ), + ) + else: + raise AssertionError(f"Pipeline {pipeline_type} is not supported") + + return self._template_from_string(self.gemm_template).render( + headers=self.header().getvalue(), + globals=self.globals().getvalue(), + instance_definition=instance_definition, + kernel_definition=kernel.def_kernel( + inputs=[X, W], # type: ignore[list-item] + outputs=[Y], + names_str="X, W, Y", + size_args=[ + f"int32_t {arg}" for arg in ["M", "N", "K", "LDA", "LDB", "LDC"] + ], + ), + instance_namespace=op.name(), + version_comment=version_comment, + rendered_dispatch=render_dispatch(op.pipeline, op.name()), + ) + + def gen_ops(self): + """ + Creates a list of `CKTileGemmOperation` instances that match the GEMM operation this template represents. + The instances are guaranteed to have the correct layout, dtype and dimension padding for the GEMM input arguments. + + An instance may invalidate the GEMM configuration at runtime. + Such instances will be assigned +inf runtime by the autotune process. + """ + instances = ops() + if not instances: + raise AssertionError( + "No Composable Kernel Universal GEMM instances found. " + "Please check if the library is installed." + ) + filtered_instances = list(filter(self.filter_op, instances)) + # NB: when using a fixed list order, most likely we will pick the subset of instances + # which are very similar to each other. Randomizing the choice seems to solve this. + random.seed(-11) + chosen_instances = ( + random.sample( + filtered_instances, + min(len(filtered_instances), config.rocm.ck_tile_max_profiling_configs), + ) + if config.rocm.ck_tile_max_profiling_configs + else filtered_instances + ) + log.debug( + "generated %d ck instances after sample: %s", + len(chosen_instances), + chosen_instances, + ) + return chosen_instances + + @staticmethod + def add_choices( + choices, + layout, + input_nodes, + ): + """ + Add Composable Kernel Universal GEMM instance choices to the auto-tuning list. + """ + template = CKTileGemmTemplate( + input_nodes, + layout, + ) + ops = template.gen_ops() + for op in ops: + for k_batch in template.k_batch_choices(op): + template.maybe_append_choice( + choices, + op=op, + kBatch=k_batch, + ) + + def k_batch_choices(self, op: "CKTileGemmOperation") -> tuple[int, ...]: + """ + Returns a list of k_batch choices for the template. + """ + default_choices = (1, 2, 4, 8, 16, 32) + + def check(dim_size, tile_size, is_padded): + if ( + is_static_int(dim_size) + and dim_size % tile_size != 0 + and is_padded == "false" + ): + return False + return True + + _, _, K, _, _, _ = self.size_args() + if op.layout_a == "Row" or op.layout_b == "Col": + choices = tuple( + filter( + lambda k_batch: check(K, op.tile_k * k_batch, op.k_is_padded), + default_choices, + ) + ) + else: + choices = default_choices + + if op.epilogue == "Default": + choices = (1,) + + return choices + + def size_args(self): + """ + Sizes and strides to be used for the kernel call + """ + X = self.input_nodes[0] + W = self.input_nodes[1] + Y = self.output_node + + M = X.get_size()[0] + K = X.get_size()[1] + N = W.get_size()[1] + LDA = X.get_stride()[0 if X.get_stride()[1] == 1 else 1] + LDB = W.get_stride()[0 if W.get_stride()[1] == 1 else 1] + LDC = Y.get_stride()[0 if Y.get_stride()[1] == 1 else 1] + + return M, N, K, LDA, LDB, LDC + + def get_runtime_arg_info(self) -> list[ArgInfo]: + return [ArgInfo("kBatch", "int32_t")] + + def get_runtime_arg_values(self, **kwargs: Any) -> list[Any]: + # maybe_append_choice kwarg for k_batch must match the name of the argument + arg_names = OrderedSet([arg.name for arg in self.get_runtime_arg_info()]) + if not arg_names.issubset(kwargs): + raise ValueError( + "Missing runtime arguments: " + ", ".join(arg_names - kwargs.keys()) + ) + return [kwargs[k] for k in arg_names] diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/ck_universal_gemm_template.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/ck_universal_gemm_template.py new file mode 100644 index 0000000000000000000000000000000000000000..e9f8ff54f9f45bc46fb3d4be5b74d36990fc69cf --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/ck_universal_gemm_template.py @@ -0,0 +1,1019 @@ +# mypy: allow-untyped-defs, disable-error-code="attr-defined, valid-type" +import copy +import logging +import math +import random +from collections import namedtuple +from typing import Optional + +import sympy + +import torch +from torch._inductor import config +from torch._inductor.codegen.cpp_utils import DTYPE_TO_CPP +from torch._inductor.codegen.rocm.ck_template import CKTemplate +from torch._inductor.codegen.rocm.compile_command import rocm_compile_command +from torch._inductor.codegen.rocm.rocm_kernel import ROCmTemplateKernel +from torch._inductor.ir import Buffer, Layout +from torch._inductor.runtime.runtime_utils import next_power_of_2 + +from ...utils import IndentedBuffer, is_dynamic, try_import_ck_lib + + +_, gen_ops_library, gen_ops_preselected, CKGemmOperation = try_import_ck_lib() + + +log = logging.getLogger(__name__) + +# lightweight collection of information about a single op +InductorROCmOp = namedtuple("InductorROCmOp", ["op", "kBatch"]) + +padding_lookup = { + "M": { + "GemmSpecialization::MPadding": True, + "GemmSpecialization::MNPadding": True, + "GemmSpecialization::MKPadding": True, + "GemmSpecialization::MNKPadding": True, + }, + "N": { + "GemmSpecialization::NPadding": True, + "GemmSpecialization::MNPadding": True, + "GemmSpecialization::NKPadding": True, + "GemmSpecialization::MNKPadding": True, + }, + "K": { + "GemmSpecialization::KPadding": True, + "GemmSpecialization::MKPadding": True, + "GemmSpecialization::NKPadding": True, + "GemmSpecialization::MNKPadding": True, + }, +} + + +def is_static_int(number): + return isinstance(number, (int, sympy.Integer)) + + +def torch_layout_to_ck_layout(torch_layout): + if torch_layout.stride[-1] == 1: + return "Row" + elif torch_layout.stride[-2] == 1: + return "Col" + else: + return None + + +class CKGemmTemplate(CKTemplate): + # the JINJA template for rendering CK Universal GEMMs + gemm_template = r"""{{version_comment}} + {{headers}} + {{globals}} + {{instance_definition}} + extern "C" { + PT_EXPORT {{kernel_definition}} { + auto gemm = {{instance_type}} {}; + auto invoker = gemm.MakeInvoker(); + {% if is_batched %} + auto argument = gemm.MakeArgument( + reinterpret_cast(X), + reinterpret_cast(W), + std::array{ {{ds_names}} }, + reinterpret_cast<{{c_element_dtype}}*>(Y), + M, + N, + K, + B, + LDA, + LDB, + std::array{ {{ds_strides}} }, + LDC, + M * K, // batch_stride_A + N * K, // batch_stride_B + std::array{ {{ds_batch_strides}} }, + M * N, // batch_stride_C + {{a_elementwise_op}}, + {{b_elementwise_op}}, + {{epilogue}} // c_elementwise_op + ); + {% else %} + auto argument = gemm.MakeArgument( + reinterpret_cast(X), + reinterpret_cast(W), + std::array{ {{ds_names}} }, + reinterpret_cast<{{c_element_dtype}}*>(Y), + M, + N, + K, + LDA, + LDB, + std::array{ {{ds_strides}} }, + LDC, + kBatch, // kBatch + {{a_elementwise_op}}, + {{b_elementwise_op}}, + {{epilogue}} // c_elementwise_op + ); + {% endif %} + if (!gemm.IsSupportedArgument(argument)) { + // we do our best to statically avoid this case in `filter_op` + std::cerr << "invalid argument for gemm instance " << gemm.GetTypeString() << std::endl; + argument.Print(); + return -23; + } + if (workspace_size) { + *workspace_size = gemm.GetWorkSpaceSize(&argument); + return 0; + } + // run the kernel + #ifdef GENERATE_CK_STANDALONE_RUNNER + const auto stream_config = StreamConfig{ + stream, + /* time kernel */ 1, + /* log level */ 1, + /* n_cold_iter */ 100, + /* n_hot_iter */ 100, + /* flush_l2_cache */ 1, + /* rotate_count */ 5}; + #else + const auto stream_config = StreamConfig{stream, /* time kernel */ false, /* log level */ 0}; + #endif + + const float elapsed_time = invoker.Run(argument, stream_config); + + #ifdef GENERATE_CK_STANDALONE_RUNNER + std::cout << "elapsed time: " << elapsed_time << " ms" << std::endl; + #else + (void)elapsed_time; + #endif + return 0; + } // kernel definition + } // extern C + """ + + standalone_runner_template = r""" + #ifdef GENERATE_CK_STANDALONE_RUNNER + // standalone runner for the generated CK GEMM kernel + + {{inline_utils}} + + extern "C" { + int run_main(int argc, char** argv) { + {% if is_batched %} + const int32_t B = {{B}}; + {% endif %} + const int32_t M = {{M}}; + const int32_t N = {{N}}; + const int32_t K = {{K}}; + const int32_t LDA = {{LDA}}; + const int32_t LDB = {{LDB}}; + const int32_t LDC = {{LDC}}; + const int32_t LDD = {{LDD}}; + const int32_t kBatch = {{kBatch}}; + + using AElementType = {{a_ck_dtype}}; + using BElementType = {{b_ck_dtype}}; + using CElementType = {{c_ck_dtype}}; + {% if has_bias %} + using BiasElementType = {{bias_ck_dtype}}; + {% endif %} + {% if has_scale %} + using ScaleAElementType = {{scale_a_ck_dtype}}; + using ScaleBElementType = {{scale_b_ck_dtype}}; + {% endif %} + + using AArgType = {{a_torch_dtype}}; + using BArgType = {{b_torch_dtype}}; + using CArgType = {{c_torch_dtype}}; + {% if has_bias %} + using BiasArgType = {{bias_torch_dtype}}; + {% endif %} + {% if has_scale %} + using ScaleAArgType = {{scale_a_torch_dtype}}; + using ScaleBArgType = {{scale_b_torch_dtype}}; + {% endif %} + + using ALayout = {{a_layout}}; + using BLayout = {{b_layout}}; + using CLayout = {{c_layout}}; + {% if has_bias %} + using BiasLayout = {{bias_layout}}; + {% endif %} + + {% if is_batched %} + using strides_t = std::array; + auto get_strides = [](int32_t batch_stride, int32_t leading_dimension, auto layout) constexpr -> strides_t { + if constexpr (std::is_same_v) { + return {batch_stride, leading_dimension, 1}; + } + return {batch_stride, 1, leading_dimension}; + }; + auto a_size = strides_t{B, M, K}; + auto a_stride = get_strides(M * K, LDA, ALayout{}); + auto b_size = strides_t{B, N, K}; + auto b_stride = get_strides(N * K, LDB, BLayout{}); + auto c_size = strides_t{B, M, N}; + auto c_stride = get_strides(M * N, LDC, CLayout{}); + {% else %} + using strides_t = std::array; + auto get_strides = [](int32_t leading_dimension, auto layout) constexpr -> strides_t { + if constexpr (std::is_same_v) { + return {leading_dimension, 1}; + } + return {1, leading_dimension}; + }; + auto a_size = strides_t{M, K}; + auto a_stride = get_strides(LDA, ALayout{}); + auto b_size = strides_t{N, K}; + auto b_stride = get_strides(LDB, BLayout{}); + auto c_size = strides_t{M, N}; + auto c_stride = get_strides(LDC, CLayout{}); + {% endif %} + + Tensor a_m_k ( HostTensorDescriptor ( a_size, a_stride ) ); + Tensor b_k_n ( HostTensorDescriptor ( b_size, b_stride ) ); + {% if has_bias %} + Tensor d_m_n ( HostTensorDescriptor ( c_size, get_strides(LDD, BiasLayout{}) ) ); + {% endif %} + {% if has_scale %} + // NB: these are hardcoded + Tensor s_a_m_n ( HostTensorDescriptor ( strides_t{M, N}, get_strides(0, Row{}) )); + Tensor s_b_m_n ( HostTensorDescriptor ( strides_t{M, N}, get_strides(0, Col{}) )); + {% endif %} + + Tensor c_m_n_host ( HostTensorDescriptor ( c_size, c_stride ) ); + Tensor c_m_n_device ( HostTensorDescriptor ( c_size, c_stride ) ); + + a_m_k.GenerateTensorValue(GeneratorTensor_2()); + b_k_n.GenerateTensorValue(GeneratorTensor_2()); + {% if has_bias %} + d_m_n.GenerateTensorValue(GeneratorTensor_2()); + {% endif %} + {% if has_scale %} + s_a_m_n.GenerateTensorValue(GeneratorTensor_2()); + s_b_m_n.GenerateTensorValue(GeneratorTensor_2()); + {% endif %} + DeviceMem a_m_k_device_buf(sizeof(AElementType) * a_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b_k_n_device_buf(sizeof(BElementType) * b_k_n.mDesc.GetElementSpaceSize()); + {% if has_bias %} + DeviceMem d_m_n_device_buf(sizeof(BiasElementType) * d_m_n.mDesc.GetElementSpaceSize()); + {% endif %} + {% if has_scale %} + DeviceMem s_a_m_n_device_buf(sizeof(ScaleAElementType) * s_a_m_n.mDesc.GetElementSpaceSize()); + DeviceMem s_b_m_n_device_buf(sizeof(ScaleBElementType) * s_b_m_n.mDesc.GetElementSpaceSize()); + {% endif %} + DeviceMem c_m_n_device_buf(sizeof(CElementType) * c_m_n_device.mDesc.GetElementSpaceSize()); + + a_m_k_device_buf.ToDevice(a_m_k.mData.data()); + b_k_n_device_buf.ToDevice(b_k_n.mData.data()); + {% if has_bias %} + d_m_n_device_buf.ToDevice(d_m_n.mData.data()); + {% endif %} + {% if has_scale %} + s_a_m_n_device_buf.ToDevice(s_a_m_n.mData.data()); + s_b_m_n_device_buf.ToDevice(s_b_m_n.mData.data()); + {% endif %} + + {{kernel_name}}( + static_cast(a_m_k_device_buf.GetDeviceBuffer()), + static_cast(b_k_n_device_buf.GetDeviceBuffer()), + {% if has_scale %} + static_cast(s_a_m_n_device_buf.GetDeviceBuffer()), + static_cast(s_b_m_n_device_buf.GetDeviceBuffer()), + {% endif %} + {% if has_bias %} + static_cast(d_m_n_device_buf.GetDeviceBuffer()), + {% endif %} + static_cast(c_m_n_device_buf.GetDeviceBuffer()), + {% if is_batched %} + B, + {% endif %} + M, + N, + K, + LDA, + LDB, + LDC, + LDD, + nullptr, // workspace_size + nullptr, // workspace + nullptr); // stream + + hip_check_error(hipDeviceSynchronize()); + + return 0; + } // run_main + } // extern C + + int main(int argc, char** argv) { + return run_main(argc, argv); + } + // compile with: {{compile_cmd}} + #endif // GENERATE_CK_STANDALONE_RUNNER + """ + + def __init__( + self, + input_nodes: list[Buffer], + layout: Layout, + alpha: float, + beta: float, + input_reorder: Optional[list[int]] = None, + ) -> None: + is_batched = len(layout.size) == 3 + name = "ck_batched_gemm_template" if is_batched else "ck_gemm_template" + super().__init__( + name=name, + input_nodes=input_nodes, + layout=layout, + input_reorder=input_reorder, + ) + self.alpha = alpha + self.beta = beta + self.is_batched = is_batched + + def header(self) -> IndentedBuffer: + res = super().header() + if self.is_batched: + res.splice( + """ + // CK GEMM header(s) + + #include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp" + """ + ) + else: + res.splice( + """ + // CK GEMM header(s) + + #include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3.hpp" + """ + ) + return res + + def globals(self) -> IndentedBuffer: + res = super().globals() + res.splice( + """ + // CK GEMM globals + + using Row = ck::tensor_layout::gemm::RowMajor; + using Col = ck::tensor_layout::gemm::ColumnMajor; + + using BlockGemmPipelineScheduler = ck::BlockGemmPipelineScheduler; + using GemmSpecialization = ck::tensor_operation::device::GemmSpecialization; + using BlockGemmPipelineVersion = ck::BlockGemmPipelineVersion; + + struct MultiplyMultiplyAdd { + template + __host__ __device__ constexpr void + operator()(E& e, const C& c, const D0& d0, const D1& d1, const D2& d2) const { + e = ck::type_convert( + ck::type_convert(c) + * ck::type_convert(d0) + * ck::type_convert(d1) + + ck::type_convert(d2) + ); + } + }; + """ + ) + return res + + def inline_utils(self): + res = IndentedBuffer() + res.splice( + """ + #include "host_tensor.cpp" + #include "device_memory.cpp" + """ + ) + return res + + def _has_padding(self, dimension, gemm_specialization): + # Get the relevant padding map for the given dimension + dimension_padding = padding_lookup.get(dimension, {}) + + # Check if the specialization is in the dimension's padding map + return dimension_padding.get(gemm_specialization, False) + + def filter_op(self, op_info: InductorROCmOp): + """ + Determines whether a given op definition is suitable for the current + input / output of the operation that this template implements. + + Filter is based on inputs' dtype, layout and statically inferred size. + + Returns None if the op is not suitable, otherwise returns the op to be used. + """ + op, kBatch = op_info.op, op_info.kBatch + metas = [T.get_layout() for T in [*self.input_nodes, self.output_node]] + X_meta = metas[0] + W_meta = metas[1] + Y_meta = metas[-1] + # disable the instance if dtypes don't match + if op.a_element_dtype != self._TORCH_DTYPE_TO_CK[X_meta.dtype]: + return None + if op.b_element_dtype != self._TORCH_DTYPE_TO_CK[W_meta.dtype]: + return None + if op.c_element_dtype != self._TORCH_DTYPE_TO_CK[Y_meta.dtype]: + return None + # disable the instance if layouts don't match + if op.a_layout != torch_layout_to_ck_layout(X_meta): + return None + if op.b_layout != torch_layout_to_ck_layout(W_meta): + return None + if op.c_layout != torch_layout_to_ck_layout(Y_meta): + return None + # try to avoid launching the instance with invalid problem size + # see GridwiseGemm_xdl_cshuffle_v3::CheckValidity + + M = X_meta.size[-2] + K = X_meta.size[-1] + N = W_meta.size[-1] + + if is_static_int(M): + if not self._has_padding("M", op.gemm_specialization): + if M % op.m_per_block != 0: + return None + if is_static_int(N): + if not self._has_padding("N", op.gemm_specialization): + if N % op.n_per_block != 0: + return None + if is_static_int(K): + if not self._has_padding("K", op.gemm_specialization): + if K % op.k_per_block != 0: + return None + K_t = kBatch * op.k_per_block + if K % K_t != 0: + return None + else: + # need another kBatch check here + lcm = abs(op.a_k1 * op.b_k1) // math.gcd(op.a_k1, op.b_k1) + K_t = kBatch * lcm + k_read_pad_splited = math.ceil(K / K_t) * lcm + if (k_read_pad_splited * (kBatch - 1)) >= K: + return None + + a_contig_size = ( + K if op.a_layout == "Row" else M if op.a_layout == "Col" else None + ) + if ( + is_static_int(a_contig_size) + and a_contig_size % op.a_block_transfer_src_scalar_per_vector != 0 + ): + return None + b_contig_size = ( + N if op.b_layout == "Row" else K if op.b_layout == "Col" else None + ) + if ( + is_static_int(b_contig_size) + and b_contig_size % op.b_block_transfer_src_scalar_per_vector != 0 + ): + return None + c_contig_size = ( + N if op.c_layout == "Row" else M if op.c_layout == "Col" else None + ) + c_shuffle_block_transfer_scalar_per_vector_n_per_block = ( + op.c_shuffle_block_transfer_scalar_per_vector_n_per_block[0] + if isinstance( + op.c_shuffle_block_transfer_scalar_per_vector_n_per_block, tuple + ) + else op.c_shuffle_block_transfer_scalar_per_vector_n_per_block + ) + if ( + is_static_int(c_contig_size) + and c_contig_size % c_shuffle_block_transfer_scalar_per_vector_n_per_block + != 0 + ): + return None + if not self._check_num_k_loops(op, kBatch): + return None + # TBD disable instances with invalid number of pipeline prefetch stages + # It will avoid compiling a small percentage of unrunnable instances which fail the gemm argument check + + return op + + def _check_num_k_loops(self, op, kBatch): + # Additional splitK scenario check + metas = [T.get_layout() for T in [*self.input_nodes]] + X_meta = metas[0] + W_meta = metas[1] + K = X_meta.size[-1] + if kBatch > 1: + if op.block_gemm_pipeline_version != "BlockGemmPipelineVersion::v1": + try: + prefetch_stages = self._prefetch_stages( + op, + torch.empty((), dtype=X_meta.dtype).element_size(), + torch.empty((), dtype=W_meta.dtype).element_size(), + torch.cuda.get_device_properties(X_meta.device).warp_size, + ) + except Exception as e: + log.debug( # noqa: G200 + "Failed to prefetch_stages for %s with exception %s", op.name, e + ) + # be conservative here and disable the op + return False + + K_t = op.k_per_block * kBatch + ak0 = (K + K_t - 1) // K_t * (op.k_per_block // op.a_k1) + num_k_loop = ak0 // (op.k_per_block // op.a_k1) + if num_k_loop <= prefetch_stages: + log.debug( + "Op %s is not compatible due to invalid number of pipeline prefetch stages. " + "Parameters: kBatch=%s, block_gemm_pipeline_version=%s, prefetch_stages=%s, num_k_loop=%s", + op.name(), + kBatch, + op.block_gemm_pipeline_version, + prefetch_stages, + num_k_loop, + ) + return False + + return True + + # small helper to figure out the prefetch stages on AMD + def _prefetch_stages(self, op, a_dtype_size, b_dtype_size, warp_size: int = 64): + version_str = op.block_gemm_pipeline_version.split("::")[-1] + try: + version = int(version_str[1:]) # Assuming the format is always 'vX' + except ValueError as e: + raise ValueError(f"Invalid version string: {version_str}") from e + if version not in [1, 2, 3, 4, 5]: + raise ValueError( + f"unknown prefetch stages for {op.block_gemm_pipeline_version}" + ) + # Define the mapping of versions to stages + version_to_stages = {1: 1, 3: 2, 4: 4, 5: 3} + # Get the stages for the given version + stages = version_to_stages.get(version) + if stages is None: + # This means we're at stage 2, and this requires computation + # See github.com/ROCm/composable_kernel/blob/d6a4605/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2.hpp#L143 # noqa: B950 + wgp_per_cu = max(4 * warp_size // op.block_size, 1) + full_mem_band_prefetch_stages = math.ceil( + 32768 + / wgp_per_cu + / ( + (op.m_per_block * a_dtype_size + op.n_per_block * b_dtype_size) + * op.k_per_block + ) + ) + stages = min(max(full_mem_band_prefetch_stages, 2), 8) + + return stages + + def emit_ck_instance(self, op: "CKGemmOperation"): + # The Jinja template for generating a C++ type alias *definition* for a Universal GEMM instance + struct_name = ( + "DeviceBatchedGemmMultiD_Xdl_CShuffle_V3" + if self.is_batched + else "DeviceGemmMultiD_Xdl_CShuffle_V3" + ) + template_definition = r""" + // Gemm operator {{operation_name}} + using Operation_{{operation_name}} = + ck::tensor_operation::device::{{struct_name}}< + {{template_params}}>; + +""" + # The Jinja template for generating a C++ type alias *usage* for a Universal GEMM instance + template_type = r""" + Operation_{{operation_name}} +""" + template_params = [] + for field_name, field_value in op.dict_items(): + if isinstance(field_value, tuple): + tuple_elements = ", ".join(map(str, iter(field_value))) + if "ds" in field_name: # element type and layout for bias + arg = f"/* {field_name} */ Tuple<{tuple_elements}>" + else: # tile shape + arg = f"/* {field_name} */ S<{tuple_elements}>" + # pyrefly: ignore [bad-argument-type] + template_params.append(arg) + else: + if field_value is not None: + # pyrefly: ignore [bad-argument-type] + template_params.append(f"/* {field_name} */ {field_value}") + operation_name = op.name().replace("(", "").replace(",", "").replace(")", "") + return self._template_from_string(template_definition).render( + operation_name=operation_name, + template_params=(",\n" + 12 * " ").join(template_params), + struct_name=struct_name, + ), self._template_from_string(template_type).render( + operation_name=operation_name + ) + + def render( # type: ignore[override] + self, + kernel: ROCmTemplateKernel, + op: "CKGemmOperation", + **kwargs, + ) -> str: + """ + The primary entry point for the code rendering process used in this template. + """ + epilogue_nodes = kwargs.get("epilogue_nodes") + assert epilogue_nodes is None or 0 == len(epilogue_nodes) + template_buffer_node = kwargs.get("template_buffer_node") + if template_buffer_node is not None: + self.output_node = template_buffer_node + # input nodes: + # * X, W for matmul + # * X, W, Bias for addmm + # * X, W, inv_scale_x, inv_scale_w for scaled_mm + # * X, W, inv_scale_x, inv_scale_w, Bias for scaled_mm with bias + X, W = self.input_nodes[0], self.input_nodes[1] + Y = self.output_node + Bias = ( + self.input_nodes[2] + if 3 == len(self.input_nodes) + else self.input_nodes[4] + if 5 == len(self.input_nodes) + else None + ) + has_bias = Bias is not None + has_scale = len(self.input_nodes) in (4, 5) + op = copy.deepcopy(op) + + # This parameter is converted into tuple because of change + # from DeviceGemm_Xdl_CShuffleV3 to DeviceGemmMultiD_Xdl_CShuffle_V3. + # The first tuple element corresponds to matmul result... + if not isinstance( + op.c_shuffle_block_transfer_scalar_per_vector_n_per_block, tuple + ): + op.c_shuffle_block_transfer_scalar_per_vector_n_per_block = ( + op.c_shuffle_block_transfer_scalar_per_vector_n_per_block, + ) + + if has_scale: + scale_x = self.input_nodes[2] + scale_w = self.input_nodes[3] + if 1 == scale_x.get_numel() and 1 == scale_w.get_numel(): + # tensorwise scale for both X, W + if has_bias: + op.c_elementwise_op = "ScaleAdd" + else: + op.c_elementwise_op = "Scale" + else: + # rowwise scale for both X, W + if has_bias: + op.c_elementwise_op = "MultiplyMultiplyAdd" + else: + op.c_elementwise_op = "MultiplyMultiply" + op.c_shuffle_dtype = "F32" + op.ds_layouts = ( + torch_layout_to_ck_layout(scale_x.get_layout()), + torch_layout_to_ck_layout(scale_w.get_layout()), + ) + op.ds_element_dtypes = ( + self._TORCH_DTYPE_TO_CK[scale_x.get_layout().dtype], + self._TORCH_DTYPE_TO_CK[scale_w.get_layout().dtype], + ) + op.c_shuffle_block_transfer_scalar_per_vector_n_per_block += (1, 1) + else: + scale_x = None + scale_w = None + + bias_dtype = "" + if Bias is not None: + bias_layout = torch_layout_to_ck_layout(Bias.get_layout()) + bias_dtype = self._TORCH_DTYPE_TO_CK[Bias.get_layout().dtype] + op.ds_layouts += (bias_layout,) + op.ds_element_dtypes += (bias_dtype,) + if not has_scale: + op.c_elementwise_op = "Bilinear" + # c_shuffle_dtype is also used for adding bias to matmul result + # before converting down to the result dtype + op.c_shuffle_dtype = op.acc_dtype + # this parameter needs to be set accordingly to bias stride for correct accumulation + if bias_layout == "Row": + # bias has (N, ) shape + bias_shuffle_block_transfer_scalar_per_vector_n_per_block = ( + op.c_shuffle_block_transfer_scalar_per_vector_n_per_block + ) + elif bias_layout == "Col": + # bias has (M, 1) shape + bias_shuffle_block_transfer_scalar_per_vector_n_per_block = (1,) + else: + raise AssertionError( + "Bias layout is neither row-major nor column-major" + ) + # ...and the second tuple element corresponds to the bias + op.c_shuffle_block_transfer_scalar_per_vector_n_per_block += ( + bias_shuffle_block_transfer_scalar_per_vector_n_per_block + ) + + instance_definition, instance_type = self.emit_ck_instance(op) + + version_comment = rf"""/** +* Generated code for CK inductor backend +* See {type(self).__module__}.{type(self).__qualname__} +* +* Template instance {op} +* +* {torch.__version__=} +* torch.version.git_version={getattr(torch.version, "git_version", "None")} +*/ +""" + epilogue = None + + if op.c_elementwise_op == "Bilinear" and scale_w is None: + epilogue = f"Bilinear {{ {self.alpha}, {self.beta} }}" + + elif op.c_elementwise_op == "Scale": + epilogue = "Scale { (inv_scale_w && inv_scale_x) ? (*inv_scale_w * *inv_scale_x) : 1.0f }" + + elif op.c_elementwise_op == "ScaleAdd": + epilogue = "ScaleAdd { (inv_scale_w && inv_scale_x) ? (*inv_scale_w * *inv_scale_x) : 1.0f }" + + elif op.c_elementwise_op == "MultiplyMultiply": + epilogue = "MultiplyMultiply {}" + + elif op.c_elementwise_op == "MultiplyMultiplyAdd": + epilogue = "MultiplyMultiplyAdd {}" + + elif op.c_elementwise_op == "PassThrough": + epilogue = "PassThrough {}" + + assert epilogue is not None, "CK GEMM epilogue is not set" + + size_arg_strs = ["M", "N", "K", "LDA", "LDB", "LDC", "LDD"] + if self.is_batched: + size_arg_strs.insert(0, "B") + + res = self._template_from_string(self.gemm_template).render( + inline_utils=self.inline_utils(), + headers=self.header().getvalue(), + globals=self.globals().getvalue(), + instance_definition=instance_definition, + kernel_definition=kernel.def_kernel( + inputs=[X, W, scale_x, scale_w, Bias], # type: ignore[list-item] + outputs=[Y], + names_str="X, W, inv_scale_x, inv_scale_w, Bias, Y", + input_reorder=self.input_reorder, + size_args=[f"int32_t {arg}" for arg in size_arg_strs], + ), + instance_type=instance_type, + a_element_dtype=op.a_element_dtype, + b_element_dtype=op.b_element_dtype, + c_element_dtype=op.c_element_dtype, + bias_element_dtype=bias_dtype, + alpha=self.alpha, + beta=self.beta, + a_elementwise_op="PassThrough {}", + b_elementwise_op="PassThrough {}", + epilogue=epilogue, + has_bias=has_bias, + ds_size=1 + if op.c_elementwise_op in ("Bilinear", "ScaleAdd") + else 2 + if op.c_elementwise_op == "MultiplyMultiply" + else 3 + if op.c_elementwise_op == "MultiplyMultiplyAdd" + else 0, + ds_names=", ".join( + ["Bias"] + if op.c_elementwise_op in ("Bilinear", "ScaleAdd") + else ["inv_scale_x", "inv_scale_w"] + if op.c_elementwise_op == "MultiplyMultiply" + else ["inv_scale_x", "inv_scale_w", "Bias"] + if op.c_elementwise_op == "MultiplyMultiplyAdd" + else [] + ), + ds_strides=", ".join( + ["LDD"] + if op.c_elementwise_op in ("Bilinear", "ScaleAdd") + else ["0", "0"] + if op.c_elementwise_op == "MultiplyMultiply" + else ["0", "0", "LDD"] + if op.c_elementwise_op == "MultiplyMultiplyAdd" + else [] + ), + version_comment=version_comment, + is_batched=self.is_batched, + ds_batch_strides=", ".join([]), # FIXME when supporting baddbmm + ) + + if config.rocm.generate_test_runner: + is_static_problem = all(is_static_int(arg) for arg in self.size_args()) + # NOTE: size_arg_strs is defined above + size_arg_vals = ( + self.size_args() + if is_static_problem + else ( + f"std::stoi(argv[{k}])" for k, _ in enumerate(self.size_args(), 1) + ) + ) + size_args = dict(zip(size_arg_strs, size_arg_vals, strict=True)) + runtime_args = dict( + zip( + [a.name for a in self.get_runtime_arg_info()], + self.get_runtime_arg_values(), + ) + ) + runner_code = self._template_from_string( + self.standalone_runner_template + ).render( + inline_utils=self.inline_utils().getvalue(), + kernel_name=kernel.kernel_name, + has_bias=has_bias, + has_scale=has_scale, + is_batched=self.is_batched, + a_ck_dtype=op.a_element_dtype, + b_ck_dtype=op.b_element_dtype, + c_ck_dtype=op.c_element_dtype, + bias_ck_dtype=op.ds_element_dtypes[0] if has_bias else "", + scale_a_ck_dtype=op.ds_element_dtypes[0] + if has_scale and 2 == len(op.ds_element_dtypes) + else "BF16", + scale_b_ck_dtype=op.ds_element_dtypes[1] + if has_scale and 2 == len(op.ds_element_dtypes) + else "BF16", + a_torch_dtype=DTYPE_TO_CPP[X.get_layout().dtype], + b_torch_dtype=DTYPE_TO_CPP[W.get_layout().dtype], + c_torch_dtype=DTYPE_TO_CPP[Y.get_layout().dtype], + bias_torch_dtype=DTYPE_TO_CPP[Bias.get_layout().dtype] + if Bias is not None + else "", + scale_a_torch_dtype=DTYPE_TO_CPP[scale_x.get_layout().dtype] + if scale_x is not None + else "", + scale_b_torch_dtype=DTYPE_TO_CPP[scale_w.get_layout().dtype] + if scale_w is not None + else "", + a_layout=torch_layout_to_ck_layout(X.get_layout()), + b_layout=torch_layout_to_ck_layout(W.get_layout()), + c_layout=torch_layout_to_ck_layout(Y.get_layout()), + bias_layout=torch_layout_to_ck_layout(Bias.get_layout()) + if Bias is not None + else "", + compile_cmd=rocm_compile_command( + [""], "", "exe" + ), + **size_args, + **runtime_args, + ) + res += runner_code + + return res + + def _is_rcr_f16(self): + X_meta, W_meta, Y_meta = ( + T.get_layout() for T in [*self.input_nodes, self.output_node] + ) + X_dtype, W_dtype, Y_dtype = ( + self._TORCH_DTYPE_TO_CK[m.dtype] for m in (X_meta, W_meta, Y_meta) + ) + X_layout, W_layout, Y_layout = ( + torch_layout_to_ck_layout(m) for m in (X_meta, W_meta, Y_meta) + ) + + return ( + X_dtype == "F16" + and W_dtype == "F16" + and Y_dtype == "F16" + and X_layout == "Row" + and W_layout == "Col" + and Y_layout == "Row" + ) + + # helper to calculate a potentially optimal kBatch(es) for a problem + def _get_kBatch(self, op): + # we only set a higher kBatch if K > 16 * the larger of M and N + # this is a hand-tuned heuristic to start + metas = [T.get_layout() for T in [*self.input_nodes]] + X_meta = metas[0] + W_meta = metas[1] + M = X_meta.size[-2] + K = X_meta.size[-1] + N = W_meta.size[-1] + if is_dynamic(*self.input_nodes): + return [1] + if K // max(M, N) < config.rocm.split_k_threshold: + return [1] + # if the user is telling us which kBatches to sweep, just use those + if config.rocm.kBatch_sweep is not None: + return config.rocm.kBatch_sweep + # Calculate the number of blocks needed for each dimension + total_k_blocks = math.ceil(K / op.k_per_block) + # we want to calculate how many blocks we need to fit per CU + cus = torch.cuda.get_device_properties(X_meta.device).multi_processor_count + # again, manual heuristics as much larger kBatch are significantly worse in + # initial testing + kBatch = min(max(next_power_of_2(total_k_blocks // cus), 1), 128) + return [kBatch] + + def gen_ops(self) -> list[InductorROCmOp]: + """ + Creates a list of `CKGemmOperation` instances that match the GEMM operation this template represents. + The instances are guaranteed to have the correct layout, dtype and dimension padding for the GEMM input arguments. + + An instance may invalidate the GEMM configuration at runtime. + Such instances will be assigned +inf runtime by the autotune process. + """ + try: + from ck4inductor.batched_universal_gemm.gen_instances import ( # type: ignore[import] + gen_ops_library as gen_batched_gemm_ops_library, + ) + from ck4inductor.universal_gemm.gen_instances import ( # type: ignore[import] + gen_ops_library as gen_gemm_ops_library, + gen_ops_preselected as gen_gemm_ops_preselected, + ) + except ImportError: + return [] + + generator = None + if self.is_batched: + generator = gen_batched_gemm_ops_library + else: + generator = gen_gemm_ops_library + if config.rocm.use_preselected_instances and self._is_rcr_f16(): + generator = gen_gemm_ops_preselected + + assert generator is not None + + rops = generator() + ops = [] + for o in rops: + kBatches = self._get_kBatch(o) + for kBatch in kBatches: + # pyrefly: ignore [bad-argument-type] + ops.append(InductorROCmOp(op=o, kBatch=kBatch)) + + filtered_instances = list(filter(lambda op: self.filter_op(op), ops)) + + # NB: when using a fixed list order, most likely we will pick the subset of instances + # which are very similar to each other. Randomizing the choice seems to solve this. + random.seed(-11) + chosen_instances = ( + random.sample( + filtered_instances, + min(len(filtered_instances), config.rocm.ck_max_profiling_configs), + ) + if config.rocm.ck_max_profiling_configs + else filtered_instances + ) + log.debug( + "generated %d ck instances after filter: %s", + len(chosen_instances), + chosen_instances, + ) + return chosen_instances + + @staticmethod + def add_ck_gemm_choices( + choices, + layout, + input_nodes, + alpha=1, + beta=0, + input_reorder=None, + ): + """ + Add Composable Kernel Universal GEMM instance choices to the auto-tuning list. + """ + template = CKGemmTemplate( + input_nodes, + layout, + alpha=alpha, + beta=beta, + input_reorder=input_reorder, + ) + ops = template.gen_ops() + for op in ops: + template.maybe_append_choice( + choices, + op=op.op, + kBatch=op.kBatch, + ) + + def size_args(self): + X = self.input_nodes[0] + W = self.input_nodes[1] + Bias = ( + self.input_nodes[2] + if len(self.input_nodes) == 3 + else self.input_nodes[4] + if len(self.input_nodes) == 5 + else None + ) + Y = self.output_node + + M = X.get_size()[-2] + K = X.get_size()[-1] + N = W.get_size()[-1] + LDA = X.get_stride()[-2 if X.get_stride()[-1] == 1 else -1] + LDB = W.get_stride()[-2 if W.get_stride()[-1] == 1 else -1] + LDC = Y.get_stride()[-2 if Y.get_stride()[-1] == 1 else -1] + LDD = ( + 0 + if (Bias is None or len(Bias.get_size()) == 1) + else Bias.get_stride()[-2 if Bias.get_stride()[-1] == 1 else -1] + ) + if self.is_batched: + B = X.get_size()[0] + return B, M, N, K, LDA, LDB, LDC, LDD + else: + return M, N, K, LDA, LDB, LDC, LDD diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/compile_command.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/compile_command.py new file mode 100644 index 0000000000000000000000000000000000000000..aa935b14af23c2efd667871df5e05798a4434fa8 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/compile_command.py @@ -0,0 +1,153 @@ +# mypy: allow-untyped-defs +import logging +import os +from typing import Optional + +from torch._inductor import config +from torch._inductor.utils import is_linux, try_import_ck_lib + + +log = logging.getLogger(__name__) + + +def _rocm_include_paths(dst_file_ext: str) -> list[str]: + from torch.utils import cpp_extension + + rocm_include = ( + os.path.join(config.rocm.rocm_home, "include") + if config.rocm.rocm_home + else cpp_extension._join_rocm_home("include") + ) + + if config.is_fbcode(): + from libfb.py import parutil + + ck_path = parutil.get_dir_path("composable-kernel-headers") + else: + if not config.rocm.ck_dir: + ck_dir, _, _, _ = try_import_ck_lib() + if not ck_dir: + log.warning("Unspecified Composable Kernel directory") + config.rocm.ck_dir = ck_dir + ck_path = config.rocm.ck_dir or cpp_extension._join_rocm_home( + "composable_kernel" + ) + + log.debug("Using ck path %s", ck_path) + + ck_include = os.path.join(ck_path, "include") + ck_library_include = os.path.join(ck_path, "library", "include") + + # CK has to take priority over ROCm include paths + # Since CK is potentially more up-to-date + paths = [ + os.path.realpath(p) for p in (ck_include, ck_library_include, rocm_include) + ] + if dst_file_ext == "exe": + ck_utility_include = os.path.join(ck_path, "library", "src", "utility") + paths.append(os.path.realpath(ck_utility_include)) + return paths + + +def _rocm_lib_options(dst_file_ext: str) -> list[str]: + from torch.utils import cpp_extension + + rocm_lib_dir = ( + os.path.join(config.rocm.rocm_home, "lib") + if config.rocm.rocm_home + else cpp_extension._join_rocm_home("lib") + ) + hip_lib_dir = ( + os.path.join(config.rocm.rocm_home, "hip", "lib") + if config.rocm.rocm_home + else cpp_extension._join_rocm_home("hip", "lib") + ) + + opts = [ + "-include __clang_hip_runtime_wrapper.h", + f"-L{os.path.realpath(rocm_lib_dir)}", + f"-L{os.path.realpath(hip_lib_dir)}", + "-lamdhip64", + ] + if dst_file_ext == "exe": + opts += ["-lpthread", "-lstdc++"] + return opts + + +def _rocm_compiler_options() -> list[str]: + arch_list = config.rocm.arch or ["native"] + gpu_arch_flags = [f"--offload-arch={arch}" for arch in arch_list] + opts = [ + config.rocm.compile_opt_level, + "-x", + "hip", + "-std=c++17", + *gpu_arch_flags, + "-fno-gpu-rdc", + "-fPIC", + "-fvisibility=hidden", + "-mllvm", + "-amdgpu-early-inline-all=true", + "-mllvm", + "-amdgpu-function-calls=false", + "-mllvm", + "-enable-post-misched=0", + ] + if config.rocm.is_debug: + opts += ["-DDEBUG_LOG=1", "-g"] + if config.rocm.save_temps: + opts += ["--save-temps=obj"] + if config.rocm.print_kernel_resource_usage: + opts += ["-Rpass-analysis=kernel-resource-usage"] + if config.rocm.flush_denormals: + opts += ["-fgpu-flush-denormals-to-zero"] + if config.rocm.use_fast_math: + opts += ["-ffast-math"] + return opts + + +def rocm_compiler() -> Optional[str]: + if is_linux(): + if config.rocm.rocm_home: + return os.path.realpath( + os.path.join(config.rocm.rocm_home, "llvm", "bin", "clang") + ) + try: + from torch.utils import cpp_extension + + return os.path.realpath( + cpp_extension._join_rocm_home("llvm", "bin", "clang") + ) + except OSError: + # neither config.rocm.rocm_home nor env variable ROCM_HOME are set + return "clang" + return None + + +def rocm_compile_command( + src_files: list[str], + dst_file: str, + dst_file_ext: str, + extra_args: Optional[list[str]] = None, +) -> str: + include_paths = _rocm_include_paths(dst_file_ext) + lib_options = _rocm_lib_options(dst_file_ext) + compiler_options = _rocm_compiler_options() + compiler = rocm_compiler() + options = ( + compiler_options + + (extra_args or []) + + [f"-I{path}" for path in include_paths] + + lib_options + ) + src_file = " ".join(src_files) + # supported extensions: .o, .so, .exe + if dst_file_ext == "o": + options.append("-c") + elif dst_file_ext == "so": + options.append("-shared") + elif dst_file_ext == "exe": + options.append("-DGENERATE_CK_STANDALONE_RUNNER") + else: + raise NotImplementedError(f"Unsupported output file suffix {dst_file_ext}!") + return f"{compiler} {' '.join(options)} -o {dst_file} {src_file}" diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/rocm_benchmark_request.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/rocm_benchmark_request.py new file mode 100644 index 0000000000000000000000000000000000000000..c5a87bef820dfc76037b5294b00a5f25f26be223 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/rocm_benchmark_request.py @@ -0,0 +1,143 @@ +# mypy: allow-untyped-defs +from __future__ import annotations + +import functools +import logging +from ctypes import byref, c_int, c_size_t, c_void_p +from typing import Any, Optional, TYPE_CHECKING, Union + +import torch +from torch._inductor import config +from torch._inductor.autotune_process import ( + BenchmarkRequest, + GPUDeviceBenchmarkMixin, + TensorMeta, +) +from torch._inductor.codecache import DLLWrapper, ROCmCodeCache + + +if TYPE_CHECKING: + from collections.abc import Callable, Iterable + + +log = logging.getLogger(__name__) + + +class ROCmBenchmarkRequest(GPUDeviceBenchmarkMixin, BenchmarkRequest): + # Important: Instances of this class have to be serializable + # across process boundaries. Do not put CUDA Tensors in here! + + def __init__( + self, + kernel_name: str, + input_tensor_meta: Union[TensorMeta, list[TensorMeta]], + output_tensor_meta: Union[TensorMeta, list[TensorMeta]], + extra_args: Iterable[Any], + source_code: str, + ) -> None: + super().__init__(kernel_name, input_tensor_meta, output_tensor_meta, extra_args) + self.source_code = source_code + self.workspace_size: int = 0 + self.workspace: Optional[torch.Tensor] = None + self.DLL: Optional[DLLWrapper] = None + self._workspace_size_updated = False + self.hash_key: str = "" + self.source_file: str = "" + self.hash_key, self.source_file = ROCmCodeCache.write(self.source_code, "so") + + def precompile(self): + # Prepopulate code cache + # may happen in separate Threadpool + log.debug("Precompiling %s", self) + ROCmCodeCache.compile(self.source_code, "so") + if config.rocm.generate_test_runner: + ROCmCodeCache.compile(self.source_code, "exe") + log.debug("Done precompiling %s", self) + + def make_run_fn( + self, *input_tensors: torch.Tensor, out: torch.Tensor + ) -> Callable[[], None]: + self.ensure_dll_loaded() + self.update_workspace_size() + args = [c_void_p(tensor.data_ptr()) for tensor in list(input_tensors) + [out]] + size_args = [c_int(arg) for arg in self.extra_args] + log.debug( + "make_run_fn: self.kernel_name=%s, self.source_file=%s, self.hash_key=%s, self.DLL=%s, args=%s, self.extra_args=%s", + self.kernel_name, + self.source_file, + self.hash_key, + self.DLL, + args, + self.extra_args, + ) + stream_ptr = c_void_p(torch.cuda.current_stream().cuda_stream) + run_method = getattr(self.DLL, self.kernel_name) + workspace_ptr = c_void_p(0) + if self.workspace_size > 0: + self.workspace = torch.zeros( + (self.workspace_size + 7) // 8, + dtype=torch.float64, + device=out.device, + ) + workspace_ptr = c_void_p(self.workspace.data_ptr()) + + # Generate partial function. + return functools.partial( + run_method, + *args, + *size_args, + None, # null workspace size ptr + workspace_ptr, # set workspace ptr, + stream_ptr, + ) + + def update_workspace_size(self) -> None: + if self._workspace_size_updated: + return + self.ensure_dll_loaded() + unique_input_count = len( + dict.fromkeys(meta.name for meta in self.input_tensor_meta) + ) + args = [c_void_p(None) for _ in range(unique_input_count + 1)] + stream_ptr = c_void_p(torch.cuda.current_stream().cuda_stream) + + run_method = getattr(self.DLL, self.kernel_name) + # Retrieve workspace_size and initialize workspace. + c_workspace_size = c_size_t() + size_args = [c_int(arg) for arg in self.extra_args] + run_method( + *args, # input ptrs and output ptrs + *size_args, + byref( + c_workspace_size + ), # set workspace size ptr to retrieve workspace size + None, # null workspace ptr + stream_ptr, + ) + torch.cuda.synchronize() # shake out any CUDA errors + self.workspace_size = c_workspace_size.value + log.debug( + "update_workspace_size called: new workspace size=%d, self.kernel_name=%s, self.source_file=%s, self.hash_key=%s, self.DLL=%s, args=%s, self.extra_args=%s", # noqa: B950 + self.workspace_size, + self.kernel_name, + self.source_file, + self.hash_key, + self.DLL, + args, + self.extra_args, + ) + self._workspace_size_updated = True + + def ensure_dll_loaded(self): + if self.DLL is None: + self.DLL, self.hash_key, self.source_file = ROCmCodeCache.load( + self.source_code, "so" + ) + + def cleanup_run_fn(self) -> None: + if self.DLL is not None: + self.DLL.close() + self.workspace = None + + def __str__(self) -> str: + return f"{self.kernel_name=}, {self.source_file=}, {self.hash_key=}" diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/rocm_cpp_scheduling.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/rocm_cpp_scheduling.py new file mode 100644 index 0000000000000000000000000000000000000000..ec58e458df6b110fab0c452ec261861d0c2d7cef --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/rocm_cpp_scheduling.py @@ -0,0 +1,100 @@ +# mypy: allow-untyped-defs +import logging +from collections.abc import Sequence +from typing import cast + +from ... import config +from ...codecache import code_hash, get_path +from ...scheduler import BaseSchedulerNode, BaseScheduling, SchedulerNode +from ...utils import get_fused_kernel_name, get_kernel_metadata, sympy_product +from ...virtualized import V +from ..common import IndentedBuffer +from .rocm_template_buffer import ROCmTemplateBuffer + + +log = logging.getLogger(__name__) + + +class ROCmCPPScheduling(BaseScheduling): + """ + Partial Scheduling implementation for ROCm C++ Kernels. + This class is intended to be used in combination with TritonScheduling, + and delegated to by CUDACombinedScheduling. + + It handles fusion decisions and ROCm C++ specific template code generation. + """ + + def group_fn(self, sizes): + return tuple(V.graph.sizevars.simplify(sympy_product(s)) for s in sizes) + + @staticmethod + def is_rocm_cpp_template(node: BaseSchedulerNode) -> bool: + return isinstance(node, SchedulerNode) and isinstance( + node.node, ROCmTemplateBuffer + ) + + def can_fuse_vertical( + self, node1: BaseSchedulerNode, node2: BaseSchedulerNode + ) -> bool: + return False + + def define_kernel(self, src_code: str, node_schedule) -> str: + wrapper = V.graph.wrapper_code + if src_code in wrapper.src_to_kernel: + kernel_name = wrapper.src_to_kernel[src_code] + else: + fused_name = ( + get_fused_kernel_name(node_schedule, config.triton.descriptive_names) + if config.triton.descriptive_names + else "" + ) + kernel_name = "_".join(["rocm", fused_name, wrapper.next_kernel_suffix()]) + # use the original src_code as the key + wrapper.src_to_kernel[src_code] = kernel_name + src_code = src_code.replace("KERNEL_NAME", kernel_name) + + _, _, kernel_path = get_path(code_hash(src_code), "py") + + compile_wrapper = IndentedBuffer() + compile_wrapper.writeline("async_compile.rocm(r'''") + compile_wrapper.splice(src_code, strip=True) + compile_wrapper.writeline( + f"''', 'so', aot_compile={str(V.graph.aot_mode)})" + ) + + metadata_comment = f"# kernel path: {kernel_path}" + origins, detailed_origins = get_kernel_metadata(node_schedule, wrapper) + metadata_comment += "\n" + origins + "\n" + detailed_origins + wrapper.define_kernel( + kernel_name, compile_wrapper.getvalue(), metadata_comment + ) + return kernel_name + + def codegen_template( + self, + template_node: BaseSchedulerNode, + epilogue_nodes: Sequence[BaseSchedulerNode], + prologue_nodes: Sequence[BaseSchedulerNode], + ): + """ + Codegen a ROCm template, possibly with fused epilogues + """ + assert self.is_rocm_cpp_template(template_node), ( + "Template node passed to ROCmScheduler.codegen_template must be a SchedulerNode that wraps a ROCmTemplateBuffer" + ) + template_node = cast(SchedulerNode, template_node) + _, (_numel, rnumel) = template_node.group + assert rnumel == 1 + ctb: ROCmTemplateBuffer = cast(ROCmTemplateBuffer, template_node.node) + kernel, render = ctb.make_kernel_render(ctb) # type: ignore[misc] + with kernel: + template_node.mark_run() + src_code = render() + + with V.set_kernel_handler(kernel): + node_schedule = [template_node] + kernel_name = self.define_kernel(src_code, node_schedule) + self.codegen_comment(node_schedule, kernel_name) + kernel.call_kernel(kernel_name, ctb) + V.graph.removed_buffers |= kernel.removed_buffers + self.free_buffers_in_scheduler() diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/rocm_kernel.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/rocm_kernel.py new file mode 100644 index 0000000000000000000000000000000000000000..c5981e129cb192d1f6e0ce1f445401e8c7e51b5e --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/rocm_kernel.py @@ -0,0 +1,297 @@ +# mypy: allow-untyped-defs +import logging +from collections.abc import Callable, Sequence +from typing import Any, Optional, TYPE_CHECKING, Union + +import torch._inductor.config as config +from torch._inductor.codegen.cpp_wrapper_cpu import CppWrapperCpu +from torch._inductor.utils import do_bench_using_profiling + +from ...ir import ( + Buffer, + ChoiceCaller, + IRNode, + Layout, + PrimitiveInfoType, + ShapeAsConstantBuffer, + TensorBox, +) +from ...virtualized import V +from ..common import Kernel, OpOverrides, WorkspaceArg, WorkspaceZeroMode +from ..cpp_utils import CppPrinter +from .rocm_benchmark_request import ROCmBenchmarkRequest +from .rocm_template_buffer import ROCmTemplateBuffer +from .rocm_utils import DTYPE_TO_ROCM_TYPE + + +if TYPE_CHECKING: + from torch._inductor.codegen.rocm.rocm_template import ArgInfo, ROCmTemplate + +log = logging.getLogger(__name__) + +cexpr = CppPrinter().doprint + + +def _normalize_idx(index: int, total_length: int) -> int: + return index if index >= 0 else index + total_length + + +class ROCmKernel(Kernel): + """ + Baseclass for ROCm based Kernels + """ + + overrides = OpOverrides # type: ignore[assignment] + + +class ROCmTemplateKernel(ROCmKernel): + """ + Template kernels defined by ROCm in C++. + """ + + _EXTRA_CPP_ARGS = "size_t* workspace_size, uint8_t* workspace, hipStream_t stream" + + def __init__( + self, + kernel_name: str, + runtime_arg_info: list["ArgInfo"], + runtime_arg_values: list[Any], + ) -> None: + """ + Initializes a new instance of the ROCmTemplateKernel class. + + Args: + kernel_name (str): The name of the kernel. + """ + super().__init__() + self.kernel_name = kernel_name + # Mapping from arg name to IRNode. + self.named_nodes: dict[str, IRNode] = {} + self.runtime_arg_info = runtime_arg_info + self.runtime_arg_values = runtime_arg_values + + def get_signature(self): + return self.signature + + def def_kernel( + self, + inputs: list[IRNode], + outputs: list[IRNode], + size_args: list[str], + names_str: str = "", + input_reorder: Optional[list[int]] = None, + ) -> str: + """ + Hook called from template code to generate function definition and + needed args. + + Args: + inputs: List of input IRNodes + outputs: List of output IRNodes + names_str: Comma separated list of input + output argument names. + input_reorder: The actual order of input nodes. + e.g. The template might have input argument defined as [X, W, Bias], + and the actual input passed into this template could be [Bias, X, W]. + In this case, the `input_reorder` would be [2, 0, 1]. + """ + names = [x.strip() for x in names_str.strip().split(",")] + if len(inputs) + len(outputs) != len(names): + raise RuntimeError( + f"{len(inputs) + len(outputs)=} != {len(names)=}, {inputs=}, {outputs=}, {names=}" + ) + + if input_reorder == [2, 0, 1]: + input_reorder = [4, 0, 1, 2, 3] + + if input_reorder is not None: + assert len(inputs) == len(input_reorder) + else: + input_reorder = list(range(len(inputs))) + + for idx in input_reorder: + name = names[idx] + node = inputs[idx] + if node is not None: + self.named_nodes[name] = node + self.args.input_buffers[node.get_name()] = name + + for name, node in zip(names[len(inputs) : len(inputs) + len(outputs)], outputs): + if node is not None: + self.named_nodes[name] = node + self.args.output_buffers[node.get_name()] = name + + arg_defs, *_ = self.args.cpp_argdefs(DTYPE_TO_ROCM_TYPE) + + runtime_arg_defs = [f"{arg.ty} {arg.name}" for arg in self.runtime_arg_info] + + signature = f"int {self.kernel_name}({', '.join(arg_defs + size_args + runtime_arg_defs)},{self._EXTRA_CPP_ARGS})" + self.signature = signature + return signature + + def call_kernel( + self, + name: str, + node: "ROCmTemplateBuffer", # type: ignore[name-defined] + ) -> None: + """ + Generates code to call the kernel through V.graph.wrapper_code. + used from within torch._inductor.wrapper.PythonWrapperCodegen + + name: Name of kernel function. + node: The ROCmTemplateBuffer node which contains information about the kernel, it's fused epilogue nodes + as well as all required inputs and outputs. + """ + wrapper = V.graph.wrapper_code + + arg_types: list[Any] + if V.graph.cpp_wrapper: + # Make sure we initialize these kernels since they're exported as + # C-style symbol names. + assert isinstance(wrapper, CppWrapperCpu) + wrapper.initialized_kernels[name] = self + # Kinda hacky because we always originally initialize name with "KERNEL_NAME" + # So, we replace with the real kernel name passed as an arg to this function. + self.signature = self.signature.replace("KERNEL_NAME", name) + _, call_args, arg_types = self.args.cpp_argdefs(DTYPE_TO_ROCM_TYPE) + else: + _, call_args, _, arg_types = self.args.python_argdefs() + + kernel_args = [] + for arg in call_args: + # dynamo wraps unspec variable as 0d CPU tensor, need convert to scalar + if V.graph.is_unspec_arg(arg): + arg = arg + ".item()" + else: + if not V.graph.cpp_wrapper: + arg = f"c_void_p({arg}.data_ptr())" + kernel_args.append(arg) + + # add size args + size_args = [ + f"{V.graph.sizevars.simplify(sarg)}" for sarg in node.template.size_args() + ] + + if V.graph.cpp_wrapper: + kernel_args.extend(size_args) + else: + kernel_args.extend(f"c_int({sarg})" for sarg in size_args) + + if V.graph.cpp_wrapper: + arg_types.extend(["int"] * len(node.template.size_args())) + + # the runtime args come right after the size args + kernel_args.extend(self.runtime_arg_values) + for arg in self.runtime_arg_info: + arg_types.append(arg.ty) + + # workspace_size ptr is NULL to mark this call is not intended for retrieving workspace_size. + # workspace_size should have already been retrieved prior to this call. + kernel_args.append("nullptr" if V.graph.cpp_wrapper else "None") + if V.graph.cpp_wrapper: + arg_types.append("size_t*") + + if node.get_workspace_size() > 0: + ws = WorkspaceArg( + count=node.get_workspace_size(), + device=V.graph.get_current_device_or_throw(), + zero_mode=WorkspaceZeroMode.UNINITIALIZED, + outer_name=WorkspaceArg.unique_name(), + ) + wrapper.generate_workspace_allocation(ws) + data_ptr = f"{ws.outer_name}.data_ptr()" + kernel_args.append( + data_ptr if V.graph.cpp_wrapper else f"c_void_p({data_ptr})" + ) + else: + ws = None + kernel_args.append("nullptr" if V.graph.cpp_wrapper else "None") + if V.graph.cpp_wrapper: + arg_types.append("uint8_t*") + wrapper.generate_kernel_call( + name, + kernel_args, + triton=False, + arg_types=arg_types, + ) + if ws: + wrapper.generate_workspace_deallocation(ws) + + +class ROCmTemplateCaller(ChoiceCaller): + """ + ROCmTemplateCaller + + This class represents a caller for ROCm template kernels. It is a subclass of ChoiceCaller. + Attributes: + name (str): The name of the caller. + category (str): The category of the caller. + bmreq (ROCmBenchmarkRequest): The benchmark request for the caller. + template_buffer (ROCmTemplateBuffer): The template buffer for the caller. + """ + + def __init__( + self, + name: str, + category: str, + input_nodes: list[Buffer], + layout: Layout, + make_kernel_render: Callable[ + [ROCmTemplateBuffer, Optional[Sequence[IRNode]]], str + ], + bmreq: ROCmBenchmarkRequest, + template: "ROCmTemplate", # type: ignore[name-defined] + info_kwargs: Optional[ + dict[str, Union[PrimitiveInfoType, list[PrimitiveInfoType]]] + ], # type: ignore[type-arg] + ) -> None: + super().__init__(name, input_nodes, layout, description="") + self.category = category + self.make_kernel_render = make_kernel_render + self.bmreq = bmreq + self.template = template + self.info_kwargs = info_kwargs + + def precompile(self) -> None: + assert self.bmreq is not None + self.bmreq.precompile() + + def benchmark(self, *args, out) -> float: + assert self.bmreq is not None + if config.profile_bandwidth_with_do_bench_using_profiling: + algo = self.bmreq.make_run_fn(*args, out=out) + return do_bench_using_profiling(algo) + return self.bmreq.benchmark(*args, out=out) + + def __str__(self) -> str: + return f"ROCmTemplateCaller(source_file={self.bmreq.source_file}, {self.info_dict()})" + + def call_name(self) -> str: + return f"rocm_template_kernels.{self.name}" + + def hash_key(self) -> str: + return "-".join( + [ + self.category, + self.bmreq.hash_key, + ] + ) + + def info_dict(self) -> dict[str, Union[PrimitiveInfoType, list[PrimitiveInfoType]]]: + """Information returned here is logged to the autotune log file when that is enabled.""" + return { + "backend": "ROCm", + "name": self.name, + **dict(self.info_kwargs["op"].dict_items()), # type: ignore[union-attr, index] + } + + def output_node(self) -> Union[TensorBox, ShapeAsConstantBuffer]: + self.bmreq.update_workspace_size() + return TensorBox.create( + ROCmTemplateBuffer( + layout=self.layout, + inputs=self.input_nodes, + make_kernel_render=self.make_kernel_render, + workspace_size=self.bmreq.workspace_size, + template=self.template, + ) + ) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/rocm_template.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/rocm_template.py new file mode 100644 index 0000000000000000000000000000000000000000..bfeb03eabc72d7cf9bce701f535e612644a806c3 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/rocm_template.py @@ -0,0 +1,192 @@ +# mypy: allow-untyped-defs +import functools +import itertools +import logging +from collections.abc import Sequence +from dataclasses import dataclass +from typing import Any, Optional +from unittest.mock import patch + +from ...autotune_process import TensorMeta +from ...ir import Buffer, IRNode, Layout +from ...utils import IndentedBuffer, unique +from ...virtualized import V +from ..common import KernelTemplate +from .rocm_benchmark_request import ROCmBenchmarkRequest +from .rocm_kernel import ROCmTemplateCaller, ROCmTemplateKernel +from .rocm_template_buffer import ROCmTemplateBuffer +from .rocm_utils import DTYPE_TO_ROCM_TYPE + + +log = logging.getLogger(__name__) + + +# FIXME: unify with the CUDA version +@dataclass(frozen=True) +class ArgInfo: + name: str + ty: str + + +class ROCmTemplate(KernelTemplate): + index_counter = itertools.count() + gfx9_threads_per_warp = 64 + + def __init__( + self, + name: str, + input_nodes: list[Buffer], + layout: Layout, + input_reorder: Optional[list[int]] = None, + ) -> None: + """ + + Baseclass for ROCm C++ Templates, derived from KernelTemplate. Not to be instantiated directly. + + Args: + name (str): The name of the ROCmTemplate object. + input_nodes (List[IRNode]): A list of input IRNodes. + layout (Layout): The layout of the output buffer / tensor. + input_reorder (Optional[List[int]]): An optional list that specifies the order of the input nodes. + + """ + super().__init__(name) + self.input_nodes = input_nodes + self.output_node: Buffer = Buffer(name="buf_out", layout=layout) + self.input_reorder = input_reorder + self.layout = layout + + def generate( # type: ignore[override] + self, + **kwargs, + ) -> ROCmTemplateCaller: + """ + Generates the ROCm template caller object for the given GEMM template and operation. This ROCmTemplateCaller + may be used to call and benchmark the generated ROCm kernel in a standalone manner to enable Autotuning. + + Args: + kwargs: Additional keyword arguments. + + Returns: + A ROCmTemplateCaller object representing the generated ROCm template caller. + """ + kernel_name = f"rocm_{self.name}" + kernel_hash_name = f"rocm_{self.name}_{next(self.index_counter)}" + with ( + patch.object(V.graph, "get_dtype", self._fake_get_dtype(self.output_node)), + ROCmTemplateKernel( + kernel_name=kernel_name, + runtime_arg_info=self.get_runtime_arg_info(), + runtime_arg_values=self.get_runtime_arg_values(**kwargs), + ) as kernel, + ): + code = self.render(kernel=kernel, **kwargs) + _, call_args, _, _ = kernel.args.python_argdefs() + log.debug("Autotune key: %s, Generated Code:\n%s", kernel_hash_name, code) + log.debug( + "Args: cpp_argdefs: %s, python_argdefs: %s", + kernel.args.cpp_argdefs(DTYPE_TO_ROCM_TYPE), + kernel.args.python_argdefs(), + ) + + input_reorder = ( + self.input_reorder + if self.input_reorder is not None + else list(range(len(self.input_nodes))) + ) + expected_args = list( + unique(self.input_nodes[idx].get_name() for idx in input_reorder) + ) + expected_args.extend([self.output_node.get_name()]) + assert list(call_args)[: len(expected_args)] == expected_args, ( + call_args, + expected_args, + ) + + size_args = ( + self.size_args() if hasattr(self, "size_args") else () + ) # subclass should define def size_args() + size_args_ints = [ + V.graph.sizevars.size_hint(arg) for arg in size_args + ] # resolve to ints for benchmarking + # The runtime args come right after the size args + runtime_args = self.get_runtime_arg_values(**kwargs) + extra_args = size_args_ints + runtime_args + bmreq = ROCmBenchmarkRequest( + kernel_name=kernel_name, + input_tensor_meta=TensorMeta.from_irnodes(self.input_nodes), + output_tensor_meta=TensorMeta.from_irnodes(self.output_node), + extra_args=extra_args, + source_code=code, + ) + + def make_kernel_render( + template_node: ROCmTemplateBuffer, + epilogue_nodes: Optional[Sequence[IRNode]] = None, + ): + kernel = ROCmTemplateKernel( + kernel_name="KERNEL_NAME", + runtime_arg_info=self.get_runtime_arg_info(), + runtime_arg_values=self.get_runtime_arg_values(**kwargs), + ) + render = functools.partial( + self.render, + kernel=kernel, + template_buffer_node=template_node, + epilogue_nodes=epilogue_nodes, + **kwargs, # includes "op" argument in case of CUTLASSGemmTemplate + ) + return kernel, render + + return ROCmTemplateCaller( + kernel_hash_name, + self.name, + self.input_nodes, + self.output_node.get_layout(), + make_kernel_render, + bmreq, + self, + kwargs, + ) + + def header(self) -> IndentedBuffer: + res = IndentedBuffer() + res.splice( + """ + #include + #include + #include + #include + #include + """ + ) + return res + + def globals(self) -> IndentedBuffer: + res = IndentedBuffer() + res.splice( + """ + // We compile all models with -fvisibility=hidden. Any symbols that need to be + // exposed in the final shared library must be declared with PT_EXPORT to make + // them visible. + #ifdef __GNUC__ // Applies to any compiler with GNU extensions (clang and g++) + #define PT_EXPORT __attribute__((__visibility__("default"))) + #else + #ifdef _WIN32 + #define PT_EXPORT __declspec(dllexport) + #else + #define PT_EXPORT + #endif + #endif + """ + ) + return res + + def render(self, **kwargs) -> str: + raise NotImplementedError + + def get_runtime_arg_info(self) -> list[ArgInfo]: + return [] + + def get_runtime_arg_values(self, **kwargs) -> list[Any]: + return [] diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/rocm_template_buffer.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/rocm_template_buffer.py new file mode 100644 index 0000000000000000000000000000000000000000..9c90d71c19980279f13cad86d7385825ba7212c9 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/rocm_template_buffer.py @@ -0,0 +1,27 @@ +from collections.abc import Callable, Sequence +from typing import TypeVar +from typing_extensions import ParamSpec + +from ...ir import Buffer, Layout, TemplateBuffer + + +_P = ParamSpec("_P") +_T = TypeVar("_T") + + +class ROCmTemplateBuffer(TemplateBuffer): + def __init__( + self, + layout: Layout, + inputs: Sequence[Buffer], + make_kernel_render: Callable[_P, _T], + workspace_size: int, + template: "ROCmTemplate", # type: ignore[name-defined] # noqa: F821 + ) -> None: + super().__init__(layout, inputs, make_kernel_render) + # Global memory (in bytes) needed for this template. + self.workspace_size = workspace_size + self.template = template + + def get_workspace_size(self) -> int: + return self.workspace_size if self.workspace_size is not None else 0 diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/rocm_utils.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/rocm_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..36871ac5c7f8fcf0a8b91a143168ab1b90530b0b --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/rocm_utils.py @@ -0,0 +1,17 @@ +# mypy: allow-untyped-defs + + +import torch + +from ..cpp_utils import DTYPE_TO_CPP + + +DTYPE_TO_ROCM_TYPE = { + **DTYPE_TO_CPP, + torch.float16: "uint16_t", + torch.float8_e4m3fnuz: "uint8_t", + torch.float8_e5m2fnuz: "uint8_t", + torch.float8_e4m3fn: "uint8_t", + torch.float8_e5m2: "uint8_t", + torch.bfloat16: "uint16_t", +} diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/compile_worker/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/compile_worker/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..023f568b4dfc96da5757cb2adb14be2d73ba8d80 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/compile_worker/__pycache__/__init__.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/compile_worker/__pycache__/__main__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/compile_worker/__pycache__/__main__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..38fafb7018f9bbb4112cfdbee37d07a9426c2705 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/compile_worker/__pycache__/__main__.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/compile_worker/__pycache__/subproc_pool.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/compile_worker/__pycache__/subproc_pool.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..876d5053a94f667e550c6c7444e4e92f0907d5c7 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/compile_worker/__pycache__/subproc_pool.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/compile_worker/__pycache__/timer.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/compile_worker/__pycache__/timer.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3e548efc086a45815d3ca892badbad3ed73e6628 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/compile_worker/__pycache__/timer.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/compile_worker/__pycache__/tracked_process_pool.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/compile_worker/__pycache__/tracked_process_pool.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..221745d9bcf10cebbda7e374d75b7c80f04d393f Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/compile_worker/__pycache__/tracked_process_pool.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/compile_worker/__pycache__/utils.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/compile_worker/__pycache__/utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f875f52269a6a00d11378a3df9c4892a0a157717 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/compile_worker/__pycache__/utils.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/kernel/flex/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/kernel/flex/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8c29b5a5cf54660b47c49a696ec04462bda01ac9 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/kernel/flex/__pycache__/__init__.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/kernel/flex/__pycache__/common.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/kernel/flex/__pycache__/common.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..43186e9f55b08113bb395bee80628ff123c55ef4 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/kernel/flex/__pycache__/common.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/kernel/flex/__pycache__/flex_attention.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/kernel/flex/__pycache__/flex_attention.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..06a6b6027dd5bde8f83add37c23af70605b1a8ef Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/kernel/flex/__pycache__/flex_attention.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/kernel/flex/__pycache__/flex_cpu.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/kernel/flex/__pycache__/flex_cpu.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bdb886f5be65ead84874aba051bbda9c75a7e279 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/kernel/flex/__pycache__/flex_cpu.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/kernel/flex/__pycache__/flex_decoding.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/kernel/flex/__pycache__/flex_decoding.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d4fbc16a4a912c95996f01eae36ff8e4ad96b39c Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/kernel/flex/__pycache__/flex_decoding.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/kernel/flex/__pycache__/flex_flash_attention.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/kernel/flex/__pycache__/flex_flash_attention.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..535e8bc045f89d348d6f530448dba6ccf7a0d6f6 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/kernel/flex/__pycache__/flex_flash_attention.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/kernel/flex/templates/flash_attention.py.jinja b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/kernel/flex/templates/flash_attention.py.jinja new file mode 100644 index 0000000000000000000000000000000000000000..0e83853fa5de8e2ae1a66726bf6e67b1a45fd212 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/kernel/flex/templates/flash_attention.py.jinja @@ -0,0 +1,76 @@ +{% if NEEDS_BLOCK_MASK %} +{{def_kernel("Q", "K", "V", "LOGSUMEXP", "KV_NUM_BLKS", "KV_IDX", "FULL_KV_NUM_BLKS", "FULL_KV_IDX")}} +{% else %} +{{def_kernel("Q", "K", "V", "LOGSUMEXP")}} +{% endif %} + from flash_attn.cute.interface import _flash_attn_fwd + from flash_attn.cute.block_sparsity import BlockSparseTensorsTorch + + # Transpose tensors for _flash_attn_fwd compatibility (B,H,M,D) -> (B,M,H,D) + q_transposed = Q.transpose(1, 2) + k_transposed = K.transpose(1, 2) + v_transposed = V.transpose(1, 2) + + @cute.jit + def score_mod(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, aux_tensors): + {{unpack_buffers("aux_tensors", indent_width=8)}} + {{ modification( + subgraph_number=0, + output_name="tSrS_ssa", + score="tSrS_ssa", + b="b_idx", + h="h_idx", + m="q_idx", + n="kv_idx", + out="tSrS_ssa" + ) | indent_except_first(2) }} + return tSrS_ssa + {{ set_cute_hash("score_mod", "score") }} + + # (B,M,H,D) -> (B,H,M,D) + output = {{get_output()}} + output_transposed = output.transpose(1, 2) + + {% if NEEDS_BLOCK_MASK %} + @cute.jit + def mask_mod(b_idx, h_idx, q_idx, kv_idx, aux_tensors): + {{unpack_buffers("aux_tensors", indent_width=8)}} + {{ modification( + subgraph_number=1, + output_name="mask_mod_output", + b="b_idx", + h="h_idx", + m="q_idx", + n="kv_idx", + ) | indent_except_first(2) }} + return mask_mod_output + {{ set_cute_hash("mask_mod", "mask") }} + block_sparse_tensors = BlockSparseTensorsTorch(KV_NUM_BLKS, KV_IDX, FULL_KV_NUM_BLKS, FULL_KV_IDX) + {% else %} + block_sparse_tensors = None + mask_mod = None + {% endif %} + + # Collect any additional tensor buffers that were added during modifications + {% set tensor_buffers = get_tensor_buffers() -%} + {% if tensor_buffers -%} + buffers = [{% for buffer in tensor_buffers %}{{buffer}}{% if not loop.last %}, {% endif %}{% endfor %}] + buffers = list(buffers) + {% else -%} + buffers = None + {% endif -%} + + # Out and LSE filled inplace + _flash_attn_fwd( + q_transposed, + k_transposed, + v_transposed, + softmax_scale={{SM_SCALE}}, + return_lse=True, + score_mod=score_mod, + mask_mod=mask_mod, + out=output_transposed, + lse=LOGSUMEXP, + block_sparse_tensors=block_sparse_tensors, + aux_tensors=buffers + ) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/kernel/flex/templates/flex_backwards.py.jinja b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/kernel/flex/templates/flex_backwards.py.jinja new file mode 100644 index 0000000000000000000000000000000000000000..3467d84475d0ce70fba937df75f7d93caabfb5dd --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/kernel/flex/templates/flex_backwards.py.jinja @@ -0,0 +1,620 @@ +{{def_kernel("Q", "K", "V", "LSE", "DELTA", "DO", "DQ", "DV", "KV_NUM_BLKS", "KV_IDX", "Q_NUM_BLKS", "Q_IDX", "FULL_KV_NUM_BLKS", "FULL_KV_IDX", "FULL_Q_NUM_BLKS", "FULL_Q_IDX")}} + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) + # DELTA: Precomputed sum(OUT*DO, axis=-1) + # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value + # DK: Derivative of Key, is the written to via the store_output call due to some limitations with + # inductor codegen + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # (Modifiable) Performance tuning options + # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. + # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. + # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. + # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. + # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. + # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qd = {{stride("Q")}} + stride_kz, stride_kh, stride_kn, stride_kd = {{stride("K")}} + stride_vz, stride_vh, stride_vn, stride_vd = {{stride("V")}} + stride_doz, stride_doh, stride_dom, stride_dod = {{stride("DO")}} + + stride_dqz, stride_dqh, stride_dqm, stride_dqd = {{stride("DQ")}} + stride_dvz, stride_dvh, stride_dvm, stride_dvd = {{stride("DV")}} + + ZQ = {{size("Q", 0)}} + HQ = {{size("Q", 1)}} + HKV = {{size("K", 1)}} + Q_LEN = {{size("Q", 2)}} + ZKV = {{size("K", 0)}} + KV_LEN = {{size("K", 2)}} + + MATMUL_PRECISION = Q.dtype.element_ty + + pid = tl.program_id(0).to(INDEX_DTYPE) + NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) + NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) + + off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx + off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx + off_zkv = off_zq % ZKV # kv batch idx + + SPARSE_Z = {{size("KV_NUM_BLKS", 0)}} + SPARSE_HQ = {{size("KV_NUM_BLKS", 1)}} + + sparse_idx_z = off_zq % SPARSE_Z + + k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) + v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) + # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) + + # offset K, V, DV pointers for batch/kv-head + K += k_adj + V += v_adj + DV += dv_adj + + RCP_LN2 = 1.44269504 + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + if pid >= NUM_KV_BLOCKS: + off_pid = pid - NUM_KV_BLOCKS + # THIS BLOCK DOES DQ + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS + start_m2_block = off_pid % NUM_Q_BLOCKS + off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE + stride_kv_num_blks_h = {{stride("KV_NUM_BLKS", 1)}} + stride_kv_idx_h = {{stride("KV_IDX", 1)}} + stride_kv_idx_m = {{stride("KV_IDX", 2)}} + + sparse_idx_hq2 = off_hq2 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 + + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) + do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) + dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) + off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) + + Q2 = Q + q_adj2 + DO2 = DO + do_adj2 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + DQ2 = DQ + dq_adj2 + LSE2 = LSE + off_chz2 + DELTA2 = DELTA + off_chz2 + + # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_m2 = start_m2_block * BLOCK_M2 + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + + # load Q and do: they stay in SRAM throughout the inner loop. + q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + if IS_DIVISIBLE: + Di = tl.load(DELTA2 + offs_m2) + lse = tl.load(LSE2 + offs_m2) + else: + Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + lse = lse[:, None] + + # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # KV_IDX and KV_NUM_BLKS are always contiguous. + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + {{gen_argdefs()}}, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + {{gen_argdefs()}}, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dQ. + dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd + dq *= SM_SCALE + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dq_ptrs, dq) + else: + tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) + else: + # THIS BLOCK DOES DK & DV + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) + + pid_mask = pid // SPARSE_KV_MULTIPLE + + stride_q_num_blks_h = {{stride("Q_NUM_BLKS", 1)}} + stride_q_idx_h = {{stride("Q_IDX", 1)}} + stride_q_idx_n = {{stride("Q_IDX", 2)}} + + + dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_n1 = pid * BLOCK_N1 + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + + # load K and V: they stay in SRAM throughout the inner loop. + k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + for off_g in range(0, GQA_SHARED_HEADS): + off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) + do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) + dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) + off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) + + Q1 = Q + q_adj1 + DO1 = DO + do_adj1 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + LSE1 = LSE + off_chz1 + DELTA1 = DELTA + off_chz1 + + sparse_idx_hq1 = off_hq1 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 + + sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask + sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 + + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Q_IDX and Q_NUM_BLKS are always contiguous. + q_indices = Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + {{gen_argdefs()}}, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. + q_indices = FULL_Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + {{gen_argdefs()}}, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dV and dK. + dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd + + index_n = offs_n1[:, None] + index_k = offs_k[None, :] + index_v = offs_v[None, :] + + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dv_ptrs, dv) + else: + tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) + + dk *= SM_SCALE + + if SAFE_HEAD_DIM: + mask = index_n < KV_LEN + else: + mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) + + # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED]) + {{store_output(("off_zq", "off_hkv", "index_n", "index_k"), "dk", "mask", indent_width=8, val_shape=("BLOCK_N1", "QK_HEAD_DIM_ROUNDED"))}} + +@triton.jit +def bwd_dq_inner( + {{gen_argdefs()}}, + K, V, # pointers + dq, q, do, Di, lse, + off_z, off_hq, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + {{gen_defines() | indent_except_first(1) }} + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = {{size("Q", 2)}} + KV_LEN = {{size("K", 2)}} + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd + vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + + hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) + + for start_n in range(0, hi): + dq = bwd_dq_block_mn( + {{gen_argdefs()}}, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + return dq + + +@triton.jit +def bwd_dq_block_mn( + {{gen_argdefs()}}, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + {{gen_defines() | indent_except_first(1)}} + + # NB reversed order to since K is transposed + kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) + qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + pre_mod_scores = qk + n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim + # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary + m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None) + + {{ modification( + subgraph_number=0, + output_name="post_mod_scores", + score="qk", + b="off_z", + h="off_hq", + m="m", + n="n", + out="qk" + ) | indent_except_first(1) }} + + + {# Note: Selective masking DQ + We load elements beyond KV_LEN w/ zero, some score mods may convert this elements to NaN + Example: lambda x, *_: 1 / score, this NaN would propagate regardless of other masking + We only need to do this on the m1 dim since these elements take part in the final reduction + for DQ #} + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + {{ modification( + subgraph_number=2, + output_name="mask_mod_output", + score="qk", + b="off_z", + h="off_hq", + m="m", + n="n", + ) | indent_except_first(2) }} + + # apply mask for partial masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + p = tl.math.exp2(post_mod_scores - lse) + # Compute dP and dS. + # NB reversed order to since V is transposed + vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) + + dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) + ds = p * (dp - Di[:, None]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + {{ modification( + subgraph_number=1, + output_name = "grad_scores", + score="pre_mod_scores", + b="off_z", + h="off_hq", + m="m", + n="n", + grad_score_mod="ds" + ) | indent_except_first(1) }} + {# See Note Selective masking DQ #} + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if WRITE_DQ: + scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN) + {{ modification( + subgraph_number=3, + output_name=None, + mask="scatter_mask", + score="pre_mod_scores", + b="off_z", + h="off_hq", + m="m", + n="n", + grad_score_mod="ds" + ) | indent_except_first(2) }} + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = grad_scores + + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + ds = tl.where(mask_mod_output, ds, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = ds.to(MATMUL_PRECISION) + # Compute dQ. + dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) + + return dq + + +@triton.jit +def bwd_dkdv_inner( + {{gen_argdefs()}}, + Q, DO, DELTA, LSE, # pointers + dk, dv, k, v, + off_z, off_hq, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + {{gen_defines() | indent_except_first(1) }} + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = {{size("Q", 2)}} + KV_LEN = {{size("K", 2)}} + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd + do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + + # The minimum is needed to handle the case where we run with a super large + # SPARSE_BLOCK_SIZE (i.e. no block-mask!) + hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) + + for start_m in range(0, hi): + dk, dv = bwd_dkdv_block_mn( + {{gen_argdefs()}}, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + offs_m1 += offset + + return dk, dv + + +@triton.jit +def bwd_dkdv_block_mn( + {{gen_argdefs()}}, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + {{gen_defines() | indent_except_first(1) }} + + # NB reversed order since Q is transposed + qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) + # Load LSE before computing qk to reduce pipeline stall. + if IS_DIVISIBLE: + lse = tl.load(LSE + offs_m1) + else: + lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qkT *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim + # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary + n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None) + + pre_mod_scores = qkT + {{ modification( + subgraph_number=0, + output_name="post_mod_scores", + score="qkT", + b="off_z", + h="off_hq", + m="m", + n="n", + out="qkT" + ) | indent_except_first(1) }} + + {# Note: Selective masking DK/DV + We load elements beyond Q_LEN w/ zero, some score mods may convert this elements to NaN + Example: lambda x, *_: 1 / score, this NaN would propagate regardless of other masking + We only need to do this on the m1 dim since these elements take part in the final reduction + for DK/DV #} + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + {{ modification( + subgraph_number=2, + output_name="mask_mod_output", + b="off_z", + h="off_hq", + m="m", + n="n", + ) | indent_except_first(2) }} + # (grads) apply mask for fully masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + pT = tl.math.exp2(post_mod_scores - lse[None, :]) + do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + # Compute dV. + ppT = pT + dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) + if IS_DIVISIBLE: + Di = tl.load(DELTA + offs_m1) + else: + Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) + dsT = pT * (dpT - Di[None, :]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + {{ modification( + subgraph_number=1, + output_name = "grad_scores", + score="pre_mod_scores", + b="off_z", + h="off_hq", + m="m", + n="n", + grad_score_mod="dsT" + ) | indent_except_first(1) }} + + {# See Note: Selective masking DK/DV#} + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if not WRITE_DQ: + idx_b = off_z + idx_h = off_hq + idx_m = m + idx_n = n + scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN) + {{ modification( + subgraph_number=3, + output_name=None, + mask="scatter_mask", + score="pre_mod_scores", + b="idx_b", + h="idx_h", + m="idx_m", + n="idx_n", + grad_score_mod="dsT" + ) | indent_except_first(2) }} + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dsT = grad_scores + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + dsT = tl.where(mask_mod_output, dsT, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) + + return dk, dv diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/kernel/flex/templates/flex_decode.py.jinja b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/kernel/flex/templates/flex_decode.py.jinja new file mode 100644 index 0000000000000000000000000000000000000000..e5f0e118c5631404b0f1fda5086e2447f64e4fbe --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/kernel/flex/templates/flex_decode.py.jinja @@ -0,0 +1,242 @@ + {{def_kernel("Q", "K", "V", "M", "L", "KV_NUM_BLKS", "KV_IDX", "FULL_KV_NUM_BLKS", "FULL_KV_IDX")}} + # Sub notation for this kernel: + # Q: Query, K: Key, V: Value + # reduction buffers: M rowmax across local KV split, L local sumexp across local KV split + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # BLOCK_M, QK_HEAD_DIM: M, and D dimemsion are always assigned to the same block + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head t: Number of kv splits + # (Modifiable) Config options: + # SPLIT_KV: number of blocks K & V are split into + # TILE_KV: length of each local KV split + # BLOCK_M: block size that Q is padded along seqlen dim. + # BLOCK_N: block size of K & V along N dimension. + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # change of base out of the loop + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # SAFE_M_BOUNDARY: Is Q seqlen a multiple of BLOCK_M? If so, we can skip an extra boundary check for loading query. + # SAFE_N_BOUNDARY: Is KV seqlen a multiple of BLOCK_N? If so, we can skip an extra boundary check for loading key/value. + + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. + # + # SPARSE_KV_BLOCK_SIZE: sparse mask block size along KV seqlen dim. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # + # + # Output: ACC output accumulated across local KV split. + + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define Q Strides + stride_qz, stride_qh, stride_qg, stride_qm, stride_qk = {{stride("Q")}} + stride_kz, stride_kh, stride_kn, stride_kk = {{stride("K")}} + stride_vz, stride_vh, stride_vn, stride_vk = {{stride("V")}} + stride_mz, stride_mt, stride_mh, stride_mm = {{stride("M")}} + stride_lz, stride_lt, stride_lh, stride_lm = {{stride("L")}} + + + Z = {{size("Q", 0)}} + ZKV = {{size("K", 0)}} + HKV = {{size("Q", 1)}} + G: tl.constexpr = GQA_SHARED_HEADS + HQ = HKV * G + Q_LEN = {{size("Q", 3)}} + KV_LEN = {{size("K", 2)}} + + MATMUL_PRECISION = Q.dtype.element_ty + + # Make sure each split is a multiple of BLOCK_N + TILE_KV_OG = tl.cdiv(KV_LEN, SPLIT_KV) + TILE_KV = tl.cdiv(TILE_KV_OG, BLOCK_N) * BLOCK_N + TILE_KV_MULTIPLE: tl.constexpr = (TILE_KV // BLOCK_N) + + off_z = tl.program_id(0).to(INDEX_DTYPE) // HKV + off_zkv = off_z % ZKV + off_hkv = tl.program_id(0).to(INDEX_DTYPE) % HKV + off_t = tl.program_id(1).to(INDEX_DTYPE) + + q_offset = off_z * stride_qz + off_hkv * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + K = K + k_offset + V = V + v_offset + + SPARSE_Z = {{size("KV_NUM_BLKS", 0)}} + SPARSE_HQ = {{size("KV_NUM_BLKS", 1)}} + + sparse_idx_z = off_z % SPARSE_Z + sparse_idx_h = off_hkv % SPARSE_HQ + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + SPARSE_KV_BLOCK_CNT = tl.cdiv(KV_LEN, SPARSE_KV_BLOCK_SIZE) + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + # initialize offsets + tl.device_assert(BLOCK_M % G == 0) + BLOCK_M_PER_HQ: tl.constexpr = BLOCK_M // G + off_g = tl.arange(0, G) # [G] + offs_g = tl.ravel(tl.broadcast_to(off_g[:, None], [G, BLOCK_M_PER_HQ])) # [BLOCK_M] + offs_hq = offs_g + off_hkv * G + off_m = tl.arange(0, BLOCK_M_PER_HQ) # [BLOCK_M_PER_HQ] + offs_m = tl.ravel(tl.broadcast_to(off_m[None, :], [G, BLOCK_M_PER_HQ])) # [BLOCK_M] + offs_d = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_vd = tl.arange(0, V_HEAD_DIM_ROUNDED) + + # Get HZ offsets for KV_NUM_BLKS and KV_IDX + stride_block_z, stride_block_h, stride_block_row = {{stride("KV_NUM_BLKS")}} + sparse_block_hz_offset = sparse_idx_z * stride_block_z + sparse_idx_h * stride_block_h + stride_kv_z, stride_kv_h, stride_kv_row, stride_kv_col = {{stride("KV_IDX")}} + sparse_idx_hz_offset = sparse_idx_z * stride_kv_z + sparse_idx_h * stride_kv_h + + # Calculate KV blocks that belong this CTA. + block_n_start = off_t * TILE_KV_MULTIPLE # n_offset inside sparse block + block_n_end = block_n_start + TILE_KV_MULTIPLE # end BLOCK_N + + q_range = stride_qg * off_g[:, None, None] + stride_qm * off_m[None, :, None] + stride_qk * offs_d[None, None, :] + + if not SAFE_M_BOUNDARY and not SAFE_HEAD_DIM: + q = tl.load(Q + q_offset + q_range, mask=(offs_d[None, None, :] < QK_HEAD_DIM) & (off_m[None, :, None] < Q_LEN)) + elif SAFE_M_BOUNDARY and not SAFE_HEAD_DIM: + q = tl.load(Q + q_offset + q_range, mask=offs_d[None, None, :] < QK_HEAD_DIM) + elif not SAFE_M_BOUNDARY and SAFE_HEAD_DIM: + q = tl.load(Q + q_offset + q_range, mask=off_m[None, :, None] < Q_LEN) + else: + q = tl.load(Q + q_offset + q_range) + + q = tl.reshape(q, [BLOCK_M, QK_HEAD_DIM_ROUNDED]) + + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # find first kv block we are loading and the number of blocks we are loading + # Offset the kv_indices tensor by the correct batch and head + kv_indices = KV_IDX + sparse_idx_hz_offset + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_block_hz_offset) + MAX_KV_IDX = {{size("KV_IDX", -1)}} + indices_idx = (block_n_start // SPARSE_KV_MULTIPLE) % (MAX_KV_IDX) + off_n_block_in_sparse = block_n_start % SPARSE_KV_MULTIPLE + off_n = tl.load(kv_indices + indices_idx) * SPARSE_KV_BLOCK_SIZE + off_n_block_in_sparse * BLOCK_N + # first kv block we're loading + + # last valid block according to sparse mask + block_n_last_valid = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + offs_n = tl.arange(0, BLOCK_N) + off_n + + desc_k = None + desc_v = None + {%- if USE_TMA %} + desc_k = tl.make_tensor_descriptor( + base=K, + shape=[KV_LEN, QK_HEAD_DIM], + strides=[stride_kn, 1], + block_shape=[BLOCK_N, QK_HEAD_DIM_ROUNDED], + ) + + desc_v = tl.make_tensor_descriptor( + base=V, + shape=[KV_LEN, V_HEAD_DIM], + strides=[stride_vn, 1], + block_shape=[BLOCK_N, V_HEAD_DIM_ROUNDED], + ) + {%- endif %} + + acc, l_i, m_i = forward_inner( + {{gen_argdefs()}}, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulatd values + acc, l_i, m_i, + #offsets + off_z, offs_hq[:, None], offs_m[:, None], offs_n[None, :], + off_n, + #block sparse data + kv_indices, kv_num_blocks, + block_n_start, block_n_end if block_n_end <= block_n_last_valid else block_n_last_valid, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + kv_indices = FULL_KV_IDX + sparse_idx_hz_offset + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_block_hz_offset) + # Assign full block in a reverse order for off_t. Prioritize the last CTA. + block_n_start = (SPLIT_KV - off_t - 1) * TILE_KV_MULTIPLE + block_n_end = block_n_start + TILE_KV_MULTIPLE + indices_idx = (block_n_start // SPARSE_KV_MULTIPLE) % (MAX_KV_IDX) + off_n_block_in_sparse = block_n_start % SPARSE_KV_MULTIPLE + off_n = tl.load(kv_indices + indices_idx) * SPARSE_KV_BLOCK_SIZE + off_n_block_in_sparse * BLOCK_N + + # last valid block according to sparse mask + block_n_last_valid = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + offs_n = tl.arange(0, BLOCK_N) + off_n + + acc, l_i, m_i = forward_inner( + {{gen_argdefs()}}, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulatd values + acc, l_i, m_i, + #offsets + off_z, offs_hq[:, None], offs_m[:, None], offs_n[None, :], + off_n, + #block sparse data + kv_indices, kv_num_blocks, + block_n_start, block_n_end if block_n_end <= block_n_last_valid else block_n_last_valid, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + m_offset = off_t * stride_mt + off_z * stride_mz + l_offset = off_t * stride_lt + off_z * stride_lz + + M_block_ptr = tl.make_block_ptr( + base=M + m_offset, + shape=(G, Q_LEN), # (G, M) + strides=(stride_mh, stride_mm), + offsets=(off_hkv*G, 0), + block_shape=(G, BLOCK_M_PER_HQ), + order=(1, 0) + ) + L_block_ptr = tl.make_block_ptr( + base=L + l_offset, + shape=(G, Q_LEN), # (G, M) + strides=(stride_lh, stride_lm), + offsets=(off_hkv*G, 0), + block_shape=(G, BLOCK_M_PER_HQ), + order=(1, 0) + ) + + # Store output, logsumexp and rowmax for cross CTA reduction. (all in float32, even when input data are in fp16) + m_i = m_i.reshape(G, BLOCK_M_PER_HQ) + l_i = l_i.reshape(G, BLOCK_M_PER_HQ) + if SAFE_M_BOUNDARY: + tl.store(M_block_ptr, m_i) + tl.store(L_block_ptr, l_i) + else: + tl.store(M_block_ptr, m_i, boundary_check=(1,)) + tl.store(L_block_ptr, l_i, boundary_check=(1,)) + + # -- store output + idx_z = off_z + idx_t = off_t + idx_hq = off_hkv*G + off_g[:, None, None] + idx_m = off_m[None, :, None] + idx_d = offs_vd[None, None, :] + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + acc = acc.reshape(G, BLOCK_M_PER_HQ, V_HEAD_DIM) + {{store_output(("idx_z", "idx_t", "idx_hq", "idx_m", "idx_d"), "acc", "mask", val_shape=("GQA_SHARED_HEADS", "BLOCK_M_PER_HQ", "V_HEAD_DIM"))}} diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/kernel/vendored_templates/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/kernel/vendored_templates/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..29870ff846de034bfdef4edc3c0dd20b95f69799 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/kernel/vendored_templates/__pycache__/__init__.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/kernel/vendored_templates/__pycache__/cutedsl_grouped_gemm.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/kernel/vendored_templates/__pycache__/cutedsl_grouped_gemm.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a93b903da8b66b4c919ca8cb07a2dc34d2a4975c Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/kernel/vendored_templates/__pycache__/cutedsl_grouped_gemm.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/template_heuristics/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/template_heuristics/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..32eb29bfd10933cbab5df1107ca4f35990f29b36 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/template_heuristics/__pycache__/__init__.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/template_heuristics/__pycache__/aten.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/template_heuristics/__pycache__/aten.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2b8aac9c06fbe3033dd721b9b912f00d54854753 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/template_heuristics/__pycache__/aten.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/template_heuristics/__pycache__/base.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/template_heuristics/__pycache__/base.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..368efa23671914258119c9bf9497fcc81e78b1a1 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/template_heuristics/__pycache__/base.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/template_heuristics/__pycache__/contiguous_mm.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/template_heuristics/__pycache__/contiguous_mm.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bde546477c4b8397fc147cf636ff2840c15ebf08 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/template_heuristics/__pycache__/contiguous_mm.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/template_heuristics/__pycache__/cutedsl.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/template_heuristics/__pycache__/cutedsl.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..685dc9c3e1c721d9f05607a1f811c409516aec36 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/template_heuristics/__pycache__/cutedsl.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/template_heuristics/__pycache__/decompose_k.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/template_heuristics/__pycache__/decompose_k.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b07d065c6c1d9eb83b9b0441721086ce1c997143 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/template_heuristics/__pycache__/decompose_k.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/template_heuristics/__pycache__/gemm.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/template_heuristics/__pycache__/gemm.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ea7b4832c9a6066312ec98bcf9d51f18e7ad3864 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/template_heuristics/__pycache__/gemm.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/template_heuristics/__pycache__/params.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/template_heuristics/__pycache__/params.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..10f7300c9b57914978a2331864a69cbfcbc1dcfd Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/template_heuristics/__pycache__/params.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/template_heuristics/__pycache__/registry.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/template_heuristics/__pycache__/registry.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..07711b2f9b95274cee3efd5f5780397bbe346a6d Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/template_heuristics/__pycache__/registry.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/template_heuristics/__pycache__/triton_addmm.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/template_heuristics/__pycache__/triton_addmm.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6d015c3ae0e8f5b16d39bfbbde6d7e86dbcaa388 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/template_heuristics/__pycache__/triton_addmm.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_library/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_library/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0126f19be8f33003dc8916d2db7bd7cf7f74dc23 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_library/__pycache__/__init__.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_library/__pycache__/autograd.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_library/__pycache__/autograd.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6520ed8e9210f4269a00ab02f5e0a3345ec594c9 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_library/__pycache__/autograd.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_library/__pycache__/custom_ops.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_library/__pycache__/custom_ops.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0dc891251976d7d8d95c5214eedf09d289c2ef6e Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_library/__pycache__/custom_ops.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_library/__pycache__/effects.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_library/__pycache__/effects.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4454b371f9f03586e62865d6cbfbf9e34494d070 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_library/__pycache__/effects.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_library/__pycache__/fake_class_registry.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_library/__pycache__/fake_class_registry.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b85f9813bf9c6193df7fa73acaad04cc655f64ad Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_library/__pycache__/fake_class_registry.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_library/__pycache__/fake_impl.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_library/__pycache__/fake_impl.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4a6e0d97d962184180bbe12e5ccb7a6514548c57 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_library/__pycache__/fake_impl.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_library/__pycache__/fake_profile.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_library/__pycache__/fake_profile.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4da584d2f67bc36471d86a4ef6caa22e70c6c2a2 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_library/__pycache__/fake_profile.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_library/__pycache__/infer_schema.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_library/__pycache__/infer_schema.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8f9913ee9b3dfe250d51c79ffe8ce465a18a4cc8 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_library/__pycache__/infer_schema.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_library/__pycache__/opaque_object.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_library/__pycache__/opaque_object.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f4cc69b9838cc5f5da8f192309af54a00082925a Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_library/__pycache__/opaque_object.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_library/__pycache__/simple_registry.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_library/__pycache__/simple_registry.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2291cd88f3e352d2864344c0d073612af082a6e2 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_library/__pycache__/simple_registry.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_library/__pycache__/triton.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_library/__pycache__/triton.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..20cb0f11e369957dcf30290ae289107bd9c5d1b8 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_library/__pycache__/triton.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_library/__pycache__/utils.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_library/__pycache__/utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3541869decf343374249ad69644a24d5751e110b Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_library/__pycache__/utils.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_refs/__pycache__/_conversions.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_refs/__pycache__/_conversions.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d9c882e6e0eb2df358a8eb4f41f182dda3c2b77e Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_refs/__pycache__/_conversions.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_refs/__pycache__/fft.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_refs/__pycache__/fft.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5359c357fe104d1905f6ccdab5be14ecec72b58c Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_refs/__pycache__/fft.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_refs/linalg/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_refs/linalg/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..393e42b06d15cf4736c23a03e87d05468ee0ab35 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_refs/linalg/__init__.py @@ -0,0 +1,435 @@ +# mypy: allow-untyped-defs +import math +from functools import partial +from typing import Optional, Union + +import torch +import torch._prims as prims +import torch._prims_common as utils +import torch._refs as refs +import torch._refs.linalg as linalg +from torch import Tensor +from torch._prims_common import ( + check_fp_or_complex, + check_is_matrix, + Dim, + DimsType, + ELEMENTWISE_TYPE_PROMOTION_KIND, + IntLike, + TensorLikeType, +) +from torch._prims_common.wrappers import ( + _maybe_convert_to_dtype, + elementwise_type_promotion_wrapper, + out_wrapper, +) + + +__all__ = [ + "diagonal", + "matrix_norm", + "norm", + "svd", + "svdvals", + "vector_norm", + "vecdot", + "cross", +] + + +def _check_norm_dtype(dtype: Optional[torch.dtype], x_dtype: torch.dtype, fn_name: str): + """ + Checks related to the dtype kwarg in `linalg.*norm` functions + """ + if dtype is not None: + torch._check( + utils.is_float_dtype(dtype) or utils.is_complex_dtype(dtype), + lambda: f"{fn_name}: dtype should be floating point or complex. Got {dtype}", + ) + torch._check( + utils.is_complex_dtype(dtype) == utils.is_complex_dtype(x_dtype), + lambda: "{fn_name}: dtype should be {d} for {d} inputs. Got {dtype}".format( + fn_name=fn_name, + d="complex" if utils.is_complex_dtype(x_dtype) else "real", + dtype=dtype, + ), + ) + torch._check( + utils.get_higher_dtype(dtype, x_dtype) == dtype, + lambda: f"{fn_name}: the dtype of the input ({x_dtype}) should be convertible " + f"without narrowing to the specified dtype ({dtype})", + ) + + +import operator + +# Utilities should come BEFORE this import +from torch._decomp import register_decomposition +from torch._decomp.decompositions import pw_cast_for_opmath + + +@register_decomposition(torch._ops.ops.aten.linalg_cross) +@out_wrapper() +@pw_cast_for_opmath +def cross(a: Tensor, b: Tensor, dim: int = -1): + torch._check( + a.ndim == b.ndim, + lambda: "linalg.cross: inputs must have the same number of dimensions.", + ) + torch._check( + a.size(dim) == 3 and b.size(dim) == 3, + lambda: f"linalg.cross: inputs dim {dim} must have length 3, got {a.size(dim)} and {b.size(dim)}", + ) + a, b = torch.broadcast_tensors(a, b) + dim = utils.canonicalize_dim(a.ndim, dim) + idx = torch.arange(3, device=a.device) + return a.index_select(dim, (idx + 1) % 3) * b.index_select( + dim, (idx + 2) % 3 + ) - a.index_select(dim, (idx + 2) % 3) * b.index_select(dim, (idx + 1) % 3) + + +def diagonal( + input: TensorLikeType, + *, + offset: int = 0, + dim1: int = -2, + dim2: int = -1, +) -> TensorLikeType: + return torch.diagonal(input, offset=offset, dim1=dim1, dim2=dim2) + + +def _check_vector_norm_args( + x: TensorLikeType, ord: Union[float, int] = 2, dim: Optional[DimsType] = None +): + from torch.fx.experimental.symbolic_shapes import sym_or + + if not (ord < 0.0 or ord == float("inf")): + return + + torch._check( + sym_or( + x.numel() != 0, + not isinstance(dim, IntLike) and dim is not None and len(dim) != 0, + ), + lambda: f"linalg.vector_norm cannot compute the {ord} norm on an empty tensor " + "because the operation does not have an identity", + ) + + shape = x.shape + if dim is not None and not isinstance(dim, IntLike): + for d in dim: + torch._check( + sym_or(x.numel() != 0, d < len(shape) and d >= 0 and shape[d] != 0), + lambda: f"linalg.vector_norm cannot compute the {ord} norm on the " + f"dimension {d} because this dimension is empty and the " + "operation does not have an identity", + ) + + +@register_decomposition(torch._ops.ops.aten.linalg_vector_norm) +@out_wrapper(exact_dtype=True) +def vector_norm( + x: TensorLikeType, + ord: Union[float, int] = 2, + dim: Optional[DimsType] = None, + keepdim: bool = False, + *, + dtype: Optional[torch.dtype] = None, +) -> Tensor: + from torch.fx.experimental.symbolic_shapes import guard_or_false + + check_fp_or_complex(x.dtype, "linalg.vector_norm") + + if isinstance(dim, Dim): + dim = [dim] # type: ignore[assignment] + + _check_vector_norm_args(x, ord, dim) + + _check_norm_dtype(dtype, x.dtype, "linalg.vector_norm") + + computation_dtype, result_dtype = utils.reduction_dtypes( + x, utils.REDUCTION_OUTPUT_TYPE_KIND.COMPLEX_TO_FLOAT, dtype + ) + + to_result_dtype = partial(_maybe_convert_to_dtype, dtype=result_dtype) + + # Implementation + if ord == 0.0: + return torch.sum(torch.ne(x, 0.0), dim=dim, keepdim=keepdim, dtype=result_dtype) + elif ord == float("inf"): + return to_result_dtype(torch.amax(torch.abs(x), dim=dim, keepdim=keepdim)) # type: ignore[return-value,arg-type] + elif ord == float("-inf"): + return to_result_dtype(torch.amin(torch.abs(x), dim=dim, keepdim=keepdim)) # type: ignore[return-value,arg-type] + else: + # From here on the computation dtype is important as the reduction is non-trivial + x = _maybe_convert_to_dtype(x, computation_dtype) # type: ignore[assignment] + reduce_sum = partial(torch.sum, dim=dim, keepdim=keepdim) + + is_ord_even = ord % 2 == 0 if isinstance(ord, IntLike) else ord % 2.0 == 0.0 + if dim == []: + dim = None + + if (dim is None and x.numel() == 1) or ( + dim is not None + and (x.ndim > 0 and all(guard_or_false(x.shape[d] == 1) for d in dim)) + ): + if x.ndim > 64: + raise RuntimeError( + f"Received a tensor with {x.ndim} dimensions, but only tensors with up to 64 dims are supported!" + ) + x = torch.abs(x) + if keepdim or x.ndim == 0: + return to_result_dtype(x).contiguous() + elif dim is None: + return to_result_dtype(x).flatten()[0] + else: + new_shape = [s for d, s in enumerate(x.shape) if d not in dim] + return to_result_dtype(x.view(new_shape)).contiguous() + + if not (is_ord_even and utils.is_float_dtype(x.dtype)): + x = torch.abs(x) + return to_result_dtype(torch.pow(reduce_sum(torch.pow(x, ord)), 1.0 / ord)) # type: ignore[return-value] + + +def _backshift_permutation(dim0, dim1, ndim): + # Auxiliary function for matrix_norm + # Computes the permutation that moves the two given dimensions to the back + ret = [i for i in range(ndim) if i != dim0 and i != dim1] + ret.extend((dim0, dim1)) + return ret + + +def _inverse_permutation(perm): + # Given a permutation, returns its inverse. It's equivalent to argsort on an array + return [i for i, j in sorted(enumerate(perm), key=operator.itemgetter(1))] + + +# CompositeImplicitAutograd +@out_wrapper(exact_dtype=True) +def matrix_norm( + A: TensorLikeType, + ord: Union[float, str] = "fro", + dim: DimsType = (-2, -1), + keepdim: bool = False, + *, + dtype: Optional[torch.dtype] = None, +) -> TensorLikeType: + # shape + check_is_matrix(A, "linalg.matrix_norm") + # dim + + dim = utils.canonicalize_dims(A.ndim, dim) + if isinstance(dim, Dim): + dim = (dim,) # type: ignore[assignment] + torch._check( + len(dim) == 2, lambda: f"linalg.matrix_norm: dim must be a 2-tuple. Got {dim}" + ) + torch._check( + # pyrefly: ignore [index-error] + dim[0] != dim[1], + # pyrefly: ignore [index-error] + lambda: f"linalg.matrix_norm: dims must be different. Got ({dim[0]}, {dim[1]})", + ) + # dtype arg + _check_norm_dtype(dtype, A.dtype, "linalg.matrix_norm") + + if isinstance(ord, str): + # ord + torch._check( + ord in ("fro", "nuc"), + lambda: f"linalg.matrix_norm: Order {ord} not supported.", + ) + # dtype + check_fp_or_complex( + A.dtype, "linalg.matrix_norm", allow_low_precision_dtypes=ord != "nuc" + ) + + if ord == "fro": + return vector_norm(A, 2, dim, keepdim, dtype=dtype) + else: # ord == "nuc" + if dtype is not None: + A = _maybe_convert_to_dtype(A, dtype) # type: ignore[assignment] + # pyrefly: ignore [index-error] + perm = _backshift_permutation(dim[0], dim[1], A.ndim) + result = torch.sum(svdvals(prims.transpose(A, perm)), -1, keepdim) + if keepdim: + inv_perm = _inverse_permutation(perm) + result = prims.transpose(torch.unsqueeze(result, -1), inv_perm) + return result + else: + # ord + abs_ord = abs(ord) + torch._check( + abs_ord in (2, 1, float("inf")), + lambda: f"linalg.matrix_norm: Order {ord} not supported.", + ) + # dtype + check_fp_or_complex( + A.dtype, "linalg.matrix_norm", allow_low_precision_dtypes=ord != 2 + ) + + max_min = partial(torch.amax if ord > 0.0 else torch.amin, keepdim=keepdim) + + def _max_min_wrapper(A, dim): + # pyrefly: ignore [unsupported-operation] + if A.size(dim) == 0 and ord > 0.0: + new_size = list(A.size()) + if keepdim: + new_size[dim] = 1 + else: + del new_size[dim] + return torch.zeros(new_size, dtype=A.dtype, device=A.device) + else: + return max_min(A, dim) + + if abs_ord == 2.0: + if dtype is not None: + A = _maybe_convert_to_dtype(A, dtype) # type: ignore[assignment] + # pyrefly: ignore [index-error] + perm = _backshift_permutation(dim[0], dim[1], A.ndim) + result = _max_min_wrapper(svdvals(prims.transpose(A, perm)), dim=-1) + if keepdim: + inv_perm = _inverse_permutation(perm) + result = prims.transpose(torch.unsqueeze(result, -1), inv_perm) + return result + else: # 1, -1, inf, -inf + # pyrefly: ignore [bad-unpacking] + dim0, dim1 = dim + if abs_ord == float("inf"): + dim0, dim1 = dim1, dim0 + if not keepdim and (dim0 < dim1): + dim1 -= 1 + return _max_min_wrapper( + vector_norm(A, 1.0, dim=dim0, keepdim=keepdim, dtype=dtype), dim1 + ) + + +# CompositeImplicitAutograd +@out_wrapper(exact_dtype=True) +def norm( + A: TensorLikeType, + ord: Optional[Union[float, str]] = None, + dim: Optional[DimsType] = None, + keepdim: bool = False, + *, + dtype: Optional[torch.dtype] = None, +) -> TensorLikeType: + if dim is not None: + if isinstance(dim, Dim): + dim = (dim,) # type: ignore[assignment] + torch._check( + len(dim) in (1, 2), + lambda: f"linalg.norm: If dim is specified, it must be of length 1 or 2. Got {dim}", + ) + elif ord is not None: + torch._check( + A.ndim in (1, 2), + lambda: f"linalg.norm: If dim is not specified but ord is, the input must be 1D or 2D. Got {A.ndim}D", + ) + + if ord is not None and ( + (dim is not None and len(dim) == 2) or (dim is None and A.ndim == 2) + ): + if dim is None: + dim = (0, 1) + return matrix_norm(A, ord, dim, keepdim, dtype=dtype) + else: + if ord is None: + ord = 2.0 + return vector_norm(A, ord, dim, keepdim, dtype=dtype) # type: ignore[arg-type] + + +# CompositeImplicitAutograd +@out_wrapper("U", "S", "Vh", exact_dtype=True) +def svd(A: TensorLikeType, full_matrices: bool = True) -> tuple[Tensor, Tensor, Tensor]: + return prims.svd(A, full_matrices=full_matrices) + + +# CompositeImplicitAutograd +@out_wrapper(exact_dtype=True) +def svdvals(A: TensorLikeType) -> Tensor: + return svd(A, full_matrices=False)[1] + + +# CompositeImplicitAutograd +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("x", "y"), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def vecdot(x: Tensor, y: Tensor, dim: int = -1) -> Tensor: + check_fp_or_complex(x.dtype, "linalg.vecdot") + return (x.conj() * y).sum(dim=dim) + + +def _pivots_to_permutation(pivots, shape, *, inverse=False): + perm = torch.empty(shape, dtype=torch.int32, device=pivots.device) + perm[..., :] = torch.arange(shape[-1], dtype=torch.int32, device=pivots.device) + indices = range(shape[-1]) + if inverse: + indices = reversed(indices) + + if len(shape) > 1: + for i in indices: + j_s = pivots[..., i] + perm_i = perm[..., i].clone() + j_idx = torch.meshgrid( + *[torch.arange(s, device=perm.device) for s in j_s.shape], indexing="ij" + ) + (j_s,) + perm_j = perm[j_idx] + perm.index_put_(j_idx, perm_i) + perm[..., i].copy_(perm_j) + + else: + for i in indices: + j = pivots[i] + perm_i = perm[i].clone() + perm_j = perm[j].clone() + perm[i].copy_(perm_j) + perm[j].copy_(perm_i) + + return perm + + +def _apply_pivots(a, pivots, shape, *, inverse=False): + perm = _pivots_to_permutation(pivots - 1, shape, inverse=inverse) + + if len(shape) == 1: + return a[perm, :] + else: + idx = torch.meshgrid( + *[torch.arange(s, device=a.device) for s in perm.shape], indexing="ij" + )[:-1] + (perm, slice(None)) + return a[idx] + + +def linalg_lu_solve_out_mps(LU, pivots, B, *, left=True, adjoint=False, out): + if out.numel() == 0: + return + + if not left: + adjoint = not adjoint + B = B.mH + + if adjoint: + lu_ = LU.mH + x = torch.linalg.solve_triangular(lu_, B, left=True, upper=False) + x = torch.linalg.solve_triangular( + lu_, x, left=True, upper=True, unitriangular=True + ) + x = _apply_pivots(x, pivots, LU.shape[:-1], inverse=True) + else: + x = _apply_pivots(B, pivots, LU.shape[:-1]) + x = torch.linalg.solve_triangular( + LU, x, left=True, upper=False, unitriangular=True + ) + x = torch.linalg.solve_triangular(LU, x, left=True, upper=True) + + if not left: + x = x.mH + + out.copy_(x) + + +mps_lib = torch.library.Library("aten", "IMPL", "MPS") # noqa: TOR901 +mps_lib.impl("aten::linalg_lu_solve.out", linalg_lu_solve_out_mps) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_refs/linalg/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_refs/linalg/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e187a76d82cc675f587d959007378cfd75425ebf Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_refs/linalg/__pycache__/__init__.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_refs/nn/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_refs/nn/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c9c2ef67bd9d44a21f9d3673ba631c0840740ced --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_refs/nn/__init__.py @@ -0,0 +1 @@ +__all__: list[str] = [] diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_refs/nn/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_refs/nn/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..853509547f5ce3fc893efc4d8c19ea5165511010 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_refs/nn/__pycache__/__init__.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_refs/nn/functional/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_refs/nn/functional/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..135788a439de5cf7882f659133ce63649d8308e2 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_refs/nn/functional/__init__.py @@ -0,0 +1,1293 @@ +# mypy: allow-untyped-decorators +# mypy: allow-untyped-defs +import math +from collections.abc import Callable +from functools import wraps +from typing import Concatenate, Optional, TypeVar, Union +from typing_extensions import ParamSpec + +import torch +import torch._prims as prims +import torch._prims_common as utils +import torch._refs as refs +from torch._decomp import register_decomposition +from torch._prims_common import ( + ELEMENTWISE_TYPE_PROMOTION_KIND, + NumberType, + ShapeType, + TensorLike, + TensorLikeType, +) +from torch._prims_common.wrappers import ( + elementwise_type_promotion_wrapper, + elementwise_unary_scalar_wrapper, + out_wrapper, +) +from torch._refs import _make_inplace + + +__all__ = [ + "alpha_dropout", + "celu", + "celu_", + "channel_shuffle", + "dropout", + "elu", + "elu_", + "gelu", + "glu", + "group_norm", + "hardshrink", + "hardtanh", + "hinge_embedding_loss", + "huber_loss", + "l1_loss", + "layer_norm", + "leaky_relu", + "log_softmax", + "margin_ranking_loss", + "mish", + "mish_", + "mse_loss", + "nll_loss", + "pairwise_distance", + "pdist", + "poisson_nll_loss", + "prelu", + "relu", + "relu6", + "selu", + "selu_", + "smooth_l1_loss", + "softmax", + "softmin", + "softplus", + "softshrink", + "tanhshrink", + "threshold", + "threshold_", + "triplet_margin_loss", +] + +_T = TypeVar("_T") +_P = ParamSpec("_P") + +Tensor = torch.Tensor +aten = torch._ops.ops.aten +DispatchKey = torch._C.DispatchKey # type: ignore[attr-defined] + + +def _dropout_helper( + self: TensorLikeType, + val: float, +) -> TensorLikeType: + """ + Helper function for all dropout-type operators. During training, + some of the elements of the input tensor are randomly masked. + + Returns the masked tensor of the boolean values. + + """ + + return ( + refs._uniform_helper( + self.shape, low=0.0, high=1.0, dtype=torch.float32, device=self.device + ) + < val + ) + + +@register_decomposition(aten.alpha_dropout) +def alpha_dropout( + self: TensorLikeType, p: float = 0.5, training: bool = False, inplace: bool = False +) -> TensorLikeType: + if inplace: + raise NotImplementedError + + if not training: + return self + + torch._check( + p <= 1 and p >= 0, + lambda: f"dropout probability has to be between 0 and 1, but got, {p}", + ) + + if p == 1: + return torch.zeros_like(self) + + if p == 0: + return self + + dropout_mask = _dropout_helper(self, 1 - p) + + # From paper: Self-Normalizing Neural Networks (https://arxiv.org/pdf/1706.02515.pdf) + # alpha = - SELU.alpha * SELU.scale, here + # SELU.alpha = 1.6732632423543772848170429916717 and + # SELU.scale = 1.0507009873554804934193349852946 + alpha = -1.7580993408473766 + + a = 1.0 / math.sqrt((alpha * alpha * p + 1) * (1 - p)) + b = torch.logical_not(dropout_mask) + b = b * (alpha * a) + alpha * a * p + dropout_mask = a * dropout_mask + + return self * dropout_mask + b + + +def _inplace_wrapper(fn: Callable[_P, _T]) -> Callable[_P, _T]: + """ + Given a nn.functional non-linearity, implements its `inplace: bool` argument + """ + + # nb. We use the name of the first argument used in the unary references + @wraps(fn) + def _fn(*args: _P.args, **kwargs: _P.kwargs) -> _T: + # pyrefly: ignore [unsupported-operation] + a = args[0] + if "inplace" not in kwargs: + kwargs["inplace"] = False + # pyrefly: ignore [unsupported-operation] + if kwargs["inplace"]: + torch._check( + "out" not in kwargs, + lambda: "Cannot set inplace=True and pass out= at the same time", + ) + kwargs["inplace"] = False + kwargs["out"] = a + return fn(*args, **kwargs) + else: + return fn(*args, **kwargs) + + return _fn + + +# celu is implemented specially because it has an alpha argument +# celu is very similar to elu +@register_decomposition(aten.celu) +@_inplace_wrapper +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("a",), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def celu( + a: TensorLikeType, alpha: Optional[NumberType] = None, inplace: bool = False +) -> TensorLikeType: + """ + Reference implementation of torch.nn.functional.celu + """ + + if inplace: + raise NotImplementedError + + rhs: TensorLikeType + if alpha is not None: + python_type = utils.dtype_to_type(a.dtype) + if not utils.is_weakly_lesser_type(type(alpha), python_type): + msg = f"alpha argument of type {type(alpha)} cannot be safely cast to type {python_type}!" + raise ValueError(msg) + rhs = alpha * torch.expm1(torch.true_divide(a, alpha)) # type: ignore[arg-type] + else: + rhs = torch.expm1(a) + + return torch.where(a > 0, a, rhs) + + +@_inplace_wrapper +@out_wrapper() +def dropout( + a: TensorLikeType, p: float = 0.5, training: bool = True, inplace: bool = False +) -> TensorLikeType: + if inplace: + raise NotImplementedError + + if not training: + return a + + torch._check( + p <= 1 and p >= 0, + lambda: f"dropout probability has to be between 0 and 1, but got, {p}", + ) + + if p == 1: + return torch.zeros_like(a) + + if p == 0: + return a + + scale = 1 / (1 - p) + dropout_mask = _dropout_helper(a, 1 - p) + + return a * dropout_mask * scale + + +@register_decomposition(aten.elu) +@_inplace_wrapper +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("a",), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def elu( + a: TensorLikeType, + alpha: NumberType = 1.0, + scale: NumberType = 1.0, + input_scale: NumberType = 1.0, + inplace: bool = False, +) -> TensorLikeType: + """ + Reference implementation of torch.nn.functional.elu + """ + if inplace: + raise NotImplementedError + + # nb. This should be factored out into a can_cast aux function + python_type = utils.dtype_to_type(a.dtype) + torch._check( + utils.is_weakly_lesser_type(type(input_scale), python_type), + lambda: f"input_scale argument of type {type(input_scale)} cannot be safely cast to type {python_type}!", + ) + torch._check( + utils.is_weakly_lesser_type(type(scale), python_type), + lambda: f"scale argument of type {type(scale)} cannot be safely cast to type {python_type}!", + ) + torch._check( + utils.is_weakly_lesser_type(type(alpha), python_type), + lambda: f"alpha argument of type {type(alpha)} cannot be safely cast to type {python_type}!", + ) + + return torch.where(a > 0, scale * a, (alpha * scale) * torch.expm1(a * input_scale)) + + +@register_decomposition(aten.relu) +@_inplace_wrapper +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("a",), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def relu(a: TensorLikeType, inplace: bool = False) -> TensorLikeType: + """ + Reference implementation of torch.nn.functional.relu + """ + + if inplace: + raise NotImplementedError + + return torch.where(torch.le(a, 0), 0, a) + + +@register_decomposition(aten.channel_shuffle) +@out_wrapper() +def channel_shuffle(input: TensorLikeType, groups: int) -> TensorLikeType: + """ + Reference implementation of :func:`torch.nn.functional.channel_shuffle`. + """ + from torch._meta_registrations import device_hint + + torch._check( + input.dim() > 2, + lambda: f"channel_shuffle expects input with > 2 dims, but got input with sizes {list(input.size())}", + ) + c = input.shape[1] + torch._check( + groups > 0, + lambda: f"Number of groups to divide channels in must be positive. Value of groups:{groups}", + ) + torch._check( + (c % groups) == 0, + lambda: f"Number of channels must be divisible by groups. Got {c} channels and {groups} groups.", + ) + n = input.shape[0] + cg = c // groups + dhw = input.shape[2:] + + if input.numel() == 0 or ( + device_hint(input) == "cuda" and (groups == 1 or groups == c) + ): + return input.view(input.shape) + + return ( + input.reshape(n, groups, cg, *dhw) + .transpose(1, 2) + .reshape(input.shape) + .contiguous() + ) + + +def group_norm( + input: Tensor, + num_groups: int, + weight: Optional[Tensor] = None, + bias: Optional[Tensor] = None, + eps: float = 1e-5, +) -> Tensor: + """ + Reference implementation of :func:`torch.nn.functional.group_norm`. + """ + torch._check( + input.ndim >= 2, + lambda: f"Expected at least 2 dimensions for input tensor but received {input.ndim}", + ) + + batch_size = input.shape[0] + num_channels = input.shape[1] + torch._check( + num_channels % num_groups == 0, + lambda: "Expected number of channels in input to be divisible by num_groups, " + + f"but got input of shape {input.shape} and num_groups = {num_groups}", + ) + + # input shape is (N, C, *), so we flatten all inner dimensions except (N, C) + flattened_inner_size = 1 + for dim_length in input.shape[2:]: + flattened_inner_size *= dim_length + + return torch.native_group_norm( + input, + weight, + bias, + batch_size, + num_channels, + flattened_inner_size, + num_groups, + eps, + )[0] + + +def layer_norm( + input: Tensor, + normalized_shape: ShapeType, + weight: Optional[Tensor] = None, + bias: Optional[Tensor] = None, + eps: float = 1e-5, +) -> Tensor: + """ + Reference implementation of :func:`torch.nn.functional.layer_norm`. + """ + return torch.native_layer_norm(input, normalized_shape, weight, bias, eps)[0] + + +@register_decomposition(aten.leaky_relu) +@_inplace_wrapper +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("a",), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def leaky_relu( + a: TensorLikeType, negative_slope: float = 0.01, inplace: bool = False +) -> TensorLikeType: + """ + Reference implementation of torch.nn.functional.leaky_relu + """ + + if inplace: + raise NotImplementedError + + python_type = utils.dtype_to_type(a.dtype) + if not utils.is_weakly_lesser_type(type(negative_slope), python_type): + msg = f"negative_slope argument of type {type(negative_slope)} cannot be safely cast to type {python_type}!" + raise ValueError(msg) + return torch.where(torch.gt(a, 0), a, torch.mul(a, negative_slope)) + + +@register_decomposition(aten.mish) +@_inplace_wrapper +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("a",), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def mish(a: TensorLikeType, inplace: bool = False) -> TensorLikeType: + """ + Reference implementation of torch.nn.functional.mish + """ + + if inplace: + raise NotImplementedError + return a * torch.tanh(torch.nn.functional.softplus(a)) + + +@register_decomposition(aten.selu) +@_inplace_wrapper +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("a",), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def selu(a: TensorLikeType, inplace: bool = False) -> TensorLikeType: + """ + Reference implementation of torch.nn.functional.selu + """ + if inplace: + raise NotImplementedError + + alpha = 1.6732632423543772848170429916717 + scale = 1.0507009873554804934193349852946 + + rhs = alpha * torch.expm1(a) + + return scale * torch.where(a > 0, a, rhs) + + +# Forwarding alias: the functional variant doesn't support the out kwarg +# CompositeImplicitAutograd - don't register decomp +def softmax( + a: TensorLikeType, + dim: Optional[int] = None, + _stacklevel: int = 3, # for compat when using TorchRefsMode(strict=True) + dtype: Optional[torch.dtype] = None, +) -> TensorLikeType: + # The error is for compat with regular PyTorch, which has this behavior + # deprecated. For PrimTorch, it's fine to drop support for deprecated + # behavior because it requires explicit opt in. This error is to inform + # users how to update their calls. + torch._check(dim is not None, lambda: "implicit dim not supported, use dim=X") + return torch.softmax(a=a, dim=dim, dtype=dtype) # type: ignore[call-overload] + + +# CompositeImplicitAutograd - don't register decomp +def softmin( + a: TensorLikeType, + dim: Optional[int] = None, + _stacklevel: int = 3, # for compat when using TorchRefsMode(strict=True) + dtype: Optional[torch.dtype] = None, +) -> TensorLikeType: + # The error is for compat with regular PyTorch, which has this behavior + # deprecated. For PrimTorch, it's fine to drop support for deprecated + # behavior because it requires explicit opt in. This error is to inform + # users how to update their calls. + torch._check(dim is not None, lambda: "implicit dim not supported, use dim=X") + return torch.softmax(a=-a, dim=dim, dtype=dtype) # type: ignore[call-overload] + + +# softplus is implemented specially because it has beta and threshold arguments +@register_decomposition(aten.softplus) +@_inplace_wrapper +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("a",), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def softplus( + a: TensorLikeType, + beta: Optional[NumberType] = None, + threshold: NumberType = 20, + inplace: bool = False, +) -> TensorLikeType: + """ + Reference implementation of torch.nn.functional.softplus + """ + + if inplace: + raise NotImplementedError + + rhs: TensorLikeType + if beta is not None: + python_type = utils.dtype_to_type(a.dtype) + if not utils.is_weakly_lesser_type(type(beta), python_type): + msg = f"beta argument of type {type(beta)} cannot be safely cast to type {python_type}!" + raise ValueError(msg) + scaled_input = a * beta + rhs = torch.true_divide(torch.log1p(torch.exp(scaled_input)), beta) # type: ignore[arg-type] + + else: + scaled_input = a + rhs = torch.log1p(torch.exp(scaled_input)) + + return torch.where(scaled_input > threshold, a, rhs) + + +@aten.hardshrink.default.py_impl(DispatchKey.Autograd) +@register_decomposition(aten.hardshrink) +@out_wrapper() +def hardshrink(a: TensorLikeType, lambd: float = 0.5): + # Formula for reference, + # hardshrink(x) = x if x > lambd + # = x if x < -lambd + # = 0 otherwise + return torch.where(torch.abs(a) <= lambd, 0, a) + + +@aten.softshrink.default.py_impl(DispatchKey.Autograd) +@register_decomposition(aten.softshrink) +@out_wrapper() +def softshrink(a: TensorLikeType, lambd: float = 0.5): + # Formula for reference, + # softshrink(x) = x - lambd if x > lambd + # = x + lambd if x < -lambd + # = 0 otherwise + torch._check( + 0 <= lambd <= torch.finfo(a.dtype).max, + lambda: f"lambda must be in range [0, {torch.finfo(a.dtype).max}] for input dtype {a.dtype}, but found {lambd}", + ) + # We implement this in one torch.where to generate better code in the backward + # see https://github.com/pytorch/pytorch/pull/107052#discussion_r1293748211 + # We multiply by 0 for dealing with nans + return torch.where(torch.abs(a) > lambd, a - torch.sign(a) * lambd, a * 0) + + +# Losses +def _reduction_int_to_str(reduction: int) -> str: + from torch._decomp.decompositions import Reduction + + if reduction == Reduction.NONE.value: + return "none" + elif reduction == Reduction.MEAN.value: + return "mean" + elif reduction == Reduction.SUM.value: + return "sum" + else: + raise ValueError(f"{reduction} is not a valid value for reduction") + + +def _apply_loss_reduction(loss: TensorLikeType, reduction: str) -> TensorLikeType: + if reduction == "sum": + return torch.sum(loss) + elif reduction == "mean": + return torch.mean(loss) + else: # reduction == "none" + return loss + + +def _check_reduction_value(reduction: str): + if reduction not in ("mean", "sum", "none"): + raise ValueError(f"{reduction} is not a valid value for reduction") + + +# This helper function maps depreciated arguments, "size_average" and "reduce" +# to their corresponding "reduction" string argument +def _get_string_reduction_arg( + *, size_average: Optional[bool], reduce: Optional[bool] +) -> str: + if size_average is None: + size_average = True + if reduce is None: + reduce = True + if size_average and reduce: + ret = "mean" + elif reduce: + ret = "sum" + else: + ret = "none" + return ret + + +# CompositeImplicitAutograd - don't register decomp +@elementwise_type_promotion_wrapper( + type_promoting_args=("input", "target"), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.COMPLEX_TO_FLOAT, +) +def l1_loss( + input: TensorLikeType, + target: TensorLikeType, + size_average: Optional[bool] = None, + reduce: Optional[bool] = None, + reduction: str = "mean", +) -> TensorLikeType: + """ + Reference implementation of torch.nn.functional.l1_loss + """ + if size_average is not None or reduce is not None: + # TODO: Raise exception instead of converting value. This is only for + # primTorch since it can drop support for deprecated arguments. + # msg = "size_average and reduce args are deprecated, please use reduction argument." + reduction = _get_string_reduction_arg(size_average=size_average, reduce=reduce) + _check_reduction_value(reduction) + loss = torch.abs(input - target) + return _apply_loss_reduction(loss, reduction) + + +@elementwise_type_promotion_wrapper( + type_promoting_args=("input", "target"), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.COMPLEX_TO_FLOAT, +) +def smooth_l1_loss( + input: TensorLikeType, + target: TensorLikeType, + size_average: Optional[bool] = None, + reduce: Optional[bool] = None, + reduction: str = "mean", + beta: float = 1.0, +) -> TensorLikeType: + """ + Reference implementation of torch.nn.functional.smooth_l1_loss + """ + if size_average is not None or reduce is not None: + # TODO: Raise exception instead of converting value. This is only for + # primTorch since it can drop support for deprecated arguments. + # msg = "size_average and reduce args are deprecated, please use reduction argument." + reduction = _get_string_reduction_arg(size_average=size_average, reduce=reduce) + _check_reduction_value(reduction) + + if beta == 0.0: + return torch.nn.functional.l1_loss( + input, target, size_average=size_average, reduce=reduce, reduction=reduction + ) + else: + loss = torch.abs(input - target) + # pyrefly: ignore [unsupported-operation] + loss = torch.where(loss < beta, 0.5 * loss**2 / beta, loss - 0.5 * beta) + return _apply_loss_reduction(loss, reduction) + + +# Forwarding alias: the functional variant doesn't support the out kwarg +# CompositeImplicitAutograd - don't register decomp +def log_softmax( + a: TensorLikeType, + dim: Optional[int] = None, + _stacklevel: int = 3, # for compat when using TorchRefsMode(strict=True) + dtype: Optional[torch.dtype] = None, +) -> TensorLikeType: + # The error is for compat with regular PyTorch, which has this behavior + # deprecated. For PrimTorch, it's fine to drop support for deprecated + # behavior because it requires explicit opt in. This error is to inform + # users how to update their calls. + torch._check(dim is not None, lambda: "implicit dim not supported, use dim=X") + return torch.log_softmax(a=a, dim=dim, dtype=dtype) # type: ignore[call-overload] + + +@register_decomposition(aten.margin_ranking_loss) +def margin_ranking_loss( + input1: TensorLikeType, + input2: TensorLikeType, + target: TensorLikeType, + margin: float = 0.0, + reduction: str = "mean", +) -> TensorLikeType: + # loss_without_reduction = max(0, -target * (input1 - input2) + margin) + if input1.ndim != input2.ndim or input1.ndim != target.ndim: + raise RuntimeError( + "margin_ranking_loss : All input tensors should have same dimension but got sizes: " + f"input1: {input1.shape}, input2: {input2.shape}, target: {target.shape} " + ) + _check_reduction_value(reduction) + loss = torch.clamp_min(-target * (input1 - input2) + margin, 0) + return _apply_loss_reduction(loss, reduction) + + +@elementwise_type_promotion_wrapper( + type_promoting_args=("input", "target"), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.COMPLEX_TO_FLOAT, +) +def mse_loss( + input: TensorLikeType, + target: TensorLikeType, + size_average: Optional[bool] = None, + reduce: Optional[bool] = None, + reduction: str = "mean", +) -> TensorLikeType: + if size_average is not None or reduce is not None: + # TODO: Raise exception instead of converting value. This is only for + # primTorch since it can drop support for deprecated arguments. + # msg = "size_average and reduce args are deprecated, please use reduction argument." + reduction = _get_string_reduction_arg(size_average=size_average, reduce=reduce) + _check_reduction_value(reduction) + loss = torch.pow(input - target, 2) + return _apply_loss_reduction(loss, reduction) + + +@register_decomposition(aten.hinge_embedding_loss) +def hinge_embedding_loss( + input: TensorLikeType, + target: TensorLikeType, + margin: float = 1.0, + reduction: str = "mean", +) -> TensorLikeType: + # loss_without_reduction = input if y == 1 + # = max(0, margin - input) if y == -1 + _check_reduction_value(reduction) + margin_clamp = torch.clamp_min(margin - input, 0) + output_margin = torch.where(target != 1, margin_clamp, 0) + output_self = torch.where(target != -1, input, 0) + loss = output_margin + output_self + return _apply_loss_reduction(loss, reduction) + + +def _nll_loss_nd( + input: TensorLikeType, + target: TensorLikeType, + weight: Optional[TensorLikeType], + reduction: str, + ignore_index: int, +) -> TensorLikeType: + torch._check( + input.ndim > 0 and input.ndim <= 3, + lambda: f"Expected input dimension to be either [1, 2, 3] but received {input.ndim}.", + ) + + torch._check( + (input.ndim == 1) or (input.shape[0] == target.shape[0]), + lambda: f"Expected input batch size {input.shape[0]} to match target batch size {target.shape[0]}.", + ) + + _check_reduction_value(reduction) + + flat_target = torch.flatten(target) + ignore_classes_mask = torch.eq(flat_target, ignore_index) + + # TODO: Enable data-dependent checks with debug mode + # TODO: This check does not work with FakeTensor inputs; See Issue #85834 + # Explicit cast for class_check to bool; See Issue #78071 + """ + from torch._subclasses.fake_tensor import FakeTensor + num_classes = input.shape[1] if input.ndim > 1 else input.shape[0] + valid_classes_mask = torch.logical_and( + (flat_target >= 0), (flat_target < num_classes) + ) + class_check = torch.all(torch.logical_or(ignore_classes_mask, valid_classes_mask)) + torch._check( + isinstance(target, FakeTensor) or bool(class_check.item()), + lambda: "A target class is out-of-bounds and not the ignore index.", + ) + """ + + ignore_class_weight = torch.scalar_tensor(0, dtype=input.dtype, device=input.device) + class_weight = ( + torch.scalar_tensor(1, dtype=input.dtype, device=input.device) + if weight is None + else weight[flat_target] + ) + current_weight = torch.where( + ignore_classes_mask, + ignore_class_weight, + class_weight, + ) + + if input.ndim == 1: + # implicit batch size = 1 + # input (1 batch size, C classes) + loss = -input[target] * current_weight + elif input.ndim == 2: + # input (N batch size, C classes) + batch_size = input.shape[0] + loss = -input[torch.arange(batch_size), target] * current_weight + else: + # 3D case (N batch size, C classes, K dimensions) + # input (N batch size, C classes, K) + batch_size = input.shape[0] + extent = input.shape[2] + numel = batch_size * extent + indices = torch.arange(numel) + bdx = indices // extent + kdx = indices % extent + loss = -input[bdx, flat_target, kdx] * current_weight + loss = torch.reshape(loss, target.shape) + + if reduction == "none": + return loss + elif reduction == "sum": + return torch.sum(loss) + else: + # calculate weighted mean of the loss function + return torch.sum(loss) / torch.sum(current_weight) + + +@register_decomposition(aten.nll_loss) +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("input",), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def nll_loss( + input: TensorLikeType, + target: TensorLikeType, + weight: Optional[TensorLikeType] = None, + size_average: Optional[bool] = None, + ignore_index: int = -100, + reduce: Optional[bool] = None, + reduction: str = "mean", +) -> TensorLikeType: + """ + Reference implementation of torch.nn.functional.nll_loss + """ + torch._check( + input.ndim > 0, + lambda: f"Expected input tensor to have 1 or more dimensions (got {input.ndim})", + ) + + # TODO: raise exception instead of converting value + # msg = "size_average and reduce args are deprecated, please use reduction argument." + # Convert these options for consistency with the eager mode + if size_average is not None or reduce is not None: + reduction = _get_string_reduction_arg(size_average=size_average, reduce=reduce) + + # The expected behavior when the target and input have zero elements: + # reduction = 'none' --- tensor([]) + # reduction = 'sum' --- tensor(0.) + # reduction = 'mean' --- tensor(nan) + # Mean reduction on empty tensors produces NaN. See the discussion in + # https://github.com/pytorch/pytorch/pull/64572#issuecomment-926504162 + if input.numel() == 0 and target.numel() == 0: + if reduction == "none": + return torch.zeros_like(target) + elif reduction == "sum": + return torch.empty_like(target) + else: + return torch.full_like(target, float("nan")) + + # The _nll_loss_nd helper function handles the most common cases. + # ndim == 1 (Single Example) + # => Batch Size: 1, Input: (C), Target: () + # ndim == 2 (k = 1) + # => Batch Size: N, Input: (N, C), Target: (N) + # ndim == 3 (k > 1) + # => Batch Size: N, Input: (N, C, K), Target: (N, K) + if input.ndim <= 3: + return _nll_loss_nd(input, target, weight, reduction, ignore_index) + + # For ndim > 3, we reshape the input and target to 3-D case. + # Input (N batch-size, C classes, k-dimensions) + # Target (N batch-size, k-dimensions) + torch._check( + input.ndim > 0 and target.ndim > 0 and target.shape[1:] == input.shape[2:], + lambda: ( + "Expected input and target to both have ndim > 0 and " + "target.shape[1:] == input.shape[2:], but got " + f"target.shape {target.shape} and input.shape {input.shape}" + ), + ) + + batch_size = input.shape[0] + num_classes = input.shape[1] + out_size = [batch_size] + list(target.shape[1:]) + + input = torch.reshape(input, [batch_size, num_classes, -1]) + target = torch.reshape(target, [batch_size, -1]) + if reduction != "none": + return _nll_loss_nd(input, target, weight, reduction, ignore_index) + else: + result = _nll_loss_nd(input, target, weight, reduction, ignore_index) + # reshape flattened inner-dim to original k-dimensions + return torch.reshape(result, out_size) + + +# TODO: This ref supports int reduction and out kwarg to be compatible with ATen: +# https://github.com/pytorch/pytorch/issues/83931 +# TODO: Could be rewritten to support complex: +# https://github.com/pytorch/pytorch/pull/85041 +@register_decomposition(aten.huber_loss) +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("input", "target"), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def huber_loss( + input: TensorLikeType, + target: TensorLikeType, + reduction: Union[str, int] = "mean", + delta: float = 1.0, +) -> TensorLikeType: + """ + Reference implementation of torch.nn.functional.huber_loss + """ + if type(reduction) is int: + reduction = _reduction_int_to_str(reduction) + _check_reduction_value(reduction) # type: ignore[arg-type] + torch._check( + delta > 0, + lambda: "huber_loss does not support non-positive values for delta.", + ) + z = (input - target).abs() + loss = torch.where(z < delta, 0.5 * z * z, delta * (z - 0.5 * delta)) + return _apply_loss_reduction(loss, reduction) # type: ignore[arg-type] + + +# tanhshrink does not use _make_elementwise_unary_reference because it does not support out +@elementwise_unary_scalar_wrapper +@elementwise_type_promotion_wrapper( + type_promoting_args=("a",), + type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, +) +def tanhshrink(a: TensorLikeType) -> TensorLikeType: + """ + Reference implementation of torch.nn.functional.tanhshrink + """ + if not isinstance(a, TensorLike): + raise RuntimeError( + "Expected a tensor input for an elementwise unary operation!" + ) + return a - torch.tanh(a) + + +@register_decomposition(aten.threshold) +@_inplace_wrapper +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("a",), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def threshold( + a: TensorLikeType, + threshold: NumberType, + value: Union[bool, int, float], + inplace: bool = False, +) -> TensorLikeType: + """ + Reference implementation of torch.nn.functional.threshold + """ + + if inplace: + raise NotImplementedError + + return torch.where(a <= threshold, value, a) + + +# CompositeImplicitAutograd - don't register decomp +# No elementwise type promotion - core op doesn't explicitly type promote +def triplet_margin_loss( + anchor: TensorLikeType, + positive: TensorLikeType, + negative: TensorLikeType, + margin: float = 1.0, + p: float = 2, + eps: float = 1e-6, + swap: bool = False, + size_average: Optional[bool] = None, + reduce: Optional[bool] = None, + reduction: str = "mean", +) -> TensorLikeType: + if size_average is not None or reduce is not None: + # TODO: Raise exception instead of converting value. This is only for + # primTorch since it can drop support for deprecated arguments. + # msg = "size_average and reduce args are deprecated, please use reduction argument." + reduction = _get_string_reduction_arg(size_average=size_average, reduce=reduce) + + if margin <= 0: + raise ValueError(f"margin must be greater than 0, got {margin}") + + # torch.nn.functional.triplet_margin_with_distance_loss has no ref defined + # since it's a pure Python implementation. Use this helper instead. + return _triplet_margin_with_distance_loss( + anchor=anchor, + positive=positive, + negative=negative, + distance_function=lambda x, y: torch.pairwise_distance(x, y, p, eps), + margin=margin, + swap=swap, + reduction=reduction, + ) + + +# Pure Python impl - don't register decomp and don't add a ref. Defined as a +# helper here since triplet_margin_loss can be nicely implemented with it. +def _triplet_margin_with_distance_loss( + anchor: TensorLikeType, + positive: TensorLikeType, + negative: TensorLikeType, + *, + distance_function: Optional[ + Callable[[TensorLikeType, TensorLikeType], TensorLikeType] + ] = None, + margin: float = 1.0, + swap: bool = False, + reduction: str = "mean", +) -> TensorLikeType: + _check_reduction_value(reduction) + + a_dim = anchor.ndim + p_dim = positive.ndim + n_dim = negative.ndim + torch._check( + a_dim == p_dim and p_dim == n_dim, + lambda: ( + f"The anchor, positive, and negative tensors are expected to have " + f"the same number of dimensions, but got: anchor {a_dim}D, " + f"positive {p_dim}D, and negative {n_dim}D inputs" + ), + ) + + if distance_function is None: + distance_function = torch.pairwise_distance + + dist_pos = distance_function(anchor, positive) + dist_neg = distance_function(anchor, negative) + # The distance swap is described in the paper "Learning shallow + # convolutional feature descriptors with triplet losses" by V. Balntas, E. + # Riba et al. If True, and if the positive example is closer to the + # negative example than the anchor is, swaps the positive example and the + # anchor in the loss computation. + if swap: + dist_swap = distance_function(positive, negative) + dist_neg = torch.minimum(dist_neg, dist_swap) + loss = torch.clamp_min(margin + dist_pos - dist_neg, 0) + return _apply_loss_reduction(loss, reduction) + + +@register_decomposition(aten.hardtanh) +@_inplace_wrapper +@out_wrapper() +@elementwise_unary_scalar_wrapper +@elementwise_type_promotion_wrapper( + type_promoting_args=("a"), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def hardtanh( + a: TensorLikeType, + min_val: NumberType = -1, + max_val: NumberType = 1, + inplace: bool = False, +) -> TensorLikeType: + """ + Reference implementation of torch.nn.functional.hardtanh + """ + if inplace: + raise NotImplementedError + if utils.is_boolean_dtype(a.dtype): + raise RuntimeError("Bool inputs not supported for hardtanh") + + # preserve legacy behavior of boundaries not causing type promotion + if utils.is_integer_dtype(a.dtype): + min_val = int(min_val) # type: ignore[arg-type] + max_val = int(max_val) # type: ignore[arg-type] + if not (a.dtype != torch.uint8 or (min_val >= 0 and max_val >= 0)): + raise RuntimeError( + "Cannot do hardtanh on an unsigned type with negative limits" + ) + + if min_val > max_val: # type: ignore[operator] + raise ValueError("min_val cannot be greater than max_val") + + return torch.clamp(a, min_val, max_val) # type: ignore[arg-type] + + +@register_decomposition(aten.gelu) +@out_wrapper() +@elementwise_unary_scalar_wrapper +@elementwise_type_promotion_wrapper( + type_promoting_args=("a",), + type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def gelu(a: TensorLikeType, approximate: str = "none") -> TensorLikeType: + """ + Reference implementation of torch.nn.functional.gelu + """ + if not isinstance(a, TensorLike): + raise RuntimeError( + "Expected a tensor input for an elementwise unary operation!" + ) + M_SQRT2 = 1.41421356237309504880 + M_SQRT1_2 = 0.70710678118654752440 + M_2_SQRTPI = 1.12837916709551257390 + if approximate == "tanh": + kBeta = M_SQRT2 * M_2_SQRTPI * 0.5 + kKappa = 0.044715 + a_cube = a * a * a + inner = kBeta * (a + kKappa * a_cube) + return 0.5 * a * (1 + torch.tanh(inner)) + elif approximate == "none": + kAlpha = M_SQRT1_2 + return a * 0.5 * (1 + torch.erf(a * kAlpha)) + else: + raise RuntimeError("approximate argument must be either none or tanh.") + + +# CompositeImplicitAutograd - don't register decomp +@elementwise_type_promotion_wrapper( + type_promoting_args=("input", "target"), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, +) +def poisson_nll_loss( + input: TensorLikeType, + target: TensorLikeType, + log_input: bool = True, + full: bool = False, + size_average: Optional[bool] = None, + eps: float = 1e-8, + reduce: Optional[bool] = None, + reduction: str = "mean", +) -> TensorLikeType: + """ + Reference implementation of torch.nn.functional.poisson_nll_loss + """ + if size_average is not None or reduce is not None: + # TODO: Raise exception instead of converting value. This is only for + # primTorch since it can drop support for deprecated arguments. + # msg = "size_average and reduce args are deprecated, please use reduction argument." + reduction = _get_string_reduction_arg(size_average=size_average, reduce=reduce) + _check_reduction_value(reduction) + if log_input: + loss = torch.exp(input) - target * input + else: + loss = input - target * torch.log(input + eps) + + if full: + stirling_term = ( + target * torch.log(target) - target + 0.5 * torch.log(2 * torch.pi * target) + ) + # avoid inplace add + loss = loss + stirling_term.masked_fill(target <= 1, 0) + return _apply_loss_reduction(loss, reduction) + + +@register_decomposition(aten.prelu) +@elementwise_type_promotion_wrapper( + type_promoting_args=("a", "weight"), + type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def prelu(a: TensorLikeType, weight: TensorLikeType) -> TensorLikeType: + """ + Reference implementation of torch.nn.functional.prelu + """ + torch._check( + isinstance(a, TensorLike), + lambda: f"prelu: Expected `a` to be tensor, but got: {type(a)}", + ) + torch._check( + isinstance(weight, TensorLike), + lambda: f"prelu: Expected `weight` to be tensor, but got: {type(weight)}", + ) + + if weight.numel() != 1: + torch._check(a.ndim > 0, lambda: "Not allow zero-dim input tensor.") + channel_size = a.shape[1] if a.ndim >= 2 else 1 + torch._check( + weight.numel() == channel_size, + lambda: f"Mismatch of parameter numbers and input channel size. Found parameter numbers =" + f" {weight.numel()} and channel size = {channel_size}.", + ) + + torch._check( + weight.ndim == 0 or weight.ndim == 1, + lambda: f"prelu: Expected `weight` to be a scalar or 1D tensor, but got: " + f"ndim = {weight.ndim}", + ) + if a.ndim == 0: + weight = weight[0] if weight.ndim == 1 else weight + else: + weight = prims.broadcast_in_dim( + weight, a.shape, () if weight.ndim == 0 else (0 if a.ndim == 1 else 1,) + ) + + return torch.where(a > 0, a, a * weight) + + +@register_decomposition(aten.relu6) +@_inplace_wrapper +@out_wrapper() +def relu6(a: TensorLikeType, inplace: bool = False) -> TensorLikeType: + """ + Reference implementation of torch.nn.functional.relu6 + """ + if inplace: + raise NotImplementedError + + # See https://github.com/pytorch/pytorch/pull/81142#discussion_r918220126 + # It may be better to use clamp here, but we use hardtanh to replicate + # the behavior of the existing implementation + return torch.nn.functional.hardtanh(a, 0, 6) + + +@register_decomposition(aten.glu) +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("a",), + type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def glu(a: TensorLikeType, dim: int = -1) -> TensorLikeType: + dim = utils.canonicalize_dims(a.ndim, dim) + torch._check( + a.shape[dim] % 2 == 0, + lambda: f"Halving dimension must be even, but dimension {dim} is size {a.shape[dim]}", + ) + b, c = torch.tensor_split(a, 2, dim) + + return b * torch.sigmoid(c) + + +@register_decomposition(aten.pairwise_distance) +@out_wrapper() +def pairwise_distance( + x1: TensorLikeType, + x2: TensorLikeType, + p: NumberType = 2.0, + eps: NumberType = 1e-6, + keepdim=False, +) -> TensorLikeType: + return torch.linalg.vector_norm(x1 - x2 + eps, ord=p, dim=-1, keepdim=keepdim) + + +@register_decomposition(aten.pdist) +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("a",), + type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def pdist(a: TensorLikeType, p: float = 2) -> TensorLikeType: + torch._check(a.ndim == 2, lambda: f"pdist only supports 2D tensors, got: {a.ndim}D") + torch._check(p >= 0, lambda: "pdist only supports non-negative p values") + # For p == 2 we can use an efficient implementation, but other values of p + # require creating a much bigger tensor for an intermediate step + if p == 2: + aTa = torch.mm(a, a.T) + aTa_diag = torch.diag(aTa) + t = torch.sqrt(torch.clamp(aTa_diag + aTa_diag.unsqueeze(-1) - 2 * aTa, min=0)) + else: + t = torch.linalg.vector_norm(a.unsqueeze(1) - a, ord=p, dim=2) + i = torch.triu_indices(t.shape[0], t.shape[1], offset=1, device=a.device) + return t.flatten().index_select(0, i[0] * t.shape[0] + i[1]) + + +@register_decomposition(aten.pixel_shuffle) +@out_wrapper() +def pixel_shuffle(self: Tensor, upscale_factor: int): + torch._check( + self.dim() >= 3, + lambda: f"pixel_shuffle expects input to have at least 3 dimensions, but got input with {self.dim} dimension(s)", + ) + batch = self.shape[:-3] + C_out = self.shape[-3] // upscale_factor**2 + HW_out = (self.shape[-2] * upscale_factor, self.shape[-1] * upscale_factor) + n = len(batch) + B_dims = range(n) + C_dim, r1_dim, r2_dim, H_dim, W_dim = range(n, n + 5) + return ( + self.view( + *batch, + C_out, + upscale_factor, + upscale_factor, + self.shape[-2], + self.shape[-1], + ) + .permute(*B_dims, C_dim, H_dim, r1_dim, W_dim, r2_dim) + .reshape(*batch, C_out, *HW_out) + .clone(memory_format=utils.suggest_memory_format(self)) + ) + + +@register_decomposition(aten.pixel_unshuffle) +@out_wrapper() +def pixel_unshuffle(self: Tensor, downscale_factor: int): + torch._check( + self.dim() >= 3, + lambda: f"pixel_unshuffle expects input to have at least 3 dimensions, but got input with {self.dim} dimension(s)", + ) + batch = self.shape[:-3] + C_out = self.shape[-3] * downscale_factor**2 + HW_out = (self.shape[-2] // downscale_factor, self.shape[-1] // downscale_factor) + n = len(batch) + B_dims = range(n) + C_dim, H_dim, r1_dim, W_dim, r2_dim = range(n, n + 5) + return ( + self.view( + *batch, + self.shape[-3], + HW_out[0], + downscale_factor, + HW_out[1], + downscale_factor, + ) + .permute(*B_dims, C_dim, r1_dim, r2_dim, H_dim, W_dim) + .reshape(*batch, C_out, *HW_out) + .clone(memory_format=utils.suggest_memory_format(self)) + ) + + +# Needed as aten.{celu_,elu_...} exist (even if they don't have the in-place kwarg) +celu_ = _make_inplace(celu) +elu_ = _make_inplace(elu) +mish_ = _make_inplace(mish) +selu_ = _make_inplace(selu) +threshold_ = _make_inplace(threshold) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_refs/nn/functional/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_refs/nn/functional/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f025305419e723d67ca540baeeae2b9c5d328ea9 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_refs/nn/functional/__pycache__/__init__.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_refs/special/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_refs/special/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a7351fb8f10cad27819c4ee7cf17805a3f3d37bc --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_refs/special/__init__.py @@ -0,0 +1,238 @@ +# mypy: allow-untyped-defs +import math +from typing import Optional, Union + +import torch +import torch._prims as prims +import torch._prims_common as utils +import torch._refs as refs +from torch import Tensor +from torch._decomp import register_decomposition +from torch._prims_common import ( + ELEMENTWISE_TYPE_PROMOTION_KIND, + Number, + NumberType, + TensorLike, + TensorLikeType, +) +from torch._prims_common.wrappers import elementwise_type_promotion_wrapper, out_wrapper +from torch._refs import ( + _make_alias, + _make_elementwise_binary_reference, + _make_elementwise_unary_reference, +) + + +__all__ = [ + "bessel_j0", + "bessel_j1", + "entr", + "erfcx", + "expit", + "i0e", + "i1", + "i1e", + "log_ndtr", + "logit", + "log_softmax", + "multigammaln", + "ndtr", + "ndtri", + "softmax", + "spherical_bessel_j0", + "xlog1py", + "zeta", +] +aten = torch._ops.ops.aten + + +@_make_elementwise_unary_reference( + ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, +) +def bessel_j0(a: TensorLikeType) -> TensorLikeType: + return prims.bessel_j0(a) + + +@_make_elementwise_unary_reference( + ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, +) +def bessel_j1(a: TensorLikeType) -> TensorLikeType: + return prims.bessel_j1(a) + + +@register_decomposition(aten.special_entr) +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("a",), + type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, +) +def entr(a: TensorLikeType) -> TensorLikeType: + return torch.where( + torch.isnan(a), + a, + torch.where(a > 0, -a * torch.log(a), torch.where(a == 0, 0, -torch.inf)), + ) + + +@register_decomposition(aten.special_erfcx) +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("a",), + type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, +) +def erfcx(a: TensorLikeType) -> TensorLikeType: + return prims.erfcx(a) + + +# alias for sigmoid +expit = _make_alias(torch.sigmoid, "expit") + + +@_make_elementwise_unary_reference( + ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, +) +def i0e(a: TensorLikeType) -> TensorLikeType: + return prims.bessel_i0e(a) + + +@_make_elementwise_unary_reference( + ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, +) +def i1(a: TensorLikeType) -> TensorLikeType: + return prims.bessel_i1(a) + + +@_make_elementwise_unary_reference( + ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, +) +def i1e(a: TensorLikeType) -> TensorLikeType: + return prims.bessel_i1e(a) + + +@register_decomposition(aten.special_log_ndtr) +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("a",), + type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, +) +def log_ndtr(a: TensorLikeType) -> TensorLikeType: + # Note: M_SQRT1_2 is the value of 1 / sqrt(2) + M_SQRT1_2 = 0.707106781186547524400844362104849039 + t = a * M_SQRT1_2 + return torch.where( + a < 1.0, + torch.log(torch.special.erfcx(-t) / 2) - t * t, + torch.log1p(-torch.erfc(t) / 2), + ) + + +@register_decomposition(aten.logit) +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("self",), + type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, +) +def logit(self: TensorLikeType, eps: Optional[float] = None) -> TensorLikeType: + if eps is None: + eps = -1.0 + lo = eps + hi = 1 - eps + self = torch.where(self < lo, lo, torch.where(self > hi, hi, self)) + return torch.log(torch.true_divide(self, torch.sub(1, self))) + + +@register_decomposition(aten.special_xlog1py) +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("a", "b"), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, +) +def xlog1py(a: Union[TensorLikeType, NumberType], b: Union[TensorLikeType, NumberType]): + torch._check( + isinstance(a, TensorLike) or isinstance(b, TensorLike), + lambda: 'Expected either argument a or b to be a Tensor"', + ) + + # Operations like eq and log do not handle scalar values, so we convert them to scalar_tensors. + if isinstance(a, TensorLike) and isinstance(b, Number): + # pyrefly: ignore [bad-argument-type] + b = refs.scalar_tensor(b, dtype=a.dtype, device=a.device) + elif isinstance(b, TensorLike) and isinstance(a, Number): + # pyrefly: ignore [bad-argument-type] + a = refs.scalar_tensor(a, dtype=b.dtype, device=b.device) + + # mypy: expected "Tensor" + assert isinstance(a, TensorLike) + assert isinstance(b, TensorLike) + rhs = torch.where(torch.eq(a, 0), 0, torch.mul(a, torch.log1p(b))) + return torch.where(torch.isnan(b), float("nan"), rhs) + + +@register_decomposition(aten.mvlgamma) +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("a",), + type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, +) +def multigammaln(a: TensorLikeType, p: int) -> TensorLikeType: + c = 0.25 * p * (p - 1) * math.log(math.pi) + b = 0.5 * torch.arange(start=(1 - p), end=1, step=1, dtype=a.dtype, device=a.device) + return torch.sum(torch.lgamma(a.unsqueeze(-1) + b), dim=-1) + c + + +@register_decomposition(aten.special_ndtr) +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("a",), + type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, +) +def ndtr(a: TensorLikeType) -> TensorLikeType: + # Note: M_SQRT1_2 is the value of 1 / sqrt(2) + M_SQRT1_2 = 0.707106781186547524400844362104849039 + a_sqrt_2 = a * M_SQRT1_2 + return (1 + torch.erf(a_sqrt_2)) * 0.5 + + +@register_decomposition(aten.special_ndtri) +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("a",), + type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, +) +def ndtri(a: TensorLikeType) -> TensorLikeType: + return prims.ndtri(a) + + +# Forwarding alias: the special variant doesn't support the out kwarg +# CompositeImplicitAutograd - don't register decomp +def log_softmax( + a: TensorLikeType, + dim: int, + dtype: Optional[torch.dtype] = None, +) -> TensorLikeType: + return torch.log_softmax(a=a, dim=dim, dtype=dtype) # type: ignore[call-overload] + + +# Forwarding alias: the special variant doesn't support the out kwarg +# CompositeImplicitAutograd - don't register decomp +def softmax( + a: TensorLikeType, + dim: int, + dtype: Optional[torch.dtype] = None, +) -> TensorLikeType: + return torch.softmax(a=a, dim=dim, dtype=dtype) # type: ignore[call-overload] + + +@_make_elementwise_unary_reference( + ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, +) +def spherical_bessel_j0(a: TensorLikeType) -> TensorLikeType: + return prims.spherical_bessel_j0(a) + + +# TODO: add docstring +@_make_elementwise_binary_reference( + type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, +) +def zeta(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: + return prims.zeta(a, b) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_refs/special/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_refs/special/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7e4fc5357bfbb08c33bf7fd3cf348e3c6f87f6fc Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_refs/special/__pycache__/__init__.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_strobelight/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_strobelight/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2b513700e1c430e45cd16c61043cb244d2fa5947 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_strobelight/__pycache__/__init__.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_strobelight/__pycache__/cli_function_profiler.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_strobelight/__pycache__/cli_function_profiler.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ce08bd053cf58722d330345275bcc9bab7246314 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_strobelight/__pycache__/cli_function_profiler.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_strobelight/__pycache__/compile_time_profiler.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_strobelight/__pycache__/compile_time_profiler.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f6aae733ee41e687661d5f46ecbf9391da7555d7 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_strobelight/__pycache__/compile_time_profiler.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_subclasses/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_subclasses/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2d364e2fe0072a08579e0ec3a3c7e6be8ac778b2 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_subclasses/__pycache__/__init__.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_subclasses/__pycache__/_fake_tensor_utils.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_subclasses/__pycache__/_fake_tensor_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c8c3e9cc0fc7d86644846babe2b51e63bf1a97a9 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_subclasses/__pycache__/_fake_tensor_utils.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_subclasses/__pycache__/fake_impls.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_subclasses/__pycache__/fake_impls.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c88a59389ac2836223c637a111b0ac7e011d5d01 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_subclasses/__pycache__/fake_impls.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_subclasses/__pycache__/fake_utils.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_subclasses/__pycache__/fake_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..66e3cb102df6dbdf42f0ad57632233db96bcfb13 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_subclasses/__pycache__/fake_utils.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_subclasses/__pycache__/functional_tensor.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_subclasses/__pycache__/functional_tensor.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..33c5cd1b9addef2697057e7bd51113c76c64efed Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_subclasses/__pycache__/functional_tensor.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_subclasses/__pycache__/meta_utils.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_subclasses/__pycache__/meta_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8c4738af5c21140384bf8acd5a39e717b6ee1ceb Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_subclasses/__pycache__/meta_utils.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_subclasses/__pycache__/schema_check_mode.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_subclasses/__pycache__/schema_check_mode.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c2a80b1c1e550d6ff76872d64990298161a9ced8 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_subclasses/__pycache__/schema_check_mode.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_subclasses/complex_tensor/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_subclasses/complex_tensor/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1ab4a816261dc088cb740c6dd85575de8a36a0f5 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_subclasses/complex_tensor/__init__.py @@ -0,0 +1,9 @@ +from ._core import ComplexTensor +from ._ops import ComplexTensorMode, is_complex_tensor + + +__all__ = ["ComplexTensor", "ComplexTensorMode", "is_complex_tensor"] + +ComplexTensor.__module__ = __name__ +ComplexTensorMode.__module__ = __name__ +is_complex_tensor.__module__ = __name__ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_subclasses/complex_tensor/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_subclasses/complex_tensor/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..29250bba9882f73768722ee4fb72f372a2d8b002 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_subclasses/complex_tensor/__pycache__/__init__.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_subclasses/complex_tensor/__pycache__/_core.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_subclasses/complex_tensor/__pycache__/_core.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d87dfc4113f35589af17c6bede6f6660574353d2 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_subclasses/complex_tensor/__pycache__/_core.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_subclasses/complex_tensor/_core.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_subclasses/complex_tensor/_core.py new file mode 100644 index 0000000000000000000000000000000000000000..edd7568b2ef06dc3f3e6e9e2f67a586aa15c984f --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_subclasses/complex_tensor/_core.py @@ -0,0 +1,151 @@ +from __future__ import annotations + +from typing import Any, TYPE_CHECKING +from typing_extensions import Self + +import torch +from torch import Tensor +from torch.autograd import Function + + +if TYPE_CHECKING: + from torch._ops import OpOverload + from torch._prims_common import DeviceLikeType + from torch.autograd.function import FunctionCtx + + +class ComplexTensor(Tensor): + """A class that decomposes all ops on complex Tensors into their real and imaginary parts.""" + + _re: Tensor + _im: Tensor + + def __new__(cls, real: Tensor, imag: Tensor) -> Self: + """Initialize a ComplexTensor from its real and imaginary parts.""" + from ._ops.common import REAL_TO_COMPLEX + + shape = real.shape + device = real.device + + # TODO (hameerabbasi): `torch.compile` sometimes fails here without making these + # contiguous. Why? + real = real.contiguous() + imag = imag.contiguous() + + # TODO (hameerabbasi): + # What should we do with dtype? + # We could convert to the complex type (float32 -> complex64), but we + # can't use that model for say `bfloat16` which does not have a + # corresponding complex dtype. + # If we want to support this complex rep using any float type (see + # https://github.com/pytorch/pytorch/issues/95100) + # We either need to: + # 1) add the complex types for say `complexbf32`, knowing they can't really be used anywhere + # else. + # 2) We use the real float dtype here, and it is up to the user to know + # that dtype=float here really means complex<2xSize> with dtype + # matching that of re/im parts alone + # I'm going with 1 for now, so that I can make gradcheck and some complex + # ops work properly, but might want to discuss this in the RFP. + dtype = REAL_TO_COMPLEX.get(real.dtype) + if dtype is None: + raise TypeError( + "Unsupported dtype for constituent tensors. Supported dtypes are: " + f"{set(REAL_TO_COMPLEX.keys())!r}." + ) + storage_offset = real.storage_offset() + strides = real.stride() + layout = real.layout + pin_memory = real.is_pinned() + + assert shape == imag.shape, f"Expected imag shape {shape}, got {imag.shape}" + assert device == imag.device, ( + f"Expected imag device {device}, got {imag.device}" + ) + assert real.dtype == imag.dtype, ( + f"Expected imag dtype {real.dtype}, got {imag.dtype}" + ) + assert pin_memory == imag.is_pinned(), ( + f"Expected imag pinning {pin_memory}, got {imag.is_pinned()}" + ) + + res = Tensor._make_wrapper_subclass( # type: ignore[attr-defined] + cls, + shape, + device=device, + dtype=dtype, + storage_offset=storage_offset, + strides=strides, + pin_memory=pin_memory, + layout=layout, + requires_grad=False, + ) + res._re = real.clone().detach() + res._im = imag.clone().detach() + + return res + + @property + def re(self) -> Tensor: + return self._re + + @property + def im(self) -> Tensor: + return self._im + + @classmethod + def __torch_dispatch__( + cls, + func: OpOverload, + types: tuple[type, ...], + args: tuple = (), + kwargs: dict | None = None, + ): + from ._ops.common import lookup_complex + + kwargs = {} if kwargs is None else kwargs + + impl = lookup_complex(func, *args, **kwargs) + if impl is None: + return NotImplemented + + return impl(*args, **kwargs) + + @staticmethod + def from_interleaved(t: Tensor) -> ComplexTensor: + t_real = torch.real(t) + t_imag = torch.imag(t) if t.dtype.is_complex else torch.zeros_like(t_real) + return Complex.apply(t_real, t_imag) + + def as_interleaved(self) -> Tensor: + return torch.complex(self.real, self.imag) + + @staticmethod + def __tensor_unflatten__( + inner_tensors: dict[str, Tensor], + meta: Any, + outer_size: tuple[int, ...], + outer_stride: tuple[int, ...], + ) -> ComplexTensor: + assert meta is None + re, im = inner_tensors["re"], inner_tensors["im"] + return ComplexTensor(re, im) + + def __tensor_flatten__(self) -> tuple[list[str], Any]: + return ["re", "im"], None + + def __repr__(self, *, tensor_contents=None) -> str: + return f"ComplexTensor(real={self.re!r}, imag={self.im!r})" + + def is_pinned(self, device: DeviceLikeType | None = None) -> bool: + return self.re.is_pinned(device) + + +class Complex(Function): + @staticmethod + def forward(ctx: FunctionCtx, real: Tensor, imag: Tensor) -> ComplexTensor: # type: ignore[bad-override] + return ComplexTensor(real, imag) + + @staticmethod + def backward(ctx: FunctionCtx, grad_output: ComplexTensor) -> tuple[Tensor, Tensor]: # type: ignore[bad-override] + return grad_output.real, grad_output.imag diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_subclasses/complex_tensor/_ops/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_subclasses/complex_tensor/_ops/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c07bdf6099b65d477e45bc7e18078eb53201dc4e --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_subclasses/complex_tensor/_ops/__init__.py @@ -0,0 +1,5 @@ +from . import aten, prims +from .common import ComplexTensorMode, is_complex_tensor + + +__all__ = ["ComplexTensorMode", "is_complex_tensor", "aten", "prims"] diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_subclasses/complex_tensor/_ops/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_subclasses/complex_tensor/_ops/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b56bc901fe0f84643b5f51f415852f5446d80e06 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_subclasses/complex_tensor/_ops/__pycache__/__init__.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_subclasses/complex_tensor/_ops/__pycache__/aten.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_subclasses/complex_tensor/_ops/__pycache__/aten.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..87765582dacc34b0439a7916dd8377d116467a43 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_subclasses/complex_tensor/_ops/__pycache__/aten.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_subclasses/complex_tensor/_ops/__pycache__/common.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_subclasses/complex_tensor/_ops/__pycache__/common.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5c97c78b218673333344ad31073f0bb82fa3540a Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_subclasses/complex_tensor/_ops/__pycache__/common.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_subclasses/complex_tensor/_ops/__pycache__/prims.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_subclasses/complex_tensor/_ops/__pycache__/prims.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fcd03db9cc4c9e96703b07bf538e410a7597e8b6 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_subclasses/complex_tensor/_ops/__pycache__/prims.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_subclasses/complex_tensor/_ops/aten.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_subclasses/complex_tensor/_ops/aten.py new file mode 100644 index 0000000000000000000000000000000000000000..e638e5413c2cdc4878756c7878fb700d4901c551 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_subclasses/complex_tensor/_ops/aten.py @@ -0,0 +1,934 @@ +from __future__ import annotations + +import warnings +from typing import TYPE_CHECKING + +import torch + +from .._core import ComplexTensor +from .common import ( + _get_func_name, + COMPLEX_TO_REAL, + complex_to_real_dtype, + is_complex, + OpType, + promote_tensors, + register_binary_nonlinear, + register_complex, + register_error, + register_force_test, + register_simple, + split_complex_arg, + split_complex_tensor, +) + + +if TYPE_CHECKING: + from collections.abc import Callable, Sequence + from typing import Any + +aten = torch.ops.aten + + +def register_binary_linear(op: OpType): + def impl_with_alpha( + lhs: ComplexTensor, rhs: ComplexTensor, *args, alpha, **kwargs + ) -> ComplexTensor: + return op(lhs, aten.mul(rhs, alpha, *args, **kwargs), *args, **kwargs) + + def impl(lhs: ComplexTensor, rhs: ComplexTensor, *args, **kwargs) -> ComplexTensor: + alpha = kwargs.pop("alpha", None) + if alpha is not None: + return impl_with_alpha(lhs, rhs, *args, alpha=alpha, **kwargs) + a_r, a_i = split_complex_arg(lhs) + b_r, b_i = split_complex_arg(rhs) + out_dt, (a_r, a_i, b_r, b_i) = promote_tensors(a_r, a_i, b_r, b_i) + u = op(a_r, b_r, *args, **kwargs) + v = op(a_i, b_i, *args, **kwargs) + return ComplexTensor(u.to(out_dt), v.to(out_dt)) + + return register_complex(op, impl) + + +@register_complex(aten.real) +def real_impl(self: ComplexTensor) -> torch.Tensor: + re, _ = split_complex_tensor(self) + return re + + +@register_complex(aten.imag) +def imag_impl(self: ComplexTensor) -> torch.Tensor: + _, im = split_complex_tensor(self) + return im + + +@register_complex(aten.is_pinned) +def is_pinned_impl(self: ComplexTensor, device: torch.device | None = None) -> bool: + return self.is_pinned(device) + + +SIMPLE_OPS_LIST = [ + aten.slice, + aten.flatten, + aten.view, + aten.diagonal, + aten.expand, + aten.unsqueeze, + aten.unsqueeze_, + aten.mean, + aten.sum, + aten.clone, + aten.neg, + aten.flip, + aten.permute, + aten.repeat, + aten.index_select, + aten.split, + aten.split_with_sizes, + aten.cumsum, + aten.detach, + aten.select, + aten.squeeze, + aten.zero_, + aten.transpose, + aten.t, + aten.gather, +] + +for simple_op in SIMPLE_OPS_LIST: + globals()[_get_func_name(simple_op)] = register_simple(simple_op) + +# TODO (hameerabbasi): Not being tested +SIMPLE_FORCE_TESTED_OPS = [ + aten.copy, + aten.col2im, + aten.alias, + aten.lift_fresh, + aten._unsafe_view, + aten.index, + aten._neg_view, + aten.avg_pool2d, + aten.avg_pool3d, + aten.avg_pool2d_backward, + aten.avg_pool3d_backward, + aten.masked_scatter_backward, + aten.select_backward, + aten.slice_backward, + aten.embedding, +] + +for simple_op in SIMPLE_FORCE_TESTED_OPS: + globals()[_get_func_name(simple_op)] = register_force_test( + simple_op, register_simple(simple_op) + ) + +del simple_op + +# some binary ops which we can stamp out +mul_impl = register_binary_nonlinear(aten.mul) +mul__impl = register_binary_nonlinear(aten.mul_) +mm_impl = register_binary_nonlinear(aten.mm) +dot_impl = register_binary_nonlinear(aten.dot) +bmm_impl = register_binary_nonlinear(aten.bmm) + +# TODO (hameerabbasi): Not being tested +convolution_impl = register_force_test( + aten.convolution, register_binary_nonlinear(aten.convolution) +) + +slice_scatter_impl = register_force_test( + aten.slice_scatter, register_binary_linear(aten.slice_scatter) +) +select_scatter_impl = register_force_test( + aten.select_scatter, register_binary_linear(aten.select_scatter) +) + +add_impl = register_binary_linear(aten.add) +add__impl = register_binary_linear(aten.add_) +sub_impl = register_binary_linear(aten.sub) +sub__impl = register_binary_linear(aten.sub_) +diagonal_scatter_impl = register_binary_linear(aten.diagonal_scatter) +fill__impl = register_binary_linear(aten.fill_) + + +@register_complex(aten.rsub) +def rsub_impl(lhs: ComplexTensor, rhs: ComplexTensor, alpha=None) -> ComplexTensor: + if alpha is None: + return torch.sub(rhs, lhs) # type: ignore[bad-return] + return torch.sub(rhs, lhs, alpha=alpha) # type: ignore[bad-return] + + +@register_complex(aten.div) +@register_complex(aten.true_divide) +def div_impl(lhs: ComplexTensor, rhs: ComplexTensor, *, rounding_mode=None): + if rounding_mode is not None: + raise NotImplementedError( + "`rounding_mode` other than `None` not implemented for`ComplexTensor`." + ) + a_r, a_i = split_complex_arg(lhs) + if not is_complex(rhs): + return ComplexTensor(a_r / rhs, a_i / rhs) + b_r, b_i = split_complex_arg(rhs) + out_dt, (a_r, a_i, b_r, b_i) = promote_tensors(a_r, a_i, b_r, b_i) + num_r = a_r * b_r + a_i * b_i + num_i = a_i * b_r - a_r * b_i + den = b_r * b_r + b_i * b_i + return ComplexTensor( + (num_r / den).to(out_dt), + (num_i / den).to(out_dt), + ) + + +@register_complex(aten.reciprocal) +def reciprocal_impl(self: ComplexTensor): + self_r, self_i = split_complex_tensor(self) + out_dt, (self_r, self_i) = promote_tensors(self_r, self_i) + den = self_r * self_r + self_i * self_i + return ComplexTensor( + aten.div(self_r, den).to(out_dt), + aten.div(-self_i, den).to(out_dt), + ) + + +# reductions +@register_complex(aten.prod) +def prod_impl(self: ComplexTensor, *args, **kwargs) -> ComplexTensor: + out_dt, (self,) = promote_tensors(self) + dtype = kwargs.pop("dtype", out_dt) + kwargs["dtype"] = complex_to_real_dtype(self.dtype) + + prod_r = torch.prod(torch.abs(self), *args, **kwargs) + sum_phi = torch.sum(torch.angle(self), *args, **kwargs) + u = prod_r * torch.cos(sum_phi) + v = prod_r * torch.sin(sum_phi) + return ComplexTensor(u, v).to(dtype) # type: ignore[bad-return] + + +@register_complex(aten.pow) +def pow_impl(self: ComplexTensor, exponent: ComplexTensor) -> ComplexTensor: + out_dt, (self, exponent) = promote_tensors(self, exponent) + return torch.exp(exponent * torch.log(self)).to(out_dt) # type: ignore[bad-return] + + +@register_complex(aten.cumprod) +def cumprod_impl(self: ComplexTensor, *args, **kwargs) -> ComplexTensor: + dtype = kwargs.pop("dtype", self.dtype) + kwargs["dtype"] = complex_to_real_dtype(dtype) + + prod_r = torch.cumprod(torch.abs(self), *args, **kwargs) + sum_phi = torch.cumsum(torch.angle(self), *args, **kwargs) + u = prod_r * torch.cos(sum_phi) + v = prod_r * torch.sin(sum_phi) + return ComplexTensor(u, v) + + +# unary funcs, +# most of these are simple or require some kind of identity +@register_complex(aten.abs) +def abs_impl(self: ComplexTensor) -> torch.Tensor: + x, y = split_complex_tensor(self) + out_dt, (x, y) = promote_tensors(x, y) + result = torch.hypot(x, y) + return result.to(out_dt) + + +@register_complex(aten.angle) +def angle_impl(self: ComplexTensor) -> torch.Tensor: + x, y = split_complex_tensor(self) + return torch.atan2(y, x) + + +@register_complex(aten.acos) +def acos_impl(self: ComplexTensor) -> ComplexTensor: + _, y = split_complex_tensor(self) + acosh_z = torch.acosh(self) + assert isinstance(acosh_z, ComplexTensor) + acosh_z_re, acosh_z_im = split_complex_tensor(acosh_z) + sign_im = 2 * torch.signbit(y) - 1 + return ComplexTensor(torch.abs(acosh_z_im), sign_im * torch.abs(acosh_z_re)) + + +@register_complex(aten.asin) +def asin_impl(self: ComplexTensor) -> ComplexTensor: + x, y = split_complex_tensor(self) + asinh_iz = torch.asinh(ComplexTensor(-y, x)) + assert isinstance(asinh_iz, ComplexTensor) + asinh_iz_re, asinh_iz_im = split_complex_tensor(asinh_iz) + return ComplexTensor(asinh_iz_im, -asinh_iz_re) + + +@register_complex(aten.atan) +def atan_impl(self: ComplexTensor) -> ComplexTensor: + x, y = split_complex_tensor(self) + tanh_iz = torch.atanh(ComplexTensor(-y, x)) + assert isinstance(tanh_iz, ComplexTensor) + tanh_iz_re, tanh_iz_im = split_complex_tensor(tanh_iz) + return ComplexTensor(tanh_iz_im, -tanh_iz_re) + + +@register_complex(aten.asinh) +def asinh_impl(self: ComplexTensor) -> ComplexTensor: + out_dt, (self,) = promote_tensors(self) + return torch.log(self + torch.sqrt(self * self + 1)).to(out_dt) # type: ignore[bad-return] + + +@register_complex(aten.acosh) +def acosh_impl(self: ComplexTensor) -> ComplexTensor: + out_dt, (self,) = promote_tensors(self) + return torch.log(self + torch.sqrt(self * self - 1)).to(out_dt) # type: ignore[bad-return] + + +@register_complex(aten.atanh) +def atanh_impl(self: ComplexTensor) -> ComplexTensor: + x, y = split_complex_tensor(self) + out_dt, (x, y) = promote_tensors(x, y) + + ret = 0.5 * ( + torch.log(ComplexTensor(1 + x, y)) - torch.log(ComplexTensor(1 - x, -y)) + ) + assert isinstance(ret, ComplexTensor) + ret_re, ret_im = split_complex_tensor(ret) + + return ComplexTensor(ret_re.to(out_dt), ret_im.to(out_dt)) + + +@register_complex(aten.cos) +def cos_impl(self: ComplexTensor) -> ComplexTensor: + x, y = split_complex_tensor(self) + return torch.cosh(ComplexTensor(-y, x)) # type: ignore[bad-return] + + +@register_complex(aten.cosh) +def cosh_impl(self: ComplexTensor) -> ComplexTensor: + x, y = split_complex_tensor(self) + out_dt, (x, y) = promote_tensors(x, y) + u = torch.cosh(x) * torch.cos(y) + v = torch.sinh(x) * torch.sin(y) + return ComplexTensor(u.to(out_dt), v.to(out_dt)) + + +@register_complex(aten.sin) +def sin_impl(self: ComplexTensor) -> ComplexTensor: + x, y = split_complex_tensor(self) + sinh_iz = torch.sinh(ComplexTensor(-y, x)) + assert isinstance(sinh_iz, ComplexTensor) + sinh_iz_re, sinh_iz_im = split_complex_tensor(sinh_iz) + return ComplexTensor(sinh_iz_im, -sinh_iz_re) + + +@register_complex(aten.sinh) +def sinh_impl(self: ComplexTensor) -> ComplexTensor: + x, y = split_complex_tensor(self) + out_dt, (x, y) = promote_tensors(x, y) + u = torch.sinh(x) * torch.cos(y) + v = torch.cosh(x) * torch.sin(y) + return ComplexTensor(u.to(out_dt), v.to(out_dt)) + + +@register_complex(aten.tan) +def tan_impl(self: ComplexTensor) -> ComplexTensor: + x, y = split_complex_tensor(self) + tanh_iz = torch.tanh(ComplexTensor(-y, x)) + assert isinstance(tanh_iz, ComplexTensor) + tanh_iz_re, tanh_iz_im = split_complex_tensor(tanh_iz) + return ComplexTensor(tanh_iz_im, -tanh_iz_re) + + +@register_complex(aten.tanh) +def tanh_impl(self: ComplexTensor) -> ComplexTensor: + x, y = split_complex_tensor(self) + out_dt, (x, y) = promote_tensors(x, y) + + _2x = 2 * x + _2y = 2 * y + _d = torch.cosh(_2x) + torch.cos(_2y) + _2xsh = torch.sinh(_2x) + + out_re = _2xsh / _d + out_im = torch.sin(_2y) / _d + + return ComplexTensor(out_re.to(out_dt), out_im.to(out_dt)) + + +@register_complex(aten.exp) +def exp_impl(self: ComplexTensor) -> ComplexTensor: + x, y = split_complex_tensor(self) + out_dt, (x, y) = promote_tensors(x, y) + ex = torch.exp(x) + u = ex * torch.cos(y) + v = ex * torch.sin(y) + return ComplexTensor(u.to(out_dt), v.to(out_dt)) + + +@register_complex(aten.expm1) +def expm1_impl(self: ComplexTensor) -> ComplexTensor: + x, y = split_complex_tensor(self) + out_dt, (x, y) = promote_tensors(x, y) + # TODO (hameerabbasi): The two lines below may have numerical issues + ex = torch.exp(x) + u = ex * torch.cos(y) - 1 + v = ex * torch.sin(y) + return ComplexTensor(u.to(out_dt), v.to(out_dt)) + + +@register_complex(aten.log) +def log_impl(self: ComplexTensor) -> ComplexTensor: + out_dt, (self,) = promote_tensors(self) + re = torch.log(torch.abs(self)) + im = torch.angle(self) + return ComplexTensor(re, im).to(out_dt) # type: ignore[bad-return] + + +@register_complex(aten.log1p) +def log1p_impl(self: ComplexTensor) -> ComplexTensor: + x, y = split_complex_tensor(self) + # TODO (hameerabbasi): The line below may have numerical issues + return torch.log(ComplexTensor(x + 1, y)) # type: ignore[bad-return] + + +@register_complex(aten.any) +def any_impl(self: ComplexTensor, *args, **kwargs) -> torch.Tensor: + x, y = split_complex_tensor(self) + return torch.any(x, *args, **kwargs) | torch.any(y, *args, **kwargs) + + +@register_complex(aten.all) +def all_impl(self: ComplexTensor, *args, **kwargs) -> torch.Tensor: + x, y = split_complex_tensor(self) + return torch.any(x, *args, **kwargs) & torch.any(y, *args, **kwargs) + + +@register_complex(aten.eq) +def eq_impl(self: ComplexTensor, rhs: ComplexTensor, *args, **kwargs) -> torch.Tensor: + a_r, a_i = split_complex_arg(self) + b_r, b_i = split_complex_arg(rhs) + return torch.eq(a_r, b_r, *args, **kwargs) & torch.eq(a_i, b_i, *args, **kwargs) + + +@register_complex(aten.ne) +def ne_impl(self: ComplexTensor, rhs: ComplexTensor, *args, **kwargs) -> torch.Tensor: + a_r, a_i = split_complex_tensor(self) + b_r, b_i = split_complex_arg(rhs) + return torch.ne(a_r, b_r, *args, **kwargs) | torch.ne(a_i, b_i, *args, **kwargs) + + +@register_complex(aten.isnan) +def isnan_impl(self: ComplexTensor) -> torch.Tensor: + re, im = split_complex_tensor(self) + return torch.isnan(re) | torch.isnan(im) + + +@register_complex(aten.isinf) +def isinf_impl(self: ComplexTensor) -> torch.Tensor: + re, im = split_complex_tensor(self) + return torch.isinf(re) | torch.isinf(im) + + +@register_complex(aten.isfinite) +def isfinite_impl(self: ComplexTensor) -> torch.Tensor: + re, im = split_complex_tensor(self) + return torch.isfinite(re) & torch.isfinite(im) + + +@register_complex(aten.isclose) +def isclose_impl( + self: ComplexTensor, + rhs: ComplexTensor, + rtol=1e-5, + atol=1e-8, + equal_nan: bool = False, +) -> torch.Tensor: + abs_diff = torch.abs(self - rhs) + abs_other = torch.abs(rhs) + basic_condition = abs_diff <= (rtol * abs_other + atol) + + # This is the nontrivial part + if equal_nan: + a_r, a_i = split_complex_tensor(self) + b_r, b_i = split_complex_arg(rhs) + + a_r_nan = torch.isnan(a_r) + b_r_nan = torch.isnan(b_r) + a_i_nan = torch.isnan(a_i) + b_i_nan = torch.isnan(b_i) + a_nan = a_r_nan | a_i_nan + + # This logical expression makes sure that the isnan of both the real and imaginary parts + # matches (so 1 + nan*i doesn't equal nan + 1*i) + equal_nan_condition = ((a_r_nan == b_r_nan) & (a_i_nan == b_i_nan)) & a_nan + return basic_condition | equal_nan_condition + + return basic_condition + + +ERROR_OPS_LIST = [ + aten.lt, + aten.le, + aten.gt, + aten.ge, + aten.amin, + aten.amax, + aten.clamp, + aten.ceil, + aten.floor, + aten.minimum, + aten.maximum, + aten.trunc, + aten.sign, + aten.argmax, + aten.argmin, + aten.sort, + aten.topk, + aten.round, + aten.fmod, +] + + +ERROR_TYPES = { + aten.minimum: RuntimeError, + aten.maximum: RuntimeError, + aten.argmax: RuntimeError, + aten.argmin: RuntimeError, + aten.sort: RuntimeError, + aten.topk: RuntimeError, +} + + +for err_op in ERROR_OPS_LIST: + globals()[_get_func_name(err_op)] = register_error( + err_op, ERROR_TYPES.get(err_op, NotImplementedError) + ) + +del err_op + + +@register_complex(aten.masked_scatter) +def masked_scatter_impl( + self: ComplexTensor, mask: torch.Tensor, source: ComplexTensor +) -> ComplexTensor: + self_r, self_i = split_complex_tensor(self) + source_r, source_i = split_complex_arg(source) + ret_r = torch.masked_scatter(self_r, mask, source_r) + ret_i = torch.masked_scatter(self_i, mask, source_i) + + return ComplexTensor(ret_r, ret_i) + + +@register_complex(aten.where) +def where_impl(mask: torch.Tensor, x: ComplexTensor, y: ComplexTensor) -> ComplexTensor: + x_r, x_i = split_complex_arg(x) + y_r, y_i = split_complex_arg(y) + + ret_r = torch.where(mask, x_r, y_r) + ret_i = torch.where(mask, x_i, y_i) + + return ComplexTensor(ret_r, ret_i) + + +@register_complex(aten.full_like) +def full_like_impl( + input: ComplexTensor, + fill_value: complex, + *args, + dtype: torch.dtype | None = None, + **kwargs, +) -> torch.Tensor | ComplexTensor: + # Note: Cannot be merged with the cases below due to the `fill_value` argument + input_r, input_i = split_complex_tensor(input) + if dtype is not None and dtype not in COMPLEX_TO_REAL: + return torch.full_like(input_r, fill_value, *args, dtype=dtype, **kwargs) + + if dtype is not None: + kwargs["dtype"] = COMPLEX_TO_REAL[dtype] + + fv_r, fv_i = split_complex_arg(fill_value) + ret_r = torch.full_like(input_r, fv_r, *args, **kwargs) + ret_i = torch.full_like(input_i, fv_i, *args, **kwargs) + + return ComplexTensor(ret_r, ret_i) + + +def register_like(op: OpType) -> Callable[..., torch.Tensor | ComplexTensor]: + def impl( + self: ComplexTensor, *args, dtype: torch.dtype | None = None, **kwargs + ) -> torch.Tensor | ComplexTensor: + self_re, self_im = split_complex_tensor(self) + + if dtype is not None and dtype not in COMPLEX_TO_REAL: + return op(self_re, *args, dtype=dtype, **kwargs) + + if dtype is not None: + kwargs["dtype"] = COMPLEX_TO_REAL[dtype] + + ret_re = op(self_re, *args, **kwargs) + ret_im = op(self_im, *args, **kwargs) + + return ComplexTensor(ret_re, ret_im) + + func_name = _get_func_name(op) + impl.__name__ = func_name + impl.__qualname__ = func_name + + return register_complex(op, impl) + + +LIKE_OPS_LIST = [ + aten.empty_like, + aten.zeros_like, + aten.randn_like, + aten.new_zeros, +] + +for like_op in LIKE_OPS_LIST: + globals()[_get_func_name(like_op)] = register_like(like_op) + +del like_op + + +@register_complex(aten.cat) +def cat_impl(tensors: Sequence[ComplexTensor], dim: int = 0) -> ComplexTensor: + tensors_r = [] + tensors_i = [] + + for t in tensors: + t_r, t_i = split_complex_arg(t) + tensors_r.append(t_r) + tensors_i.append(t_i) + + ret_r = torch.cat(tensors_r, dim=dim) + ret_i = torch.cat(tensors_i, dim=dim) + + return ComplexTensor(ret_r, ret_i) + + +@register_complex(aten.sgn) +def sgn_impl(self: ComplexTensor) -> ComplexTensor: + self_r, self_i = split_complex_tensor(self) + out_dt, (self_r, self_i) = promote_tensors(self_r, self_i) + abs_self = torch.abs(ComplexTensor(self_r, self_i)) + mask = (self_r != 0) | (self_i != 0) + masked_sgn = ComplexTensor( + (self_r / abs_self).to(out_dt), (self_i / abs_self).to(out_dt) + ) + return torch.where(mask, masked_sgn, 0) # type: ignore[bad-return] + + +@register_complex(aten.sqrt) +def sqrt_impl(self: ComplexTensor) -> ComplexTensor: + self_r, self_i = split_complex_tensor(self) + out_dt, (self_r, self_i) = promote_tensors(self_r, self_i) + self = ComplexTensor(self_r, self_i) + self_abs_sqrt = torch.sqrt(torch.abs(self)) + self_half_angle = 0.5 * torch.angle(self) + + ret_r = self_abs_sqrt * torch.cos(self_half_angle) + ret_i = self_abs_sqrt * torch.sin(self_half_angle) + + return ComplexTensor(ret_r.to(out_dt), ret_i.to(out_dt)) + + +@register_complex(aten.rsqrt) +def rsqrt_impl(self: ComplexTensor) -> ComplexTensor: + self_r, self_i = split_complex_tensor(self) + out_dt, (self_r, self_i) = promote_tensors(self_r, self_i) + self = ComplexTensor(self_r, self_i) + self_abs_rsqrt = torch.rsqrt(torch.abs(self)) + self_neg_half_angle = -0.5 * torch.angle(self) + + ret_r = self_abs_rsqrt * torch.cos(self_neg_half_angle) + ret_i = self_abs_rsqrt * torch.sin(self_neg_half_angle) + + return ComplexTensor(ret_r.to(out_dt), ret_i.to(out_dt)) + + +@register_complex(aten.addmm) +def addmm_impl( + input: ComplexTensor, + mat1: ComplexTensor, + mat2: ComplexTensor, + out_dtype: torch.dtype | None = None, + beta: complex = 1, + alpha: complex = 1, +) -> ComplexTensor: + ret = beta * input + alpha * torch.mm(mat1, mat2) + assert isinstance(ret, ComplexTensor) + ret_r, ret_i = split_complex_tensor(ret) + if out_dtype is not None: + out_dtype = COMPLEX_TO_REAL[out_dtype] + ret_r, ret_i = ret_r.to(out_dtype), ret_i.to(out_dtype) + return ComplexTensor(ret_r, ret_i) + + +def elemwise_nonzero(self: ComplexTensor) -> torch.Tensor: + re, im = split_complex_tensor(self) + return (re != 0) | (im != 0) + + +def register_nonzero_impl(op: OpType): + def nonzero_impl( + self: ComplexTensor, other: ComplexTensor, *args, **kwargs + ) -> torch.Tensor: + return op(elemwise_nonzero(self), elemwise_nonzero(other), *args, **kwargs) + + func_name = _get_func_name(op) + nonzero_impl.__name__ = func_name + nonzero_impl.__qualname__ = func_name + + return register_complex(op, nonzero_impl) + + +logical_and_impl = register_nonzero_impl(aten.logical_and) +logical_or_impl = register_nonzero_impl(aten.logical_or) +logical_xor_impl = register_nonzero_impl(aten.logical_xor) + + +@register_complex(aten.logical_not) +def logical_not_impl(self: ComplexTensor, *args, **kwargs) -> torch.Tensor: + return torch.logical_not(elemwise_nonzero(self), *args, **kwargs) + + +@register_complex(aten.view_as_real) +def view_as_real_impl(self: ComplexTensor) -> torch.Tensor: + re, im = split_complex_tensor(self) + return torch.stack([re, im], dim=-1) + + +@register_complex(aten.linalg_vector_norm) +def linalg_vector_norm_impl(self: ComplexTensor, *args, **kwargs) -> torch.Tensor: + return torch.linalg.vector_norm(torch.abs(self), *args, **kwargs) + + +@register_force_test(aten.copy_) +def copy__impl( + self: ComplexTensor | torch.Tensor, + src: ComplexTensor | torch.Tensor, + *args, + **kwargs, +) -> ComplexTensor | torch.Tensor: + if not self.dtype.is_complex: + warnings.warn( + "Casting complex values to real discards the imaginary part", UserWarning + ) + src_re, src_im = split_complex_arg(src) + return self.copy_(src_re) + + self_re, self_im = split_complex_arg(self) + src_re, src_im = split_complex_arg(src) + + ret_re = self_re.copy_(src_re, *args, **kwargs) + ret_im = self_im.copy_(src_im, *args, **kwargs) + + return ComplexTensor(ret_re, ret_im) + + +@register_complex(aten._local_scalar_dense) +def _local_scalar_dense_impl(self: ComplexTensor, *args, **kwargs) -> complex: + x, y = split_complex_tensor(self) + u = aten._local_scalar_dense(x, *args, **kwargs) + v = aten._local_scalar_dense(y, *args, **kwargs) + return complex(u, v) + + +@register_complex(aten.allclose) +def allclose_impl( + input: torch.Tensor, + other: torch.Tensor, + rtol: float = 1e-05, + atol: float = 1e-08, + equal_nan: bool = False, +) -> bool: + return torch.all( + torch.isclose(input, other, rtol=rtol, atol=atol, equal_nan=equal_nan) + ).item() # type: ignore[bad-return] + + +@register_complex(aten.stack) +def stack_impl(self: list[ComplexTensor], *args, **kwargs) -> ComplexTensor: + re_im_tuples = [split_complex_arg(self_i) for self_i in self] + u = torch.stack([c[0] for c in re_im_tuples], *args, **kwargs) + v = torch.stack([c[1] for c in re_im_tuples], *args, **kwargs) + return ComplexTensor(u, v) + + +# TODO (hameerabbasi): Not being tested +@register_complex(aten._conj_physical) +@register_complex(aten.conj_physical) +def conj_physical_impl(self: ComplexTensor) -> ComplexTensor: + re, im = split_complex_tensor(self) + return ComplexTensor(re, -im) + + +# TODO (hameerabbasi): Not being tested +@register_complex(aten._conj) +def _conj_impl(self: ComplexTensor) -> ComplexTensor: + re, im = split_complex_tensor(self) + return ComplexTensor(re, torch._neg_view(im)) + + +@register_complex(aten.index_add) +def index_add_impl( + self: ComplexTensor, dim: int, index: torch.Tensor, source: ComplexTensor, **kwargs +) -> ComplexTensor: + alpha = kwargs.pop("alpha", None) + if alpha is not None: + source = source * alpha + self_re, self_im = split_complex_arg(self) + source_re, source_im = split_complex_arg(source) + + ret_re = self_re.index_add(dim, index, source_re) + ret_im = self_im.index_add(dim, index, source_im) + + return ComplexTensor(ret_re, ret_im) + + +# TODO (hameerabbasi): Not being tested +@register_complex(aten.index_add_) +def index_add__impl( + self: ComplexTensor, dim: int, index: torch.Tensor, source: ComplexTensor, **kwargs +) -> ComplexTensor: + alpha = kwargs.pop("alpha", None) + if alpha is not None: + source = source * alpha + + self_re, self_im = split_complex_arg(self) + source_re, source_im = split_complex_arg(source) + + ret_re = self_re.index_add_(dim, index, source_re) + ret_im = self_im.index_add_(dim, index, source_im) + + return ComplexTensor(ret_re, ret_im) + + +@register_complex(aten.masked_fill) +def masked_fill_impl( + self: ComplexTensor, mask: torch.Tensor, value: complex +) -> ComplexTensor: + self_re, self_im = split_complex_arg(self) + value_re, value_im = split_complex_arg(value) + + ret_re = self_re.masked_fill(mask, value_re) + ret_im = self_im.masked_fill(mask, value_im) + + return ComplexTensor(ret_re, ret_im) + + +# TODO (hameerabbasi): Not being tested +@register_complex(aten.masked_fill_) +def masked_fill__impl( + self: ComplexTensor, mask: torch.Tensor, value: complex +) -> ComplexTensor: + self_re, self_im = split_complex_arg(self) + value_re, value_im = split_complex_arg(value) + + ret_re = self_re.masked_fill_(mask, value_re) + ret_im = self_im.masked_fill_(mask, value_im) + + return ComplexTensor(ret_re, ret_im) + + +@register_complex(aten.constant_pad_nd) +def constant_pad_nd_impl( + self: ComplexTensor, pad, value: complex | None = None +) -> ComplexTensor: + self_re, self_im = split_complex_tensor(self) + if value is None: + ret_re = aten.constant_pad_nd(self_re, pad) + ret_im = aten.constant_pad_nd(self_im, pad) + else: + value_re, value_im = split_complex_arg(value) + ret_re = aten.constant_pad_nd(self_re, pad, value_re) + ret_im = aten.constant_pad_nd(self_im, pad, value_im) + + return ComplexTensor(ret_re, ret_im) + + +@register_complex(aten.var) +def var_impl(self: ComplexTensor, *args, **kwargs) -> torch.Tensor: + self_re, self_im = split_complex_tensor(self) + return torch.var(self_re, *args, **kwargs) + torch.var(self_im, *args, **kwargs) + + +@register_complex(aten.scatter_add) +def scatter_add_impl( + self: ComplexTensor, dim, index, src: ComplexTensor +) -> ComplexTensor: + self_re, self_im = split_complex_arg(self) + src_re, src_im = split_complex_arg(src) + + ret_re = torch.scatter_add(self_re, dim, index, src_re) + ret_im = torch.scatter_add(self_im, dim, index, src_im) + + return ComplexTensor(ret_re, ret_im) + + +@register_complex(aten.scatter_add_) +def scatter_add__impl( + self: ComplexTensor, dim, index, src: ComplexTensor +) -> ComplexTensor: + self_re, self_im = split_complex_arg(self) + src_re, src_im = split_complex_arg(src) + + out_re = self_re.scatter_add_(dim, index, src_re) + out_im = self_im.scatter_add_(dim, index, src_im) + + return ComplexTensor(out_re, out_im) + + +@register_complex(aten.index_put_) +def index_put__impl( + self: ComplexTensor, + indices: tuple[torch.Tensor, ...], + values: ComplexTensor, + accumulate: bool = False, +) -> ComplexTensor: + self_re, self_im = split_complex_arg(self) + values_re, values_im = split_complex_arg(values) + + out_re = self_re.index_put_(indices, values_re, accumulate=accumulate) + out_im = self_im.index_put_(indices, values_im, accumulate=accumulate) + + return ComplexTensor(out_re, out_im) + + +@register_complex(aten.tanh_backward) +def tanh_backward(out_grad: torch.Tensor, y: torch.Tensor): + return out_grad * (1.0 - y * y).conj_physical() + + +@register_complex(aten.diagonal_backward) +def diagonal_backward( + grad_output: torch.Tensor, input_sizes: list[int], offset: int, dim1: int, dim2: int +): + grad_input = grad_output.new_zeros(input_sizes) + return torch.diagonal_scatter(grad_input, grad_output, offset, dim1, dim2) + + +def _dt_to_real(dt: torch.dtype | Any) -> torch.dtype | Any: + if not isinstance(dt, torch.dtype): + return dt + + return COMPLEX_TO_REAL[dt] + + +def register_to_impl(op: OpType): + """Register an op similar to `aten.to`, but may have different signatures.""" + + def impl(self: ComplexTensor, *args, **kwargs) -> torch.Tensor | ComplexTensor: + x, y = split_complex_tensor(self) + try: + args = tuple(_dt_to_real(a) for a in args) + kwargs = {k: _dt_to_real(v) for k, v in kwargs.items()} + except KeyError: + return op(x, *args, **kwargs) + + return ComplexTensor(op(x, *args, **kwargs), op(y, *args, **kwargs)) + + func_name = _get_func_name(op) + impl.__name__ = func_name + impl.__qualname__ = func_name + + return register_complex(op, impl) + + +to_impl = register_to_impl(aten.to) +_to_copy_impl = register_to_impl(aten._to_copy) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_subclasses/complex_tensor/_ops/common.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_subclasses/complex_tensor/_ops/common.py new file mode 100644 index 0000000000000000000000000000000000000000..88532efe224bba013b221000a988b594ea01b2cf --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_subclasses/complex_tensor/_ops/common.py @@ -0,0 +1,317 @@ +from collections.abc import Callable +from typing import Any, overload, TypeAlias +from typing_extensions import TypeIs + +import torch +from torch import Tensor +from torch._decomp import get_decompositions +from torch._ops import OpOverload, OpOverloadPacket +from torch._refs import is_complex as _is_complex +from torch.types import Number +from torch.utils._python_dispatch import TorchDispatchMode +from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten + +from .._core import ComplexTensor + + +OpType: TypeAlias = OpOverloadPacket | OpOverload + +TableType: TypeAlias = dict[OpType, Callable] + +# Mapping from ops to implementations +COMPLEX_OPS_TABLE: TableType = {} + +COMPLEX_TO_REAL = { + torch.complex128: torch.float64, + torch.complex64: torch.float32, + torch.complex32: torch.float16, +} + +REAL_TO_COMPLEX = {v: k for k, v in COMPLEX_TO_REAL.items()} + +# Used to promote dtypes in `promote_real_cpu_tensors` +PROMOTE_TYPES = { + torch.float16: torch.float32, + torch.bfloat16: torch.float32, + torch.complex32: torch.complex64, +} + + +def is_complex_tensor(obj: Any, /) -> TypeIs[ComplexTensor]: + r"""Returns True if the input is a ComplexTensor, else False + + Args: + a: any input + + Examples: + + >>> # xdoctest: +SKIP + >>> from torch.complex import ComplexTensor + >>> data = torch.zeros((3, 2), dtype=torch.complex64) + >>> ct = ComplexTensor.from_interleaved(data) + >>> is_complex_tensor(ct) + True + """ + return isinstance(obj, ComplexTensor) + + +@overload +def promote_tensors( + *tensors: ComplexTensor, +) -> tuple[torch.dtype, tuple[ComplexTensor, ...]]: ... + + +@overload +def promote_tensors( + *tensors: Tensor, +) -> tuple[torch.dtype, tuple[Tensor, ...]]: ... + + +def promote_tensors( + *tensors: Tensor | ComplexTensor, +) -> tuple[torch.dtype, tuple[Tensor | ComplexTensor, ...]]: + """ + Promotes all tensors to a common dtype. + Additionally promotes CPU tensors to at least `float32`. + """ + tensor = next(t for t in tensors if isinstance(t, Tensor)) + out_dt = tensor.dtype + for t in tensors: + if isinstance(t, Tensor): + out_dt = torch.promote_types(out_dt, t.dtype) + + prom_dt = PROMOTE_TYPES.get(out_dt, out_dt) + return out_dt, tuple( + t.to(prom_dt) if isinstance(t, Tensor) else torch.asarray(t, dtype=prom_dt) + for t in tensors + ) + + +def register_complex( + op: OpType, + func_impl: Callable | None = None, +): + """Decorator to register an implementation for some ops in some dispatch tables""" + + def inner(func): + if COMPLEX_OPS_TABLE.get(op, func) is not func: + raise RuntimeError(f"Attempted to register multiple functions for {op}") + COMPLEX_OPS_TABLE[op] = func + return func + + if func_impl is None: + return inner + + return inner(func_impl) + + +FORCE_TEST_LIST: list[OpType] = [] + + +def register_force_test(op: OpType, *args, **kwargs): + """Will attempt to test these ops even if they err on "normal" inputs""" + FORCE_TEST_LIST.append(op) + return register_complex(op, *args, **kwargs) + + +DECOMPOSITIONS = get_decompositions(list(torch.ops.aten)) # type: ignore[no-matching-overload] + + +def lookup_complex(func: OpOverload, *args, **kwargs) -> Callable | None: + """ + Lookup an impl from the table. + + Try the particular overload first, then the overload packet. + + If nothing is found, try the decompositions with both. + """ + return COMPLEX_OPS_TABLE.get( + func, + COMPLEX_OPS_TABLE.get( + func.overloadpacket, + DECOMPOSITIONS.get(func, DECOMPOSITIONS.get(func.overloadpacket)), + ), + ) + + +def is_complex(x: Any, /) -> bool: + """Utility to detect if a given object is (known) to be complex.""" + return (isinstance(x, Tensor) and _is_complex(x)) or isinstance(x, complex) + + +@overload +def split_complex_arg( + arg: Tensor | ComplexTensor, +) -> tuple[Tensor, Tensor]: ... + + +@overload +def split_complex_arg( + arg: complex | Number, +) -> tuple[Number, Number]: ... + + +def split_complex_arg( + arg: Tensor | ComplexTensor | complex | Number, +) -> tuple[Tensor, Tensor] | tuple[Number, Number]: + """ + Split a complex argument into a real/imaginary component. + + If real, use zero for the imaginary part. + """ + if isinstance(arg, ComplexTensor): + return split_complex_tensor(arg) + if isinstance(arg, Tensor): + if is_complex(arg): + return arg.real, arg.imag + return arg, torch.zeros_like(arg) + # TODO (hameerabbasi): Should there be a `torch.SymComplex`? + if isinstance(arg, complex): + return arg.real, arg.imag + if isinstance(arg, float | torch.SymFloat): + return arg, 0.0 + if isinstance(arg, int | torch.SymInt): + return arg, 0 + if isinstance(arg, bool | torch.SymBool): + return arg, False + raise TypeError(f"Expected tensor or number got, {type(arg)}") + + +def split_complex_tensor(complex_tensor: ComplexTensor) -> tuple[Tensor, Tensor]: + """Split a ComplexTensor into its real and imaginary parts.""" + return complex_tensor.re, complex_tensor.im + + +def complex_to_real_dtype(dtype: torch.dtype) -> torch.dtype: + """Convert a complex dtype to the dtype of its real part. Return other dtypes as-is.""" + return COMPLEX_TO_REAL.get(dtype, dtype) + + +def _get_op_name(op: OpType) -> str: + """Get the op name from the op.""" + if isinstance(op, OpOverload): + op = op.overloadpacket + return str(op).split(".", 1)[1] + + +def _get_func_name(op: OpType) -> str: + """Get the name of the implementation function from the op.""" + return f"{_get_op_name(op)}_impl" + + +def register_error(op: OpType, exc_type: type[Exception] = NotImplementedError): + msg = f"`aten.{_get_op_name(op)}` not implemented for `{ComplexTensor.__name__}`." + + def ordered_impl(*args, **kwargs): + raise exc_type(msg) + + func_name = _get_func_name(op) + ordered_impl.__name__ = func_name + ordered_impl.__qualname__ = func_name + + return register_force_test(op, ordered_impl) + + +def register_binary_nonlinear(op: OpType) -> Callable: + """Register a "multiplication-style" op, e.g. aten.mul, aten.mm, ...""" + + def impl(lhs: ComplexTensor, rhs: ComplexTensor, *args, **kwargs) -> ComplexTensor: + a_r, a_i = split_complex_arg(lhs) + b_r, b_i = split_complex_arg(rhs) + out_dt, (a_r, a_i, b_r, b_i) = promote_tensors(a_r, a_i, b_r, b_i) + real = op(a_r, b_r, *args, **kwargs) - op(a_i, b_i, *args, **kwargs) + imag = op(a_r, b_i, *args, **kwargs) + op(a_i, b_r, *args, **kwargs) + return ComplexTensor(real.to(out_dt), imag.to(out_dt)) + + func_name = _get_func_name(op) + impl.__name__ = func_name + impl.__qualname__ = func_name + + return register_complex(op, impl) + + +def register_simple(op: OpType): + """Register an op which can be applied independently to the real and complex parts to get the result.""" + + def impl( + self: ComplexTensor, *args, dtype: torch.dtype | None = None, **kwargs + ) -> ComplexTensor: + x, y = split_complex_tensor(self) + if dtype is not None and dtype not in COMPLEX_TO_REAL: + raise RuntimeError( + "Non-complex `dtype` specified, please write custom impl." + ) + + if dtype in COMPLEX_TO_REAL: + assert dtype is not None + kwargs["dtype"] = COMPLEX_TO_REAL[dtype] + + u = op(x, *args, **kwargs) + v = op(y, *args, **kwargs) + + u_flat, u_spec = tree_flatten(u) + v_flat, v_spec = tree_flatten(v) + assert u_spec == v_spec + out_flat = [ + ComplexTensor(ui, vi) for ui, vi in zip(u_flat, v_flat, strict=False) + ] + return tree_unflatten(out_flat, u_spec) + + func_name = _get_func_name(op) + impl.__name__ = func_name + impl.__qualname__ = func_name + + return register_complex(op, impl) + + +def _as_complex_tensor(arg: Tensor | Any) -> Tensor | ComplexTensor | Any: + """Convert a Tensor with complex dtypes to a ComplexTensor. Pass along other args as-is.""" + if ( + not isinstance(arg, ComplexTensor) + and isinstance(arg, Tensor) + and arg.dtype in COMPLEX_TO_REAL + ): + return ComplexTensor.from_interleaved(arg) + return arg + + +def _as_interleaved(arg: ComplexTensor | Any) -> Tensor | Any: + """Convert a ComplexTensor to a Tensor with a complex dtype. Pass other arguments as-is.""" + if isinstance(arg, ComplexTensor): + return arg.as_interleaved() + return arg + + +class ComplexTensorMode(TorchDispatchMode): + _compile: bool + + """ A TorchDispatchMode to replace any Tensor that has a complex dtype with a ComplexTensor for the computation. """ + + def __init__(self, _dispatch_key=None, *, _compile: bool = False): + """Initialize a ComplexTensorMode. + + Args: + _dispatch_key: passed on to TorchDispatchMode + _compile: Compile the op before the computation + """ + super().__init__(_dispatch_key) + self._compile = _compile + + def __torch_dispatch__( + self, + func: OpOverload, + types: tuple[type], + args: tuple = (), + kwargs: dict[str, Any] | None = None, + ): + if kwargs is None: + kwargs = {} + + # TODO (hameerabbasi): Test perf with `_compile` set to `True` + if self._compile: + func = torch.compile(func) # type: ignore[bad-assignment] + + args = tree_map(_as_complex_tensor, args) + kwargs = tree_map(_as_complex_tensor, kwargs) + + return tree_map(_as_interleaved, func(*args, **kwargs)) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_subclasses/complex_tensor/_ops/prims.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_subclasses/complex_tensor/_ops/prims.py new file mode 100644 index 0000000000000000000000000000000000000000..9a237b32d99042a649632a432290919ea4db9c46 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_subclasses/complex_tensor/_ops/prims.py @@ -0,0 +1,34 @@ +import torch + +from .._core import ComplexTensor +from .common import ( + complex_to_real_dtype, + register_complex, + register_force_test, + split_complex_tensor, +) + + +prims = torch.ops.prims +aten = torch.ops.aten + + +# TODO (hameerabbasi): Not being tested +@register_force_test(prims.convert_element_type) +def convert_element_type_impl(x: ComplexTensor, dtype: torch.dtype) -> ComplexTensor: + dtype = complex_to_real_dtype(dtype) + u, v = split_complex_tensor(x) + u_out = prims.convert_element_type(u, dtype) + v_out = prims.convert_element_type(v, dtype) + + return ComplexTensor(u_out, v_out) + + +@register_complex(prims.conj_physical) +def conj_physical_impl(self: ComplexTensor) -> ComplexTensor: + return aten._conj_physical(self) + + +@register_complex(prims.conj) +def conj_impl(self: ComplexTensor) -> ComplexTensor: + return aten._conj(self) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/nn/qat/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/nn/qat/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5e202951c034e8fd2ab0039e69aba4d25e74a760 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/nn/qat/__pycache__/__init__.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/nn/qat/dynamic/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/nn/qat/dynamic/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..114e1ddf4d4cd6b7c22b1d4eb1e88f77e89b30a9 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/nn/qat/dynamic/__pycache__/__init__.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/nn/qat/dynamic/modules/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/nn/qat/dynamic/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..dca71fcf09b019f3e197576eb415ba4fd54fa28a --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/nn/qat/dynamic/modules/__init__.py @@ -0,0 +1,4 @@ +from .linear import Linear + + +__all__ = ["Linear"] diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/nn/qat/dynamic/modules/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/nn/qat/dynamic/modules/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..71354aba101ef95434ce467d5c01c9014bb4d83f Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/nn/qat/dynamic/modules/__pycache__/__init__.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/nn/qat/dynamic/modules/__pycache__/linear.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/nn/qat/dynamic/modules/__pycache__/linear.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f9f85aead9bc2b7d9d95c96d95a99dfceda6eedf Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/nn/qat/dynamic/modules/__pycache__/linear.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/nn/qat/dynamic/modules/linear.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/nn/qat/dynamic/modules/linear.py new file mode 100644 index 0000000000000000000000000000000000000000..dc2238eedf6f902174421a94702a4188fa463098 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/nn/qat/dynamic/modules/linear.py @@ -0,0 +1,40 @@ +from typing import Optional, TYPE_CHECKING + +import torch + + +if TYPE_CHECKING: + from torch.ao.quantization.qconfig import QConfig + + +__all__ = ["Linear"] + + +class Linear(torch.ao.nn.qat.Linear): + r""" + A linear module attached with FakeQuantize modules for weight, + used for dynamic quantization aware training. + + We adopt the same interface as `torch.nn.Linear`, please see + https://pytorch.org/docs/stable/nn.html#torch.nn.Linear + for documentation. + + Similar to `torch.nn.Linear`, with FakeQuantize modules initialized to + default. + """ + + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + qconfig: Optional["QConfig"] = None, + device: int | str | torch.device | None = None, + dtype: str | None = None, + ) -> None: + super().__init__(in_features, out_features, bias, qconfig, device, dtype) + if not torch.ao.quantization.qconfig._activation_is_memoryless(qconfig): # type: ignore[arg-type] + raise ValueError( + "Dynamic QAT requires a memoryless observer." + + "This means a MovingAverage observer with averaging constant equal to 1" + ) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/nn/qat/modules/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/nn/qat/modules/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..043c4371f66e11bbf09a15397a9446b7b564f556 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/nn/qat/modules/__pycache__/__init__.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/nn/qat/modules/__pycache__/conv.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/nn/qat/modules/__pycache__/conv.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bdca969d4bf2dead926422ca521e104736c74026 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/nn/qat/modules/__pycache__/conv.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/nn/qat/modules/__pycache__/embedding_ops.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/nn/qat/modules/__pycache__/embedding_ops.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..84a6057b9af76d7e578c4f8a591b7860cf206dc5 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/nn/qat/modules/__pycache__/embedding_ops.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/nn/qat/modules/__pycache__/linear.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/nn/qat/modules/__pycache__/linear.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..105b75446ef729f76eb5a2aba5c05f3188cd2675 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/nn/qat/modules/__pycache__/linear.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/__pycache__/_correct_bias.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/__pycache__/_correct_bias.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..03d8545a7bdd4e460468348186a94453388aa6ff Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/__pycache__/_correct_bias.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/__pycache__/_equalize.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/__pycache__/_equalize.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3f8e7dbadf28d20212d99e67bd7e764b6f3e5cc4 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/__pycache__/_equalize.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/__pycache__/fuse_modules.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/__pycache__/fuse_modules.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c18dc18e7cb41258c2448afc66ec6e18daf384c6 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/__pycache__/fuse_modules.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/__pycache__/fuser_method_mappings.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/__pycache__/fuser_method_mappings.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c3730f20830c3e4bd8706e9acae7eabbc233d553 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/__pycache__/fuser_method_mappings.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/__pycache__/qconfig.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/__pycache__/qconfig.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..07f1c9f902f6b9106a69af1d81a9f35288c763d5 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/__pycache__/qconfig.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/__pycache__/quantization_mappings.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/__pycache__/quantization_mappings.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..890a5a12aacbd137ec765be46a39075d54162bcc Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/__pycache__/quantization_mappings.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/__pycache__/quantize.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/__pycache__/quantize.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f1663b4ac5145b862c6adcfab16330529f0d3797 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/__pycache__/quantize.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/__pycache__/quantize_fx.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/__pycache__/quantize_fx.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a8297a66dbb8358dd9c9ebd2086cfaeb9471ef54 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/__pycache__/quantize_fx.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/backend_config/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/backend_config/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2d452359d41c36dc719e6df932b6fb018ca6a36b --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/backend_config/__init__.py @@ -0,0 +1,30 @@ +from .backend_config import ( + BackendConfig, + BackendPatternConfig, + DTypeConfig, + DTypeWithConstraints, + ObservationType, +) +from .executorch import get_executorch_backend_config +from .fbgemm import get_fbgemm_backend_config +from .native import get_native_backend_config, get_native_backend_config_dict +from .onednn import get_onednn_backend_config +from .qnnpack import get_qnnpack_backend_config +from .tensorrt import get_tensorrt_backend_config, get_tensorrt_backend_config_dict + + +__all__ = [ + "get_fbgemm_backend_config", + "get_native_backend_config", + "get_native_backend_config_dict", + "get_qnnpack_backend_config", + "get_tensorrt_backend_config", + "get_tensorrt_backend_config_dict", + "get_executorch_backend_config", + "BackendConfig", + "BackendPatternConfig", + "DTypeConfig", + "DTypeWithConstraints", + "ObservationType", + "get_onednn_backend_config", +] diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/backend_config/_common_operator_config_utils.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/backend_config/_common_operator_config_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..9cb322fd85d2c2b07a68f3b436e4f0536ac87e28 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/backend_config/_common_operator_config_utils.py @@ -0,0 +1,782 @@ +# mypy: allow-untyped-defs +import copy +import operator +from collections import namedtuple +from collections.abc import Callable + +import torch +import torch.ao.nn.intrinsic as nni +import torch.ao.nn.intrinsic.qat as nniqat +import torch.ao.nn.qat as nnqat +import torch.ao.nn.quantized.reference as nnqr +import torch.nn as nn +import torch.nn.functional as F +from torch.ao.quantization.fuser_method_mappings import ( + _sequential_wrapper2, + fuse_conv_bn, + fuse_conv_bn_relu, + fuse_convtranspose_bn, + fuse_linear_bn, +) + +from .backend_config import ( + BackendPatternConfig, + DTypeConfig, + DTypeWithConstraints, + ObservationType, +) + + +__all__: list[str] = [] + +# TODO: rename to be more explicit, e.g. qat_conv_relu +_ConvMetadata = namedtuple( + "_ConvMetadata", + [ + "root", + "transpose", + "bn", + "reference", + "transpose_reference", + "fused_conv_relu", + "fused_conv_bn", + "fused_conv_bn_relu", + "qat", + "relu_qat", + "bn_qat", + "bn_relu_qat", + "func", + "func_transpose", + ], +) +_Conv1dMetadata = _ConvMetadata( + nn.Conv1d, + nn.ConvTranspose1d, + nn.BatchNorm1d, + nnqr.Conv1d, + nnqr.ConvTranspose1d, + nni.ConvReLU1d, + nni.ConvBn1d, + nni.ConvBnReLU1d, + nnqat.Conv1d, + nniqat.ConvReLU1d, + nniqat.ConvBn1d, + nniqat.ConvBnReLU1d, + F.conv1d, + F.conv_transpose1d, +) +_Conv2dMetadata = _ConvMetadata( + nn.Conv2d, + nn.ConvTranspose2d, + nn.BatchNorm2d, + nnqr.Conv2d, + nnqr.ConvTranspose2d, + nni.ConvReLU2d, + nni.ConvBn2d, + nni.ConvBnReLU2d, + nnqat.Conv2d, + nniqat.ConvReLU2d, + nniqat.ConvBn2d, + nniqat.ConvBnReLU2d, + F.conv2d, + F.conv_transpose2d, +) +_Conv3dMetadata = _ConvMetadata( + nn.Conv3d, + nn.ConvTranspose3d, + nn.BatchNorm3d, + nnqr.Conv3d, + nnqr.ConvTranspose3d, + nni.ConvReLU3d, + nni.ConvBn3d, + nni.ConvBnReLU3d, + nnqat.Conv3d, + nniqat.ConvReLU3d, + nniqat.ConvBn3d, + nniqat.ConvBnReLU3d, + F.conv3d, + F.conv_transpose3d, +) + +# Add constraints for fixed qparams ops like sigmoid and tanh to ensure values +# fall within the proper ranges, e.g. [0, 1] for sigmoid, [-1, 1] for tanh +_FIXED_QPARAM_OP_0TO1_CONSTRAINTS = DTypeWithConstraints( + dtype=torch.quint8, + quant_min_lower_bound=0, + quant_max_upper_bound=255, + scale_exact_match=1.0 / 256.0, + zero_point_exact_match=0, +) +_FIXED_QPARAM_OP_NEG1TO1_CONSTRAINTS = DTypeWithConstraints( + dtype=torch.quint8, + quant_min_lower_bound=0, + quant_max_upper_bound=255, + scale_exact_match=2.0 / 256.0, + zero_point_exact_match=128, +) +_FIXED_QPARAMS_OP_TO_CONSTRAINTS: dict[Callable | str, DTypeWithConstraints] = { + torch.nn.Hardsigmoid: _FIXED_QPARAM_OP_0TO1_CONSTRAINTS, + torch.nn.functional.hardsigmoid: _FIXED_QPARAM_OP_0TO1_CONSTRAINTS, + "hardsigmoid": _FIXED_QPARAM_OP_0TO1_CONSTRAINTS, + "hardsigmoid_": _FIXED_QPARAM_OP_0TO1_CONSTRAINTS, + torch.nn.Sigmoid: _FIXED_QPARAM_OP_0TO1_CONSTRAINTS, + torch.sigmoid: _FIXED_QPARAM_OP_0TO1_CONSTRAINTS, + "sigmoid": _FIXED_QPARAM_OP_0TO1_CONSTRAINTS, + "sigmoid_": _FIXED_QPARAM_OP_0TO1_CONSTRAINTS, + torch.nn.Softmax: _FIXED_QPARAM_OP_0TO1_CONSTRAINTS, + torch.nn.Tanh: _FIXED_QPARAM_OP_NEG1TO1_CONSTRAINTS, + torch.tanh: _FIXED_QPARAM_OP_NEG1TO1_CONSTRAINTS, + "tanh": _FIXED_QPARAM_OP_NEG1TO1_CONSTRAINTS, + "tanh_": _FIXED_QPARAM_OP_NEG1TO1_CONSTRAINTS, +} + + +def _get_binary_op_configs( + dtype_configs: list[DTypeConfig], +) -> list[BackendPatternConfig]: + binary_op_configs: list[BackendPatternConfig] = [] + num_tensor_args_to_observation_type_mapping = { + # TODO: this is not used right now since we have extra check in prepare + # will need to change this to NO_OBSERVER later after we implemented + # Tensor dtype inference properly + 0: ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT, + 1: ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT, + 2: ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT, + } + for op_with_quantized_bop_scalar_variant in [ + operator.add, + torch.add, + operator.mul, + torch.mul, + ]: + bop_patterns = [ + (op_with_quantized_bop_scalar_variant, nn.ReLU), + (op_with_quantized_bop_scalar_variant, F.relu), + (op_with_quantized_bop_scalar_variant, torch.relu), + op_with_quantized_bop_scalar_variant, + ] + binary_op_configs.extend( + BackendPatternConfig(bop_pattern) + .set_dtype_configs(dtype_configs) # noqa: E131 + ._set_num_tensor_args_to_observation_type( + num_tensor_args_to_observation_type_mapping + ) + for bop_pattern in bop_patterns + ) + # matmul + binary_op_configs.append( + BackendPatternConfig(torch.matmul).set_dtype_configs(dtype_configs) # noqa: E131 + ) + return binary_op_configs + + +def _get_linear_configs(dtype_configs: list[DTypeConfig]) -> list[BackendPatternConfig]: + """ + Return all configs related to linear modules and ops. + """ + observation_type = ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT + linear_configs: list[BackendPatternConfig] = [] + + # (1) Single linear modules/functions + # ------------------------------------- + # linear module + linear_configs.append( + BackendPatternConfig(torch.nn.Linear) + .set_observation_type(observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs) + .set_root_module(torch.nn.Linear) + .set_reference_quantized_module(nnqr.Linear) + .set_qat_module(nnqat.Linear) + ) + # linear qat module + linear_configs.append( + BackendPatternConfig(nnqat.Linear) + .set_observation_type(observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs) + .set_root_module(torch.nn.Linear) + .set_reference_quantized_module(nnqr.Linear) + ) + # functional linear + linear_configs.append( + BackendPatternConfig(torch.nn.functional.linear) + .set_observation_type(observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs) + ._set_input_type_to_index({"weight": 1, "bias": 2}) + ) + + # (2) Linear + relu + # ------------------- + # 2.1 linear module + relu fusion config + # linear relu, linear module + relu module + linear_configs.append( + BackendPatternConfig((torch.nn.Linear, torch.nn.ReLU)) + .set_dtype_configs(dtype_configs) # noqa: E131 + .set_fuser_method(_sequential_wrapper2(nni.LinearReLU)) + .set_fused_module(nni.LinearReLU) + ) + # linear relu, linear module + functional relu + linear_configs.append( + BackendPatternConfig((torch.nn.Linear, torch.nn.functional.relu)) + .set_dtype_configs(dtype_configs) # noqa: E131 + .set_fuser_method(_sequential_wrapper2(nni.LinearReLU)) + .set_fused_module(nni.LinearReLU) + ) + + # 2.2 linear module + relu, fused module configs + # linear relu, fused module + linear_configs.append( + BackendPatternConfig(nni.LinearReLU) + .set_observation_type(observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs) + .set_root_module(torch.nn.Linear) + .set_reference_quantized_module(nnqr.Linear) + .set_qat_module(nniqat.LinearReLU) + ) + # linear relu, qat fused module + linear_configs.append( + BackendPatternConfig(nniqat.LinearReLU) + .set_observation_type(observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs) + .set_root_module(torch.nn.Linear) + .set_reference_quantized_module(nnqr.Linear) + ) + # 2.3 functional linear + relu configs + # linear relu, functional linear + relu module + linear_configs.append( + BackendPatternConfig((F.linear, torch.nn.ReLU)) + .set_observation_type(observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs) + ) + # linear relu, functional linear + functional relu + linear_configs.append( + BackendPatternConfig((F.linear, F.relu)) + .set_observation_type(observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs) + ) + + # (3) Linear + batchnorm + # ------------------------ + # 3.1 linear bn fusion + linear_configs.append( + BackendPatternConfig((nn.Linear, nn.BatchNorm1d)) + .set_dtype_configs(dtype_configs) # noqa: E131 + .set_fuser_method(fuse_linear_bn) + .set_fused_module(nni.LinearBn1d) + ) + + # 3.2 linear bn fused + # linear bn, fused module + linear_configs.append( + BackendPatternConfig(nni.LinearBn1d) + .set_observation_type(observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs) + .set_root_module(torch.nn.Linear) + .set_reference_quantized_module(nnqr.Linear) + .set_qat_module(nniqat.LinearBn1d) + ) + # linear bn, qat fused module + linear_configs.append( + BackendPatternConfig(nniqat.LinearBn1d) + .set_observation_type(observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs) + .set_root_module(torch.nn.Linear) + .set_reference_quantized_module(nnqr.Linear) + ) + return linear_configs + + +def _get_conv_configs(dtype_configs): + """ + Return all configs related to conv modules and ops. + """ + conv_configs = [] + observation_type = ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT + for convs in [_Conv1dMetadata, _Conv2dMetadata, _Conv3dMetadata]: + # (1) Single conv modules/functions + # ----------------------------------- + # conv module + conv_configs.append( + BackendPatternConfig(convs.root) + .set_observation_type(observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs) + .set_root_module(convs.root) + .set_reference_quantized_module(convs.reference) + .set_qat_module(convs.qat) + ) + # conv qat module + conv_configs.append( + BackendPatternConfig(convs.qat) + .set_observation_type(observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs) + .set_root_module(convs.root) + .set_reference_quantized_module(convs.reference) + ) + # functional conv + conv_configs.append( + BackendPatternConfig(convs.func) + .set_observation_type(observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs) + ._set_input_type_to_index({"weight": 1, "bias": 2}) + ) + + # (2) Conv + relu + # ----------------- + # 2.1 conv module + relu fusion configs + # conv relu fusion, conv module + relu module + conv_configs.append( + BackendPatternConfig((convs.root, torch.nn.ReLU)) + .set_dtype_configs(dtype_configs) # noqa: E131 + .set_fuser_method(_sequential_wrapper2(convs.fused_conv_relu)) + .set_fused_module(convs.fused_conv_relu) + ) + # conv relu fusion, conv module + functional relu + conv_configs.append( + BackendPatternConfig((convs.root, F.relu)) + .set_dtype_configs(dtype_configs) # noqa: E131 + .set_fuser_method(_sequential_wrapper2(convs.fused_conv_relu)) + .set_fused_module(convs.fused_conv_relu) + ) + # 2.2 conv module + relu fused module configs + # conv relu, fused module + conv_configs.append( + BackendPatternConfig(convs.fused_conv_relu) + .set_observation_type(observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs) + .set_root_module(convs.root) + .set_reference_quantized_module(convs.reference) + .set_qat_module(convs.relu_qat) + ) + # conv relu, qat fused module + conv_configs.append( + BackendPatternConfig(convs.relu_qat) + .set_observation_type(observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs) + .set_root_module(convs.root) + .set_reference_quantized_module(convs.reference) + ) + # 2.3 functional conv + relu configs + # conv relu, functional conv + relu module + conv_configs.append( + BackendPatternConfig((convs.func, torch.nn.ReLU)) + .set_observation_type(observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs) + ) + # conv relu, functional conv + functional relu + conv_configs.append( + BackendPatternConfig((convs.func, F.relu)) + .set_observation_type(observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs) + ) + + # fused conv relu + conv_configs.append( + BackendPatternConfig(convs.fused_conv_relu) + .set_dtype_configs(dtype_configs) # noqa: E131 + .set_qat_module(convs.relu_qat) + ) + + conv_configs.append( + BackendPatternConfig(convs.relu_qat) + .set_dtype_configs(dtype_configs) # noqa: E131 + .set_root_module(convs.root) + .set_reference_quantized_module(convs.reference) + ) + + # (3) Conv + batchnorm (+ relu) + # ------------------------------- + # 3.1 conv bn fusion configs + # conv + bn fusion + conv_configs.append( + BackendPatternConfig((convs.root, convs.bn)) + .set_dtype_configs(dtype_configs) # noqa: E131 + .set_fuser_method(fuse_conv_bn) + .set_fused_module(convs.fused_conv_bn) + ) + # conv + bn + relu module fusion + conv_configs.append( + BackendPatternConfig((convs.root, convs.bn, nn.ReLU)) + .set_dtype_configs(dtype_configs) # noqa: E131 + .set_fuser_method(fuse_conv_bn_relu) + .set_fused_module(convs.fused_conv_bn_relu) + ) + # conv + bn + relu functional fusion + conv_configs.append( + BackendPatternConfig((convs.root, convs.bn, F.relu)) + .set_dtype_configs(dtype_configs) # noqa: E131 + .set_root_module(convs.root) + .set_fuser_method(fuse_conv_bn_relu) + .set_fused_module(convs.fused_conv_bn_relu) + ) + # TODO: we can add fusion for torch.relu as well + + # 3.2 conv + bn (+ relu) fused module configs + # fused conv bn + conv_configs.append( + BackendPatternConfig(convs.fused_conv_bn) + .set_dtype_configs(dtype_configs) # noqa: E131 + .set_qat_module(convs.bn_qat) + ) + + # fused conv bn relu + conv_configs.append( + BackendPatternConfig(convs.fused_conv_bn_relu) + .set_dtype_configs(dtype_configs) # noqa: E131 + .set_qat_module(convs.bn_relu_qat) + ) + + # conv bn, qat fused module + conv_configs.append( + BackendPatternConfig(convs.bn_qat) + .set_observation_type(observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs) + .set_root_module(convs.root) + .set_reference_quantized_module(convs.reference) + ) + # conv bn relu, qat fused module + conv_configs.append( + BackendPatternConfig(convs.bn_relu_qat) + .set_observation_type(observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs) + .set_root_module(convs.root) + .set_reference_quantized_module(convs.reference) + ) + + # (4) conv transpose and its fusion + # 4.1 conv transpose config + conv_configs.append( + BackendPatternConfig(convs.transpose) + .set_dtype_configs(dtype_configs) # noqa: E131 + .set_root_module(convs.transpose) + .set_reference_quantized_module(convs.transpose_reference) + ) + + # 4.2 conv transpose + bn fusion + conv_configs.append( + BackendPatternConfig((convs.transpose, convs.bn)) + .set_dtype_configs(dtype_configs) # noqa: E131 + .set_fuser_method(fuse_convtranspose_bn) + .set_root_module(convs.transpose) + .set_reference_quantized_module(convs.transpose_reference) + ) + + # 4.3 functional conv transpose + conv_configs.append( + BackendPatternConfig(convs.func_transpose) + .set_dtype_configs(dtype_configs) # noqa: E131 + ._set_input_type_to_index({"weight": 1, "bias": 2}) + ) + + return conv_configs + + +def _get_cat_config(dtype_configs: list[DTypeConfig]) -> BackendPatternConfig: + return ( + BackendPatternConfig(torch.cat) + .set_observation_type(ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT) + .set_dtype_configs(dtype_configs) + ) + + +def _get_ln_configs(dtype_configs: list[DTypeConfig]) -> list[BackendPatternConfig]: + ln_configs = [] + ln_configs.append( + BackendPatternConfig(torch.nn.LayerNorm) + .set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT) # noqa: E131 + .set_dtype_configs(dtype_configs) + ) + ln_configs.append( + BackendPatternConfig(torch.nn.functional.layer_norm) + .set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT) # noqa: E131 + .set_dtype_configs(dtype_configs) + ._set_input_type_to_index({"weight": 2, "bias": 3}) + ) + return ln_configs + + +def _get_default_op_configs( + dtype_configs: list[DTypeConfig], +) -> list[BackendPatternConfig]: + default_ops = [ + torch.nn.ELU, + torch.nn.LeakyReLU, + torch.nn.Hardswish, + torch.nn.InstanceNorm1d, + torch.nn.InstanceNorm2d, + torch.nn.InstanceNorm3d, + torch.nn.Dropout, + torch.nn.PReLU, + torch.nn.functional.elu, + torch.nn.functional.hardswish, + torch.nn.functional.leaky_relu, + torch.nn.functional.dropout, + ] + configs = [ + BackendPatternConfig(op) + .set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT) # noqa: E131 + .set_dtype_configs(dtype_configs) + for op in default_ops + ] + + configs.append( + BackendPatternConfig(torch.nn.functional.group_norm) + .set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT) # noqa: E131 + .set_dtype_configs(dtype_configs) + ._set_input_type_to_index({"weight": 2, "bias": 3}) + ) + + configs.append( + BackendPatternConfig(torch.nn.functional.instance_norm) + .set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT) # noqa: E131 + .set_dtype_configs(dtype_configs) + ._set_input_type_to_index({"weight": 3, "bias": 4}) + ) + return configs + + +def _add_fixed_qparams_to_dtype_configs( + dtype_configs: list[DTypeConfig], + constraints: DTypeWithConstraints, +) -> list[DTypeConfig]: + """ + Return a copy of the list of DTypeConfigs where activations are subject to the specified + constraints required for fixed qparams ops. + + If the data type doesn't match the one in the constraints, simply leave the corresponding + DTypeConfig unchanged. + + If `scale_min_lower_bound` or `scale_max_upper_bound` is specified in the activations, + throw an exception since these settings are incompatible with fixed qparams ops. + """ + new_dtype_configs = [] + for dtype_config in dtype_configs: + dc = copy.deepcopy(dtype_config) + for orig_constraints in [ + dc.input_dtype_with_constraints, + dc.output_dtype_with_constraints, + ]: + if orig_constraints.dtype != constraints.dtype: + continue + if orig_constraints.scale_min_lower_bound is not None: + raise ValueError( + f"scale_min_lower_bound is invalid for fixed qparams ops: {dtype_config}" + ) + if orig_constraints.scale_max_upper_bound is not None: + raise ValueError( + f"scale_max_upper_bound is invalid for fixed qparams ops: {dtype_config}" + ) + orig_constraints.quant_min_lower_bound = constraints.quant_min_lower_bound + orig_constraints.quant_max_upper_bound = constraints.quant_max_upper_bound + orig_constraints.scale_exact_match = constraints.scale_exact_match + orig_constraints.zero_point_exact_match = constraints.zero_point_exact_match + new_dtype_configs.append(dc) + return new_dtype_configs + + +def _get_fixed_qparams_op_configs( + dtype_configs: list[DTypeConfig], +) -> list[BackendPatternConfig]: + fixed_qparams_op_configs = [] + for fixed_qparam_op, constraints in _FIXED_QPARAMS_OP_TO_CONSTRAINTS.items(): + new_dtype_configs = _add_fixed_qparams_to_dtype_configs( + dtype_configs, constraints + ) + fixed_qparams_op_configs.append( + BackendPatternConfig(fixed_qparam_op) + .set_observation_type( + ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT + ) # noqa: E131 + .set_dtype_configs(new_dtype_configs) + ) + return fixed_qparams_op_configs + + +def _get_share_qparams_op_configs(dtype_configs): + """Get the operator config for the operators that works for both float and quantized input + if input is quantized, the output Tensor shares the same quantization parameter + with input. + Example operator: avgpool2d, reshape, transpose, maxpool2d + Example observed operator: + observer_0 - avgpool2d - observer_0 (same observer instance as input) + """ + + def _get_share_qprams_op_backend_config(op): + return ( + BackendPatternConfig(op) + .set_observation_type(ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT) + .set_dtype_configs(dtype_configs) + ) + + share_qparams_ops = [ + torch.nn.AdaptiveAvgPool1d, + torch.nn.AdaptiveAvgPool2d, + torch.nn.AdaptiveAvgPool3d, + torch.nn.AvgPool1d, + torch.nn.AvgPool2d, + torch.nn.AvgPool3d, + torch.nn.Hardtanh, + torch.nn.Identity, + torch.nn.MaxPool1d, + torch.nn.MaxPool2d, + torch.nn.MaxPool3d, + torch.nn.PixelShuffle, + torch.nn.PixelUnshuffle, + torch.nn.ReLU, + torch.nn.ReLU6, + torch.adaptive_avg_pool1d, + torch.nn.functional.adaptive_avg_pool2d, + torch.nn.functional.adaptive_avg_pool3d, + torch.nn.functional.hardtanh, + torch.nn.functional.hardtanh_, + torch.nn.functional.interpolate, + torch.nn.functional.max_pool1d, + torch.nn.functional.max_pool2d, + torch.nn.functional.max_pool3d, + torch.nn.functional.pixel_shuffle, + torch.nn.functional.pixel_unshuffle, + torch.nn.functional.relu, + torch.nn.functional.relu6, + torch.avg_pool1d, + torch._C._nn.avg_pool2d, + torch._C._nn.avg_pool3d, + torch.clamp, + torch.flatten, + torch.mean, + torch.narrow, + torch.repeat_interleave, + torch.transpose, + torch.squeeze, + torch.stack, + torch.unsqueeze, + operator.floordiv, + "contiguous", + "clamp", + "detach", + "detach_", + "mean", + "permute", + "repeat", + "repeat_interleave", + "reshape", + "resize_", + "relu", + "relu_", + "squeeze", + "squeeze_", + "transpose", + "unsqueeze", + "unsqueeze_", + "view", + ] + return [_get_share_qprams_op_backend_config(op) for op in share_qparams_ops] + + +def _get_bn_configs(dtype_configs: list[DTypeConfig]) -> list[BackendPatternConfig]: + """Get configs related to batchnorm.""" + bn_configs = [] + bn_to_fused_bn = { + torch.nn.BatchNorm2d: nni.BNReLU2d, + torch.nn.BatchNorm3d: nni.BNReLU3d, + } + for bn in bn_to_fused_bn: + fused_bn = bn_to_fused_bn[bn] + # bn module + relu module fusion config + bn_configs.append( + BackendPatternConfig((bn, nn.ReLU)) + .set_dtype_configs(dtype_configs) # noqa: E131 + .set_fuser_method(_sequential_wrapper2(fused_bn)) + .set_fused_module(fused_bn) + ) + # bn module + F.relu fusion config + bn_configs.append( + BackendPatternConfig((bn, F.relu)) + .set_dtype_configs(dtype_configs) # noqa: E131 + .set_fuser_method(_sequential_wrapper2(fused_bn)) + .set_fused_module(fused_bn) + ) + bn_configs.append( + BackendPatternConfig(bn) + .set_observation_type( + ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT + ) # noqa: E131 + .set_dtype_configs(dtype_configs) + ) + + # fused bn configs + for fused_bn in bn_to_fused_bn.values(): + bn_configs.append( + BackendPatternConfig(fused_bn) + .set_observation_type( + ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT + ) # noqa: E131 + .set_dtype_configs(dtype_configs) + ) + return bn_configs + + +def _get_rnn_op_configs(dtype_configs: list[DTypeConfig]) -> list[BackendPatternConfig]: + rnn_op_configs = [] + for rnn_op, ref_rnn_op in [ + (nn.GRUCell, nnqr.GRUCell), + (nn.LSTMCell, nnqr.LSTMCell), + (nn.RNNCell, nnqr.RNNCell), + (nn.LSTM, nnqr.LSTM), + (nn.GRU, nnqr.GRU), + ]: + rnn_op_configs.append( + BackendPatternConfig(rnn_op) + .set_observation_type( + ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT + ) # noqa: E131 + .set_dtype_configs(dtype_configs) + .set_root_module(rnn_op) + .set_reference_quantized_module(ref_rnn_op) + ) + return rnn_op_configs + + +def _get_embedding_op_configs( + dtype_configs: list[DTypeConfig], +) -> list[BackendPatternConfig]: + embedding_op_configs = [] + for embedding_op, qat_embedding_op, ref_embedding_op in [ + (nn.Embedding, nnqat.Embedding, nnqr.Embedding), + (nn.EmbeddingBag, nnqat.EmbeddingBag, nnqr.EmbeddingBag), + ]: + embedding_op_configs.append( + BackendPatternConfig(embedding_op) + .set_observation_type( + ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT + ) # noqa: E131 + .set_dtype_configs(dtype_configs) + .set_qat_module(qat_embedding_op) + .set_root_module(embedding_op) + .set_reference_quantized_module(ref_embedding_op) + ) + + # config for qat op + embedding_op_configs.append( + BackendPatternConfig(qat_embedding_op) + .set_observation_type( + ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT + ) # noqa: E131 + .set_dtype_configs(dtype_configs) + .set_root_module(embedding_op) + .set_reference_quantized_module(ref_embedding_op) + ) + return embedding_op_configs + + +def _get_tensor_info_op_configs(dtype_configs): + """ + These ops work on tensors of different dtypes but return non-tensors + containing information about the input tensor. + """ + + def _get_config(op): + return ( + BackendPatternConfig(op) + .set_observation_type(ObservationType.INPUT_OUTPUT_NOT_OBSERVED) + .set_dtype_configs(dtype_configs) + ) + + return [_get_config(op) for op in ("shape", "size")] diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/backend_config/_qnnpack_pt2e.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/backend_config/_qnnpack_pt2e.py new file mode 100644 index 0000000000000000000000000000000000000000..d4e67b79c370207c228fd66d33fadad03a58ed2a --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/backend_config/_qnnpack_pt2e.py @@ -0,0 +1,181 @@ +# mypy: allow-untyped-defs +import operator + +import torch +from torch.ao.quantization.backend_config import ( + BackendConfig, + BackendPatternConfig, + DTypeConfig, + ObservationType, +) + + +weighted_op_quint8_dtype_config = DTypeConfig( + input_dtype=torch.quint8, + output_dtype=torch.quint8, + weight_dtype=torch.qint8, + bias_dtype=torch.float, +) + + +def get_linear_configs(): + linear_configs = [] + observation_type = ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT + dtype_configs = [weighted_op_quint8_dtype_config] + + # TODO: need to fix the way we insert observers for this pattern + # should be solved in the new fusion API + # reason that this doesn't work: the pattern is a bit complicated and we don't + # have a way to specify which input of the pattern we would like to observe + # pattern: + # bias input weight + # \ | / + # \ | t + # \ | / + # addmm + # we want to observe "weight" as weight, but there is not way to convey this + # information with current pattern language + # + # right now: + # original: + # weight - t \ + # input - addmm + # observed (no hack): + # weight - t - observer \ + # input - observer - addmm + # target: + # weight - observer - t \ + # input - observer - addmm + + # def root_node_getter(node_pattern): + # addmm, bias, act, weight = node_pattern + # return addmm + + # linear_configs.append( + # BackendPatternConfig((torch.ops.aten.addmm.default, MatchAllNode, MatchAllNode, torch.ops.aten.t.default)) + # .set_observation_type(observation_type) # noqa: E131 + # .set_dtype_configs(dtype_configs) + # ._set_root_node_getter(root_node_getter)) + + linear_configs.append( + BackendPatternConfig(torch.ops.aten.addmm.default) + .set_observation_type(observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs) + ._set_input_type_to_index({"weight": 2, "bias": 0}) + ) + # linear is decomposed to `t - mm` if bias is not present + linear_configs.append( + BackendPatternConfig(torch.ops.aten.mm.default) + .set_observation_type(observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs) + ._set_input_type_to_index({"weight": 1}) + ) + return linear_configs + + +def get_conv_configs(): + conv_configs = [] + observation_type = ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT + dtype_configs = [weighted_op_quint8_dtype_config] + conv_configs.append( + BackendPatternConfig(torch.ops.aten.convolution.default) + .set_observation_type(observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs) + ._set_input_type_to_index({"weight": 1, "bias": 2}) + ) + conv_configs.append( + BackendPatternConfig( + (torch.ops.aten.convolution.default, torch.ops.aten.relu.default) + ) + .set_observation_type(observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs) + ._set_input_type_to_index({"weight": 1, "bias": 2}) + ) + # TODO: remove when functionalization is supported in PT2 mode + conv_configs.append( + BackendPatternConfig( + (torch.ops.aten.convolution.default, torch.ops.aten.relu_.default) + ) + .set_observation_type(observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs) + ._set_input_type_to_index({"weight": 1, "bias": 2}) + ) + return conv_configs + + +def get_pooling_configs(): + backend_pattern_configs = [] + observation_type = ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT + dtype_configs = [weighted_op_quint8_dtype_config] + + def root_node_getter(node_pattern): + _getitem, maxpool, _index = node_pattern + return maxpool + + backend_pattern_configs.append( + BackendPatternConfig() + ._set_pattern_complex_format( + (operator.getitem, torch.ops.aten.max_pool2d_with_indices.default, 0) + ) + .set_observation_type(observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs) + ._set_root_node_getter(root_node_getter) + ) + + return backend_pattern_configs + + +def get_relu_configs(): + backend_pattern_configs = [] + observation_type = ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT + dtype_configs = [weighted_op_quint8_dtype_config] + backend_pattern_configs.append( + BackendPatternConfig(torch.ops.aten.relu.default) + .set_observation_type(observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs) + ) + return backend_pattern_configs + + +def get_binary_op_configs(): + binary_op_configs: list[BackendPatternConfig] = [] + dtype_configs = [weighted_op_quint8_dtype_config] + num_tensor_args_to_observation_type_mapping = { + # TODO: this is not used right now since we have extra check in prepare + # will need to change this to NO_OBSERVER later after we implemented + # Tensor dtype inference properly + 0: ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT, + 1: ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT, + 2: ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT, + } + for op_with_quantized_bop_scalar_variant in [ + torch.ops.aten.add.Tensor, + torch.ops.aten.add_.Tensor, + ]: + bop_patterns = [ + (op_with_quantized_bop_scalar_variant, torch.ops.aten.relu.default), + op_with_quantized_bop_scalar_variant, + # TODO: remove when functionalization is supported in pt2_mode + (op_with_quantized_bop_scalar_variant, torch.ops.aten.relu_.default), + ] + binary_op_configs.extend( + BackendPatternConfig(bop_pattern) + .set_dtype_configs(dtype_configs) # noqa: E131 + ._set_num_tensor_args_to_observation_type( + num_tensor_args_to_observation_type_mapping + ) + for bop_pattern in bop_patterns + ) + + return binary_op_configs + + +def get_qnnpack_pt2e_backend_config(): + return ( + BackendConfig("qnnpack_pytorch_2.0_export") + .set_backend_pattern_configs(get_linear_configs()) + .set_backend_pattern_configs(get_binary_op_configs()) + .set_backend_pattern_configs(get_conv_configs()) + .set_backend_pattern_configs(get_pooling_configs()) + .set_backend_pattern_configs(get_relu_configs()) + ) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/backend_config/backend_config.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/backend_config/backend_config.py new file mode 100644 index 0000000000000000000000000000000000000000..96a0b44a3afdf5a663319e67e1b422e17136d4d5 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/backend_config/backend_config.py @@ -0,0 +1,751 @@ +# mypy: allow-untyped-defs +from __future__ import annotations + +from dataclasses import dataclass +from enum import Enum +from typing import Any, TYPE_CHECKING + +import torch + + +if TYPE_CHECKING: + from collections.abc import Callable + + from torch.ao.quantization.utils import Pattern + + +__all__ = [ + "BackendConfig", + "BackendPatternConfig", + "DTypeConfig", + "DTypeWithConstraints", + "ObservationType", +] + + +# DTypeConfig dict keys +INPUT_DTYPE_DICT_KEY = "input_dtype" +OUTPUT_DTYPE_DICT_KEY = "output_dtype" +WEIGHT_DTYPE_DICT_KEY = "weight_dtype" +BIAS_DTYPE_DICT_KEY = "bias_dtype" +IS_DYNAMIC_DICT_KEY = "is_dynamic" + +# BackendConfig dict keys +NAME_DICT_KEY = "name" +CONFIGS_DICT_KEY = "configs" + +# BackendPatternConfig dict keys +PATTERN_DICT_KEY = "pattern" +PATTERN_COMPLEX_FORMAT_DICT_KEY = "pattern_complex_format" +OBSERVATION_TYPE_DICT_KEY = "observation_type" +DTYPE_CONFIGS_DICT_KEY = "dtype_configs" +ROOT_MODULE_DICT_KEY = "root_module" +QAT_MODULE_DICT_KEY = "qat_module" +REFERENCE_QUANTIZED_MODULE_DICT_KEY = "reference_quantized_module_for_root" +FUSED_MODULE_DICT_KEY = "fused_module" +FUSER_METHOD_DICT_KEY = "fuser_method" +ROOT_NODE_GETTER_DICT_KEY = "root_node_getter" +EXTRA_INPUTS_GETTER_DICT_KEY = "extra_inputs_getter" +NUM_TENSOR_ARGS_TO_OBSERVATION_TYPE_DICT_KEY = "num_tensor_args_to_observation_type" +INPUT_TYPE_TO_INDEX_DICT_KEY = "input_type_to_index" + + +# TODO: maybe rename this to something that's not related to observer +# e.g. QParamsType +class ObservationType(Enum): + """An enum that represents different ways of how an operator/operator pattern + should be observed + """ + + OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT = 0 + """this means input and output are observed with different observers, based + on qconfig.activation + example: conv, linear, softmax + """ + + OUTPUT_SHARE_OBSERVER_WITH_INPUT = 1 + """this means the output will use the same observer instance as input, based + on qconfig.activation + example: torch.cat, maxpool + """ + + INPUT_OUTPUT_NOT_OBSERVED = 2 + """this means the input and output are never observed + example: x.shape, x.size + """ + + +@dataclass +class DTypeWithConstraints: + """ + Config for specifying additional constraints for a given dtype, such as quantization + value ranges, scale value ranges, and fixed quantization params, to be used in + :class:`~torch.ao.quantization.backend_config.DTypeConfig`. + + The constraints currently supported are: + + * `quant_min_lower_bound` and `quant_max_upper_bound`: Lower and upper + bounds for the minimum and maximum quantized values respectively. If + the QConfig's `quant_min` and `quant_max` fall outside this range, + then the QConfig will be ignored. + + * `scale_min_lower_bound` and `scale_max_upper_bound`: Lower and upper + bounds for the minimum and maximum scale values respectively. If the + QConfig's minimum scale value (currently exposed as `eps`) falls below + the lower bound, then the QConfig will be ignored. Note that the upper + bound is currently not enforced. + + * `scale_exact_match` and `zero_point_exact_match`: Exact match requirements + for scale and zero point, to be used for operators with fixed quantization + parameters such as sigmoid and tanh. If the observer specified in the QConfig + is neither `FixedQParamsObserver` nor `FixedQParamsFakeQuantize`, or if + the quantization parameters don't match, then the QConfig will be ignored. + """ + + dtype: torch.dtype | None = None + quant_min_lower_bound: int | float | None = None + quant_max_upper_bound: int | float | None = None + scale_min_lower_bound: int | float | None = None + scale_max_upper_bound: int | float | None = None + scale_exact_match: float | None = None + zero_point_exact_match: int | None = None + + +@dataclass +class DTypeConfig: + """ + Config object that specifies the supported data types passed as arguments to + quantize ops in the reference model spec, for input and output activations, + weights, and biases. + + For example, consider the following reference model: + + quant1 - [dequant1 - fp32_linear - quant2] - dequant2 + + The pattern in the square brackets refers to the reference pattern of + statically quantized linear. Setting the input dtype as `torch.quint8` + in the DTypeConfig means we pass in `torch.quint8` as the dtype argument + to the first quantize op (quant1). Similarly, setting the output dtype as + `torch.quint8` means we pass in `torch.quint8` as the dtype argument to + the second quantize op (quant2). + + Note that the dtype here does not refer to the interface dtypes of the + op. For example, the "input dtype" here is not the dtype of the input + tensor passed to the quantized linear op. Though it can still be the + same as the interface dtype, this is not always the case, e.g. the + interface dtype is fp32 in dynamic quantization but the "input dtype" + specified in the DTypeConfig would still be quint8. The semantics of + dtypes here are the same as the semantics of the dtypes specified in + the observers. + + These dtypes are matched against the ones specified in the user's + QConfig. If there is a match, and the QConfig satisfies the constraints + specified in the DTypeConfig (if any), then we will quantize the given + pattern using this DTypeConfig. Otherwise, the QConfig is ignored and + the pattern will not be quantized. + + Example usage:: + + >>> # xdoctest: +SKIP(failing) + >>> dtype_config1 = DTypeConfig( + ... input_dtype=torch.quint8, + ... output_dtype=torch.quint8, + ... weight_dtype=torch.qint8, + ... bias_dtype=torch.float) + + >>> dtype_config2 = DTypeConfig( + ... input_dtype=DTypeWithConstraints( + ... dtype=torch.quint8, + ... quant_min_lower_bound=0, + ... quant_max_upper_bound=255, + ... ), + ... output_dtype=DTypeWithConstraints( + ... dtype=torch.quint8, + ... quant_min_lower_bound=0, + ... quant_max_upper_bound=255, + ... ), + ... weight_dtype=DTypeWithConstraints( + ... dtype=torch.qint8, + ... quant_min_lower_bound=-128, + ... quant_max_upper_bound=127, + ... ), + ... bias_dtype=torch.float) + + >>> dtype_config1.input_dtype + torch.quint8 + + >>> dtype_config2.input_dtype + torch.quint8 + + >>> dtype_config2.input_dtype_with_constraints + DTypeWithConstraints(dtype=torch.quint8, quant_min_lower_bound=0, quant_max_upper_bound=255, \ +scale_min_lower_bound=None, scale_max_upper_bound=None) + """ + + input_dtype_with_constraints: DTypeWithConstraints + output_dtype_with_constraints: DTypeWithConstraints + weight_dtype_with_constraints: DTypeWithConstraints + bias_dtype: torch.dtype | None + is_dynamic: bool | None + + def __init__( + self, + input_dtype: torch.dtype | DTypeWithConstraints | None = None, + output_dtype: torch.dtype | DTypeWithConstraints | None = None, + weight_dtype: torch.dtype | DTypeWithConstraints | None = None, + bias_dtype: torch.dtype | None = None, + is_dynamic: bool | None = None, + ): + if isinstance(input_dtype, DTypeWithConstraints): + self.input_dtype_with_constraints = input_dtype + else: + self.input_dtype_with_constraints = DTypeWithConstraints(dtype=input_dtype) + + if isinstance(output_dtype, DTypeWithConstraints): + self.output_dtype_with_constraints = output_dtype + else: + self.output_dtype_with_constraints = DTypeWithConstraints( + dtype=output_dtype + ) + + if isinstance(weight_dtype, DTypeWithConstraints): + self.weight_dtype_with_constraints = weight_dtype + else: + self.weight_dtype_with_constraints = DTypeWithConstraints( + dtype=weight_dtype + ) + + self.bias_dtype = bias_dtype + self.is_dynamic = is_dynamic + + @property + def input_dtype(self) -> torch.dtype | None: + return self.input_dtype_with_constraints.dtype + + @property + def output_dtype(self) -> torch.dtype | None: + return self.output_dtype_with_constraints.dtype + + @property + def weight_dtype(self) -> torch.dtype | None: + return self.weight_dtype_with_constraints.dtype + + @classmethod + def from_dict(cls, dtype_config_dict: dict[str, Any]) -> DTypeConfig: + """ + Create a ``DTypeConfig`` from a dictionary with the following items (all optional): + "input_dtype": torch.dtype or ``DTypeWithConstraints`` + "output_dtype": torch.dtype or ``DTypeWithConstraints`` + "weight_dtype": torch.dtype or ``DTypeWithConstraints`` + "bias_type": torch.dtype + "is_dynamic": bool + """ + input_dtype = dtype_config_dict.get(INPUT_DTYPE_DICT_KEY) + if input_dtype is not None and not isinstance( + input_dtype, (torch.dtype, DTypeWithConstraints) + ): + raise ValueError( + "Expected input_dtype to be a torch.dtype or DTypeWithConstraints" + ) + output_dtype = dtype_config_dict.get(OUTPUT_DTYPE_DICT_KEY) + if output_dtype is not None and not isinstance( + output_dtype, (torch.dtype, DTypeWithConstraints) + ): + raise ValueError( + "Expected output_dtype to be a torch.dtype or DTypeWithConstraints" + ) + weight_dtype = dtype_config_dict.get(WEIGHT_DTYPE_DICT_KEY) + if weight_dtype is not None and not isinstance( + weight_dtype, (torch.dtype, DTypeWithConstraints) + ): + raise ValueError( + "Expected weight_dtype to be a torch.dtype or DTypeWithConstraints" + ) + bias_dtype = dtype_config_dict.get(BIAS_DTYPE_DICT_KEY) + is_dynamic = dtype_config_dict.get(IS_DYNAMIC_DICT_KEY) + return cls(input_dtype, output_dtype, weight_dtype, bias_dtype, is_dynamic) + + def to_dict(self) -> dict[str, Any]: + """ + Convert this ``DTypeConfig`` to a dictionary with the items described in + :func:`~torch.ao.quantization.backend_config.DTypeConfig.from_dict`. + """ + dtype_config_dict: dict[str, Any] = {} + if self.input_dtype is not None: + dtype_config_dict[INPUT_DTYPE_DICT_KEY] = self.input_dtype_with_constraints + if self.output_dtype is not None: + dtype_config_dict[OUTPUT_DTYPE_DICT_KEY] = ( + self.output_dtype_with_constraints + ) + if self.weight_dtype is not None: + dtype_config_dict[WEIGHT_DTYPE_DICT_KEY] = ( + self.weight_dtype_with_constraints + ) + if self.bias_dtype is not None: + dtype_config_dict[BIAS_DTYPE_DICT_KEY] = self.bias_dtype + if self.is_dynamic is not None: + dtype_config_dict[IS_DYNAMIC_DICT_KEY] = self.is_dynamic + return dtype_config_dict + + +class BackendConfig: + # TODO: refer to NativeBackendConfig once that is implemented + """Config that defines the set of patterns that can be quantized on a given backend, and how reference + quantized models can be produced from these patterns. + + A pattern in this context refers to a module, a functional, an operator, or a directed acyclic graph + of the above. Each pattern supported on the target backend can be individually configured through + :class:`~torch.ao.quantization.backend_config.BackendPatternConfig` in terms of: + + (1) The supported input/output activation, weight, and bias data types + + (2) How observers and quant/dequant ops are inserted in order to construct the reference pattern, and + + (3) (Optionally) Fusion, QAT, and reference module mappings. + + The format of the patterns is described in: + https://github.com/pytorch/pytorch/blob/master/torch/ao/quantization/backend_config/README.md + + Example usage:: + + import torch + from torch.ao.quantization.backend_config import ( + BackendConfig, + BackendPatternConfig, + DTypeConfig, + ObservationType, + ) + + weighted_int8_dtype_config = DTypeConfig( + input_dtype=torch.quint8, + output_dtype=torch.quint8, + weight_dtype=torch.qint8, + bias_dtype=torch.float) + + def fuse_conv2d_relu(is_qat, conv, relu): + return torch.ao.nn.intrinsic.ConvReLU2d(conv, relu) + + # For quantizing Linear + linear_config = BackendPatternConfig(torch.nn.Linear) \ + .set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT) \ + .add_dtype_config(weighted_int8_dtype_config) \ + .set_root_module(torch.nn.Linear) \ + .set_qat_module(torch.ao.nn.qat.Linear) \ + .set_reference_quantized_module(torch.ao.nn.quantized.reference.Linear) + + # For fusing Conv2d + ReLU into ConvReLU2d + conv_relu_config = BackendPatternConfig((torch.nn.Conv2d, torch.nn.ReLU)) \ + .set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT) \ + .add_dtype_config(weighted_int8_dtype_config) \ + .set_fused_module(torch.ao.nn.intrinsic.ConvReLU2d) \ + .set_fuser_method(fuse_conv2d_relu) + + # For quantizing ConvReLU2d + fused_conv_relu_config = BackendPatternConfig(torch.ao.nn.intrinsic.ConvReLU2d) \ + .set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT) \ + .add_dtype_config(weighted_int8_dtype_config) \ + .set_root_module(torch.nn.Conv2d) \ + .set_qat_module(torch.ao.nn.intrinsic.qat.ConvReLU2d) \ + .set_reference_quantized_module(torch.ao.nn.quantized.reference.Conv2d) + + backend_config = BackendConfig("my_backend") \ + .set_backend_pattern_config(linear_config) \ + .set_backend_pattern_config(conv_relu_config) \ + .set_backend_pattern_config(fused_conv_relu_config) + + """ + + def __init__(self, name: str = ""): + self.name = name + # Store all BackendPatternConfigs in a map to handle duplicates + # Note: the key in this map uses the complex reversed tuple format. + # This is intended only for internal use; users who wish to access + # the original patterns should go through `self.configs` instead. + self._pattern_complex_format_to_config: dict[Pattern, BackendPatternConfig] = {} + + def __repr__(self): + return f"BackendConfig({self.__dict__})" + + def set_name(self, name: str) -> BackendConfig: + """ + Set the name of the target backend. + """ + self.name = name + return self + + def set_backend_pattern_config(self, config: BackendPatternConfig) -> BackendConfig: + """ + Set the config for an pattern that can be run on the target backend. + This overrides any existing config for the given pattern. + """ + # Avoid circular dependencies + pattern_complex_format = torch.ao.quantization.backend_config.utils._get_pattern_in_reversed_nested_tuple_format( + config + ) # type: ignore[attr-defined] + self._pattern_complex_format_to_config[pattern_complex_format] = config + return self + + def set_backend_pattern_configs( + self, configs: list[BackendPatternConfig] + ) -> BackendConfig: + """ + Set the configs for patterns that can be run on the target backend. + This overrides any existing config for a given pattern if it was previously registered already. + """ + for conf in configs: + self.set_backend_pattern_config(conf) + return self + + @property + def configs(self) -> list[BackendPatternConfig]: + """ + Return a copy of the list of configs set in this `BackendConfig`. + """ + return list(self._pattern_complex_format_to_config.values()) + + @classmethod + def from_dict(cls, backend_config_dict: dict[str, Any]) -> BackendConfig: + """ + Create a ``BackendConfig`` from a dictionary with the following items: + + "name": the name of the target backend + + "configs": a list of dictionaries that each represents a `BackendPatternConfig` + + """ + conf = cls(backend_config_dict.get(NAME_DICT_KEY, "")) + for d in backend_config_dict.get(CONFIGS_DICT_KEY, []): + if isinstance(d, BackendPatternConfig): + conf.set_backend_pattern_config(d) + elif isinstance(d, dict): + conf.set_backend_pattern_config(BackendPatternConfig.from_dict(d)) + else: + raise ValueError( + f"Expected backend_config_dict['{CONFIGS_DICT_KEY}'] to be a dictionary" + ) + return conf + + def to_dict(self) -> dict[str, Any]: + """ + Convert this ``BackendConfig`` to a dictionary with the items described in + :func:`~torch.ao.quantization.backend_config.BackendConfig.from_dict`. + """ + return { + NAME_DICT_KEY: self.name, + CONFIGS_DICT_KEY: [c.to_dict() for c in self.configs], + } + + +class BackendPatternConfig: + """ + Config object that specifies quantization behavior for a given operator pattern. + For a detailed example usage, see :class:`~torch.ao.quantization.backend_config.BackendConfig`. + """ + + def __init__(self, pattern: Pattern | None = None): + self.pattern: Pattern | None = pattern + self.observation_type = ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT + self.dtype_configs: list[DTypeConfig] = [] + self.root_module: type[torch.nn.Module] | None = None + self.qat_module: type[torch.nn.Module] | None = None + self.reference_quantized_module: type[torch.nn.Module] | None = None + self.fused_module: type[torch.nn.Module] | None = None + self.fuser_method: Callable | None = None + + # Temporary/internal configs + self._root_node_getter: Callable | None = None + self._extra_inputs_getter: Callable | None = None + self._num_tensor_args_to_observation_type: dict[int, ObservationType] = {} + self._input_type_to_index: dict[str, int] = {} + self._pattern_complex_format: Pattern | None = None + + def __repr__(self): + dict_nonempty = { + k: v + for k, v in self.__dict__.items() + if ( + (not isinstance(v, (list, dict)) and v is not None) + or (isinstance(v, (list, dict)) and len(v) > 0) + ) + } + return f"BackendPatternConfig({dict_nonempty})" + + def set_pattern(self, pattern: Pattern) -> BackendPatternConfig: + """ + Set the pattern to configure. + + The pattern can be a float module, functional operator, pytorch operator, or a tuple + combination of the above. Tuple patterns are treated as sequential patterns, and + currently only tuples of 2 or 3 elements are supported. + """ + if self._pattern_complex_format is not None: + raise ValueError( + "Only one of 'pattern' or 'pattern_complex_format' can be set" + ) + self.pattern = pattern + return self + + def set_observation_type( + self, observation_type: ObservationType + ) -> BackendPatternConfig: + """ + Set how observers should be inserted in the graph for this pattern. + + Observation type here refers to how observers (or quant-dequant ops) will be placed + in the graph. This is used to produce the desired reference patterns understood by + the backend. Weighted ops such as linear and conv require different observers + (or quantization parameters passed to quantize ops in the reference model) for the + input and the output. + + There are two observation types: + + `OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT` (default): the output observer instance + will be different from the input. This is the most common observation type. + + `OUTPUT_SHARE_OBSERVER_WITH_INPUT`: the output observer instance will be the + same as the input. This is useful for operators like `cat`. + + Note: This will be renamed in the near future, since we will soon insert QuantDeQuantStubs + with observers (and fake quantizes) attached instead of observers themselves. + """ + self.observation_type = observation_type + return self + + def add_dtype_config(self, dtype_config: DTypeConfig) -> BackendPatternConfig: + """ + Add a set of supported data types passed as arguments to quantize ops in the + reference model spec. + """ + self.dtype_configs.append(dtype_config) + return self + + def set_dtype_configs( + self, dtype_configs: list[DTypeConfig] + ) -> BackendPatternConfig: + """ + Set the supported data types passed as arguments to quantize ops in the + reference model spec, overriding all previously registered data types. + """ + self.dtype_configs = dtype_configs + return self + + def set_root_module( + self, root_module: type[torch.nn.Module] + ) -> BackendPatternConfig: + """ + Set the module that represents the root for this pattern. + + When we construct the reference quantized model during the convert phase, + the root modules (e.g. torch.nn.Linear for torch.ao.nn.intrinsic.LinearReLU) + will be swapped to the corresponding reference quantized modules (e.g. + torch.ao.nn.reference.quantized.Linear). This allows custom backends to + specify custom reference quantized module implementations to match the + numerics of their lowered operators. Since this is a one-to-one mapping, + both the root module and the reference quantized module must be specified + in the same BackendPatternConfig in order for the conversion to take place. + """ + self.root_module = root_module + return self + + def set_qat_module(self, qat_module: type[torch.nn.Module]) -> BackendPatternConfig: + """ + Set the module that represents the QAT implementation for this pattern. + """ + self.qat_module = qat_module + return self + + def set_reference_quantized_module( + self, reference_quantized_module: type[torch.nn.Module] + ) -> BackendPatternConfig: + """ + Set the module that represents the reference quantized implementation for + this pattern's root module. + + For more detail, see :func:`~torch.ao.quantization.backend_config.BackendPatternConfig.set_root_module`. + """ + self.reference_quantized_module = reference_quantized_module + return self + + def set_fused_module( + self, fused_module: type[torch.nn.Module] + ) -> BackendPatternConfig: + """ + Set the module that represents the fused implementation for this pattern. + """ + self.fused_module = fused_module + return self + + def set_fuser_method(self, fuser_method: Callable) -> BackendPatternConfig: + """ + Set the function that specifies how to fuse this BackendPatternConfig's pattern. + + The first argument of this function should be `is_qat`, and the rest of the arguments + should be the items in the tuple pattern. The return value of this function should be + the resulting fused module. + + For example, the fuser method for the pattern `(torch.nn.Linear, torch.nn.ReLU)` can be: + + def fuse_linear_relu(is_qat, linear, relu): + return torch.ao.nn.intrinsic.LinearReLU(linear, relu) + + For a more complicated example, see https://gist.github.com/jerryzh168/8bea7180a8ba3c279f2c9b050f2a69a6. + """ + self.fuser_method = fuser_method + return self + + def _set_root_node_getter(self, root_node_getter: Callable) -> BackendPatternConfig: + self._root_node_getter = root_node_getter + return self + + def _set_extra_inputs_getter( + self, extra_inputs_getter: Callable + ) -> BackendPatternConfig: + self._extra_inputs_getter = extra_inputs_getter + return self + + def _set_num_tensor_args_to_observation_type( + self, num_tensor_args_to_observation_type: dict[int, ObservationType] + ) -> BackendPatternConfig: + self._num_tensor_args_to_observation_type = num_tensor_args_to_observation_type + return self + + def _set_input_type_to_index( + self, input_type_to_index: dict[str, int] + ) -> BackendPatternConfig: + self._input_type_to_index = input_type_to_index + return self + + def _set_pattern_complex_format(self, pattern: Pattern) -> BackendPatternConfig: + """ + Set the pattern to configure, using the reversed nested tuple format. + + See the BackendConfig README for more detail: + https://github.com/pytorch/pytorch/blob/master/torch/ao/quantization/backend_config/README.md#advanced-pattern-specification + """ + if self.pattern is not None: + raise ValueError( + "Only one of 'pattern' or 'pattern_complex_format' can be set" + ) + self._pattern_complex_format = pattern + return self + + @classmethod + def from_dict( + cls, backend_pattern_config_dict: dict[str, Any] + ) -> BackendPatternConfig: + """ + Create a ``BackendPatternConfig`` from a dictionary with the following items: + + "pattern": the pattern being configured + "observation_type": the :class:`~torch.ao.quantization.backend_config.ObservationType` that specifies how + observers should be inserted for this pattern + "dtype_configs": a list of dictionaries that represents :class:`~torch.ao.quantization.backend_config.DTypeConfig` s + "root_module": a :class:`torch.nn.Module` that represents the root for this pattern + "qat_module": a :class:`torch.nn.Module` that represents the QAT implementation for this pattern + "reference_quantized_module": a :class:`torch.nn.Module` that represents the reference quantized + implementation for this pattern's root module. + "fused_module": a :class:`torch.nn.Module` that represents the fused implementation for this pattern + "fuser_method": a function that specifies how to fuse the pattern for this pattern + "pattern_complex_format": the pattern specified in the reversed nested tuple format (deprecated) + + """ + + def _get_dtype_config(obj: Any) -> DTypeConfig: + """ + Convert the given object into a ``DTypeConfig`` if possible, else throw an exception. + """ + if isinstance(obj, DTypeConfig): + return obj + if isinstance(obj, dict): + return DTypeConfig.from_dict(obj) + raise ValueError( + f"Expected a list of DTypeConfigs in " + f"backend_pattern_config_dict[\"{DTYPE_CONFIGS_DICT_KEY}\"], got '{type(obj)}'" + ) + + conf = cls() + if PATTERN_DICT_KEY in backend_pattern_config_dict: + conf.set_pattern(backend_pattern_config_dict[PATTERN_DICT_KEY]) + if OBSERVATION_TYPE_DICT_KEY in backend_pattern_config_dict: + conf.set_observation_type( + backend_pattern_config_dict[OBSERVATION_TYPE_DICT_KEY] + ) + for d in backend_pattern_config_dict.get(DTYPE_CONFIGS_DICT_KEY, []): + conf.add_dtype_config(_get_dtype_config(d)) + conf.set_root_module( + backend_pattern_config_dict.get(ROOT_MODULE_DICT_KEY) # type: ignore[arg-type] + ) + conf.set_qat_module(backend_pattern_config_dict.get(QAT_MODULE_DICT_KEY)) # type: ignore[arg-type] + conf.set_reference_quantized_module( + backend_pattern_config_dict.get(REFERENCE_QUANTIZED_MODULE_DICT_KEY) # type: ignore[arg-type] + ) + conf.set_fused_module( + backend_pattern_config_dict.get(FUSED_MODULE_DICT_KEY) # type: ignore[arg-type] + ) + conf.set_fuser_method( + backend_pattern_config_dict.get(FUSER_METHOD_DICT_KEY) # type: ignore[arg-type] + ) + conf._set_root_node_getter( + backend_pattern_config_dict.get(ROOT_NODE_GETTER_DICT_KEY) # type: ignore[arg-type] + ) + conf._set_extra_inputs_getter( + backend_pattern_config_dict.get(EXTRA_INPUTS_GETTER_DICT_KEY) # type: ignore[arg-type] + ) + conf._set_num_tensor_args_to_observation_type( + backend_pattern_config_dict.get( + NUM_TENSOR_ARGS_TO_OBSERVATION_TYPE_DICT_KEY, {} + ) + ) + conf._set_input_type_to_index( + backend_pattern_config_dict.get(INPUT_TYPE_TO_INDEX_DICT_KEY, {}) + ) + if PATTERN_COMPLEX_FORMAT_DICT_KEY in backend_pattern_config_dict: + conf._set_pattern_complex_format( + backend_pattern_config_dict[PATTERN_COMPLEX_FORMAT_DICT_KEY] + ) + return conf + + def to_dict(self) -> dict[str, Any]: + """ + Convert this ``BackendPatternConfig`` to a dictionary with the items described in + :func:`~torch.ao.quantization.backend_config.BackendPatternConfig.from_dict`. + """ + backend_pattern_config_dict: dict[str, Any] = { + OBSERVATION_TYPE_DICT_KEY: self.observation_type, + DTYPE_CONFIGS_DICT_KEY: [c.to_dict() for c in self.dtype_configs], + } + if self.pattern is not None: + backend_pattern_config_dict[PATTERN_DICT_KEY] = self.pattern + if self.root_module is not None: + backend_pattern_config_dict[ROOT_MODULE_DICT_KEY] = self.root_module + if self.qat_module is not None: + backend_pattern_config_dict[QAT_MODULE_DICT_KEY] = self.qat_module + if self.reference_quantized_module is not None: + backend_pattern_config_dict[REFERENCE_QUANTIZED_MODULE_DICT_KEY] = ( + self.reference_quantized_module + ) + if self.fused_module is not None: + backend_pattern_config_dict[FUSED_MODULE_DICT_KEY] = self.fused_module + if self.fuser_method is not None: + backend_pattern_config_dict[FUSER_METHOD_DICT_KEY] = self.fuser_method + if self._root_node_getter is not None: + backend_pattern_config_dict[ROOT_NODE_GETTER_DICT_KEY] = ( + self._root_node_getter + ) + if self._extra_inputs_getter is not None: + backend_pattern_config_dict[EXTRA_INPUTS_GETTER_DICT_KEY] = ( + self._extra_inputs_getter + ) + if len(self._num_tensor_args_to_observation_type) > 0: + backend_pattern_config_dict[ + NUM_TENSOR_ARGS_TO_OBSERVATION_TYPE_DICT_KEY + ] = self._num_tensor_args_to_observation_type + if len(self._input_type_to_index) > 0: + backend_pattern_config_dict[INPUT_TYPE_TO_INDEX_DICT_KEY] = ( + self._input_type_to_index + ) + if self._pattern_complex_format is not None: + backend_pattern_config_dict[PATTERN_COMPLEX_FORMAT_DICT_KEY] = ( + self._pattern_complex_format + ) + return backend_pattern_config_dict diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/backend_config/executorch.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/backend_config/executorch.py new file mode 100644 index 0000000000000000000000000000000000000000..2b9b16492821b73dba1ff3ce6e2617d844d94229 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/backend_config/executorch.py @@ -0,0 +1,498 @@ +# TODO: rename executorch to qnnpack_executorch since executorch is a general runtime +# not a specific backend + +import operator + +import torch +import torch.ao.nn.qat as nnqat +import torch.ao.nn.quantized.reference as nnqr +import torch.nn as nn +import torch.nn.functional as F +from torch.ao.quantization.fuser_method_mappings import ( + _sequential_wrapper2, + fuse_conv_bn, + fuse_conv_bn_relu, +) + +from ._common_operator_config_utils import _Conv2dMetadata +from .backend_config import ( + BackendConfig, + BackendPatternConfig, + DTypeConfig, + DTypeWithConstraints, + ObservationType, +) +from .qnnpack import ( + qnnpack_default_op_qint8_symmetric_dtype_config, + qnnpack_weighted_op_qint8_symmetric_dtype_config, +) + + +__all__ = [ + "get_executorch_backend_config", +] + + +# =================== +# | DTYPE CONFIGS | +# =================== + +executorch_weighted_op_int8_dtype_config = DTypeConfig( + input_dtype=torch.quint8, + output_dtype=torch.quint8, + weight_dtype=torch.qint8, + bias_dtype=torch.float, +) + +executorch_default_op_quint8_dtype_config = DTypeConfig( + input_dtype=torch.quint8, + output_dtype=torch.quint8, +) + +executorch_default_dynamic_quint8_dtype_config = DTypeConfig( + input_dtype=torch.quint8, + output_dtype=torch.float, + weight_dtype=torch.qint8, + bias_dtype=torch.float, + is_dynamic=True, +) + +executorch_act_qint8_scale_min_2_neg_12 = DTypeWithConstraints( + dtype=torch.qint8, + scale_min_lower_bound=2**-12, +) + +executorch_weight_qint8_neg_127_to_127_scale_min_2_neg_12 = DTypeWithConstraints( + dtype=torch.qint8, + quant_min_lower_bound=-127, + quant_max_upper_bound=127, + scale_min_lower_bound=2**-12, +) + +executorch_default_dynamic_qint8_dtype_config = DTypeConfig( + input_dtype=executorch_act_qint8_scale_min_2_neg_12, + output_dtype=torch.float, + weight_dtype=executorch_weight_qint8_neg_127_to_127_scale_min_2_neg_12, + bias_dtype=torch.float, + is_dynamic=True, +) + +executorch_default_dynamic_float16_dtype_config = DTypeConfig( + input_dtype=torch.float16, + output_dtype=torch.float, + weight_dtype=torch.float16, + bias_dtype=torch.float, + is_dynamic=True, +) + +executorch_weight_only_quint8_dtype_config = DTypeConfig( + input_dtype=torch.float, + output_dtype=torch.float, + weight_dtype=torch.quint8, +) + + +# ============================= +# | BACKEND PATTERN CONFIGS | +# ============================= + + +def _get_linear_configs() -> list[BackendPatternConfig]: + """ + Return all configs related to linear modules and ops. + """ + observation_type = ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT + dtype_configs = [ + qnnpack_weighted_op_qint8_symmetric_dtype_config, + executorch_weighted_op_int8_dtype_config, + executorch_default_dynamic_quint8_dtype_config, + executorch_default_dynamic_qint8_dtype_config, + executorch_default_dynamic_float16_dtype_config, + ] + linear_configs: list[BackendPatternConfig] = [] + # linear module + linear_configs.append( + BackendPatternConfig(torch.nn.Linear) + .set_observation_type(observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs) + .set_root_module(torch.nn.Linear) + .set_reference_quantized_module(nnqr.Linear) + .set_qat_module(nnqat.Linear) + ) + # linear qat module + linear_configs.append( + BackendPatternConfig(nnqat.Linear) + .set_observation_type(observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs) + .set_root_module(torch.nn.Linear) + .set_reference_quantized_module(nnqr.Linear) + ) + # functional linear + linear_configs.append( + BackendPatternConfig(torch.nn.functional.linear) + .set_observation_type(observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs) + ._set_input_type_to_index({"weight": 1, "bias": 2}) + ) + return linear_configs + + +def _get_conv_configs() -> list[BackendPatternConfig]: + """ + Return all configs related to conv modules and ops. + """ + observation_type = ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT + dtype_configs = [ + qnnpack_weighted_op_qint8_symmetric_dtype_config, + executorch_weighted_op_int8_dtype_config, + ] + conv_configs = [] + for convs in [_Conv2dMetadata]: + # (1) Single conv modules/functions + # ----------------------------------- + # conv module + conv_configs.append( + BackendPatternConfig(convs.root) + .set_observation_type(observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs) + .set_root_module(convs.root) + .set_reference_quantized_module(convs.reference) + .set_qat_module(convs.qat) + ) + # conv qat module + conv_configs.append( + BackendPatternConfig(convs.qat) + .set_observation_type(observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs) + .set_root_module(convs.root) + .set_reference_quantized_module(convs.reference) + ) + # functional conv + conv_configs.append( + BackendPatternConfig(convs.func) + .set_observation_type(observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs) + ._set_input_type_to_index({"weight": 1, "bias": 2}) + ) + + # (2) Conv + relu + # ----------------------------------- + # conv module + relu module + conv_configs.append( + BackendPatternConfig((convs.root, nn.ReLU)) + .set_dtype_configs(dtype_configs) # noqa: E131 + .set_fuser_method(_sequential_wrapper2(convs.fused_conv_relu)) + .set_fused_module(convs.fused_conv_relu) + ) + # conv module + functional relu + conv_configs.append( + BackendPatternConfig((convs.root, F.relu)) + .set_dtype_configs(dtype_configs) # noqa: E131 + .set_fuser_method(_sequential_wrapper2(convs.fused_conv_relu)) + .set_fused_module(convs.fused_conv_relu) + ) + # fused conv relu module + conv_configs.append( + BackendPatternConfig(convs.fused_conv_relu) + .set_observation_type(observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs) + .set_root_module(convs.root) + .set_reference_quantized_module(convs.reference) + .set_qat_module(convs.relu_qat) + ) + # conv relu, qat fused module + conv_configs.append( + BackendPatternConfig(convs.relu_qat) + .set_observation_type(observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs) + .set_root_module(convs.root) + .set_reference_quantized_module(convs.reference) + ) + # functional conv + relu module + conv_configs.append( + BackendPatternConfig((convs.func, nn.ReLU)) + .set_observation_type(observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs) + ) + # functional conv + functional relu + conv_configs.append( + BackendPatternConfig((convs.func, F.relu)) + .set_observation_type(observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs) + ) + # fused conv relu + conv_configs.append( + BackendPatternConfig(convs.fused_conv_relu) + .set_dtype_configs(dtype_configs) # noqa: E131 + .set_qat_module(convs.relu_qat) + ) + + conv_configs.append( + BackendPatternConfig(convs.relu_qat) + .set_dtype_configs(dtype_configs) # noqa: E131 + .set_root_module(convs.root) + .set_reference_quantized_module(convs.reference) + ) + + # (3) Conv + batchnorm (+ relu) + # ------------------------------- + # conv + batchnorm (+ relu) + conv_configs.append( + BackendPatternConfig((convs.root, convs.bn)) + .set_dtype_configs(dtype_configs) # noqa: E131 + .set_fuser_method(fuse_conv_bn) + .set_fused_module(convs.fused_conv_bn) + ) + # conv + bn + relu module fusion + conv_configs.append( + BackendPatternConfig((convs.root, convs.bn, nn.ReLU)) + .set_dtype_configs(dtype_configs) # noqa: E131 + .set_fuser_method(fuse_conv_bn_relu) + .set_fused_module(convs.fused_conv_bn_relu) + ) + # conv + bn + relu functional fusion + conv_configs.append( + BackendPatternConfig((convs.root, convs.bn, F.relu)) + .set_dtype_configs(dtype_configs) # noqa: E131 + .set_root_module(convs.root) + .set_fuser_method(fuse_conv_bn_relu) + .set_fused_module(convs.fused_conv_bn_relu) + ) + # TODO: we can add fusion for torch.relu as well + # 3.2 conv + bn (+ relu) fused module configs + # fused conv bn + conv_configs.append( + BackendPatternConfig(convs.fused_conv_bn) + .set_dtype_configs(dtype_configs) # noqa: E131 + .set_qat_module(convs.bn_qat) + ) + + # fused conv bn relu + conv_configs.append( + BackendPatternConfig(convs.fused_conv_bn_relu) + .set_dtype_configs(dtype_configs) # noqa: E131 + .set_qat_module(convs.bn_relu_qat) + ) + + # conv bn, qat fused module + conv_configs.append( + BackendPatternConfig(convs.bn_qat) + .set_observation_type(observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs) + .set_root_module(convs.root) + .set_reference_quantized_module(convs.reference) + ) + # conv bn relu, qat fused module + conv_configs.append( + BackendPatternConfig(convs.bn_relu_qat) + .set_observation_type(observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs) + .set_root_module(convs.root) + .set_reference_quantized_module(convs.reference) + ) + return conv_configs + + +def _get_binary_ops_configs() -> list[BackendPatternConfig]: + """ + Return all configs related to binary ops. + """ + dtype_configs = [ + qnnpack_default_op_qint8_symmetric_dtype_config, + executorch_weighted_op_int8_dtype_config, + ] + num_tensor_args_to_observation_type_mapping = { + # TODO: this is not used right now since we have extra check in prepare + # will need to change this to NO_OBSERVER later after we implemented + # Tensor dtype inference properly + 0: ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT, + 1: ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT, + 2: ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT, + } + binary_op_configs: list[BackendPatternConfig] = [] + for op in [ + operator.add, + torch.add, + operator.sub, + torch.sub, + operator.mul, + torch.mul, + ]: + bop_patterns = [ + (op, torch.nn.ReLU), + (op, torch.nn.functional.relu), + (op, torch.relu), + op, + ] + binary_op_configs.extend( + BackendPatternConfig(bop_pattern) + .set_dtype_configs(dtype_configs) # noqa: E131 + ._set_num_tensor_args_to_observation_type( + num_tensor_args_to_observation_type_mapping + ) + for bop_pattern in bop_patterns + ) + return binary_op_configs + + +def _get_share_qparams_ops_configs() -> list[BackendPatternConfig]: + """ + Return the operator configs for the operators that works for both float and quantized + input if input is quantized, the output Tensor shares the same quantization parameter + with input. + + Example operator: avgpool2d, reshape, transpose, maxpool2d + Example observed operator: + observer_0 - avgpool2d - observer_0 (same observer instance as input) + """ + observation_type = ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT + dtype_configs = [ + qnnpack_default_op_qint8_symmetric_dtype_config, + executorch_default_op_quint8_dtype_config, + ] + share_qparams_ops = [ + torch.nn.Flatten, + F.adaptive_avg_pool2d, + F.elu, + F.hardtanh, + F.max_pool2d, + F.pad, + F.relu, + F.relu6, + F.leaky_relu, + F.leaky_relu_, + torch.nn.AdaptiveAvgPool2d, + torch.nn.ConstantPad2d, + torch.nn.ELU, + torch.nn.MaxPool2d, + torch.nn.ReLU6, + torch.nn.Hardtanh, + torch.nn.LeakyReLU, + torch.clamp, + torch.flatten, + torch.mean, + torch.permute, + torch.permute_copy, + torch.squeeze, + "clamp", + "mean", + "permute", + "reshape", + "relu", + "relu_", + "squeeze", + "squeeze_", + "leaky_relu", + ] + share_qparams_op_configs: list[BackendPatternConfig] = [ + BackendPatternConfig(op) + .set_observation_type(observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs) + for op in share_qparams_ops + ] + return share_qparams_op_configs + + +def _get_bn_configs() -> list[BackendPatternConfig]: + """ + Return all configs related to batchnorm. + """ + observation_type = ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT + dtype_configs = [ + qnnpack_default_op_qint8_symmetric_dtype_config, + executorch_default_op_quint8_dtype_config, + ] + bn_configs = [] + bn_configs.append( + BackendPatternConfig(nn.BatchNorm2d) + .set_observation_type(observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs) + ) + return bn_configs + + +def _get_cat_configs() -> list[BackendPatternConfig]: + dtype_configs = [ + qnnpack_default_op_qint8_symmetric_dtype_config, + executorch_default_op_quint8_dtype_config, + ] + cat_configs = [] + cat_configs.append( + BackendPatternConfig(torch.cat) + .set_observation_type(ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT) + .set_dtype_configs(dtype_configs) + ) + cat_configs.append( + BackendPatternConfig(torch.concat) + .set_observation_type(ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT) + .set_dtype_configs(dtype_configs) + ) + cat_configs.append( + BackendPatternConfig(torch.concatenate) + .set_observation_type(ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT) + .set_dtype_configs(dtype_configs) + ) + return cat_configs + + +def _get_embedding_op_configs() -> list[BackendPatternConfig]: + dtype_configs = [ + executorch_weight_only_quint8_dtype_config, + ] + embedding_op_configs = [] + for embedding_op, qat_embedding_op, ref_embedding_op in [ + (nn.Embedding, nnqat.Embedding, nnqr.Embedding), + (nn.EmbeddingBag, nnqat.EmbeddingBag, nnqr.EmbeddingBag), + ]: + embedding_op_configs.append( + BackendPatternConfig(embedding_op) + .set_observation_type( + ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT + ) # noqa: E131 + .set_dtype_configs(dtype_configs) + .set_qat_module(qat_embedding_op) + .set_root_module(embedding_op) + .set_reference_quantized_module(ref_embedding_op) + ) + # config for qat op + embedding_op_configs.append( + BackendPatternConfig(qat_embedding_op) + .set_observation_type( + ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT + ) # noqa: E131 + .set_dtype_configs(dtype_configs) + .set_root_module(embedding_op) + .set_reference_quantized_module(ref_embedding_op) + ) + + # config for functional embedding + embedding_op_configs.append( + BackendPatternConfig(torch.nn.functional.embedding) + .set_observation_type( + ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT + ) # noqa: E131 + .set_dtype_configs(dtype_configs) + ._set_input_type_to_index({"weight": 1}) + ) + return embedding_op_configs + + +# ===================== +# | BACKEND CONFIGS | +# ===================== + + +def get_executorch_backend_config() -> BackendConfig: + """ + Return the `BackendConfig` for backends PyTorch lowers to through the Executorch stack. + """ + return ( + BackendConfig("executorch") + .set_backend_pattern_configs(_get_linear_configs()) + .set_backend_pattern_configs(_get_conv_configs()) + .set_backend_pattern_configs(_get_binary_ops_configs()) + .set_backend_pattern_configs(_get_share_qparams_ops_configs()) + .set_backend_pattern_configs(_get_bn_configs()) + .set_backend_pattern_configs(_get_cat_configs()) + .set_backend_pattern_configs(_get_embedding_op_configs()) + ) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/backend_config/fbgemm.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/backend_config/fbgemm.py new file mode 100644 index 0000000000000000000000000000000000000000..5d665f4fd030aba47c98ee692f0d9e7eca41cbc6 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/backend_config/fbgemm.py @@ -0,0 +1,129 @@ +import torch + +from ._common_operator_config_utils import ( + _get_binary_op_configs, + _get_bn_configs, + _get_cat_config, + _get_conv_configs, + _get_default_op_configs, + _get_embedding_op_configs, + _get_fixed_qparams_op_configs, + _get_linear_configs, + _get_rnn_op_configs, + _get_share_qparams_op_configs, + _get_tensor_info_op_configs, +) +from .backend_config import BackendConfig, DTypeConfig + + +__all__ = [ + "get_fbgemm_backend_config", +] + +# =================== +# | DTYPE CONFIGS | +# =================== + +# TODO: For now, these DTypeConfigs are identical to the ones defined in native.py +# In the future, once we support specifying quant_min/quant_max and scale_min/scale_max, +# these will diverge. In particular, for FBGEMM, we will restrict the activation quantized +# values to within [0, 127]. + +fbgemm_weighted_op_quint8_dtype_config = DTypeConfig( + input_dtype=torch.quint8, + output_dtype=torch.quint8, + weight_dtype=torch.qint8, + bias_dtype=torch.float, +) + +fbgemm_default_op_quint8_dtype_config = DTypeConfig( + input_dtype=torch.quint8, + output_dtype=torch.quint8, +) + +fbgemm_default_op_fp16_dtype_config = DTypeConfig( + input_dtype=torch.float16, + output_dtype=torch.float16, + weight_dtype=torch.float16, + bias_dtype=torch.float16, +) + +fbgemm_default_dynamic_int8_dtype_config = DTypeConfig( + input_dtype=torch.quint8, + output_dtype=torch.float, + weight_dtype=torch.qint8, + bias_dtype=torch.float, + is_dynamic=True, +) + +fbgemm_default_dynamic_float16_dtype_config = DTypeConfig( + input_dtype=torch.float16, + output_dtype=torch.float, + weight_dtype=torch.float16, + bias_dtype=torch.float, + is_dynamic=True, +) + +fbgemm_weight_only_quint8_dtype_config = DTypeConfig( + input_dtype=torch.float, + output_dtype=torch.float, + weight_dtype=torch.quint8, +) + +fbgemm_weight_only_quint4x2_dtype_config = DTypeConfig( + input_dtype=torch.float, + output_dtype=torch.float, + weight_dtype=torch.quint4x2, +) + + +# ===================== +# | BACKEND CONFIGS | +# ===================== + + +def get_fbgemm_backend_config() -> BackendConfig: + """ + Return the `BackendConfig` for PyTorch's native FBGEMM backend. + """ + conv_dtype_configs = [fbgemm_weighted_op_quint8_dtype_config] + linear_dtype_configs = [ + fbgemm_weighted_op_quint8_dtype_config, + fbgemm_default_dynamic_int8_dtype_config, + fbgemm_default_dynamic_float16_dtype_config, + ] + binary_op_dtype_configs = [fbgemm_default_op_quint8_dtype_config] + default_op_dtype_configs = [fbgemm_default_op_quint8_dtype_config] + fixed_qparams_op_dtype_configs = [fbgemm_default_op_quint8_dtype_config] + share_qparams_op_dtype_configs = [fbgemm_default_op_quint8_dtype_config] + tensor_info_op_dtype_configs = [fbgemm_default_op_quint8_dtype_config] + rnn_op_dtype_configs = [ + fbgemm_default_dynamic_int8_dtype_config, + fbgemm_default_dynamic_float16_dtype_config, + ] + embedding_op_dtype_configs = [ + fbgemm_weight_only_quint8_dtype_config, + fbgemm_weight_only_quint4x2_dtype_config, + ] + return ( + BackendConfig("fbgemm") + .set_backend_pattern_configs(_get_conv_configs(conv_dtype_configs)) + .set_backend_pattern_configs(_get_linear_configs(linear_dtype_configs)) + .set_backend_pattern_configs(_get_binary_op_configs(binary_op_dtype_configs)) + .set_backend_pattern_config(_get_cat_config(default_op_dtype_configs)) + .set_backend_pattern_configs(_get_default_op_configs(default_op_dtype_configs)) + .set_backend_pattern_configs( + _get_fixed_qparams_op_configs(fixed_qparams_op_dtype_configs) + ) + .set_backend_pattern_configs( + _get_share_qparams_op_configs(share_qparams_op_dtype_configs) + ) + .set_backend_pattern_configs( + _get_tensor_info_op_configs(tensor_info_op_dtype_configs) + ) + .set_backend_pattern_configs(_get_bn_configs(default_op_dtype_configs)) + .set_backend_pattern_configs(_get_rnn_op_configs(rnn_op_dtype_configs)) + .set_backend_pattern_configs( + _get_embedding_op_configs(embedding_op_dtype_configs) + ) + ) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/backend_config/native.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/backend_config/native.py new file mode 100644 index 0000000000000000000000000000000000000000..a98d1a9a3d41b43b1c0ce55a2471d3342af71a55 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/backend_config/native.py @@ -0,0 +1,231 @@ +# mypy: allow-untyped-defs +import torch + +from ._common_operator_config_utils import ( + _get_binary_op_configs, + _get_bn_configs, + _get_cat_config, + _get_conv_configs, + _get_default_op_configs, + _get_embedding_op_configs, + _get_fixed_qparams_op_configs, + _get_linear_configs, + _get_ln_configs, + _get_rnn_op_configs, + _get_share_qparams_op_configs, + _get_tensor_info_op_configs, +) +from .backend_config import BackendConfig, DTypeConfig + + +__all__ = [ + "get_test_only_legacy_native_backend_config", + "default_op_quint8_dtype_config", + "default_op_fp16_dtype_config", + "default_dynamic_int8_dtype_config", + "default_dynamic_float16_dtype_config", + "input_output_only_quint8_dtype_config", + "weight_only_quint8_dtype_config", + "weight_only_quint4x2_dtype_config", + "get_native_backend_config", + "get_native_backend_config_dict", + "get_test_only_legacy_native_backend_config_dict", +] + +# =================== +# | DTYPE CONFIGS | +# =================== + +# weighted op int8 dtype config +# this is config for ops that has quantized weights, like linear, conv +weighted_op_quint8_dtype_config = DTypeConfig( + input_dtype=torch.quint8, + output_dtype=torch.quint8, + weight_dtype=torch.qint8, + bias_dtype=torch.float, +) + +default_op_quint8_dtype_config = DTypeConfig( + input_dtype=torch.quint8, + output_dtype=torch.quint8, +) + +default_op_fp16_dtype_config = DTypeConfig( + input_dtype=torch.float16, + output_dtype=torch.float16, + weight_dtype=torch.float16, + bias_dtype=torch.float16, +) + +default_dynamic_int8_dtype_config = DTypeConfig( + input_dtype=torch.quint8, + output_dtype=torch.float, + weight_dtype=torch.qint8, + bias_dtype=torch.float, + # currently the dtype check is not yet enabled, so we provided the dtype_configs but + # it is not really used yet, + # we will enable it a bit later after we moved everything to backend_config_dict + is_dynamic=True, +) + +default_dynamic_float16_dtype_config = DTypeConfig( + input_dtype=torch.float16, + output_dtype=torch.float, + weight_dtype=torch.float16, + bias_dtype=torch.float, + # currently the dtype check is not yet enabled, so we provided the dtype_configs but + # it is not really used yet, + # we will enable it a bit later after we moved everything to backend_config_dict + is_dynamic=True, +) + +# Needed for LayerNorm and f.layer_norm, since currently the kernel only supports float weights +input_output_only_quint8_dtype_config = DTypeConfig( + input_dtype=torch.quint8, + output_dtype=torch.quint8, + weight_dtype=torch.float, + bias_dtype=torch.float, +) + +weight_only_quint8_dtype_config = DTypeConfig( + input_dtype=torch.float, + output_dtype=torch.float, + weight_dtype=torch.quint8, +) + +weight_only_quint4x2_dtype_config = DTypeConfig( + input_dtype=torch.float, + output_dtype=torch.float, + weight_dtype=torch.quint4x2, +) + + +# ===================== +# | BACKEND CONFIGS | +# ===================== + + +def get_test_only_legacy_native_backend_config() -> BackendConfig: + """ + Return the `BackendConfig` for PyTorch Native backend (fbgemm/qnnpack) with various additional fp16 ops. + """ + conv_dtype_configs = [weighted_op_quint8_dtype_config] + linear_dtype_configs = [ + weighted_op_quint8_dtype_config, + default_dynamic_int8_dtype_config, + default_dynamic_float16_dtype_config, + default_op_fp16_dtype_config, + ] + binary_op_dtype_configs = [ + default_op_quint8_dtype_config, + default_op_fp16_dtype_config, + ] + default_op_dtype_configs = [default_op_quint8_dtype_config] + fixed_qparams_op_dtype_configs = [ + default_op_quint8_dtype_config, + default_op_fp16_dtype_config, + ] + share_qparams_op_dtype_configs = [ + default_op_quint8_dtype_config, + default_op_fp16_dtype_config, + ] + tensor_info_op_dtype_configs = [ + default_op_quint8_dtype_config, + ] + rnn_op_dtype_configs = [ + default_dynamic_int8_dtype_config, + default_dynamic_float16_dtype_config, + ] + embedding_op_dtype_configs = [ + weight_only_quint8_dtype_config, + weight_only_quint4x2_dtype_config, + ] + layer_norm_op_dtype_configs = [input_output_only_quint8_dtype_config] + return ( + BackendConfig("_native_and_fp16") + .set_backend_pattern_configs(_get_conv_configs(conv_dtype_configs)) + .set_backend_pattern_configs(_get_linear_configs(linear_dtype_configs)) + .set_backend_pattern_configs(_get_binary_op_configs(binary_op_dtype_configs)) + .set_backend_pattern_config(_get_cat_config(default_op_dtype_configs)) + .set_backend_pattern_configs(_get_default_op_configs(default_op_dtype_configs)) + .set_backend_pattern_configs( + _get_fixed_qparams_op_configs(fixed_qparams_op_dtype_configs) + ) + .set_backend_pattern_configs( + _get_share_qparams_op_configs(share_qparams_op_dtype_configs) + ) + .set_backend_pattern_configs( + _get_tensor_info_op_configs(tensor_info_op_dtype_configs) + ) + .set_backend_pattern_configs(_get_bn_configs(default_op_dtype_configs)) + .set_backend_pattern_configs(_get_ln_configs(layer_norm_op_dtype_configs)) + .set_backend_pattern_configs(_get_rnn_op_configs(rnn_op_dtype_configs)) + .set_backend_pattern_configs( + _get_embedding_op_configs(embedding_op_dtype_configs) + ) + ) + + +def get_native_backend_config() -> BackendConfig: + """ + Return the `BackendConfig` for PyTorch Native backend (fbgemm/qnnpack). + """ + # TODO: express this BackendConfig as a union of the FBGEMM and QNNPACK BackendConfigs + conv_dtype_configs = [weighted_op_quint8_dtype_config] + linear_dtype_configs = [ + weighted_op_quint8_dtype_config, + default_dynamic_int8_dtype_config, + default_dynamic_float16_dtype_config, + ] + binary_op_dtype_configs = [default_op_quint8_dtype_config] + default_op_dtype_configs = [default_op_quint8_dtype_config] + fixed_qparams_op_dtype_configs = [default_op_quint8_dtype_config] + share_qparams_op_dtype_configs = [default_op_quint8_dtype_config] + tensor_info_op_dtype_configs = [default_op_quint8_dtype_config] + rnn_op_dtype_configs = [ + default_dynamic_int8_dtype_config, + default_dynamic_float16_dtype_config, + ] + embedding_op_dtype_configs = [ + weight_only_quint8_dtype_config, + weight_only_quint4x2_dtype_config, + ] + layer_norm_op_dtype_configs = [input_output_only_quint8_dtype_config] + return ( + BackendConfig("native") + .set_backend_pattern_configs(_get_conv_configs(conv_dtype_configs)) + .set_backend_pattern_configs(_get_linear_configs(linear_dtype_configs)) + .set_backend_pattern_configs(_get_binary_op_configs(binary_op_dtype_configs)) + .set_backend_pattern_config(_get_cat_config(default_op_dtype_configs)) + .set_backend_pattern_configs(_get_default_op_configs(default_op_dtype_configs)) + .set_backend_pattern_configs( + _get_fixed_qparams_op_configs(fixed_qparams_op_dtype_configs) + ) + .set_backend_pattern_configs( + _get_share_qparams_op_configs(share_qparams_op_dtype_configs) + ) + .set_backend_pattern_configs( + _get_tensor_info_op_configs(tensor_info_op_dtype_configs) + ) + .set_backend_pattern_configs(_get_bn_configs(default_op_dtype_configs)) + .set_backend_pattern_configs(_get_ln_configs(layer_norm_op_dtype_configs)) + .set_backend_pattern_configs(_get_rnn_op_configs(rnn_op_dtype_configs)) + .set_backend_pattern_configs( + _get_embedding_op_configs(embedding_op_dtype_configs) + ) + ) + + +def get_native_backend_config_dict(): + """ + Return the `BackendConfig` for PyTorch Native backend (fbgemm/qnnpack) in dictionary form. + """ + return get_native_backend_config().to_dict() + + +def get_test_only_legacy_native_backend_config_dict(): + """ + Return the `BackendConfig` for PyTorch Native backend (fbgemm/qnnpack) with various additional + fp16 ops in dictionary form. + """ + return get_test_only_legacy_native_backend_config().to_dict() diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/backend_config/onednn.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/backend_config/onednn.py new file mode 100644 index 0000000000000000000000000000000000000000..3cc7a2cf4c669742583fb1fe23d9948e04dbecc1 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/backend_config/onednn.py @@ -0,0 +1,641 @@ +# mypy: allow-untyped-defs +import itertools +import operator + +import torch +import torch.ao.nn.intrinsic as nni +import torch.ao.nn.quantized.reference as nnqr +import torch.nn as nn +import torch.nn.functional as F +from torch.ao.quantization.fuser_method_mappings import _sequential_wrapper2 +from torch.ao.quantization.utils import MatchAllNode + +from ._common_operator_config_utils import ( + _get_binary_op_configs, + _get_bn_configs, + _get_cat_config, + _get_conv_configs, + _get_default_op_configs, + _get_embedding_op_configs, + _get_fixed_qparams_op_configs, + _get_linear_configs, + _get_ln_configs, + _get_rnn_op_configs, + _get_share_qparams_op_configs, +) +from .backend_config import ( + BackendConfig, + BackendPatternConfig, + DTypeConfig, + ObservationType, +) + + +# =================== +# | DTYPE CONFIGS | +# =================== + +onednn_weighted_op_int8_dtype_config = DTypeConfig( + input_dtype=torch.quint8, + output_dtype=torch.quint8, + weight_dtype=torch.qint8, + bias_dtype=torch.float, +) + +onednn_op_quint8_dtype_config = DTypeConfig( + input_dtype=torch.quint8, + output_dtype=torch.quint8, +) + +onednn_dynamic_int8_dtype_config = DTypeConfig( + input_dtype=torch.quint8, + output_dtype=torch.float, + weight_dtype=torch.qint8, + bias_dtype=torch.float, + is_dynamic=True, +) + +onednn_weight_only_qint8_dtype_config = DTypeConfig( + input_dtype=torch.float, + output_dtype=torch.float, + weight_dtype=torch.qint8, +) + +onednn_input_output_only_quint8_dtype_config = DTypeConfig( + input_dtype=torch.quint8, + output_dtype=torch.quint8, + weight_dtype=torch.float, + bias_dtype=torch.float, +) + +# =================== +# | FUSER METHODS | +# =================== + + +def _fuse_linear_bn_leaky_relu(is_qat, linear, bn, leaky_relu): + r"""Given the linear, bn and leaky_relu modules, fuses them and returns the fused module + Args: + is_qat: a flag for whether we are using quantization aware training fusion + or post training quantization fusion + linear: Module instance of type Linear + bn: BatchNorm1d instance that needs to be fused with the linear layer + leaky_relu: LeakyReLU instance that needs to be fused with the linear layer + Examples:: + >>> # xdoctest: +SKIP(failing) + >>> m1 = nn.Linear(20, 10) + >>> b1 = nn.BatchNorm1d(10) + >>> lr = nn.LeakyReLU(0.01) + >>> m2 = _fuse_linear_bn_leaky_relu(m1, b1, lr) + """ + if linear.training != bn.training or bn.training != leaky_relu.training: + raise AssertionError( + "Linear, BN and LeakyReLU all must be in the same mode (train or eval)." + ) + + if is_qat: + raise NotImplementedError( + f"Cannot fuse train modules: {(linear, bn, leaky_relu)}" + ) + else: + map_to_fused_module_eval = { + nn.Linear: nni.LinearLeakyReLU, + } + fused_module = map_to_fused_module_eval.get(type(linear), None) + if fused_module is not None: + fused_linear = nn.utils.fusion.fuse_linear_bn_eval(linear, bn) + fm = fused_module(fused_linear, leaky_relu) + return fm + else: + raise NotImplementedError( + f"Cannot fuse eval modules: {(linear, bn, leaky_relu)}" + ) + + +# ====================== +# | CONFIGS FOR CONV | +# ====================== +observation_type = ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT + +conv_dtype_configs = [onednn_weighted_op_int8_dtype_config] +conv_configs = _get_conv_configs(conv_dtype_configs) + +# (1) Conv2d + Add + +# conv2d Y +# \ / +# add + +# include: +# conv2d conv2d +# \ / +# add + + +def _fuse_conv_add_left(is_qat, add, conv, _): + return nni.ConvAdd2d(conv, add) + + +def _conv_add_root_node_getter_left(pattern): + _, conv, _ = pattern + return conv + + +def _conv_add_extra_inputs_getter_left(pattern): + """get inputs pattern for extra inputs, inputs for root node + are assumed to be copied over from root node to the fused node + """ + _, _conv, extra_input = pattern + return [extra_input] + + +# conv2d +# \ +# bn Y +# \ / +# add + + +def _fuse_conv_bn_add_left(is_qat, add, bn_conv, _): + bn, conv = bn_conv + if is_qat: + raise NotImplementedError(f"Cannot fuse train modules: {(conv, bn, add)}") + else: + fused_conv = nn.utils.fusion.fuse_conv_bn_eval(conv, bn) + return nni.ConvAdd2d(fused_conv, add) + + +def _conv_bn_add_root_node_getter_left(add_pattern): + _, bn_conv, _ = add_pattern + _bn, conv = bn_conv + return conv + + +def _conv_bn_add_extra_inputs_getter_left(add_pattern): + """get inputs pattern for extra inputs, inputs for root node + are assumed to be copied over from root node to the fused node + """ + _, _bn_conv, extra_input = add_pattern + return [extra_input] + + +conv_add_left_optioins = itertools.product( + [True, False], # with_bn + [torch.add, operator.add], # add_op +) + +for with_bn, add_op in conv_add_left_optioins: + if with_bn: + conv_configs.append( + BackendPatternConfig() + ._set_pattern_complex_format( + (add_op, (nn.BatchNorm2d, nn.Conv2d), MatchAllNode) + ) # noqa: E131 + .set_observation_type(observation_type) + .set_dtype_configs(conv_dtype_configs) + .set_fuser_method(_fuse_conv_bn_add_left) + ._set_root_node_getter(_conv_bn_add_root_node_getter_left) + ._set_extra_inputs_getter(_conv_bn_add_extra_inputs_getter_left) + .set_fused_module(nni.ConvAdd2d) + ) + else: + conv_configs.append( + BackendPatternConfig() + ._set_pattern_complex_format((add_op, nn.Conv2d, MatchAllNode)) # noqa: E131 + .set_observation_type(observation_type) + .set_dtype_configs(conv_dtype_configs) + .set_fuser_method(_fuse_conv_add_left) + ._set_root_node_getter(_conv_add_root_node_getter_left) + ._set_extra_inputs_getter(_conv_add_extra_inputs_getter_left) + .set_fused_module(nni.ConvAdd2d) + ) + +# Y conv2d +# \ / +# add + + +def _fuse_conv_add_right(is_qat, add, _, conv): + return nni.ConvAdd2d(conv, add) + + +def _conv_add_root_node_getter_right(pattern): + _add, _, conv = pattern + return conv + + +def _conv_add_extra_inputs_getter_right(pattern): + """get inputs pattern for extra inputs, inputs for root node + are assumed to be copied over from root node to the fused node + """ + _, extra_input, _conv = pattern + return [extra_input] + + +# conv2d +# / +# Y bn +# \ / +# add + + +def _fuse_conv_bn_add_right(is_qat, add, _, bn_conv): + bn, conv = bn_conv + if is_qat: + raise NotImplementedError(f"Cannot fuse train modules: {(conv, bn, add)}") + else: + fused_conv = nn.utils.fusion.fuse_conv_bn_eval(conv, bn) + return nni.ConvAdd2d(fused_conv, add) + + +def _conv_bn_add_root_node_getter_right(pattern): + _add, _, bn_conv = pattern + _bn, conv = bn_conv + return conv + + +def _conv_bn_add_extra_inputs_getter_right(pattern): + """get inputs pattern for extra inputs, inputs for root node + are assumed to be copied over from root node to the fused node + """ + _, extra_input, _bn_conv = pattern + return [extra_input] + + +conv_add_optioins = itertools.product( + [True, False], # with_bn + [torch.add, operator.add], # add_op +) + +for with_bn, add_op in conv_add_optioins: + if with_bn: + conv_configs.append( + BackendPatternConfig() + ._set_pattern_complex_format( + (add_op, MatchAllNode, (nn.BatchNorm2d, nn.Conv2d)) + ) # noqa: E131 + .set_observation_type(observation_type) + .set_dtype_configs(conv_dtype_configs) + .set_fuser_method(_fuse_conv_bn_add_right) + ._set_root_node_getter(_conv_bn_add_root_node_getter_right) + ._set_extra_inputs_getter(_conv_bn_add_extra_inputs_getter_right) + .set_fused_module(nni.ConvAdd2d) + ) + else: + conv_configs.append( + BackendPatternConfig() + ._set_pattern_complex_format((add_op, MatchAllNode, nn.Conv2d)) # noqa: E131 + .set_observation_type(observation_type) + .set_dtype_configs(conv_dtype_configs) + .set_fuser_method(_fuse_conv_add_right) + ._set_root_node_getter(_conv_add_root_node_getter_right) + ._set_extra_inputs_getter(_conv_add_extra_inputs_getter_right) + .set_fused_module(nni.ConvAdd2d) + ) + +conv_configs.append( + BackendPatternConfig(nni.ConvAdd2d) + .set_observation_type(observation_type) # noqa: E131 + .set_dtype_configs(conv_dtype_configs) + .set_root_module(nn.Conv2d) + .set_reference_quantized_module(nnqr.Conv2d) +) + +# (2) Conv2d + Add + Relu + +# conv2d Y +# \ / +# add +# \ +# relu + + +def _fuse_conv_add_relu_left(is_qat, relu, add_pattern): + add, conv, _ = add_pattern + return nni.ConvAddReLU2d(conv, add, relu) + + +def _conv_add_relu_root_node_getter_left(pattern): + _relu, add_pattern = pattern + _, conv, _ = add_pattern + return conv + + +def _conv_add_relu_extra_inputs_getter_left(pattern): + """get inputs pattern for extra inputs, inputs for root node + are assumed to be copied over from root node to the fused node + """ + _relu, add_pattern = pattern + _, _conv, extra_input = add_pattern + return [extra_input] + + +# conv2d +# \ +# bn Y +# \ / +# add +# \ +# relu + + +def _fuse_conv_bn_add_relu_left(is_qat, relu, add_pattern): + add, bn_conv, _ = add_pattern + bn, conv = bn_conv + if is_qat: + raise NotImplementedError(f"Cannot fuse train modules: {(conv, bn, add, relu)}") + else: + fused_conv = nn.utils.fusion.fuse_conv_bn_eval(conv, bn) + return nni.ConvAddReLU2d(fused_conv, add, relu) + + +def _conv_bn_add_relu_root_node_getter_left(pattern): + _relu, add_pattern = pattern + _, bn_conv, _ = add_pattern + _bn, conv = bn_conv + return conv + + +def _conv_bn_add_relu_extra_inputs_getter_left(pattern): + """get inputs pattern for extra inputs, inputs for root node + are assumed to be copied over from root node to the fused node + """ + _relu, add_pattern = pattern + _, _bn_conv, extra_input = add_pattern + return [extra_input] + + +conv_add_relu_left_optioins = itertools.product( + [True, False], # with_bn + [torch.add, operator.add], # add_op +) + +for with_bn, add_op in conv_add_relu_left_optioins: + if with_bn: + conv_configs.append( + BackendPatternConfig() + ._set_pattern_complex_format( + (nn.ReLU, (add_op, (nn.BatchNorm2d, nn.Conv2d), MatchAllNode)) + ) # noqa: E131 + .set_observation_type(observation_type) + .set_dtype_configs(conv_dtype_configs) + .set_fuser_method(_fuse_conv_bn_add_relu_left) + ._set_root_node_getter(_conv_bn_add_relu_root_node_getter_left) + ._set_extra_inputs_getter(_conv_bn_add_relu_extra_inputs_getter_left) + .set_fused_module(nni.ConvAddReLU2d) + ) + else: + conv_configs.append( + BackendPatternConfig() + ._set_pattern_complex_format((nn.ReLU, (add_op, nn.Conv2d, MatchAllNode))) # noqa: E131 + .set_observation_type(observation_type) + .set_dtype_configs(conv_dtype_configs) + .set_fuser_method(_fuse_conv_add_relu_left) + ._set_root_node_getter(_conv_add_relu_root_node_getter_left) + ._set_extra_inputs_getter(_conv_add_relu_extra_inputs_getter_left) + .set_fused_module(nni.ConvAddReLU2d) + ) + +# Y conv2d +# \ / +# add +# \ +# relu + + +def _fuse_conv_add_relu_right(is_qat, relu, add_pattern): + add, _, conv = add_pattern + return nni.ConvAddReLU2d(conv, add, relu) + + +def _conv_add_relu_root_node_getter_right(pattern): + _relu, add_pattern = pattern + _, _extra_input, conv = add_pattern + return conv + + +def _conv_add_relu_extra_inputs_getter_right(pattern): + """get inputs pattern for extra inputs, inputs for root node + are assumed to be copied over from root node to the fused node + """ + _relu, add_pattern = pattern + _, extra_input, _conv = add_pattern + return [extra_input] + + +# conv2d +# / +# Y bn +# \ / +# add +# \ +# relu + + +def _fuse_conv_bn_add_relu_right(is_qat, relu, add_pattern): + add, _, bn_conv = add_pattern + bn, conv = bn_conv + if is_qat: + raise NotImplementedError(f"Cannot fuse train modules: {(conv, bn, add, relu)}") + else: + fused_conv = nn.utils.fusion.fuse_conv_bn_eval(conv, bn) + return nni.ConvAddReLU2d(fused_conv, add, relu) + + +def _conv_bn_add_relu_root_node_getter_right(pattern): + _relu, add_pattern = pattern + _, _, bn_conv = add_pattern + _bn, conv = bn_conv + return conv + + +def _conv_bn_add_relu_extra_inputs_getter_right(pattern): + """get inputs pattern for extra inputs, inputs for root node + are assumed to be copied over from root node to the fused node + """ + _relu, add_pattern = pattern + _, extra_input, _bn_conv = add_pattern + return [extra_input] + + +conv_add_relu_left_optioins = itertools.product( + [True, False], # with_bn + [torch.add, operator.add], # add_op +) + +for with_bn, add_op in conv_add_relu_left_optioins: + if with_bn: + conv_configs.append( + BackendPatternConfig() + ._set_pattern_complex_format( + (nn.ReLU, (add_op, MatchAllNode, (nn.BatchNorm2d, nn.Conv2d))) + ) # noqa: E131 + .set_observation_type(observation_type) + .set_dtype_configs(conv_dtype_configs) + .set_fuser_method(_fuse_conv_bn_add_relu_right) + ._set_root_node_getter(_conv_bn_add_relu_root_node_getter_right) + ._set_extra_inputs_getter(_conv_bn_add_relu_extra_inputs_getter_right) + .set_fused_module(nni.ConvAddReLU2d) + ) + else: + conv_configs.append( + BackendPatternConfig() + ._set_pattern_complex_format((nn.ReLU, (add_op, MatchAllNode, nn.Conv2d))) # noqa: E131 + .set_observation_type(observation_type) + .set_dtype_configs(conv_dtype_configs) + .set_fuser_method(_fuse_conv_add_relu_right) + ._set_root_node_getter(_conv_add_relu_root_node_getter_right) + ._set_extra_inputs_getter(_conv_add_relu_extra_inputs_getter_right) + .set_fused_module(nni.ConvAddReLU2d) + ) + +conv_configs.append( + BackendPatternConfig(nni.ConvAddReLU2d) + .set_observation_type(observation_type) # noqa: E131 + .set_dtype_configs(conv_dtype_configs) + .set_root_module(nn.Conv2d) + .set_reference_quantized_module(nnqr.Conv2d) +) + +# ======================== +# | CONFIGS FOR LINEAR | +# ======================== + +linear_dtype_configs = [ + onednn_weighted_op_int8_dtype_config, + onednn_dynamic_int8_dtype_config, +] +linear_configs = _get_linear_configs(linear_dtype_configs) + + +def _add_eltwise_fusion_configs( + configs, + root_module, + root_op, + post_module, + post_op, + dtype_configs, + fuser_method, + fused_module, + observation_type, + ref_quant_module, +): + # 1 base module + op module fusion config + configs.append( + BackendPatternConfig((root_module, post_module)) + .set_dtype_configs(dtype_configs) # noqa: E131 + .set_fuser_method(fuser_method) + .set_fused_module(fused_module) + ) + # base module + functional post op + configs.append( + BackendPatternConfig((root_module, post_op)) + .set_dtype_configs(dtype_configs) # noqa: E131 + .set_fuser_method(fuser_method) + .set_fused_module(fused_module) + ) + + # 2 fused module configs + configs.append( + BackendPatternConfig(fused_module) + .set_observation_type(observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs) + .set_root_module(root_module) + .set_reference_quantized_module(ref_quant_module) + ) + + # 3 functional base op + post op configs + configs.append( + BackendPatternConfig((root_op, post_module)) + .set_observation_type(observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs) + ) + configs.append( + BackendPatternConfig((root_op, post_op)) + .set_observation_type(observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs) + ) + + +# Configs for linear + leaky_relu fusion +_add_eltwise_fusion_configs( + linear_configs, + nn.Linear, + F.linear, + nn.LeakyReLU, + F.leaky_relu, + linear_dtype_configs, + _sequential_wrapper2(nni.LinearLeakyReLU), + nni.LinearLeakyReLU, + observation_type, + nnqr.Linear, +) + +# Configs for linear module + batchnorm + leaky_relu +linear_configs.append( + BackendPatternConfig((nn.Linear, nn.BatchNorm1d, nn.LeakyReLU)) + .set_dtype_configs(linear_dtype_configs) # noqa: E131 + .set_fuser_method(_fuse_linear_bn_leaky_relu) + .set_fused_module(nni.LinearLeakyReLU) +) + +# Configs for linear + tanh fusion +_add_eltwise_fusion_configs( + linear_configs, + nn.Linear, + F.linear, + nn.Tanh, + torch.tanh, + linear_dtype_configs, + _sequential_wrapper2(nni.LinearTanh), + nni.LinearTanh, + observation_type, + nnqr.Linear, +) + +# =========================== +# | CONFIGS FOR OTHER OPS | +# =========================== + +binary_op_dtype_configs = [onednn_op_quint8_dtype_config] +default_op_dtype_configs = [onednn_op_quint8_dtype_config] +fixed_qparams_op_dtype_configs = [onednn_op_quint8_dtype_config] +share_qparams_op_dtype_configs = [onednn_op_quint8_dtype_config] +rnn_op_dtype_configs = [onednn_dynamic_int8_dtype_config] +embedding_op_dtype_configs = [onednn_weight_only_qint8_dtype_config] +layer_norm_op_dtype_configs = [onednn_input_output_only_quint8_dtype_config] + +# ===================== +# | BACKEND CONFIGS | +# ===================== + + +def get_onednn_backend_config() -> BackendConfig: + """ + Return the `BackendConfig` for PyTorch's native ONEDNN backend. + """ + return ( + BackendConfig("onednn") + .set_backend_pattern_configs(conv_configs) + .set_backend_pattern_configs(linear_configs) + .set_backend_pattern_configs(_get_binary_op_configs(binary_op_dtype_configs)) + .set_backend_pattern_config(_get_cat_config(default_op_dtype_configs)) + .set_backend_pattern_configs(_get_default_op_configs(default_op_dtype_configs)) + .set_backend_pattern_configs( + _get_fixed_qparams_op_configs(fixed_qparams_op_dtype_configs) + ) + .set_backend_pattern_configs( + _get_share_qparams_op_configs(share_qparams_op_dtype_configs) + ) + .set_backend_pattern_configs(_get_bn_configs(default_op_dtype_configs)) + .set_backend_pattern_configs(_get_ln_configs(layer_norm_op_dtype_configs)) + .set_backend_pattern_configs(_get_rnn_op_configs(rnn_op_dtype_configs)) + .set_backend_pattern_configs( + _get_embedding_op_configs(embedding_op_dtype_configs) + ) + ) + + +__all__ = [ + "get_onednn_backend_config", +] diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/backend_config/qnnpack.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/backend_config/qnnpack.py new file mode 100644 index 0000000000000000000000000000000000000000..841bac512a6549f39f757b9531591f1e47e72a83 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/backend_config/qnnpack.py @@ -0,0 +1,171 @@ +import torch + +from ._common_operator_config_utils import ( + _get_binary_op_configs, + _get_bn_configs, + _get_cat_config, + _get_conv_configs, + _get_default_op_configs, + _get_embedding_op_configs, + _get_fixed_qparams_op_configs, + _get_linear_configs, + _get_rnn_op_configs, + _get_share_qparams_op_configs, +) +from .backend_config import BackendConfig, DTypeConfig, DTypeWithConstraints + + +__all__ = [ + "get_qnnpack_backend_config", +] + +# =================== +# | DTYPE CONFIGS | +# =================== + +qnnpack_weighted_op_quint8_dtype_config = DTypeConfig( + input_dtype=torch.quint8, + output_dtype=torch.quint8, + weight_dtype=torch.qint8, + bias_dtype=torch.float, +) + +qnnpack_default_op_quint8_dtype_config = DTypeConfig( + input_dtype=torch.quint8, + output_dtype=torch.quint8, +) + +qnnpack_default_op_fp16_dtype_config = DTypeConfig( + input_dtype=torch.float16, + output_dtype=torch.float16, + weight_dtype=torch.float16, + bias_dtype=torch.float16, +) + +qnnpack_default_dynamic_int8_dtype_config = DTypeConfig( + input_dtype=torch.quint8, + output_dtype=torch.float, + weight_dtype=torch.qint8, + bias_dtype=torch.float, + is_dynamic=True, +) + +qnnpack_default_dynamic_float16_dtype_config = DTypeConfig( + input_dtype=torch.float16, + output_dtype=torch.float, + weight_dtype=torch.float16, + bias_dtype=torch.float, + is_dynamic=True, +) + +qnnpack_weight_only_quint8_dtype_config = DTypeConfig( + input_dtype=torch.float, + output_dtype=torch.float, + weight_dtype=torch.quint8, +) + +qnnpack_weight_only_quint4x2_dtype_config = DTypeConfig( + input_dtype=torch.float, + output_dtype=torch.float, + weight_dtype=torch.quint4x2, +) + +# xnnpack compatible dtype configs + +# We restrict scale values to be 2 ** -12 to ensure the +# requantization scale never falls below the xnnpack lower +# threshold. Additionally, for qint8 weight, we restrict +# the quantization values to [-127, +127], excluding -128. +# For more detail, refer to the description of +# `default_symmetric_qnnpack_qconfig`. + +# TODO: add additional restriction on qscheme to ensure it +# is either per_tensor_symmetric or per_channel_symmetric + +qnnpack_act_qint8_scale_min_2_neg_12 = DTypeWithConstraints( + dtype=torch.qint8, + scale_min_lower_bound=2**-12, +) + +qnnpack_weight_qint8_neg_127_to_127_scale_min_2_neg_12 = DTypeWithConstraints( + dtype=torch.qint8, + quant_min_lower_bound=-127, + quant_max_upper_bound=127, + scale_min_lower_bound=2**-12, +) + +qnnpack_weighted_op_qint8_symmetric_dtype_config = DTypeConfig( + input_dtype=qnnpack_act_qint8_scale_min_2_neg_12, + output_dtype=qnnpack_act_qint8_scale_min_2_neg_12, + weight_dtype=qnnpack_weight_qint8_neg_127_to_127_scale_min_2_neg_12, + bias_dtype=torch.float, +) + +qnnpack_default_op_qint8_symmetric_dtype_config = DTypeConfig( + input_dtype=qnnpack_act_qint8_scale_min_2_neg_12, + output_dtype=qnnpack_act_qint8_scale_min_2_neg_12, +) + + +# ===================== +# | BACKEND CONFIGS | +# ===================== + + +def get_qnnpack_backend_config() -> BackendConfig: + """ + Return the `BackendConfig` for PyTorch's native QNNPACK backend. + """ + conv_dtype_configs = [ + qnnpack_weighted_op_qint8_symmetric_dtype_config, + qnnpack_weighted_op_quint8_dtype_config, + ] + linear_dtype_configs = [ + qnnpack_weighted_op_qint8_symmetric_dtype_config, + qnnpack_weighted_op_quint8_dtype_config, + qnnpack_default_dynamic_int8_dtype_config, + qnnpack_default_dynamic_float16_dtype_config, + ] + binary_op_dtype_configs = [ + qnnpack_default_op_qint8_symmetric_dtype_config, + qnnpack_default_op_quint8_dtype_config, + ] + default_op_dtype_configs = [ + qnnpack_default_op_qint8_symmetric_dtype_config, + qnnpack_default_op_quint8_dtype_config, + ] + fixed_qparams_op_dtype_configs = [ + qnnpack_default_op_qint8_symmetric_dtype_config, + qnnpack_default_op_quint8_dtype_config, + ] + share_qparams_op_dtype_configs = [ + qnnpack_default_op_qint8_symmetric_dtype_config, + qnnpack_default_op_quint8_dtype_config, + ] + rnn_op_dtype_configs = [ + qnnpack_default_dynamic_int8_dtype_config, + qnnpack_default_dynamic_float16_dtype_config, + ] + embedding_op_dtype_configs = [ + qnnpack_weight_only_quint8_dtype_config, + qnnpack_weight_only_quint4x2_dtype_config, + ] + return ( + BackendConfig("qnnpack") + .set_backend_pattern_configs(_get_conv_configs(conv_dtype_configs)) + .set_backend_pattern_configs(_get_linear_configs(linear_dtype_configs)) + .set_backend_pattern_configs(_get_binary_op_configs(binary_op_dtype_configs)) + .set_backend_pattern_config(_get_cat_config(default_op_dtype_configs)) + .set_backend_pattern_configs(_get_default_op_configs(default_op_dtype_configs)) + .set_backend_pattern_configs( + _get_fixed_qparams_op_configs(fixed_qparams_op_dtype_configs) + ) + .set_backend_pattern_configs( + _get_share_qparams_op_configs(share_qparams_op_dtype_configs) + ) + .set_backend_pattern_configs(_get_bn_configs(default_op_dtype_configs)) + .set_backend_pattern_configs(_get_rnn_op_configs(rnn_op_dtype_configs)) + .set_backend_pattern_configs( + _get_embedding_op_configs(embedding_op_dtype_configs) + ) + ) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/backend_config/tensorrt.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/backend_config/tensorrt.py new file mode 100644 index 0000000000000000000000000000000000000000..d0490e2071f4f2df59b4bb6eb2a1d7885b4aa036 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/backend_config/tensorrt.py @@ -0,0 +1,98 @@ +# mypy: allow-untyped-defs +import torch + +from ._common_operator_config_utils import ( + _get_binary_op_configs, + _get_conv_configs, + _get_linear_configs, + _get_share_qparams_op_configs, + _get_tensor_info_op_configs, +) +from .backend_config import ( + BackendConfig, + BackendPatternConfig, + DTypeConfig, + ObservationType, +) + + +__all__ = [ + "get_tensorrt_backend_config", + "get_tensorrt_backend_config_dict", +] + + +def get_tensorrt_backend_config() -> BackendConfig: + """ + Return the `BackendConfig` for the TensorRT backend. + NOTE: Current api will change in the future, it's just to unblock experimentation for + new backends, please don't use it right now. + TODO: add a README when it's more stable + """ + # dtype configs + weighted_op_qint8_dtype_config = DTypeConfig( + input_dtype=torch.qint8, + output_dtype=torch.qint8, + weight_dtype=torch.qint8, + bias_dtype=torch.float, + ) + non_weighted_op_qint8_dtype_config = DTypeConfig( + input_dtype=torch.qint8, + output_dtype=torch.qint8, + ) + + addmm_config = ( + BackendPatternConfig(torch.addmm) + .set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT) + .add_dtype_config(weighted_op_qint8_dtype_config) + ._set_input_type_to_index( + { + "bias": 0, + "input": 1, + "weight": 2, + } + ) + ) + cat_config = ( + BackendPatternConfig(torch.cat) + .set_observation_type(ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT) + .add_dtype_config(non_weighted_op_qint8_dtype_config) + ) + conv_dtype_configs = [ + weighted_op_qint8_dtype_config, + ] + linear_dtype_configs = [ + weighted_op_qint8_dtype_config, + ] + binary_op_dtype_configs = [ + weighted_op_qint8_dtype_config, + ] + share_qparams_op_dtype_configs = [ + non_weighted_op_qint8_dtype_config, + ] + tensor_info_op_dtype_configs = [ + non_weighted_op_qint8_dtype_config, + ] + # there might be things not supported in fx2trt, but it will error out + # during fx2trt conversion and can support them after that + return ( + BackendConfig("tensorrt") + .set_backend_pattern_configs(_get_conv_configs(conv_dtype_configs)) + .set_backend_pattern_config(addmm_config) + .set_backend_pattern_config(cat_config) + .set_backend_pattern_configs(_get_linear_configs(linear_dtype_configs)) + .set_backend_pattern_configs(_get_binary_op_configs(binary_op_dtype_configs)) + .set_backend_pattern_configs( + _get_share_qparams_op_configs(share_qparams_op_dtype_configs) + ) + .set_backend_pattern_configs( + _get_tensor_info_op_configs(tensor_info_op_dtype_configs) + ) + ) + + +def get_tensorrt_backend_config_dict(): + """ + Return the `BackendConfig` for the TensorRT backend in dictionary form. + """ + return get_tensorrt_backend_config().to_dict() diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/backend_config/utils.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/backend_config/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..4d486a061129324a311cdf74ebb58a51bf2dd9d8 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/backend_config/utils.py @@ -0,0 +1,317 @@ +# mypy: allow-untyped-defs +from collections.abc import Callable +from typing import Any + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.ao.quantization.fuser_method_mappings import _reverse2, _reverse3 +from torch.ao.quantization.utils import Pattern + +from .backend_config import BackendConfig, BackendPatternConfig, DTypeConfig + + +__all__ = [ + "get_pattern_to_dtype_configs", + "get_qat_module_classes", + "get_fused_module_classes", + "get_pattern_to_input_type_to_index", + "get_root_module_to_quantized_reference_module", + "get_fuser_method_mapping", + "get_module_to_qat_module", + "get_fusion_pattern_to_root_node_getter", + "get_fusion_pattern_to_extra_inputs_getter", + "remove_boolean_dispatch_from_name", + "pattern_to_human_readable", + "entry_to_pretty_str", +] + + +def get_pattern_to_dtype_configs( + backend_config: BackendConfig, +) -> dict[Pattern, list[DTypeConfig]]: + pattern_to_dtype_configs: dict[Pattern, list[DTypeConfig]] = {} + for pattern, config in backend_config._pattern_complex_format_to_config.items(): + pattern_to_dtype_configs[pattern] = config.dtype_configs + return pattern_to_dtype_configs + + +def get_qat_module_classes(backend_config: BackendConfig) -> tuple[type, ...]: + qat_module_classes = [ + config.qat_module + for config in backend_config.configs + if config.qat_module is not None + ] + return tuple(set(qat_module_classes)) + + +def get_fused_module_classes(backend_config: BackendConfig) -> tuple[type, ...]: + fused_module_classes = [ + config.fused_module + for config in backend_config.configs + if config.fused_module is not None + ] + return tuple(set(fused_module_classes)) + + +def get_pattern_to_input_type_to_index( + backend_config: BackendConfig, +) -> dict[Pattern, dict[str, int]]: + pattern_to_input_type_to_index: dict[Pattern, dict[str, int]] = {} + for pattern, config in backend_config._pattern_complex_format_to_config.items(): + pattern_to_input_type_to_index[pattern] = config._input_type_to_index + return pattern_to_input_type_to_index + + +def get_root_module_to_quantized_reference_module( + backend_config: BackendConfig, +) -> dict[type[torch.nn.Module], type[torch.nn.Module]]: + mapping: dict[type[torch.nn.Module], type[torch.nn.Module]] = {} + for config in backend_config.configs: + if ( + config.root_module is not None + and config.reference_quantized_module is not None + ): + mapping[config.root_module] = config.reference_quantized_module + return mapping + + +def get_fuser_method_mapping( + backend_config: BackendConfig, +) -> dict[Pattern, nn.Sequential | Callable]: + fuser_method_mapping: dict[Pattern, nn.Sequential | Callable] = {} + for pattern, config in backend_config._pattern_complex_format_to_config.items(): + if config.fuser_method is not None: + # Note: both the fuser method and the pattern are specified in forward order in the + # BackendConfig, but the internal pattern matching code uses the reversed nested tuple + # format, so we need to convert both to the internal format + fuser_method = _get_fuser_method_in_reversed_nested_tuple_format(config) + fuser_method_mapping[pattern] = fuser_method + return fuser_method_mapping + + +def get_module_to_qat_module( + backend_config: BackendConfig, +) -> dict[Pattern, type[torch.nn.Module]]: + module_to_qat_module: dict[Pattern, type[torch.nn.Module]] = {} + for pattern, config in backend_config._pattern_complex_format_to_config.items(): + if config.qat_module is not None: + module_to_qat_module[pattern] = config.qat_module + return module_to_qat_module + + +def get_fusion_pattern_to_root_node_getter( + backend_config: BackendConfig, +) -> dict[Pattern, Callable]: + """Get a map from fusion pattern to a function that returns the root node + from the fusion pattern, e.g. the most common one is: + def get_root_node(node_pattern): + while not isinstance(node_pattern[-1], Node): + node_pattern = node_pattern[-1] + return node_pattern[-1] + This can work for all patterns whose root node is the "last node" in the pattern, + e.g. (torch.add, MatchAllNode, (torch.ReLU, torch.Conv2d)) + """ + root_node_getter_mapping: dict[Pattern, Callable] = {} + for pattern, config in backend_config._pattern_complex_format_to_config.items(): + if config._root_node_getter is not None: + root_node_getter_mapping[pattern] = config._root_node_getter + return root_node_getter_mapping + + +def get_fusion_pattern_to_extra_inputs_getter( + backend_config: BackendConfig, +) -> dict[Pattern, Callable]: + """Get a map from fusion pattern to a function that returns extra input nodes + from the fusion pattern, in the order required by the root node. This is optional, + if not specified, we will not copy over any extra inputs for the root node. + Example: + # Let's say we have the pattern (torch.add, MatchAllNode, (torch.nn.BatchNorm2d, torch.nn.Conv2d)) + # and root node is torch.nn.Conv2d, and the node in MatchAllNode would be an extra + # argument to the fused module, we can unpack the pattern and return the node at + # MatchAllNode here + # we can implement extra_inputs_getter as follows: + def extra_inputs_getter(pattern) -> List[Any]: + add, extra_input, conv_pattern = pattern + return [extra_input] + """ + extra_inputs_getter_mapping: dict[Pattern, Callable] = {} + for pattern, config in backend_config._pattern_complex_format_to_config.items(): + if config._extra_inputs_getter is not None: + extra_inputs_getter_mapping[pattern] = config._extra_inputs_getter + return extra_inputs_getter_mapping + + +def remove_boolean_dispatch_from_name(p) -> Any: + """ + Some ops have a default string representation such as + '.fn at 0x7ff1106bf280>', + this function replaces them with the hardcoded function names. + """ + if p is F.fractional_max_pool2d: + return "torch.nn.functional.fractional_max_pool2d" + elif p is F.fractional_max_pool3d: + return "torch.nn.functional.fractional_max_pool3d" + elif p is F.max_pool1d: + return "torch.nn.functional.max_pool1d" + elif p is F.max_pool2d: + return "torch.nn.functional.max_pool2d" + elif p is F.max_pool3d: + return "torch.nn.functional.max_pool3d" + elif p is F.adaptive_max_pool1d: + return "torch.nn.functional.adaptive_max_pool1d" + elif p is F.adaptive_max_pool2d: + return "torch.nn.functional.adaptive_max_pool2d" + elif p is F.adaptive_max_pool3d: + return "torch.nn.functional.adaptive_max_pool3d" + if "boolean_dispatch" in str(p): + raise AssertionError( + f"{p} does not have a human readable representation in " + + "quantization documentation" + ) + return p + + +def pattern_to_human_readable(p) -> Any: + if isinstance(p, tuple): + # nested patterns, recurse + return tuple(pattern_to_human_readable(inner_p) for inner_p in p) + elif isinstance(p, str): + # method names are already human readable + return p + else: + p = remove_boolean_dispatch_from_name(p) + return p + + +# TODO(future PR): move backend_config_dict to use dataclass and move this logic to +# the corresponding __str__ function +def entry_to_pretty_str(entry) -> str: + """ + Given a backend_config_dict entry, returns a string with the human readable + representation of it. + """ + s = "{\n" + + # always output the pattern first + if "pattern" in entry: + pattern_str = pattern_to_human_readable(entry["pattern"]) + + s += f" 'pattern': {pattern_str},\n" + + # custom output for dtype_configs to make it look nice + if "dtype_configs" in entry: + s += " 'dtype_configs': [\n" + for dtype_config in entry["dtype_configs"]: + s += " {\n" + for k, v in dtype_config.items(): + s += f" '{k}': {v},\n" + s += " },\n" + s += " ],\n" + + # custom output for num_tensor_args_to_observation_type to make it look nice + if "num_tensor_args_to_observation_type" in entry: + s += " 'num_tensor_args_to_observation_type': {\n" + for k, v in entry["num_tensor_args_to_observation_type"].items(): + s += f" {k}: {v},\n" + s += " },\n" + + # output all the other fields + custom_handled_fields = [ + "pattern", + "dtype_configs", + "num_tensor_args_to_observation_type", + ] + for field_name in entry: + if field_name in custom_handled_fields: + continue + s += f" '{field_name}': {entry[field_name]},\n" + + s += "}" + return s + + +def _get_pattern_in_reversed_nested_tuple_format( + config: BackendPatternConfig, +) -> Pattern: + """ + Return the pattern specified in the given config in the reversed nested tuple format + used internally in the quantization pattern matching code. + + If the pattern is not a tuple, or the pattern is already specified in the reversed + nested tuple format, return the pattern as is. Otherwise: + + For 2-tuples (a, b), return (b, a). + For 3-tuples (a, b, c), return (c, (b, a)). + + For example: + * Given nn.Linear, return nn.Linear + * Given (nn.Linear, nn.ReLU), return (nn.ReLU, nn.Linear) + * Given (nn.Conv2d, nn.BatchNorm2d, nn.ReLU), return + (nn.ReLU, (nn.BatchNorm2d, nn.Conv2d)) + + For context, the reason why this is needed is the user-facing BackendConfig + API accepts the flat 2-or-3-tuple format in forward order. While this simple + format handles the vast majority of use cases, it does not handle the more + complex ones, and so the internal pattern matching code for quantization uses + the following, more general reversed nested tuple format instead: + + operator = module_type | functional | torch op | native op | MatchAllNode + Pattern = (operator, Pattern, Pattern, ...) | operator + + In the future, we expect to replace the above complex format with the one used + by the subgraph rewriter in torch.fx, so we don't have to maintain our own + complex pattern matching code. Then we won't need this helper function anymore. + """ + if config._pattern_complex_format is not None: + return config._pattern_complex_format + if config.pattern is None: + raise ValueError( + "Either 'pattern' or 'pattern_complex_format' must be specified" + ) + if not isinstance(config.pattern, tuple): + return config.pattern + + # Pattern is specified in the simple tuple format, need to convert + if len(config.pattern) == 2: + (a, b) = config.pattern + return (b, a) + elif len(config.pattern) == 3: + (a, b, c) = config.pattern + return (c, (b, a)) + else: + raise ValueError("Expected a tuple with 2 or 3 elements, got: ", config.pattern) + + +def _get_fuser_method_in_reversed_nested_tuple_format( + config: BackendPatternConfig, +) -> Callable: + """ + Return the fuser method specified in the given config in the reversed nested + tuple format used internally in the quantization pattern matching code. + + If pattern is specified in the reversed nested tuple format, we assume the + fuser method is also specified in this format and simply return it as is. + Otherwise, we convert the fuser method as follows: + + * Given f(is_qat, conv, relu), return f'(is_qat, relu, conv) + * Given f(is_qat, conv, bn, relu), return f'(is_qat, relu, bn_conv), + where bn_conv is a 2-tuple (bn, conv) + + The first argument of a fuser method is always `is_qat` and is not affected + in the conversion. We currently only support functions with 3 or 4 arguments. + """ + if config.fuser_method is None: + raise AssertionError("config.fuser_method must be provided") + if config._pattern_complex_format is not None: + return config.fuser_method + if not isinstance(config.pattern, tuple): + raise ValueError("Expected pattern to be a tuple, got: ", config.pattern) + + # Pattern is specified in the simple tuple format, need to convert + if len(config.pattern) == 2: + return _reverse2(config.fuser_method) + elif len(config.pattern) == 3: + return _reverse3(config.fuser_method) + else: + raise ValueError("Expected a tuple with 2 or 3 elements, got: ", config.pattern) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/backend_config/x86.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/backend_config/x86.py new file mode 100644 index 0000000000000000000000000000000000000000..c64b56c981b391140f63038ac507b0708ee876f4 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/backend_config/x86.py @@ -0,0 +1,126 @@ +import torch + +from ._common_operator_config_utils import ( + _get_binary_op_configs, + _get_bn_configs, + _get_cat_config, + _get_conv_configs, + _get_default_op_configs, + _get_embedding_op_configs, + _get_fixed_qparams_op_configs, + _get_linear_configs, + _get_rnn_op_configs, + _get_share_qparams_op_configs, + _get_tensor_info_op_configs, +) +from .backend_config import BackendConfig, DTypeConfig + + +__all__ = [ + "get_x86_backend_config", +] + +# =================== +# | DTYPE CONFIGS | +# =================== + +# X86 aligns with FBGEMM for now + +x86_weighted_op_int8_dtype_config = DTypeConfig( + input_dtype=torch.quint8, + output_dtype=torch.quint8, + weight_dtype=torch.qint8, + bias_dtype=torch.float, +) + +x86_default_op_quint8_dtype_config = DTypeConfig( + input_dtype=torch.quint8, + output_dtype=torch.quint8, +) + +x86_default_op_fp16_dtype_config = DTypeConfig( + input_dtype=torch.float16, + output_dtype=torch.float16, + weight_dtype=torch.float16, + bias_dtype=torch.float16, +) + +x86_default_dynamic_int8_dtype_config = DTypeConfig( + input_dtype=torch.quint8, + output_dtype=torch.float, + weight_dtype=torch.qint8, + bias_dtype=torch.float, + is_dynamic=True, +) + +x86_default_dynamic_float16_dtype_config = DTypeConfig( + input_dtype=torch.float16, + output_dtype=torch.float, + weight_dtype=torch.float16, + bias_dtype=torch.float, + is_dynamic=True, +) + +x86_weight_only_quint8_dtype_config = DTypeConfig( + input_dtype=torch.float, + output_dtype=torch.float, + weight_dtype=torch.quint8, +) + +x86_weight_only_quint4x2_dtype_config = DTypeConfig( + input_dtype=torch.float, + output_dtype=torch.float, + weight_dtype=torch.quint4x2, +) + + +# ===================== +# | BACKEND CONFIGS | +# ===================== + + +def get_x86_backend_config() -> BackendConfig: + """ + Return the `BackendConfig` for PyTorch's native x86 backend. + """ + conv_dtype_configs = [x86_weighted_op_int8_dtype_config] + linear_dtype_configs = [ + x86_weighted_op_int8_dtype_config, + x86_default_dynamic_int8_dtype_config, + x86_default_dynamic_float16_dtype_config, + ] + binary_op_dtype_configs = [x86_weighted_op_int8_dtype_config] + default_op_dtype_configs = [x86_default_op_quint8_dtype_config] + fixed_qparams_op_dtype_configs = [x86_weighted_op_int8_dtype_config] + share_qparams_op_dtype_configs = [x86_default_op_quint8_dtype_config] + tensor_info_op_dtype_configs = [x86_default_op_quint8_dtype_config] + rnn_op_dtype_configs = [ + x86_default_dynamic_int8_dtype_config, + x86_default_dynamic_float16_dtype_config, + ] + embedding_op_dtype_configs = [ + x86_weight_only_quint8_dtype_config, + x86_weight_only_quint4x2_dtype_config, + ] + return ( + BackendConfig("x86") + .set_backend_pattern_configs(_get_conv_configs(conv_dtype_configs)) + .set_backend_pattern_configs(_get_linear_configs(linear_dtype_configs)) + .set_backend_pattern_configs(_get_binary_op_configs(binary_op_dtype_configs)) + .set_backend_pattern_config(_get_cat_config(default_op_dtype_configs)) + .set_backend_pattern_configs(_get_default_op_configs(default_op_dtype_configs)) + .set_backend_pattern_configs( + _get_fixed_qparams_op_configs(fixed_qparams_op_dtype_configs) + ) + .set_backend_pattern_configs( + _get_share_qparams_op_configs(share_qparams_op_dtype_configs) + ) + .set_backend_pattern_configs( + _get_tensor_info_op_configs(tensor_info_op_dtype_configs) + ) + .set_backend_pattern_configs(_get_bn_configs(default_op_dtype_configs)) + .set_backend_pattern_configs(_get_rnn_op_configs(rnn_op_dtype_configs)) + .set_backend_pattern_configs( + _get_embedding_op_configs(embedding_op_dtype_configs) + ) + ) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..72d624ad7d6a3926c5d34afab3b7066928f9933d --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/__init__.py @@ -0,0 +1,3 @@ +from .convert import convert +from .fuse import fuse +from .prepare import prepare diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/_decomposed.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/_decomposed.py new file mode 100644 index 0000000000000000000000000000000000000000..0754627a19dd1241dda4c53121f994b1b63ff025 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/_decomposed.py @@ -0,0 +1,1268 @@ +# mypy: allow-untyped-defs +import math + +import torch +from torch._refs import _unsqueeze_multiple +from torch.ao.quantization.utils import determine_qparams, validate_qmin_qmax +from torch.library import impl, Library + + +# Note: decomposed means decomposed quantized tensor, using decomposed so that the +# name is not too long +quantized_decomposed_lib = Library("quantized_decomposed", "DEF") + +_INTEGER_DTYPES = [torch.uint8, torch.int8, torch.uint16, torch.int16, torch.int32] +_FLOAT_DTYPES = [torch.float8_e5m2, torch.float8_e4m3fn] + +_DTYPE_TO_QVALUE_BOUNDS = { + k: (torch.iinfo(k).min, torch.iinfo(k).max) for k in _INTEGER_DTYPES +} +_DTYPE_TO_QVALUE_BOUNDS.update( + {k: (int(torch.finfo(k).min), int(torch.finfo(k).max)) for k in _FLOAT_DTYPES} +) + + +# Helper to check the passed in quant min and max are valid for the dtype +def _quant_min_max_bounds_check(quant_min, quant_max, dtype): + if dtype not in _DTYPE_TO_QVALUE_BOUNDS: + raise ValueError(f"Unsupported dtype: {dtype}") + quant_min_lower_bound, quant_max_upper_bound = _DTYPE_TO_QVALUE_BOUNDS[dtype] + + if quant_min < quant_min_lower_bound: + raise AssertionError( + "quant_min out of bound for dtype, " + f"quant_min_lower_bound: {quant_min_lower_bound} quant_min: {quant_min}" + ) + + if quant_max > quant_max_upper_bound: + raise AssertionError( + "quant_max out of bound for dtype, " + f"quant_max_upper_bound: {quant_max_upper_bound} quant_max: {quant_max}" + ) + + +quantized_decomposed_lib.define( + "quantize_per_tensor(Tensor input, float scale, int zero_point, " + "int quant_min, int quant_max, ScalarType dtype) -> Tensor" +) + + +@impl(quantized_decomposed_lib, "quantize_per_tensor", "CompositeExplicitAutograd") +def quantize_per_tensor( + input: torch.Tensor, + scale: float, + zero_point: int, + quant_min: int, + quant_max: int, + dtype: torch.dtype, +) -> torch.Tensor: + """Affine quantization for the Tensor using the same quantization parameters to map + from floating point to quantized values + + Args: + input (torch.Tensor): original float32 or bfloat16 Tensor + scale (float): quantization parameter for affine quantization + zero_point (int): quantization parameter for affine quantization + quant_min (int): minimum quantized value for output Tensor + quant_max (int): maximum quantized value for output Tensor + dtype (torch.dtype): requested dtype (e.g. torch.uint8) for output Tensor + + Returns: + Tensor with requested dtype (e.g. torch.uint8), note the quantization parameters + are not stored in the Tensor, we are storing them in function arguments instead + """ + if input.dtype in [torch.float16, torch.bfloat16]: + input = input.to(torch.float32) + if input.dtype != torch.float32: + raise AssertionError( + f"Expecting input to have dtype torch.float32, but got dtype: {input.dtype}" + ) + _quant_min_max_bounds_check(quant_min, quant_max, dtype) + + inv_scale = 1.0 / scale + return torch.clamp( + torch.round(input * inv_scale) + zero_point, quant_min, quant_max + ).to(dtype) + + +@impl(quantized_decomposed_lib, "quantize_per_tensor", "Meta") +def quantize_per_tensor_meta( + input: torch.Tensor, + scale: float, + zero_point: int, + quant_min: int, + quant_max: int, + dtype: torch.dtype, +) -> torch.Tensor: + if input.dtype in [torch.float16, torch.bfloat16]: + input = input.to(torch.float32) + if input.dtype != torch.float32: + raise AssertionError( + f"Expecting input to have dtype torch.float32, but got dtype: {input.dtype}" + ) + return torch.empty_like(input, dtype=dtype) + + +quantized_decomposed_lib.define( + "quantize_per_tensor.tensor(Tensor input, Tensor scale, Tensor zero_point, " + "int quant_min, int quant_max, ScalarType dtype) -> Tensor" +) + + +@impl( + quantized_decomposed_lib, "quantize_per_tensor.tensor", "CompositeExplicitAutograd" +) +def quantize_per_tensor_tensor( + input: torch.Tensor, + scale: torch.Tensor, + zero_point: torch.Tensor, + quant_min: int, + quant_max: int, + dtype: torch.dtype, +) -> torch.Tensor: + """Affine quantization for the Tensor using the same quantization parameters to map + from floating point to quantized values + Same as `quantize_per_tensor` but scale and zero_point are Scalar Tensor instead of + scalar values + """ + if zero_point.numel() != 1: + raise AssertionError( + f"Expecting zero_point tensor to be one element, but received : {zero_point.numel()}" + ) + if scale.numel() != 1: + raise AssertionError( + f"Expecting scale tensor to be one element, but received : {scale.numel()}" + ) + return quantize_per_tensor( + input, + scale.item(), + zero_point.item(), # type: ignore[arg-type] + quant_min, # type: ignore[arg-type] + quant_max, # type: ignore[arg-type] + dtype, + ) + + +@impl(quantized_decomposed_lib, "quantize_per_tensor.tensor", "Meta") +def quantize_per_tensor_tensor_meta( + input: torch.Tensor, + scale: torch.Tensor, + zero_point: torch.Tensor, + quant_min: int, + quant_max: int, + dtype: torch.dtype, +) -> torch.Tensor: + if input.dtype in [torch.float16, torch.bfloat16]: + input = input.to(torch.float32) + if zero_point.numel() != 1: + raise AssertionError( + f"Expecting zero_point tensor to be one element, but received : {zero_point.numel()}" + ) + if scale.numel() != 1: + raise AssertionError( + f"Expecting scale tensor to be one element, but received : {scale.numel()}" + ) + if input.dtype != torch.float32: + raise AssertionError( + f"Expecting input to have dtype torch.float32, but got dtype: {input.dtype}" + ) + return torch.empty_like(input, dtype=dtype) + + +# TODO: remove other variants and keep this one +quantized_decomposed_lib.define( + "quantize_per_tensor.tensor2(Tensor input, Tensor scale, Tensor zero_point, " + "Tensor quant_min, Tensor quant_max, ScalarType dtype) -> Tensor" +) + + +@impl( + quantized_decomposed_lib, "quantize_per_tensor.tensor2", "CompositeExplicitAutograd" +) +def quantize_per_tensor_tensor2( + input: torch.Tensor, + scale: torch.Tensor, + zero_point: torch.Tensor, + quant_min: torch.Tensor, + quant_max: torch.Tensor, + dtype: torch.dtype, +) -> torch.Tensor: + """Affine quantization for the Tensor using the same quantization parameters to map + from floating point to quantized values + Same as `quantize_per_tensor` but scale and zero_point are Scalar Tensor instead of + scalar values + """ + if zero_point.numel() != 1: + raise AssertionError( + f"Expecting zero_point tensor to be one element, but received : {zero_point.numel()}" + ) + if scale.numel() != 1: + raise AssertionError( + f"Expecting scale tensor to be one element, but received : {scale.numel()}" + ) + return quantize_per_tensor( + input, + scale.item(), + zero_point.item(), # type: ignore[arg-type] + quant_min.item(), # type: ignore[arg-type] + quant_max.item(), # type: ignore[arg-type] + dtype, + ) + + +@impl(quantized_decomposed_lib, "quantize_per_tensor.tensor2", "Meta") +def quantize_per_tensor_tensor2_meta( + input: torch.Tensor, + scale: torch.Tensor, + zero_point: torch.Tensor, + quant_min: torch.Tensor, + quant_max: torch.Tensor, + dtype: torch.dtype, +) -> torch.Tensor: + return quantize_per_tensor_tensor_meta( + input, + scale, + zero_point, # type: ignore[arg-type] + quant_min, # type: ignore[arg-type] + quant_max, # type: ignore[arg-type] + dtype, + ) + + +# Note: quant_min/quant_max/dtype are not used in the operator, but for now it's kept in +# the signature as metadata for the input Tensor, this might be useful for pattern +# matching in the future +# We will revisit this later if we found there are no use cases for it +quantized_decomposed_lib.define( + "dequantize_per_tensor(Tensor input, float scale, int zero_point, " + "int quant_min, int quant_max, ScalarType dtype, *, ScalarType? out_dtype=None) -> Tensor" +) + + +@impl(quantized_decomposed_lib, "dequantize_per_tensor", "CompositeExplicitAutograd") +def dequantize_per_tensor( + input: torch.Tensor, + scale: float, + zero_point: int, + quant_min: int, + quant_max: int, + dtype: torch.dtype, + *, + out_dtype: torch.dtype | None = None, +) -> torch.Tensor: + """Affine dequantization for the Tensor using the same quantization parameters to map + from quantized values to floating point values + + Args: + input (torch.Tensor): Tensor with dtype matching `dtype` argument, + e.g. (`torch.uint8`), it is a per tensor quantized Tensor if combined with + quantization parameters in the argument of this function (scale/zero_point) + + scale (float): quantization parameter for affine quantization + + zero_point (int): quantization parameter for affine quantization + + quant_min (int): minimum quantized value for input Tensor (not used in computation, + reserved for pattern matching) + + quant_max (int): maximum quantized value for input Tensor (not used in computation, + reserved for pattern matching) + + dtype (torch.dtype): dtype for input Tensor (not used in computation, + reserved for pattern matching) + + out_dtype (torch.dtype?): optional dtype for output Tensor + + Returns: + dequantized float32 Tensor + """ + if input.dtype != dtype: + raise AssertionError( + f"Expecting input to have dtype: {dtype}, but got {input.dtype}" + ) + if out_dtype is None: + out_dtype = torch.float32 + if dtype in _DTYPE_TO_QVALUE_BOUNDS: + # TODO: investigate why + # (input - zero_point).to(torch.float32) * scale + # failed the test + return (input.to(out_dtype) - zero_point) * scale + else: + raise ValueError(f"Unsupported dtype in dequantize_per_tensor: {dtype}") + + +@impl(quantized_decomposed_lib, "dequantize_per_tensor", "Meta") +def dequantize_per_tensor_meta( + input: torch.Tensor, + scale: torch.Tensor, + zero_point: torch.Tensor, + quant_min: int, + quant_max: int, + dtype: torch.dtype, + *, + out_dtype: torch.dtype | None = None, +) -> torch.Tensor: + if out_dtype is None: + out_dtype = torch.float32 + return torch.empty_like(input, dtype=out_dtype) + + +quantized_decomposed_lib.define( + "dequantize_per_tensor.tensor(Tensor input, Tensor scale, Tensor zero_point, " + "int quant_min, int quant_max, ScalarType dtype, *, ScalarType? out_dtype=None) -> Tensor" +) + + +@impl( + quantized_decomposed_lib, + "dequantize_per_tensor.tensor", + "CompositeExplicitAutograd", +) +def dequantize_per_tensor_tensor( + input: torch.Tensor, + scale: torch.Tensor, + zero_point: torch.Tensor, + quant_min: int, + quant_max: int, + dtype: torch.dtype, + *, + out_dtype: torch.dtype | None = None, +) -> torch.Tensor: + """Affine dequantization for the Tensor using the same quantization parameters to map + from quantized values to floating point values + Same as `dequantize_per_tensor` but scale and zero_point are Scalar Tensor instead of + scalar values + """ + if zero_point.numel() != 1: + raise AssertionError( + f"Expecting zero_point tensor to be one element, but received : {zero_point.numel()}" + ) + if scale.numel() != 1: + raise AssertionError( + f"Expecting scale tensor to be one element, but received : {scale.numel()}" + ) + return dequantize_per_tensor( + input, + scale.item(), + zero_point.item(), # type: ignore[arg-type] + quant_min, + quant_max, + dtype, + out_dtype=out_dtype, + ) + + +@impl(quantized_decomposed_lib, "dequantize_per_tensor.tensor", "Meta") +def dequantize_per_tensor_tensor_meta( + input: torch.Tensor, + scale: torch.Tensor, + zero_point: torch.Tensor, + quant_min: int, + quant_max: int, + dtype: torch.dtype, + *, + out_dtype: torch.dtype | None = None, +) -> torch.Tensor: + if out_dtype is None: + out_dtype = torch.float32 + if zero_point.numel() != 1: + raise AssertionError( + f"Expecting zero_point tensor to be one element, but received : {zero_point.numel()}" + ) + if scale.numel() != 1: + raise AssertionError( + f"Expecting scale tensor to be one element, but received : {scale.numel()}" + ) + if input.dtype != dtype: + raise AssertionError( + f"Expecting input to have dtype: {dtype}, but got {input.dtype}" + ) + if dtype in _DTYPE_TO_QVALUE_BOUNDS: + return torch.empty_like(input, dtype=out_dtype) + else: + raise ValueError(f"Unsupported dtype in dequantize_per_tensor: {dtype}") + + +# TODO: remove other variants and keep this one +quantized_decomposed_lib.define( + "dequantize_per_tensor.tensor2(Tensor input, Tensor scale, Tensor zero_point, " + "Tensor quant_min, Tensor quant_max, ScalarType dtype, *, ScalarType? out_dtype=None) -> Tensor" +) + + +@impl( + quantized_decomposed_lib, + "dequantize_per_tensor.tensor2", + "CompositeExplicitAutograd", +) +def dequantize_per_tensor_tensor2( + input: torch.Tensor, + scale: torch.Tensor, + zero_point: torch.Tensor, + quant_min: torch.Tensor, + quant_max: torch.Tensor, + dtype: torch.dtype, + *, + out_dtype: torch.dtype | None = None, +) -> torch.Tensor: + """Affine dequantization for the Tensor using the same quantization parameters to map + from quantized values to floating point values + Same as `dequantize_per_tensor` but scale and zero_point are Scalar Tensor instead of + scalar values + """ + if zero_point.numel() != 1: + raise AssertionError( + f"Expecting zero_point tensor to be one element, but received : {zero_point.numel()}" + ) + if scale.numel() != 1: + raise AssertionError( + f"Expecting scale tensor to be one element, but received : {scale.numel()}" + ) + return dequantize_per_tensor( + input, + scale.item(), + zero_point.item(), # type: ignore[arg-type] + quant_min.item(), # type: ignore[arg-type] + quant_max.item(), # type: ignore[arg-type] + dtype, + out_dtype=out_dtype, + ) + + +@impl(quantized_decomposed_lib, "dequantize_per_tensor.tensor2", "Meta") +def dequantize_per_tensor_tensor2_meta( + input, + scale, + zero_point, + quant_min, + quant_max, + dtype, + *, + out_dtype: torch.dtype | None = None, +) -> torch.Tensor: + return dequantize_per_tensor_tensor_meta( + input, scale, zero_point, quant_min, quant_max, dtype, out_dtype=out_dtype + ) + + +quantized_decomposed_lib.define( + "choose_qparams.tensor(Tensor input, int quant_min, int quant_max, " + "float eps, ScalarType dtype) -> (Tensor, Tensor)" +) + + +@impl(quantized_decomposed_lib, "choose_qparams.tensor", "CompositeExplicitAutograd") +def choose_qparams_tensor( + input: torch.Tensor, qmin: int, qmax: int, eps: float, dtype: torch.dtype +) -> tuple[torch.Tensor, torch.Tensor]: + """Given an input Tensor, derive the per tensor affine quantization parameter + (scale and zero_point) for target quantized Tensor from the Tensor + + Args: + input (torch.Tensor): floating point input Tensor + quant_min (int): minimum quantized value for target quantized Tensor + quant_max (int): maximum quantized value for target quantized Tensor + dtype (torch.dtype): dtype for target quantized Tensor + + Returns: + scale (float): quantization parameter for the target quantized Tensor + zero_point (int): quantization parameter for the target quantized Tensor + """ + if input.dtype not in [ + torch.float32, + torch.float16, + torch.bfloat16, + ]: + raise AssertionError( + f"Expecting input to have dtype torch.float32/16/b16, but got dtype: {input.dtype}" + ) + if dtype not in _DTYPE_TO_QVALUE_BOUNDS: + raise AssertionError( + f"Expecting target dtype to be one of {_DTYPE_TO_QVALUE_BOUNDS.keys()}, but got: {dtype}" + ) + validate_qmin_qmax(qmin, qmax) + + min_val, max_val = torch.aminmax(input) + + return determine_qparams( + min_val, + max_val, + qmin, + qmax, + dtype, + torch.Tensor([eps]), + has_customized_qrange=False, + ) + + +quantized_decomposed_lib.define( + "choose_qparams_symmetric.tensor(Tensor input, int quant_min, int quant_max, " + "float eps, ScalarType dtype) -> (Tensor, Tensor)" +) + + +@impl( + quantized_decomposed_lib, + "choose_qparams_symmetric.tensor", + "CompositeExplicitAutograd", +) +def choose_qparams_symmetric_tensor( + input: torch.Tensor, qmin: int, qmax: int, eps: float, dtype: torch.dtype +) -> tuple[torch.Tensor, torch.Tensor]: + """Given an input Tensor, derive the per tensor affine quantization parameter + (scale and zero_point) for target quantized Tensor from the Tensor + + Args: + input (torch.Tensor): floating point input Tensor + quant_min (int): minimum quantized value for target quantized Tensor + quant_max (int): maximum quantized value for target quantized Tensor + dtype (torch.dtype): dtype for target quantized Tensor + + Returns: + scale (float): quantization parameter for the target quantized Tensor + zero_point (int): quantization parameter for the target quantized Tensor + """ + if input.dtype not in [ + torch.float32, + torch.float16, + torch.bfloat16, + ]: + raise AssertionError( + f"Expecting input to have dtype torch.float32/16/b16, but got dtype: {input.dtype}" + ) + if dtype not in _DTYPE_TO_QVALUE_BOUNDS: + raise AssertionError( + f"Expecting target dtype to be one of {_DTYPE_TO_QVALUE_BOUNDS.keys()}, but got: {dtype}" + ) + validate_qmin_qmax(qmin, qmax) + + min_val, max_val = torch.aminmax(input) + return determine_qparams( + min_val, + max_val, + qmin, + qmax, + dtype, + torch.Tensor([eps]), + has_customized_qrange=False, + qscheme=torch.per_tensor_symmetric, + ) + + +@impl(quantized_decomposed_lib, "choose_qparams.tensor", "Meta") +def choose_qparams_tensor_meta( + input: torch.Tensor, quant_min: int, quant_max: int, eps: float, dtype: torch.dtype +) -> tuple[torch.Tensor, torch.Tensor]: + if input.dtype not in [ + torch.float32, + torch.float16, + torch.bfloat16, + ]: + raise AssertionError( + f"Expecting input to have dtype torch.float32/16/b16, but got dtype: {input.dtype}" + ) + if quant_min >= quant_max: + raise AssertionError( + f"Expecting quant_min to be smaller than quant_max but received min: {quant_min} max: {quant_max}" + ) + return torch.empty(1, dtype=torch.double, device=input.device), torch.empty( + 1, dtype=torch.int64, device=input.device + ) + + +@impl(quantized_decomposed_lib, "choose_qparams_symmetric.tensor", "Meta") +def choose_qparams_symmetric_tensor_meta( + input: torch.Tensor, quant_min: int, quant_max: int, eps: float, dtype: torch.dtype +) -> tuple[torch.Tensor, torch.Tensor]: + return torch.empty(1, dtype=torch.double, device=input.device), torch.empty( + 1, dtype=torch.int64, device=input.device + ) + + +# Helper function used to implement per-channel quantization against any axis +def _permute_to_axis_zero(x, axis): + new_axis_list = list(range(x.dim())) + new_axis_list[axis] = 0 + new_axis_list[0] = axis + y = x.permute(tuple(new_axis_list)) + return y, new_axis_list + + +quantized_decomposed_lib.define( + "quantize_per_channel(Tensor input, Tensor scales, Tensor zero_points, int axis, " + "int quant_min, int quant_max, ScalarType dtype) -> Tensor" +) + + +@impl(quantized_decomposed_lib, "quantize_per_channel", "CompositeExplicitAutograd") +def quantize_per_channel( + input: torch.Tensor, + scales: torch.Tensor, + zero_points: torch.Tensor, + axis: int, + quant_min: int, + quant_max: int, + dtype: torch.dtype, +) -> torch.Tensor: + """Affine per channel quantization for the Tensor using the same quantization + parameters for each channel/axis to map from floating point to quantized values + + Args: + input (torch.Tensor): original float32 or bfloat16 Tensor + scales (torch.Tensor): a list of scale quantization parameter for + affine quantization, one per channel + zero_point (torch.Tensor): a list of zero_point quantization parameter for + affine quantization, one per channel + quant_min (int): minimum quantized value for output Tensor + quant_max (int): maximum quantized value for output Tensor + dtype (torch.dtype): requested dtype (e.g. torch.uint8) for output Tensor + + Returns: + Tensor with requested dtype (e.g. torch.uint8), note the quantization parameters + are not stored in the Tensor, we are storing them in function arguments instead + """ + if input.dtype in [torch.float16, torch.bfloat16]: + input = input.to(torch.float32) + if input.dtype != torch.float32: + raise AssertionError( + f"Expecting input to have dtype torch.float32, but got dtype: {input.dtype}" + ) + if axis >= input.dim(): + raise AssertionError(f"Expecting axis to be < {input.dim()}") + _quant_min_max_bounds_check(quant_min, quant_max, dtype) + input, permute_axis_list = _permute_to_axis_zero(input, axis) + + new_shape = [1] * input.dim() + new_shape[0] = scales.shape[0] + scales = scales.view(new_shape) + zero_points = zero_points.view(new_shape) + + res = torch.clamp( + torch.round(input * (1.0 / scales)) + zero_points, quant_min, quant_max + ) + out = res.permute(tuple(permute_axis_list)) + return out.to(dtype) + + +@impl(quantized_decomposed_lib, "quantize_per_channel", "Meta") +def quantize_per_channel_meta( + input: torch.Tensor, + scales: torch.Tensor, + zero_points: torch.Tensor, + axis: int, + quant_min: int, + quant_max: int, + dtype: torch.dtype, +) -> torch.Tensor: + if input.dtype in [torch.float16, torch.bfloat16]: + input = input.to(torch.float32) + if input.dtype != torch.float32: + raise AssertionError( + f"Expecting input to have dtype torch.float32, but got dtype: {input.dtype}" + ) + if axis >= input.dim(): + raise AssertionError(f"Expecting axis to be < {input.dim()}") + _quant_min_max_bounds_check(quant_min, quant_max, dtype) + return torch.empty_like(input, dtype=dtype) + + +# Note: quant_min/quant_max/dtype are not used in the operator, but for now it's kept in +# the signature as metadata for the input Tensor, this might be useful for pattern +# matching in the future +# We will revisit this later if we found there are no use cases for it +quantized_decomposed_lib.define( + "dequantize_per_channel(Tensor input, Tensor scales, Tensor? zero_points, int axis, " + "int quant_min, int quant_max, ScalarType dtype, *, ScalarType? out_dtype=None) -> Tensor" +) + + +@impl(quantized_decomposed_lib, "dequantize_per_channel", "CompositeExplicitAutograd") +def dequantize_per_channel( + input: torch.Tensor, + scales: torch.Tensor, + zero_points: torch.Tensor | None, + axis: int, + quant_min: int, + quant_max: int, + dtype: torch.dtype, + *, + out_dtype: torch.dtype | None = None, +) -> torch.Tensor: + """Affine per channel dequantization for the Tensor using the same quantization + parameters for each channel/axis to map from quantized values to floating point values + + Args: + input (torch.Tensor): Tensor with dtype matching `dtype` argument, + e.g. (`torch.uint8`), it is a per channel quantized Tensor if combined with + quantization parameter in the argument of this function (scales/zero_points/axis) + + scales (torch.Tensor): a list of scale quantization parameter for + affine quantization, one per channel + + zero_points (torch.Tensor): a list of zero_point quantization parameter for + affine quantization, one per channel + + quant_min (int): minimum quantized value for output Tensor (not used in computation, + reserved for pattern matching) + + quant_max (int): maximum quantized value for output Tensor (not used in computation, + reserved for pattern matching) + + dtype (torch.dtype): requested dtype for output Tensor (not used in computation, + reserved for pattern matching) + + out_dtype (torch.dtype?): optional dtype for output Tensor + + Returns: + dequantized float32 Tensor + """ + if input.dtype != dtype: + raise AssertionError( + f"Expecting input to have dtype: {dtype}, but got dtype: {input.dtype}" + ) + if out_dtype is None: + out_dtype = torch.float32 + if axis >= input.dim(): + raise AssertionError(f"Expecting axis to be < {input.dim()}") + _quant_min_max_bounds_check(quant_min, quant_max, dtype) + input, permute_axis_list = _permute_to_axis_zero(input, axis) + + new_shape = [1] * input.dim() + new_shape[0] = scales.shape[0] + scales = scales.view(new_shape) + if zero_points is not None: + res = (input - zero_points.view(new_shape)) * scales + else: + res = input * scales + + res = res.to(out_dtype) + + out = res.permute(tuple(permute_axis_list)) + return out + + +@impl(quantized_decomposed_lib, "dequantize_per_channel", "Meta") +def dequantize_per_channel_meta( + input: torch.Tensor, + scales: torch.Tensor, + zero_points: torch.Tensor | None, + axis: int, + quant_min: int, + quant_max: int, + dtype: torch.dtype, + *, + out_dtype: torch.dtype | None = None, +) -> torch.Tensor: + if input.dtype != dtype: + raise AssertionError( + f"Expecting input to have dtype {dtype}, but got dtype: {input.dtype}" + ) + if out_dtype is None: + out_dtype = torch.float32 + if axis >= input.dim(): + raise AssertionError(f"Expecting axis to be < {input.dim()}") + _quant_min_max_bounds_check(quant_min, quant_max, dtype) + return torch.empty_like(input, dtype=out_dtype) + + +quantized_decomposed_lib.define( + "choose_qparams_per_token(Tensor input, ScalarType dtype) -> (Tensor, Tensor)" +) + + +@impl( + quantized_decomposed_lib, + "choose_qparams_per_token", + "CompositeExplicitAutograd", +) +def choose_qparams_per_token( + input: torch.Tensor, + dtype: torch.dtype, +) -> tuple[torch.Tensor, torch.Tensor]: + """Choose quantization parameters for per token quantization. This means for a N dimension Tensor + (M1, M2, ...Mn, N), we calculate scales/zero_points for each N elements and quantize + every N elements with the same quantization parameter. The dimension for scales/zero_points + will be (M1 * M2 ... * Mn) + + Args: + input (torch.Tensor): original float32/float16 Tensor + dtype (torch.dtype): dtype (e.g. torch.uint8) for input Tensor + + Returns: + scales and zero_points, both float32 Tensors + """ + + scales = input.abs().amax(dim=-1, keepdim=True) + if scales.dtype == torch.float16: + scales = ( + scales.float() + ) # want float scales to avoid overflows for fp16, (bf16 has wide enough range) + if dtype == torch.int8: + n_bits = 8 + quant_max = 2 ** (n_bits - 1) - 1 + else: + raise Exception( # noqa: TRY002 + f"unsupported dtype in choose_qparams_per_token: {dtype}" + ) + + scales = scales.clamp(min=1e-5).div(quant_max) + zero_points = torch.zeros_like(scales) + return scales, zero_points + + +@impl( + quantized_decomposed_lib, + "choose_qparams_per_token", + "Meta", +) +def choose_qparams_per_token_meta( + input: torch.Tensor, + dtype: torch.dtype, +) -> tuple[torch.Tensor, torch.Tensor]: + size = list(input.shape[:-1]) + [1] + return torch.empty(size, dtype=torch.double, device=input.device), torch.empty( + size, dtype=torch.int64, device=input.device + ) + + +quantized_decomposed_lib.define( + "_choose_qparams_per_token_asymmetric_impl(Tensor input, ScalarType dtype) -> (Tensor, Tensor)" +) + + +@impl( + quantized_decomposed_lib, + "_choose_qparams_per_token_asymmetric_impl", + "CompositeImplicitAutograd", +) +def _choose_qparams_per_token_asymmetric_impl( + input: torch.Tensor, + dtype: torch.dtype, +) -> tuple[torch.Tensor, torch.Tensor]: + """Choose quantization parameters for per token quantization. This means for a N dimension Tensor + (M1, M2, ...Mn, N), we calculate scales/zero_points for each N elements and quantize + every N elements with the same quantization parameter. The dimension for scales/zero_points + will be (M1 * M2 ... * Mn) + + Args: + input (torch.Tensor): original float32/float16 Tensor + dtype (torch.dtype): dtype (e.g. torch.uint8) for input Tensor + + Returns: + scales and zero_points, both float32 Tensors + """ + # Based on https://github.com/google/XNNPACK/blob/df156f0cf3db5a4576cc711123eeb54915f82ffc/src/xnnpack/quantization.h#L18 + qmin, qmax = -128, 127 + min_val = torch.amin(input, dim=-1, keepdim=True) + max_val = torch.amax(input, dim=-1, keepdim=True) + min_val_neg = torch.min(min_val, torch.zeros_like(min_val)) + max_val_pos = torch.max(max_val, torch.zeros_like(max_val)) + eps = torch.finfo(torch.float32).eps # use xnnpack eps? + + # scale + scale = (max_val_pos - min_val_neg) / float(qmax - qmin) + scale = scale.clamp(min=eps) + + # zero point + descaled_min = min_val_neg / scale + descaled_max = max_val_pos / scale + zero_point_from_min_error = qmin + descaled_min + zero_point_from_max_error = qmax + descaled_max + zero_point = torch.where( + zero_point_from_min_error + zero_point_from_max_error > 0, + qmin - descaled_min, + qmax - descaled_max, + ) + zero_point = torch.clamp(zero_point, qmin, qmax).round() + + return scale.to(torch.float64), zero_point.to(torch.int64) + + +quantized_decomposed_lib.define( + "choose_qparams_per_token_asymmetric(Tensor input, ScalarType dtype) -> (Tensor, Tensor)" +) + + +@impl( + quantized_decomposed_lib, + "choose_qparams_per_token_asymmetric", + "CompositeExplicitAutograd", +) +def choose_qparams_per_token_asymmetric( + input: torch.Tensor, + dtype: torch.dtype, +) -> tuple[torch.Tensor, torch.Tensor]: + return _choose_qparams_per_token_asymmetric_impl(input, dtype) + + +@impl( + quantized_decomposed_lib, + "choose_qparams_per_token_asymmetric", + "Meta", +) +def choose_qparams_per_token_asymmetric_meta( + input: torch.Tensor, + dtype: torch.dtype, +) -> tuple[torch.Tensor, torch.Tensor]: + size = list(input.shape[:-1]) + [1] + return torch.empty(size, dtype=torch.double, device=input.device), torch.empty( + size, dtype=torch.int64, device=input.device + ) + + +def _per_token_quant_qparam_dim_check(input, scales, zero_points): + num_tokens = math.prod(list(input.size())[:-1]) + if num_tokens != scales.numel(): + raise AssertionError(f"num_tokens: {num_tokens} scales: {scales.size()}") + if num_tokens != zero_points.numel(): + raise AssertionError( + f"num_tokens: {num_tokens} zero_points: {zero_points.size()}" + ) + + +quantized_decomposed_lib.define( + "quantize_per_token(Tensor input, Tensor scales, Tensor zero_points, " + "int quant_min, int quant_max, ScalarType dtype) -> Tensor" +) + + +@impl(quantized_decomposed_lib, "quantize_per_token", "CompositeExplicitAutograd") +def quantize_per_token( + input: torch.Tensor, + scales: torch.Tensor, + zero_points: torch.Tensor, + quant_min: int, + quant_max: int, + dtype: torch.dtype, +): + """Per token quantization for the Tensor using the quantization parameters to map + from floating point to quantized values. This means for a N dimension Tensor + (M1, M2, ...Mn, N), we calculate scales/zero_points for each N elements and quantize + every N elements with the same quantization parameter. The dimension for scales/zero_points + will be (M1 * M2 ... * Mn) + + Args: + input (torch.Tensor): original float32 or bfloat16 Tensor + scales (float32 torch.Tensor): quantization parameter for per token affine quantization + zero_points (int32 torch.Tensor): quantization parameter for per token affine quantization + quant_min (int): minimum quantized value for output Tensor + quant_max (int): maximum quantized value for output Tensor + dtype (torch.dtype): requested dtype (e.g. torch.uint8) for output Tensor + + Returns: + Tensor with requested dtype (e.g. torch.uint8), note the quantization parameters + are not stored in the Tensor, we are storing them in function arguments instead + """ + _quant_min_max_bounds_check(quant_min, quant_max, dtype) + _per_token_quant_qparam_dim_check(input, scales, zero_points) + input = ( + input.mul(1.0 / scales) + .add(zero_points) + .round() + .clamp(quant_min, quant_max) + .to(dtype) + ) + return input + + +@impl(quantized_decomposed_lib, "quantize_per_token", "Meta") +def quantize_per_token_meta( + input: torch.Tensor, + scales: torch.Tensor, + zero_points: torch.Tensor, + quant_min: int, + quant_max: int, + dtype: torch.dtype, +): + _quant_min_max_bounds_check(quant_min, quant_max, dtype) + return torch.empty_like(input, dtype=dtype) + + +quantized_decomposed_lib.define( + "dequantize_per_token(Tensor input, Tensor scales, Tensor zero_points, " + "int quant_min, int quant_max, ScalarType dtype, ScalarType output_dtype) -> Tensor" +) + + +@impl(quantized_decomposed_lib, "dequantize_per_token", "CompositeExplicitAutograd") +def dequantize_per_token( + input: torch.Tensor, + scales: torch.Tensor, + zero_points: torch.Tensor, + quant_min: int, + quant_max: int, + dtype: torch.dtype, + output_dtype: torch.dtype = torch.float32, +): + """Per token dequantization for the Tensor using the quantization parameters to map + from floating point to quantized values. This means for a N dimension Tensor + (M1, M2, ...Mn, N), we calculate scales/zero_points for each N elements and quantize + every N elements with the same quantization parameter. The dimension for scales/zero_points + will be (M1 * M2 ... * Mn) + + Args: + input (torch.Tensor): quantized Tensor (uint8, int8 etc.) + scales (float64 torch.Tensor): quantization parameter for per token affine quantization + zero_points (int64 torch.Tensor): quantization parameter for per token affine quantization + quant_min (int): minimum quantized value for input Tensor + quant_max (int): maximum quantized value for input Tensor + dtype (torch.dtype): dtype (e.g. torch.uint8) for input Tensor + output_dtype (torch.dtype): dtype (e.g. torch.float32) for output Tensor + + Returns: + dequantized Tensor with dtype `output_dtype` + """ + input = input - zero_points + input = input * scales + # Since scales are of float64 type, we need to cast it to output dtype requested + return input.to(output_dtype) + + +@impl(quantized_decomposed_lib, "dequantize_per_token", "Meta") +def dequantize_per_token_meta( + input: torch.Tensor, + scales: torch.Tensor, + zero_points: torch.Tensor, + quant_min: int, + quant_max: int, + dtype: torch.dtype, + output_dtype: torch.dtype = torch.float32, +): + _quant_min_max_bounds_check(quant_min, quant_max, dtype) + # TODO: support fp16 + return torch.empty_like(input, dtype=output_dtype) + + +quantized_decomposed_lib.define( + "quantize_per_channel_group(Tensor input, Tensor scales, Tensor zero_points, int quant_min, " + "int quant_max, ScalarType dtype, int group_size) -> Tensor" +) + + +# TODO: dtype is ignored for now +@impl( + quantized_decomposed_lib, "quantize_per_channel_group", "CompositeExplicitAutograd" +) +def quantize_per_channel_group( + input: torch.Tensor, + scales: torch.Tensor, + zero_points: torch.Tensor, + quant_min: int, + quant_max: int, + dtype: torch.dtype, + group_size=128, +): + if group_size <= 1: + raise AssertionError("group_size must be > 1") + # needed for GPTQ single column quantize + if group_size > input.shape[-1] and scales.shape[-1] == 1: + group_size = input.shape[-1] + + if input.shape[-1] % group_size != 0: + raise AssertionError("input.shape[-1] must be divisible by group_size") + if input.dim() != 2: + raise AssertionError("input must be 2-dimensional") + + # TODO: check for dtype, currently we can't express torch.int4 so it's omitted + to_quant = input.reshape(-1, group_size) + if torch.isnan(to_quant).sum() != 0: + raise AssertionError("to_quant must not contain NaNs") + + scales = scales.reshape(-1, 1) + zero_points = zero_points.reshape(-1, 1) + + input_int8 = ( + to_quant.mul(1.0 / scales) + .add(zero_points) + .round() + .clamp_(quant_min, quant_max) + .to(dtype) + .reshape_as(input) + ) + + return input_int8 + + +@impl(quantized_decomposed_lib, "quantize_per_channel_group", "Meta") +def quantize_per_channel_group_meta( + input: torch.Tensor, + scales: torch.Tensor, + zero_points: torch.Tensor, + quant_min: int, + quant_max: int, + dtype: torch.dtype, + group_size=128, +): + """Groupwise quantization within each channel for an 2-d Tensor using the quantization parameters + to map from floating point to quantized values. This means for each row of a 2-d Tensor + (M, N), we calculate scales/zero_points for each `group_size` elements + and quantize every `group_size` elements with the same quantization parameter. + The dimension for scales/zero_points will be (M * ceil(N, group_size),) + + Args: + input (torch.Tensor): original float32 or bfloat16 Tensor + scales (float32 torch.Tensor): quantization parameter for per channel group affine quantization + zero_points (int32 torch.Tensor): quantization parameter for per channel group affine quantization + quant_min (int): minimum quantized value for output Tensor + quant_max (int): maximum quantized value for output Tensor + dtype (torch.dtype): requested dtype (e.g. torch.uint8) for output Tensor + + Returns: + Tensor with requested dtype (e.g. torch.uint8), note the quantization parameters + are not stored in the Tensor, we are storing them in function arguments instead + """ + if group_size <= 1: + raise AssertionError("group_size must be > 1") + # needed for GPTQ single column quantize + if group_size > input.shape[-1] and scales.shape[-1] == 1: + group_size = input.shape[-1] + + if input.shape[-1] % group_size != 0: + raise AssertionError("input.shape[-1] must be divisible by group_size") + if input.dim() != 2: + raise AssertionError("input must be 2-dimensional") + return torch.empty_like(input, dtype=dtype) + + +quantized_decomposed_lib.define( + "dequantize_per_channel_group(Tensor input, Tensor scales, Tensor? zero_points, int quant_min, " + "int quant_max, ScalarType dtype, int group_size, ScalarType output_dtype) -> Tensor" +) + + +@impl( + quantized_decomposed_lib, + "dequantize_per_channel_group", + "CompositeExplicitAutograd", +) +def dequantize_per_channel_group( + w_int8: torch.Tensor, + scales: torch.Tensor, + zero_points: torch.Tensor | None, + quant_min: int, + quant_max: int, + dtype: torch.dtype, + group_size: int = 128, + output_dtype: torch.dtype = torch.float32, +): + """Groupwise dequantization within each channel for an 2-d Tensor using the quantization parameters + to map from floating point to quantized values. This means for each row of a 2-d Tensor + (M, N), we calculate scales/zero_points for each `group_size` elements + and quantize every `group_size` elements with the same quantization parameter. + The dimension for scales/zero_points will be (M * ceil(N, group_size),) + + Args: + input (torch.Tensor): quantized Tensor (uint8/int8 etc.) + scales (float32 torch.Tensor): quantization parameter for per channel group affine quantization + zero_points (int32 torch.Tensor): quantization parameter for per channel group affine quantization + quant_min (int): minimum quantized value for input Tensor + quant_max (int): maximum quantized value for input Tensor + dtype (torch.dtype): dtype (e.g. torch.uint8) for input Tensor + output_dtype (torch.dtype): dtype (e.g. torch.float32) for output Tensor + + Returns: + dequantized Tensor with dtype `output_dtype` + """ + + if group_size <= 1: + raise AssertionError("group_size must be > 1") + # needed for GPTQ single column dequantize + if group_size > w_int8.shape[-1] and scales.shape[-1] == 1: + group_size = w_int8.shape[-1] + if w_int8.shape[-1] % group_size != 0: + raise AssertionError("w_int8.shape[-1] must be divisible by group_size") + if w_int8.dim() != 2: + raise AssertionError("w_int8 must be 2-dimensional") + + w_int8_grouped = w_int8.reshape(-1, group_size) + scales = scales.reshape(-1, 1) + if zero_points is not None: + zp = zero_points.reshape(-1, 1) + else: + zp = torch.zeros([], dtype=torch.int32, device=scales.device) + w_dq = w_int8_grouped.sub(zp).mul(scales).reshape_as(w_int8).to(output_dtype) + return w_dq + + +quantized_decomposed_lib.define( + "fake_quant_per_channel(Tensor input, Tensor scales, Tensor zero_points, int axis, " + "int quant_min, int quant_max) -> Tensor" +) + + +class FakeQuantPerChannel(torch.autograd.Function): + @staticmethod + # pyrefly: ignore [bad-override] + def forward(ctx, input, scales, zero_points, axis, quant_min, quant_max): + if scales.dtype != torch.float32: + scales = scales.to(torch.float32) + if zero_points.dtype != torch.int32: + zero_points = zero_points.to(torch.int32) + if input.dtype != torch.float32: + raise AssertionError( + f"Expecting input to have dtype torch.float32, but got dtype: {input.dtype}" + ) + if axis >= input.dim(): + raise AssertionError(f"Expecting axis to be < {input.dim()}") + broadcast_dims = list(range(axis)) + list(range(axis + 1, input.ndim)) + unsqueeze_scales = _unsqueeze_multiple(scales, broadcast_dims) + unsqueeze_zero_points = _unsqueeze_multiple(zero_points, broadcast_dims) + temp = torch.round(input * (1.0 / unsqueeze_scales)) + unsqueeze_zero_points + out = ( + torch.clamp(temp, quant_min, quant_max) - unsqueeze_zero_points + ) * unsqueeze_scales + mask = torch.logical_and((temp >= quant_min), (temp <= quant_max)) + + ctx.save_for_backward(mask) + return out + + @staticmethod + # pyrefly: ignore [bad-override] + def backward(ctx, gy): + (mask,) = ctx.saved_tensors + return gy * mask, None, None, None, None, None + + +@impl(quantized_decomposed_lib, "fake_quant_per_channel", "Autograd") +def fake_quant_per_channel( + input: torch.Tensor, + scales: torch.Tensor, + zero_points: torch.Tensor, + axis: int, + quant_min: int, + quant_max: int, +) -> torch.Tensor: + return FakeQuantPerChannel.apply( + input, scales, zero_points, axis, quant_min, quant_max + ) + + +@impl(quantized_decomposed_lib, "fake_quant_per_channel", "Meta") +def fake_quant_per_channel_meta( + input: torch.Tensor, + scales: torch.Tensor, + zero_points: torch.Tensor, + axis: int, + quant_min: int, + quant_max: int, +) -> torch.Tensor: + return torch.empty_like(input) + + +quantized_decomposed_lib.define( + "convert_element_type.no_fuse(Tensor input, ScalarType dtype) -> Tensor" +) + + +@impl( + quantized_decomposed_lib, + "convert_element_type.no_fuse", + "CompositeExplicitAutograd", +) +def convert_element_type(input: torch.Tensor, dtype: torch.dtype) -> torch.Tensor: + return torch.ops.prims.convert_element_type.default(input, dtype) + + +@impl(quantized_decomposed_lib, "convert_element_type.no_fuse", "Meta") +def convert_element_type_meta(input: torch.Tensor, dtype: torch.dtype) -> torch.Tensor: + return torch.empty_like(input, dtype=dtype) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/_equalize.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/_equalize.py new file mode 100644 index 0000000000000000000000000000000000000000..dda37214210e34bb7676b9877d2e44876366a07f --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/_equalize.py @@ -0,0 +1,1020 @@ +# mypy: allow-untyped-defs +import operator +import warnings +from collections import namedtuple +from typing import Any + +import torch +import torch.ao.nn.intrinsic as nni +import torch.nn as nn +import torch.nn.functional as F +from torch.ao.quantization.fx.graph_module import _get_observed_graph_module_attr +from torch.ao.quantization.observer import ( + _with_args, + ObserverBase, + PerChannelMinMaxObserver, +) +from torch.ao.quantization.utils import _parent_name, check_min_max_valid +from torch.fx import GraphModule +from torch.fx.graph import Node + +from .utils import ( + get_new_attr_name_with_prefix, + maybe_get_next_module, + node_arg_is_weight, +) + + +CUSTOM_MODULE_SUPP_LIST: list[Any] = [] + + +def reshape_scale(scale: torch.Tensor, axis: int, input: torch.Tensor) -> torch.Tensor: + """Reshapes the scale so that we can multiply it to the input by the given axis.""" + new_shape = [1] * input.ndim + new_shape[axis] = input.size(axis) + return scale.view(new_shape) + + +qsheme_mapping_per_tensor_to_per_channel = { + torch.per_tensor_affine: torch.per_channel_affine, + torch.per_tensor_symmetric: torch.per_channel_symmetric, +} + + +class _InputEqualizationObserver(nn.Module): + r"""Observer for tracking the running min/max values of input columns, and + computing the quantization parameters for the overall min/max input values. + + Args: + dtype: Quantized data type + qscheme: Quantization scheme + quant_min: Minimum quantization value. If unspecified, it will + follow the 8-bit setup. + quant_max: Maximum quantization value. If unspecified, it will + follow the 8-bit setup. + + The running minimum/maximum :math:`x_\text{min/max}` are computed in the + same way as :class:`~torch.ao.quantization.observer.PerChannelMinMaxObserver`, + with the difference that the running min/max values are stored per column. + This observer is intended to be used along with a WeightEqualizationObserver + to calculate the equalization scale. + """ + + def __init__( + self, + dtype=torch.quint8, + qscheme=torch.per_tensor_affine, + quant_min=None, + quant_max=None, + factory_kwargs=None, + ) -> None: + super().__init__() + + if qscheme not in {torch.per_tensor_affine, torch.per_tensor_symmetric}: + raise TypeError("Input qscheme must be per-tensor") + + self.dtype = dtype + self.qscheme = qscheme + + per_channel_qscheme = qsheme_mapping_per_tensor_to_per_channel[qscheme] + self.input_obs = PerChannelMinMaxObserver( + ch_axis=1, + dtype=dtype, + qscheme=per_channel_qscheme, + quant_min=quant_min, + quant_max=quant_max, + factory_kwargs=factory_kwargs, + ) + + self.equalization_scale = torch.tensor(1) + self.equalization_shape: list[int] = [] + + def forward(self, x_orig): + if x_orig.ndim < 2 or x_orig.ndim > 5: + raise ValueError( + "InputEqualizationObserver only supports Linear and Conv layers" + ) + + # Calculate the shape needed to reshape the equalization scale later (needed for Conv layers) + self.equalization_shape = [1] * x_orig.ndim + self.equalization_shape[1] = x_orig.size(1) + + return self.input_obs(x_orig) + + def get_input_minmax(self): + return (self.input_obs.min_val, self.input_obs.max_val) + + def set_equalization_scale(self, equalization_scale): + # Reshape the equalization scale along axis=1 so that it can be + # multiplied with the input along axis=1 + if equalization_scale.nelement() == 1 and equalization_scale == torch.tensor(1): + return + self.equalization_scale = torch.reshape( + equalization_scale, self.equalization_shape + ) + + def calculate_scaled_minmax(self): + r"""Returns the scaled min/max inputs""" + if ( + self.equalization_scale.nelement() == 1 + and self.equalization_scale == torch.tensor(1) + ): + warnings.warn( + "Must call calculate_equalization_scale before calling calculate_scaled_minmax. " + + "Will not scale the next quantization observer.", + stacklevel=2, + ) + return None, None + + # Calculate qparams for the scaled min/max inputs + # Scale the input by the equalization scale located at the same column + # index + (min_inputs, max_inputs) = self.get_input_minmax() + equalization_scale_reshaped = reshape_scale( + self.equalization_scale, 0, min_inputs + ) + min_input_scaled = torch.min(torch.mul(min_inputs, equalization_scale_reshaped)) + max_input_scaled = torch.max(torch.mul(max_inputs, equalization_scale_reshaped)) + + return min_input_scaled, max_input_scaled + + with_args = classmethod(_with_args) + + +class _WeightEqualizationObserver(nn.Module): + r"""Observer for tracking the running min/max values of weight columns and + rows, and computing the quantization parameters for the weight rows. + + Args: + dtype: Quantized data type + qscheme: Quantization scheme + quant_min: Minimum quantization value. If unspecified, it will + follow the 8-bit setup. + quant_max: Maximum quantization value. If unspecified, it will + follow the 8-bit setup. + + This observer is made up of 1 PerChannelMinMaxObserver `weight_col_obs` used + to record the running minimum and maximum of columns of incoming weight + tensors. This observer is intended to be used along with an + InputEqualizationObserver to calculate the equalization scale. + + The running minimum/maximum :math:`w_\text{min/max}` are computed in the + same way as :class:`~torch.ao.quantization.observer.PerChannelMinMaxObserver`. + """ + + def __init__( + self, + dtype=torch.qint8, + qscheme=torch.per_tensor_affine, + quant_min=None, + quant_max=None, + factory_kwargs=None, + ) -> None: + super().__init__() + + self.dtype = dtype + self.qscheme = qscheme + self.ch_axis = 1 + + per_channel_qscheme = qscheme + if qscheme in {torch.per_tensor_affine, torch.per_tensor_symmetric}: + per_channel_qscheme = qsheme_mapping_per_tensor_to_per_channel[qscheme] + self.weight_col_obs = PerChannelMinMaxObserver( + ch_axis=1, + dtype=dtype, + qscheme=per_channel_qscheme, + quant_min=quant_min, + quant_max=quant_max, + factory_kwargs=factory_kwargs, + ) + + self.equalization_scale = torch.tensor(1) + + def forward(self, w_orig): + if w_orig.ndim < 2 or w_orig.ndim > 5: + raise ValueError( + "InputEqualizationObserver only supports Linear and Conv layers" + ) + + return self.weight_col_obs(w_orig) + + def get_weight_col_minmax(self): + return (self.weight_col_obs.min_val, self.weight_col_obs.max_val) + + def set_equalization_scale(self, equalization_scale): + self.equalization_scale = equalization_scale + + with_args = classmethod(_with_args) + + +def calculate_equalization_scale( + input_obs: _InputEqualizationObserver, weight_obs: _WeightEqualizationObserver +) -> torch.Tensor: + r"""Calculates the equalization scale and sets the equalization_scale value + in the observers. + + Args: + input_obs: Observer that tracks the ranges for the input columns + weight_obs: Observer that tracks the ranges for the weight columns + """ + + (min_inputs, max_inputs) = input_obs.get_input_minmax() + (min_weights, max_weights) = weight_obs.get_weight_col_minmax() + + if not ( + check_min_max_valid(min_inputs, max_inputs) + and check_min_max_valid(min_weights, max_weights) + ): + warnings.warn( + "Must run observer before calling calculate_equalization_scale. " + + "Returning default equalization scale torch.tensor(1).", + stacklevel=2, + ) + return torch.tensor(1) + + if min_inputs.shape != min_weights.shape: + raise ValueError( + "Input and Weight must have the same column dimension. " + + f"Found {min_inputs.shape} and {min_weights.shape} shapes instead." + ) + + equalization_scale = torch.sqrt( + (max_weights - min_weights) / (max_inputs - min_inputs) + ) + # Replace all 'inf', 'nan', 0's with 1s to prevent errors + equalization_scale[equalization_scale == 0.0] = 1 + equalization_scale = torch.nan_to_num(equalization_scale, nan=1, posinf=1, neginf=1) + return equalization_scale + + +class EqualizationQConfig( + # pyrefly: ignore [invalid-inheritance] + namedtuple("EqualizationQConfig", ["input_activation", "weight"]) +): + """ + Describes how to quantize a layer or a part of the network specifically for + input-weight equalization by providing settings (observer classes) for + inputs, outputs, and weights. + + Note that EqualizationQConfig needs to contain observer **classes** (like + MinMaxObserver) or a callable that returns instances on invocation, not the + concrete observer instances themselves. + Quantization function will instantiate observers multiple times for each of + the layers. + + Observer classes have usually reasonable default arguments, but they can be + overwritten with `with_args` method (that behaves like functools.partial): + + my_qconfig = EqualizationQConfig(input_activation=_InputEqualizationObserver.with_args(dtype=torch.qint8), + weight=_WeightEqualizationObserver.with_args(dtype=torch.qint8)) + """ + + __slots__ = () + + def __new__(cls, input_activation=torch.nn.Identity, weight=torch.nn.Identity): + if isinstance(input_activation, nn.Module) or isinstance(weight, nn.Module): + raise ValueError( + "EqualizationQConfig received observer instance, please pass observer class instead. " + + "Use MyObserver.with_args(x=1) to override arguments to constructor if needed" + ) + self = super().__new__(cls, input_activation, weight) + return self + + +input_equalization_observer = _InputEqualizationObserver.with_args( + dtype=torch.quint8, qscheme=torch.per_tensor_symmetric +) +weight_equalization_observer = _WeightEqualizationObserver.with_args( + dtype=torch.qint8, qscheme=torch.per_channel_symmetric +) +default_equalization_qconfig = EqualizationQConfig( + input_activation=input_equalization_observer, weight=weight_equalization_observer +) + + +def fused_module_supports_equalization(module) -> bool: + """Checks if the fused node supports equalization.""" + return type(module) in [ + nni.LinearReLU, + nni.ConvReLU1d, + nni.ConvReLU2d, + nni.ConvReLU3d, + ] + + +def nn_module_supports_equalization(module) -> bool: + """Checks if the torch.nn node supports equalization.""" + return type(module) in [nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d] + + +def custom_module_supports_equalization(module) -> bool: + """Checks if the custom node supports equalization.""" + return type(module) in CUSTOM_MODULE_SUPP_LIST + + +def node_supports_equalization(node: Node, modules) -> bool: + """Checks if the current node supports equalization + Currently we only support nn.Linear/F.Linear and nn.Conv/F.conv layers + """ + if node.op == "call_module": + return ( + nn_module_supports_equalization(modules[str(node.target)]) + or fused_module_supports_equalization(modules[str(node.target)]) + or custom_module_supports_equalization(modules[str(node.target)]) + ) + elif node.op == "call_function": + return node.target in [F.linear, F.conv1d, F.conv2d, F.conv3d] + return False + + +def is_equalization_observer(observer: nn.Module) -> bool: + return isinstance( + observer, (_InputEqualizationObserver, _WeightEqualizationObserver) + ) + + +############################################################################### +# Functions for equalization during convert # +############################################################################### + + +def get_op_node_and_weight_eq_obs( + input_eq_obs_node: Node, model: GraphModule, modules: dict[str, nn.Module] +) -> tuple[Node | None, _WeightEqualizationObserver | None]: + """Gets the following weight equalization observer. There should always + exist a weight equalization observer after an input equalization observer. + + Returns the operation node that follows the input equalization observer node + and the weight equalization observer + """ + + # Find the op node that comes directly after the input equalization observer + op_node = None + for user in input_eq_obs_node.users: + if node_supports_equalization(user, modules): + op_node = user + break + + if op_node is None: + raise AssertionError( + "Expected an operation node after the input equalization observer" + ) + if op_node.op == "call_module": + # If the op_node is a nn.Linear layer, then it must have a + # WeightEqualizationObserver configuration + maybe_equalization_node_name_to_config = _get_observed_graph_module_attr( + model, "equalization_node_name_to_qconfig" + ) + if maybe_equalization_node_name_to_config is None: + raise AssertionError( + "Expected 'equalization_node_name_to_qconfig' attribute in observed graph module" + ) + equalization_node_name_to_qconfig: dict[str, Any] = ( + maybe_equalization_node_name_to_config # type: ignore[assignment] + ) + if equalization_node_name_to_qconfig.get(op_node.name, None) is None: + raise AssertionError( + f"No equalization qconfig found for op node {op_node.name}" + ) + weight_eq_obs = equalization_node_name_to_qconfig.get( # type: ignore[union-attr] + op_node.name, None + ).weight() + + if not isinstance(weight_eq_obs, _WeightEqualizationObserver): + raise AssertionError( + "Expected weight equalization observer to be a _WeightEqualizationObserver" + ) + return op_node, weight_eq_obs + + elif op_node.op == "call_function": + weight_node = maybe_get_weight_eq_obs_node(op_node, modules) + if weight_node is not None: + weight_eq_obs = modules[str(weight_node.target)] + if not isinstance(weight_eq_obs, _WeightEqualizationObserver): + raise AssertionError( + "Expected weight equalization observer to be a _WeightEqualizationObserver" + ) + return op_node, weight_eq_obs + + return None, None + + +def maybe_get_weight_eq_obs_node( + op_node: Node, modules: dict[str, nn.Module] +) -> Node | None: + """Gets the weight equalization observer node if it exists.""" + if op_node.op != "call_function": + raise AssertionError( + "maybe_get_weight_eq_obs_node expects a call_function op_node" + ) + for node_arg in op_node.args: + if node_arg_is_weight(op_node, node_arg): + if ( + isinstance(node_arg, Node) + and node_arg.op == "call_module" + and isinstance( + modules[str(node_arg.target)], _WeightEqualizationObserver + ) + ): + return node_arg + return None + + +def maybe_get_next_input_eq_obs( + node: Node, modules: dict[str, nn.Module] +) -> _InputEqualizationObserver | None: + """Gets the following input equalization observer if it exists. + + For example, in the case of connecting linear layers: + x -> inp_obs1 -> eq_obs1 -> linear1 -> out_obs1 -> eq_obs2 -> linear2 -> out_obs2 + If the node being passed in is the linear1 node, then we want to return eq_obs2, + the following equalization observer for linear2. + + However, if there are no connecting layers: + x -> inp_obs1 -> eq_obs1 -> linear1 -> out_obs1 -> add + Then we want to return None. + + In the case of an unfused linear-relu layer with a connecting linear layer: + linear1 -> relu -> out_obs1 -> eq_obs2 -> linear2 -> out_obs2 + Since it is unfused, we want to skip over the relu layer and return eq_obs2, + the following equalization observer for linear2. + """ + + if not node_supports_equalization(node, modules): + raise AssertionError("Node does not support equalization") + + # Locate the following nn.ReLU or F.relu node if it exists + maybe_relu_node = maybe_get_next_module(node, modules, nn.ReLU) + if maybe_relu_node is None: + maybe_relu_node = maybe_get_next_module( + node, modules, target_functional_type=F.relu + ) + + # Locate the following output observer if it exists. + # We will skip the relu node if it exists. + maybe_obs_node = ( + maybe_get_next_module(node, modules, ObserverBase) + if maybe_relu_node is None + else maybe_get_next_module(maybe_relu_node, modules, ObserverBase) + ) + if maybe_obs_node is None: + return None + + maybe_eq_obs_node = maybe_get_next_module( + maybe_obs_node, modules, _InputEqualizationObserver + ) + if maybe_eq_obs_node is None: + return None + + maybe_eq_obs = modules[str(maybe_eq_obs_node)] + if not isinstance(maybe_eq_obs, _InputEqualizationObserver): + raise AssertionError( + "Expected the following equalization observer to be an _InputEqualizationObserver" + ) + return maybe_eq_obs + + +def maybe_get_next_equalization_scale( + node: Node, modules: dict[str, nn.Module] +) -> torch.Tensor | None: + """If the next next node is an InputEqualizationObserver then we want to + return its equalization scale, else we return 1 + + This is used in the case where there are two connecting linear layers: + linear1 -> LinearOutObs -> InputEqObs -> linear2 + In this case, the node given is linear1 and we want to locate the InputEqObs. + """ + next_inp_eq_obs = maybe_get_next_input_eq_obs(node, modules) + # pyrefly: ignore [invalid-argument] + if next_inp_eq_obs: + if ( + next_inp_eq_obs.equalization_scale.nelement() == 1 + and next_inp_eq_obs.equalization_scale == torch.tensor(1) + ): + return None + return next_inp_eq_obs.equalization_scale + return None + + +def scale_input_observer(node: Node, modules: dict[str, nn.Module]) -> None: + """Scales the following input quantization observer's min/max values by + updating the values with the scaled min/max values calculated by the input + equalization observer + """ + input_eq_obs = modules[str(node.target)] + if not isinstance(input_eq_obs, _InputEqualizationObserver): + raise AssertionError( + "Expected the module at node.target to be an _InputEqualizationObserver" + ) + + input_quant_obs_node = node.args[0] + if not isinstance(input_quant_obs_node, Node): + raise AssertionError( + "Expected the input quantization observer node to be a Node" + ) + + input_quant_obs = modules[str(input_quant_obs_node.target)] + if not isinstance(input_quant_obs, ObserverBase): + return + + min_input_scaled, max_input_scaled = input_eq_obs.calculate_scaled_minmax() + if min_input_scaled is None and max_input_scaled is None: + return + input_quant_obs.min_val = min_input_scaled + input_quant_obs.max_val = max_input_scaled + + +def scale_weight_node( + node: Node, + modules: dict[str, nn.Module], + equalization_scale: torch.Tensor, + next_equalization_scale: torch.Tensor | None, +) -> None: + """Scale the weights for input-weight equalization by multiplying the + weight by 1/equalization_scale and next_equalization_scale + + Args: + node: Current node whose weights we want to scale + equalization_scale: Current node's calculated equalization scale + next_equalization_scale: Next node's calculated equalization scale if + the following node needs to be equalized, 1 otherwise + """ + if equalization_scale is None: + return + + if fused_module_supports_equalization(modules[str(node.target)]): + op_module = modules[str(node.target)][0] # type: ignore[index] + else: + op_module = modules[str(node.target)] + if not ( + nn_module_supports_equalization(op_module) + or custom_module_supports_equalization(op_module) + ): + raise AssertionError( + "Expected operation module to support equalization (nn or custom)" + ) + + # Scale the weights for input-weight equalization + # If the following layer needs to be equalized then we will multiply its scale + weight = op_module.weight + if not isinstance(weight, torch.Tensor): + raise AssertionError("Expected op_module.weight to be a torch.Tensor") + + # Scale the weights by the reciprocal of the equalization scale + # Reshape the equalization scale so that we can multiply it to the weight along axis=1 + equalization_scale_reshaped = reshape_scale(equalization_scale, 1, weight) + scaled_weight = torch.mul(weight, torch.reciprocal(equalization_scale_reshaped)) + + if next_equalization_scale is None: + op_module.weight = nn.Parameter(scaled_weight) + return + + # Multiply the weights row wise by the next equalization scale + # Reshape the equalization scale so that we can multiply it to the weight along axis=0 + next_equalization_scale_reshaped = reshape_scale(next_equalization_scale, 0, weight) + scaled_weight = torch.mul(scaled_weight, next_equalization_scale_reshaped) + + op_module.weight = nn.Parameter(scaled_weight) + + # Multiply the bias element wise by the next equalization scale + bias = op_module.bias + if bias is None: + return + if not isinstance(bias, torch.Tensor): + raise AssertionError("Expected op_module.bias to be a torch.Tensor") + + # Reshape the equalization scale so that we can multiply it element-wise to the bias + next_equalization_scale_reshaped = reshape_scale(next_equalization_scale, 0, bias) + scaled_bias = torch.mul(bias, next_equalization_scale_reshaped) + op_module.bias = nn.Parameter(scaled_bias) + + +def scale_weight_functional( + op_node: Node, + model: GraphModule, + modules: dict[str, nn.Module], + equalization_scale: torch.Tensor, + next_equalization_scale: torch.Tensor | None, +) -> None: + """Scales the weight value for functional layers""" + if equalization_scale is None: + return + + # From the given op_node, the path looks like: + # get_attr(weight) -> weight_quant_obs -> weight_eq_obs -> op_node + # So we want to trace back from the op_node to get the equalization observer + # node, then the quantization observer node, and then finally the weight + # node which contains the weight values. + + # Get the equalization observer node + weight_eq_obs_node = maybe_get_weight_eq_obs_node(op_node, modules) + if weight_eq_obs_node is None: + return + + # Get the quantization observer node + weight_quant_obs_node = weight_eq_obs_node.args[0] + if weight_quant_obs_node is None: + return + if not ( + isinstance(weight_quant_obs_node, Node) + and isinstance(modules[str(weight_quant_obs_node.target)], ObserverBase) + ): + raise AssertionError( + "Expected weight_quant_obs_node to be a Node whose module is an ObserverBase" + ) + + # Get the get_attr(weight) node + weight_node = weight_quant_obs_node.args[0] + if weight_node is None: + return + if not (isinstance(weight_node, Node) and weight_node.op == "get_attr"): + raise AssertionError("Expected weight node to be a 'get_attr' Node") + + weight_parent_name, weight_name = _parent_name(weight_node.target) + weight = getattr(modules[weight_parent_name], weight_name) + + # Scale the weights for input-weight equalization + # If the following layer needs to be equalized then we will multiply its scale + # Reshape the equalization scale so that we can multiply it to the weight along axis=1 + equalization_scale_reshaped = reshape_scale(equalization_scale, 1, weight) + scaled_weight = torch.mul(weight, torch.reciprocal(equalization_scale_reshaped)) + + if next_equalization_scale is None: + setattr(modules[weight_parent_name], weight_name, scaled_weight) + return + + # Multiply the weights row wise by the next equalization scale + # Reshape the equalization scale so that we can multiply it to the weight along axis=1 + next_equalization_scale_reshaped = reshape_scale( + next_equalization_scale, 0, scaled_weight + ) + scaled_weight = torch.mul(scaled_weight, next_equalization_scale_reshaped) + + setattr(modules[weight_parent_name], weight_name, scaled_weight) + if not torch.allclose(model.get_buffer(str(weight_node.target)), scaled_weight): + raise AssertionError("Model buffer for weight does not match the scaled weight") + + # Multiply the bias element wise by the next equalization scale + bias_node = None + for node in op_node.args: + # Find the node containing the weight values + if isinstance(node, Node) and node.op == "get_attr" and "bias" in node.name: + bias_node = node + break + if bias_node is None: + return + + bias_parent_name, bias_name = _parent_name(bias_node.target) + bias = getattr(modules[bias_parent_name], bias_name) + + # Reshape the equalization scale so that we can multiply it element-wise to the bias + next_equalization_scale_reshaped = reshape_scale(next_equalization_scale, 0, bias) + scaled_bias = torch.mul(bias, next_equalization_scale_reshaped) + setattr(modules[bias_parent_name], bias_name, scaled_bias) + + +def clear_weight_quant_obs_node(op_node: Node, modules: dict[str, nn.Module]) -> None: + """Given the operation node, we want find the corresponding quantization + observer and reset its min/max values + """ + weight_eq_obs_node = maybe_get_weight_eq_obs_node(op_node, modules) + if weight_eq_obs_node is None: + return + + weight_quant_obs_node = weight_eq_obs_node.args[0] + if weight_quant_obs_node is None: + return + if not isinstance(weight_quant_obs_node, Node): + raise AssertionError("Expected weight_quant_obs_node to be a Node") + + weight_quant_obs = modules[str(weight_quant_obs_node.target)] + if not isinstance(modules[str(weight_quant_obs_node.target)], ObserverBase): + raise AssertionError( + "Expected the module at weight_quant_obs_node to be an ObserverBase" + ) + weight_quant_obs.reset_min_max_vals() # type: ignore[operator] + + +def remove_node(model: GraphModule, node: Node, prev_node: Node): + """Removes the given node from the model by replacing all of its users with + the given previous node + """ + # For all of the current node's users, replace the current node with + # the input quantization observer node + orig_users = list(node.users.keys()) + for user_node in orig_users: + user_node.replace_input_with(node, prev_node) + + # Erase the InputEqualizationObserver node + model.graph.erase_node(node) + + +def update_obs_for_equalization( + model: GraphModule, modules: dict[str, nn.Module] +) -> dict[str, _WeightEqualizationObserver]: + """Update all of the observer's equalization scale. For each + InputEqualizationObserver, we will find the location of the next + WeightEqualizationObserver, create it, and calculate the equalization scale + based on the two observers. + + We will then return a dictionary mapping operation node names to + the corresponding WeightEqualizationObservers for that operation. + """ + weight_eq_obs_dict = {} + for node in model.graph.nodes: + if node.op == "call_module" and isinstance( + modules[node.target], _InputEqualizationObserver + ): + input_eq_obs = modules[node.target] + if not isinstance(input_eq_obs, _InputEqualizationObserver): + raise AssertionError( + "Expected module at node.target to be an _InputEqualizationObserver" + ) + op_node, weight_eq_obs = get_op_node_and_weight_eq_obs(node, model, modules) + + if op_node is None or weight_eq_obs is None: + continue + + if op_node.op == "call_module": + # Calibrate the weight equalization observer since it has just + # been created + if fused_module_supports_equalization(modules[str(op_node.target)]): + module = modules[str(op_node.target)][0] # type: ignore[index] + if not nn_module_supports_equalization(module): + raise AssertionError( + "Expected fused module to support equalization" + ) + weight_eq_obs(module.weight) + else: + weight_eq_obs(modules[str(op_node.target)].weight) + + # Calculate and set the equalization scale values + equalization_scale = calculate_equalization_scale( + input_eq_obs, weight_eq_obs + ) + input_eq_obs.set_equalization_scale(equalization_scale) + weight_eq_obs.set_equalization_scale(equalization_scale) + + weight_eq_obs_dict[op_node.name] = weight_eq_obs + + return weight_eq_obs_dict + + +def convert_eq_obs( + model: GraphModule, + modules: dict[str, nn.Module], + weight_eq_obs_dict: dict[str, _WeightEqualizationObserver], +) -> None: + """Converts the equalization operations and updates the other nodes in the + following way: + - Removes the input equalization observers and inserts a mul operator + along with an equalization scale node wherever applicable (we do not + want to insert a mul operator between connecting linear layers). + - Updates the input quantization observers with the scaled input min/max + values. + - Scales the weights by the current and next equalization scales. + - Removes the weight equalization observer node if it exists. + + Before (after prepare): + weight values + | + WeightQuantObs + | + WeightEqObs + | + x -> InpQuantObs -> InpEqObs -> linear -> OutQuantObs + + After this function: + scaled weight values + | + equalization scale WeightQuantObs + | | + x -> mul -> InpQuantObs (scaled min/max) -> linear -> OutQuantObs + + After convert: + equalization scale scaled weight values + | | + x -> mul -> quantize_per_tensor -> quantized::linear + + Note that although the equalization observer appeared after the quantization + observer after prepare_fx, the mul node appears before the quantization node + after convert_fx. This is because placing the equalization observer after + the quantization observer in prepare_fx would allow us to keep the invariant + that the graph before the current node inserts its observers is not + modified. + + Having the equalization observer before the quantization observer would also + cause some inconsistences between the ordering of the quantization and + equalization observers. + For example, a single linear layer would look like: + x -> InpEqObs1 -> InpQuantObs1 -> linear1 -> OutQuantObs1 + But between two connected linear layers, it would look like: + linear1 -> OutQuantObs1 -> InpEqObs2 -> linear2 -> OutQuantObs2 + """ + for node in model.graph.nodes: + if node.op == "call_module" and isinstance( + modules[node.target], _InputEqualizationObserver + ): + inp_quant_obs_node = node.args[0] + prev_node = inp_quant_obs_node.args[0] + + # If the previous node is a layer that needs to be equalized, then + # we will remove the current node because we do not need to add any + # equalization nodes between two layers that need to be equalized + + # Before: linear1/relu (prev_node) -> output_quant_obs1 (inp_quant_obs_node) -> input_eq_obs2 (node) -> linear2 + # After: linear1/relu (prev_node) -> output_quant_obs1 (inp_quant_obs_node) -> linear2 + if ( + node_supports_equalization(prev_node, modules) + or "relu" in prev_node.name + ): + remove_node(model, node, inp_quant_obs_node) + continue + + # Update the following input quantization observer's min/max values + scale_input_observer(node, modules) + + # Remove the InputEqualization node and add a mul operator before + # the quantization observer node that appears before the equalization node + # Before: x -> input_quant_obs -> input_eq_obs -> linear + # After: x -> mul -> input_quant_obs -> linear + + # Create a node containing the equalization scale + with model.graph.inserting_before(inp_quant_obs_node): + get_new_eq_scale_name = get_new_attr_name_with_prefix( + prev_node.name + "_equalization_scale" + ) + name = get_new_eq_scale_name(modules) + setattr(model, name, modules[node.target].equalization_scale) + eq_scale_node = model.graph.create_node("get_attr", name) + + # Create a node multiplying the input with the equalization scale + with model.graph.inserting_after(eq_scale_node): + inputs = (prev_node, eq_scale_node) + mul_node = model.graph.create_node("call_function", torch.mul, inputs) + + # Set the mul nod to be the input_quant_obs_node's input instead of + # the previous node + inp_quant_obs_node.replace_input_with(prev_node, mul_node) + remove_node(model, node, inp_quant_obs_node) + + elif weight_eq_obs_dict.get(node.name, None) is not None: + weight_eq_obs = weight_eq_obs_dict.get(node.name) + if not isinstance(weight_eq_obs, _WeightEqualizationObserver): + raise AssertionError( + "Expected weight equalization observer to be a _WeightEqualizationObserver" + ) + equalization_scale = weight_eq_obs.equalization_scale + + if ( + equalization_scale.nelement() == 1 + and equalization_scale == torch.tensor(1) + ): + equalization_scale = None # type: ignore[assignment] + maybe_next_equalization_scale = maybe_get_next_equalization_scale( + node, modules + ) + + # Scale the weight nodes + if node.op == "call_module": + scale_weight_node( + node, + modules, + # pyrefly: ignore [bad-argument-type] + equalization_scale, + maybe_next_equalization_scale, + ) + elif node.op == "call_function": + scale_weight_functional( + node, + model, + modules, + # pyrefly: ignore [bad-argument-type] + equalization_scale, + maybe_next_equalization_scale, + ) + + weight_eq_obs_node = maybe_get_weight_eq_obs_node(node, modules) + if weight_eq_obs_node is None: + return + if not isinstance( + modules[str(weight_eq_obs_node.target)], _WeightEqualizationObserver + ): + raise AssertionError( + "Expected weight equalization observer to be a _WeightEqualizationObserver" + ) + + # Clear the quantization observer's min/max values so that they + # can get updated later based on the new scale values + clear_weight_quant_obs_node(node, modules) + + # Erase the weight equalization observer node + prev_node = weight_eq_obs_node.args[0] + remove_node(model, weight_eq_obs_node, prev_node) # type: ignore[arg-type] + else: + raise ValueError( + "Expected operation node to be 'call_module' or 'call_function" + + f"Instead got node {node.name} as '{node.op}'." + ) + + +def _convert_equalization_ref(model: GraphModule): + """Reference function which applies changes needed for equalization, but + does not quantize the nodes + """ + modules = dict(model.named_modules(remove_duplicate=False)) + + # Calculate the equalization scale, update the observers with the scaled + # inputs, and scale the weight + weight_eq_obs_dict = update_obs_for_equalization(model, modules) + convert_eq_obs(model, modules, weight_eq_obs_dict) + + return GraphModule(model, model.graph) + + +############################################################################### +# Functions for running the equalized model on the Numeric Suite # +############################################################################### + + +def get_layer_sqnr_dict( + model_a: nn.Module, model_b: nn.Module, x: torch.Tensor +) -> dict[str, float]: + """Runs the Numeric Suite on model_a and model_b and returns a dictionary + containing the SQNR between layers in model_a and model_b. + + Note: In order to support equalized models, this function has a hacky fix in + which we do not match any torch.mul operators. This is because equalized + models contain extra mul operators to scale the input by the equalization + scale, but this edge case has not been resolved yet within the numeric suite code. + + Args: + model_a: A float model + model_b: A quantized model + x: Inputs to use during calibration + """ + import torch.ao.ns._numeric_suite_fx as ns + from torch.ao.ns.fx.mappings import get_unmatchable_types_map + + unmatchable_types_map = get_unmatchable_types_map() + unmatchable_types_map["funs_unmatchable"].add(torch.mul) + + model_a_ns, model_b_ns = ns.add_loggers( + "fp32", + model_a, + "int8", + model_b, + ns.OutputLogger, + unmatchable_types_map=unmatchable_types_map, + ) + + model_a_ns(x) + model_b_ns(x) + + activation_comparison_dict = ns.extract_logger_info( + model_a_ns, model_b_ns, ns.OutputLogger, "int8" + ) + ns.extend_logger_results_with_comparison( + activation_comparison_dict, + "fp32", + "int8", + torch.ao.ns.fx.utils.compute_sqnr, + "sqnr", + ) + + # Construct a dictionary mapping layer names to the SQNR values + layer_sqnr_dict = {} + for key in activation_comparison_dict: + layer = activation_comparison_dict[key]["node_output"]["int8"][0]["fqn"] + sqnr = activation_comparison_dict[key]["node_output"]["int8"][0]["sqnr"][0] + layer_sqnr_dict[layer] = sqnr + + return layer_sqnr_dict + + +def get_equalization_qconfig_dict( + layer_sqnr_dict: dict[str, float], num_layers_to_equalize: int +) -> Any: + """Given the layer to SQNR dictionary, find the layers with the highest + quantization errors, and return an equalization_qconfig_dict + specifying to only equalize those top layers. + + Args: + layer_sqnr_dict: Dictionary mapping layer names to SQNR values (found + when comparing an equalized model against a float model) + num_layers_to_equalize: Number of layers with the highest quantization + errors to equalize + """ + + # Sort the layer_sqnr_dictionary values and get the layers with the lowest + # SQNR values (aka highest quantization errors) + layer_sqnr_sorted = sorted(layer_sqnr_dict.items(), key=operator.itemgetter(1)) + layers_to_equalize = layer_sqnr_sorted[:num_layers_to_equalize] + + # Constructs an equalization_qconfig_dict that specifies to only equalize + # the layers with the highest quantization errors + module_to_qconfig_list = [ + (item[0], default_equalization_qconfig) for item in layers_to_equalize + ] + equalization_qconfig_dict = {"module_name": module_to_qconfig_list} + return equalization_qconfig_dict diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/_lower_to_native_backend.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/_lower_to_native_backend.py new file mode 100644 index 0000000000000000000000000000000000000000..ad20bcc96251d8fb439e5201a2038e28e5ec675b --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/_lower_to_native_backend.py @@ -0,0 +1,1413 @@ +# mypy: allow-untyped-defs +import operator +from collections.abc import Callable +from typing import Any + +import torch +import torch.ao.nn.intrinsic as nni +import torch.ao.nn.intrinsic.quantized as nniq +import torch.ao.nn.intrinsic.quantized.dynamic as nniqd +import torch.ao.nn.quantized as nnq +import torch.ao.nn.quantized.dynamic as nnqd +import torch.ao.nn.quantized.reference as nnqr +import torch.nn as nn +import torch.nn.functional as F +from torch.ao.nn.quantized.modules.utils import WeightedQuantizedModule +from torch.ao.quantization.qconfig import QConfigAny +from torch.ao.quantization.quantization_mappings import get_quantized_operator +from torch.ao.quantization.utils import _parent_name +from torch.fx import GraphModule, map_arg, Node +from torch.fx.graph import Graph + +from .utils import ( + collect_producer_nodes, + create_node_from_old_node_preserve_meta, + get_linear_prepack_op_for_dtype, + get_new_attr_name_with_prefix, + get_qconv_prepack_op, + graph_module_from_producer_nodes, +) + + +QOP_TO_ARG_NAMES_TO_SKIP: dict[Callable[..., Any], list[str]] = { + torch._ops.ops.quantized.hardswish: ["inplace"], + torch._ops.ops.quantized.elu: ["inplace"], + torch._ops.ops.quantized.dropout: ["inplace"], + torch._ops.ops.quantized.instance_norm: [ + "running_mean", + "running_var", + "use_input_stats", + "momentum", + ], +} + + +def _is_node_in_list(node, modules, func_list, method_list, module_type_list): + is_call_function = node.op == "call_function" and node.target in func_list + is_call_method = node.op == "call_method" and node.target in method_list + is_call_module = ( + node.op == "call_module" and type(modules[str(node.target)]) in module_type_list + ) + return is_call_function, is_call_method, is_call_module + + +def is_fixed_qparams_node(node, modules): + func_list = [ + torch.nn.functional.hardsigmoid, + torch.nn.functional.sigmoid, + torch.sigmoid, + torch.tanh, + ] + method_list = [ + "hardsigmoid", + "hardsigmoid_", + "sigmoid", + "sigmoid_", + "tanh", + "tanh_", + ] + module_type_list = [ + torch.nn.Hardsigmoid, + torch.nn.Sigmoid, + torch.nn.Tanh, + torch.nn.Softmax, + ] + return _is_node_in_list(node, modules, func_list, method_list, module_type_list) + + +def is_default_node(node, modules): + func_list = [ + torch.nn.functional.elu, + torch.nn.functional.hardswish, + torch.nn.functional.instance_norm, + torch.nn.functional.layer_norm, + torch.nn.functional.leaky_relu, + torch.nn.functional.dropout, + ] + method_list: list[Any] = [] + module_type_list = [ + nnqr.ConvTranspose1d, + nnqr.ConvTranspose2d, + nnqr.ConvTranspose3d, + torch.nn.ELU, + torch.nn.LeakyReLU, + torch.nn.Hardswish, + torch.nn.InstanceNorm1d, + torch.nn.InstanceNorm2d, + torch.nn.InstanceNorm3d, + torch.nn.LayerNorm, + torch.nn.Dropout, + torch.nn.PReLU, + torch.nn.BatchNorm2d, + torch.nn.BatchNorm3d, + torch.ao.nn.intrinsic.BNReLU2d, + torch.ao.nn.intrinsic.BNReLU3d, + ] + return _is_node_in_list(node, modules, func_list, method_list, module_type_list) + + +def is_copy_node(node, modules): + func_list = [ + torch.adaptive_avg_pool1d, + torch.nn.functional.adaptive_avg_pool2d, + torch.nn.functional.adaptive_avg_pool3d, + torch.nn.functional.hardtanh, + torch.nn.functional.hardtanh_, + torch.nn.functional.interpolate, + torch.nn.functional.max_pool1d, + torch.nn.functional.max_pool2d, + torch.nn.functional.max_pool3d, + torch.nn.functional.relu, + torch.nn.functional.relu6, + torch.avg_pool1d, + torch._C._nn.avg_pool2d, + torch._C._nn.avg_pool3d, + torch.clamp, + torch.flatten, + torch.mean, + operator.floordiv, + # F.channel_shuffle and torch.channel_shuffle are essentially the same thing + # so we only need to put one of them here + torch.channel_shuffle, + ] + method_list = [ + "clamp", + "mean", + "relu", + "relu_", + ] + module_type_list = [ + torch.nn.AdaptiveAvgPool1d, + torch.nn.AdaptiveAvgPool2d, + torch.nn.AdaptiveAvgPool3d, + torch.nn.AvgPool1d, + torch.nn.AvgPool2d, + torch.nn.AvgPool3d, + torch.nn.Hardtanh, + torch.nn.MaxPool1d, + torch.nn.MaxPool2d, + torch.nn.MaxPool3d, + torch.nn.ReLU, + torch.nn.ReLU6, + torch.nn.ChannelShuffle, + ] + return _is_node_in_list(node, modules, func_list, method_list, module_type_list) + + +def is_general_tensor_shape_node(node, modules): + func_list = [ + torch.narrow, + torch.transpose, + torch.repeat_interleave, + torch.squeeze, + torch.stack, + torch.unsqueeze, + torch.nn.functional.pixel_shuffle, + torch.nn.functional.pixel_unshuffle, + ] + method_list = [ + "contiguous", + "detach", + "detach_", + "permute", + "repeat", + "repeat_interleave", + "reshape", + "resize_", + "shape", + "size", + "squeeze", + "squeeze_", + "transpose", + "unsqueeze", + "unsqueeze_", + "view", + ] + module_type_list = [ + torch.nn.Identity, + torch.nn.PixelShuffle, + torch.nn.PixelUnshuffle, + ] + return _is_node_in_list(node, modules, func_list, method_list, module_type_list) + + +def is_other_node(node, modules): + func_list = [ + torch.cat, + ] + method_list: list[Any] = [] + module_type_list: list[Any] = [] + return _is_node_in_list(node, modules, func_list, method_list, module_type_list) + + +def is_special_pattern_node(node, modules): + res_function, res_method, res_module = False, False, False + for checker in [ + is_fixed_qparams_node, + is_default_node, + is_copy_node, + is_general_tensor_shape_node, + is_other_node, + ]: + is_call_function, is_call_method, is_call_module = checker(node, modules) + res_function = res_function or is_call_function + res_method = res_method or is_call_method + res_module = res_module or is_call_module + return res_function, res_method, res_module + + +def is_dequantize_node(node): + return ( + isinstance(node, Node) + and node.op == "call_method" + and node.target == "dequantize" + ) + + +def is_getattr_tensor_metadata_node(node): + return ( + node.op == "call_function" + and node.target is getattr + and node.args[1] == "shape" + ) + + +def is_get_tensor_info_node(node): + return node.op == "call_method" and node.target in ["shape", "size"] + + +def should_skip_lowering(op: torch.fx.node.Node, qconfig_map: dict[str, QConfigAny]): + """ + Return True if the op is configured with a None qconfig, False otherwise. + Note: maybe need to generalize this to also check for the dtype, and we + only lower when dtype matches, but right now fbgemm/qnnpack only support + a single dtype, so it is OK for now. + """ + return op.name in qconfig_map and qconfig_map[op.name] is None + + +# Mapping from reference module class to the replacement static quantized module class for lowering +STATIC_LOWER_MODULE_MAP: dict[type[nn.Module], type[WeightedQuantizedModule]] = { + nnqr.Linear: nnq.Linear, + nnqr.Conv1d: nnq.Conv1d, + nnqr.Conv2d: nnq.Conv2d, + nnqr.Conv3d: nnq.Conv3d, +} + +# Mapping from reference module class to the replacement dynamic quantized module class for lowering +DYNAMIC_LOWER_MODULE_MAP: dict[type[nn.Module], type[nn.Module]] = { + nnqr.Linear: nnqd.Linear, + nnqr.GRUCell: nnqd.GRUCell, + nnqr.LSTMCell: nnqd.LSTMCell, + nnqr.RNNCell: nnqd.RNNCell, + nnqr.LSTM: nnqd.LSTM, + nnqr.GRU: nnqd.GRU, +} + +# Mapping from reference module class to the replacement weight only quantized module class for lowering +# TODO: correct the namespace for these modules +WEIGHT_ONLY_LOWER_MODULE_MAP: dict[type[nn.Module], type[nn.Module]] = { + nnqr.Embedding: nnq.Embedding, + nnqr.EmbeddingBag: nnq.EmbeddingBag, +} + +# TODO: merge with STATIC_LOWER_MODULE_MAP after we merge +# _lower_static_weighted_ref_module and special_pattern_replacement +SPECIAL_PATTERN_LOWER_MODULE_MAP = { + nn.BatchNorm2d: nnq.BatchNorm2d, + nn.BatchNorm3d: nnq.BatchNorm3d, + nnqr.ConvTranspose1d: nnq.ConvTranspose1d, + nnqr.ConvTranspose2d: nnq.ConvTranspose2d, + nnqr.ConvTranspose3d: nnq.ConvTranspose3d, + nn.ELU: nnq.ELU, + nn.LeakyReLU: nnq.LeakyReLU, + nn.Hardswish: nnq.Hardswish, + nn.InstanceNorm1d: nnq.InstanceNorm1d, + nn.InstanceNorm2d: nnq.InstanceNorm2d, + nn.InstanceNorm3d: nnq.InstanceNorm3d, + nn.LayerNorm: nnq.LayerNorm, + nn.Dropout: nnq.Dropout, + nn.Softmax: nnq.Softmax, + nn.PReLU: nnq.PReLU, + nni.BNReLU2d: nniq.BNReLU2d, + nni.BNReLU3d: nniq.BNReLU3d, +} + +# Mapping from fused module class to a 2-tuple of: +# 1) The inner reference module class +# 2) The replacement static quantized module class for lowering +STATIC_LOWER_FUSED_MODULE_MAP: dict[ + type[nn.Module], tuple[type[nn.Module], type[WeightedQuantizedModule]] +] = { + nni.LinearReLU: (nnqr.Linear, nniq.LinearReLU), + # TODO: LinearLeakyReLU is registered as global but it is only fused and + # lowered when ondnn's backend config is used. Maybe need to separate + # registration and lowering functions for different backends in the future. + nni.LinearLeakyReLU: (nnqr.Linear, nniq.LinearLeakyReLU), + nni.LinearTanh: (nnqr.Linear, nniq.LinearTanh), + nni.ConvReLU1d: (nnqr.Conv1d, nniq.ConvReLU1d), + nni.ConvReLU2d: (nnqr.Conv2d, nniq.ConvReLU2d), + nni.ConvReLU3d: (nnqr.Conv3d, nniq.ConvReLU3d), +} + +# The difference between STATIC_LOWER_FUSED_MODULE_TWO_INPUTS_MAP and STATIC_LOWER_FUSED_MODULE_MAP: +# The refer node inside STATIC_LOWER_FUSED_MODULE_TWO_INPUTS_MAP has 2 inputs. +# Mapping from fused module class to a 2-tuple of: +# 1) The inner reference module class +# 2) The replacement static quantized module class for lowering +STATIC_LOWER_FUSED_MODULE_TWO_INPUTS_MAP: dict[ + type[nn.Module], tuple[type[nn.Module], type[WeightedQuantizedModule]] +] = { + nni.ConvAdd2d: (nnqr.Conv2d, nniq.ConvAdd2d), + nni.ConvAddReLU2d: (nnqr.Conv2d, nniq.ConvAddReLU2d), +} + +# Mapping from fused module class to a 2-tuple of: +# 1) The inner reference module class +# 2) The replacement dynamic quantized module class for lowering +DYNAMIC_LOWER_FUSED_MODULE_MAP: dict[ + type[nn.Module], tuple[type[nn.Module], type[nn.Module]] +] = { + nni.LinearReLU: (nnqr.Linear, nniqd.LinearReLU), +} + +# Mapping from a functional to lower to a 2-tuple of +# 1) The quantized version of the op +# 2) The quantized version of the op fused with relu, if it exists, else None +STATIC_LOWER_FUNCTIONAL_MAP: dict[Callable, tuple[Callable, Callable | None]] = { + F.linear: (torch.ops.quantized.linear, torch.ops.quantized.linear_relu), + F.conv1d: (torch.ops.quantized.conv1d, torch.ops.quantized.conv1d_relu), + F.conv2d: (torch.ops.quantized.conv2d, torch.ops.quantized.conv2d_relu), + F.conv3d: (torch.ops.quantized.conv3d, torch.ops.quantized.conv3d_relu), + F.conv_transpose1d: (torch.ops.quantized.conv_transpose1d, None), + F.conv_transpose2d: (torch.ops.quantized.conv_transpose2d, None), + F.conv_transpose3d: (torch.ops.quantized.conv_transpose3d, None), +} + +WEIGHT_PREPACK_OPS: set[Callable] = { + torch._ops.ops.quantized.linear_prepack, + torch._ops.ops.quantized.linear_prepack_fp16, + torch._ops.ops.quantized.conv1d_prepack, + torch._ops.ops.quantized.conv2d_prepack, + torch._ops.ops.quantized.conv3d_prepack, + torch.ops.quantized.conv_transpose1d_prepack, + torch.ops.quantized.conv_transpose2d_prepack, + torch.ops.quantized.conv_transpose3d_prepack, +} + +# Mapping from a functional to a dictionary, where the key is a 2-tuple of +# (input_activation_dtype, weight_dtype) and the value is a 2-tuple of +# 1) The dynamically quantized version of the op +# 2) The dynamically quantized version of the op fused with relu, if it exists, else None +DYNAMIC_LOWER_FUNCTIONAL_MAP: dict[ + Callable, dict[tuple[torch.dtype, torch.dtype], tuple[Callable, Callable | None]] +] = { + F.linear: { + (torch.quint8, torch.qint8): ( + torch.ops.quantized.linear_dynamic, + torch.ops.quantized.linear_relu_dynamic, + ), + (torch.float16, torch.float16): ( + torch.ops.quantized.linear_dynamic_fp16, + torch.ops.quantized.linear_relu_dynamic_fp16, + ), + }, + # dynamic conv + relu is not available yet + F.conv1d: { + (torch.quint8, torch.qint8): (torch.ops.quantized.conv1d_dynamic, None), + }, + F.conv2d: { + (torch.quint8, torch.qint8): (torch.ops.quantized.conv2d_dynamic, None), + }, + F.conv3d: { + (torch.quint8, torch.qint8): (torch.ops.quantized.conv3d_dynamic, None), + }, +} + +CONV_FUNCTIONAL_OPS: set[Callable] = { + F.conv1d, + F.conv2d, + F.conv3d, +} + +CONV_TRANSPOSE_FUNCTIONAL_OPS: set[Callable] = { + F.conv_transpose1d, + F.conv_transpose2d, + F.conv_transpose3d, +} + +# TODO: add tests for lowering these ops +QBIN_OP_MAPPING: dict[Callable | str, Callable] = { + operator.add: torch.ops.quantized.add, + torch.add: torch.ops.quantized.add, + operator.mul: torch.ops.quantized.mul, + operator.matmul: torch.ops.quantized.matmul, + torch.mul: torch.ops.quantized.mul, + torch.matmul: torch.ops.quantized.matmul, +} +QBIN_RELU_OP_MAPPING: dict[Callable | str, Callable] = { + operator.add: torch.ops.quantized.add_relu, + torch.add: torch.ops.quantized.add_relu, + operator.mul: torch.ops.quantized.mul_relu, + torch.mul: torch.ops.quantized.mul_relu, +} + +ORIGINAL_WEIGHTS_LOOKUP = "original_weights_lookup" + + +def _save_packed_weight(self, destination, prefix, keep_vars): + for attr_name in dir(self): + if "_packed_weight" in attr_name and isinstance( + getattr(self, attr_name), torch._C.ScriptObject + ): # type: ignore[attr-defined] + packed_weight = getattr(self, attr_name) + destination[prefix + attr_name] = packed_weight + + +def _load_packed_weight( + self, + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, +): + attrs_to_pop = [] + for attr_name in state_dict: + if attr_name.startswith("_packed_weight") and isinstance( + state_dict[attr_name], torch._C.ScriptObject + ): # type: ignore[attr-defined] # noqa: B950 + setattr(self, attr_name, state_dict[attr_name]) + attrs_to_pop.append(attr_name) + + # pop the packed param attributesn + for attr_name in attrs_to_pop: + state_dict.pop(attr_name) + + +def fold_weight( + quantized_model: GraphModule, + node_name_to_scope: dict[str, tuple[str, type]], + keep_original_weights: bool = False, +) -> GraphModule: + """ + Trace back from the weight node util we hit getattr, reconstruct the + graph module with the traced nodes and run the graph module to pack the + weight. then replace the original chain of ops with the packed weight. + """ + packed_weights = {} + # map from folded node name to the prepacked weight name + folded_nodes = {} + original_weights_lookup: dict[str, list] = {} + lookup_counter = 0 + # get packed weights + for node in quantized_model.graph.nodes: + if node.op == "call_function" and node.target in WEIGHT_PREPACK_OPS: + nodes_to_fold = collect_producer_nodes(node) + if nodes_to_fold is not None: + for node_to_fold in nodes_to_fold: + folded_nodes[node_to_fold.name] = node + + prepacking_module = graph_module_from_producer_nodes( + quantized_model, nodes_to_fold + ) + packed_weight = prepacking_module() + packed_weights[node.name] = packed_weight + if keep_original_weights: + original_weights = list(prepacking_module.state_dict().values()) + original_weights_lookup[str(lookup_counter)] = sorted( + original_weights, key=lambda x: x.numel(), reverse=True + ) + if len(original_weights_lookup[str(lookup_counter)]) == 1: + # bias is None + original_weights_lookup[str(lookup_counter)].append(None) + lookup_counter += 1 + lookup_counter = 0 + + # remove folded nodes and replace the prepacking node with getattr + folded_graph = Graph() + env: dict[Any, Any] = {} + + def load_arg(a): + return map_arg(a, lambda node: env[node.name]) + + for node in quantized_model.graph.nodes: + prepack_node = folded_nodes.get(node.name, None) + if prepack_node is node: + packed_weight = packed_weights[node.name] + # add a prepacked attribute to root + op_node = next(iter(prepack_node.users)) + module_path, _ = node_name_to_scope[op_node.name] + get_new_packed_weight_name = get_new_attr_name_with_prefix( + module_path + "_packed_weight_" + ) + packed_weight_name = get_new_packed_weight_name(quantized_model) + setattr(quantized_model, packed_weight_name, packed_weight) + # replace prepack node with a getattr node + env[node.name] = folded_graph.create_node( + "get_attr", packed_weight_name, (), {} + ) + if keep_original_weights: + key_name = ( + packed_weight_name.replace(":", "_") + .replace("/", "_") + .replace("|", "_") + .replace(" ", "") + .lower() + ) + original_weights_lookup[key_name] = original_weights_lookup[ + str(lookup_counter) + ] + del original_weights_lookup[str(lookup_counter)] + lookup_counter += 1 + elif prepack_node is not None: + # remove the fold node + continue + else: + # copy other nodes + env[node.name] = folded_graph.node_copy(node, load_arg) + + quantized_model = GraphModule(quantized_model, folded_graph) + quantized_model._register_state_dict_hook(_save_packed_weight) + quantized_model.register_load_state_dict_pre_hook(_load_packed_weight) + + if keep_original_weights: + setattr( # noqa: B010 + quantized_model, ORIGINAL_WEIGHTS_LOOKUP, original_weights_lookup + ) + + return quantized_model + + +def _get_module(node: Node, modules: dict[str, nn.Module]) -> nn.Module | None: + """ + Return the `torch.nn.Module` that corresponds to the specified node's target. + If no such node exists, return None. + """ + if node.op == "call_module" and str(node.target) in modules: + return modules[str(node.target)] + else: + return None + + +def _match_static_pattern( + node: Node, + modules: dict[str, nn.Module], + qconfig_map: dict[str, QConfigAny], + matching_modules_or_ops: list[Callable], + dequantize_node_arg_indices: list[int], +) -> tuple[Node, Node, Node] | tuple[None, None, None]: + """ + Match the pattern (dequantize - ref node - quantize) against the node provided. + + If there is a match, return a 3-tuple of: + 1) q_node: the quantize node, + 2) relu_node: a relu node wrapping the ref_node, and + 3) ref_node: a reference module or functional node to replace with its quantized counterpart + Otherwise, if there is no match, return a 3-tuple of (None, None, None). + + Parameters: + node: The `torch.fx.Node` to match against. + modules: A mapping from node names to modules in the model graph, used for module lookup. + qconfig_map: A mapping from node names to the qconfigs associated with the nodes. + If the corresponding qconfig for the reference node is None, then return no match. + matching_modules_or_ops: Either a list of functions or a list of `torch.nn.Module`s. + If the reference node is not in this list, then return no match. + dequantize_node_arg_indices: A list of indices in the reference node args where dequantize + nodes may be present. An empty list means skipping the check for dequantize nodes. + """ + SKIP_LOWERING_VALUE = (None, None, None) + + # Match quantize node + if node.op != "call_function" or node.target != torch.quantize_per_tensor: + return SKIP_LOWERING_VALUE + q_node = node + ref_node = q_node.args[0] + if not isinstance(ref_node, Node): + raise AssertionError("Expected the reference node to be a torch.fx Node") + + # Handle cases where the node is wrapped in a ReLU + if (ref_node.op == "call_function" and ref_node.target in (F.relu, torch.relu)) or ( + ref_node.op == "call_module" and type(_get_module(ref_node, modules)) is nn.ReLU + ): + relu_node = ref_node + ref_node = relu_node.args[0] + if not isinstance(ref_node, Node): + raise AssertionError( + "Expected the reference node after ReLU to be a torch.fx Node" + ) + else: + relu_node = None + if should_skip_lowering(ref_node, qconfig_map): + return SKIP_LOWERING_VALUE + + # Match reference module or functional + if isinstance(matching_modules_or_ops[0], type) and issubclass( + matching_modules_or_ops[0], nn.Module + ): + expected_op = "call_module" + match_key = type(_get_module(ref_node, modules)) + else: + expected_op = "call_function" + match_key = ref_node.target # type: ignore[assignment] + if ref_node.op != expected_op or match_key not in matching_modules_or_ops: + return SKIP_LOWERING_VALUE + + # Match dequantize node(s). Both of the following conditions must pass: + # (1) All `torch.fx.Node`s at the matching indices must be a dequantize node + # (2) There must be at least one dequantize node + matched_dequantize = False + for i in dequantize_node_arg_indices: + if i >= len(ref_node.args): + raise AssertionError( + f"Dequantize index {i} exceeded reference node's arg length {len(ref_node.args)}" + ) + arg = ref_node.args[i] + if is_dequantize_node(arg): + matched_dequantize = True + elif isinstance(arg, Node): + return SKIP_LOWERING_VALUE + if not matched_dequantize: + return SKIP_LOWERING_VALUE + + return (q_node, relu_node, ref_node) # type: ignore[return-value] + + +def _match_static_pattern_with_two_inputs( + node: Node, + modules: dict[str, nn.Module], + qconfig_map: dict[str, QConfigAny], + matching_modules_or_ops: list[Callable], +) -> tuple[Node, Node] | tuple[None, None]: + """ + (dequantize \ + Match the pattern (dequantize - ref node - quantize) against the node provided. + + If there is a match, return a 2-tuple of: + 1) q_node: the quantize node, + 2) ref_node: a reference module or functional node to replace with its quantized counterpart + Otherwise, if there is no match, return a 2-tuple of (None, None). + + Parameters: + node: The `torch.fx.Node` to match against. + modules: A mapping from node names to modules in the model graph, used for module lookup. + qconfig_map: A mapping from node names to the qconfigs associated with the nodes. + If the corresponding qconfig for the reference node is None, then return no match. + matching_modules_or_ops: Either a list of functions or a list of `torch.nn.Module`s. + If the reference node is not in this list, then return no match. + """ + SKIP_LOWERING_VALUE = (None, None) + + # Match quantize node + if node.op != "call_function" or node.target != torch.quantize_per_tensor: + return SKIP_LOWERING_VALUE + q_node = node + ref_node = q_node.args[0] + if not isinstance(ref_node, Node): + raise AssertionError("Expected the reference node to be a torch.fx Node") + + if should_skip_lowering(ref_node, qconfig_map): + return SKIP_LOWERING_VALUE + + # Match reference module or functional + if isinstance(matching_modules_or_ops[0], type) and issubclass( + matching_modules_or_ops[0], nn.Module + ): + expected_op = "call_module" + match_key = type(_get_module(ref_node, modules)) + else: + # This pass only support op of "call_module" + return SKIP_LOWERING_VALUE + + if ref_node.op != expected_op or match_key not in matching_modules_or_ops: + return SKIP_LOWERING_VALUE + + # Check ref_node has 2 input nodes, both are dq node. + if len(ref_node.args) != 2: + return SKIP_LOWERING_VALUE + for i in range(len(ref_node.args)): + arg = ref_node.args[i] + if not is_dequantize_node(arg): + return SKIP_LOWERING_VALUE + + return (q_node, ref_node) + + +def _lower_static_weighted_ref_module( + model: GraphModule, qconfig_map: dict[str, QConfigAny] +): + """ + Traverse the graph and find dequantize - ref module - quantize patterns + and replace them with the quantized version of the ref module. + """ + modules = dict(model.named_modules(remove_duplicate=False)) + for n in model.graph.nodes: + # Step 0: Find nodes that match this pattern (dequantize - ref module - quantize) + matching_modules = list(STATIC_LOWER_MODULE_MAP.keys()) + list( + STATIC_LOWER_FUSED_MODULE_MAP.keys() + ) + q_node, _relu_node, ref_node = _match_static_pattern( + n, + modules, + qconfig_map, + matching_modules, # type: ignore[arg-type] + dequantize_node_arg_indices=[0], + ) + if q_node is None: + continue + if ref_node is None: + raise AssertionError( + "Expected a reference node when matching static pattern" + ) + (_, scale_node, zero_point_node, _) = q_node.args + ref_module = _get_module(ref_node, modules) + ref_class = type(ref_module) + if not isinstance(scale_node, Node): + raise AssertionError("Expected scale_node to be a Node") + if not isinstance(zero_point_node, Node): + raise AssertionError("Expected zero_point_node to be a Node") + if not issubclass(ref_class, nn.Module): + raise AssertionError( + "Expected reference module class to be a subclass of nn.Module" + ) + + # Step 1: Change this pattern to use the corresponding quantized module + # For fused modules, we also check whether the inner module is a reference module + # If so, we replace the entire fused module with the corresponding quantized module + if ref_class in STATIC_LOWER_FUSED_MODULE_MAP: + inner_ref_class, q_class = STATIC_LOWER_FUSED_MODULE_MAP[ref_class] + if type(ref_module[0]) is not inner_ref_class: # type: ignore[index] + continue + else: + q_class = STATIC_LOWER_MODULE_MAP[ref_class] + output_scale = getattr(model, scale_node.target) # type: ignore[arg-type] + output_zero_point = getattr(model, zero_point_node.target) # type: ignore[arg-type] + q_module = q_class.from_reference(ref_module, output_scale, output_zero_point) + # replace reference module with quantized module + parent_name, module_name = _parent_name(ref_node.target) + setattr(modules[parent_name], module_name, q_module) + + # Step 2: Reroute around dq_node, and remove q_node and its args + if len(ref_node.args) != 1: + raise AssertionError("Expected reference node to have exactly 1 arg") + dq_node = ref_node.args[0] + if not isinstance(dq_node, Node): + raise AssertionError("Expected dq_node to be a Node") + ref_node.replace_input_with(dq_node, dq_node.args[0]) # type: ignore[arg-type] + q_node.replace_all_uses_with(ref_node) + model.graph.erase_node(q_node) + model.graph.erase_node(scale_node) + model.graph.erase_node(zero_point_node) + + +def _lower_static_weighted_ref_module_with_two_inputs( + model: GraphModule, qconfig_map: dict[str, QConfigAny] +): + """ + Traverse the graph and find patterns + dequantize dequantize + \\ // + ref module + \\ + quantize + and replace them with the quantized version of the ref module. + """ + modules = dict(model.named_modules(remove_duplicate=False)) + for n in model.graph.nodes: + # (dequantize \ + # Step 0: Find nodes that match this pattern (dequantize - ref module - quantize) + matching_modules = list(STATIC_LOWER_FUSED_MODULE_TWO_INPUTS_MAP.keys()) + (q_node, ref_node) = _match_static_pattern_with_two_inputs( + n, + modules, + qconfig_map, + matching_modules, # type: ignore[arg-type] + ) + if q_node is None: + continue + if ref_node is None: + raise AssertionError( + "Expected a reference node when matching static pattern with two inputs" + ) + (_, scale_node, zero_point_node, _) = q_node.args + ref_module = _get_module(ref_node, modules) + ref_class = type(ref_module) + if not isinstance(scale_node, Node): + raise AssertionError("Expected scale_node to be a Node") + if not isinstance(zero_point_node, Node): + raise AssertionError("Expected zero_point_node to be a Node") + if not issubclass(ref_class, nn.Module): + raise AssertionError( + "Expected reference module class to be a subclass of nn.Module" + ) + + # Step 1: Change this pattern to use the corresponding quantized module + # For fused modules, we also check whether the inner module is a reference module + # If so, we replace the entire fused module with the corresponding quantized module + if ref_class in STATIC_LOWER_FUSED_MODULE_TWO_INPUTS_MAP: + inner_ref_class, q_class = STATIC_LOWER_FUSED_MODULE_TWO_INPUTS_MAP[ + ref_class + ] + if type(ref_module[0]) is not inner_ref_class: # type: ignore[index] + continue + else: + continue + output_scale = getattr(model, scale_node.target) # type: ignore[arg-type] + output_zero_point = getattr(model, zero_point_node.target) # type: ignore[arg-type] + q_module = q_class.from_reference(ref_module, output_scale, output_zero_point) + # replace reference module with quantized module + parent_name, module_name = _parent_name(ref_node.target) + setattr(modules[parent_name], module_name, q_module) + + # Step 2: Reroute around dq_node, and remove q_node and its args + if len(ref_node.args) != 2: + raise AssertionError("Expected reference node to have exactly 2 args") + for arg in ref_node.args: + if not is_dequantize_node(arg): + continue + dq_node = arg + if not isinstance(dq_node, Node): + raise AssertionError("Expected dq_node to be a Node") + ref_node.replace_input_with(dq_node, dq_node.args[0]) # type: ignore[arg-type] + + q_node.replace_all_uses_with(ref_node) + model.graph.erase_node(q_node) + model.graph.erase_node(scale_node) + model.graph.erase_node(zero_point_node) + + +def _lower_dynamic_weighted_ref_module(model: GraphModule): + """ + Traverse the graph and find quantize_per_tensor_dynamic - dequantize - ref_module patterns + and replace them with the dynamically quantized version of the ref module. + """ + named_modules = dict(model.named_modules(remove_duplicate=False)) + for n in model.graph.nodes: + if n.op != "call_module" or type(named_modules[str(n.target)]) not in set( + DYNAMIC_LOWER_MODULE_MAP.keys() + ).union(set(DYNAMIC_LOWER_FUSED_MODULE_MAP.keys())): + continue + ref_node = n + dq_node = ref_node.args[0] + if dq_node.op != "call_method" or dq_node.target != "dequantize": + continue + + input_dynamic_q_node = dq_node.args[0] + + if ( + input_dynamic_q_node.op != "call_function" + or input_dynamic_q_node.target != torch.quantize_per_tensor_dynamic + ): + continue + + activation_dtype = input_dynamic_q_node.args[1] + is_fp16 = activation_dtype == torch.float16 + is_int8 = activation_dtype in [torch.quint8, torch.qint8] + if not is_int8 and not is_fp16: + continue + + ref_module = named_modules[str(ref_node.target)] + ref_class = type(ref_module) + if ref_class in DYNAMIC_LOWER_FUSED_MODULE_MAP: + inner_ref_class, q_class = DYNAMIC_LOWER_FUSED_MODULE_MAP[ref_class] + if type(ref_module[0]) is not inner_ref_class: + continue + else: + q_class = DYNAMIC_LOWER_MODULE_MAP.get(ref_class) # type: ignore[assignment] + # TODO: maybe define a WeightedDynamicallyQuantizedModule + q_module = q_class.from_reference(ref_module) # type: ignore[attr-defined] + + # replace reference module with dynamically quantized module + parent_name, module_name = _parent_name(ref_node.target) + setattr(named_modules[parent_name], module_name, q_module) + ref_node.replace_input_with(dq_node, input_dynamic_q_node.args[0]) + + +def _lower_weight_only_weighted_ref_module(model: GraphModule): + """ + Traverse the graph and find ref_module patterns + and replace them with the weight only quantized version of the ref module. + """ + named_modules = dict(model.named_modules(remove_duplicate=False)) + for n in model.graph.nodes: + if n.op != "call_module" or type(named_modules[str(n.target)]) not in set( + WEIGHT_ONLY_LOWER_MODULE_MAP.keys() + ): + continue + ref_node = n + ref_module = named_modules[str(ref_node.target)] + ref_class = type(ref_module) + q_class = WEIGHT_ONLY_LOWER_MODULE_MAP.get(ref_class) + # TODO: WeightedQuantizedModule is currently assuming static quant apis + # with output_scale, output_zero_point in from_reference, we may want to + # relax that, or rename this + # TODO: maybe define a WeightedWeightOnlyQuantizedModule + q_module = q_class.from_reference(ref_module) # type: ignore[union-attr] + + # replace reference module with dynamically quantized module + parent_name, module_name = _parent_name(ref_node.target) + setattr(named_modules[parent_name], module_name, q_module) + + +def _lower_static_weighted_ref_functional( + model: GraphModule, qconfig_map: dict[str, QConfigAny] +): + """ + Traverse the graph and replace functional reference patterns with their quantized versions. + """ + modules = dict(model.named_modules(remove_duplicate=False)) + for n in model.graph.nodes: + # Step 0: Find nodes that match this pattern (dequantize - functional op - quantize) + matching_ops = list(STATIC_LOWER_FUNCTIONAL_MAP.keys()) + (q_node, relu_node, func_node) = _match_static_pattern( + n, modules, qconfig_map, matching_ops, dequantize_node_arg_indices=[0, 1] + ) + if q_node is None: + continue + if func_node is None: + raise AssertionError( + "Expected a function node when matching static functional pattern" + ) + (_, output_scale_node, output_zp_node, _) = q_node.args + (input_dq_node, weight_dq_node, *remaining_func_args) = func_node.args + if not isinstance(output_zp_node, Node): + raise AssertionError("Expected output_zp_node to be a Node") + if not isinstance(input_dq_node, Node): + raise AssertionError("Expected input_dq_node to be a Node") + if not isinstance(weight_dq_node, Node): + raise AssertionError("Expected weight_dq_node to be a Node") + quantized_weight = weight_dq_node.args[0] + if not isinstance(quantized_weight, Node): + raise AssertionError("Expected quantized_weight to be a Node") + if quantized_weight.op != "call_function" or quantized_weight.target not in ( + torch.quantize_per_tensor, + torch.quantize_per_channel, + ): + continue + + # Step 1: Replace quantized weights with packed weights, which will be folded later + # Use the right prepack op and prepare the corresponding args + # Linear prepack args: (quantized weights[, bias]) + # Conv prepack args: (quantized weights[, bias, stride, padding, dilation, groups]) + prepack_args = [quantized_weight] + remaining_func_args + if func_node.target is F.linear: + weight_dtype = quantized_weight.args[-1] + prepack_op = get_linear_prepack_op_for_dtype(weight_dtype) + elif func_node.target in CONV_FUNCTIONAL_OPS: + prepack_op = get_qconv_prepack_op(func_node.target) # type: ignore[arg-type] + # For conv1d, the stride, padding, and dilation args may be ints, + # in which case we need to convert them to tuples + if func_node.target is F.conv1d: + for i in [2, 3, 4]: + if len(prepack_args) > i and isinstance(prepack_args[i], int): + prepack_args[i] = (prepack_args[i],) + elif func_node.target in CONV_TRANSPOSE_FUNCTIONAL_OPS: + prepack_op = get_qconv_prepack_op(func_node.target) # type: ignore[arg-type] + # For conv_transpose1d, the stride, padding, and dilation args may be ints, + # in which case we need to convert them to tuples + if func_node.target is F.conv_transpose1d: + # Note prepack_args[5] is groups. + for i in [2, 3, 4, 6]: + if len(prepack_args) > i and isinstance(prepack_args[i], int): + prepack_args[i] = (prepack_args[i],) + # swap dilation and groups + # prepack op has arguments: {w, b, stride, padding, output_padding, dilation, groups} + # transposed conv op has arguments: {x, w, b, stride, padding, output_padding, groups, dilation} + if len(prepack_args) > 6: + prepack_args[5], prepack_args[6] = prepack_args[6], prepack_args[5] + else: + raise ValueError(f"Lowering is not supported for op '{func_node.target}'") + with model.graph.inserting_before(output_scale_node): # type: ignore[arg-type] + # kwargs of the func node are needed for prepack op (i.e., quantized::linear_prepack) + # They are not needed for compute op (i.e., quantized::linear) + kwargs = func_node.kwargs + # F.linear uses 'bias' key for bias while qlinear_prepack uses 'B' for bias + if func_node.target is F.linear and "bias" in kwargs: + kwargs = kwargs.copy() + kwargs["B"] = kwargs["bias"] + del kwargs["bias"] + packed_weight = model.graph.create_node( + "call_function", prepack_op, tuple(prepack_args), kwargs + ) + + # Step 2: Replace reference pattern with the corresponding quantized op + (q_func, q_relu_func) = STATIC_LOWER_FUNCTIONAL_MAP[func_node.target] # type: ignore[index] + # conv_transpose does not support fusion with relu yet. q_relu_func is None in such cases + if q_relu_func is not None: + func_node.target = q_relu_func if relu_node is not None else q_func + else: + func_node.target = q_func + func_node.args = ( + input_dq_node.args[0], + packed_weight, + output_scale_node, + output_zp_node, + ) + # kwargs for func_node has been moved to kwargs for prepack op + func_node.kwargs = {} + q_node.replace_all_uses_with(func_node) + # Move func_node after output_zp_node in the graph + output_zp_node.append(func_node) + + # Clean up: Remove quantize node, and the relu node if it exists + model.graph.erase_node(q_node) + if relu_node is not None and q_relu_func is not None: + model.graph.erase_node(relu_node) + + +def _lower_dynamic_weighted_ref_functional( + model: GraphModule, qconfig_map: dict[str, QConfigAny] +): + """ + Traverse the graph and replace functional reference patterns with their dynamically + quantized versions. + Examples: + quantize_per_tensor_dynamic - dequantize - functional linear --> linear_dynamic + to(torch.float16) - dequantize - functional linear --> linear_dynamic_fp16 + """ + modules = dict(model.named_modules(remove_duplicate=False)) + # we want to search in reserved order so that we can match the larger patterns first + # e.g. we want to match linear - relu before linear. + for n in reversed(model.graph.nodes): + # Step 0: Find nodes that match this pattern + # (quantize_per_tensor_dynamic - dequantize - dynamically quantized op) + # We search for the pattern backwards, starting with the quantize node + # Quantize node args: (func, scale, zp, dtype) + func_node = n + # Handle cases where the functional op is wrapped in a ReLU + if ( + func_node.op == "call_function" + and func_node.target is F.relu + or func_node.op == "call_module" + and type(modules[str(func_node.target)]) is torch.nn.ReLU + ): + relu_node = func_node + func_node = relu_node.args[0] + else: + relu_node = None + if should_skip_lowering(func_node, qconfig_map): + continue + # Linear args: (dequantized inputs, dequantized weights[, bias]) + # Conv args: (dequantized inputs, dequantized weights[, bias, stride, padding, dilation, groups]) + if ( + func_node.op != "call_function" + or func_node.target not in DYNAMIC_LOWER_FUNCTIONAL_MAP + ): + continue + (input_dq_node, weight_dq_node, *remaining_func_args) = func_node.args + if ( + input_dq_node.op != "call_method" + or input_dq_node.target != "dequantize" + or weight_dq_node.op != "call_method" + or weight_dq_node.target != "dequantize" + ): + continue + + input_dynamic_q_node = input_dq_node.args[0] + + if ( + input_dynamic_q_node.op != "call_function" + or input_dynamic_q_node.target != torch.quantize_per_tensor_dynamic + ): + continue + + reduce_range_node = None + (pattern_input, activation_dtype, reduce_range_node) = input_dynamic_q_node.args + is_fp16 = activation_dtype == torch.float16 + is_int8 = activation_dtype in [torch.quint8, torch.qint8] + if not is_int8 and not is_fp16: + continue + + quantized_weight = weight_dq_node.args[0] + weight_dtype = quantized_weight.args[-1] + + # Step 1: Try to select reference pattern with the corresponding quantized op + dynamic_quant_dtype_key = (activation_dtype, weight_dtype) + if ( + dynamic_quant_dtype_key + not in DYNAMIC_LOWER_FUNCTIONAL_MAP[func_node.target] + ): + print( + f"Didn't find dtype combination {dynamic_quant_dtype_key} during " + f"dynamic quantized op lowering for {func_node.target}" + ) + continue + (q_func, q_relu_func) = DYNAMIC_LOWER_FUNCTIONAL_MAP[func_node.target][ + dynamic_quant_dtype_key + ] + + if q_func is None or q_relu_func is None: + print( + "Didn't find corresponding quantized function or quantized relu function " + f"for {func_node.target}, {dynamic_quant_dtype_key}" + ) + continue + + # Step 2: Replace quantized weights with packed weights, which will be folded later + # Use the right prepack op and prepare the corresponding args + # Linear prepack args: (quantized weights[, bias]) + # Conv prepack args: (quantized weights[, bias, stride, padding, dilation, groups]) + prepack_args = [quantized_weight] + remaining_func_args + prepack_kwargs = {} + if func_node.target is F.linear: + prepack_op = get_linear_prepack_op_for_dtype(weight_dtype) + kwargs = func_node.kwargs.copy() + if "bias" in kwargs: + prepack_kwargs["B"] = kwargs["bias"] + del kwargs["bias"] + func_node.kwargs = kwargs + elif func_node.target in CONV_FUNCTIONAL_OPS: + prepack_op = get_qconv_prepack_op(func_node.target) + # For conv1d, the stride, padding, and dilation args may be ints, + # in which case we need to convert them to tuples + if func_node.target is F.conv1d: + for i in [2, 3, 4]: + if len(prepack_args) > i and isinstance(prepack_args[i], int): + prepack_args[i] = (prepack_args[i],) + else: + raise ValueError(f"Lowering is not supported for op '{func_node.target}'") + with model.graph.inserting_before(func_node): + packed_weight = model.graph.create_node( + "call_function", prepack_op, tuple(prepack_args), prepack_kwargs + ) + + # Step 3: Replace reference pattern with the corresponding quantized op + func_node.target = q_relu_func if relu_node is not None else q_func + if is_int8: + func_node.args = (pattern_input, packed_weight, reduce_range_node) + else: + func_node.args = (pattern_input, packed_weight) + + if relu_node is not None: + relu_node.replace_all_uses_with(func_node) + + # Step 4: Remove the relu node if it exists + if relu_node is not None: + model.graph.erase_node(relu_node) + + +def _lower_quantized_binary_op(model: GraphModule, qconfig_map: dict[str, QConfigAny]): + binary_ops_to_lower: list[Callable] = [ + operator.add, + torch.add, + operator.mul, + torch.mul, + torch.matmul, + ] + modules = dict(model.named_modules(remove_duplicate=False)) + for n in model.graph.nodes: + # Step 0: Find nodes that match this pattern (dequantize - ref module - quantize) + (q_node, relu_node, bop_node) = _match_static_pattern( + n, + modules, + qconfig_map, + binary_ops_to_lower, + dequantize_node_arg_indices=[0, 1], + ) + if q_node is None: + continue + if bop_node is None: + raise AssertionError( + "Expected a binary op node when matching quantized binary op pattern" + ) + (_, scale_node, zero_point_node, _) = q_node.args + + # Step 1: Remove dequant nodes + num_dq_nodes = 0 + for arg in bop_node.args: + if not is_dequantize_node(arg): + continue + dq_node = arg + if not isinstance(dq_node, Node): + raise AssertionError("Expected dq_node to be a Node") + dn_input = dq_node.args[0] + bop_node.replace_input_with(dq_node, dn_input) # type: ignore[arg-type] + num_dq_nodes += 1 + if num_dq_nodes <= 0: + raise AssertionError( + "Expected at least one dequantize node in binary op args" + ) + + # Step 2: Swap binary op to quantized binary op + if bop_node.target not in QBIN_OP_MAPPING: + raise AssertionError( + f"Unsupported binary op {bop_node.target} for lowering" + ) + binop_to_qbinop = QBIN_OP_MAPPING if relu_node is None else QBIN_RELU_OP_MAPPING + qbin_op = binop_to_qbinop[bop_node.target] + # prepare the args for quantized binary op + # (x, y) + qop_node_args = list(bop_node.args) + # (x, y, scale, zero_point) + # add scale and zero_point arguments for Tensor - Tensor operation + if num_dq_nodes == 2: + qop_node_args.extend([scale_node, zero_point_node]) + # insert a call to quantized binary op and remove the original binary op + with model.graph.inserting_after(q_node): + qop_node = create_node_from_old_node_preserve_meta( + model.graph, + ("call_function", qbin_op, tuple(qop_node_args), {}), + bop_node, + ) + q_node.replace_all_uses_with(qop_node) + + # Step 3: Remove quantize node, binary op node, and relu node if any + model.graph.erase_node(q_node) + if relu_node is not None: + model.graph.erase_node(relu_node) + model.graph.erase_node(bop_node) + + +def special_pattern_replacement(model: GraphModule): + modules = dict(model.named_modules(remove_duplicate=False)) + for n in model.graph.nodes: + q_node = n + is_quantize = q_node.target is torch.quantize_per_tensor + is_to_fp16 = ( + q_node.op == "call_method" + and q_node.target == "to" + and len(q_node.args) == 2 + and q_node.args[1] == torch.float16 + ) + # Only continue when neither quantize nor to_fp16 + if not is_quantize and not is_to_fp16: + continue + ref_node = q_node.args[0] + # get output scale/zero_point/dtype from the quantize node + # ref_node, scale_node, zero_point_node, dtype = q_node.args + # TODO: add safety checks that users for the ref_node and dq_node needs to be one + is_call_function, is_call_method, is_call_module = is_fixed_qparams_node( + ref_node, modules + ) + if is_to_fp16 and (is_call_function or is_call_method or is_call_module): + # TODO: add a warning or error out here? (bc-breaking if error out) + # warnings.warn( + # "Only reference patterns are currently supported for {dtype} dtype with {op} op" + # "".format(dtype=dtypes, op=ref_node)) + continue + + is_call_function, is_call_method, is_call_module = is_default_node( + ref_node, modules + ) + if is_to_fp16 and (is_call_function or is_call_method or is_call_module): + # TODO: add a warning or error out here? (bc-breaking if error out) + continue + + # This check includes all supported ops + is_call_function, is_call_method, is_call_module = is_special_pattern_node( + ref_node, modules + ) + if not (is_call_module or is_call_function or is_call_method): + continue + if len(ref_node.args) <= 0 and len(ref_node.kwargs) <= 0: + raise AssertionError("Expected ref_node to have args or kwargs") + dq_node_or_nodes = ( + ref_node.args[0] + if len(ref_node.args) > 0 + else next(iter(ref_node.kwargs.values())) + ) + if not isinstance(dq_node_or_nodes, (Node, tuple, list)): + raise AssertionError( + "Expected dq_node_or_nodes to be a Node, tuple, or list" + ) + is_dequantize = False + if isinstance(dq_node_or_nodes, Node): + is_dequantize = ( + dq_node_or_nodes.op == "call_method" + and dq_node_or_nodes.target == "dequantize" + ) + elif isinstance(dq_node_or_nodes, (tuple, list)): + is_dequantize = all( + x.op == "call_method" and x.target == "dequantize" + for x in dq_node_or_nodes + ) + + if not is_dequantize: + continue + + # TODO: enable we have patterns that needs to swap the modules + if is_call_module: + ref_module = modules[ref_node.target] + if type(ref_module) in SPECIAL_PATTERN_LOWER_MODULE_MAP and is_quantize: + qmodule_cls = SPECIAL_PATTERN_LOWER_MODULE_MAP.get(type(ref_module)) + scale_node = q_node.args[1] + zero_point_node = q_node.args[2] + output_scale = getattr(model, scale_node.target) + output_zero_point = getattr(model, zero_point_node.target) + + qmodule = qmodule_cls.from_reference( # type:ignore[union-attr] + ref_module, output_scale, output_zero_point + ) + # replace reference module with quantized module + parent_name, module_name = _parent_name(ref_node.target) + setattr(modules[parent_name], module_name, qmodule) + + # reroute around dq node: + dq_nodes: list[Node] = [] + if isinstance(dq_node_or_nodes, Node): + dq_nodes = [dq_node_or_nodes] + elif isinstance(dq_node_or_nodes, (tuple, list)): + dq_nodes = list(dq_node_or_nodes) + + for dq_node in dq_nodes: + dn_input = dq_node.args[0] + ref_node.replace_input_with(dq_node, dn_input) + + # store q node args + qnode_qparams = list(q_node.args)[1:] + # replace uses of q node with input and remove q node + q_node_input = q_node.args[0] + q_node.replace_all_uses_with(q_node_input) + model.graph.erase_node(q_node) + + is_call_function, is_call_method, is_call_module = is_default_node( + ref_node, modules + ) + if is_call_function: + # pass scale/zer_point arguments from quantize_per_tensor to the default node operator + # insert an op after the zero_point node so that the scale/zero_point + # nodes are is available + qop = get_quantized_operator(ref_node.target) + args = list(ref_node.args) + kwargs = dict(ref_node.kwargs) + if qop in QOP_TO_ARG_NAMES_TO_SKIP: + args_to_skip = QOP_TO_ARG_NAMES_TO_SKIP[qop] + for arg in args_to_skip: + if arg in kwargs: + kwargs.pop(arg) + kwargs["output_scale"] = qnode_qparams[0] + kwargs["output_zero_point"] = qnode_qparams[1] + with model.graph.inserting_after(qnode_qparams[1]): + qop_node = create_node_from_old_node_preserve_meta( + model.graph, ("call_function", qop, tuple(args), kwargs), ref_node + ) + ref_node.replace_all_uses_with(qop_node) + model.graph.erase_node(ref_node) + else: + # remove scale/zero_point node for quantize node + for n in qnode_qparams: + if isinstance(n, Node): + model.graph.erase_node(n) + + return model + + +def _lower_getattr_tensor_metadta_op(model: GraphModule): + """Modified the graph of the model inplace, to skip extra dequantize op before + the general tensor shape ops when possible + """ + for n in model.graph.nodes: + if is_getattr_tensor_metadata_node(n): + maybe_dq = n.args[0] + if maybe_dq.op != "call_method" or maybe_dq.target != "dequantize": + continue + # skip the dequantize node + args = list(n.args) + args[0] = n.args[0].args[0] + n.args = tuple(args) + + +def _lower_get_tensor_info_op(model: GraphModule): + """Modified the graph of the model inplace, to skip extra dequantize op before + the general tensor shape ops when possible + """ + for n in model.graph.nodes: + if not is_get_tensor_info_node(n): + continue + maybe_dq = n.args[0] + if maybe_dq.op != "call_method" or maybe_dq.target != "dequantize": + continue + # skip the dequantize node + args = list(n.args) + args[0] = n.args[0].args[0] + n.args = tuple(args) + + +def _lower_to_native_backend( + model: GraphModule, + qconfig_map: dict[str, QConfigAny], + node_name_to_scope: dict[str, tuple[str, type]], + keep_original_weights: bool = False, +) -> GraphModule: + """Lower a quantized reference model (with reference quantized operator patterns) + to the native backend in PyTorch (fbgemm/qnnpack), both backends shares the same + operator signature so they can be lowered with the same function + """ + _lower_static_weighted_ref_module(model, qconfig_map) + _lower_static_weighted_ref_module_with_two_inputs(model, qconfig_map) + _lower_dynamic_weighted_ref_module(model) + _lower_weight_only_weighted_ref_module(model) + _lower_static_weighted_ref_functional(model, qconfig_map) + _lower_dynamic_weighted_ref_functional(model, qconfig_map) + _lower_quantized_binary_op(model, qconfig_map) + _lower_getattr_tensor_metadta_op(model) + _lower_get_tensor_info_op(model) + special_pattern_replacement(model) + model.graph.eliminate_dead_code() + model = fold_weight(model, node_name_to_scope, keep_original_weights) + model.graph.eliminate_dead_code() + model.recompile() + model.graph.lint() + return model diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/convert.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/convert.py new file mode 100644 index 0000000000000000000000000000000000000000..9a19a40cab908baa78fffeb89f46eedc71976736 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/convert.py @@ -0,0 +1,1323 @@ +# mypy: ignore-errors + +import copy +import operator +import warnings +from typing import Any, TYPE_CHECKING + +import torch +from torch.ao.quantization import CUSTOM_KEY, NUMERIC_DEBUG_HANDLE_KEY +from torch.ao.quantization.backend_config import ( + BackendConfig, + get_native_backend_config, +) +from torch.ao.quantization.backend_config.utils import ( + get_fused_module_classes, + get_pattern_to_dtype_configs, + get_qat_module_classes, + get_root_module_to_quantized_reference_module, +) +from torch.ao.quantization.observer import _is_activation_post_process +from torch.ao.quantization.qconfig import qconfig_equals, QConfigAny +from torch.ao.quantization.qconfig_mapping import QConfigMapping +from torch.ao.quantization.quant_type import QuantType +from torch.ao.quantization.quantize import _remove_qconfig +from torch.ao.quantization.stubs import DeQuantStub +from torch.ao.quantization.utils import ( + _parent_name, + activation_is_statically_quantized, + get_qparam_dict, + get_swapped_custom_module_class, + is_per_channel, + to_underlying_dtype, + weight_is_quantized, +) +from torch.fx import GraphModule +from torch.fx.graph import Argument, Graph, Node +from torch.nn.utils.parametrize import type_before_parametrizations + +# importing the lib so that the quantized_decomposed ops are registered +from ._decomposed import quantized_decomposed_lib # noqa: F401 +from ._equalize import convert_eq_obs, update_obs_for_equalization +from .custom_config import ConvertCustomConfig, PrepareCustomConfig +from .graph_module import _is_observed_module, _is_observed_standalone_module +from .lower_to_fbgemm import lower_to_fbgemm +from .qconfig_mapping_utils import ( + _compare_prepare_convert_qconfig_mappings, + _generate_node_name_to_qconfig, + _is_qconfig_supported_by_dtype_configs, + _update_qconfig_for_fusion, + _update_qconfig_for_qat, +) +from .utils import ( + _get_module, + _is_custom_module_lstm, + _is_custom_module_mha, + assert_and_get_unique_device, + collect_producer_nodes, + create_getattr_from_value, + get_custom_module_class_keys, + graph_module_from_producer_nodes, + node_arg_is_weight, +) + + +if TYPE_CHECKING: + from collections.abc import Callable + + +__all__ = [ + "convert", + "convert_custom_module", + "convert_standalone_module", + "convert_weighted_module", +] + +SUPPORTED_QDTYPES = [ + torch.quint8, + torch.qint8, + torch.qint32, + torch.uint8, + torch.int8, + torch.uint16, + torch.int16, + torch.int32, + torch.float8_e5m2, + torch.float8_e4m3fn, +] + +_QSCHEME_TO_CHOOSE_QPARAMS_OP = { + torch.per_tensor_affine: torch.ops.quantized_decomposed.choose_qparams.tensor, + torch.per_tensor_symmetric: torch.ops.quantized_decomposed.choose_qparams_symmetric.tensor, +} + + +def _replace_observer_with_quantize_dequantize_node_decomposed( + model: torch.fx.GraphModule, + node: Node, + modules: dict[str, torch.nn.Module], + node_name_to_scope: dict[str, tuple[str, type]], + node_name_to_qconfig: dict[str, QConfigAny], + model_device: torch.device | None = None, +) -> None: + """Replace activation_post_process module call node with quantize and + dequantize node working with decomposed Tensor + + Before: + ... -> observer_0(x) -> ... + After: + ... -> torch.ops.quantized_decomposed.quantize_per_tensor(x, ...) -> + torch.ops.quantized_decomposed.dequantize_per_tensor() -> ... + + or quantize_per_channel and dequantize_per_channel + """ + graph = model.graph + if modules is None: + raise AssertionError("modules must not be None") + if not isinstance(node.target, str): + raise AssertionError( + f"Expected node.target to be a str, but got {type(node.target)}" + ) + module_path, prefix = _get_module_path_and_prefix( + node, node_name_to_scope, node_name_to_qconfig + ) + activation_post_process = modules[node.target] + if hasattr(activation_post_process, "convert"): + activation_post_process.convert(model, node) + return + # skip replacing observers to quant/dequant nodes if the qconfigs of all + # consumers and producers of this observer are None + skip_replacement = all( + _has_none_qconfig(n, node_name_to_qconfig) + for n in list(node.args) + list(node.users.keys()) + ) + if skip_replacement or not _is_conversion_supported(activation_post_process): + # didn't find corresponding quantize op and info for the activation_post_process + # so we just remove the observer + with graph.inserting_before(node): + node.replace_all_uses_with(node.args[0]) + graph.erase_node(node) + return + + # otherwise, we can convert the activation_post_process module call to quantize/dequantize node + + # 1. extract the information from activation_post_process module for generating + # the quantize and dequantize operator + dtype = activation_post_process.dtype # type: ignore[attr-defined] + + is_dynamic = False + if hasattr(activation_post_process, "is_dynamic"): + is_dynamic = activation_post_process.is_dynamic # type: ignore[assignment] + + def add_dequantize_op_kwargs(dequantize_op, input_node): + dequantize_op_kwargs = {} + if "val" in input_node.meta: + dq_out_dtype = input_node.meta["val"].dtype + if dq_out_dtype != torch.float32: + dequantize_op_kwargs = {"out_dtype": dq_out_dtype} + return dequantize_op_kwargs + + if dtype in SUPPORTED_QDTYPES and (not is_dynamic): + # TODO: probably should cleanup this condition check, it's hard + # to reason about this if and the following elif + + # uint8/int8/int32 static quantization branch + + # 1. extract information for inserting q/dq node from activation_post_process + node_type = "call_function" + quantize_op: Callable | None = None + scale, zero_point = activation_post_process.calculate_qparams() # type: ignore[attr-defined, operator] + if is_per_channel(activation_post_process.qscheme): # type: ignore[attr-defined] + ch_axis = int(activation_post_process.ch_axis) # type: ignore[attr-defined, arg-type] + quantize_op = torch.ops.quantized_decomposed.quantize_per_channel.default + dequantize_op = ( + torch.ops.quantized_decomposed.dequantize_per_channel.default + ) + quant_min = activation_post_process.quant_min + quant_max = activation_post_process.quant_max + dtype_ = to_underlying_dtype(dtype) + qparams = { + "_scale_": scale, + "_zero_point_": zero_point, + "_axis_": ch_axis, + "_quant_min_": quant_min, + "_quant_max_": quant_max, + "_dtype_": dtype_, + } + else: + quantize_op = torch.ops.quantized_decomposed.quantize_per_tensor.default + dequantize_op = torch.ops.quantized_decomposed.dequantize_per_tensor.default + scale = float(scale) + zero_point = int(zero_point) + quant_min = activation_post_process.quant_min # type: ignore[attr-defined] + quant_max = activation_post_process.quant_max # type: ignore[attr-defined] + dtype_ = to_underlying_dtype(dtype) + qparams = { + "_scale_": scale, + "_zero_point_": zero_point, + "_quant_min_": quant_min, + "_quant_max_": quant_max, + "_dtype_": dtype_, + } + + # 2. replace activation_post_process node with quantize and dequantize + with graph.inserting_before(node): + input_node = node.args[0] + quantize_op_inputs = [input_node] + for key, value_or_node in qparams.items(): + # TODO: we can add the information of whether a value needs to + # be registered as an attribute in qparams dict itself + if key in ["_scale_", "_zero_point_"] and ( + not isinstance(value_or_node, (float, int)) + ): + # For scale and zero_point values we register them as buffers in the root module. + # However, note that when the values are not tensors, as in the case of + # per_tensor quantization, they will be treated as literals. + # However, registering them as a node seems to cause issue with dynamo + # tracing where it may consider tensor overload as opposed to default. + # With extra check of scale and zero_point being scalar, it makes + # sure that the default overload can be used. + # TODO: maybe need more complex attr name here + qparam_node = create_getattr_from_value( + model, + graph, + module_path + prefix + key, + value_or_node, + model_device, + ) + quantize_op_inputs.append(qparam_node) + else: + # for qparams that are not scale/zero_point (like axis, dtype) we store them as literals in the graph. + quantize_op_inputs.append(value_or_node) + + quantized_node = graph.create_node( + node_type, quantize_op, tuple(quantize_op_inputs), {} + ) + # use the same qparams from quantize op + dq_inputs = [quantized_node] + quantize_op_inputs[1:] + dequantized_node = graph.call_function( + dequantize_op, + tuple(dq_inputs), + add_dequantize_op_kwargs(dequantize_op, input_node), + ) + + node.replace_all_uses_with(dequantized_node) + # propagate numeric debug handle from observer/fake_quant node to dequantize node + if ( + CUSTOM_KEY in node.meta + and NUMERIC_DEBUG_HANDLE_KEY in node.meta[CUSTOM_KEY] + ): + if CUSTOM_KEY not in dequantized_node.meta: + dequantized_node.meta[CUSTOM_KEY] = {} + dequantized_node.meta[CUSTOM_KEY][NUMERIC_DEBUG_HANDLE_KEY] = node.meta[ + CUSTOM_KEY + ][NUMERIC_DEBUG_HANDLE_KEY] + graph.erase_node(node) + elif is_dynamic: + # uint8/int8/fp16 dynamic quantization + + # 1. extract information for inserting q/dq node from activation_post_process + node_type = "call_function" + quantize_op = torch.ops.quantized_decomposed.quantize_per_tensor.tensor + # we only use choose_qparams for is_decomposed now, + # but we should probably align the non-decomposed path with this as well, + # and that can be done after we remove reduce_range flag + # 1. extract qparams from activation_post_process module + dtype_ = to_underlying_dtype(dtype) + if dtype_ not in [torch.uint8, torch.int8]: + raise AssertionError( + "only uint8 and int8 are supported in reference flow for dynamic quantization right now" + ) + quant_min = activation_post_process.quant_min # type: ignore[attr-defined] + quant_max = activation_post_process.quant_max # type: ignore[attr-defined] + qscheme = getattr(activation_post_process, "qscheme", torch.per_tensor_affine) # type: ignore[attr-defined] + eps = getattr(activation_post_process, "eps", torch.finfo(torch.float32).eps) # type: ignore[attr-defined] + # note: scale and zero_point are missing for quantize_per_tensor op + # we'll need to get this from choose_qparams op, which we'll add after + # this step + qparams = { + "_quant_min_": quant_min, + "_quant_max_": quant_max, + "_eps_": eps, + "_dtype_": dtype_, + } + + choose_qparams_op = _QSCHEME_TO_CHOOSE_QPARAMS_OP[qscheme] + # 2. insert choose_qparams op and update the qparams list + with graph.inserting_before(node): + input_node = node.args[0] + choose_qparams_op_inputs = [node.args[0]] + list(qparams.values()) + choose_qparams_node = graph.create_node( + "call_function", choose_qparams_op, tuple(choose_qparams_op_inputs), {} + ) + # choose_qparms returns (scale, zero_point) + scale_node = graph.create_node( + "call_function", operator.getitem, (choose_qparams_node, 0), {} + ) + zero_point_node = graph.create_node( + "call_function", operator.getitem, (choose_qparams_node, 1), {} + ) + # we have quant_min, quant_max and dtype, all should be stored + # as literals + quant_min = qparams["_quant_min_"] + quant_max = qparams["_quant_max_"] + dtype = qparams["_dtype_"] + qparams = { + "_scale_": scale_node, + "_zero_point_": zero_point_node, + "_quant_min_": quant_min, + "_quant_max_": quant_max, + "_dtype_": dtype, + } + + # 3. replace activation_post_process node to quantize and dequantize node + with graph.inserting_before(node): + input_node = node.args[0] + quantize_op_inputs = [input_node] + for key, value_or_node in qparams.items(): + # TODO: we can add the information of whether a value needs to + # be registered as an attribute in qparams dict itself + if key in ["_scale_", "_zero_point_"]: + # in this case we have a node in the graph since it's dynamically + # computed from the input, with choose_qparams op + qparam_node = value_or_node + quantize_op_inputs.append(qparam_node) + else: + # for qparams that are not scale/zero_point (like axis, dtype) we + # store them as literals in the graph. + quantize_op_inputs.append(value_or_node) + + quantized_node = graph.create_node( + node_type, quantize_op, tuple(quantize_op_inputs), {} + ) + # use the same qparams from quantize op + dq_inputs = [quantized_node] + quantize_op_inputs[1:] + # need to use the tensor variant of this op, since scale and zero_point + # from choose_qparam are Tensors, instead of float/int, this is to + # prevent these nodes being traced away by downstream systems + dequantize_op = torch.ops.quantized_decomposed.dequantize_per_tensor.tensor + dequantized_node = graph.call_function( + dequantize_op, + tuple(dq_inputs), + add_dequantize_op_kwargs(dequantize_op, input_node), + ) + + node.replace_all_uses_with(dequantized_node) + # propagate numeric debug handle from observer/fake_quant node to dequantize node + if NUMERIC_DEBUG_HANDLE_KEY in node.meta: + dequantized_node.meta[NUMERIC_DEBUG_HANDLE_KEY] = node.meta[ + NUMERIC_DEBUG_HANDLE_KEY + ] + graph.erase_node(node) + elif dtype == torch.float16: + # Insert to_fp16 -> to_fp32 node + dtype_convert_op = torch.ops.quantized_decomposed.convert_element_type.no_fuse + with graph.inserting_before(node): + input_node = node.args[0] + convert_fp16_node = graph.create_node( + "call_function", dtype_convert_op, (input_node, torch.float16), {} + ) + convert_fp32_node = graph.create_node( + "call_function", dtype_convert_op, (convert_fp16_node, torch.float), {} + ) + node.replace_all_uses_with(convert_fp32_node) + graph.erase_node(node) + + # should not reach since we have checks in the beginning to make sure the + # activation_post_process is supported + + +def _replace_observer_with_quantize_dequantize_node( + model: torch.fx.GraphModule, + node: Node, + modules: dict[str, torch.nn.Module], + node_name_to_scope: dict[str, tuple[str, type]], + node_name_to_qconfig: dict[str, QConfigAny], + model_device: torch.device | None = None, +) -> None: + """Replace activation_post_process module call node with quantize and + dequantize node + + Before: + ... -> observer_0(x) -> ... + After: + ... -> torch.quantize_per_tensor(x, ...) -> x.dequantize() -> ... + """ + if modules is None: + raise AssertionError("modules must not be None") + if not isinstance(node.target, str): + raise AssertionError( + f"Expected node.target to be a str, but got {type(node.target)}" + ) + graph = model.graph + module_path, prefix = _get_module_path_and_prefix( + node, node_name_to_scope, node_name_to_qconfig + ) + activation_post_process = modules[node.target] + # skip replacing observers to quant/dequant nodes if the qconfigs of all + # consumers and producers of this observer are None + skip_replacement = all( + _has_none_qconfig(n, node_name_to_qconfig) + for n in list(node.args) + list(node.users.keys()) + ) + if skip_replacement or not _is_conversion_supported(activation_post_process): + # didn't find corresponding quantize op and info for the activation_post_process + # so we just remove the observer + with graph.inserting_before(node): + node.replace_all_uses_with(node.args[0]) + graph.erase_node(node) + return + + # otherwise, we can convert the activation_post_process module call to quantize/dequantize node + dtype = activation_post_process.dtype # type: ignore[attr-defined] + + is_dynamic = False + if hasattr(activation_post_process, "is_dynamic"): + is_dynamic = activation_post_process.is_dynamic # type: ignore[attr-defined, assignment] + + if dtype in [ + torch.quint8, + torch.qint8, + torch.qint32, + torch.float8_e5m2, + torch.float8_e4m3fn, + ] and (not is_dynamic): + # TODO: probably should cleanup this condition check, it's hard + # to reason about this if and the following elif + + # uint8/int8/int32 static quantization branch + + # 1. extract the information from activation_post_process module for generating + # the quantize and dequantize operator + node_type = "call_function" + quantize_op: Callable | None = None + scale, zero_point = activation_post_process.calculate_qparams() # type: ignore[attr-defined, operator] + if is_per_channel(activation_post_process.qscheme): # type: ignore[attr-defined] + ch_axis = int(activation_post_process.ch_axis) # type: ignore[attr-defined, arg-type] + qparams = { + "_scale_": scale, + "_zero_point_": zero_point, + "_axis_": ch_axis, + "_dtype_": dtype, + } + quantize_op = torch.quantize_per_channel + else: + scale = float(scale) + zero_point = int(zero_point) + qparams = {"_scale_": scale, "_zero_point_": zero_point, "_dtype_": dtype} + quantize_op = torch.quantize_per_tensor + + # 2. replace activation_post_process node with quantize and dequantize + with graph.inserting_before(node): + input_node = node.args[0] + quantize_op_inputs = [input_node] + for key, value_or_node in qparams.items(): + # TODO: we can add the information of whether a value needs to + # be registered as an attribute in qparams dict itself + if key in ["_scale_", "_zero_point_"]: + # For scale and zero_point values we register them as buffers in the root module. + # TODO: maybe need more complex attr name here + qparam_node = create_getattr_from_value( + model, + graph, + module_path + prefix + key, + value_or_node, + model_device, + ) + quantize_op_inputs.append(qparam_node) + else: + # for qparams that are not scale/zero_point (like axis, dtype) we store them as literals in the graph. + quantize_op_inputs.append(value_or_node) + + quantized_node = graph.create_node( + node_type, quantize_op, tuple(quantize_op_inputs), {} + ) + dequantized_node = graph.call_method("dequantize", args=(quantized_node,)) + node.replace_all_uses_with(dequantized_node) + graph.erase_node(node) + elif is_dynamic: + # uint8/int8/fp16 dynamic quantization branch + + node_type = "call_function" + quantize_op = torch.quantize_per_tensor_dynamic + # TODO: get reduce range from observer + # reduce_range = activation_post_process.reduce_range + reduce_range = torch.backends.quantized.engine in ("fbgemm", "x86") + qparams = {"_dtype_": dtype, "_reduce_range_": reduce_range} + + with graph.inserting_before(node): + input_node = node.args[0] + quantize_op_inputs = [input_node] + for value in qparams.values(): + quantize_op_inputs.append(value) + + quantized_node = graph.create_node( + node_type, quantize_op, tuple(quantize_op_inputs), {} + ) + dequantized_node = graph.call_method("dequantize", args=(quantized_node,)) + node.replace_all_uses_with(dequantized_node) + graph.erase_node(node) + elif dtype == torch.float16: + node_type = "call_method" + quantize_op = "to" # type: ignore[assignment] + qparams = {"_dtype_": dtype} + with graph.inserting_before(node): + input_node = node.args[0] + quantize_op_inputs = [input_node] + for value in qparams.values(): + # TODO: we can add the information of whether a value needs to + # be registered as an attribute in qparams dict itself + quantize_op_inputs.append(value) + + quantized_node = graph.create_node( + node_type, quantize_op, tuple(quantize_op_inputs), {} + ) + dequantized_node = graph.call_method("dequantize", args=(quantized_node,)) + node.replace_all_uses_with(dequantized_node) + graph.erase_node(node) + + # should not reach since we have checks in the beginning to make sure the + # activation_post_process is supported + + +# this is a temporary hack for custom module, we may want to implement +# this properly after the custom module class design is finalized +# TODO: DeQuantStubs are currently inserted only after custom module LSTM, while observers are inserted +# after all other custom modules. In the future, we should simply insert QuantStubs before and DeQuantStubs +# after custom modules in general, and replace these with "quantize" and "dequantize" nodes respectively. +def _replace_observer_or_dequant_stub_with_dequantize_node( + node: Node, graph: Graph +) -> None: + call_custom_module_node = node.args[0] + if not isinstance(call_custom_module_node, Node): + raise AssertionError( + f"Expecting the for call custom module node to be a Node, but got {call_custom_module_node}" + ) + node.replace_all_uses_with(call_custom_module_node) + graph.erase_node(node) + _insert_dequantize_node(call_custom_module_node, graph) + + +def _is_conversion_supported(activation_post_process: torch.nn.Module) -> bool: + dtype = activation_post_process.dtype # type: ignore[attr-defined] + + is_dynamic = False + if hasattr(activation_post_process, "is_dynamic"): + is_dynamic = activation_post_process.is_dynamic # type: ignore[attr-defined, assignment] + + return ( + (dtype in SUPPORTED_QDTYPES and (not is_dynamic)) + or is_dynamic # type: ignore[return-value] + or dtype == torch.float16 + ) + + +def _has_none_qconfig( + node: Argument, node_name_to_qconfig: dict[str, QConfigAny] +) -> bool: + """Check if a node has a qconfig of None, i.e. user requested to not quantize + the node + """ + return ( + isinstance(node, Node) + and node.name in node_name_to_qconfig + and node_name_to_qconfig[node.name] is None + ) + + +def _run_weight_observers(observed: GraphModule, backend_config: BackendConfig) -> None: + """Extract the subgraph that produces the weight for dynamic quant + or weight only quant node and run the subgraph to observe the weight. + Note that the observers of dynamic quant or weight only quant ops are + run during the convert step. + """ + for node in observed.graph.nodes: + if node.op != "call_function": + continue + for node_arg in node.args: + # node_arg is weight + if node_arg and node_arg_is_weight(node, node_arg): + weight_observer_nodes = collect_producer_nodes(node_arg) + if weight_observer_nodes is None: + continue + weight_observer_module = graph_module_from_producer_nodes( + observed, weight_observer_nodes + ) + # run the weight observer + weight_observer_module() + + +def _maybe_recursive_remove_dequantize(arg: Any, node: Node, graph: Graph) -> None: + """If the arg is a dequantize Node, or a list/tuple/dict of dequantize Node, + we'll recursively remove the dequantize Node + """ + if isinstance(arg, Node) and arg.op == "call_method" and arg.target == "dequantize": + quantize_node = arg.args[0] + # we only replace the specific use since dequantize could be used by other nodes + # as well + node.replace_input_with(arg, quantize_node) + elif isinstance(arg, (list, tuple)): + for arg_element in arg: + _maybe_recursive_remove_dequantize(arg_element, node, graph) + elif isinstance(arg, dict): + for arg_element in arg.values(): + _maybe_recursive_remove_dequantize(arg_element, node, graph) + else: + warnings.warn( + f"Unsupported node type in recursive remove dequantize: {type(arg)}", + stacklevel=2, + ) + + +def _get_module_path_and_prefix( + obs_node: Node, + node_name_to_scope: dict[str, tuple[str, type]], + node_name_to_qconfig: dict[str, QConfigAny], +) -> tuple[str, str]: + """Given and observer node, get the `Scope` or the fully qualified name for + the submodule containing the observed node, also return a prefix of "_input" + when the observed node is an input of a F.linear op, and not the output of another + quantized op. + TODO: this logic is hacky, we should think about how to remove it or make it more + general + """ + observed_node = obs_node.args[0] + # an observer can be inserted for both input of the next operator or output of the previous + # operator (they can be the same) + # this flag identifies if the observer is inserted only because the observed node is + # the input of the next operator + if not isinstance(observed_node, Node): + raise AssertionError( + f"Expecting observed node to be a Node, but got {observed_node}" + ) + is_input_observer_only = ( + node_name_to_qconfig[observed_node.name] is None + if observed_node.name in node_name_to_qconfig + else None + ) + if is_input_observer_only: + # if the quantize function is at the input of op, then we find the first user of the observer_node + # to get the path. If a linear call_function is in the user list, we return the first instance + # of linear node to get the FQN. + users = list(obs_node.users) + first_linear_use_or_first_use = users[0] if users else None + linear_node = None + for n in users: + if n.op == "call_function" and n.target is torch.nn.functional.linear: + linear_node = n + break + if linear_node: + first_linear_use_or_first_use = linear_node + prefix = "_input" + else: + # if the quantize function is at the output of the op, we use the observer input node to get the path + first_linear_use_or_first_use = observed_node + prefix = "" + + if ( + first_linear_use_or_first_use + and first_linear_use_or_first_use.name in node_name_to_scope + ): + module_path, _ = node_name_to_scope[first_linear_use_or_first_use.name] + else: + # TODO: it's not used, so actually we can skip quantization + # but this requires changing return type of quantize_node + # we can fix it later if needed + module_path = "" + return module_path, prefix + + +def _insert_dequantize_node(node: Node, graph: Graph) -> None: + """Inserts dequantize node for `node` in `graph`""" + with graph.inserting_after(node): + dequantize_node = graph.call_method("dequantize", (node,)) + for user_node in dict(node.users): + if user_node is not dequantize_node: + user_node.replace_input_with(node, dequantize_node) + + +def _maybe_get_observer_for_node( + node: Node, modules: dict[str, torch.nn.Module] +) -> torch.nn.Module | None: + """ + If the node is observed, return the observer + instance. Otherwise, return None. + """ + for maybe_obs_node in node.users: + if maybe_obs_node.op == "call_module": + maybe_obs = modules[str(maybe_obs_node.target)] + if _is_activation_post_process(maybe_obs): + return maybe_obs + return None + + +def convert_standalone_module( + node: Node, + modules: dict[str, torch.nn.Module], + model: torch.fx.GraphModule, + is_reference: bool, + backend_config: BackendConfig | None, +) -> None: + """Converts a observed standalone module to a quantized standalone module by calling + the fx convert api, currently using the same `is_reference` flag as parent, but we may + changing this behavior in the future (e.g. separating quantization and lowering for + standalone module as well) + + Args: + - node: The call_module node of the observed standalone module + - modules: named_module of original model + - model: original model + - is_reference: a flag from parent provided by user to decide if we want to + produce a reference model or a fbgemm/qnnpack model + - backend_config: backend configuration of the target backend of quantization + """ + # TODO: remove is_reference flag + if is_reference: + convert_fn = torch.ao.quantization.quantize_fx.convert_to_reference_fx + else: + convert_fn = torch.ao.quantization.quantize_fx.convert_fx # type: ignore[attr-defined] + # We know that observed standalone module is a GraphModule since + # it's produced by us + observed_standalone_module: GraphModule = modules[str(node.target)] # type: ignore[assignment] + sm_input_quantized_idxs = observed_standalone_module.meta[ + "_observed_graph_module_attrs" + ].standalone_module_input_quantized_idxs + # remove the dequantize nodes for inputs + args = list(node.args) + for idx in range(len(args)): + if idx in sm_input_quantized_idxs: + arg = args[idx] + if arg.op == "call_method" and arg.target == "dequantize": # type: ignore[union-attr] + quantize_node = arg.args[0] # type: ignore[union-attr] + node.replace_input_with(arg, quantize_node) + if len(arg.users) == 0: # type: ignore[union-attr] + model.graph.erase_node(arg) + # add dequantize node for output + sm_output_quantized_idxs = observed_standalone_module.meta[ + "_observed_graph_module_attrs" + ].standalone_module_output_quantized_idxs + if len(sm_output_quantized_idxs) > 0: + if sm_output_quantized_idxs[0] != 0: + raise AssertionError( + "Currently only quantized output idxs = [0] is supported" + ) + + # if it's non-empty, then it means the output is kept in quantized form + # we'll just add a dequantize node after this node + _insert_dequantize_node(node, model.graph) + + # TODO: allow convert_custom_config to override backend_config + # for standalone module + quantized_standalone_module = convert_fn( + observed_standalone_module, backend_config=backend_config + ) + parent_name, name = _parent_name(node.target) + # update the modules dict + setattr(modules[parent_name], name, quantized_standalone_module) + modules[str(node.target)] = quantized_standalone_module + + +def convert_weighted_module( + node: Node, + modules: dict[str, torch.nn.Module], + observed_node_names: set[str], + node_name_to_qconfig: dict[str, QConfigAny], + backend_config: BackendConfig, + is_decomposed: bool = False, + is_reference: bool = False, + model_device: torch.device | None = None, +) -> None: + """Convert a weighted module to reference quantized module in the model + If the QConfig of a QAT module is not set, the module will still be converted to + a float module. + + Args: + - node: The call_module node of the observed standalone module + - modules: named_module of original model + - observed_node_names: names for the set of observed fx node, we can skip + this conversion if the node is not observed + """ + original_module = modules[str(node.target)] + qconfig: QConfigAny = original_module.qconfig # type: ignore[assignment] + weight_post_process = None + qat_module_classes = get_qat_module_classes(backend_config) + + if isinstance(original_module, qat_module_classes): + # Converting qat module to a float module, we need to attach + # weight fake_quant to the module, weight fake_quant is assumed to be run during + # QAT so we don't need to run it again here + weight_post_process = original_module.weight_fake_quant + original_module = original_module.to_float() # type: ignore[operator] + # change qat module to float module + parent_name, name = _parent_name(node.target) + setattr(modules[parent_name], name, original_module) + + is_observed = node.name in observed_node_names + # If a qconfig is not defined for this node, then skip converting to a reference module + if ( + qconfig is None + or _has_none_qconfig(node, node_name_to_qconfig) + or not is_observed + ): + return + + # skip converting to reference quantized module if the qconfig is not supported + pattern_to_dtype_configs = get_pattern_to_dtype_configs(backend_config) + dtype_configs = pattern_to_dtype_configs.get(type(original_module), []) + if not _is_qconfig_supported_by_dtype_configs(qconfig, dtype_configs): + return + + # TODO: rename weight_is_statically_quantized to weight_is_int8_quantized + is_weight_quantized = weight_is_quantized(qconfig) + + # the condition for swapping the module to reference quantized module is: + # weights need to be quantized + if not is_weight_quantized: + return + + fused_module = None + float_module = original_module + # extract the individual float_module and fused module + if isinstance(original_module, torch.ao.nn.intrinsic._FusedModule): + fused_module = float_module + float_module = fused_module[0] # type: ignore[index] + + # TODO: move this to the reference quantized module + # weight_qparams or weight_qparams dict + wq_or_wq_dict = {"is_decomposed": is_decomposed} + if isinstance(float_module, torch.nn.RNNCellBase): + weight_post_process_ih = qconfig.weight() # type: ignore[union-attr, operator] + weight_post_process_hh = qconfig.weight() # type: ignore[union-attr, operator] + weight_post_process_ih(float_module.weight_ih) + weight_post_process_hh(float_module.weight_hh) + weight_qparams_ih = get_qparam_dict(weight_post_process_ih) + weight_qparams_hh = get_qparam_dict(weight_post_process_hh) + wq_or_wq_dict.update( + { + "weight_ih": weight_qparams_ih, + "weight_hh": weight_qparams_hh, + } + ) + elif isinstance(float_module, (torch.nn.LSTM, torch.nn.GRU)): + # format for wq_or_wq_dict (flattened attributes): + # {"weight_ih_l0_scale": ..., "weight_ih_l0_qscheme": ..., ...} + for wn in float_module._flat_weights_names: + if hasattr(float_module, wn) and wn.startswith("weight"): + weight = getattr(float_module, wn) + weight_post_process = qconfig.weight() # type: ignore[union-attr, operator] + if weight_post_process.dtype == torch.qint8: # type: ignore[union-attr] + weight_post_process(weight) # type: ignore[operator, misc] + wq_or_wq_dict[wn] = get_qparam_dict(weight_post_process) + else: + # weight_post_process is None means the original module is not a QAT module + # we need to get weight_post_process from qconfig in this case + is_ptq = weight_post_process is None + if is_ptq: + weight_post_process = qconfig.weight() # type: ignore[union-attr, operator] + if model_device is not None: + device = model_device + else: + device = assert_and_get_unique_device(float_module) + if device: + weight_post_process.to(device) + + # Call weight observer/fake_quant at least once to ensure the scales and zero points + # have the right shapes. Note: there are two cases where we don't have to do this: + # + # (1) QAT: The model's forward method already calls the weight observer/fake_quant, + # and this typically happens during training, so we don't need to do it here. + # + # (2) Non-reference (lowered) case: The quantized module's from_float method already + # calls the weight observer/fake_quant, so we don't have to do it here. + # + # Currently we ignore both cases and call the weight observer/fake_quant here + # regardless, which is technically incorrect. For (1), this is mainly to preserve BC + # in test code, which may not always train before convert. In the future, we should + # break BC for these two cases. See https://github.com/pytorch/pytorch/issues/73941. + # + # For PT2, however, we don't need to preserve BC here, so we can skip this hack + # for QAT. We identify this case as (is_decomposed + is_reference + is_qat). + # Note that we still need it for PTQ in the PT2 flow since the model's forward + # method doesn't call the weight observer. + is_qat = not is_ptq + if not (is_decomposed and is_reference and is_qat): + weight_post_process(float_module.weight) # type: ignore[operator] + + wq_or_wq_dict.update(get_qparam_dict(weight_post_process)) + + # We use the same reference module for all modes of quantization: static, dynamic, weight_only + # root_module_to_quantized_reference_module: module mapping from root (floating point) module class + # to quantized reference module class, e.g. nn.Conv2d to nn.quantized._reference.Conv2d + root_module_to_quantized_reference_module = ( + get_root_module_to_quantized_reference_module(backend_config) + ) + ref_qmodule_cls = root_module_to_quantized_reference_module.get( + type_before_parametrizations(float_module), None + ) + if ref_qmodule_cls is None: + raise AssertionError( + f"No reference quantized module class configured for {type_before_parametrizations(float_module)}" + ) + ref_qmodule = ref_qmodule_cls.from_float(float_module, wq_or_wq_dict) # type: ignore[attr-defined] + if fused_module is not None: + fused_module[0] = ref_qmodule # type: ignore[operator] + else: + parent_name, name = _parent_name(node.target) + setattr(modules[parent_name], name, ref_qmodule) + + +def _remove_previous_dequantize_in_custom_module( + node: Node, prev_node: Node, graph: Graph +) -> None: + """ + Given a custom module `node`, if the previous node is a dequantize, reroute the custom as follows: + + Before: quantize - dequantize - custom_module + After: quantize - custom_module + \\ - dequantize + """ + # expecting the input node for a custom module node to be a Node + if not isinstance(prev_node, Node): + raise AssertionError( + f"Expecting the argument for custom module node to be a Node, but got {prev_node}" + ) + if prev_node.op == "call_method" and prev_node.target == "dequantize": + node.replace_input_with(prev_node, prev_node.args[0]) + # Remove the dequantize node if it doesn't have other users + if len(prev_node.users) == 0: + graph.erase_node(prev_node) + + +def convert_custom_module( + node: Node, + graph: Graph, + modules: dict[str, torch.nn.Module], + custom_module_class_mapping: dict[QuantType, dict[type, type]], + statically_quantized_custom_module_nodes: set[Node], +) -> None: + """Converts an observed custom module to a quantized custom module based on + `custom_module_class_mapping` + For static quantization, we'll also remove the previous `dequantize` node and + attach the observer node for output to the module, the observer for the node + will be converted to a dequantize node instead of quantize-dequantize pairs + later in the graph. In the end we would have a quantized custom module that + has the same interface as a default quantized module in nn.quantized namespace, + i.e. quantized input and quantized output. + + Args: + - node: The call_module node of the observed standalone module + - graph: The graph containing the node + - modules: named_module of original model + - custom_module_class_mapping: mapping from observed custom module class to + quantized custom module class, used to swap custom modules + - statically_quantized_custom_module_nodes: we'll add the custom module node + if we find it is statically quantized, this will be used later when converting + observers to quant/dequant node pairs, if the observed node is a statically + quantized custom module nodes, we'll convert the observer to a dequantize node, + this is to keep the interface the same as the default quantized module. + TODO: maybe we want to redesign this part to align with reference model design + as well, but there has been some discussions around the interface, so we can do + it later. + """ + observed_custom_module = modules[str(node.target)] + qconfig = observed_custom_module.qconfig + if activation_is_statically_quantized(qconfig): + statically_quantized_custom_module_nodes.add(node) + if _is_custom_module_lstm(node, modules): + # The inputs are tuples in the form (input, (hidden0, hidden1)) + # Ensure all three input nodes are quantized + if not ( + len(node.args) == 2 + and isinstance(node.args[1], tuple) + and len(node.args[1]) == 2 + ): + raise AssertionError( + "Expected LSTM custom module inputs to be (input, (hidden0, hidden1))" + ) + (inputs, (hidden0, hidden1)) = node.args # type: ignore[misc] + if not isinstance(inputs, Node): + raise AssertionError("Expected inputs to be a Node") + if not isinstance(hidden0, Node): + raise AssertionError("Expected hidden0 to be a Node") + if not isinstance(hidden1, Node): + raise AssertionError("Expected hidden1 to be a Node") + _remove_previous_dequantize_in_custom_module(node, inputs, graph) + _remove_previous_dequantize_in_custom_module(node, hidden0, graph) + _remove_previous_dequantize_in_custom_module(node, hidden1, graph) + elif _is_custom_module_mha(node, modules): + # Inputs are in the form (query, key, value) + # TODO: This is the first step in enabling the full fx custom module + # quantization path for MultiheadAttention, and only covers the inputs + # to the module. + # Additional handling is yet to be implemented for the outputs, similar + # to LSTM custom module + if len(node.args) != 3: + raise AssertionError( + "Expected MHA custom module inputs to be (query, key, value)" + ) + query, key, value = node.args + if not isinstance(query, Node): + raise AssertionError("Expected query to be a Node") + if not isinstance(key, Node): + raise AssertionError("Expected key to be a Node") + if not isinstance(value, Node): + raise AssertionError("Expected value to be a Node") + _remove_previous_dequantize_in_custom_module(node, query, graph) + _remove_previous_dequantize_in_custom_module(node, key, graph) + _remove_previous_dequantize_in_custom_module(node, value, graph) + else: + # remove the previous dequant node to ensure the inputs are quantized + arg = node.args[0] + if not isinstance(arg, Node): + raise AssertionError("Expected arg to be a Node") + _remove_previous_dequantize_in_custom_module(node, arg, graph) + # absorb the following observer into the module conversion + activation_post_process = _maybe_get_observer_for_node(node, modules) + if activation_post_process is None: + raise AssertionError( + "Expected activation_post_process to be present for observed custom module" + ) + observed_custom_module.activation_post_process = activation_post_process + + # swap the observed custom module to quantized custom module + quantized_custom_module_class = get_swapped_custom_module_class( + observed_custom_module, custom_module_class_mapping, qconfig + ) + quantized_custom_module = quantized_custom_module_class.from_observed( + observed_custom_module + ) + parent_name, name = _parent_name(node.target) + setattr(modules[parent_name], name, quantized_custom_module) + + +def convert( + model: GraphModule, + is_reference: bool = False, + convert_custom_config: ConvertCustomConfig | dict[str, Any] | None = None, + is_standalone_module: bool = False, + _remove_qconfig_flag: bool = True, + qconfig_mapping: QConfigMapping | dict[str, Any] | None = None, + backend_config: BackendConfig | dict[str, Any] | None = None, + is_decomposed: bool = False, + keep_original_weights: bool = False, +) -> GraphModule: + """ + We will convert an observed model (a module with observer calls) to a reference + quantized model, the rule is simple: + 1. for each observer module call in the graph, we'll convert it to calls to + quantize and dequantize functions based on the observer instance + 2. for weighted operations like linear/conv, we need to convert them to reference + quantized module, this requires us to know whether the dtype configured for the + weight is supported in the backend, this is done in prepare step and the result + is stored in observed_node_names, we can decide whether we need to swap the + module based on this set + + Args: + * `is_standalone_module`: when this flag is True, it means we are quantizing + a submodule that is not inlined in parent module, and will be quantized + separately as one unit. + + * `is_decomposed`: a boolean flag to indicate whether we want to use the + quantize operator for decomposed quantized tensor + (torch.ops.quantized_decomposed.quantize_per_tensor) or default/standalone + quantized tensor (torch.quantize_per_tensor) + + Returns: + a quantized standalone module, whether input/output is quantized is + specified by prepare_custom_config, with + input_quantized_idxs, output_quantized_idxs, please + see docs for :func:`~torch.ao.quantization.prepare_fx` for details + """ + if convert_custom_config is None: + convert_custom_config = ConvertCustomConfig() + + if isinstance(convert_custom_config, dict): + warnings.warn( + "Passing a convert_custom_config_dict to convert is deprecated and will not be supported " + "in a future version. Please pass in a ConvertCustomConfig instead.", + FutureWarning, + stacklevel=2, + ) + convert_custom_config = ConvertCustomConfig.from_dict(convert_custom_config) + + if isinstance(qconfig_mapping, dict): + warnings.warn( + "Passing a QConfig dictionary to convert is deprecated and will not be supported " + "in a future version. Please pass in a QConfigMapping instead.", + FutureWarning, + stacklevel=2, + ) + qconfig_mapping = ( + QConfigMapping.from_dict(qconfig_mapping) if qconfig_mapping else None + ) + qconfig_mapping = copy.deepcopy(qconfig_mapping) + if not (qconfig_mapping is None or isinstance(qconfig_mapping, QConfigMapping)): + raise AssertionError("qconfig_mapping must be None or a QConfigMapping") + + if isinstance(backend_config, dict): + warnings.warn( + "Passing a backend_config_dict to prepare is deprecated and will not be supported " + "in a future version. Please pass in a BackendConfig instead.", + FutureWarning, + stacklevel=2, + ) + backend_config = BackendConfig.from_dict(backend_config) + + if backend_config is None: + backend_config = get_native_backend_config() + + if not _is_observed_module(model): + raise AssertionError("incoming model must be produced by prepare_fx") + observed_graph_module_attrs = model.meta["_observed_graph_module_attrs"] + node_name_to_scope: dict[str, tuple[str, type]] = ( + observed_graph_module_attrs.node_name_to_scope + ) + prepare_custom_config: PrepareCustomConfig = ( + observed_graph_module_attrs.prepare_custom_config + ) + observed_node_names: set[str] = observed_graph_module_attrs.observed_node_names + node_name_to_qconfig: dict[str, QConfigAny] = ( + observed_graph_module_attrs.node_name_to_qconfig + ) # type: ignore[assignment] + + # mapping from fully qualified module name to module instance + # for example, + # { + # '': Model(...), + # 'linear': Linear(...), + # 'linear.weight_fake_quant': PerChannelMinMaxObserver(...), + # } + # We use remove_duplicate=False here because torch.cat uses + # the same activation_post_process module instance but different names + modules = dict(model.named_modules(remove_duplicate=False)) + + # TODO refactor this code once we update the prepare logic to have additional information on + # which graph nodes have been observed and share that with convert to decide which observers to ignore. + if qconfig_mapping: + prepare_qconfig_mapping: QConfigMapping = ( + observed_graph_module_attrs.qconfig_mapping + ) # type: ignore[assignment] + modules_copy = copy.deepcopy(modules) + + if observed_graph_module_attrs.is_qat: + _update_qconfig_for_qat(qconfig_mapping, backend_config) + _update_qconfig_for_fusion(model, qconfig_mapping) + + _compare_prepare_convert_qconfig_mappings( + prepare_qconfig_mapping, qconfig_mapping + ) # type: ignore[arg-type] + convert_node_name_to_qconfig = _generate_node_name_to_qconfig( + model, modules_copy, model.graph, qconfig_mapping, node_name_to_scope + ) + # check the convert_node_name_to_qconfig generated and ensure that + # all the values either match what was set in prepare node_name_to_qconfig + # or are set to None in the convert_node_name_to_qconfig. + for k, v in node_name_to_qconfig.items(): + if k not in convert_node_name_to_qconfig: + raise AssertionError( + f"Expected key {k} in convert node_name_to_qconfig" + ) + if convert_node_name_to_qconfig[k] is not None: + if not qconfig_equals(v, convert_node_name_to_qconfig[k]): + raise AssertionError( + f"Expected k {k} to have the same value in prepare and convert QConfigMappings, " + f"but {v} was updated to {convert_node_name_to_qconfig[k]}" + ) + node_name_to_qconfig = convert_node_name_to_qconfig + + custom_module_classes = get_custom_module_class_keys( + convert_custom_config.observed_to_quantized_mapping + ) + custom_module_class_mapping = convert_custom_config.observed_to_quantized_mapping + + if observed_graph_module_attrs.equalization_node_name_to_qconfig is not None: + # If we want to do equalization then do the following: + # Calculate the equalization scale, update the observers with the scaled + # inputs, and scale the weight + weight_eq_obs_dict = update_obs_for_equalization(model, modules) + convert_eq_obs(model, modules, weight_eq_obs_dict) + + # always run weight observers in the top level forward method + # for dynamic quant ops or weight only quant ops + _run_weight_observers(model, backend_config) + + # additional state to override inputs to be quantized, if specified + # by the user + placeholder_node_seen_cnt = 0 + input_quantized_idxs: list[int] = prepare_custom_config.input_quantized_indexes + output_quantized_idxs: list[int] = prepare_custom_config.output_quantized_indexes + + root_module_to_quantized_reference_module = ( + get_root_module_to_quantized_reference_module(backend_config) + ) + # convert tuples so that it can work with isinstance(module, tuple_of_classes) + root_module_classes = tuple(root_module_to_quantized_reference_module.keys()) + qat_module_classes = get_qat_module_classes(backend_config) + fused_module_classes = get_fused_module_classes(backend_config) + statically_quantized_custom_module_nodes: set[Node] = set() + model_device = assert_and_get_unique_device(model) + + for node in list(model.graph.nodes): + if node.op == "placeholder": + cur_placeholder_node_idx = placeholder_node_seen_cnt + placeholder_node_seen_cnt += 1 + if cur_placeholder_node_idx in input_quantized_idxs: + # Inputs are assumed to be quantized if the user specified the + # input_quantized_idxs override. + # we need to dequantize the inputs since all operators took + # floating point inputs in reference quantized models + _insert_dequantize_node(node, model.graph) + elif node.op == "output": + # If the argument is empty we don't need to do anything + if len(output_quantized_idxs) == 0: + continue + # Result are kept quantized if the user specified the + # output_quantized_idxs override. + # Remove the dequantize operator for the node in the end if any + return_node = node + output = node.args[0] + # outputs can be Node, list, tuple, dict, other cases are not supported yet + if isinstance(output, (list, tuple)): + for idx in output_quantized_idxs: + _maybe_recursive_remove_dequantize( + output[idx], return_node, model.graph + ) + elif isinstance(output, (Node, dict)): + # we treat dict as a single argument currently, but it can be extended + # to support {"key": dtype} after we change output_quantized_idxs to + # dict + if 0 in output_quantized_idxs: + _maybe_recursive_remove_dequantize(output, return_node, model.graph) + else: + warnings.warn( + f"Unsupported node type for output_quantized_idxs: {type(output)}", + stacklevel=2, + ) + elif node.op == "call_module": + mod = _get_module(node, modules) + if mod is None: + raise AssertionError( + "Expected module for call_module node to be present in modules mapping" + ) + if _is_activation_post_process(mod): + observed_node = node.args[0] + if observed_node in statically_quantized_custom_module_nodes: + _replace_observer_or_dequant_stub_with_dequantize_node( + node, model.graph + ) + else: + if is_decomposed: + _replace_observer_with_quantize_dequantize_node_decomposed( + model, + node, + modules, + node_name_to_scope, + node_name_to_qconfig, + model_device, + ) + else: + _replace_observer_with_quantize_dequantize_node( + model, + node, + modules, + node_name_to_scope, + node_name_to_qconfig, + model_device, + ) + elif isinstance(mod, DeQuantStub): + _replace_observer_or_dequant_stub_with_dequantize_node( + node, model.graph + ) + elif _is_observed_standalone_module(mod): + convert_standalone_module( + node, modules, model, is_reference, backend_config + ) + # below this point `type_before_parametrizations` is used + # instead of `type` to handle situations with fx quant + sparsity + elif type_before_parametrizations(mod) in set(root_module_classes).union( + qat_module_classes + ).union(fused_module_classes): + # extra check for fused module classes to make sure they are fused module classes + # of target modules + if ( + type_before_parametrizations(mod) in fused_module_classes + and type_before_parametrizations(mod[0]) not in root_module_classes + ): # type: ignore[index] + continue + convert_weighted_module( + node, + modules, + observed_node_names, + node_name_to_qconfig, + backend_config, + is_decomposed, + is_reference, + model_device, + ) + elif type_before_parametrizations(mod) in custom_module_classes: + convert_custom_module( + node, + model.graph, + modules, + custom_module_class_mapping, + statically_quantized_custom_module_nodes, + ) + + # remove deadcode after converting observers to quant/dequant ops + model.graph.eliminate_dead_code() + model = GraphModule(model, model.graph) + + # TODO: maybe move this to quantize_fx.py + if not is_reference: + model = lower_to_fbgemm( + model, node_name_to_qconfig, node_name_to_scope, keep_original_weights + ) + + # TODO: this looks hacky, we want to check why we need this and see if we can + # remove this + # removes qconfig and activation_post_process modules + if _remove_qconfig_flag: + _remove_qconfig(model) + model.delete_all_unused_submodules() + model.meta.pop("_observed_graph_module_attrs", None) + return model diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/custom_config.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/custom_config.py new file mode 100644 index 0000000000000000000000000000000000000000..e749de94bd5c3d1eb0c34a14cfcf38d441aedbff --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/custom_config.py @@ -0,0 +1,521 @@ +# mypy: allow-untyped-defs +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any + +from torch.ao.quantization import QConfigMapping +from torch.ao.quantization.backend_config import BackendConfig +from torch.ao.quantization.quant_type import ( + _get_quant_type_to_str, + _quant_type_from_str, + QuantType, +) + + +__all__ = [ + "ConvertCustomConfig", + "FuseCustomConfig", + "PrepareCustomConfig", + "StandaloneModuleConfigEntry", +] + + +# TODO: replace all usages with these constants +STANDALONE_MODULE_NAME_DICT_KEY = "standalone_module_name" +STANDALONE_MODULE_CLASS_DICT_KEY = "standalone_module_class" +FLOAT_TO_OBSERVED_DICT_KEY = "float_to_observed_custom_module_class" +OBSERVED_TO_QUANTIZED_DICT_KEY = "observed_to_quantized_custom_module_class" +NON_TRACEABLE_MODULE_NAME_DICT_KEY = "non_traceable_module_name" +NON_TRACEABLE_MODULE_CLASS_DICT_KEY = "non_traceable_module_class" +INPUT_QUANTIZED_INDEXES_DICT_KEY = "input_quantized_idxs" +OUTPUT_QUANTIZED_INDEXES_DICT_KEY = "output_quantized_idxs" +PRESERVED_ATTRIBUTES_DICT_KEY = "preserved_attributes" + + +@dataclass +class StandaloneModuleConfigEntry: + # qconfig_mapping for the prepare function called in the submodule, + # None means use qconfig from parent qconfig_mapping + qconfig_mapping: QConfigMapping | None + example_inputs: tuple[Any, ...] + prepare_custom_config: PrepareCustomConfig | None + backend_config: BackendConfig | None + + +class PrepareCustomConfig: + """ + Custom configuration for :func:`~torch.ao.quantization.quantize_fx.prepare_fx` and + :func:`~torch.ao.quantization.quantize_fx.prepare_qat_fx`. + + Example usage:: + + prepare_custom_config = PrepareCustomConfig() \ + .set_standalone_module_name("module1", qconfig_mapping, example_inputs, \ + child_prepare_custom_config, backend_config) \ + .set_standalone_module_class(MyStandaloneModule, qconfig_mapping, example_inputs, \ + child_prepare_custom_config, backend_config) \ + .set_float_to_observed_mapping(FloatCustomModule, ObservedCustomModule) \ + .set_non_traceable_module_names(["module2", "module3"]) \ + .set_non_traceable_module_classes([NonTraceableModule1, NonTraceableModule2]) \ + .set_input_quantized_indexes([0]) \ + .set_output_quantized_indexes([0]) \ + .set_preserved_attributes(["attr1", "attr2"]) + """ + + def __init__(self) -> None: + self.standalone_module_names: dict[str, StandaloneModuleConfigEntry] = {} + self.standalone_module_classes: dict[type, StandaloneModuleConfigEntry] = {} + self.float_to_observed_mapping: dict[QuantType, dict[type, type]] = {} + self.non_traceable_module_names: list[str] = [] + self.non_traceable_module_classes: list[type] = [] + self.input_quantized_indexes: list[int] = [] + self.output_quantized_indexes: list[int] = [] + self.preserved_attributes: list[str] = [] + + def __repr__(self): + dict_nonempty = {k: v for k, v in self.__dict__.items() if len(v) > 0} + return f"PrepareCustomConfig({dict_nonempty})" + + def set_standalone_module_name( + self, + module_name: str, + qconfig_mapping: QConfigMapping | None, + example_inputs: tuple[Any, ...], + prepare_custom_config: PrepareCustomConfig | None, + backend_config: BackendConfig | None, + ) -> PrepareCustomConfig: + """ + Set the configuration for running a standalone module identified by ``module_name``. + + If ``qconfig_mapping`` is None, the parent ``qconfig_mapping`` will be used instead. + If ``prepare_custom_config`` is None, an empty ``PrepareCustomConfig`` will be used. + If ``backend_config`` is None, the parent ``backend_config`` will be used instead. + """ + self.standalone_module_names[module_name] = StandaloneModuleConfigEntry( + qconfig_mapping, example_inputs, prepare_custom_config, backend_config + ) + return self + + def set_standalone_module_class( + self, + module_class: type, + qconfig_mapping: QConfigMapping | None, + example_inputs: tuple[Any, ...], + prepare_custom_config: PrepareCustomConfig | None, + backend_config: BackendConfig | None, + ) -> PrepareCustomConfig: + """ + Set the configuration for running a standalone module identified by ``module_class``. + + If ``qconfig_mapping`` is None, the parent ``qconfig_mapping`` will be used instead. + If ``prepare_custom_config`` is None, an empty ``PrepareCustomConfig`` will be used. + If ``backend_config`` is None, the parent ``backend_config`` will be used instead. + """ + self.standalone_module_classes[module_class] = StandaloneModuleConfigEntry( + qconfig_mapping, example_inputs, prepare_custom_config, backend_config + ) + return self + + def set_float_to_observed_mapping( + self, + float_class: type, + observed_class: type, + quant_type: QuantType = QuantType.STATIC, + ) -> PrepareCustomConfig: + """ + Set the mapping from a custom float module class to a custom observed module class. + + The observed module class must have a ``from_float`` class method that converts the float module class + to the observed module class. This is currently only supported for static quantization. + """ + if quant_type != QuantType.STATIC: + raise ValueError( + "set_float_to_observed_mapping is currently only supported for static quantization" + ) + if quant_type not in self.float_to_observed_mapping: + self.float_to_observed_mapping[quant_type] = {} + self.float_to_observed_mapping[quant_type][float_class] = observed_class + return self + + def set_non_traceable_module_names( + self, module_names: list[str] + ) -> PrepareCustomConfig: + """ + Set the modules that are not symbolically traceable, identified by name. + """ + self.non_traceable_module_names = module_names + return self + + def set_non_traceable_module_classes( + self, module_classes: list[type] + ) -> PrepareCustomConfig: + """ + Set the modules that are not symbolically traceable, identified by class. + """ + self.non_traceable_module_classes = module_classes + return self + + def set_input_quantized_indexes(self, indexes: list[int]) -> PrepareCustomConfig: + """ + Set the indexes of the inputs of the graph that should be quantized. + Inputs are otherwise assumed to be in fp32 by default instead. + """ + self.input_quantized_indexes = indexes + return self + + def set_output_quantized_indexes(self, indexes: list[int]) -> PrepareCustomConfig: + """ + Set the indexes of the outputs of the graph that should be quantized. + Outputs are otherwise assumed to be in fp32 by default instead. + """ + self.output_quantized_indexes = indexes + return self + + def set_preserved_attributes(self, attributes: list[str]) -> PrepareCustomConfig: + """ + Set the names of the attributes that will persist in the graph module even if they are not used in + the model's ``forward`` method. + """ + self.preserved_attributes = attributes + return self + + # TODO: remove this + @classmethod + def from_dict( + cls, prepare_custom_config_dict: dict[str, Any] + ) -> PrepareCustomConfig: + """ + Create a ``PrepareCustomConfig`` from a dictionary with the following items: + + "standalone_module_name": a list of (module_name, qconfig_mapping, example_inputs, + child_prepare_custom_config, backend_config) tuples + + "standalone_module_class" a list of (module_class, qconfig_mapping, example_inputs, + child_prepare_custom_config, backend_config) tuples + + "float_to_observed_custom_module_class": a nested dictionary mapping from quantization + mode to an inner mapping from float module classes to observed module classes, e.g. + {"static": {FloatCustomModule: ObservedCustomModule}} + + "non_traceable_module_name": a list of modules names that are not symbolically traceable + "non_traceable_module_class": a list of module classes that are not symbolically traceable + "input_quantized_idxs": a list of indexes of graph inputs that should be quantized + "output_quantized_idxs": a list of indexes of graph outputs that should be quantized + "preserved_attributes": a list of attributes that persist even if they are not used in ``forward`` + + This function is primarily for backward compatibility and may be removed in the future. + """ + + def _get_qconfig_mapping(obj: Any, dict_key: str) -> QConfigMapping | None: + """ + Convert the given object into a QConfigMapping if possible, else throw an exception. + """ + if isinstance(obj, QConfigMapping) or obj is None: + return obj + if isinstance(obj, dict): + return QConfigMapping.from_dict(obj) + raise ValueError( + f"Expected QConfigMapping in prepare_custom_config_dict[\"{dict_key}\"], got '{type(obj)}'" + ) + + def _get_prepare_custom_config( + obj: Any, dict_key: str + ) -> PrepareCustomConfig | None: + """ + Convert the given object into a PrepareCustomConfig if possible, else throw an exception. + """ + if isinstance(obj, PrepareCustomConfig) or obj is None: + return obj + if isinstance(obj, dict): + return PrepareCustomConfig.from_dict(obj) + raise ValueError( + f"Expected PrepareCustomConfig in prepare_custom_config_dict[\"{dict_key}\"], got '{type(obj)}'" + ) + + def _get_backend_config(obj: Any, dict_key: str) -> BackendConfig | None: + """ + Convert the given object into a BackendConfig if possible, else throw an exception. + """ + if isinstance(obj, BackendConfig) or obj is None: + return obj + if isinstance(obj, dict): + return BackendConfig.from_dict(obj) + raise ValueError( + f"Expected BackendConfig in prepare_custom_config_dict[\"{dict_key}\"], got '{type(obj)}'" + ) + + conf = cls() + for ( + module_name, + qconfig_dict, + example_inputs, + _prepare_custom_config_dict, + backend_config_dict, + ) in prepare_custom_config_dict.get(STANDALONE_MODULE_NAME_DICT_KEY, []): + qconfig_mapping = _get_qconfig_mapping( + qconfig_dict, STANDALONE_MODULE_NAME_DICT_KEY + ) + prepare_custom_config = _get_prepare_custom_config( + _prepare_custom_config_dict, STANDALONE_MODULE_NAME_DICT_KEY + ) + backend_config = _get_backend_config( + backend_config_dict, STANDALONE_MODULE_NAME_DICT_KEY + ) + conf.set_standalone_module_name( + module_name, + qconfig_mapping, + example_inputs, + prepare_custom_config, + backend_config, + ) + for ( + module_class, + qconfig_dict, + example_inputs, + _prepare_custom_config_dict, + backend_config_dict, + ) in prepare_custom_config_dict.get(STANDALONE_MODULE_CLASS_DICT_KEY, []): + qconfig_mapping = _get_qconfig_mapping( + qconfig_dict, STANDALONE_MODULE_CLASS_DICT_KEY + ) + prepare_custom_config = _get_prepare_custom_config( + _prepare_custom_config_dict, STANDALONE_MODULE_CLASS_DICT_KEY + ) + backend_config = _get_backend_config( + backend_config_dict, STANDALONE_MODULE_CLASS_DICT_KEY + ) + conf.set_standalone_module_class( + module_class, + qconfig_mapping, + example_inputs, + prepare_custom_config, + backend_config, + ) + for quant_type_name, custom_module_mapping in prepare_custom_config_dict.get( + FLOAT_TO_OBSERVED_DICT_KEY, {} + ).items(): + quant_type = _quant_type_from_str(quant_type_name) + for float_class, observed_class in custom_module_mapping.items(): + conf.set_float_to_observed_mapping( + float_class, observed_class, quant_type + ) + conf.set_non_traceable_module_names( + prepare_custom_config_dict.get(NON_TRACEABLE_MODULE_NAME_DICT_KEY, []) + ) + conf.set_non_traceable_module_classes( + prepare_custom_config_dict.get(NON_TRACEABLE_MODULE_CLASS_DICT_KEY, []) + ) + conf.set_input_quantized_indexes( + prepare_custom_config_dict.get(INPUT_QUANTIZED_INDEXES_DICT_KEY, []) + ) + conf.set_output_quantized_indexes( + prepare_custom_config_dict.get(OUTPUT_QUANTIZED_INDEXES_DICT_KEY, []) + ) + conf.set_preserved_attributes( + prepare_custom_config_dict.get(PRESERVED_ATTRIBUTES_DICT_KEY, []) + ) + return conf + + def to_dict(self) -> dict[str, Any]: + """ + Convert this ``PrepareCustomConfig`` to a dictionary with the items described in + :func:`~torch.ao.quantization.fx.custom_config.PrepareCustomConfig.from_dict`. + """ + + def _make_tuple(key: Any, e: StandaloneModuleConfigEntry): + qconfig_dict = e.qconfig_mapping.to_dict() if e.qconfig_mapping else None + prepare_custom_config_dict = ( + e.prepare_custom_config.to_dict() if e.prepare_custom_config else None + ) + return ( + key, + qconfig_dict, + e.example_inputs, + prepare_custom_config_dict, + e.backend_config, + ) + + d: dict[str, Any] = {} + for module_name, sm_config_entry in self.standalone_module_names.items(): + if STANDALONE_MODULE_NAME_DICT_KEY not in d: + d[STANDALONE_MODULE_NAME_DICT_KEY] = [] + d[STANDALONE_MODULE_NAME_DICT_KEY].append( + _make_tuple(module_name, sm_config_entry) + ) + for module_class, sm_config_entry in self.standalone_module_classes.items(): + if STANDALONE_MODULE_CLASS_DICT_KEY not in d: + d[STANDALONE_MODULE_CLASS_DICT_KEY] = [] + d[STANDALONE_MODULE_CLASS_DICT_KEY].append( + _make_tuple(module_class, sm_config_entry) + ) + for ( + quant_type, + float_to_observed_mapping, + ) in self.float_to_observed_mapping.items(): + if FLOAT_TO_OBSERVED_DICT_KEY not in d: + d[FLOAT_TO_OBSERVED_DICT_KEY] = {} + d[FLOAT_TO_OBSERVED_DICT_KEY][_get_quant_type_to_str(quant_type)] = ( + float_to_observed_mapping + ) + if len(self.non_traceable_module_names) > 0: + d[NON_TRACEABLE_MODULE_NAME_DICT_KEY] = self.non_traceable_module_names + if len(self.non_traceable_module_classes) > 0: + d[NON_TRACEABLE_MODULE_CLASS_DICT_KEY] = self.non_traceable_module_classes + if len(self.input_quantized_indexes) > 0: + d[INPUT_QUANTIZED_INDEXES_DICT_KEY] = self.input_quantized_indexes + if len(self.output_quantized_indexes) > 0: + d[OUTPUT_QUANTIZED_INDEXES_DICT_KEY] = self.output_quantized_indexes + if len(self.preserved_attributes) > 0: + d[PRESERVED_ATTRIBUTES_DICT_KEY] = self.preserved_attributes + return d + + +class ConvertCustomConfig: + """ + Custom configuration for :func:`~torch.ao.quantization.quantize_fx.convert_fx`. + + Example usage:: + + convert_custom_config = ConvertCustomConfig() \ + .set_observed_to_quantized_mapping(ObservedCustomModule, QuantizedCustomModule) \ + .set_preserved_attributes(["attr1", "attr2"]) + """ + + def __init__(self) -> None: + self.observed_to_quantized_mapping: dict[QuantType, dict[type, type]] = {} + self.preserved_attributes: list[str] = [] + + def __repr__(self): + dict_nonempty = {k: v for k, v in self.__dict__.items() if len(v) > 0} + return f"ConvertCustomConfig({dict_nonempty})" + + def set_observed_to_quantized_mapping( + self, + observed_class: type, + quantized_class: type, + quant_type: QuantType = QuantType.STATIC, + ) -> ConvertCustomConfig: + """ + Set the mapping from a custom observed module class to a custom quantized module class. + + The quantized module class must have a ``from_observed`` class method that converts the observed module class + to the quantized module class. + """ + if quant_type not in self.observed_to_quantized_mapping: + self.observed_to_quantized_mapping[quant_type] = {} + self.observed_to_quantized_mapping[quant_type][observed_class] = quantized_class + return self + + def set_preserved_attributes(self, attributes: list[str]) -> ConvertCustomConfig: + """ + Set the names of the attributes that will persist in the graph module even if they are not used in + the model's ``forward`` method. + """ + self.preserved_attributes = attributes + return self + + # TODO: remove this + @classmethod + def from_dict( + cls, convert_custom_config_dict: dict[str, Any] + ) -> ConvertCustomConfig: + """ + Create a ``ConvertCustomConfig`` from a dictionary with the following items: + + "observed_to_quantized_custom_module_class": a nested dictionary mapping from quantization + mode to an inner mapping from observed module classes to quantized module classes, e.g.:: + { + "static": {FloatCustomModule: ObservedCustomModule}, + "dynamic": {FloatCustomModule: ObservedCustomModule}, + "weight_only": {FloatCustomModule: ObservedCustomModule} + } + "preserved_attributes": a list of attributes that persist even if they are not used in ``forward`` + + This function is primarily for backward compatibility and may be removed in the future. + """ + conf = cls() + for quant_type_name, custom_module_mapping in convert_custom_config_dict.get( + OBSERVED_TO_QUANTIZED_DICT_KEY, {} + ).items(): + quant_type = _quant_type_from_str(quant_type_name) + for observed_class, quantized_class in custom_module_mapping.items(): + conf.set_observed_to_quantized_mapping( + observed_class, quantized_class, quant_type + ) + conf.set_preserved_attributes( + convert_custom_config_dict.get(PRESERVED_ATTRIBUTES_DICT_KEY, []) + ) + return conf + + def to_dict(self) -> dict[str, Any]: + """ + Convert this ``ConvertCustomConfig`` to a dictionary with the items described in + :func:`~torch.ao.quantization.fx.custom_config.ConvertCustomConfig.from_dict`. + """ + d: dict[str, Any] = {} + for ( + quant_type, + observed_to_quantized_mapping, + ) in self.observed_to_quantized_mapping.items(): + if OBSERVED_TO_QUANTIZED_DICT_KEY not in d: + d[OBSERVED_TO_QUANTIZED_DICT_KEY] = {} + d[OBSERVED_TO_QUANTIZED_DICT_KEY][_get_quant_type_to_str(quant_type)] = ( + observed_to_quantized_mapping + ) + if len(self.preserved_attributes) > 0: + d[PRESERVED_ATTRIBUTES_DICT_KEY] = self.preserved_attributes + return d + + +class FuseCustomConfig: + """ + Custom configuration for :func:`~torch.ao.quantization.quantize_fx.fuse_fx`. + + Example usage:: + + fuse_custom_config = FuseCustomConfig().set_preserved_attributes( + ["attr1", "attr2"] + ) + """ + + def __init__(self) -> None: + self.preserved_attributes: list[str] = [] + + def __repr__(self): + dict_nonempty = {k: v for k, v in self.__dict__.items() if len(v) > 0} + return f"FuseCustomConfig({dict_nonempty})" + + def set_preserved_attributes(self, attributes: list[str]) -> FuseCustomConfig: + """ + Set the names of the attributes that will persist in the graph module even if they are not used in + the model's ``forward`` method. + """ + self.preserved_attributes = attributes + return self + + # TODO: remove this + @classmethod + def from_dict(cls, fuse_custom_config_dict: dict[str, Any]) -> FuseCustomConfig: + """ + Create a ``ConvertCustomConfig`` from a dictionary with the following items: + + "preserved_attributes": a list of attributes that persist even if they are not used in ``forward`` + + This function is primarily for backward compatibility and may be removed in the future. + """ + conf = cls() + conf.set_preserved_attributes( + fuse_custom_config_dict.get(PRESERVED_ATTRIBUTES_DICT_KEY, []) + ) + return conf + + def to_dict(self) -> dict[str, Any]: + """ + Convert this ``FuseCustomConfig`` to a dictionary with the items described in + :func:`~torch.ao.quantization.fx.custom_config.ConvertCustomConfig.from_dict`. + """ + d: dict[str, Any] = {} + if len(self.preserved_attributes) > 0: + d[PRESERVED_ATTRIBUTES_DICT_KEY] = self.preserved_attributes + return d diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/fuse.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/fuse.py new file mode 100644 index 0000000000000000000000000000000000000000..3f4ee15779a180ea88c7dda47c7e6a45da092714 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/fuse.py @@ -0,0 +1,195 @@ +# mypy: allow-untyped-defs +import warnings +from collections.abc import Callable +from typing import Any + +from torch.ao.quantization.backend_config import ( + BackendConfig, + get_native_backend_config, +) +from torch.ao.quantization.backend_config.utils import ( + get_fuser_method_mapping, + get_fusion_pattern_to_extra_inputs_getter, + get_fusion_pattern_to_root_node_getter, +) +from torch.ao.quantization.utils import NodePattern, Pattern +from torch.fx import GraphModule, map_arg, Node +from torch.fx.graph import Graph + +from .custom_config import FuseCustomConfig +from .fuse_handler import _get_fusion_pattern_to_fuse_handler_cls, FuseHandler +from .match_utils import _is_match, MatchAllNode +from .pattern_utils import _sorted_patterns_dict + + +__all__ = [ + "fuse", + # TODO: We should make this private in the future + # This is currently needed for test_public_bindings for some reason + "FuseHandler", +] + + +def fuse( + model: GraphModule, + is_qat: bool, + fuse_custom_config: FuseCustomConfig | dict[str, Any] | None = None, + backend_config: BackendConfig | dict[str, Any] | None = None, +) -> GraphModule: + if fuse_custom_config is None: + fuse_custom_config = FuseCustomConfig() + + if isinstance(fuse_custom_config, dict): + warnings.warn( + "Passing a fuse_custom_config_dict to fuse is deprecated and will not be supported " + "in a future version. Please pass in a FuseCustomConfig instead.", + FutureWarning, + stacklevel=2, + ) + fuse_custom_config = FuseCustomConfig.from_dict(fuse_custom_config) + + if isinstance(backend_config, dict): + warnings.warn( + "Passing a backend_config_dict to prepare is deprecated and will not be supported " + "in a future version. Please pass in a BackendConfig instead.", + FutureWarning, + stacklevel=2, + ) + backend_config = BackendConfig.from_dict(backend_config) + + named_modules = dict(model.named_modules()) + + if backend_config is None: + backend_config = get_native_backend_config() + + fusion_pattern_to_fuse_handler_cls = _sorted_patterns_dict( + _get_fusion_pattern_to_fuse_handler_cls(backend_config) + ) + fuser_method_mapping = get_fuser_method_mapping(backend_config) + fusion_pattern_to_root_node_getter = get_fusion_pattern_to_root_node_getter( + backend_config + ) + fusion_pattern_to_extra_inputs_getter = get_fusion_pattern_to_extra_inputs_getter( + backend_config + ) + + # find fusion + fusion_pairs = _find_matches(model, model.graph, fusion_pattern_to_fuse_handler_cls) + # TODO: change this to inplace changes to graph, since we no longer construct + # new GraphModule anymore + fused_graph = Graph() + env: dict[Any, Any] = {} + + def load_arg(a): + return map_arg(a, lambda node: env[node.name]) + + def default_root_node_getter(node_pattern): + while not isinstance(node_pattern[-1], Node): + node_pattern = node_pattern[-1] + return node_pattern[-1] + + for node in model.graph.nodes: + ( + maybe_last_node, + pattern, + matched_node_pattern, + obj, + node_to_subpattern, + ) = fusion_pairs.get(node.name, (None, None, None, None, None)) + # get the corresponding subpattern for the current node + if node_to_subpattern is not None: + node_subpattern = node_to_subpattern.get(node, None) + else: + node_subpattern = None + if maybe_last_node is node: + if obj is None: + raise AssertionError( + "fuse handler object must not be None for matched root node" + ) + root_node_getter = fusion_pattern_to_root_node_getter.get( + pattern, default_root_node_getter + ) + root_node = root_node_getter(matched_node_pattern) # type: ignore[index] + extra_inputs_getter = fusion_pattern_to_extra_inputs_getter.get( + pattern, None + ) + extra_inputs = [] + if extra_inputs_getter is not None: + extra_inputs = extra_inputs_getter(matched_node_pattern) + # TODO: add validation that root_node is a module and has the same type + # as the root_module in the configuration + env[node.name] = obj.fuse( + load_arg, + named_modules, + fused_graph, + root_node, + extra_inputs, + matched_node_pattern, # type: ignore[arg-type] + fuse_custom_config, + fuser_method_mapping, + is_qat, + ) + elif maybe_last_node is None or node_subpattern is MatchAllNode: + env[node.name] = fused_graph.node_copy(node, load_arg) + # node matched in patterns and is not root is removed here + + model = GraphModule(model, fused_graph) + return model + + +def _find_matches( + root: GraphModule, + graph: Graph, + pattern_to_fuse_handler_cls: dict[Pattern, Callable], +) -> dict[str, tuple[Node, Pattern, NodePattern, FuseHandler, dict[Node, Any]]]: + modules = dict(root.named_modules()) + # node name -> (root_node, match_value) + match_map: dict[ + str, tuple[Node, Pattern, NodePattern, FuseHandler, dict[Node, Any]] + ] = {} + # a map from node to the matched subpattern + node_to_subpattern: dict[Node, Any] = {} + + # TODO: dedup with quantization matching function in match_utils.py + def apply_match(pattern, node, match, matched_node_pattern, node_to_subpattern): + if isinstance(pattern, tuple): + s, *args = pattern + current_node_pattern: list[Node] = [] + apply_match(s, node, match, current_node_pattern, node_to_subpattern) + for subpattern, arg in zip(args, node.args): + apply_match( + subpattern, arg, match, current_node_pattern, node_to_subpattern + ) + matched_node_pattern.append(tuple(current_node_pattern)) + else: + # the first pattern matches will take precedence + if node.name not in match_map: + matched_node_pattern.append(node) + # MatchAllNode here is actually MatchAllInputNode which should not + # be added to match_map + if pattern is not MatchAllNode: + node_to_subpattern[node] = pattern + root_node, pattern, handler = match + match_map[node.name] = ( + root_node, + pattern, + matched_node_pattern, + handler, + node_to_subpattern, + ) + + for node in reversed(graph.nodes): + if node.name not in match_map: + for pattern, fuse_handler_cls in pattern_to_fuse_handler_cls.items(): + matched_node_pattern: list[Node] = [] + if _is_match(modules, node, pattern): + apply_match( + pattern, + node, + (node, pattern, fuse_handler_cls(node)), + matched_node_pattern, + node_to_subpattern, + ) + break + + return match_map diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/fuse_handler.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/fuse_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..76fe84c2c3ad5fd88303d5f04e83c3dccfd24e5a --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/fuse_handler.py @@ -0,0 +1,129 @@ +# mypy: allow-untyped-defs +from abc import ABC, abstractmethod +from collections.abc import Callable +from typing import Any + +import torch +from torch.ao.quantization.backend_config import BackendConfig +from torch.ao.quantization.fuser_method_mappings import get_fuser_method_new +from torch.ao.quantization.utils import _parent_name, NodePattern, Pattern +from torch.fx.graph import Graph, Node +from torch.nn.utils.parametrize import type_before_parametrizations + +from .custom_config import FuseCustomConfig +from .match_utils import MatchAllNode + + +__all__ = [ + "DefaultFuseHandler", + "FuseHandler", +] + + +# ---------------------------- +# Fusion Pattern Registrations +# ---------------------------- + + +# Base Pattern Handler +class FuseHandler(ABC): + """Base handler class for the fusion patterns""" + + @abstractmethod + def __init__(self, node: Node): + pass + + @abstractmethod + def fuse( + self, + load_arg: Callable, + named_modules: dict[str, torch.nn.Module], + fused_graph: Graph, + root_node: Node, + extra_inputs: list[Any], + matched_node_pattern: NodePattern, + fuse_custom_config: FuseCustomConfig, + fuser_method_mapping: dict[Pattern, torch.nn.Sequential | Callable], + is_qat: bool, + ) -> Node: + pass + + +class DefaultFuseHandler(FuseHandler): + def __init__(self, node: Node): # pylint: disable=useless-parent-delegation + super().__init__(node) # type:ignore[safe-super] + + def fuse( + self, + load_arg: Callable, + named_modules: dict[str, torch.nn.Module], + fused_graph: Graph, + root_node: Node, + extra_inputs: list[Any], + matched_node_pattern: NodePattern, + fuse_custom_config: FuseCustomConfig, + fuser_method_mapping: dict[Pattern, torch.nn.Sequential | Callable], + is_qat: bool, + ) -> Node: + if root_node.op != "call_module": + raise AssertionError("Expecting module node to be a call_module Node") + root_module = named_modules[str(root_node.target)] + + def get_modules(pattern): + """Given a node pattern, extract the corresponding modules + e.g. input: (relu_node, (bn_node, conv_node)) + output: (relu_module, (bn_module, conv_module)) + """ + if isinstance(pattern, (tuple, list)): + n, *args = pattern + modules: list[torch.nn.Module] = [] + modules.append(get_modules(n)) + modules.extend(get_modules(a) for a in args) + return tuple(modules) + else: + n = pattern + if n.op == "call_module": + return named_modules[n.target] + elif n.op == "call_function" and n.target is torch.nn.functional.relu: + relu = torch.nn.ReLU() + relu.training = root_module.training + return relu + elif n.op == "call_function" or n.op == "call_method": + return n.target + else: + return MatchAllNode + + # since relu can be used multiple times, we'll need to create a relu module for each match + matched_modules = get_modules(matched_node_pattern) + + def get_matched_types(m): + if isinstance(m, tuple): + return tuple(map(get_matched_types, m)) + if isinstance(m, torch.nn.Module): + return type_before_parametrizations(m) + return m + + matched_module_types = get_matched_types(matched_modules) + module_parent_name, module_name = _parent_name(root_node.target) + fuser_method = get_fuser_method_new(matched_module_types, fuser_method_mapping) + # TODO: change the signature for fuser_method to take matched module patterns + # as input + fused_module = fuser_method(is_qat, *matched_modules) + setattr(named_modules[module_parent_name], module_name, fused_module) + extra_args = [load_arg(input) for input in extra_inputs] + node = fused_graph.node_copy(root_node, load_arg) + args = list(node.args) + args.extend(extra_args) + node.args = tuple(args) + return node + + +def _get_fusion_pattern_to_fuse_handler_cls( + backend_config: BackendConfig, +) -> dict[Pattern, Callable]: + fusion_pattern_to_fuse_handlers: dict[Pattern, Callable] = {} + for pattern, config in backend_config._pattern_complex_format_to_config.items(): + if config.fuser_method is not None: + # TODO: is this logic right? + fusion_pattern_to_fuse_handlers[pattern] = DefaultFuseHandler + return fusion_pattern_to_fuse_handlers diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/graph_module.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/graph_module.py new file mode 100644 index 0000000000000000000000000000000000000000..87ec3179a68ee26a5b2199c3f7543fdfd73e2864 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/graph_module.py @@ -0,0 +1,205 @@ +# mypy: allow-untyped-defs +import copy +from typing import Any + +import torch +from torch.fx import GraphModule +from torch.fx.graph import Graph + + +__all__ = [ + "FusedGraphModule", + "ObservedGraphModule", + "ObservedStandaloneGraphModule", + "QuantizedGraphModule", +] + + +class FusedGraphModule(GraphModule): + def __init__( + self, + root: torch.nn.Module | dict[str, Any], + graph: Graph, + preserved_attr_names: set[str], + ): + self.preserved_attr_names = preserved_attr_names + preserved_attrs = { + attr: getattr(root, attr) + for attr in self.preserved_attr_names + if hasattr(root, attr) + } + super().__init__(root, graph) + for attr in preserved_attrs: + setattr(self, attr, preserved_attrs[attr]) + + # GraphModule does not copy attributes which are not in the __dict__ + # of vanilla nn.Module. So, we override __deepcopy__ in order + # to copy the quantization specific attributes correctly. + def __deepcopy__(self, memo): + fake_mod = torch.nn.Module() + fake_mod.__dict__ = copy.deepcopy(self.__dict__) + return FusedGraphModule( + fake_mod, + copy.deepcopy(self.graph), + copy.deepcopy(self.preserved_attr_names), + ) + + +class ObservedGraphModule(GraphModule): + def __init__( + self, + root: torch.nn.Module | dict[str, Any], + graph: Graph, + preserved_attr_names: set[str], + ): + self.preserved_attr_names = { + "_activation_post_process_map", + "_activation_post_process_indexes", + "_patterns", + "_node_name_to_qconfig", + "_prepare_custom_config", + "_equalization_node_name_to_qconfig", + "_node_name_to_scope", + "_qconfig_mapping", + "_is_qat", + "_observed_node_names", + }.union(preserved_attr_names) + preserved_attrs = { + attr: getattr(root, attr) + for attr in self.preserved_attr_names + if hasattr(root, attr) + } + super().__init__(root, graph) + for attr in preserved_attrs: + setattr(self, attr, preserved_attrs[attr]) + + # GraphModule does not copy attributes which are not in the __dict__ + # of vanilla nn.Module. So, we override __deepcopy__ in order + # to copy the quantization specific attributes correctly. + def __deepcopy__(self, memo): + fake_mod = torch.nn.Module() + fake_mod.__dict__ = copy.deepcopy(self.__dict__) + return ObservedGraphModule( + fake_mod, + copy.deepcopy(self.graph), + copy.deepcopy(self.preserved_attr_names), + ) + + +def _is_observed_module(module: Any) -> bool: + return hasattr(module, "meta") and "_observed_graph_module_attrs" in module.meta + + +def _get_observed_graph_module_attr( + model: torch.nn.Module | GraphModule, attr_name: str +) -> Any: + if hasattr(model, "meta") and "_observed_graph_module_attrs" in model.meta: # type: ignore[operator, index] + return getattr(model.meta["_observed_graph_module_attrs"], attr_name) # type: ignore[index] + return None + + +class ObservedStandaloneGraphModule(ObservedGraphModule): + def __init__( + self, + root: torch.nn.Module | dict[str, Any], + graph: Graph, + preserved_attr_names: set[str], + ): + preserved_attr_names = preserved_attr_names.union( + { + "_standalone_module_input_quantized_idxs", + "_standalone_module_output_quantized_idxs", + } + ) + super().__init__(root, graph, preserved_attr_names) + + def __deepcopy__(self, memo): + fake_mod = torch.nn.Module() + fake_mod.__dict__ = copy.deepcopy(self.__dict__) + return ObservedStandaloneGraphModule( + fake_mod, + copy.deepcopy(self.graph), + copy.deepcopy(self.preserved_attr_names), + ) + + +def _is_observed_standalone_module(module: Any) -> bool: + return ( + _is_observed_module(module) + and module.meta["_observed_graph_module_attrs"].is_observed_standalone_module + ) + + +def _save_packed_weight(self, destination, prefix, keep_vars): + for attr_name in dir(self): + if "_packed_weight" in attr_name and isinstance( + getattr(self, attr_name), torch._C.ScriptObject + ): # type: ignore[attr-defined] + packed_weight = getattr(self, attr_name) + destination[prefix + attr_name] = packed_weight + + +class QuantizedGraphModule(GraphModule): + """This class is created to make sure PackedParams + (e.g. LinearPackedParams, Conv2dPackedParams) to appear in state_dict + so that we can serialize and deserialize quantized graph module with + torch.save(m.state_dict()) and m.load_state_dict(state_dict) + """ + + def __init__( + self, + root: torch.nn.Module | dict[str, Any], + graph: Graph, + preserved_attr_names: set[str], + ): + self.preserved_attr_names = preserved_attr_names + preserved_attrs = { + attr: getattr(root, attr) + for attr in self.preserved_attr_names + if hasattr(root, attr) + } + super().__init__(root, graph) + for attr in preserved_attrs: + setattr(self, attr, preserved_attrs[attr]) + self._register_state_dict_hook(_save_packed_weight) + + def _load_from_state_dict( + self, + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ): + attrs_to_pop = [] + for attr_name in state_dict: + if attr_name.startswith("_packed_weight") and isinstance( + state_dict[attr_name], torch._C.ScriptObject + ): # type: ignore[attr-defined] # noqa: B950 + setattr(self, attr_name, state_dict[attr_name]) + attrs_to_pop.append(attr_name) + + # pop the packed param attributesn + for attr_name in attrs_to_pop: + state_dict.pop(attr_name) + + super()._load_from_state_dict( + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ) + + def __deepcopy__(self, memo): + fake_mod = torch.nn.Module() + fake_mod.__dict__ = copy.deepcopy(self.__dict__) + return QuantizedGraphModule( + fake_mod, + copy.deepcopy(self.graph), + copy.deepcopy(self.preserved_attr_names), + ) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/lower_to_fbgemm.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/lower_to_fbgemm.py new file mode 100644 index 0000000000000000000000000000000000000000..73fd3e8741b6d6c26d5a352d25d4cf6986de4d9d --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/lower_to_fbgemm.py @@ -0,0 +1,21 @@ +from torch.ao.quantization.qconfig import QConfigAny +from torch.fx import GraphModule + +from ._lower_to_native_backend import _lower_to_native_backend + + +__all__ = ["lower_to_fbgemm"] + + +def lower_to_fbgemm( + model: GraphModule, + qconfig_map: dict[str, QConfigAny], + node_name_to_scope: dict[str, tuple[str, type]], + keep_original_weights: bool = False, +) -> GraphModule: + """Lower a quantized reference model (with reference quantized operator patterns) + to fbgemm + """ + return _lower_to_native_backend( + model, qconfig_map, node_name_to_scope, keep_original_weights + ) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/lower_to_qnnpack.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/lower_to_qnnpack.py new file mode 100644 index 0000000000000000000000000000000000000000..f1fa3ecf3f5a3b2b5dc67d769853f8424bae7efb --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/lower_to_qnnpack.py @@ -0,0 +1,18 @@ +from torch.ao.quantization.qconfig import QConfigAny +from torch.fx import GraphModule + +from ._lower_to_native_backend import _lower_to_native_backend + + +__all__ = ["lower_to_qnnpack"] + + +def lower_to_qnnpack( + model: GraphModule, + qconfig_map: dict[str, QConfigAny], + node_name_to_scope: dict[str, tuple[str, type]], +) -> GraphModule: + """Lower a quantized reference model (with reference quantized operator patterns) + to qnnpack + """ + return _lower_to_native_backend(model, qconfig_map, node_name_to_scope) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/lstm_utils.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/lstm_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..78849692a45efab6b7ce3af208ee16f6d77286c6 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/lstm_utils.py @@ -0,0 +1,228 @@ +import copy +import operator +from typing import Any, TYPE_CHECKING + +import torch +from torch.ao.quantization import ( + default_weight_fake_quant, + default_weight_observer, + FakeQuantizeBase, + QConfig, + QConfigMapping, +) +from torch.ao.quantization.backend_config import BackendConfig +from torch.ao.quantization.observer import _PartialWrapper +from torch.ao.quantization.quantize_fx import convert_to_reference_fx, prepare_fx + + +if TYPE_CHECKING: + from collections.abc import Callable + + +# TODO: move all LSTM util functions from fx/utils.py to this file +def _get_lstm_with_individually_observed_parts( + float_lstm: torch.nn.LSTM, + example_inputs: tuple[Any, ...], + backend_config: BackendConfig | None = None, + linear_output_obs_ctr: _PartialWrapper | None = None, + sigmoid_obs_ctr: _PartialWrapper | None = None, + tanh_obs_ctr: _PartialWrapper | None = None, + cell_state_obs_ctr: _PartialWrapper | None = None, + hidden_state_obs_ctr: _PartialWrapper | None = None, + split_gates: bool = False, +) -> torch.ao.nn.quantizable.LSTM: + """ + Return an observed `torch.ao.nn.quantizable.LSTM` created from a `torch.nn.LSTM` + with specific observers or fake quantizes assigned to the inner ops or submodules. + + In both eager and FX graph mode quantization, `torch.ao.nn.quantizable.LSTM` is + used as an observed custom module, which is responsible for inserting its own + observers. By default, all inner ops inherit the parent custom module's QConfig. + Users who wish to override this behavior may extend `torch.ao.nn.quantizable.LSTM` + and use this helper function to customize the observer insertion logic. + + This is meant to be used to convert a float module to an observed module in the + custom module flow. + + Args: + `float_lstm`: The float LSTM module + `example_inputs`: example inputs for the forward function of the LSTM module + `backend_config`: BackendConfig to use to observe the LSTM module + `linear_output_obs_ctr`: observer or fake quantize for linear outputs Wx + b, + where W is the weight matrix, b is the bias, and x is either the inputs + or the hidden state from the previous layer (if any) + `sigmoid_obs_ctr`: observer or fake quantize for sigmoid activations + `tanh_obs_ctr`: observer or fake quantize for tanh activations + `cell_state_obs_ctr`: observer or fake quantize for the cell state + `hidden_state_obs_ctr`: observer or fake quantize for the hidden state and + the output + + Return: + A `torch.ao.nn.quantizable.LSTM` with the specified observers or fake quantizes + assigned to the inner ops. + """ + + def make_qconfig(obs_ctr: _PartialWrapper) -> QConfig: + """ + Make a QConfig with fixed qparams observers or fake quantizes. + """ + if isinstance(obs_ctr(), FakeQuantizeBase): + weight = default_weight_fake_quant + else: + weight = default_weight_observer + return QConfig(activation=obs_ctr, weight=weight) + + quantizable_lstm = torch.ao.nn.quantizable.LSTM( + float_lstm.input_size, + float_lstm.hidden_size, + float_lstm.num_layers, + float_lstm.bias, + float_lstm.batch_first, + float_lstm.dropout, + float_lstm.bidirectional, + split_gates=split_gates, + ) + quantizable_lstm.qconfig = float_lstm.qconfig + + for idx in range(float_lstm.num_layers): + quantizable_lstm.layers[idx] = ( + torch.ao.nn.quantizable.modules.rnn._LSTMLayer.from_float( + float_lstm, + idx, + float_lstm.qconfig, + batch_first=False, + split_gates=split_gates, + ) + ) + + # Build QConfigMapping for the LSTM cell + # Note: FloatFunctional qconfigs will be configured separately below + cell_qm = QConfigMapping().set_global(float_lstm.qconfig) # type: ignore[arg-type] + if sigmoid_obs_ctr is not None: + cell_qm.set_module_name("input_gate", make_qconfig(sigmoid_obs_ctr)) + cell_qm.set_module_name("forget_gate", make_qconfig(sigmoid_obs_ctr)) + cell_qm.set_module_name("output_gate", make_qconfig(sigmoid_obs_ctr)) + if tanh_obs_ctr is not None: + cell_qm.set_module_name("cell_gate", make_qconfig(tanh_obs_ctr)) + + # Insert observers into each LSTM cell + # TODO: maybe make this work for layer_bw as well + for layer in quantizable_lstm.layers: + cell = layer.layer_fw.cell # type: ignore[union-attr] + if not isinstance(cell, torch.nn.Module): + raise AssertionError("cell should be a nn.Module") + cell = prepare_fx(cell, cell_qm, example_inputs, backend_config=backend_config) + # HACK: Manually replace the activation_post_process following these ops. + # This is needed for FloatFunctional ops because there is currently no way + # to configure these ops in FX graph mode quantization today. This is because + # the FloatFunctional modules simply disappear from the graph after tracing. + # In the future, we should rewrite quantizable LSTM without FloatFunctionals. + if not split_gates: + op_index_to_activation_post_process_ctr = { + (torch.add, 0): linear_output_obs_ctr, # gates.add + (torch.mul, 0): cell_state_obs_ctr, # fgate_cx.mul + (torch.mul, 1): cell_state_obs_ctr, # igate_cgate.mul + (torch.add, 1): cell_state_obs_ctr, # fgate_cx_igate_cgate.add + (torch.mul, 2): hidden_state_obs_ctr, # ogate_cy.mul + } + else: + op_index_to_activation_post_process_ctr = { + (torch.add, 0): linear_output_obs_ctr, # gates.add (input) + (torch.add, 1): linear_output_obs_ctr, # gates.add (forget) + (torch.add, 2): linear_output_obs_ctr, # gates.add (cell) + (torch.add, 3): linear_output_obs_ctr, # gates.add (output) + (torch.mul, 0): cell_state_obs_ctr, # fgate_cx.mul + (torch.mul, 1): cell_state_obs_ctr, # igate_cgate.mul + (torch.add, 4): cell_state_obs_ctr, # fgate_cx_igate_cgate.add + (torch.mul, 2): hidden_state_obs_ctr, # ogate_cy.mul + } + add_count = 0 + mul_count = 0 + for node in cell.graph.nodes: + op_index: tuple[Callable, int] | None = None # e.g. (torch.add, 1) + if node.target is torch.add: + op_index = (torch.add, add_count) + add_count += 1 + elif node.target is torch.mul: + op_index = (torch.mul, mul_count) + mul_count += 1 + else: + # Neither torch.add nor torch.mul + continue + if op_index not in op_index_to_activation_post_process_ctr: + continue + if len(node.users) != 1: + raise AssertionError("expected exactly one user for the node") + activation_post_process_name = next(iter(node.users.keys())).name + activation_post_process_ctr = op_index_to_activation_post_process_ctr[ + op_index + ] + if activation_post_process_ctr is not None: + setattr( + cell, activation_post_process_name, activation_post_process_ctr() + ) + layer.layer_fw.cell = cell # type: ignore[union-attr] + return quantizable_lstm + + +def _get_reference_quantized_lstm_module( + observed_lstm: torch.ao.nn.quantizable.LSTM, + backend_config: BackendConfig | None = None, +) -> torch.ao.nn.quantized.LSTM: + """ + Return a `torch.ao.nn.quantized.LSTM` created from a `torch.ao.nn.quantizable.LSTM` + with observers or fake quantizes inserted through `prepare_fx`, e.g. from + `_get_lstm_with_individually_observed_parts`. + + This is meant to be used to convert an observed module to a quantized module in the + custom module flow. + + Args: + `observed_lstm`: a `torch.ao.nn.quantizable.LSTM` observed through `prepare_fx` + `backend_config`: BackendConfig to use to produce the reference quantized model + + Return: + A reference `torch.ao.nn.quantized.LSTM` module. + """ + quantized_lstm = torch.ao.nn.quantized.LSTM( + observed_lstm.input_size, + observed_lstm.hidden_size, + observed_lstm.num_layers, + observed_lstm.bias, + observed_lstm.batch_first, + observed_lstm.dropout, + observed_lstm.bidirectional, + ) + + for i, layer in enumerate(quantized_lstm.layers): + cell = copy.deepcopy(observed_lstm.layers.get_submodule(str(i)).layer_fw.cell) # type: ignore[union-attr] + cell = convert_to_reference_fx(cell, backend_config=backend_config) # type: ignore[arg-type] + if not isinstance(cell, torch.fx.GraphModule): + raise AssertionError("cell must be converted to a torch.fx.GraphModule") + # HACK: Manually remove input quantize nodes and output dequantize nodes, + # since custom modules expect quint8 inputs and outputs for now. Note that + # this functionality is supposedly handled through PrepareCustomConfig's + # `set_input_quantized_indexes` and `set_output_quantized_indexes`, but that + # API doesn't currently handle tuple inputs and outputs, so we have to do + # this manually for now. In the future we should (1) relax the restriction + # on custom module input/output dtypes, and (2) expand support for complex + # input/output structures. + for node in cell.graph.nodes: + if node.target is torch.quantize_per_tensor: + arg = node.args[0] + # Remove quantize(x), quantize(hidden[0]), and quantize(hidden[1]) + if arg.target == "x" or ( + arg.target is operator.getitem and arg.args[0].target == "hidden" + ): + with cell.graph.inserting_before(node): + node.replace_all_uses_with(arg) + cell.graph.erase_node(node) + if node.target == "output": + # Remove all dequantize nodes in the output tuple + for arg in node.args[0]: + with cell.graph.inserting_before(node): + node.replace_input_with(arg, arg.args[0]) + cell.graph.eliminate_dead_code() + cell.recompile() + layer.layer_fw.cell = cell # type: ignore[union-attr] + return quantized_lstm diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/match_utils.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/match_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..79194caa4a17b9f2db99981b184081d09df80e84 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/match_utils.py @@ -0,0 +1,231 @@ +# mypy: allow-untyped-defs +import sys +from collections.abc import Callable, Iterable +from typing import Any + +import torch +from torch.ao.quantization.qconfig import QConfigAny +from torch.ao.quantization.utils import MatchAllNode, Pattern +from torch.fx.graph import Graph, Node +from torch.nn.utils.parametrize import type_before_parametrizations + +from .graph_module import _is_observed_standalone_module +from .quantize_handler import QuantizeHandler + + +__all__: list[str] = [] + +# TODO(future PR): the 1st argument is typed as `List[Node]`, but a better type +# would be a recursive `List[Union[Node, Tuple[Union[Node, ...]]]]` +_MatchResult = tuple[Node, list[Node], Pattern | None, QuantizeHandler] + +_MatchResultWithQConfig = tuple[ + Node, list[Node], Pattern | None, QuantizeHandler, QConfigAny +] + + +# Note: The order of patterns is important! match function will take whatever is matched first, so we'll +# need to put the fusion patterns before single patterns. For example, add_relu should be registered come before relu. +# decorators are applied in the reverse order we see. Also when we match the nodes in the graph with these patterns, +# we'll start from the last node of the graph and traverse back. +def _is_match(modules, node, pattern, max_uses=sys.maxsize): + """Matches a node in fx against a pattern""" + if isinstance(pattern, tuple): + self_match, *arg_matches = pattern + if self_match is getattr: + if len(pattern) != 2: + raise AssertionError("Expecting getattr pattern to have two elements") + arg_matches = [] + else: + self_match = pattern + arg_matches = [] + + if isinstance(self_match, type) and issubclass(self_match, MatchAllNode): + return True + + if node == pattern: + return True + + if not isinstance(node, Node) or len(node.users) > max_uses: + return False + + if isinstance(self_match, type) and issubclass(self_match, torch.nn.Module): + if node.op != "call_module": + return False + if type_before_parametrizations(modules[node.target]) != self_match: + return False + elif callable(self_match): + if node.op != "call_function" or node.target is not self_match: + return False + elif node.target is getattr: + if node.args[1] != pattern[1]: + return False + elif isinstance(self_match, str): + if node.op != "call_method" or node.target != self_match: + return False + elif node.target != self_match: + return False + + if not arg_matches: + return True + + if len(arg_matches) != len(node.args): + return False + + return all( + _is_match(modules, node, arg_match, max_uses=1) + for node, arg_match in zip(node.args, arg_matches) + ) + + +def _find_matches( + graph: Graph, + modules: dict[str, torch.nn.Module], + patterns: dict[Pattern, QuantizeHandler], + root_node_getter_mapping: dict[Pattern, Callable], + standalone_module_names: list[str] | None = None, + standalone_module_classes: list[type] | None = None, + custom_module_classes: list[Any] | None = None, +) -> dict[str, _MatchResult]: + """ + Matches the nodes in the input graph to quantization patterns, and + outputs the information needed to quantize them in future steps. + + Inputs: + - graph: an fx.Graph object + - modules: a mapping of fully qualified module name to instance, + for example, {'foo': ModuleFoo, ...} + - patterns: a mapping from a tuple of nodes in reverse order to + uninitialized QuantizeHandler subclass. + + Outputs a map of + node_name -> + (node, matched_values, matched_pattern, QuantizeHandler instance, + qconfig) + + For example, { + 'relu_1': (relu_1, [relu_1], torch.nn.functional.relu, + , QConfig(...)), + ... + } + """ + if custom_module_classes is None: + custom_module_classes = [] + + if standalone_module_classes is None: + standalone_module_classes = [] + + if standalone_module_names is None: + standalone_module_names = [] + + match_map: dict[str, _MatchResult] = {} + all_matched: set[str] = set() + + def _recursive_record_node_in_match_map( + last_node, match_map, node_pattern, matched_node_pattern, pattern, match_value + ): + if isinstance(node_pattern, Node): + match_map[node_pattern.name] = ( + last_node, + matched_node_pattern, + pattern, + match_value, + ) + elif not isinstance(node_pattern, Iterable): + return + else: + for n in node_pattern: + _recursive_record_node_in_match_map( + last_node, match_map, n, matched_node_pattern, pattern, match_value + ) + + # TODO: 1. merge with fuse matcher 2. document the code + def record_match(pattern, node, last_node, matched_node_pattern, match_map): + if isinstance(pattern, tuple): + s, *args = pattern + is_single_arg = len(args) == 1 + current_node_pattern: list[Node] = [] + record_match(s, node, last_node, matched_node_pattern, match_map) + if pattern[0] is not getattr: + for subpattern, arg in zip(args, node.args): + record_match(subpattern, arg, node, current_node_pattern, match_map) + if len(current_node_pattern) > 1: + # current_node_pattern is the node pattern we get from matching + # the subpattern with arguments of the node + # we use is_single_arg to recover the original structure of the pattern + # if the original pattern has a single argument, we will have + # (original_op, (original_arg, ...)) + # otherwise, we'll have a list of arguments + # (original_op, arg0, arg1, arg2, ...) + if is_single_arg: + matched_node_pattern.append(tuple(current_node_pattern)) + else: + matched_node_pattern.extend(list(current_node_pattern)) + else: + matched_node_pattern.append(current_node_pattern[0]) + else: + matched_node_pattern.append(node) + + for node in reversed(graph.nodes): + if node.name not in match_map and node.name not in all_matched: + for pattern, quantize_handler_cls in patterns.items(): + root_node_getter = root_node_getter_mapping.get(pattern) + if _is_match(modules, node, pattern) and node.name not in match_map: + matched_node_pattern: list[Node] = [] + record_match(pattern, node, node, matched_node_pattern, match_map) + quantize_handler = quantize_handler_cls( # type: ignore[operator] + matched_node_pattern, modules, root_node_getter + ) + last_node = node + # record the match for all nodes in the pattern + _recursive_record_node_in_match_map( + last_node, + match_map, + # we need to record all nodes in the matched pattern in the match_map + matched_node_pattern, + # this is a part of the value corresponding to the node + matched_node_pattern, + pattern, + quantize_handler, + ) + break + + # add custom module instances to the match result + if modules is None: + raise AssertionError("modules must not be None") + for node in graph.nodes: + if ( + node.op == "call_module" + and type(modules[node.target]) in custom_module_classes + ): + match_map[node.name] = ( + node, + node, + None, + QuantizeHandler(node, modules, is_custom_module=True), + ) + + def is_standalone_module(node_target: str, modules: dict[str, torch.nn.Module]): + if modules is None: + raise AssertionError("modules must not be None") + return ( + node_target in standalone_module_names + or type(modules[node_target]) # type: ignore[operator] + in standalone_module_classes # type: ignore[operator] + ) + + # add standalone modules to the match + for node in graph.nodes: + if node.op == "call_module" and ( + is_standalone_module(node.target, modules) + or _is_observed_standalone_module(modules[node.target]) + ): + # add node to matched nodes + match_map[node.name] = ( + node, + node, + None, + QuantizeHandler(node, modules, is_standalone_module=True), + ) + + return match_map diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/pattern_utils.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/pattern_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e86f95d67aba092daff6a3a14a14767f29d249a2 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/pattern_utils.py @@ -0,0 +1,112 @@ +# mypy: allow-untyped-defs +import copy +from collections import OrderedDict +from typing import Any + +from torch.ao.quantization.fake_quantize import FixedQParamsFakeQuantize +from torch.ao.quantization.observer import ObserverBase +from torch.ao.quantization.utils import Pattern + + +__all__ = [ + "get_default_fusion_patterns", + "get_default_quant_patterns", + "get_default_output_activation_post_process_map", +] + +# TODO(future PR): fix the typing on QuantizeHandler (currently a circular dependency) +QuantizeHandler = Any + +# pattern for conv bn fusion +_DEFAULT_FUSION_PATTERNS: dict[Pattern, QuantizeHandler] = OrderedDict() + + +def _register_fusion_pattern(pattern): + def insert(fn): + _DEFAULT_FUSION_PATTERNS[pattern] = fn + return fn + + return insert + + +def get_default_fusion_patterns() -> dict[Pattern, QuantizeHandler]: + return copy.copy(_DEFAULT_FUSION_PATTERNS) + + +_DEFAULT_QUANTIZATION_PATTERNS: dict[Pattern, QuantizeHandler] = OrderedDict() + +# Mapping from pattern to activation_post_process(observer/fake_quant) constructor for output activation +# e.g. pattern: torch.sigmoid, +# output_activation_post_process: default_fixed_qparams_range_0to1_fake_quant +_DEFAULT_OUTPUT_FAKE_QUANTIZE_MAP: dict[Pattern, QuantizeHandler] = {} +_DEFAULT_OUTPUT_OBSERVER_MAP: dict[Pattern, QuantizeHandler] = {} + + +# Register pattern for both static quantization and qat +def _register_quant_pattern(pattern, fixed_qparams_observer=None): + def insert(fn): + _DEFAULT_QUANTIZATION_PATTERNS[pattern] = fn + if fixed_qparams_observer is not None: + _DEFAULT_OUTPUT_FAKE_QUANTIZE_MAP[pattern] = ( + FixedQParamsFakeQuantize.with_args(observer=fixed_qparams_observer) + ) + _DEFAULT_OUTPUT_OBSERVER_MAP[pattern] = fixed_qparams_observer + return fn + + return insert + + +# Get patterns for both static quantization and qat +def get_default_quant_patterns() -> dict[Pattern, QuantizeHandler]: + return copy.copy(_DEFAULT_QUANTIZATION_PATTERNS) + + +# a map from pattern to output activation post process constructor +# e.g. torch.sigmoid -> default_affine_fixed_qparam_fake_quant +def get_default_output_activation_post_process_map( + is_training, +) -> dict[Pattern, ObserverBase]: + if is_training: + return copy.copy(_DEFAULT_OUTPUT_FAKE_QUANTIZE_MAP) + else: + return copy.copy(_DEFAULT_OUTPUT_OBSERVER_MAP) + + +# Example use of register pattern function: +# @_register_fusion_pattern(torch.nn.ReLU, (torch.nn.BatchNorm2d, torch.nn.Conv2d))) +# class ConvOrLinearBNReLUFusion(): +# def __init__(...): +# ... +# + + +def _sorted_patterns_dict( + patterns_dict: dict[Pattern, QuantizeHandler], +) -> dict[Pattern, QuantizeHandler]: + """ + Return a sorted version of the patterns dictionary such that longer patterns are matched first, + e.g. match (F.relu, F.linear) before F.relu. + This works for current use cases, but we may need to have a more clever way to sort + things to address more complex patterns + """ + + def get_len(pattern): + """this will calculate the length of the pattern by counting all the entries + in the pattern. + this will make sure (nn.ReLU, (nn.BatchNorm, nn.Conv2d)) comes before + (nn.BatchNorm, nn.Conv2d) so that we can match the former first + """ + len = 0 + if isinstance(pattern, tuple): + for item in pattern: + len += get_len(item) + else: + len += 1 + return len + + return OrderedDict( + sorted( + patterns_dict.items(), + key=lambda kv: -get_len(kv[0]) if isinstance(kv[0], tuple) else 1, + ) + ) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/prepare.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/prepare.py new file mode 100644 index 0000000000000000000000000000000000000000..0c2fab3f27eb917b22368dae04cd908f2f81a7c2 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/prepare.py @@ -0,0 +1,2251 @@ +# mypy: allow-untyped-defs +import copy +import warnings +from dataclasses import asdict +from typing import Any + +import torch +from torch._subclasses import FakeTensor +from torch.ao.quantization import ( + _DerivedObserverOrFakeQuantize, + FixedQParamsFakeQuantize, + FixedQParamsObserver, + ObserverBase, + ObserverOrFakeQuantize, + PlaceholderObserver, +) +from torch.ao.quantization.backend_config import ( + BackendConfig, + DTypeConfig, + get_native_backend_config, +) +from torch.ao.quantization.backend_config.utils import ( + get_fusion_pattern_to_root_node_getter, + get_module_to_qat_module, + get_pattern_to_dtype_configs, +) +from torch.ao.quantization.observer import _is_activation_post_process, _PartialWrapper +from torch.ao.quantization.qconfig import _is_reuse_input_qconfig, QConfigAny +from torch.ao.quantization.qconfig_mapping import QConfigMapping +from torch.ao.quantization.quantize import convert, propagate_qconfig_ +from torch.ao.quantization.quantizer import ( + DerivedQuantizationSpec, + EdgeOrNode, + FixedQParamsQuantizationSpec, + QuantizationSpec, + QuantizationSpecBase, + SharedQuantizationSpec, +) +from torch.ao.quantization.utils import ( + _parent_name, + get_qconfig_dtypes, + get_swapped_custom_module_class, + NodePattern, + Pattern, +) +from torch.fx import GraphModule +from torch.fx.graph import Graph, Node +from torch.fx.node import Argument + +from ._equalize import is_equalization_observer, node_supports_equalization +from .custom_config import PrepareCustomConfig, StandaloneModuleConfigEntry +from .match_utils import _find_matches, _MatchResultWithQConfig +from .pattern_utils import _sorted_patterns_dict +from .qconfig_mapping_utils import ( + _generate_node_name_to_qconfig, + _get_flattened_qconfig_dict, + _update_qconfig_for_fusion, + _update_qconfig_for_qat, +) +from .quantize_handler import ( + _default_root_node_getter, + _get_pattern_to_quantize_handlers, + QuantizeHandler, +) +from .utils import ( + _insert_dequant_stubs_for_custom_module_lstm_output, + _is_custom_module_lstm, + _maybe_get_custom_module_lstm_from_node_arg, + _qconfig_satisfies_dtype_config_constraints, + all_node_args_have_no_tensors, + assert_and_get_unique_device, + get_custom_module_class_keys, + get_new_attr_name_with_prefix, + get_non_observable_arg_indexes_and_types, + node_arg_is_bias, + node_arg_is_weight, + NON_QUANTIZABLE_WEIGHT_OPS, + ObservedGraphModuleAttrs, +) + + +__all__ = [ + "insert_observers_for_model", + "prepare", + "propagate_dtypes_for_known_nodes", +] + + +# list of dtypes to not add observers to +_DO_NOT_OBS_DTYPE_LIST = [int, float, torch.bool, None] +_OBS_DTYPE_LIST = [ + torch.quint8, + torch.qint8, + torch.qint32, + torch.float16, + torch.uint8, + torch.int8, + torch.int16, + torch.int32, + torch.float8_e5m2, + torch.float8_e4m3fn, +] + +_DEFAULT_FP32_OBS_OR_FQ_CTR = PlaceholderObserver.with_args(dtype=torch.float) + +# note: the following default target dtype info dicts are temporary, +# should be moved to the new programmable API class soon +_DEFAULT_FP32_QCONFIG_FOR_TARGET_DTYPE_INFO = { + "input_act_obs_or_fq_ctr": torch.ao.quantization.qconfig._default_fp32_placeholder_qconfig.activation, + "output_act_obs_or_fq_ctr": torch.ao.quantization.qconfig._default_fp32_placeholder_qconfig.activation, +} + +_DEFAULT_QUINT8_QCONFIG_FOR_TARGET_DTYPE_INFO = { + "input_act_obs_or_fq_ctr": torch.ao.quantization.qconfig._default_quint8_placeholder_qconfig.activation, + "output_act_obs_or_fq_ctr": torch.ao.quantization.qconfig._default_quint8_placeholder_qconfig.activation, +} + + +def _get_observer_kwargs( + quant_spec: QuantizationSpec | FixedQParamsQuantizationSpec, +): + kwargs_dict = asdict(quant_spec) + return copy.deepcopy(kwargs_dict) + + +def _get_qspec_for_arg( + arg: Node, + input_qspec_map: dict[Node, QuantizationSpecBase], + named_modules: dict[str, torch.nn.Module], +) -> QuantizationSpecBase | None: + while _is_activation_post_process_node(arg, named_modules): + arg = arg.args[0] # type: ignore[assignment] + return input_qspec_map.get(arg) + + +def _create_obs_or_fq_from_qspec( + quantization_spec: QuantizationSpecBase | None, + obs_or_fq_map: dict[EdgeOrNode, ObserverOrFakeQuantize], + is_qat: bool, +): + """Create observer or fake quantize objects based on quantization spec + + Args: + quantization_spec: used to store parameters to create the observer or fake quantizer + obs_or_fq_map: this is a map from edge/output to the corresponding observer/fake_quant + instance, it may be reused for different edge/output depending on configuration + """ + if quantization_spec is None: + return None + if isinstance(quantization_spec, SharedQuantizationSpec): + edge_or_node = quantization_spec.edge_or_node + if edge_or_node not in obs_or_fq_map: + raise AssertionError( + "please make sure only refer to edge or node that has " + f"observer/fake_quant inserted: '{edge_or_node}' not in\n{obs_or_fq_map.keys()}" + ) + return obs_or_fq_map[edge_or_node] + elif isinstance(quantization_spec, DerivedQuantizationSpec): + # can't use asdict, so not calling get_observer_kwargs here + kwargs = { + "dtype": quantization_spec.dtype, + "derive_qparams_fn": quantization_spec.derive_qparams_fn, + "quant_min": quantization_spec.quant_min, + "quant_max": quantization_spec.quant_max, + "qscheme": quantization_spec.qscheme, + "ch_axis": quantization_spec.ch_axis, + } + edge_or_nodes = quantization_spec.derived_from + obs_or_fqs = [obs_or_fq_map[k] for k in edge_or_nodes] + # pyrefly: ignore [unsupported-operation] + kwargs["obs_or_fqs"] = obs_or_fqs + return _DerivedObserverOrFakeQuantize.with_args(**kwargs)() + elif isinstance(quantization_spec, FixedQParamsQuantizationSpec): + kwargs = _get_observer_kwargs(quantization_spec) + observer_ctr = FixedQParamsObserver.with_args(**kwargs) + if is_qat: + return FixedQParamsFakeQuantize.with_args(observer=observer_ctr)() + else: + return observer_ctr() + + if not isinstance(quantization_spec, QuantizationSpec): + raise AssertionError("quantization_spec must be a QuantizationSpec") + observer_or_fake_quant_ctr = quantization_spec.observer_or_fake_quant_ctr + kwargs = _get_observer_kwargs(quantization_spec) + kwargs.pop("observer_or_fake_quant_ctr") + # we will remove is_dynamic from QuantizationSpec because + # it seems that dynamic range quantization + obs_or_fq_class = observer_or_fake_quant_ctr + if isinstance(observer_or_fake_quant_ctr, _PartialWrapper): + obs_or_fq_class = observer_or_fake_quant_ctr.p.func # type: ignore[union-attr, assignment] + if "PerChannel" not in obs_or_fq_class.__name__: # type: ignore[operator, union-attr] + kwargs.pop("ch_axis") + return observer_or_fake_quant_ctr.with_args(**kwargs)() + + +def _needs_obs_or_fq( + prev_output_dtype: Any, + prev_output_is_dynamic: bool, + cur_target_dtype: Any, + cur_target_is_dynamic: bool, + reuse_input_obs_or_fq: bool, + is_zeroth_arg: bool = False, +) -> bool: + """ + note: we will treat "not specified" as torch.float for now + utility function that checks if we should insert an observer or fake quant node + base on the requested dtype for the nodes from user + + is_zeroth_arg: we only dynamically quantize the first arg of the node right now + this should be removed when we enable configuring dynamic quantization + for a specific argument, this can be removed if we deprecate fx graph mode + quantization + + """ + + # need to insert placeholder observer for dynamic quantization so that it can + # be converted to choose_qparams -> q -> dq in convert step + if cur_target_is_dynamic: + if cur_target_dtype not in _OBS_DTYPE_LIST: + raise AssertionError( + f"Expected cur_target_dtype to be torch.float, but got: {cur_target_dtype}" + ) + if prev_output_dtype in _DO_NOT_OBS_DTYPE_LIST: + raise AssertionError( + "prev_output_dtype must not be in _DO_NOT_OBS_DTYPE_LIST" + ) + return is_zeroth_arg + if reuse_input_obs_or_fq: + return False + # non dynamic quantization + if cur_target_dtype in _OBS_DTYPE_LIST: + return ( + prev_output_dtype in _OBS_DTYPE_LIST + [torch.float] + and cur_target_dtype != prev_output_dtype + ) + + # lots of error checking are skipped here for now + return False + + +def _is_activation_post_process_node( + node: Node, named_modules: dict[str, torch.nn.Module] +) -> bool: + return ( + isinstance(node, torch.fx.Node) + and node.op == "call_module" + and _is_activation_post_process(named_modules[str(node.target)]) + ) + + +def _get_dtype_and_is_dynamic( + obs_or_fq: ObserverOrFakeQuantize | None, +) -> tuple[torch.dtype | None, bool]: + """Given a constructor for observer or fake quant module, returns + a Tuple of dtype and is_dynamic + """ + # TODO: instead of instantiating the instance, we can use inspect to get the default args + if obs_or_fq is None: + return None, False + else: + return obs_or_fq.dtype, getattr(obs_or_fq, "is_dynamic", False) # type: ignore[return-value] + + +def _is_input_arg_dtype_supported_by_backend( + arg: Argument, + node: Node, + qconfig: QConfigAny, + dtype_config: DTypeConfig, + backend_config: BackendConfig, +) -> bool: + """Check if the configured qconfig for the argument + is supported by the backend or not + """ + if isinstance(arg, (list, tuple)): + return all( + _is_input_arg_dtype_supported_by_backend( + a, node, qconfig, dtype_config, backend_config + ) + for a in arg + ) + if not isinstance(arg, Node): + return True + # TODO: support check for standalone module + is_weight = node_arg_is_weight(node, arg) + is_bias = node_arg_is_bias(node, arg) + is_activation = not is_weight and not is_bias + if is_activation: + input_act_obs_or_fq_ctr = node.meta["target_dtype_info"].get( + "input_act_obs_or_fq_ctr" + ) + input_act_obs_or_fq = ( + input_act_obs_or_fq_ctr() if input_act_obs_or_fq_ctr else None + ) + qconfig_dtype, qconfig_is_dynamic = _get_dtype_and_is_dynamic( + input_act_obs_or_fq + ) + # TODO(future PR): remove the cast to bool below after figuring + # out why backend_config has is_dynamic set to None in some cases. + return (dtype_config.input_dtype is None) or ( + dtype_config.input_dtype == qconfig_dtype + and bool(dtype_config.is_dynamic) == bool(qconfig_is_dynamic) + and _qconfig_satisfies_dtype_config_constraints( + qconfig, dtype_config.input_dtype_with_constraints + ) + ) + elif is_weight: + # TODO: move dtype check into `_qconfig_satisfies_dtype_config_constraints` as well + weight_obs_or_fq_ctr = node.meta["target_dtype_info"].get( + "weight_obs_or_fq_ctr", None + ) + weight_obs_or_fq = weight_obs_or_fq_ctr() if weight_obs_or_fq_ctr else None + qconfig_weight_dtype, _ = _get_dtype_and_is_dynamic(weight_obs_or_fq) + backend_config_weight_dtype = dtype_config.weight_dtype + dtype_matches = qconfig_weight_dtype == backend_config_weight_dtype + qconfig_satisfies_constraints = _qconfig_satisfies_dtype_config_constraints( + qconfig, dtype_config.weight_dtype_with_constraints, is_activation=False + ) + return backend_config_weight_dtype is None or ( + dtype_matches and qconfig_satisfies_constraints + ) + else: # bias + # TODO: move dtype check into `_qconfig_satisfies_dtype_config_constraints` as well + bias_obs_or_fq_ctr = node.meta["target_dtype_info"].get( + "bias_obs_or_fq_ctr", None + ) + bias_obs_or_fq = bias_obs_or_fq_ctr() if bias_obs_or_fq_ctr else None + qconfig_bias_dtype, _ = _get_dtype_and_is_dynamic(bias_obs_or_fq) + backend_config_bias_dtype = dtype_config.bias_dtype + return ( + backend_config_bias_dtype is None + or qconfig_bias_dtype == backend_config_bias_dtype + ) + + +def _is_output_dtype_supported_by_backend( + node: Node, + qconfig: QConfigAny, + dtype_config: DTypeConfig, +) -> bool: + """Check if the configured qconfig for the output + is supported by the backend or not + """ + # TODO: move dtype check into `_qconfig_satisfies_dtype_config_constraints` as well + backend_config_output_dtype = dtype_config.output_dtype + # TODO: we should check is_dynamic here as well, the code from _is_input_arg_dtype_supported_by_backend + # from input activation check can be reused here + qconfig_output_dtype = None + output_act_obs_or_fq_ctr = node.meta["target_dtype_info"].get( + "output_act_obs_or_fq_ctr", _DEFAULT_FP32_OBS_OR_FQ_CTR + ) + output_act_obs_or_fq = ( + output_act_obs_or_fq_ctr() if output_act_obs_or_fq_ctr else None + ) + qconfig_output_dtype, qconfig_output_is_dynamic = _get_dtype_and_is_dynamic( + output_act_obs_or_fq + ) + # TODO: this is a hack because we can only specify one activation_obs_or_fq for + # qconfig (qconfig.activation), and we are only supporting dynamically quantized + # linear op which has fp32 output dtype, this should be removed if we generalize + # the structure of qconfig in the future + if qconfig_output_is_dynamic: + qconfig_output_dtype = torch.float32 + dtype_matches = qconfig_output_dtype == backend_config_output_dtype + qconfig_satisfies_constraints = _qconfig_satisfies_dtype_config_constraints( + qconfig, dtype_config.output_dtype_with_constraints + ) + return backend_config_output_dtype is None or ( + dtype_matches and qconfig_satisfies_constraints + ) + + +def _is_observer_in_same_graph( + node: Node, + named_modules: dict[str, torch.nn.Module], + obs_or_fq_map: dict[EdgeOrNode, ObserverOrFakeQuantize], + is_qat, +): + """Check if observer in same graph + when the node output is not fp32 and input is 'placeholder' + the input is assumed to be quantized, so it is observed + in a different place rather than not observed. + """ + node_output_dtype = _get_arg_target_dtype_as_output( + node, named_modules, obs_or_fq_map, is_qat + ) + if len(node.args) > 0 and isinstance(node.args[0], Node): + if ( + node_output_dtype in [torch.quint8, torch.uint8] + and node.args[0].op == "placeholder" + ): + return False + return True + + +def _is_pattern_dtype_config_and_qconfig_supported_by_backend( + pattern: Pattern | None, + matched_node_pattern: list[Node] | None, + qconfig: QConfigAny, + backend_config: BackendConfig, +) -> bool: + """Check if the dtype configuration of a pattern is supported by + the backend or not, and whether the qconfig satisfies constraints + specified in the corresponding dtype config. + """ + if backend_config is None or pattern is None: + return True + if matched_node_pattern is None or len(matched_node_pattern) < 1: + raise AssertionError("matched_node_pattern must be non-empty") + pattern_to_dtype_configs = get_pattern_to_dtype_configs(backend_config) + dtype_configs: list[DTypeConfig] = pattern_to_dtype_configs.get(pattern, []) + pattern_to_root_node_getter = get_fusion_pattern_to_root_node_getter(backend_config) + + root_node_getter = pattern_to_root_node_getter.get( + pattern, _default_root_node_getter + ) + root_node = root_node_getter(matched_node_pattern) + input_node = root_node + output_node = matched_node_pattern[0] + for dtype_config in dtype_configs: + # check if arg dtype are supported + supported = True + for arg in list(input_node.args) + list(input_node.kwargs.values()): + supported = supported and _is_input_arg_dtype_supported_by_backend( + arg, input_node, qconfig, dtype_config, backend_config + ) + # check if output dtype is supported + supported = supported and _is_output_dtype_supported_by_backend( + output_node, qconfig, dtype_config + ) + if supported: + return True + return False + + +def _get_standalone_module_configs( + node: Node, + named_modules: dict[str, torch.nn.Module], + prepare_custom_config: PrepareCustomConfig, + parent_qconfig: QConfigAny, + parent_backend_config: BackendConfig | None, +) -> tuple[QConfigMapping, tuple[Any, ...], PrepareCustomConfig, BackendConfig | None]: + """ + Returns the standalone module QConfigMapping and PrepareCustomConfig + for `node`, assuming that the module pointed to by `node` is + a standalone modules. + """ + module_name = str(node.target) + module_type = type(named_modules[module_name]) # type: ignore[index] + # name config has precedence over type config + config_entry = StandaloneModuleConfigEntry(None, (), None, None) + config_entry = prepare_custom_config.standalone_module_classes.get( + module_type, config_entry + ) + config_entry = prepare_custom_config.standalone_module_names.get( + module_name, config_entry + ) + # fallback to use parent module's qconfig if user didn't specify qconfig dict + qconfig_mapping = config_entry.qconfig_mapping or QConfigMapping().set_global( + parent_qconfig + ) + example_inputs = config_entry.example_inputs + prepare_custom_config = config_entry.prepare_custom_config or PrepareCustomConfig() + backend_config = config_entry.backend_config or parent_backend_config + return (qconfig_mapping, example_inputs, prepare_custom_config, backend_config) + + +def _qat_swap_modules( + root: torch.nn.Module, module_to_qat_module: dict[Pattern, type[torch.nn.Module]] +) -> None: + convert(root, mapping=module_to_qat_module, inplace=True, remove_qconfig=False) + + +def _add_matched_node_name_to_set(matched_node_pattern: NodePattern, s: set[str]): + if isinstance(matched_node_pattern, Node): + s.add(matched_node_pattern.name) + elif isinstance(matched_node_pattern, (list, tuple)): + for maybe_node in matched_node_pattern: + _add_matched_node_name_to_set(maybe_node, s) + + +def _insert_obs_or_fq( + node: Node, + obs_or_fq: ObserverOrFakeQuantize, + model: torch.nn.Module, + named_modules: dict[str, torch.nn.Module], + graph: Graph, + model_device: torch.device | None = None, +) -> Node: + """ + Attaches `obs_or_fq` to `model`, and creates a node which calls + `obs_or_fq` on the output of `node`. + + obs_or_fq: an instance of Observer or FakeQuantize module + """ + if model_device is None: + model_device = assert_and_get_unique_device(model) + if model_device: + obs_or_fq.to(model_device) + # add obs_or_fq module as attribute + if is_equalization_observer(obs_or_fq): + prefix = node.name + "_equalization_process_" + else: + prefix = "activation_post_process_" + get_new_obs_or_fq_name = get_new_attr_name_with_prefix(prefix) + obs_or_fq_name = get_new_obs_or_fq_name(model) + setattr(model, obs_or_fq_name, obs_or_fq) + named_modules[obs_or_fq_name] = obs_or_fq + with graph.inserting_after(node): + new_obs = graph.create_node("call_module", obs_or_fq_name, (node,), {}) + return new_obs + + +def _set_target_dtype_info_for_matched_node_pattern( + matched_node_pattern: NodePattern, + last_node: Node, + qconfig: QConfigAny, + qhandler: QuantizeHandler | None, + backend_config: BackendConfig, + named_modules: dict[str, torch.nn.Module], + cache_for_no_tensor_check: dict[Node, bool], + processed_nodes: set[Node], +) -> None: + """Sets the target_dtype_info for each node in matched_node_pattern + Note: processed_nodes is used to ensure we only process each node once + """ + if isinstance(matched_node_pattern, (list, tuple)): + for node_pattern in matched_node_pattern: + _set_target_dtype_info_for_matched_node_pattern( + node_pattern, + last_node, + qconfig, + qhandler, + backend_config, + named_modules, + cache_for_no_tensor_check, + processed_nodes, + ) + + # set target_dtype_info if matched_node_pattern is a Node + # other types of matched object, e.g. int, float literals, are ignored + elif isinstance(matched_node_pattern, Node): + # for pyre + if not isinstance(matched_node_pattern, Node): + raise AssertionError("matched_node_pattern must be a Node") + node = matched_node_pattern + if node in processed_nodes: + return + processed_nodes.add(node) + + if qconfig is None: + return + # TODO: refactor the following code in terms of apply a qconfig to a pattern + # e.g. for a pattern with op1 -> op2 -> op3, and qconfig = QConfig(input_act=obs0, output_act=obs1) + # we set the input_obs_or_fq_ctr for the arguments of op1 to based on qconfig.input_act, + # and set output_obs_or_fq_ctr based on qconfig.output_act + # this also requires we extend the structure of QConfig to support more fine + # grained configurations + target_dtype_info: dict[str, Any] = _get_target_activation_dtype_for_node( + node, + qconfig, + qhandler, + named_modules, + backend_config, + cache_for_no_tensor_check, + ) + node.meta["target_dtype_info"] = target_dtype_info + + +def _get_target_activation_dtype_for_node( + node: Node, + qconfig: QConfigAny, + qhandler: QuantizeHandler | None, + named_modules: dict[str, torch.nn.Module], + backend_config: BackendConfig, + cache_for_no_tensor_check: dict[Node, bool], +) -> dict[str, Any]: + """ + For each op attribute in the op's input activation, output activation, + weight, bias - returns the settings of dtype and is_dynamic we expect + for the `quantize` call in the reference model representation, or None + if there is no `quantize` call needed. + + For example, if we have a node corresponding to `op0` in + + x0 -> op0 -> x1 + + And we want a reference quantized representation to be + + x0 -> quant_static -> dequant -> op0 -> quant_dynamic -> dequant -> x1 + + Then this function will return + + { + "input_act_obs_or_fq_ctr": MinMaxObserver.with_args(dtype=torch.quint8, is_dynamic=False), + "output_act_obs_or_fq_ctr": MinMaxObserver.with_args(dtype=torch.quint8, is_dynamic=False), + } + + TODO(future PR, if needed): explicitly spell out the non-Tensor + dtypes. + """ + args_have_no_tensors = all_node_args_have_no_tensors( + node, named_modules, cache_for_no_tensor_check + ) + if args_have_no_tensors: + return { + "input_act_obs_or_fq_ctr": None, + "output_act_obs_or_fq_ctr": None, + } + # get qconfig to determine the eventual dtype of this node + if qconfig is not None: + act_dtype, weight_dtype, input_act_is_dynamic = get_qconfig_dtypes(qconfig) + + # Currently `QConfig` only has one `activation` field. + # For static quantization, it is reused for both input + # and output activation. For dynamic quantization, this + # field is currently only used for the input activation, + # with the output activation being in fp32. + # In the future this may change as we add more fields + # to the `QConfig` object. + bias_dtype = ( + torch.float16 + if ( + act_dtype == torch.float16 + and weight_dtype == torch.float16 + and (not input_act_is_dynamic) + ) + else torch.float + ) + + is_general_tensor_value_op = ( + qhandler is not None and qhandler.is_general_tensor_value_op() + ) + + _is_standalone_module = qhandler is not None and qhandler.is_standalone_module() + + weight_index = None + if ( + isinstance(node, Node) + and node.op == "call_function" + and node.target in backend_config._pattern_complex_format_to_config + ): + weight_index = backend_config._pattern_complex_format_to_config[ + node.target + ]._input_type_to_index.get("weight") + + bias_index = None + if ( + isinstance(node, Node) + and node.op == "call_function" + and node.target in backend_config._pattern_complex_format_to_config + ): + bias_index = backend_config._pattern_complex_format_to_config[ + node.target + ]._input_type_to_index.get("bias") + + return { + "input_act_obs_or_fq_ctr": qconfig.activation, + "weight_obs_or_fq_ctr": qconfig.weight, + "bias_obs_or_fq_ctr": PlaceholderObserver.with_args(dtype=bias_dtype), + "weight_index": weight_index, + "bias_index": bias_index, + "output_act_obs_or_fq_ctr": qconfig.activation, + "reuse_input_obs_or_fq": _is_reuse_input_qconfig(qconfig), + "input_output_share_observers": is_general_tensor_value_op, + "_is_standalone_module": _is_standalone_module, + } + return copy.copy(_DEFAULT_FP32_QCONFIG_FOR_TARGET_DTYPE_INFO) + + +def _get_output_act_obs_or_fq( + arg: Node, + named_modules: dict[str, torch.nn.Module], + obs_or_fq_map: dict[EdgeOrNode, ObserverOrFakeQuantize], + is_qat: bool, +) -> ObserverOrFakeQuantize | None: + """Get the constructor for observer or fake quant object for + the argument in the original graph as the output of previous node, + skipping inserted observers + + We are assuming that the observers are inserted correctly, and the dtype for + argument in quantized graph will match what is specified by the qconfig + """ + if not isinstance(arg, Node): + raise AssertionError("arg must be a Node") + if "quantization_annotation" in arg.meta: + return _create_obs_or_fq_from_qspec( + arg.meta["quantization_annotation"].output_qspec, obs_or_fq_map, is_qat + ) + + # Custom module LSTM output is a tuple that we broke down into the internal nodes in order + # to insert DeQuantStubs (see `_insert_dequant_stubs_for_custom_module_lstm_output`). + # Since we modified the graph in this case, we must trace back from the args through + # the specific nodes we added in order to reach the original LSTM node. Otherwise, we would + # not be able to accurately detect whether this node is a consumer of custom module LSTM. + custom_module_lstm_node = _maybe_get_custom_module_lstm_from_node_arg( + arg, named_modules + ) + output_act_obs_or_fq_ctr = None + if custom_module_lstm_node is not None: + output_act_obs_or_fq_ctr = custom_module_lstm_node.meta["target_dtype_info"][ + "output_act_obs_or_fq_ctr" + ] + output_act_obs_or_fq = ( + output_act_obs_or_fq_ctr() if output_act_obs_or_fq_ctr else None + ) + elif _is_activation_post_process_node(arg, named_modules): + observed_arg = arg.args[0] + if not isinstance(observed_arg, Node): + raise AssertionError("Currently we only support observing Node") + if "quantization_annotation" in observed_arg.meta: + output_act_obs_or_fq = _create_obs_or_fq_from_qspec( + observed_arg.meta["quantization_annotation"].output_qspec, + obs_or_fq_map, + is_qat, + ) + else: + if "target_dtype_info" not in observed_arg.meta: + raise AssertionError( + "expected 'target_dtype_info' in observed_arg.meta" + ) + output_act_obs_or_fq_ctr = observed_arg.meta["target_dtype_info"][ + "output_act_obs_or_fq_ctr" + ] + output_act_obs_or_fq = ( + output_act_obs_or_fq_ctr() if output_act_obs_or_fq_ctr else None + ) + else: + if "target_dtype_info" in arg.meta: + output_act_obs_or_fq_ctr = arg.meta["target_dtype_info"].get( + "output_act_obs_or_fq_ctr", _DEFAULT_FP32_OBS_OR_FQ_CTR + ) + else: + output_act_obs_or_fq_ctr = _DEFAULT_FP32_OBS_OR_FQ_CTR + output_act_obs_or_fq = ( + output_act_obs_or_fq_ctr() if output_act_obs_or_fq_ctr else None + ) + + return output_act_obs_or_fq + + +def _get_arg_target_dtype_as_output( + arg: Node, + named_modules: dict[str, torch.nn.Module], + obs_or_fq_map: dict[EdgeOrNode, ObserverOrFakeQuantize], + is_qat: bool, +) -> torch.dtype | None: + arg_as_output_act_obs_or_fq = _get_output_act_obs_or_fq( + arg, named_modules, obs_or_fq_map, is_qat + ) + arg_as_output_target_dtype, _ = _get_dtype_and_is_dynamic( + arg_as_output_act_obs_or_fq + ) + return arg_as_output_target_dtype + + +def _get_arg_as_input_act_obs_or_fq( + arg: Node, + node: Node, + named_modules: dict[str, torch.nn.Module], + obs_or_fq_map: dict[EdgeOrNode, ObserverOrFakeQuantize], + is_qat: bool, +) -> ObserverOrFakeQuantize | None: + """Get the observer or fake quant constructor for the Argument `arg`, as input + to Node `node` + """ + if not isinstance(arg, Node): + raise AssertionError("arg must be a Node") + # "input_qspec_map" is the more general design we'll use for pt2e path + # it is a map from input argument node to observer or fake quant constructor, for example + # for the following graph: + # x -> conv -> output + # + # we may annotate conv node like the following: + # conv.meta[...] = QuantizationAnnotation("input_qspec_map": {x: MinMaxObserver.with_args(dtype=torch.qint8)}, ...) + # + if "quantization_annotation" in node.meta: + input_qspec_map = node.meta["quantization_annotation"].input_qspec_map + input_arg_qspec = _get_qspec_for_arg(arg, input_qspec_map, named_modules) + if input_arg_qspec is None: + input_arg_obs_or_fq = _DEFAULT_FP32_OBS_OR_FQ_CTR() + else: + input_arg_obs_or_fq = _create_obs_or_fq_from_qspec( + input_arg_qspec, obs_or_fq_map, is_qat + ) + return input_arg_obs_or_fq + + # we can remove the following path in the future if fx graph mode quantization is + # no longer used + is_weight = node_arg_is_weight(node, arg) + is_bias = node_arg_is_bias(node, arg) + is_activation = not is_weight and not is_bias + obs_or_fq_ctr = None + if is_activation: + obs_or_fq_ctr = node.meta["target_dtype_info"].get( + "input_act_obs_or_fq_ctr", _DEFAULT_FP32_OBS_OR_FQ_CTR + ) + elif is_weight: + if node.target not in NON_QUANTIZABLE_WEIGHT_OPS: + obs_or_fq_ctr = node.meta["target_dtype_info"].get( + "weight_obs_or_fq_ctr", _DEFAULT_FP32_OBS_OR_FQ_CTR + ) + else: + obs_or_fq_ctr = node.meta["target_dtype_info"].get( + "bias_obs_or_fq_ctr", _DEFAULT_FP32_OBS_OR_FQ_CTR + ) + return obs_or_fq_ctr() if obs_or_fq_ctr else None + + +def _maybe_insert_input_observer_for_arg_or_kwarg( + node: Node | Any, + arg: Argument, + qconfig: QConfigAny, + model: torch.nn.Module, + named_modules: dict[str, torch.nn.Module], + graph: Graph, + qhandler: QuantizeHandler | None, + prepare_custom_config: PrepareCustomConfig, + obs_or_fq_map: dict[EdgeOrNode, ObserverOrFakeQuantize], + is_qat: bool, + backend_config: BackendConfig | None = None, + model_device: torch.device | None = None, +) -> Argument: + """ + Given a `node` and an `arg`, inserts an input observer between + `node` and `arg` if necessary. + """ + # for ops such as torch.cat([x0, x1]), + # traverse through the list + if isinstance(arg, (list, tuple)): + new_arg_to_return = [] + for inner_arg in arg: + new_inner_arg = _maybe_insert_input_observer_for_arg_or_kwarg( + node, + inner_arg, + qconfig, + model, + named_modules, + graph, + qhandler, + prepare_custom_config, + obs_or_fq_map, + is_qat, + backend_config, + model_device, + ) + new_arg_to_return.append(new_inner_arg) + return type(arg)(new_arg_to_return) + + if not isinstance(arg, Node): + return arg + if not isinstance(arg, Node): + raise AssertionError("arg must be a Node") + # default (no observer) + new_arg = arg + + is_standalone_module = qhandler is not None and qhandler.is_standalone_module() + # TODO: move this to a separate function + if not is_standalone_module: + # Note: qconfig can be None in this branch this we are getting act/fq from + # node.meta now + # regular flow for most nodes, except standalone modules + + if "quantization_annotation" in node.meta: + reuse_input_obs_or_fq = node.meta[ + "quantization_annotation" + ]._reuse_input_obs_or_fq + else: + if "target_dtype_info" not in node.meta: + raise AssertionError("expected 'target_dtype_info' in node.meta") + # TODO: we are assuming "target_dtype_info" exists here, maybe + # a default value also need to be provided here + target_dtype_info = node.meta["target_dtype_info"] + # for nodes that doesn't have `reuse_input_obs_or_fq` configured, + # we'll default to False, this makes configuring this field optional for users + reuse_input_obs_or_fq = target_dtype_info.get( + "reuse_input_obs_or_fq", False + ) + arg_as_input_act_obs_or_fq = _get_arg_as_input_act_obs_or_fq( + arg, node, named_modules, obs_or_fq_map, is_qat + ) + ( + arg_as_input_target_dtype, + arg_as_input_target_is_dynamic, + ) = _get_dtype_and_is_dynamic(arg_as_input_act_obs_or_fq) + + arg_as_output_act_obs_or_fq = _get_output_act_obs_or_fq( + arg, named_modules, obs_or_fq_map, is_qat + ) + ( + arg_as_output_target_dtype, + arg_as_output_target_is_dynamic, + ) = _get_dtype_and_is_dynamic(arg_as_output_act_obs_or_fq) + + needs_obs_or_fq = _needs_obs_or_fq( + arg_as_output_target_dtype, + arg_as_output_target_is_dynamic, + arg_as_input_target_dtype, + arg_as_input_target_is_dynamic, + reuse_input_obs_or_fq, + is_zeroth_arg=len(node.args) > 0 and arg is node.args[0], + ) + + else: + if qconfig is None: + raise AssertionError("qconfig must not be None") + # custom flow for standalone modules + _, _, sm_prepare_custom_config, _ = _get_standalone_module_configs( + node, named_modules, prepare_custom_config, qconfig, backend_config + ) + sm_input_quantized_idxs = sm_prepare_custom_config.input_quantized_indexes + + # for args, this is set to the index of the current arg + # for kwargs, this is left at None + cur_input_idx = None + for arg_idx, arg_to_check in enumerate(node.args): + if arg_to_check is arg: + cur_input_idx = arg_idx + break + + if cur_input_idx is None: + needs_obs_or_fq = False + else: + arg_as_output_target_dtype = _get_arg_target_dtype_as_output( + arg, named_modules, obs_or_fq_map, is_qat + ) + arg_as_input_target_dtype = ( + torch.quint8 + if cur_input_idx in sm_input_quantized_idxs + else torch.float + ) + needs_obs_or_fq = ( + arg_as_output_target_dtype != arg_as_input_target_dtype + ) and (arg_as_input_target_dtype != torch.float) + + act_post_process_ctr = qconfig.activation + arg_as_input_act_obs_or_fq = ( + act_post_process_ctr() if act_post_process_ctr else None + ) + + if needs_obs_or_fq: + existing_obs_node = None + + # Before using the new observer, check if an observer + # of the correct type already exists. If it does, use it. + # This prevents duplicate observer insertions if a node is + # used by multiple nodes. + # TODO: this is looking into how the value is used in the future + # we should remove this + # removing this means we insert one observer for each use, even if they + # have the same dtype, we can have an extra pass that removes the extra observers + for maybe_obs_node in arg.users: + if maybe_obs_node.op == "call_module": + maybe_obs_mod = named_modules[maybe_obs_node.target] # type: ignore[index] + if ( + type(maybe_obs_mod) is type(arg_as_input_act_obs_or_fq) + and maybe_obs_mod.dtype == arg_as_input_target_dtype # type: ignore[possibly-undefined] + ): + arg_as_input_act_obs_or_fq = maybe_obs_mod # type: ignore[assignment] + existing_obs_node = maybe_obs_node + break + + if arg_as_input_act_obs_or_fq is None: + raise AssertionError("arg_as_input_act_obs_or_fq must not be None") + obs_or_fq_map[(arg, node)] = arg_as_input_act_obs_or_fq + if existing_obs_node is None: + new_obs_node = _insert_obs_or_fq( + arg, + arg_as_input_act_obs_or_fq, + model, + named_modules, + graph, + model_device, + ) + # override this arg to be the observed arg + new_arg = new_obs_node + else: + new_arg = existing_obs_node + + return new_arg + + +def _maybe_insert_input_observers_for_node( + node: Node, + qconfig: QConfigAny, + model: torch.nn.Module, + named_modules: dict[str, torch.nn.Module], + graph: Graph, + qhandler: QuantizeHandler | None, + prepare_custom_config: PrepareCustomConfig, + obs_or_fq_map: dict[EdgeOrNode, ObserverOrFakeQuantize], + is_qat: bool, + backend_config: BackendConfig | None = None, + model_device: torch.device | None = None, +) -> None: + """ + If needed, inserts observers to the input args and kwargs of `node`. + Note: modifies `node` inplace. + + For example, if cur_node needs an observer after prev_node, we change from + + prev_node -> cur_node + + To + + prev_node -> obs -> cur_node + + Note: backend_config only needed for standalone_module node + """ + # Look through every input arg. If that arg's target dtype does not + # match the current node's target dtype, insert an observer. + new_args = [] + for arg in node.args: + new_arg = _maybe_insert_input_observer_for_arg_or_kwarg( + node, + arg, + qconfig, + model, + named_modules, + graph, + qhandler, + prepare_custom_config, + obs_or_fq_map, + is_qat, + backend_config, + model_device, + ) + new_args.append(new_arg) + + new_kwargs = {} + for k, kwarg in node.kwargs.items(): + new_kwarg = _maybe_insert_input_observer_for_arg_or_kwarg( + node, + kwarg, + qconfig, + model, + named_modules, + graph, + qhandler, + prepare_custom_config, + obs_or_fq_map, + is_qat, + backend_config, + model_device, + ) + new_kwargs[k] = new_kwarg + + # assign the new args and kwargs to the node, inplace + node.args = tuple(new_args) + node.kwargs = new_kwargs + + +def _maybe_insert_input_equalization_observers_for_node( + node: Node, + equalization_qconfig: Any, + model: torch.nn.Module, + named_modules: dict[str, torch.nn.Module], + graph: Graph, + is_branch: bool, +) -> None: + """ + If `node` needs to be equalized, find the input/weight observers it needs in + `equalization_qconfig`, creates them, and inserts it into `graph`. + + If `node` does not need an equalization observer, returns None. + """ + if equalization_qconfig is None or not node_supports_equalization( + node, named_modules + ): + return + + if is_branch: + warnings.warn( + f"Cannot equalize {node} because it is part of a branch.", stacklevel=2 + ) + return + + new_args = [] + for arg in node.args: + if not isinstance(arg, Node) or node_arg_is_bias(node, arg): + new_args.append(arg) + continue + + is_weight = node_arg_is_weight(node, arg) + + act_eq_process_ctr = ( + equalization_qconfig.weight + if is_weight + else equalization_qconfig.input_activation + ) + + new_eq_obs_mod = act_eq_process_ctr() + new_eq_obs_node = _insert_obs_or_fq( + arg, new_eq_obs_mod, model, named_modules, graph + ) + + new_args.append(new_eq_obs_node) + + # assign the new args and kwargs to the node, inplace + node.args = tuple(new_args) + + +def _maybe_insert_output_observer_for_node( + node: Node, + model: torch.nn.Module, + named_modules: dict[str, torch.nn.Module], + graph: Graph, + obs_or_fq_map: dict[EdgeOrNode, ObserverOrFakeQuantize], + is_qat: bool, +) -> Node | None: + """ + If `node` needs an output observer, creates it, inserts it into `graph` + and returns it. + + If `node` does not need an output observer, returns None. + + Note: inserting dynamic quantization ops for output is not supported in fx graph mode + quantization code path right now + """ + if node.op == "output": + raise AssertionError("observer insertion for outputs is handled elsewhere") + + is_standalone_module = False + if "quantization_annotation" in node.meta: + output_act_obs_or_fq = _create_obs_or_fq_from_qspec( + node.meta["quantization_annotation"].output_qspec, obs_or_fq_map, is_qat + ) + else: + if "target_dtype_info" not in node.meta: + raise AssertionError("expected 'target_dtype_info' in node.meta") + is_standalone_module = node.meta["target_dtype_info"].get( + "_is_standalone_module", False + ) + output_act_obs_or_fq_ctr = node.meta["target_dtype_info"].get( + "output_act_obs_or_fq_ctr" + ) + output_act_obs_or_fq = ( + output_act_obs_or_fq_ctr() if output_act_obs_or_fq_ctr else None + ) + target_dtype, target_is_dynamic = _get_dtype_and_is_dynamic(output_act_obs_or_fq) + # uncomment after we support reuse_input_obs_or_fq properly by having separate + # implementations for this key instead of reusing the input_output_share_observers + # code + # reuse_input_obs_or_fq = node.meta["target_dtype_info"].get("reuse_input_obs_or_fq", False) + # for now we set this to False since reuse_input_obs_or_fq for + # the output of a node is implementation in the same code path as observer sharing, + # we should refactor this part to make it clearer in the future + # and we would be able to read this from config directly + reuse_input_obs_or_fq = False + + # Note: prev_output_dtype = torch.float and prev_output_is_dynamic=False + # because the prev_output is the output of an fp32 op, although technically + # we should get the dtype of the output from node.meta["val"] in the future + # if we deprecate fx graph mode quantization + needs_obs_or_fq = _needs_obs_or_fq( + torch.float, False, target_dtype, target_is_dynamic, reuse_input_obs_or_fq + ) + # currently the activation in QConfig(activation=...,) is for both input + # and output, and when the activation is configured to be dynamic quantization + # e.g. PlaceholderObserver(dtype=torch.quint8, is_dynamic=True, ...), it means + # the input should by dynamically quantized, but output should not be quantized + # + # there is no way we can specify different observer/fq for input and output + # activation through QConfig today, this limitation is lifted in the + # quantizer/annotation API in pytorch 2.0 export quantization code path, + # but since this code is reused, annotating output to be dynamically quantized + # would not work either for that. + # we can change QConfig to support input/output activation if we want + # to remove the following check, or if we can deprecate fx graph mode quantization + if target_is_dynamic: + needs_obs_or_fq = False + + # we never insert observers to output of standalone module, we assume + # if needed, they are inserted inside the standalone module + needs_obs_or_fq = needs_obs_or_fq and (not is_standalone_module) + + if needs_obs_or_fq: + obs_or_fq_map[node] = output_act_obs_or_fq + return _insert_obs_or_fq( + node, output_act_obs_or_fq, model, named_modules, graph + ) + else: + return None + + +def _maybe_insert_observers_before_graph_output( + graph_output_node: Node, + model: torch.nn.Module, + named_modules: dict[str, torch.nn.Module], + graph: Graph, + obs_or_fq_map: dict[EdgeOrNode, ObserverOrFakeQuantize], + is_qat: bool, +) -> None: + """ + If the output needs to be quantized and there are any nodes + in the output which are not already observed, inserts observers + for those nodes. + """ + + def _recursive_maybe_replace_node_with_obs( + maybe_node: Argument, + model: torch.nn.Module, + named_modules: dict[str, torch.nn.Module], + graph: Graph, + ) -> Argument: + """ + Navigate an arbitrary data structure of lists, tuples, dicts. + For each container type, recurse on all inputs. Once any Node + is found, insert an observer if needed and do not recurse further. + + For example, given a structure of + + {'foo1': [[bar1]], 'foo2': {'foo3': [[[bar3]]]}} + + we recurse down to bar1 and bar3, observe them if necessary, + and if we inserted an observer then replace the original node + with its observer. + + Returns the data structure with all nodes needing observation being + replaced by their observers. + """ + if isinstance(maybe_node, Node): + # check dtype of this node + arg_as_output_target_dtype = _get_arg_target_dtype_as_output( + maybe_node, named_modules, obs_or_fq_map, is_qat + ) + observer_mod = None + arg_as_input_target_dtype = torch.float + if "target_dtype_info" in maybe_node.meta: + observer_cls = maybe_node.meta["target_dtype_info"].get( + "input_act_obs_or_fq_ctr", None + ) + if observer_cls is not None: + observer_mod = observer_cls() + arg_as_input_target_dtype = observer_mod.dtype + # TODO: this does not handle dynamic quantization yet + need_obs = ( + arg_as_output_target_dtype != arg_as_input_target_dtype + and arg_as_input_target_dtype != torch.float + ) + if need_obs: + if observer_mod is None: + raise AssertionError( + "observer_mod must not be None when need_obs is True" + ) + # insert observer + observer_node = _insert_obs_or_fq( + maybe_node, observer_mod, model, named_modules, graph + ) + return observer_node + else: + return maybe_node + elif isinstance(maybe_node, (list, tuple)): + results = [ + _recursive_maybe_replace_node_with_obs( + inner_node, model, named_modules, graph + ) + for inner_node in maybe_node + ] + if isinstance(maybe_node, list): + return results + else: + return tuple(results) + elif isinstance(maybe_node, dict): + results_dict = {} + for k, inner_v in maybe_node.items(): + results_dict[k] = _recursive_maybe_replace_node_with_obs( + inner_v, model, named_modules, graph + ) + return results_dict + elif maybe_node is None: + return None + else: + raise Exception( # noqa: TRY002 + "Unhandled type for returned node:", maybe_node + ) + + new_args = [ + _recursive_maybe_replace_node_with_obs(old_arg, model, named_modules, graph) + for old_arg in graph_output_node.args + ] + + graph_output_node.args = tuple(new_args) # type: ignore[assignment] + + +def _maybe_propagate_dtype_for_node( + node: Node, + target_dtype: torch.dtype | type, + node_name_to_match_result_with_qconfig: dict[str, _MatchResultWithQConfig], +) -> None: + """ + Assigns `target_dtype` to `node`, setting `is_dynamic` to False. If `node` + is a general tensor shape op, also call this function recursively on + the first argument, to propagate the dtype to the caller. + """ + node.meta["target_dtype_info"]["input_act_obs_or_fq_ctr"] = None + node.meta["target_dtype_info"]["output_act_obs_or_fq_ctr"] = None + # if this is a copy node, propagate to first arg + ( + _root_node, + _, + _pattern, + qhandler, + _qconfig, + ) = node_name_to_match_result_with_qconfig.get( + node.name, (None, None, None, None, None) + ) + # TODO: probably need to remove `is_general_tensor_value_op` + if qhandler is not None and qhandler.is_general_tensor_value_op(): + prev_node = node.args[0] + if isinstance(prev_node, Node): + _maybe_propagate_dtype_for_node( + prev_node, target_dtype, node_name_to_match_result_with_qconfig + ) + + +def propagate_dtypes_for_known_nodes( + graph: Graph, + node_name_to_match_result_with_qconfig: dict[str, _MatchResultWithQConfig], +) -> None: + """ + Currently we assume that inputs to the graph are either `torch.float` or + `torch.quint8`, which is not always correct. For ops such as + `x.masked_fill(mask, value)`, we know that the dtype of `mask` is a + `BoolTensor`. Propagate this information throughout the graph. + + Note: not all dtypes in the graph will be correct after this pass, but a + higher percentage of them will be correct. Hopefully in the future we can + replace this with a better way to reason about dtypes of tensors. + """ + for node in graph.nodes: + non_observable_arg_dict = get_non_observable_arg_indexes_and_types(node) + + for arg_type in non_observable_arg_dict: + non_observable_indices = non_observable_arg_dict[arg_type](node) + + for index in non_observable_indices: + arg = node.args[index] + + # when an argument is a tuple, it does not show up as another node so we need to go through + # all elements of the tuple manually + if isinstance(arg, (tuple, list)): + arg_list = list(arg) + else: + arg_list = [arg] + + for cur_arg in arg_list: + # hard coded arguments show up but aren't `Node` typed and do not need dtype propagated + if isinstance(cur_arg, torch.fx.node.Node): + _maybe_propagate_dtype_for_node( + cur_arg, arg_type, node_name_to_match_result_with_qconfig + ) + + +def _maybe_make_input_output_share_observers( + node: Node, + model: torch.nn.Module, + named_modules: dict[str, torch.nn.Module], +) -> bool: + """ + Ensures that we share an observer + for all input arguments as well as the output argument. In detail, given + a graph of + + x0 -> obs0 -> op -> x2 + / + x1 -> obs1 / + + where node obs0 points to observer instance observer0, + obs1 points to observer1 and obs2 points to observer2, we make nodes obs1 + and ob2 point to observer0. + Returns: whether the operation succeeded or not + """ + first_arg = None + # find the first non-Tensor arg + for i in range(len(node.args)): + if isinstance(node.args[i], (Node, list, tuple)): + first_arg = node.args[i] + break + + # if there is no non-Tensor arg, return directly + if first_arg is None: + return False + + if isinstance(first_arg, (list, tuple)): + first_arg_arg = first_arg[0] + elif isinstance(first_arg, Node): + first_arg_arg = first_arg + else: + return False + + # if we have a graph such as + # observed_node -> non_observed_node -> cat + # we need to navigate up to the first observer + iteration_guard = 0 + while not _is_activation_post_process_node(first_arg_arg, named_modules): + if not isinstance(first_arg_arg, Node): + return False + # did not find an activation_post_process for the op + if first_arg_arg.op == "placeholder": + return False + # trace back the args until we found the first Tensor/Node + trace_back_node = None + for i in range(len(first_arg_arg.args)): + trace_back_node = first_arg_arg.args[i] + if isinstance(trace_back_node, Node): + break + if trace_back_node is None: + return False + first_arg_arg = trace_back_node + + iteration_guard += 1 + if iteration_guard > 10000: + raise AssertionError("Unable to find observer of previous node") + + if not isinstance(first_arg_arg, Node): + raise AssertionError("first_arg_arg must be a Node") + target_to_use = first_arg_arg.target + if not isinstance(target_to_use, str): + raise AssertionError("target_to_use must be a string") + obs_mod_to_use = named_modules[target_to_use] + + if isinstance(first_arg, (list, tuple)): + # set all other input observer nodes to use that module + for input_idx, input_arg in enumerate(first_arg): + if input_idx == 0: + continue + iteration_guard = 0 + while not _is_activation_post_process_node(input_arg, named_modules): + # failed to trace back since no input arg for the current node + if len(input_arg.args) < 1: + return False + input_arg = input_arg.args[0] + iteration_guard += 1 + if iteration_guard > 10000: + raise AssertionError("Unable to find observer of previous node") + + parent_name, name = _parent_name(input_arg.target) + setattr(named_modules[parent_name], name, obs_mod_to_use) + + # set the output observer node to use that module + for output_obs_node in node.users: + if not _is_activation_post_process_node(output_obs_node, named_modules): + raise AssertionError( + "output_obs_node must be an activation post process node" + ) + parent_name, name = _parent_name(output_obs_node.target) + setattr(named_modules[parent_name], name, obs_mod_to_use) + + # TODO(future PR): delete the orphaned observer modules + return True + + +def _remove_output_observer( + node: Node, model: torch.nn.Module, named_modules: dict[str, torch.nn.Module] +): + items = list(node.users.items()) + for output_obs_node, _ in items: + if not _is_activation_post_process_node(output_obs_node, named_modules): + raise AssertionError( + "output_obs_node must be an activation post process node" + ) + output_obs_node.replace_all_uses_with(node) + model.graph.erase_node(output_obs_node) # type: ignore[union-attr, operator] + + +def _swap_custom_module_to_observed( + node: Node, + qconfig: QConfigAny, + named_modules: dict[str, torch.nn.Module], + prepare_custom_config: PrepareCustomConfig, +): + custom_module = named_modules[node.target] # type: ignore[index] + custom_module_class_mapping = prepare_custom_config.float_to_observed_mapping + observed_custom_module_class = get_swapped_custom_module_class( + custom_module, custom_module_class_mapping, qconfig + ) + observed_custom_module = observed_custom_module_class.from_float(custom_module) + parent_name, name = _parent_name(node.target) + setattr(named_modules[parent_name], name, observed_custom_module) + + +def insert_observers_for_model( + model: GraphModule, + node_name_to_match_result_with_qconfig: dict[str, _MatchResultWithQConfig], + node_name_to_qconfig: dict[str, QConfigAny], + prepare_custom_config: PrepareCustomConfig, + equalization_config_map: dict[str, Any], + backend_config: BackendConfig, + observed_node_names: set[str], + is_qat: bool, +) -> Node | None: + """ + Inserts observers, using the following high level algorithm: + + For each node in the graph: + 1. determine the target dtype of this node in the quantized graph, and save + it for future steps + 2. determine the target dtype or all args and kwargs of this node + 3. if any arg or kwarg's target dtype does not match the current node's + dtype, insert an observer + 4. if the current node needs an output observer, insert it + + For example: + + - starting graph: + x0 -> linear -> x1 + + - observed graph after processing x0: + x0(fp32) + + - observed graph after processing linear: + x0(fp32) -> x0_obs0(int8) -> linear(int8) -> linear_obs0(int8) + + - observed graph after processing x1: + x0(fp32) -> x0_obs0(int8) -> linear(int8) -> linear_obs0(int8) -> x1 + + After a node is processed, the naive observer placement is guaranteed to be + complete for that node and all of its predecessors. There can be future + passes which optimize the graph by deduplicating observers, etc. + """ + + # node.meta["target_dtype_info"] stores the target dtype information + # that's derived from qconfig for the Node, for example, if we have + # a conv2d node that has a qconfig + # qconfig = QConfig(activation=..., weight=...) + # # information for input and bias node omitted + # # for getattr node + # # weight = getattr(self, 'weight') + # weight.meta["target_dtype_info"] = { + # 'output_act_obs_or_fq_ctr': qconfig.weight, + # } + # # for conv2d node + # # conv2d = call_function[target=torch.nn.functional.conv2d]( + # # args=(input, weight, bias)) + # conv2d.meta["target_dtype_info"] = { + # 'input_act_obs_or_fq_ctr': qconfig.activation + # 'weight_obs_or_fq_ctr': qconfig.weight, + # 'bias_obs_or_fq_ctr': PlaceholderObserver.with_args(dtype=torch.float32), + # 'output_act_obs_or_fq_ctr': qconfig.activation, + # } + # + cache_for_no_tensor_check: dict[Node, bool] = {} + + # first, populate the dtype map based only on qconfig and qhandler + # this assumes: + # graph inputs are fp32 by default, and int8 where overridden + # other nodes output dtype is specified by the qconfig + named_modules = dict(model.named_modules(remove_duplicate=False)) + + input_quantized_idxs: list[int] = prepare_custom_config.input_quantized_indexes + output_quantized_idxs: list[int] = prepare_custom_config.output_quantized_indexes + processed_nodes: set[Node] = set() + # initialize target_dtype_info + for node in model.graph.nodes: + node.meta["target_dtype_info"] = copy.copy( + _DEFAULT_FP32_QCONFIG_FOR_TARGET_DTYPE_INFO + ) + + inputs_seen_counter = 0 + outputs_seen_counter = 0 + placeholder_node_to_input_index: dict[Node, int] = {} + # TODO: we probably don't need this counter since each graph will only have + # one output node? + output_node_to_output_index: dict[Node, int] = {} + for node in model.graph.nodes: + if node.op == "placeholder": + placeholder_node_to_input_index[node] = inputs_seen_counter + inputs_seen_counter += 1 + if node.op == "output": + output_node_to_output_index[node] = outputs_seen_counter + outputs_seen_counter += 1 + + # Step 1, set the observer or fake quantize module constructor for each node in the + # matched_node_pattern + + for match_res_with_qconfig in node_name_to_match_result_with_qconfig.values(): + ( + last_node, + matched_node_pattern, + pattern, + qhandler, + qconfig, + ) = match_res_with_qconfig + if qhandler is None: + raise AssertionError("qhandler must not be None") + _set_target_dtype_info_for_matched_node_pattern( + matched_node_pattern, + last_node, + qconfig, + qhandler, + backend_config, + named_modules, + cache_for_no_tensor_check, + processed_nodes, + ) + + # Step 2. Special cases for some operators, we might be able to remove them + # in the future if we know dtype information of each node better + + # Step 2.1. some settings are not based on patterns, we need to process each node + # instead + for node in model.graph.nodes: + if ( + node.op == "placeholder" + and placeholder_node_to_input_index[node] in input_quantized_idxs + ): + # users are not supposed to call calculate_qparams on PlaceholderObserver, and + # this is OK because we are using this as a way to encode the dtypes of input + # tensor, we won't actually insert these observers in the graph and won't + # actually call calculate_qparams + node.meta["target_dtype_info"] = copy.copy( + _DEFAULT_QUINT8_QCONFIG_FOR_TARGET_DTYPE_INFO + ) + elif node.op in ("call_module", "call_method", "call_function"): + args_have_no_tensors = all_node_args_have_no_tensors( + node, named_modules, cache_for_no_tensor_check + ) + if args_have_no_tensors: + node.meta["target_dtype_info"] = { + "input_act_obs_or_fq_ctr": None, + "output_act_obs_or_fq_ctr": None, + } + elif ( + node.op == "output" + and output_node_to_output_index[node] in output_quantized_idxs + ): + # TODO(future PR): update the output_quantized_idxs API to match + # arbitrary data structures. There is always a single output, and + # that output can have arbitrary nesting of values. List[int] is + # not the right data type for this. + + # TODO(future PR): support more dtypes in model outputs, if necessary + node.meta["target_dtype_info"] = copy.copy( + _DEFAULT_QUINT8_QCONFIG_FOR_TARGET_DTYPE_INFO + ) + + # Step 2.2, for nodes with known input dtypes, propagate them throughout the + # graph. For example, if there is a call such as + # x1 = x0.masked_fill(mask, 1) + # we propagate the type of mask to be torch.bool + propagate_dtypes_for_known_nodes( + model.graph, node_name_to_match_result_with_qconfig + ) + + # Step 3, check if the requested target_dtype_info is supported by backend or not + # if not, we'll reset the target_dtye_info to use the default (float Tensor) + + # reset the counters and set of processed_nodes + processed_nodes: set[Node] = set() + for match_res_with_qconfig in node_name_to_match_result_with_qconfig.values(): + ( + last_node, + matched_node_pattern, + pattern, + qhandler, + qconfig, + ) = match_res_with_qconfig + is_supported_by_backend = ( + _is_pattern_dtype_config_and_qconfig_supported_by_backend( + pattern, matched_node_pattern, qconfig, backend_config + ) + ) + if qhandler is None: + raise AssertionError("qhandler must not be None") + + # get output_act_dtype so that we don't also reset the special typed nodes + # TODO: we might want to handle these more uniformly with the default path + # this can be improved if we can use node.meta["val"] + output_act_or_fq_ctr = node.meta["target_dtype_info"][ + "output_act_obs_or_fq_ctr" + ] + output_act_or_fq = output_act_or_fq_ctr() if output_act_or_fq_ctr else None + output_act_dtype, _ = _get_dtype_and_is_dynamic(output_act_or_fq) + if not is_supported_by_backend and output_act_dtype not in [ + None, + int, + float, + torch.bool, + ]: + # restore target_dtype_info to default if it is not supported by backend + _set_target_dtype_info_for_matched_node_pattern( + matched_node_pattern, + last_node, + torch.ao.quantization.qconfig._default_fp32_placeholder_qconfig, + None, + backend_config, + named_modules, + cache_for_no_tensor_check, + processed_nodes, + ) + + # After this point, the current node and all of its arguments + # have a target_dtype_info assigned. Now, we insert observers for inputs + # of this node (if needed for this node), and the output of this node + # (if needed for this node). + + # Since we are mutating the graph as we go, we iterate over the original + # nodes before observer insertion, instead of model.graph.nodes. + nodes_before_observation = list(model.graph.nodes) + + # Avoid duplicates custom module swaps for multiple nodes with same target. + custom_module_names_already_swapped: set[str] = set() + + # TODO: reuse placeholder_node_to_input_index and output_node_to_output_index + # reset inputs/outputs counters + inputs_seen_counter = 0 + outputs_seen_counter = 0 + results_node = None + obs_or_fq_map: dict[EdgeOrNode, ObserverOrFakeQuantize] = {} + model_device = assert_and_get_unique_device(model) + + # TODO: change this to insert obs/fq by pattern instead of by node + for node in nodes_before_observation: + if node.op == "placeholder": + # if a graph input is in fp32, it does not need observation + # if a graph input is in int8, we assume the observation happens + # outside of the graph, and no additional observation is needed + pass + + elif node.op in ("call_module", "call_method", "call_function", "output"): + # check for matches + ( + last_node, + matched_node_pattern, + pattern, + qhandler, + qconfig, + ) = node_name_to_match_result_with_qconfig.get( # type: ignore[assignment] + node.name, (None, None, None, None, None) + ) + equalization_qconfig = equalization_config_map.get(node.name, None) + + this_node_dtype_info = node.meta["target_dtype_info"] + if "val" in node.meta: + output_is_a_tensor = this_node_dtype_info is not None and isinstance( + node.meta["val"], FakeTensor + ) + else: + output_is_a_tensor = this_node_dtype_info is not None + + skip_inserting_observers = ( + (qconfig is None) or not output_is_a_tensor + ) and (node.op != "output") + + # TODO: take a closer look to see if we can remove this check + # right now it is here because of `observed_node_names`, we are using + # it as an indicator for swapping the modules to reference modules in + # convert + is_supported_by_backend = ( + _is_pattern_dtype_config_and_qconfig_supported_by_backend( + pattern, matched_node_pattern, qconfig, backend_config + ) + ) + + if not skip_inserting_observers and is_supported_by_backend: + named_modules = dict(model.named_modules(remove_duplicate=False)) + if node.op != "output": + if matched_node_pattern is None: + raise AssertionError("matched_node_pattern must not be None") + # add matched nodes to the observed node name set + _add_matched_node_name_to_set( + matched_node_pattern, observed_node_names + ) + + # This is currently only used for equalization. + # Checks if the current node is in a branch in which the two + # first layers are both being quantized. + # + # ex. conv2 + # / + # x -> conv1 + # + # If this is the case, we will not apply equalization to the + # initial two layers. + is_quantized_branch = False + if ( + len(node.args) > 0 + and isinstance(node.args[0], Node) + and len(node.args[0].users) > 1 + ): + for user in node.args[0].users: + # Checks if there exists another user being quantized + is_user_quantized = node_name_to_qconfig.get( + user.name, None + ) is not None or ( + user.op == "call_module" + and isinstance( + named_modules[str(user.target)], ObserverBase + ) + ) + if user != node and is_user_quantized: + is_quantized_branch = True + + pattern_to_root_node_getter = ( + get_fusion_pattern_to_root_node_getter(backend_config) + ) + root_node_getter = pattern_to_root_node_getter.get( + pattern, _default_root_node_getter + ) + root_node = root_node_getter(matched_node_pattern) + is_input_node_of_the_pattern = node is root_node + if is_input_node_of_the_pattern: + # this modifies node inplace + _maybe_insert_input_observers_for_node( + node, + qconfig, + model, + named_modules, + model.graph, + qhandler, + prepare_custom_config, + obs_or_fq_map, + is_qat, + backend_config, + model_device, + ) + + # insert equalization input observers if needed + _maybe_insert_input_equalization_observers_for_node( + node, + equalization_qconfig, + model, + named_modules, + model.graph, + is_quantized_branch, + ) + + is_last_node_of_pattern = node is last_node + input_output_share_observers = node.meta["target_dtype_info"].get( + "input_output_share_observers", False + ) + reuse_input_obs_or_fq = node.meta["target_dtype_info"].get( + "reuse_input_obs_or_fq", False + ) + + if is_last_node_of_pattern: + if _is_custom_module_lstm( + node, named_modules, qconfig, qhandler + ): + # Currently custom module outputs are assumed to be already quantized, + # so we need to insert a DeQuantStub after the output. For custom module + # LSTM specifically, the outputs are also a nested tuple, so we must first + # break down the tuple to insert DeQuantStubs after the internal nodes. + + # TODO: This currently diverges from how custom modules are handled today, + # where we insert observers after the output instead of DeQuantStubs, and + # replace these observers with "dequantize" nodes during convert. Conceptually, + # these output observers are the same as DeQuantStubs. In the future, we + # should resolve this inconsistency by inserting DeQuantStubs for all custom + # modules, not just for LSTM. + _insert_dequant_stubs_for_custom_module_lstm_output( + node, model, named_modules, model.graph + ) + if node.target not in custom_module_names_already_swapped: + custom_module_names_already_swapped.add(node.target) + _swap_custom_module_to_observed( + node, qconfig, named_modules, prepare_custom_config + ) + else: + # this returns the new observer node if it was needed + maybe_output_obs_node = ( + _maybe_insert_output_observer_for_node( + node, + model, + named_modules, + model.graph, + obs_or_fq_map, + is_qat, + ) + ) + + if maybe_output_obs_node is not None: + # Update users of original node to use the output observer + # instead. For example, change + # + # next_node + # / + # cur_node -> obs + # + # to + # + # next_node + # / + # cur_node -> obs + # + # We need to save orig users before updating uses because + # the list of users will change as we update uses + orig_users = list(node.users.keys()) + for user_node in orig_users: + if user_node is maybe_output_obs_node: + continue + user_node.replace_input_with( + node, maybe_output_obs_node + ) + + _is_observer_in_same_graph_ = ( + _is_observer_in_same_graph( + node, named_modules, obs_or_fq_map, is_qat + ) + ) + + # for ops whose inputs and outputs share observer/fqs, we modify the graph + # to make all inputs and outputs use the first input's + # observer/fq + if ( + input_output_share_observers + and _is_observer_in_same_graph_ + ) or reuse_input_obs_or_fq: + if not _maybe_make_input_output_share_observers( + node, model, named_modules + ): + _remove_output_observer( + node, model, named_modules + ) + + if qhandler is not None and qhandler.is_custom_module(): + if ( + node.target + not in custom_module_names_already_swapped + ): + custom_module_names_already_swapped.add( + node.target + ) + _swap_custom_module_to_observed( + node, + qconfig, + named_modules, + prepare_custom_config, + ) + + else: # output + _maybe_insert_observers_before_graph_output( + node, model, named_modules, model.graph, obs_or_fq_map, is_qat + ) + + # + # After this point, the current node has input and output observers + # that it needs for itself inserted. + # + + # increment the counters, so future inputs and outputs are assigned + # correct dtypes + if node.op == "placeholder": + inputs_seen_counter += 1 + elif node.op == "output": + outputs_seen_counter += 1 + results_node = node + + return results_node + + +def _run_prepare_fx_on_standalone_modules( + model: torch.nn.Module, + is_qat: bool, + named_modules: dict[str, torch.nn.Module], + node_name_to_match_result_with_qconfig: Any, + prepare_custom_config: PrepareCustomConfig, + backend_config: BackendConfig, +) -> None: + """ + Runs prepare_fx on each standalone module. Note: this does + not modify the graph, it just replaces the unobserved modules with + their observed versions. + """ + for ( + root_node, + _, + _pattern, + qhandler, + qconfig, + ) in node_name_to_match_result_with_qconfig.values(): + if qhandler is None: + continue + elif not qhandler.is_standalone_module(): + continue + + ( + sm_qconfig_mapping, + sm_example_inputs, + sm_prepare_custom_config, + sm_backend_config, + ) = _get_standalone_module_configs( + root_node, named_modules, prepare_custom_config, qconfig, backend_config + ) + + standalone_module = named_modules[root_node.target] + prepare = torch.ao.quantization.quantize_fx._prepare_standalone_module_fx # type: ignore[attr-defined] + observed_standalone_module = prepare( + standalone_module, + sm_qconfig_mapping, + is_qat, + example_inputs=sm_example_inputs, + prepare_custom_config=sm_prepare_custom_config, + backend_config=sm_backend_config, + ) + parent_name, name = _parent_name(root_node.target) + setattr(named_modules[parent_name], name, observed_standalone_module) + named_modules[root_node.target] = observed_standalone_module + + +def _save_state( + observed: GraphModule, + node_name_to_qconfig: dict[str, QConfigAny], + node_name_to_scope: dict[str, tuple[str, type]], + prepare_custom_config: PrepareCustomConfig, + equalization_node_name_to_qconfig: dict[str, Any], + qconfig_mapping: QConfigMapping, + is_qat: bool, + observed_node_names: set[str], +) -> None: + observed.meta["_observed_graph_module_attrs"] = ObservedGraphModuleAttrs( + node_name_to_qconfig=node_name_to_qconfig, + node_name_to_scope=node_name_to_scope, + prepare_custom_config=prepare_custom_config, + equalization_node_name_to_qconfig=equalization_node_name_to_qconfig, + qconfig_mapping=qconfig_mapping, + is_qat=is_qat, + observed_node_names=observed_node_names, + ) + + +def prepare( + model: GraphModule, + qconfig_mapping: QConfigMapping | dict[str, Any], + is_qat: bool, + node_name_to_scope: dict[str, tuple[str, type]], + example_inputs: tuple[Any, ...], + prepare_custom_config: PrepareCustomConfig | dict[str, Any] | None = None, + _equalization_config: QConfigMapping | dict[str, Any] | None = None, + backend_config: BackendConfig | dict[str, Any] | None = None, + is_standalone_module: bool = False, +) -> GraphModule: + """standalone_module means it a submodule that is not inlined in + parent module, and will be quantized separately as one unit. + + How the standalone module is observed is specified by `input_quantized_idxs` and + `output_quantized_idxs` in the prepare_custom_config for the standalone module + Args: + node_name_to_scope: mapping from node name to the scope of the module which contains the node. + The scope is a tuple of fully qualified path of the module and the type of the module + Returns: + model(GraphModule): prepared standalone module + attributes related to standalone module + in model.meta["_observed_graph_module_attrs"]: + is_observed_standalone_module (bool): boolean value that shows whether the + current model is a observed standalone module or not + standalone_module_input_quantized_idxs(List[Int]): a list of + indexes for the graph input that is expected to be quantized, + same as input_quantized_idxs configuration provided + for the standalone module + standalone_module_output_quantized_idxs(List[Int]): a list of + indices for the graph output that is quantized + same as input_quantized_idxs configuration provided + for the standalone module + """ + if prepare_custom_config is None: + prepare_custom_config = PrepareCustomConfig() + if _equalization_config is None: + _equalization_config = QConfigMapping() + + if isinstance(qconfig_mapping, dict): + warnings.warn( + "Passing a QConfig dictionary to prepare is deprecated and will not be supported " + "in a future version. Please pass in a QConfigMapping instead.", + FutureWarning, + stacklevel=2, + ) + qconfig_mapping = QConfigMapping.from_dict(qconfig_mapping) + + if isinstance(_equalization_config, dict): + warnings.warn( + "Passing a QConfig dictionary to prepare for equalization is deprecated and will not " + "be supported in a future version. Please pass in a QConfigMapping instead.", + FutureWarning, + stacklevel=2, + ) + _equalization_config = QConfigMapping.from_dict(_equalization_config) + + if isinstance(prepare_custom_config, dict): + warnings.warn( + "Passing a prepare_custom_config_dict to prepare is deprecated and will not be supported " + "in a future version. Please pass in a PrepareCustomConfig instead.", + FutureWarning, + stacklevel=2, + ) + prepare_custom_config = PrepareCustomConfig.from_dict(prepare_custom_config) + + if isinstance(backend_config, dict): + warnings.warn( + "Passing a backend_config_dict to prepare is deprecated and will not be supported " + "in a future version. Please pass in a BackendConfig instead.", + FutureWarning, + stacklevel=2, + ) + backend_config = BackendConfig.from_dict(backend_config) + + if not isinstance(qconfig_mapping, QConfigMapping): + raise AssertionError("qconfig_mapping must be a QConfigMapping") + if not isinstance(_equalization_config, QConfigMapping): + raise AssertionError("_equalization_config must be a QConfigMapping") + qconfig_mapping = copy.deepcopy(qconfig_mapping) + _equalization_config = copy.deepcopy(_equalization_config) + + # mapping from a tuple of nodes in reverse order to uninitialized + # QuantizeHandler subclass. For example, + # { + # # match a single node + # (: + # ), + # # match multiple nodes in reverse order + # ((, ): + # ), + # } + + pattern_to_quantize_handler: dict[Pattern, QuantizeHandler] = {} + if backend_config is None: + backend_config = get_native_backend_config() + pattern_to_quantize_handler = _get_pattern_to_quantize_handlers(backend_config) + pattern_to_quantize_handler = _sorted_patterns_dict(pattern_to_quantize_handler) + + root_node_getter_mapping = get_fusion_pattern_to_root_node_getter(backend_config) + + # pyrefly: ignore [bad-argument-type] + _update_qconfig_for_fusion(model, qconfig_mapping) + # pyrefly: ignore [bad-argument-type] + _update_qconfig_for_fusion(model, _equalization_config) + # pyrefly: ignore [bad-argument-type] + flattened_qconfig_dict = _get_flattened_qconfig_dict(qconfig_mapping) + # TODO: support regex as well + propagate_qconfig_(model, flattened_qconfig_dict, prepare_custom_config.to_dict()) + + if is_qat: + module_to_qat_module = get_module_to_qat_module(backend_config) + _qat_swap_modules(model, module_to_qat_module) + # pyrefly: ignore [bad-argument-type] + _update_qconfig_for_qat(qconfig_mapping, backend_config) + + # mapping from fully qualified module name to module instance + # for example, + # { + # '': Model(...), + # 'linear': Linear(...), + # 'linear.weight_fake_quant': PerChannelMinMaxObserver(...), + # } + named_modules = dict(model.named_modules(remove_duplicate=False)) + + # fill node_name_to_qconfig, a map from node name to qconfig, used in _find_matches + equalization_node_name_to_qconfig = _generate_node_name_to_qconfig( + model, + named_modules, + model.graph, + # pyrefly: ignore [bad-argument-type] + _equalization_config, + node_name_to_scope, + ) + node_name_to_qconfig = _generate_node_name_to_qconfig( + model, + named_modules, + model.graph, + # pyrefly: ignore [bad-argument-type] + qconfig_mapping, + node_name_to_scope, + ) + + # match the patterns that will get quantized + standalone_module_names = list(prepare_custom_config.standalone_module_names.keys()) + standalone_module_classes = list( + prepare_custom_config.standalone_module_classes.keys() + ) + + custom_module_classes = get_custom_module_class_keys( + prepare_custom_config.float_to_observed_mapping + ) + matches_without_qconfig = _find_matches( + model.graph, + named_modules, + pattern_to_quantize_handler, + root_node_getter_mapping, + standalone_module_names, + standalone_module_classes, + custom_module_classes, + ) + + # map qconfig instances to matches + node_name_to_match_result_with_qconfig = {} + for node_name, match_without_qconfig in matches_without_qconfig.items(): + match_with_qconfig = (*match_without_qconfig, node_name_to_qconfig[node_name]) + node_name_to_match_result_with_qconfig[node_name] = match_with_qconfig + + _run_prepare_fx_on_standalone_modules( + model, + is_qat, + named_modules, + node_name_to_match_result_with_qconfig, + prepare_custom_config, + backend_config, + ) + + # record names for the set of observed node, so that in convert step + # we know whether we need to convert a floating point module to reference + # quantized module or not + observed_node_names: set[str] = set() + + result_node = insert_observers_for_model( + model, + node_name_to_match_result_with_qconfig, + node_name_to_qconfig, + prepare_custom_config, + equalization_node_name_to_qconfig, + backend_config, + observed_node_names, + is_qat, + ) + model = GraphModule(model, model.graph) + + _save_state( + model, + node_name_to_qconfig, + node_name_to_scope, + prepare_custom_config, + equalization_node_name_to_qconfig, + # pyrefly: ignore [bad-argument-type] + qconfig_mapping, + is_qat, + observed_node_names, + ) + + if is_standalone_module: + if result_node is None: + raise AssertionError("result_node must not be None for standalone modules") + if not isinstance(result_node.args[0], Node): + raise AssertionError( + "standalone module only supports returning simple value currently (not tuple, dict etc.)" + ) + # these inputs are observed in parent + # converting List[int] to Tensor since module attribute is + # Union[Tensor, Module] + input_quantized_idxs: list[int] = prepare_custom_config.input_quantized_indexes + output_quantized_idxs: list[int] = ( + prepare_custom_config.output_quantized_indexes + ) + observed_graph_module_attrs = model.meta["_observed_graph_module_attrs"] + # inplace modification + observed_graph_module_attrs.is_observed_standalone_module = True + observed_graph_module_attrs.standalone_module_input_quantized_idxs = ( + input_quantized_idxs + ) + observed_graph_module_attrs.standalone_module_output_quantized_idxs = ( + output_quantized_idxs + ) + return model diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/qconfig_mapping_utils.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/qconfig_mapping_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..783cba8149e6e09164d01c7f9ebafdc2e6240428 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/qconfig_mapping_utils.py @@ -0,0 +1,401 @@ +# mypy: allow-untyped-defs +import re +from collections import defaultdict, OrderedDict +from collections.abc import Callable +from typing import Any + +import torch +from torch.ao.nn.intrinsic import _FusedModule +from torch.ao.quantization import QConfig +from torch.ao.quantization.backend_config import BackendConfig, DTypeConfig +from torch.ao.quantization.backend_config.utils import get_module_to_qat_module +from torch.ao.quantization.observer import _is_activation_post_process +from torch.ao.quantization.qconfig import ( + _add_module_to_qconfig_obs_ctr, + qconfig_equals, + QConfigAny, +) +from torch.ao.quantization.qconfig_mapping import ( + _MODULE_NAME_DICT_KEY, + _MODULE_NAME_REGEX_DICT_KEY, + _OBJECT_TYPE_DICT_KEY, + QConfigMapping, +) +from torch.ao.quantization.utils import _parent_name, get_qconfig_dtypes +from torch.fx import GraphModule +from torch.fx.graph import Graph + + +__all__: list[str] = [] + + +def _maybe_adjust_qconfig_for_module_name_object_type_order( + qconfig_mapping: QConfigMapping, + cur_module_path: str, + cur_object_type: Callable, + cur_object_type_idx: int, + fallback_qconfig: QConfigAny, +) -> QConfigAny: + for ( + module_name, + object_type, + index, + ), qconfig in qconfig_mapping.module_name_object_type_order_qconfigs.items(): + if ( + (module_name == cur_module_path) + and (object_type == cur_object_type) + and (index == cur_object_type_idx) + ): + return qconfig + return fallback_qconfig + + +def _update_qconfig_for_fusion(model: GraphModule, qconfig_mapping: QConfigMapping): + """ + Update the QConfigMapping to account for fused modules such as LinearReLU. + This assumes the QConfigMapping's attributes have already been converted to OrderedDicts. + """ + object_type_dict = qconfig_mapping.object_type_qconfigs + if len(object_type_dict) == 0: + return qconfig_mapping + + modules = dict(model.named_modules()) + + for node in model.graph.nodes: + if node.op == "call_module" and node.target in modules: + maybe_fused_module = modules[str(node.target)] + if not isinstance(maybe_fused_module, _FusedModule): + continue + + ops = list(maybe_fused_module._modules.values()) + fused_qconfig = object_type_dict.get(type(ops[0]), None) + + # Raise an error if the modules in the fused module have + # different qconfigs specified in the qconfig_dict + # TODO: currently it only works for modules, + # need to make this work for torch.nn.functional.relu + # TODO: currently it only works for object_type configurations, + # ideally it should work for different types of configurations, + # maybe we want to redesign this part + for op in ops[1:]: + if not qconfig_equals( + object_type_dict.get(type(op), None), fused_qconfig + ): + raise LookupError( + "During fusion, we need to specify the same " + + f"qconfigs for all module types in {type(maybe_fused_module)} " + + f"offending type: {type(op)}" + ) + + if fused_qconfig is not None: + object_type_dict[type(maybe_fused_module)] = fused_qconfig + + +def _generate_node_name_to_qconfig( + root: torch.nn.Module, + modules: dict[str, torch.nn.Module], + input_graph: Graph, + qconfig_mapping: QConfigMapping, + node_name_to_scope: dict[str, tuple[str, type]], +) -> dict[str, QConfigAny]: + global_qconfig = qconfig_mapping.global_qconfig + node_name_to_qconfig = {} + + # example: + # + # {'foo.bar': {F.linear: 0, F.conv2d: 1, ...}, ...} + # + # meaning in submodule 'foo.bar', we have seen 0 F.linear and + # 1 F.conv2d invocations so far. + submodule_to_object_type_to_cur_idx: dict[str, dict[Callable, int]] = defaultdict( + lambda: defaultdict(int) + ) + for node in input_graph.nodes: + qconfig = None + if node.op == "get_attr": + module_name, _ = _parent_name(node.target) + qconfig = _maybe_adjust_qconfig_for_module_type_or_name( + qconfig_mapping, type(modules[module_name]), module_name, global_qconfig + ) + qconfig_with_device_check = _add_module_to_qconfig_obs_ctr( + qconfig, modules.get(node.target, None) + ) + elif node.op == "call_function": + # precedence: module_name_qconfig + # > function_qconfig > global_qconfig + # module_name takes precedence over function qconfig + function_qconfig = _get_object_type_qconfig( + qconfig_mapping, node.target, global_qconfig + ) + module_path, module_type = node_name_to_scope[node.name] + qconfig = _maybe_adjust_qconfig_for_module_type_or_name( + qconfig_mapping, module_type, module_path, function_qconfig + ) + + cur_object_type_idx = submodule_to_object_type_to_cur_idx[module_path][ + node.target + ] + submodule_to_object_type_to_cur_idx[module_path][node.target] += 1 + qconfig = _maybe_adjust_qconfig_for_module_name_object_type_order( + qconfig_mapping, module_path, node.target, cur_object_type_idx, qconfig + ) + qconfig_with_device_check = _add_module_to_qconfig_obs_ctr( + qconfig, modules.get(node.target, None) + ) + + elif node.op == "call_method": + module_path, module_type = node_name_to_scope[node.name] + # first use node.target (string) to get the qconfig + # this is to support configs like + # "object_type": [("reshape", qconfig)] + qconfig = _maybe_adjust_qconfig_for_module_type_or_name( + qconfig_mapping, node.target, module_path, global_qconfig + ) + # if there is no special config for the method, we'll fall back to the + # config for the module that contains the call_method node + qconfig = _maybe_adjust_qconfig_for_module_type_or_name( + qconfig_mapping, module_type, module_path, qconfig + ) + # currently call_method does not support modifying qconfig + # by order, we can add this later if it is needed. + qconfig_with_device_check = _add_module_to_qconfig_obs_ctr( + qconfig, modules.get(node.target, None) + ) + + elif node.op == "call_module": + # if the node is an observer, just continue - don't add it to the qconfig_map + if _is_activation_post_process(modules[node.target]): + continue + qconfig = _maybe_adjust_qconfig_for_module_type_or_name( + qconfig_mapping, type(modules[node.target]), node.target, global_qconfig + ) + + module_path, module_type = node_name_to_scope[node.name] + # Note: for call_module, the module_path is the current module's name. + # to meaningfully count invocations, we need to count them in the parent + # module. + parent_name, _ = _parent_name(module_path) + cur_object_type_idx = submodule_to_object_type_to_cur_idx[parent_name][ + module_type + ] + submodule_to_object_type_to_cur_idx[parent_name][module_type] += 1 + qconfig = _maybe_adjust_qconfig_for_module_name_object_type_order( + qconfig_mapping, parent_name, module_type, cur_object_type_idx, qconfig + ) + qconfig_with_device_check = _add_module_to_qconfig_obs_ctr( + qconfig, modules.get(node.target, None) + ) + + # regex is not supported eager mode propagate_qconfig_, we'll + # need to set the qconfig explicitly here in case regex + # is used + modules[node.target].qconfig = qconfig_with_device_check + else: + qconfig_with_device_check = None + + node_name_to_qconfig[node.name] = qconfig_with_device_check + return node_name_to_qconfig + + +def _check_is_valid_config_dict( + config_dict: Any, allowed_keys: set[str], dict_name: str +) -> None: + r"""Checks if the given config_dict has the correct keys + + Args: + `config_dict`: dictionary whose keys we want to check + """ + + for k in config_dict: + if k not in allowed_keys: + raise ValueError( + "Expected " + + dict_name + + " to have the following keys: " + + str(allowed_keys) + + ". But found '" + + k + + "' instead." + ) + + +def _compare_prepare_convert_qconfig_mappings( + prepare_qconfig_mapping: QConfigMapping, convert_qconfig_mapping: QConfigMapping +): + r"""Compare the qconfig_mapping passed in convert to the one from prepare and check the values + + Args: + `prepare_qconfig_mapping`: configuration for prepare quantization step + `convert_qconfig_mapping`: configuration for convert quantization step + """ + if not qconfig_equals( + prepare_qconfig_mapping.global_qconfig, convert_qconfig_mapping.global_qconfig + ): + raise AssertionError( + "Expected global qconfigs to be the same in the prepare and convert quantization configs" + ) + prepare_dicts: list[OrderedDict] = [ + prepare_qconfig_mapping.object_type_qconfigs, + prepare_qconfig_mapping.module_name_qconfigs, + prepare_qconfig_mapping.module_name_regex_qconfigs, + ] + convert_dicts: list[OrderedDict] = [ + convert_qconfig_mapping.object_type_qconfigs, + convert_qconfig_mapping.module_name_qconfigs, + convert_qconfig_mapping.module_name_regex_qconfigs, + ] + dict_names = [ + _OBJECT_TYPE_DICT_KEY, + _MODULE_NAME_DICT_KEY, + _MODULE_NAME_REGEX_DICT_KEY, + ] + for i in range(len(prepare_dicts)): + for name in prepare_dicts[i]: + if name not in convert_dicts[i]: + raise AssertionError( + f"Missing key {dict_names[i]} {name} in convert QConfigMapping when it was present in prepare" + ) + if convert_dicts[i][name] is not None and not qconfig_equals( + prepare_dicts[i][name], convert_dicts[i][name] + ): + raise AssertionError( + "Expected convert QConfigMapping to have the same qconfig as prepare for key " + f"{dict_names[i]} {name}; prepare: {prepare_dicts[i][name]}; convert: {convert_dicts[i][name]}" + ) + + +def _is_qconfig_supported_by_dtype_configs( + qconfig: QConfig, dtype_configs: list[DTypeConfig] +): + for dtype_config in dtype_configs: + is_dynamic = dtype_config.is_dynamic + if is_dynamic is None: + is_dynamic = False + input_dtype = dtype_config.input_dtype or torch.float + weight_dtype = dtype_config.weight_dtype or torch.float + bias_dtype = dtype_config.bias_dtype or torch.float + output_dtype = dtype_config.output_dtype or torch.float + ( + qconfig_activation_dtype, + qconfig_weight_dtype, + qconfig_input_act_is_dynamic, + ) = get_qconfig_dtypes(qconfig) + qconfig_bias_dtype = ( + torch.float16 + if ( + qconfig_activation_dtype == torch.float16 + and qconfig_weight_dtype == torch.float16 + and not is_dynamic + ) + else torch.float + ) + + if is_dynamic: + is_match = ( + qconfig_input_act_is_dynamic + and input_dtype == qconfig_activation_dtype + and output_dtype == torch.float + and weight_dtype == qconfig_weight_dtype + ) + else: + is_match = ( + input_dtype == qconfig_activation_dtype + and output_dtype == qconfig_activation_dtype + and weight_dtype == qconfig_weight_dtype + and bias_dtype == qconfig_bias_dtype + ) + if is_match: + return True + return False + + +def _get_object_type_qconfig( + qconfig_mapping: QConfigMapping, + object_type: Callable | str, + fallback_qconfig: QConfigAny, +) -> QConfigAny: + return qconfig_mapping.object_type_qconfigs.get(object_type, fallback_qconfig) + + +def _get_module_name_regex_qconfig(qconfig_mapping, module_name, fallback_qconfig): + for regex_pattern, qconfig in qconfig_mapping.module_name_regex_qconfigs.items(): + if re.match(regex_pattern, module_name): + # first match wins + return qconfig + return fallback_qconfig + + +def _get_module_name_qconfig(qconfig_mapping, module_name, fallback_qconfig): + if module_name == "": + # module name qconfig not found + return fallback_qconfig + if module_name in qconfig_mapping.module_name_qconfigs: + return qconfig_mapping.module_name_qconfigs[module_name] + else: + parent, _ = _parent_name(module_name) + return _get_module_name_qconfig(qconfig_mapping, parent, fallback_qconfig) + + +def _maybe_adjust_qconfig_for_module_type_or_name( + qconfig_mapping, module_type, module_name, global_qconfig +): + # get qconfig for module_name, + # fallback to module_name_regex_qconfig, module_type_qconfig, + # global_qconfig if necessary + module_type_qconfig = _get_object_type_qconfig( + qconfig_mapping, module_type, global_qconfig + ) + module_name_regex_qconfig = _get_module_name_regex_qconfig( + qconfig_mapping, module_name, module_type_qconfig + ) + module_name_qconfig = _get_module_name_qconfig( + qconfig_mapping, module_name, module_name_regex_qconfig + ) + return module_name_qconfig + + +def _get_flattened_qconfig_dict( + qconfig_mapping: QConfigMapping, +) -> dict[Callable | str, QConfigAny]: + """flatten the global, object_type and module_name qconfig + to the same qconfig_dict so that it can be used by + propagate_qconfig_ function. + "module_name_regex" is ignored for now since it's not supported + in propagate_qconfig_, but it can be fixed later. + + For example: + Input: { + "": qconfig, + "object_type": [ + (torch.add, qconfig) + ], + "module_name": [ + ("conv", qconfig) + ] + } + + Output: { + "": qconfig, + torch.add: qconfig, + "conv": qconfig + } + """ + flattened: dict[Callable | str, QConfigAny] = {"": qconfig_mapping.global_qconfig} + flattened.update(qconfig_mapping.object_type_qconfigs) + flattened.update(qconfig_mapping.module_name_qconfigs) # type: ignore[arg-type] + return flattened + + +def _update_qconfig_for_qat( + qconfig_mapping: QConfigMapping, backend_config: BackendConfig +): + """ + Update the qconfig_mapping to account for module swaps during QAT. + During QAT we perform a module swap on the nn.Module types to the corresponding nn.qat.modules types. + """ + module_to_qat_module_class = get_module_to_qat_module(backend_config) + object_type_dict = qconfig_mapping.object_type_qconfigs + new_object_type_dict = object_type_dict.copy() + for k, v in new_object_type_dict.items(): + if k in module_to_qat_module_class: + object_type_dict[module_to_qat_module_class[k]] = v diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/quantize_handler.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/quantize_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..0bd8d7fe3a17439b46ad673a5aaf7eae28b7082f --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/quantize_handler.py @@ -0,0 +1,226 @@ +# mypy: allow-untyped-defs +from abc import ABC +from collections.abc import Callable + +import torch +from torch.ao.quantization.backend_config import ( + BackendConfig, + DTypeConfig, + ObservationType, +) +from torch.ao.quantization.utils import NodePattern, Pattern, QuantizerCls +from torch.fx.graph import Node + +from .utils import all_node_args_have_no_tensors + + +__all__ = [ + "QuantizeHandler", + "BinaryOpQuantizeHandler", + "CatQuantizeHandler", + "ConvReluQuantizeHandler", + "LinearReLUQuantizeHandler", + "BatchNormQuantizeHandler", + "EmbeddingQuantizeHandler", + "RNNDynamicQuantizeHandler", + "DefaultNodeQuantizeHandler", + "FixedQParamsOpQuantizeHandler", + "CopyNodeQuantizeHandler", + "GeneralTensorShapeOpQuantizeHandler", + "CustomModuleQuantizeHandler", + "StandaloneModuleQuantizeHandler", +] + + +def _default_root_node_getter(node_pattern): + if node_pattern is None: + return node_pattern + while not isinstance(node_pattern, Node): + node_pattern = node_pattern[-1] + return node_pattern + + +# Base Pattern Handler +class QuantizeHandler(ABC): # noqa: B024 + """Base handler class for the quantizer patterns""" + + def __init__( + self, + node_pattern: NodePattern, + modules: dict[str, torch.nn.Module], + root_node_getter: Callable | None = None, + is_custom_module=False, + is_standalone_module=False, + ): + """Records pattern information in __init__, which will be used + in convert + """ + self.node_pattern = node_pattern + self.modules = modules + if root_node_getter is None: + root_node_getter = _default_root_node_getter + self.root_node = root_node_getter(node_pattern) + self.is_custom_module_ = is_custom_module + self.is_standalone_module_ = is_standalone_module + self.num_tensor_args = 0 + # determine how many of the first two args are Tensors (versus scalars) + # this distinguishes things like "x + y" from "x + 2" or "2 + x" + if isinstance(self.root_node, Node): + cache_for_no_tensor_check: dict[Node, bool] = {} + for arg_idx in range(len(self.root_node.args)): + arg = self.root_node.args[arg_idx] + if isinstance(arg, Node) and ( + not all_node_args_have_no_tensors( + arg, self.modules, cache_for_no_tensor_check + ) + ): + self.num_tensor_args += 1 + + def is_general_tensor_value_op(self) -> bool: + """ + Returns True if the operator works for both floating point and + quantized input, and does some computation based on the input Tensor, + or the ops that only re-arranges the Tensor values or query some metadata + about the Tensor + so we need to insert observer/fake_quant for the output of the + operator (same observer instance as input) + since the distribution of values is different for input and output + Tensors (for HistogramObserver) while they share the same quantization + parameters + Example operator: avgpool2d, reshape, transpose, maxpool2d + Example observed operator: + observer_0 - avgpool2d - observer_0 (same observer instance as input) + """ + return False + + def is_custom_module(self): + return self.is_custom_module_ + + def is_standalone_module(self): + return self.is_standalone_module_ + + +def _get_quantize_handler_cls( + observation_type: ObservationType, + dtype_configs: list[DTypeConfig], + num_tensor_args_to_observation_type: dict[int, ObservationType], +) -> type[QuantizeHandler]: + """ + Return a configurable QuantizeHandler that matches the given specifications from the backend. + """ + + class ConfigurableQuantizeHandler(QuantizeHandler): + def __init__( + self, + node_pattern: NodePattern, + modules: dict[str, torch.nn.Module], + root_node_getter: Callable | None = None, + ): + super().__init__(node_pattern, modules, root_node_getter) + if num_tensor_args_to_observation_type: + if self.num_tensor_args not in num_tensor_args_to_observation_type: + raise AssertionError( + f"Must provide observation_type config for tensor number {self.num_tensor_args}" + f" in num_tensor_args_to_observation_type for {node_pattern}" + ) + self.observation_type = num_tensor_args_to_observation_type[ + self.num_tensor_args + ] + else: + self.observation_type = observation_type + self.dtype_configs = dtype_configs + + def is_general_tensor_value_op(self) -> bool: + return ( + self.observation_type + == ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT + ) + + return ConfigurableQuantizeHandler + + +def _get_pattern_to_quantize_handlers( + backend_config: BackendConfig, +) -> dict[Pattern, QuantizerCls]: + """ + Note: Quantize handler is just a holder for some check methods like + (should_insert_observer_for_output), maybe this can be a enum as well, + we can refactor this after we convert the path for fbgemm/qnnpack fully to the + new path, this is not exposed to backend developers + """ + pattern_to_quantize_handlers = {} + for pattern, config in backend_config._pattern_complex_format_to_config.items(): + observation_type = config.observation_type + dtype_configs = config.dtype_configs + num_tensor_args_to_observation_type = ( + config._num_tensor_args_to_observation_type + ) + pattern_to_quantize_handlers[pattern] = _get_quantize_handler_cls( + observation_type, dtype_configs, num_tensor_args_to_observation_type + ) + return pattern_to_quantize_handlers + + +# TODO: remove this class, this is still exposed in torch.ao.quantization +# but we should be able to break bc +class BinaryOpQuantizeHandler(QuantizeHandler): + pass + + +class CatQuantizeHandler(QuantizeHandler): + pass + + +# TODO: remove this class +class ConvReluQuantizeHandler(QuantizeHandler): + pass + + +# TODO: remove this class +class LinearReLUQuantizeHandler(QuantizeHandler): + pass + + +# TODO: remove this class +class BatchNormQuantizeHandler(QuantizeHandler): + pass + + +# TODO: remove this class +class EmbeddingQuantizeHandler(QuantizeHandler): + pass + + +# TODO: remove this class +class RNNDynamicQuantizeHandler(QuantizeHandler): + pass + + +# TODO: remove this class +class DefaultNodeQuantizeHandler(QuantizeHandler): + """Common quantized op, first input and first output will be quantized""" + + +# TODO: remove this class +class FixedQParamsOpQuantizeHandler(QuantizeHandler): + pass + + +# TODO: remove +class CopyNodeQuantizeHandler(QuantizeHandler): + pass + + +# TODO: remove +class GeneralTensorShapeOpQuantizeHandler(QuantizeHandler): + pass + + +# TODO: not used, can be removed after torch.ao.quantization namespace is deprecated +class CustomModuleQuantizeHandler(QuantizeHandler): + pass + + +# TODO: not used, can be removed after torch.ao.quantization namespace is deprecated +class StandaloneModuleQuantizeHandler(QuantizeHandler): + pass diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/tracer.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/tracer.py new file mode 100644 index 0000000000000000000000000000000000000000..2c1635936845a44ab895a3c6b0c5e07e9ec9951e --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/tracer.py @@ -0,0 +1,48 @@ +from collections.abc import Callable + +import torch +from torch.ao.nn.intrinsic import _FusedModule +from torch.fx._symbolic_trace import Tracer +from torch.fx.proxy import Scope + + +__all__ = [ + "QuantizationTracer", +] + + +class ScopeContextManager(torch.fx.proxy.ScopeContextManager): + def __init__( + self, scope: Scope, current_module: torch.nn.Module, current_module_path: str + ): + super().__init__(scope, Scope(current_module_path, type(current_module))) + + +class QuantizationTracer(Tracer): + def __init__( + self, skipped_module_names: list[str], skipped_module_classes: list[Callable] + ): + super().__init__() + self.skipped_module_names = skipped_module_names + self.skipped_module_classes = skipped_module_classes + # NB: initialized the module_type of top level module to None + # we are assuming people won't configure the model with the type of top level + # module here, since people can use "" for global config + # We can change this if there is a use case that configures + # qconfig using top level module type + self.scope = Scope("", None) + self.record_stack_traces = True + + def is_leaf_module(self, m: torch.nn.Module, module_qualified_name: str) -> bool: + return ( + ( + ( + m.__module__.startswith("torch.nn") + or m.__module__.startswith("torch.ao.nn") + ) + and not isinstance(m, torch.nn.Sequential) + ) + or module_qualified_name in self.skipped_module_names + or type(m) in self.skipped_module_classes + or isinstance(m, _FusedModule) + ) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/utils.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..0a46d2057c5480ae036bbb847cf3d9bb185b29ce --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/utils.py @@ -0,0 +1,997 @@ +# mypy: allow-untyped-defs +import copy +import functools +import operator +import warnings +from collections import namedtuple +from collections.abc import Callable +from dataclasses import dataclass +from typing import Any + +import torch +import torch.nn as nn +from torch.ao.quantization import QConfigAny, QuantType +from torch.ao.quantization.backend_config import DTypeWithConstraints +from torch.ao.quantization.fake_quantize import ( + FakeQuantizeBase, + FixedQParamsFakeQuantize, +) +from torch.ao.quantization.observer import ( + _is_activation_post_process, + FixedQParamsObserver, + ObserverBase, +) +from torch.ao.quantization.qconfig import ( + float16_dynamic_qconfig, + float16_static_qconfig, + qconfig_equals, +) +from torch.ao.quantization.qconfig_mapping import QConfigMapping +from torch.ao.quantization.stubs import DeQuantStub +from torch.ao.quantization.utils import ( + _assert_and_get_unique_device, + activation_is_statically_quantized, +) +from torch.fx import GraphModule, map_arg +from torch.fx.graph import Graph, Node + +# importing the lib so that the quantized_decomposed ops are registered +from ._decomposed import quantized_decomposed_lib # noqa: F401 +from .custom_config import PrepareCustomConfig + + +# TODO: revisit this list. Many helper methods shouldn't be public +__all__ = [ + "all_node_args_except_first", + "all_node_args_have_no_tensors", + "assert_and_get_unique_device", + "collect_producer_nodes", + "create_getattr_from_value", + "create_node_from_old_node_preserve_meta", + "EMPTY_ARG_DICT", + "get_custom_module_class_keys", + "get_linear_prepack_op_for_dtype", + "get_new_attr_name_with_prefix", + "get_non_observable_arg_indexes_and_types", + "get_qconv_prepack_op", + "get_skipped_module_name_and_classes", + "graph_module_from_producer_nodes", + "maybe_get_next_module", + "NodeInfo", + "node_arg_is_bias", + "node_arg_is_weight", + "NON_OBSERVABLE_ARG_DICT", + "NON_QUANTIZABLE_WEIGHT_OPS", + "return_arg_list", + "ObservedGraphModuleAttrs", +] + +NON_QUANTIZABLE_WEIGHT_OPS = { + torch.nn.functional.layer_norm, + torch.nn.functional.group_norm, + torch.nn.functional.instance_norm, +} + + +@dataclass +class ObservedGraphModuleAttrs: + node_name_to_qconfig: dict[str, QConfigAny] + node_name_to_scope: dict[str, tuple[str, type]] + prepare_custom_config: PrepareCustomConfig + equalization_node_name_to_qconfig: dict[str, Any] + qconfig_mapping: QConfigMapping + is_qat: bool + observed_node_names: set[str] + is_observed_standalone_module: bool = False + standalone_module_input_quantized_idxs: list[int] | None = None + standalone_module_output_quantized_idxs: list[int] | None = None + + +def node_arg_is_weight(node: Node, arg: Any) -> bool: + """Returns if node arg is weight""" + weight_index = None + if "target_dtype_info" in node.meta: + weight_index = node.meta["target_dtype_info"].get("weight_index", None) + if ( + weight_index is not None + and weight_index < len(node.args) + and node.args[weight_index] is arg + ): + return True + return node.kwargs.get("weight") is arg + + +def node_arg_is_bias(node: Node, arg: Any) -> bool: + """Returns if node arg is bias""" + bias_index = None + if "target_dtype_info" in node.meta: + bias_index = node.meta["target_dtype_info"].get("bias_index", None) + if ( + bias_index is not None + and bias_index < len(node.args) + and node.args[bias_index] is arg + ): + return True + return node.kwargs.get("bias") is arg + + +def get_custom_module_class_keys( + custom_module_mapping: dict[QuantType, dict[type, type]], +) -> list[Any]: + r"""Get all the unique custom module keys in the custom config dict + e.g. + Input: + { + QuantType.STATIC: { + CustomModule1: ObservedCustomModule + }, + QuantType.DYNAMIC: { + CustomModule2: DynamicObservedCustomModule + }, + QuantType.WEIGHT_ONLY: { + CustomModule3: WeightOnlyObservedCustomModule + }, + } + + Output: + # extract the keys across all inner STATIC, DYNAMIC, and WEIGHT_ONLY dicts + [CustomModule1, CustomModule2, CustomModule3] + """ + # using set to dedup + float_custom_module_classes: set[Any] = set() + for quant_mode in [QuantType.STATIC, QuantType.DYNAMIC, QuantType.WEIGHT_ONLY]: + quant_mode_custom_module_config = custom_module_mapping.get(quant_mode, {}) + quant_mode_custom_module_classes = set(quant_mode_custom_module_config.keys()) + float_custom_module_classes |= quant_mode_custom_module_classes + return list(float_custom_module_classes) + + +def get_linear_prepack_op_for_dtype(dtype): + if dtype == torch.float16: + return torch.ops.quantized.linear_prepack_fp16 + elif dtype == torch.qint8: + return torch.ops.quantized.linear_prepack + else: + raise Exception("can't get linear prepack op for dtype:", dtype) # noqa: TRY002 + + +def get_qconv_prepack_op(conv_op: Callable) -> Callable: + prepack_ops = { + torch.nn.functional.conv1d: torch.ops.quantized.conv1d_prepack, + torch.nn.functional.conv2d: torch.ops.quantized.conv2d_prepack, + torch.nn.functional.conv3d: torch.ops.quantized.conv3d_prepack, + torch.nn.functional.conv_transpose1d: torch.ops.quantized.conv_transpose1d_prepack, + torch.nn.functional.conv_transpose2d: torch.ops.quantized.conv_transpose2d_prepack, + torch.nn.functional.conv_transpose3d: torch.ops.quantized.conv_transpose3d_prepack, + } + prepack_op = prepack_ops.get(conv_op) + if prepack_op is None: + raise AssertionError(f"Didn't find prepack op for {conv_op}") + return prepack_op + + +# Returns a function that can get a new attribute name for module with given +# prefix, for example, +# >> get_new_observer_name = get_new_attr_name_with_prefix('_observer') +# >> new_name = get_new_observer_name(module) +# new_name will be an unused attribute name on module, e.g. `_observer_1` +def get_new_attr_name_with_prefix(prefix: str) -> Callable: + prefix = prefix.replace(".", "_") + + def get_new_attr_name(module: torch.nn.Module): + def get_attr_name(i: int): + return prefix + str(i) + + i = 0 + attr_name = get_attr_name(i) + while hasattr(module, attr_name): + i += 1 + attr_name = get_attr_name(i) + return attr_name + + return get_new_attr_name + + +def collect_producer_nodes(node: Node) -> list[Node] | None: + r"""Starting from a target node, trace back until we hit input or + getattr node. This is used to extract the chain of operators + starting from getattr to the target node, for example:: + + def forward(self, x): + observed = self.observer(self.weight) + return F.linear(x, observed) + + collect_producer_nodes(observed) will either return a list of nodes that + produces the observed node or None if we can't extract a self contained + graph without free variables(inputs of the forward function). + """ + nodes = [node] + frontier = [node] + while frontier: + node = frontier.pop() + all_args = list(node.args) + list(node.kwargs.values()) + for arg in all_args: + if not isinstance(arg, Node): + continue + if arg.op == "placeholder": + # hit input, can't fold in this case + return None + nodes.append(arg) + if not (arg.op == "call_function" and arg.target is getattr): + frontier.append(arg) + return nodes + + +def graph_module_from_producer_nodes( + root: GraphModule, producer_nodes: list[Node] +) -> GraphModule: + r"""Construct a graph module from extracted producer nodes + from `collect_producer_nodes` function + Args: + root: the root module for the original graph + producer_nodes: a list of nodes we use to construct the graph + Return: + A graph module constructed from the producer nodes + """ + if len(producer_nodes) == 0: + raise AssertionError("list of producer nodes can not be empty") + # since we traced back from node to getattr + producer_nodes.reverse() + graph = Graph() + env: dict[Any, Any] = {} + + def load_arg(a): + return map_arg(a, lambda node: env[node]) + + for producer_node in producer_nodes: + env[producer_node] = graph.node_copy(producer_node, load_arg) + graph.output(load_arg(producer_nodes[-1])) + graph_module = GraphModule(root, graph) + return graph_module + + +# TODO: delete +@functools.cache +def assert_and_get_unique_device(module: torch.nn.Module) -> Any: + """ + Returns the unique device for a module, or None if no device is found. + Throws an error if multiple devices are detected. + """ + return _assert_and_get_unique_device(module) + + +def create_getattr_from_value( + module: torch.nn.Module, + graph: Graph, + prefix: str, + value: Any, + device: torch.device | None = None, +) -> Node: + """ + Given a value of any type, creates a getattr node corresponding to the value and + registers the value as a buffer to the module. + """ + get_new_attr_name = get_new_attr_name_with_prefix(prefix) + attr_name = get_new_attr_name(module) + if device is None: + device = assert_and_get_unique_device(module) + new_value = ( + value.detach().clone() + if isinstance(value, torch.Tensor) + else torch.tensor(value, device=device) + ) + module.register_buffer(attr_name, new_value) + # Create get_attr with value + attr_node = graph.create_node("get_attr", attr_name) + return attr_node + + +def all_node_args_have_no_tensors( + node: Node, modules: dict[str, torch.nn.Module], cache: dict[Node, bool] +) -> bool: + """ + If we know for sure that all of this node's args have no + tensors (are primitives), return True. If we either + find a tensor or are not sure, return False. Note: this + function is not exact. + """ + if cache and node in cache: + return cache[node] + + result = False # will be overwritten + if not isinstance(node, Node): + result = True + elif node.op == "placeholder": + result = False + elif node.op == "call_module": + if not isinstance(node.target, str): + raise AssertionError("node.target must be a string for call_module nodes") + if _is_activation_post_process(modules[node.target]): + result = all_node_args_have_no_tensors(node.args[0], modules, cache) # type: ignore[arg-type] + elif node.op == "call_module": + result = False + elif node.op == "call_function" and node.target is operator.getitem: + result = all_node_args_have_no_tensors(node.args[0], modules, cache) # type: ignore[arg-type] + elif node.op == "get_attr": + result = False + elif node.target is getattr and node.args[1] in ["ndim", "shape"]: + # x1 = x0.ndim + result = True + elif node.op == "call_method" and node.target == "size": + # x1 = x0.size(0) + result = True + else: + found_one_tensor = False + for arg in node.args: + if isinstance(arg, list): + for list_el in arg: + if isinstance(list_el, Node): + this_list_el_args_have_no_tensors = ( + all_node_args_have_no_tensors(list_el, modules, cache) + ) + found_one_tensor = found_one_tensor or ( + not this_list_el_args_have_no_tensors + ) + # If found_one_tensor is True, there is no point in + # recursing further as the end result will always + # be True. + # TODO(future PR): remove this entire function and + # change to dtype inference without recursion. + if found_one_tensor: + result = not found_one_tensor + if cache: + cache[node] = result + return result + elif isinstance(arg, int): + pass + else: + if isinstance(arg, Node): + this_arg_args_have_no_tensors = all_node_args_have_no_tensors( + arg, modules, cache + ) + found_one_tensor = found_one_tensor or ( + not this_arg_args_have_no_tensors + ) + # If found_one_tensor is True, there is no point in + # recursing further as the end result will always + # be True. + # TODO(future PR): remove this entire function and + # change to dtype inference without recursion. + if found_one_tensor: + result = not found_one_tensor + if cache: + cache[node] = result + return result + else: + found_one_tensor = True + result = not found_one_tensor + if cache: + cache[node] = result + return result + + +def all_node_args_except_first(node: Node) -> list[int]: + """ + Returns all node arg indices after first + """ + return list(range(1, len(node.args))) + + +def return_arg_list(arg_indices: list[int]) -> Callable[[Node], list[int]]: + """ + Constructs a function that takes a node as arg and returns the arg_indices + that are valid for node.args + """ + + def arg_indices_func(node: Node) -> list[int]: + return [i for i in arg_indices if i < len(node.args)] + + return arg_indices_func + + +NodeInfo = namedtuple("NodeInfo", "op target") + +# this dict identifies which indices of a node are non tensors +# so that they can be propagated correctly since inserting observers +# for them would cause errors + +NON_OBSERVABLE_ARG_DICT: dict[ + NodeInfo, dict[type | torch.dtype, Callable[[Node], list[int]]] +] = { + NodeInfo("call_method", "masked_fill"): { + torch.bool: return_arg_list([1]), + float: return_arg_list([2]), + }, + NodeInfo("call_method", "permute"): {int: all_node_args_except_first}, + NodeInfo("call_method", "repeat"): {int: all_node_args_except_first}, + NodeInfo("call_method", "reshape"): {int: all_node_args_except_first}, + NodeInfo("call_method", "size"): {int: return_arg_list([1])}, + NodeInfo("call_method", "transpose"): {int: all_node_args_except_first}, + NodeInfo("call_method", torch.transpose): {int: all_node_args_except_first}, + NodeInfo("call_method", "unsqueeze"): {int: return_arg_list([1])}, + NodeInfo("call_method", "unsqueeze_"): {int: return_arg_list([1])}, + NodeInfo("call_method", torch.unsqueeze): {int: return_arg_list([1])}, + NodeInfo("call_method", "view"): {int: all_node_args_except_first}, +} + +EMPTY_ARG_DICT: dict[type | torch.dtype, Callable[[Node], list[int]]] = {} + + +def get_non_observable_arg_indexes_and_types( + node: Node, +) -> dict[type | torch.dtype, Callable[[Node], list[int]]]: + """ + Returns a dict with of non float tensor types as keys and values which correspond to a + function to retrieve the list (which takes the node as an argument) + """ + info = NodeInfo(node.op, node.target) + + return NON_OBSERVABLE_ARG_DICT.get(info, EMPTY_ARG_DICT) + + +def maybe_get_next_module( + node: Node, + modules: dict[str, nn.Module], + target_module_type: type[nn.Module] | None = None, + target_functional_type: Any = None, +) -> Node | None: + """Gets the next module that matches what is needed in + is_target_module_type if it exists + + Args: + node: The node whose users we want to look at + target_module_type: Module type that we want to check + target_functional_type: Functional type that we want to check + """ + + for user in node.users: + if ( + user.op == "call_module" + and target_module_type is not None + and isinstance(modules[str(user.target)], target_module_type) + ): + return user + elif ( + user.op == "call_function" + and target_functional_type is not None + and user.target == target_functional_type + ): + return user + + return None + + +def create_node_from_old_node_preserve_meta( + quantized_graph: Graph, + create_node_args: tuple[Any, ...], + old_node: Node, +) -> Node: + """ + Creates `new_node` and copies the necessary metadata to it from `old_node`. + """ + new_node = quantized_graph.create_node(*create_node_args) + new_node.stack_trace = old_node.stack_trace + return new_node + + +def get_skipped_module_name_and_classes( + prepare_custom_config: PrepareCustomConfig, is_standalone_module: bool +) -> tuple[list[str], list[type[Any]]]: + skipped_module_names = copy.copy(prepare_custom_config.non_traceable_module_names) + skipped_module_classes = copy.copy( + prepare_custom_config.non_traceable_module_classes + ) + if not is_standalone_module: + # standalone module and custom module config are applied in top level module + skipped_module_names += list( + prepare_custom_config.standalone_module_names.keys() + ) + skipped_module_classes += list( + prepare_custom_config.standalone_module_classes.keys() + ) + skipped_module_classes += get_custom_module_class_keys( + prepare_custom_config.float_to_observed_mapping + ) + + return skipped_module_names, skipped_module_classes + + +def _is_custom_module_lstm( + node: Node, + named_modules: dict[str, torch.nn.Module], + qconfig: QConfigAny = None, + # QuantizeHandler, but we cannot include the type here due to circular imports + qhandler: Any | None = None, +) -> bool: + """ + Return whether this refers to the custom module LSTM flow. + """ + mod = _get_module(node, named_modules) + if qconfig is not None and qhandler is not None: + if not isinstance( + qhandler, torch.ao.quantization.fx.quantize_handler.QuantizeHandler + ): # type: ignore[attr-defined] + raise AssertionError("qhandler must be a QuantizeHandler when provided") + return ( + isinstance(mod, torch.nn.LSTM) + and activation_is_statically_quantized(qconfig) + and qhandler.is_custom_module() + ) + else: + return isinstance(mod, torch.ao.nn.quantizable.LSTM) + + +def _is_custom_module_mha( + node: Node, + named_modules: dict[str, torch.nn.Module], + qconfig: QConfigAny = None, + # QuantizeHandler, but we cannot include the type here due to circular imports + qhandler: Any | None = None, +) -> bool: + """ + Return whether this refers to the custom module MultiheadAttention flow. + """ + mod = _get_module(node, named_modules) + if qconfig is not None and qhandler is not None: + if not isinstance( + qhandler, torch.ao.quantization.fx.quantize_handler.QuantizeHandler + ): # type: ignore[attr-defined] + raise AssertionError("qhandler must be a QuantizeHandler when provided") + return ( + isinstance(mod, torch.nn.MultiheadAttention) + and activation_is_statically_quantized(qconfig) + and qhandler.is_custom_module() + ) + else: + return isinstance(mod, torch.ao.nn.quantizable.MultiheadAttention) + + +def _get_module( + node: Node, named_modules: dict[str, torch.nn.Module] +) -> torch.nn.Module | None: + """ + If `node` refers to a call_module node, return the module, else None. + """ + if node.op == "call_module" and str(node.target) in named_modules: + return named_modules[str(node.target)] + else: + return None + + +def _insert_dequant_stub( + node: Node, + model: torch.nn.Module, + named_modules: dict[str, torch.nn.Module], + graph: Graph, +) -> Node: + """ + Attach a `DeQuantStub` to the model and create a node that calls this + `DeQuantStub` on the output of `node`, similar to how observers are inserted. + """ + prefix = "dequant_stub_" + get_new_dequant_stub_name = get_new_attr_name_with_prefix(prefix) + dequant_stub_name = get_new_dequant_stub_name(model) + dequant_stub = DeQuantStub() + setattr(model, dequant_stub_name, dequant_stub) + named_modules[dequant_stub_name] = dequant_stub + with graph.inserting_after(node): + return graph.call_module(dequant_stub_name, (node,)) + + +def _insert_dequant_stubs_for_custom_module_lstm_output( + node: Node, + model: torch.nn.Module, + named_modules: dict[str, torch.nn.Module], + graph: Graph, +) -> Node: + """ + Insert DeQuantStubs after each internal output node of custom module LSTM. + + Custom module LSTM outputs are nested tuples of the structure (output, (hidden0, hidden1)), + Since we cannot dequantize a tuple as a whole, we must first break down the tuple into its + components through `getitem`. This function transforms the graph as follows: + + (1) Split the LSTM node into (output, (hidden0, hidden1)) + (2) Insert a DeQuantStub after each internal node + (3) Recombine the DeQuantStubs into the same structure as before + (4) Reroute all consumers of the original LSTM node and its sub-nodes + (e.g. lstm[0]) + + Before: + lstm_output + | + v + original_user(s) + After: + lstm_output + / \\ + / (getitem) \\ + / \\ + v v + output hidden + | / \\ + (DeQuantStub) (getitem) + | / \\ + v v v + output_dq hidden0 hidden1 + | | | + | (DeQuantStub) (DeQuantStub) + | | | + | v v + | hidden0_dq hidden1_dq + | \\ / + | (tuple) + | \\ / + | v v + | hidden_dq + \\ / + \\ (tuple) / + v v + lstm_output_dq + | + v + original_user(s) + + For step (4), reroute all users of the original LSTM node(s) as follows: + lstm_output -> lstm_output_dq + lstm_output[0] -> output_dq + lstm_output[1] -> hidden_dq + lstm_output[1][0] -> hidden0_dq + lstm_output[1][1] -> hidden1_dq + + Return the node `lstm_output_dq`. + """ + # (1) Split the LSTM node into (output, (hidden0, hidden1)) + # (2) Insert a DeQuantStub after each internal node + with graph.inserting_after(node): + output = graph.call_function(operator.getitem, (node, 0)) + output_dq = _insert_dequant_stub(output, model, named_modules, graph) + with graph.inserting_after(output_dq): + hidden = graph.call_function(operator.getitem, (node, 1)) + with graph.inserting_after(hidden): + hidden0 = graph.call_function(operator.getitem, (hidden, 0)) + hidden0_dq = _insert_dequant_stub(hidden0, model, named_modules, graph) + with graph.inserting_after(hidden0_dq): + hidden1 = graph.call_function(operator.getitem, (hidden, 1)) + hidden1_dq = _insert_dequant_stub(hidden1, model, named_modules, graph) + + # (3) Recombine the DeQuantStubs into the same structure as before + with graph.inserting_after(hidden1_dq): + hidden_dq = graph.call_function(tuple, ([hidden0_dq, hidden1_dq],)) + with graph.inserting_after(hidden_dq): + lstm_output_dq = graph.call_function(tuple, ([output_dq, hidden_dq],)) + + # (4) Reroute all consumers of the original LSTM node and its sub-nodes + for user in list(node.users.keys()): + if user != output and user != hidden: + user.replace_input_with(node, lstm_output_dq) + # The getitem and tuple nodes we added here may interfere with reference quantized + # pattern matching, so we need to redirect the consumers of internal nodes to the + # corresponding nodes with DeQuantStubs (e.g. lstm_output_dq[0] -> output_dq) attached, + # in order to preserve reference patterns like "dequantize - consumer - quantize". + _reroute_tuple_getitem_pattern(graph) + return lstm_output_dq + + +def _maybe_get_custom_module_lstm_from_node_arg( + arg: Node, + named_modules: dict[str, torch.nn.Module], +) -> Node | None: + """ + Given an argument of a node, if the argument refers to the path through which the node + is a consumer of custom module LSTM, return the custom module LSTM node, or None otherwise. + + This is used to determine whether a node is a consumer of custom module LSTM, and, if so, + skip inserting input observers for this node. This is because custom module LSTM produces + quantized outputs, so inserting an input observer for the consumer of custom module LSTM + would unnecessarily quantize the outputs again. + + lstm -> consumer + + In practice, however, custom module LSTM outputs a tuple (output, (hidden0, hidden1)) with + DeQuantStubs attached to each internal node (see `_insert_dequant_stubs_for_custom_module_lstm_output`). + This tuple can be consumed in one of four ways: + + lstm -> getitem -> DeQuantStub -> consumer # consume lstm[0] + lstm -> getitem -> getitem -> DeQuantStub -> tuple -> consumer # consume lstm[1] + lstm -> getitem -> getitem -> DeQuantStub -> consumer # consume lstm[1][0] or lstm[1][1] + lstm -> getitem -> DeQuantStub -> tuple -> consumer # consume lstm + + Thus, we must match against the above patterns instead of simply checking the parent node + to determine whether this node is a consumer of a custom module LSTM. + """ + + def match_dq(a): + return isinstance(_get_module(a, named_modules), DeQuantStub) + + def match_lstm(a): + return _is_custom_module_lstm(a, named_modules) + + def match_getitem(a): + return a.op == "call_function" and a.target is operator.getitem + + def match_tuple(a): + return a.op == "call_function" and a.target is tuple + + def _match_pattern(match_pattern: list[Callable]) -> Node | None: + """ + Traverse up the graph and match the args one by one. + If there is a match, return the last matched node, or None otherwise. + """ + a = arg + for i, match in enumerate(match_pattern): + if not match(a): + return None + # Match next arg, for tuple the arg is a tuple of a list, e.g. ([dq_1, other_node],) + if i < len(match_pattern) - 1: + if match is match_tuple: + a = a.args[0][0] # type: ignore[assignment,index] + else: + a = a.args[0] # type: ignore[assignment] + # pyrefly: ignore [bad-return] + return a + + all_match_patterns = [ + [match_dq, match_getitem, match_lstm], + [match_tuple, match_dq, match_getitem, match_getitem, match_lstm], + [match_dq, match_getitem, match_getitem, match_lstm], + [match_tuple, match_dq, match_getitem, match_lstm], + ] + + for p in all_match_patterns: + matched_node = _match_pattern(p) + if matched_node is not None: + return matched_node + return None + + +def _reroute_tuple_getitem_pattern(graph: Graph): + """ + Search for patterns where N consecutive `tuple` call_function nodes are followed by + N consecutive `getitem` call_function nodes that are "reverses" of the `tuple` nodes. + If we find this pattern, reroute the consumers of the last `getitem` to skip these + N `tuple` and `getitem` nodes. + + Before: + + a b c + | \\ / + \\ tuple + \\ / + tuple + | + getitem(1) + | + getitem(0) + | + d + + After: + + b + | + d + """ + + def find_patterns( + node: Node, + index_stack: list[int], + current_pattern: list[Node], + matched_patterns: list[list[Node]], + seen: set[tuple[Node, tuple[int, ...]]], + ): + """ + Traverse the graph recursively to match for the N-tuple - N-getitem patterns, + starting at the given node. + + We use a stack to keep track of the expected `getitem` indices, since these are + reversed from the `tuple` indices. In the above example, the stack after + (b -> tuple -> tuple) will be [0, 1], which will be popped by getitem(1) first + and then by getitem(0). + + TODO: traverse upwards from the output and handle the case when tuple is not a + separate node, e.g. graph.call_function(operator.getitem, args=(a, (b, c))) + """ + if len(index_stack) == 0 and len(current_pattern) > 0: + matched_patterns.append(copy.copy(current_pattern)) + current_pattern.clear() + + # Avoid duplicating work + state = (node, tuple(index_stack)) + if state in seen: + return + seen.add(state) + + # Iterate through users of this node to find tuple/getitem nodes to match + for user in node.users: + if user.op == "call_function" and user.target is tuple: + for i, user_arg in enumerate(user.args[0]): # type: ignore[arg-type] + if user_arg == node: + index_stack.append(i) + current_pattern.append(user) + find_patterns( + user, index_stack, current_pattern, matched_patterns, seen + ) + elif user.op == "call_function" and user.target is operator.getitem: + if len(index_stack) > 0: + if user.args[1] == index_stack[-1]: + index_stack.pop() + current_pattern.append(user) + find_patterns( + user, index_stack, current_pattern, matched_patterns, seen + ) + return matched_patterns + + # Collect all matched patterns + matched_patterns: list[list[Node]] = [] + seen: set[tuple[Node, tuple[int, ...]]] = set() # (node, index_stack) + for node in graph.nodes: + find_patterns(node, [], [], matched_patterns, seen) + + # For each pattern, redirect all consumers of the last getitem node to the correct input + # of the first tuple node + for pattern in matched_patterns: + first_tuple = pattern[0] + last_getitem = pattern[-1] + if not (first_tuple.op == "call_function" and first_tuple.target is tuple): + raise AssertionError( + "first tuple node must be a call_function with target tuple" + ) + if not ( + last_getitem.op == "call_function" + and last_getitem.target is operator.getitem + ): + raise AssertionError( + "last getitem node must be a call_function with target operator.getitem" + ) + last_getitem_index = last_getitem.args[1] + new_input = first_tuple.args[0][last_getitem_index] # type: ignore[index] + for user in list(last_getitem.users.keys()): + user.replace_input_with(last_getitem, new_input) # type: ignore[arg-type] + + +def _get_observer_from_activation_post_process( + activation_post_process: ObserverBase | FakeQuantizeBase, +) -> ObserverBase: + """ + If `activation_post_process` is an observer, return the observer. + If `activation_post_process` is a fake quantize, return the internal observer. + """ + if isinstance(activation_post_process, ObserverBase): + return activation_post_process + else: + if not isinstance(activation_post_process, FakeQuantizeBase): + raise AssertionError( + "activation_post_process must be an ObserverBase or FakeQuantizeBase" + ) + return activation_post_process.activation_post_process # type: ignore[return-value] + + +def _qconfig_satisfies_dtype_config_constraints( + qconfig: QConfigAny, + dtype_with_constraints: DTypeWithConstraints, + is_activation: bool = True, +) -> bool: + """ + Return whether `qconfig` satisfies the following constraints from the backend, + specified through the activation and weight DTypeWithConstraints. + + 1. QConfig specified a quantization range that falls within the backend's, if any + 2. QConfig specified a min scale value that is >= the backend's, if any + 3. QConfig specified a FixedQParamsObserver or FixedQParamsFakeQuantize that has + scale and zero point that match the backend's, if any + + If `is_activation` is True, we check `qconfig.activation`, else we check `qconfig.weight`. + If `qconfig` or `dtype_with_constraints.dtype` is None, or the dtypes do not match, return True. + """ + + # TODO: log warnings only when the user enabled a debug flag + def _activation_post_process_satisfies_dtype_config_constraints( + activation_post_process: ObserverBase | FakeQuantizeBase, + dtype_with_constraints: DTypeWithConstraints, + debug_string: str, + ) -> bool: + observer = _get_observer_from_activation_post_process(activation_post_process) + app_quant_min = getattr(observer, "quant_min", None) + app_quant_max = getattr(observer, "quant_max", None) + # TODO: for now, just use the existing eps value as scale_min. In the future, we should + # resolve the differences between the two, either by renaming eps or some other way + app_scale_min = getattr(observer, "eps", None) + backend_quant_min = dtype_with_constraints.quant_min_lower_bound + backend_quant_max = dtype_with_constraints.quant_max_upper_bound + backend_scale_min = dtype_with_constraints.scale_min_lower_bound + backend_scale_exact_match = dtype_with_constraints.scale_exact_match + backend_zero_point_exact_match = dtype_with_constraints.zero_point_exact_match + # check quantization ranges + if backend_quant_min is not None and backend_quant_max is not None: + if app_quant_min is None or app_quant_max is None: + warnings.warn( + f"QConfig {debug_string} must specify 'quant_min' and 'quant_max', ignoring {qconfig}", + stacklevel=2, + ) + return False + elif app_quant_min < backend_quant_min or app_quant_max > backend_quant_max: + warnings.warn( + f"QConfig {debug_string} quantization range must fall within the backend's:\n" + f"QConfig range = ({app_quant_min}, {app_quant_max}), " + f"BackendConfig range = ({backend_quant_min}, {backend_quant_max}), " + f"ignoring {qconfig}", + stacklevel=2, + ) + return False + # check scale min + if backend_scale_min is not None: + if app_scale_min is None: + warnings.warn( + f"QConfig {debug_string} must specify 'eps', ignoring {qconfig}", + stacklevel=2, + ) + return False + if app_scale_min < backend_scale_min: + warnings.warn( + f"QConfig {debug_string} eps ({app_scale_min}) must be greater than or equal to " + f"the backend's min scale value ({backend_scale_min}), ignoring {qconfig}", + stacklevel=2, + ) + return False + # check fixed scale and zero point + if ( + backend_scale_exact_match is not None + and backend_zero_point_exact_match is not None + ): + # For tests only, accept the following qconfigs for now + # TODO: handle fp16 qconfigs properly + for accepted_qconfig in [float16_static_qconfig, float16_dynamic_qconfig]: + if qconfig_equals(qconfig, accepted_qconfig): + return True + suggestion_str = ( + "Please use torch.ao.quantization.get_default_qconfig_mapping or " + "torch.ao.quantization.get_default_qat_qconfig_mapping. Example:\n" + ' qconfig_mapping = get_default_qconfig_mapping("fbgemm")\n' + " model = prepare_fx(model, qconfig_mapping, example_inputs)" + ) + if not isinstance( + activation_post_process, FixedQParamsObserver + ) and not isinstance(activation_post_process, FixedQParamsFakeQuantize): + warnings.warn( + f"QConfig must specify a FixedQParamsObserver or a FixedQParamsFakeQuantize " + f"for fixed qparams ops, ignoring {qconfig}.\n{suggestion_str}", + stacklevel=2, + ) + return False + if ( + observer.scale != backend_scale_exact_match + or observer.zero_point != backend_zero_point_exact_match + ): + warnings.warn( + f"QConfig fixed scale ({observer.scale}) and zero point ({observer.zero_point}) " + f"do not match the backend's ({backend_scale_exact_match} and {backend_zero_point_exact_match}), " + f"ignoring {qconfig}.\n{suggestion_str}", + stacklevel=2, + ) + return False + return True + + if qconfig is None or dtype_with_constraints.dtype is None: + return True + + activation_post_process_ctr = ( + qconfig.activation if is_activation else qconfig.weight + ) + debug_string = "activation" if is_activation else "weight" + satisfies_constraints = True + if activation_post_process_ctr is not None: + activation_post_process = activation_post_process_ctr() + if not _is_activation_post_process(activation_post_process): + raise AssertionError( + "activation_post_process must be an activation post process" + ) + # If dtypes don't match, don't check the activation_post_process and return True early + if activation_post_process.dtype != dtype_with_constraints.dtype: + return True + satisfies_constraints = ( + _activation_post_process_satisfies_dtype_config_constraints( + activation_post_process, dtype_with_constraints, debug_string + ) + ) + return satisfies_constraints diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/pt2e/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/pt2e/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/pt2e/_affine_quantization.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/pt2e/_affine_quantization.py new file mode 100644 index 0000000000000000000000000000000000000000..aa75f32eb8d801f271b51d12671bc2e4cf7e4eb5 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/pt2e/_affine_quantization.py @@ -0,0 +1,891 @@ +# copied from https://github.com/pytorch/ao/blob/main/torchao/quantization/observer.py +# and https://github.com/pytorch/ao/blob/main/torchao/quantization/quant_primitives.py +# PLEASE DON'T MODIFY THIS FILE SO THAT WE DON'T GET OUT OF SYNC +import logging +from abc import ABCMeta +from typing import Any + +import torch +from torch.ao.quantization.observer import ( + AffineQuantizedObserverBase, + get_block_size, + Granularity, + MappingType, + TorchAODType, + ZeroPointDomain, +) + + +ABC: Any = ABCMeta("ABC", (object,), {}) # compatible with Python 2 *and* 3: + +logger = logging.getLogger(__name__) + +FP8_TYPES = { + torch.float8_e4m3fn, + torch.float8_e5m2, + torch.float8_e4m3fnuz, + torch.float8_e5m2fnuz, +} +_SUB_BYTE_UINT_BOUNDS = { + torch.uint1: (0, 2**1 - 1), + torch.uint2: (0, 2**2 - 1), + torch.uint3: (0, 2**3 - 1), + torch.uint4: (0, 2**4 - 1), + torch.uint5: (0, 2**5 - 1), + torch.uint6: (0, 2**6 - 1), + torch.uint7: (0, 2**7 - 1), +} + +""" +Map from dtype to the bound value of integers +TODO: maybe can replace this with call to torch.iinfo +""" +_DTYPE_TO_QVALUE_BOUNDS: dict[torch.dtype | TorchAODType, tuple[int, int]] = { + torch.uint8: (0, 255), + torch.int8: (-128, 127), + torch.int16: (-(2**15), 2**15 - 1), + torch.int32: (-(2**31), 2**31 - 1), +} +_DTYPE_TO_QVALUE_BOUNDS.update(_SUB_BYTE_UINT_BOUNDS) + + +def _is_float8_type(dtype: torch.dtype) -> bool: + fp8_types = { + torch.float8_e4m3fn, + torch.float8_e4m3fnuz, + torch.float8_e5m2, + torch.float8_e5m2fnuz, + } + return dtype in fp8_types + + +# TODO: decide on if we want to allow custom quant_min/quant_max here +def _get_and_check_qmin_qmax(dtype, quant_min, quant_max): + """Get quant_min and quant_max args based on dtype and also + verify that they are within the range of possible quant_min/quant_max + for dtype + """ + if dtype in FP8_TYPES: + quant_min_lower_bound, quant_max_upper_bound = ( + torch.finfo(dtype).min, + torch.finfo(dtype).max, + ) + elif dtype not in _DTYPE_TO_QVALUE_BOUNDS: + raise ValueError(f"Unsupported dtype: {dtype}") + else: + quant_min_lower_bound, quant_max_upper_bound = _DTYPE_TO_QVALUE_BOUNDS[dtype] + if quant_min is None: + quant_min = quant_min_lower_bound + if quant_max is None: + quant_max = quant_max_upper_bound + + if quant_min < quant_min_lower_bound: + raise AssertionError( + "quant_min out of bound for dtype, " + f"quant_min_lower_bound: {quant_min_lower_bound} quant_min: {quant_min}" + ) + + if quant_max > quant_max_upper_bound: + raise AssertionError( + "quant_max out of bound for dtype, " + f"quant_max_upper_bound: {quant_max_upper_bound} quant_max: {quant_max}" + ) + return quant_min, quant_max + + +def _get_reduction_params(block_size, input_size): + """Given block_size and input size find the parameters for reduction: + + Output: + shape_for_reduction: the shape we use to `view` input to prepare it for reduction + reduction_dims: the dims we'll do reduction over + + Example:: + Input: + block_size: (3, 3, 2, 10) + input_size: (3, 3, 10, 10) + + Output: + shape_for_reduction: (3, 3, 5, 2, 10) + reduction_dim: [0, 1, 3, 4] + """ + if len(block_size) != len(input_size): + raise AssertionError( + "block_size length must equal input_size length, got " + f"block_size={block_size}, input_size={input_size}" + ) + shape_for_reduction = [] + reduction_dims = [] + cur_dim = 0 + for i in range(len(block_size)): + if block_size[i] != input_size[i] and block_size[i] > 1: + if input_size[i] % block_size[i] != 0: + raise AssertionError( + f"Expecting input size at {i} dimension: {input_size[i]} to be divisible " + f"by block_size at {i} dimension: {block_size[i]}" + ) + shape_for_reduction.append(input_size[i] // block_size[i]) + shape_for_reduction.append(block_size[i]) + # reduce over the block_size[i] dim + reduction_dims.append(cur_dim + 1) + cur_dim += 2 + else: + # block_size[i] == input_size[i] or block_size[i] == 1 + shape_for_reduction.append(input_size[i]) + # we only need to reduce over the dimension if block_size is greater than 1 + # otherwise it's already the same as reduced dimension + if block_size[i] != 1: + reduction_dims.append(cur_dim) + cur_dim += 1 + return shape_for_reduction, reduction_dims + + +def _register_custom_op(lib): + """This decorator is used to preserve some high level operators for torch.export.export + while still allow them to be decomposed for inductor path + + requirement: make sure `fn.__name__[1:]` is the operator name you want to register + + NOTE: This should be applied at the top, after all other decorators have been applied + NOTE: We haven't tested the case when `fn` accepts tensor subclass instance as input, + e.g. uint4 tensor subclass instance, and we'll probably need to figure out what would make + sense for downstream system (like executorch) to accept as well + + Example: + lib = torch.library.Library("my_namespace', "FRAGMENT") + + register_custom_op = _register_custom_op(lib) + + @register_custom_op + def _the_op_that_needs_to_be_preserved(...) + ... + + # after this, `_the_op_that_needs_to_be_preserved` will be preserved as + # torch.ops.my_namespace.the_op_that_needs_to_be_preserved operator after + # torch.export.export / torch._export.export_for_training + + """ + from torch._inductor.decomposition import register_decomposition + + def decorator(fn): + from torch._library.infer_schema import infer_schema + + # expecting fn.__name__ starts with `_` and we want to take the rest + # to be the name of the custom op + if fn.__name__[0] != "_": + raise AssertionError( + f"Expecting function name starts with `_`, got {fn.__name__}" + ) + if any(c in fn.__name__ for c in ".<>"): + raise AssertionError( + f"Expecting op to be defined in normal functions, not lambda or local: {fn.__name__}" + ) + op_name = fn.__name__[1:] + schema = op_name + infer_schema(fn, mutates_args={}) + lib.define(schema) + lib.impl(op_name, fn, "CompositeImplicitAutograd") + + lib_namespace = lib.ns + op = getattr(getattr(torch.ops, lib_namespace), op_name) + register_decomposition([op])(fn) + return op + + return decorator + + +quant_lib = torch.library.Library("pt2e_quant", "FRAGMENT") # noqa: TOR901 + +register_custom_op = _register_custom_op(quant_lib) + + +def choose_qparams_affine_with_min_max( + min_val: torch.Tensor, + max_val: torch.Tensor, + mapping_type: MappingType, + block_size: tuple[int, ...], + target_dtype: torch.dtype, + quant_min: int | None = None, + quant_max: int | None = None, + eps: float | None = None, + scale_dtype: torch.dtype | None = None, + zero_point_dtype: torch.dtype | None = None, + preserve_zero: bool = True, + zero_point_domain: ZeroPointDomain | None = ZeroPointDomain.INT, +) -> tuple[torch.Tensor, torch.Tensor]: + """A variant of :func:`~torchao.quantization.quant_primitives.choose_qparams_affine` + operator that pass in min_val and max_val directly instead of deriving these from a single input. + This is used for observers in static quantization where min_val and max_val may be obtained through + tracking all the data in calibration data set. + + Args: + Mostly same as :func:`~torchao.quantization.quant_primitives.choose_qparams_affine`. with one + difference: instead of passing in `input` Tensor and use that to calculate min_val/max_val + and then scale/zero_point, we pass in min_val/max_val directly + """ + return _choose_qparams_affine( + None, + mapping_type.name, + block_size, + target_dtype, + quant_min, + quant_max, + eps, + scale_dtype, + zero_point_dtype, + preserve_zero, + zero_point_domain.name if zero_point_domain is not None else None, + min_val, + max_val, + ) + + +@register_custom_op +def _choose_qparams_affine( + input: torch.Tensor | None, + mapping_type: str, + block_size: list[int], + target_dtype: torch.dtype, + quant_min: int | float | bool | None = None, + quant_max: int | float | bool | None = None, + eps: float | None = None, + scale_dtype: torch.dtype | None = None, + zero_point_dtype: torch.dtype | None = None, + preserve_zero: bool = True, + zero_point_domain: str | None = "INT", + min_val: torch.Tensor | None = None, + max_val: torch.Tensor | None = None, +) -> tuple[torch.Tensor, torch.Tensor]: + """op definition that has compatible signatures with custom op library + + The op does the following: + 1. figure out the dimension for reduction based on block_size + 2. find min_val/max_val based on the dimension for reduction + 3. calculate quantization parameters based on min_val/max_val based on args like `preserve_zero` + and `zero_point_domain` + """ + quant_min, quant_max = _get_and_check_qmin_qmax(target_dtype, quant_min, quant_max) + if mapping_type not in [ + MappingType.SYMMETRIC.name, + MappingType.SYMMETRIC_NO_CLIPPING_ERR.name, + MappingType.ASYMMETRIC.name, + ]: + raise AssertionError(f"Unsupported mapping type: {mapping_type}") + if target_dtype in FP8_TYPES: + if mapping_type != MappingType.SYMMETRIC.name: + raise AssertionError( + f"Only symmetric quantization is supported for FP8 types, got {mapping_type}" + ) + + if input is not None: + if scale_dtype is None: + scale_dtype = input.dtype + if zero_point_dtype is None: + zero_point_dtype = input.dtype + if eps is None: + eps = torch.finfo(input.dtype).eps + + if len(block_size) != input.dim(): + raise AssertionError( + f"Got input dim:{input.dim()}, block_size: {block_size}" + ) + shape_for_reduction, reduction_dims = _get_reduction_params( + block_size, input.size() + ) + input = input.view(shape_for_reduction) + + min_val = torch.amin(input, dim=reduction_dims, keepdim=False) + max_val = torch.amax(input, dim=reduction_dims, keepdim=False) + else: + if min_val is None or max_val is None: + raise AssertionError( + f"Need to provide `min_val` and `max_val` when `input` is None, got: {min_val, max_val}" + ) + if min_val.dtype != max_val.dtype: + raise AssertionError( + f"Expecting `min_val` and `max_val` to have the same dtype, got: {min_val.dtype, max_val.dtype}" + ) + + if scale_dtype is None: + scale_dtype = min_val.dtype + if zero_point_dtype is None: + zero_point_dtype = min_val.dtype + if eps is None: + eps = torch.finfo(min_val.dtype).eps + + if preserve_zero: + min_val_neg = torch.min(min_val, torch.zeros_like(min_val)) + max_val_pos = torch.max(max_val, torch.zeros_like(max_val)) + else: + min_val_neg = min_val + max_val_pos = max_val + + if ( + mapping_type == MappingType.SYMMETRIC.name + or mapping_type == MappingType.SYMMETRIC_NO_CLIPPING_ERR.name + ): + # scales + if mapping_type == MappingType.SYMMETRIC.name: + max_val_pos = torch.max(-min_val_neg, max_val_pos) + scale = max_val_pos / (float(quant_max - quant_min) / 2) + else: + if mapping_type != MappingType.SYMMETRIC_NO_CLIPPING_ERR.name: + raise AssertionError( + f"Expected mapping_type to be SYMMETRIC_NO_CLIPPING_ERR, got {mapping_type}" + ) + # calculate smin and smax individually and choose the larger one. For example, if quant_min = -8 and + # quant_max = 7. + # - If smin is bigger: There would be coverage on negative values down to -8, and less rounding + # error than the existing SYMMETRIC case. + # - If smax is bigger: it covers the positive values up to 7. The round + # error may be bigger than the existing SYMMETRIC case. Either way, there's no out-of-range fp values after + # quantization. + smin = min_val_neg / float(quant_min) + smax = max_val_pos / float(quant_max) + mask = smin > smax + scale = torch.where(mask, smin, smax) + # zeros + if not preserve_zero: + raise ValueError( + "preserve_zero == False is not supported for symmetric quantization" + ) + if ( + zero_point_domain is not None + and zero_point_domain != ZeroPointDomain.INT.name + ): + raise ValueError( + "zero_point_domain != ZeroPointDomain.INT is not supported for symmetric quantization" + ) + scale = torch.clamp(scale, min=eps) + zero_point = torch.full_like(scale, int((quant_max + quant_min + 1) / 2)) + else: + if mapping_type != MappingType.ASYMMETRIC.name: + raise AssertionError( + f"Expected mapping_type to be ASYMMETRIC, got {mapping_type}" + ) + scale = (max_val_pos - min_val_neg) / float(quant_max - quant_min) + scale = torch.clamp(scale, min=eps) + if zero_point_domain == ZeroPointDomain.NONE.name: + zero_point = None + else: + if preserve_zero: + zero_point = quant_min - torch.round(min_val_neg / scale) + zero_point = torch.clamp(zero_point, quant_min, quant_max) + else: + if zero_point_domain != ZeroPointDomain.FLOAT.name: + raise AssertionError( + "if not preserve_zero, zero_point must be in FLOAT domain" + ) + mid_point = (quant_max + quant_min + 1) / 2 + zero_point = min_val_neg + scale * mid_point + + if zero_point is not None: + zero_point = zero_point.to(dtype=zero_point_dtype) + return scale.to(dtype=scale_dtype), zero_point + + +@torch.no_grad() +def quantize_affine( + input: torch.Tensor, + block_size: tuple[int, ...], + scale: torch.Tensor, + zero_point: torch.Tensor | None, + output_dtype: torch.dtype, + quant_min: int | float | None = None, + quant_max: int | float | None = None, + zero_point_domain: ZeroPointDomain | None = ZeroPointDomain.INT, +) -> torch.Tensor: + """ + Args: + input (torch.Tensor): original float32, float16 or bfloat16 Tensor + block_size: (Tuple[int, ...]): granularity of quantization, + this means the size of the tensor elements that's sharing the same qparam + e.g. when size is the same as the input tensor dimension, we are using per tensor quantization + scale (float): quantization parameter for affine quantization + zero_point (int): quantization parameter for affine quantization + output_dtype (torch.dtype): requested dtype (e.g. torch.uint8) for output Tensor + quant_min (Optional[int]): minimum quantized value for output Tensor, if not specified, it will be derived from dtype + quant_max (Optional[int]): maximum quantized value for output Tensor, if not specified, it will be derived from dtype + zero_point_domain (ZeroPointDomain): the domain that zero_point is in, should be either integer or float + if zero_point is in integer domain, zero point is added to the quantized integer value during + quantization + if zero_point is in floating point domain, zero point is subtracted from the floating point (unquantized) + value during quantization + default is ZeroPointDomain.INT + + Note: + How can block_size represent different granularities? + let's say we have a Tensor of size: (3, 3, 10, 10), here is the table showing how block_size represents different + granularities: + + granularity type | block_size + per_tensor | (3, 3, 10, 10) + per_axis (axis=0) | (1, 3, 10, 10) + per_axis (axis=1) | (3, 1, 10, 10) + per_group (groupsize=2) | (3, 3, 10, 2) + per_group (groupsize=2) for axis = 3 | (3, 3, 2, 10) + + + Output: + quantized tensor with requested dtype + """ + return _quantize_affine( + input, + block_size, + scale, + zero_point, + output_dtype, + quant_min, + quant_max, + zero_point_domain.name if zero_point_domain is not None else None, + ) + + +@register_custom_op +def _quantize_affine( + input: torch.Tensor, + block_size: list[int], + scale: torch.Tensor, + zero_point: torch.Tensor | None, + output_dtype: torch.dtype, + quant_min: int | float | bool | None = None, + quant_max: int | float | bool | None = None, + zero_point_domain: str | None = ZeroPointDomain.INT.name, +) -> torch.Tensor: + """op definition that has compatible signatures with custom op library + + Note: + zero_point_domain is optional specifies how we quantize the floating point to quantized data: + INT: quantized_val = (float_val / scale) (integer) + zero_point (integer) + FLOAT: quantized_val = (float_val - (zero_point (float) - scale * mid_point)) / scale + None: quantized_val = (float_val / scale) | this is primarily used for floatx quantization + Where we do not want to round values to nearest integer and instead scale and cast. + """ + quant_min, quant_max = _get_and_check_qmin_qmax(output_dtype, quant_min, quant_max) + # workaround for uintx dtypes, since we don't have native Uintx dtype connected with + # torch.uintx dtypes yet + if output_dtype in _SUB_BYTE_UINT_BOUNDS: + output_dtype = torch.uint8 + return _quantize_affine_no_dtype_cast( + input, + block_size, + scale, + zero_point, + quant_min, + quant_max, + zero_point_domain, + ).to(output_dtype) + + +def _quantize_affine_no_dtype_cast( + input: torch.Tensor, + block_size: list[int], + scale: torch.Tensor, + zero_point: torch.Tensor | None, + quant_min: int | float, + quant_max: int | float, + zero_point_domain: str | None = ZeroPointDomain.INT.name, +) -> torch.Tensor: + """ + The op does the following: + 1. figure out the dimension for reduction based on block_size, also reshape the input to align with + the shape after reduction + 2. quantize the input based on the quantization parameters scale and zero_point and args like zero_point_domain + 3. reshape the quantized result to original shape + """ + # TODO: validations + # TODO: validate scale/zero_point dimensions are compatible with block_size + if input.dtype not in [torch.float32, torch.float16, torch.bfloat16]: + raise AssertionError(f"Unsupported input dtype: {input.dtype}") + if len(block_size) != input.dim(): + raise AssertionError(f"Got input dim:{input.dim()}, block_size: {block_size}") + shape_for_reduction, reduction_dims = _get_reduction_params( + block_size, input.size() + ) + original_shape = input.shape + input = input.view(shape_for_reduction) + shape_after_reduction = shape_for_reduction + for i in reduction_dims: + shape_after_reduction[i] = 1 + scale = scale.view(shape_after_reduction) + if zero_point is not None: + zero_point = zero_point.view(shape_after_reduction) + + if zero_point_domain == ZeroPointDomain.INT.name: + quant = torch.clamp( + torch.round(input * (1.0 / scale)) + zero_point, quant_min, quant_max + ) + elif zero_point_domain == ZeroPointDomain.NONE.name: + if zero_point is not None: + raise AssertionError( + "zero_point should be None when zero_point_domain is NONE" + ) + quant = torch.clamp(torch.round(input * (1.0 / scale)), quant_min, quant_max) + elif zero_point_domain is None: + # This case handles quantization for float8 we expect no zero point and no zero point domain + if zero_point is not None: + raise AssertionError( + "zero_point should be None when zero_point_domain is None" + ) + quant = torch.clamp(input * scale.reciprocal(), quant_min, quant_max) + else: + if zero_point_domain != ZeroPointDomain.FLOAT.name: + raise AssertionError(f"Unexpected zero_point_domain: {zero_point_domain}") + mid_point = (quant_max + quant_min + 1) / 2 + min_val = zero_point - scale * mid_point + quant = torch.clamp( + torch.round((input - min_val) / scale), quant_min, quant_max + ) + quant = quant.view(original_shape) + + return quant + + +def dequantize_affine( + input: torch.Tensor, + block_size: tuple[int, ...], + scale: torch.Tensor, + zero_point: torch.Tensor | None, + input_dtype: torch.dtype, + quant_min: int | float | None = None, + quant_max: int | float | None = None, + zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT, + *, + output_dtype: torch.dtype = torch.float32, +) -> torch.Tensor: + """ + Args: + input (torch.Tensor): quantized tensor, should match the dtype `dtype` argument + block_size: (List[int]): granularity of quantization, + this means the size of the tensor elements that's sharing the same qparam + e.g. when size is the same as the input tensor dimension, we are using per tensor quantization + scale (Tensor): quantization parameter for affine quantization + zero_point (Tensor): quantization parameter for affine quantization + input_dtype (torch.dtype): requested dtype (e.g. torch.uint8) for output Tensor + quant_min (Optional[int]): minimum quantized value for input Tensor + quant_max (Optional[int]): maximum quantized value for input Tensor + output_dtype (torch.dtype): dtype for output Tensor, default is fp32 + zero_point_domain (ZeroPointDomain): the domain that zero_point is in, should be either integer or float + if zero_point is in integer domain, zero point is added to the quantized integer value during + quantization + if zero_point is in floating point domain, zero point is subtracted from the floating point (unquantized) + value during quantization + default is ZeroPointDomain.INT + + Output: + dequantized Tensor, with requested dtype or fp32 + """ + return _dequantize_affine( + input, + block_size, + scale, + zero_point, + input_dtype, + quant_min, + quant_max, + zero_point_domain.name if zero_point_domain is not None else None, + output_dtype=output_dtype, + ) + + +@register_custom_op +def _dequantize_affine( + input: torch.Tensor, + block_size: list[int], + scale: torch.Tensor, + zero_point: torch.Tensor | None, + input_dtype: torch.dtype, + quant_min: int | float | bool | None = None, + quant_max: int | float | bool | None = None, + zero_point_domain: str | None = ZeroPointDomain.INT.name, + output_dtype: torch.dtype = torch.float32, +) -> torch.Tensor: + """op definition that has compatible signatures with custom op library""" + # TODO: validate scale/zero_point dimensions are compatible with block_size + if input_dtype not in _SUB_BYTE_UINT_BOUNDS: + if input.dtype != input_dtype: + raise AssertionError(f"Expected: {input_dtype}, got: {input.dtype}") + if output_dtype not in [torch.float32, torch.float16, torch.bfloat16]: + raise AssertionError(f"Unsupported output dtype: {output_dtype}") + quant_min, quant_max = _get_and_check_qmin_qmax(input_dtype, quant_min, quant_max) + return _dequantize_affine_no_dtype_check( + input, + block_size, + scale, + zero_point, + quant_min, + quant_max, + zero_point_domain, + output_dtype, + ) + + +def _dequantize_affine_no_dtype_check( + input: torch.Tensor, + block_size: list[int], + scale: torch.Tensor, + zero_point: torch.Tensor | None, + quant_min: int | float, + quant_max: int | float, + zero_point_domain: str | None = ZeroPointDomain.INT.name, + output_dtype: torch.dtype = torch.float32, +) -> torch.Tensor: + """This function converts AQT tensors to their high precision floating point representation + + The op does the following: + 1. figure out the dimension for reduction based on block_size, also reshape the input to align with + the shape after reduction + 2. dequantize the input based on the quantization parameters scale and zero_point and args like zero_point_domain + 3. reshape the quantized result to original shape and change dtype to the output_dtype + """ + if len(block_size) != input.dim(): + raise AssertionError(f"Got input dim:{input.dim()}, block_size: {block_size}") + shape_for_reduction, reduction_dims = _get_reduction_params( + block_size, input.size() + ) + original_shape = input.shape + input = input.view(shape_for_reduction) + shape_after_reduction = shape_for_reduction + for i in reduction_dims: + shape_after_reduction[i] = 1 + scale = scale.view(shape_after_reduction) + + if zero_point is not None: + zero_point = zero_point.view(shape_after_reduction) + + if zero_point_domain == ZeroPointDomain.INT.name: + # Force a copy to avoid input modification due + # to upcoming in-place operations. + dequant = input.to(torch.int32, copy=True) + if zero_point is not None: + dequant = dequant - zero_point.to(torch.int32) + dequant = dequant.to(output_dtype) + dequant = dequant * scale + elif zero_point_domain == ZeroPointDomain.NONE.name: + if zero_point is not None: + raise AssertionError( + "zero_point should be None when zero_point_domain is NONE" + ) + dequant = input.to(output_dtype) + dequant = dequant * scale + elif zero_point_domain is None: + # This case handles dequantization for float8 we expect no zero point and no zero point domain + if zero_point is not None: + raise AssertionError( + "zero_point should be None when zero_point_domain is None" + ) + if not _is_float8_type(input.dtype): + raise AssertionError( + f"dequantiztion with no zero point domain is only supported with FP8 types, got {input.dtype}" + ) + dequant = input.to(output_dtype) + dequant = dequant * scale + else: + if zero_point_domain != ZeroPointDomain.FLOAT.name: + raise AssertionError(f"Unexpected zero point domain: {zero_point_domain}") + # TODO: this seems to be a detail for tinygemm (converting from uint to int, probably need to refactor this) + mid_point = (quant_max + quant_min + 1) / 2 + # This should allocate new memory and avoid input modification + dequant = input - mid_point + dequant = dequant.to(output_dtype) + dequant *= scale + if zero_point is not None: + dequant += zero_point + + return dequant.view(original_shape).to(output_dtype) + + +class AffineQuantizedMinMaxObserver(AffineQuantizedObserverBase): + def forward(self, input: torch.Tensor): + if input.numel() == 0: + return input + + input_detached = input.detach() + self.original_dtype = input_detached.dtype + if self.granularity is None: + raise AssertionError("granularity is None") + self.block_size = get_block_size(input_detached.shape, self.granularity) + + shape_for_reduction, reduction_dims = _get_reduction_params( + self.block_size, input_detached.size() + ) + input_detached = input_detached.view(shape_for_reduction) + min_val = torch.amin(input_detached, dim=reduction_dims, keepdim=False) + max_val = torch.amax(input_detached, dim=reduction_dims, keepdim=False) + if not hasattr(self, "min_val") or not hasattr(self, "max_val"): + self.min_val = min_val + self.max_val = max_val + else: + if self.min_val.shape != min_val.shape: + raise AssertionError( + f"Can't update existing min_val - shape mismatch, self.min_val:{self.min_val.shape} != min_val:{min_val.shape}" + ) + if self.max_val.shape != max_val.shape: + raise AssertionError( + f"Can't update existing max_val - shape mismatch, self.max_val {self.max_val.shape} != max_val:{max_val.shape}" + ) + min_val = torch.min(self.min_val, min_val) + max_val = torch.max(self.max_val, max_val) + self.min_val.copy_(min_val) + self.max_val.copy_(max_val) + # returning original input + return input + + def calculate_qparams(self) -> tuple[torch.Tensor, torch.Tensor]: + if not (hasattr(self, "min_val") and hasattr(self, "max_val")): + raise AssertionError( + "Expecting the observer has min_val and max_val, please run the observer before calling calculate_qparams" + ) + return choose_qparams_affine_with_min_max( + self.min_val, + self.max_val, + self.mapping_type, + [], # BlockSize is not needed because the min/max are already reduced + self.target_dtype, + self.quant_min, + self.quant_max, + self.eps, + self.scale_dtype, + self.zero_point_dtype, + self.preserve_zero, + self.zero_point_domain, + ) + + +class AffineQuantizedMovingAverageMinMaxObserver(AffineQuantizedObserverBase): + def __init__( + self, + mapping_type: MappingType, + target_dtype: torch.dtype, + granularity: Granularity, + averaging_constant=0.01, + quant_min: int | None = None, + quant_max: int | None = None, + eps: float | None = None, + is_dynamic=False, + scale_dtype: torch.dtype | None = None, + zero_point_dtype: torch.dtype | None = None, + preserve_zero: bool = True, + zero_point_domain: ZeroPointDomain | None = ZeroPointDomain.INT, + # there could be some extra args that's ignored + **kwargs, + ): + self.is_dynamic = is_dynamic + self.averaging_constant = averaging_constant + if is_dynamic and self.averaging_constant != 1: + raise NotImplementedError( + "MovingAverageMinMaxObserver doesn't support dynamic quantization for " + f"averaging constant of {self.averaging_constant}" + ) + + super().__init__( + mapping_type=mapping_type, + target_dtype=target_dtype, + granularity=granularity, + quant_min=quant_min, + quant_max=quant_max, + eps=eps, + scale_dtype=scale_dtype, + zero_point_dtype=zero_point_dtype, + preserve_zero=preserve_zero, + zero_point_domain=zero_point_domain, + ) + + def forward(self, input: torch.Tensor): + if input.numel() == 0: + return input + + input_detached = input.detach() + self.original_dtype = input_detached.dtype + if self.granularity is None: + raise AssertionError("granularity is None") + self.block_size = get_block_size(input_detached.shape, self.granularity) + + shape_for_reduction, reduction_dims = _get_reduction_params( + self.block_size, input_detached.size() + ) + input_detached = input_detached.view(shape_for_reduction) + min_val = torch.amin(input_detached, dim=reduction_dims, keepdim=False) + max_val = torch.amax(input_detached, dim=reduction_dims, keepdim=False) + if not hasattr(self, "min_val") or not hasattr(self, "max_val"): + self.min_val = min_val + self.max_val = max_val + else: + if self.min_val.shape != min_val.shape: + raise AssertionError( + f"Can't update existing min_val - shape mismatch, self.min_val:{self.min_val.shape} != min_val:{min_val.shape}" + ) + if self.max_val.shape != max_val.shape: + raise AssertionError( + f"Can't update existing max_val - shape mismatch, self.max_val {self.max_val.shape} != max_val:{max_val.shape}" + ) + min_val = self.min_val + self.averaging_constant * (min_val - self.min_val) + max_val = self.max_val + self.averaging_constant * (max_val - self.max_val) + self.min_val.copy_(min_val) + self.max_val.copy_(max_val) + + # returning original input + return input + + def calculate_qparams(self) -> tuple[torch.Tensor, torch.Tensor]: + if not (hasattr(self, "min_val") and hasattr(self, "max_val")): + raise AssertionError( + "Expecting the observer has min_val and max_val, please run the observer before calling calculate_qparams" + ) + + return choose_qparams_affine_with_min_max( + self.min_val, + self.max_val, + self.mapping_type, + [], # BlockSize is not needed because the min/max are already reduced + self.target_dtype, + self.quant_min, + self.quant_max, + self.eps, + self.scale_dtype, + self.zero_point_dtype, + self.preserve_zero, + self.zero_point_domain, + ) + + +class AffineQuantizedPlaceholderObserver(AffineQuantizedObserverBase): + def __init__( + self, + mapping_type: MappingType, + target_dtype: torch.dtype, + granularity: Granularity, + quant_min: int | None = None, + quant_max: int | None = None, + eps: float | None = None, + is_dynamic=False, + scale_dtype: torch.dtype | None = None, + zero_point_dtype: torch.dtype | None = None, + preserve_zero: bool = True, + zero_point_domain: ZeroPointDomain | None = ZeroPointDomain.INT, + # there could be some extra args that's ignored + **kwargs, + ): + self.is_dynamic = is_dynamic + + super().__init__( + mapping_type=mapping_type, + target_dtype=target_dtype, + granularity=granularity, + quant_min=quant_min, + quant_max=quant_max, + eps=eps, + scale_dtype=scale_dtype, + zero_point_dtype=zero_point_dtype, + preserve_zero=preserve_zero, + zero_point_domain=zero_point_domain, + ) + + def forward(self, input): + self.block_size = get_block_size(input.shape, self.granularity) + self.original_dtype = input.dtype + return input + + def calculate_qparams(self): + raise Exception( # noqa: TRY002 + "calculate_qparams should not be called for PlaceholderObserver" + ) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/pt2e/_numeric_debugger.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/pt2e/_numeric_debugger.py new file mode 100644 index 0000000000000000000000000000000000000000..6eaeaa46a924893fa0aace363f3040d5e2d692de --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/pt2e/_numeric_debugger.py @@ -0,0 +1,341 @@ +import copy +import logging +from collections.abc import Callable, Sequence +from dataclasses import dataclass + +import torch +from torch.ao.ns.fx.utils import compute_sqnr +from torch.ao.quantization.pt2e.graph_utils import bfs_trace_with_node_process +from torch.export import ExportedProgram +from torch.fx import GraphModule, Node +from torch.nn import functional as F + + +NUMERIC_DEBUG_HANDLE_KEY = "numeric_debug_handle" +CUSTOM_KEY = "custom" + +log = logging.getLogger(__name__) + + +def generate_numeric_debug_handle(ep: ExportedProgram) -> None: + """ + Attach numeric_debug_handle_id for all nodes in the graph module of the given + ExportedProgram, like conv2d, squeeze, conv1d, etc, except for placeholder. + Notice that nodes like getattr are out of scope since they are not in the graph. + + The graph nodes of input exported program are modified inplace. + + Here's an example of using debug handle quantize flow:: + + ep = export_for_training(eager_model, example_inputs) + generate_numeric_debug_handle(ep) + + m = ep.module() + quantizer = XNNPACKQuantizer() + m = prepare_pt2e(m, quantizer) + m = convert_pt2e(m) + """ + + # Sanity check the input data type + if not isinstance(ep, ExportedProgram): + raise ValueError( + f"Expected ep to be ExportedProgram, got {type(ExportedProgram)}" + ) + + unique_id = 0 + + def _find_max_id(node: torch.fx.Node) -> None: + nonlocal unique_id + unique_id = max( + unique_id, node.meta.get(CUSTOM_KEY, {}).get(NUMERIC_DEBUG_HANDLE_KEY, 0) + ) + + def _assign_debug_handle(node: torch.fx.Node) -> None: + nonlocal unique_id + if CUSTOM_KEY not in node.meta: + node.meta[CUSTOM_KEY] = {} + + if NUMERIC_DEBUG_HANDLE_KEY not in node.meta[CUSTOM_KEY]: + node.meta[CUSTOM_KEY][NUMERIC_DEBUG_HANDLE_KEY] = unique_id + unique_id += 1 + + # Find the max ID that exists in the graph first, in case part of the graph + # has already been annotated. This way we guarantee there are no duplicate + # handle IDs. + bfs_trace_with_node_process(ep, _find_max_id) + + unique_id += 1 + + # Assign debug handles to all nodes in the graph that don't have one based on the + # max ID found in the previous step. + bfs_trace_with_node_process(ep, _assign_debug_handle) + + +def _detach(x: object) -> object: + detached: object = None + if isinstance(x, torch.Tensor): + detached = x.detach() + elif isinstance(x, (list, tuple)): + detached = type(x)([_detach(e) for e in x]) + elif isinstance(x, dict): + detached = {k: _detach(e) for k, e in x.items()} + else: + detached = x + return detached + + +def _tensor_shape_equals(x: object, y: object) -> bool: + if isinstance(x, torch.Tensor) and isinstance(y, torch.Tensor): + return x.shape == y.shape + elif isinstance(x, (list, tuple)) and isinstance(y, (list, tuple)): + return all(_tensor_shape_equals(e1, e2) for e1, e2 in zip(x, y)) + elif isinstance(x, dict) and isinstance(y, dict): + all_equal = True + for k in x: + all_equal = all_equal and k in y and (_tensor_shape_equals(x[k], y[k])) + return all_equal + else: + log.debug("Comparing non Tensors: %s and %s, they must be equal", x, y) + return type(x) is type(y) and x == y + + +def _loss_fn( + loss: Callable[[torch.Tensor, torch.Tensor], torch.Tensor], x: object, y: object +) -> object: + """The returned loss will have the same structure as `x` and `y`, e.g. + if both are Tensor, we'll return a Tensor + if both are list, we'll return a list of Tensors + if both are dict, we'll return a dict with the same key, and value being the loss between the + two Tensors + """ + if isinstance(x, torch.Tensor) and isinstance(y, torch.Tensor): + return loss(x.to(torch.float32), y.to(torch.float32)) + elif isinstance(x, (list, tuple)) and isinstance(y, (list, tuple)): + return type(x)([_loss_fn(loss, e1, e2) for e1, e2 in zip(x, y)]) + elif isinstance(x, dict) and isinstance(y, dict): + return {k: _loss_fn(loss, e, y[k]) for k, e in x.items()} + else: + return None + + +class OutputLogger(torch.nn.Module): + """ + Base class for capturing output values for nodes in a GraphModule, it only captures + Tensor output currently, but we can extend it to work for other types of inputs later if needed + """ + + # Mark as impure so that calls to it will not be removed during DCE. + _is_impure = True + + def __init__( + self, + debug_handle: int, + node_name: str | None = None, + nn_module_stack: object | None = None, + ) -> None: + super().__init__() + self.node_name = node_name + self.nn_module_stack = nn_module_stack + self.debug_handle = debug_handle + self.stats: list[object] = [] + + def forward(self, x: object) -> object: + self.stats.append(_detach(x)) + return x + + def __extra_repr__(self) -> str: + return ( + f"debug_handle={self.debug_handle}, node_name={self.node_name}, " + "nn_module_stack={self.nn_module_stack}, num_stats={len(self.stats)})" + ) + + +def _insert_logger(model: GraphModule, node: Node, debug_handle: int) -> Node: + """For a given node, adds an OutputLogger that observes the output of that node, + and all its users use the OutputLogger output instead. + The OutputLogger will contain the debug_handle which can be used to compare + graphs after transforms""" + + # to avoid circular dep + from torch.ao.quantization.fx.utils import get_new_attr_name_with_prefix + + # add a logger after the node + with model.graph.inserting_after(node): + get_new_attr_name = get_new_attr_name_with_prefix(f"{node.name}_logger") + logger_name = get_new_attr_name(model) + setattr( + model, + logger_name, + OutputLogger(debug_handle, node.name, node.meta.get("nn_module_stack")), + ) + logger_node = model.graph.call_module(logger_name, (node,), {}) + + orig_users = list(node.users.keys()) + for user_node in orig_users: + if user_node is logger_node: + continue + user_node.replace_input_with(node, logger_node) + + return logger_node + + +def prepare_for_propagation_comparison(model: GraphModule) -> GraphModule: + """Add output loggers to node that has numeric_debug_handle + + Args: + model (GraphModule): original model + Returns: + a model with output loggers for all nodes that has numeric_debug_handle_id + """ + # don't change the original model + model = copy.deepcopy(model) + for n in model.graph.nodes: + if ( + CUSTOM_KEY not in n.meta + or NUMERIC_DEBUG_HANDLE_KEY not in n.meta[CUSTOM_KEY] + ): + continue + numeric_debug_handle = n.meta[CUSTOM_KEY][NUMERIC_DEBUG_HANDLE_KEY] + _insert_logger(model, n, numeric_debug_handle) + + model.recompile() + return model + + +@dataclass(frozen=True) +class QuantizationComparisonResult: + actual: torch.Tensor + ref: torch.Tensor + + @property + def mse_loss(self) -> object: + return self.loss(F.mse_loss) + + @property + def sqnr(self) -> object: + return self.loss(compute_sqnr) + + def loss( + self, loss_function: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] + ) -> object: + return _loss_fn(loss_function, self.actual, self.ref) + + def __repr__(self) -> str: + # Don't include the tensors themselves as they are quite large to print + # out. + return ( + f"QuantizationComparisonResult(mse_loss={self.mse_loss}, sqnr={self.sqnr})" + ) + + def __post_init__(self) -> None: + if not isinstance(self.actual, (torch.Tensor, list, tuple, dict)): + raise ValueError( + f"`self.actual` value must be a Tensor, list, tuple or dict, got: {self.actual}" + ) + + if not isinstance(self.ref, (torch.Tensor, list, tuple, dict)): + raise ValueError( + f"`self.ref` value must be a Tensor, list, tuple or dict, got: {self.ref}" + ) + + if not _tensor_shape_equals(self.ref, self.actual): + raise ValueError( + f"Cannot compare tensors with different shapes: ref={self.ref} vs actual={self.actual}" + ) + + +@dataclass(frozen=True) +class NodeAccuracySummary: + handle: int + actual_node_name: str + actual_module_stack: str + ref_node_name: str + ref_module_stack: str + results: Sequence[QuantizationComparisonResult] + + +def _module_stack_to_str(module_stack: object) -> str: + """Simplifies the stack from ("mod", "mod.foo", "mod.foo.0", "mod.foo.0.linear") + to "mod.foo.0.linear" + """ + if not isinstance(module_stack, dict): + return str(module_stack) + module_values_list = list(module_stack.values()) + if len(module_values_list) > 0: + owning_module = module_values_list[-1][0] + return str(owning_module) + else: + return str(module_stack) + + +def extract_results_from_loggers( + model: GraphModule, +) -> dict[int, tuple[str | None, object, list[object]]]: + """For a given model, extract the tensors stats and related information for each debug handle. + The reason we have a list of object, instead of Tensor is because the output of node may not be + a Tensor, it could be (nested) list, tuple or dict as well. + + Returns: + A dict is keyed by the debug_handle id and the values are a list of object recorded + in loggers + + """ + # Results maps debug handle to a tensor list for each model being compared. + handles: dict[int, tuple[str | None, object, list[object]]] = {} + for _name, module in model.named_children(): + if isinstance(module, OutputLogger) and len(module.stats) > 0: + handles[module.debug_handle] = ( + module.node_name, + module.nn_module_stack, + module.stats, + ) + + return handles + + +def compare_results( + ref_results: dict[int, tuple[str | None, object, list[torch.Tensor]]], + actual_results: dict[int, tuple[str | None, object, list[torch.Tensor]]], +) -> dict[int, NodeAccuracySummary]: + """Given two dict mapping from `debug_handle_id` (int) to list of tensors + return a map from `debug_handle_id` to `NodeAccuracySummary` that contains + comparison information like SQNR, MSE etc. + + Args: + ref_results (Dict[int, Tuple[str, object, List[torch.Tensor]]]): reference results for each debug_handle_id + actual_results (Dict[int, Tuple[str, object, List[torch.Tensor]]]): actual results for each debug_handle_id + + Returns: + Dict[int, NodeAccuracySummary] + """ + comparisons = {} + for debug_handle, (ref_name, ref_stack, ref_stats) in ref_results.items(): + if debug_handle not in actual_results: + log.debug( + "Cannot compare for handle %s because it wasn't found in the transformed model", + debug_handle, + ) + continue + actual_name, actual_stack, actual_stats = actual_results[debug_handle] + try: + results = [ + QuantizationComparisonResult(actual=a, ref=b) + for a, b in zip(actual_stats, ref_stats) + ] + except Exception as e: + # Add extra information for an exception from QuantizationComparisonResult + # if the shapes didn't match, to include the handle and the node names. + raise ValueError( + f"For numeric_debug_handle={debug_handle} from ref node {ref_name} and actual node {actual_name}" + ) from e + + comparisons[debug_handle] = NodeAccuracySummary( + handle=debug_handle, + actual_node_name=actual_name or "", + actual_module_stack=_module_stack_to_str(actual_stack), + ref_node_name=ref_name or "", + ref_module_stack=_module_stack_to_str(ref_stack), + results=results, + ) + + return comparisons diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/pt2e/duplicate_dq_pass.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/pt2e/duplicate_dq_pass.py new file mode 100644 index 0000000000000000000000000000000000000000..81c03e51414320e3faee0f6b8906ea38910c41ee --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/pt2e/duplicate_dq_pass.py @@ -0,0 +1,81 @@ +import logging +import operator + +import torch +from torch.ao.quantization.pt2e.utils import ( + _filter_sym_size_users, + _is_valid_annotation, +) +from torch.fx.node import map_arg +from torch.fx.passes.infra.pass_base import PassBase, PassResult + + +logger = logging.getLogger(__name__) +logger.setLevel(logging.WARNING) + +__all__ = ["DuplicateDQPass"] + +_QUANTIZE_OPS = [ + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.quantize_per_tensor.tensor, + torch.ops.quantized_decomposed.quantize_per_channel.default, +] + +_DEQUANTIZE_OPS = [ + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.tensor, + torch.ops.quantized_decomposed.dequantize_per_channel.default, +] + + +def _maybe_duplicate_dq( + gm: torch.fx.GraphModule, dq_node: torch.fx.Node, user: torch.fx.Node +) -> None: + annotation = user.meta.get("quantization_annotation", None) + if not _is_valid_annotation(annotation): # type: ignore[arg-type] + return + with gm.graph.inserting_after(dq_node): + new_node = gm.graph.node_copy(dq_node) + + def maybe_replace_node(n: torch.fx.Node) -> torch.fx.Node: + if n == dq_node: + return new_node + else: + return n + + new_args = map_arg(user.args, maybe_replace_node) + new_kwargs = map_arg(user.kwargs, maybe_replace_node) + user.args = new_args # type: ignore[assignment] + user.kwargs = new_kwargs # type: ignore[assignment] + + +class DuplicateDQPass(PassBase): + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: + for node in graph_module.graph.nodes: + if node.op == "call_function" and node.target in _DEQUANTIZE_OPS: + dq_users = _filter_sym_size_users(node) + if len(dq_users) <= 1: + continue + # Do not duplicate dq for dynamic quantization + # Pattern: choose_qparam - getitem - q - dq + q_node = node.args[0] + if q_node.op == "call_function" and q_node.target in _QUANTIZE_OPS: + getitem_node = q_node.args[1] + if ( + isinstance(getitem_node, torch.fx.node.Node) + and getitem_node.op == "call_function" + and getitem_node.target is operator.getitem + ): + choose_qparam_node = getitem_node.args[0] + if ( + isinstance(choose_qparam_node, torch.fx.node.Node) + and choose_qparam_node.op == "call_function" + and choose_qparam_node.target + == torch.ops.quantized_decomposed.choose_qparams.tensor + ): + continue + for user in dq_users: + _maybe_duplicate_dq(graph_module, node, user) + graph_module.graph.eliminate_dead_code() + graph_module.recompile() + return PassResult(graph_module, True) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/pt2e/export_utils.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/pt2e/export_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..70cca73dd00dcb4bd865dda4f2718a610496323e --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/pt2e/export_utils.py @@ -0,0 +1,240 @@ +# mypy: allow-untyped-defs +import types + +import torch +import torch.nn.functional as F +from torch.ao.quantization.utils import _assert_and_get_unique_device + + +__all__ = [ + "model_is_exported", +] + +_EXPORTED_TRAINING_ATTR = "_exported_training" + + +class _WrapperModule(torch.nn.Module): + """Class to wrap a callable in an :class:`torch.nn.Module`. Use this if you + are trying to export a callable. + """ + + def __init__(self, fn): + super().__init__() + self.fn = fn + + def forward(self, *args, **kwargs): + """Simple forward that just calls the ``fn`` provided to :meth:`WrapperModule.__init__`.""" + return self.fn(*args, **kwargs) + + +def model_is_exported(m: torch.nn.Module) -> bool: + """ + Return True if the `torch.nn.Module` was exported, False otherwise + (e.g. if the model was FX symbolically traced or not traced at all). + """ + return isinstance(m, torch.fx.GraphModule) and any( + "val" in n.meta for n in m.graph.nodes + ) + + +def _replace_dropout(m: torch.fx.GraphModule, train_to_eval: bool): + """ + Switch dropout patterns in the model between train and eval modes. + + Dropout has different behavior in train vs eval mode. For exported models, + however, calling `model.train()` or `model.eval()` does not automatically switch + the dropout behavior between the two modes, so here we need to rewrite the aten + dropout patterns manually to achieve the same effect. + + See https://github.com/pytorch/pytorch/issues/103681. + """ + # Avoid circular dependencies + from .utils import _get_aten_graph_module_for_pattern + + # Needed to ensure subgraph matches are self-contained + m.graph.eliminate_dead_code() + m.recompile() + + for inplace in [False, True]: + + def dropout_train(x): + return F.dropout(x, p=0.5, training=True, inplace=inplace) + + def dropout_eval(x): + return F.dropout(x, p=0.5, training=False, inplace=inplace) + + example_inputs = (torch.randn(1),) + if train_to_eval: + match_pattern = _get_aten_graph_module_for_pattern( + _WrapperModule(dropout_train), + example_inputs, + ) + replacement_pattern = _get_aten_graph_module_for_pattern( + _WrapperModule(dropout_eval), + example_inputs, + ) + else: + match_pattern = _get_aten_graph_module_for_pattern( + _WrapperModule(dropout_eval), + example_inputs, + ) + replacement_pattern = _get_aten_graph_module_for_pattern( + _WrapperModule(dropout_train), + example_inputs, + ) + + from torch.fx.subgraph_rewriter import replace_pattern_with_filters + + replace_pattern_with_filters( + m, + match_pattern, + replacement_pattern, + match_filters=[], + ignore_literals=True, + ) + m.recompile() + + +def _replace_batchnorm(m: torch.fx.GraphModule, train_to_eval: bool): + """ + Switch batchnorm patterns in the model between train and eval modes. + + Batchnorm has different behavior in train vs eval mode. For exported models, + however, calling `model.train()` or `model.eval()` does not automatically switch + the batchnorm behavior between the two modes, so here we need to rewrite the aten + batchnorm patterns manually to achieve the same effect. + """ + # TODO(Leslie): This function still fails to support custom momentum and eps value. + # Enable this support in future updates. + + # Avoid circular dependencies + from .utils import _get_aten_graph_module_for_pattern + + # Needed to ensure subgraph matches are self-contained + m.graph.eliminate_dead_code() + m.recompile() + + def bn_train( + x: torch.Tensor, + bn_weight: torch.Tensor, + bn_bias: torch.Tensor, + bn_running_mean: torch.Tensor, + bn_running_var: torch.Tensor, + ): + return F.batch_norm( + x, bn_running_mean, bn_running_var, bn_weight, bn_bias, training=True + ) + + def bn_eval( + x: torch.Tensor, + bn_weight: torch.Tensor, + bn_bias: torch.Tensor, + bn_running_mean: torch.Tensor, + bn_running_var: torch.Tensor, + ): + return F.batch_norm( + x, bn_running_mean, bn_running_var, bn_weight, bn_bias, training=False + ) + + example_inputs = ( + torch.randn(1, 1, 3, 3), # x + torch.randn(1), # bn_weight + torch.randn(1), # bn_bias + torch.randn(1), # bn_running_mean + torch.randn(1), # bn_running_var + ) + + device = _assert_and_get_unique_device(m) + is_cuda = device is not None and device.type == "cuda" + bn_train_aten = _get_aten_graph_module_for_pattern( + _WrapperModule(bn_train), + example_inputs, + is_cuda, + ) + bn_eval_aten = _get_aten_graph_module_for_pattern( + _WrapperModule(bn_eval), + example_inputs, + is_cuda, + ) + + if train_to_eval: + match_pattern = bn_train_aten + replacement_pattern = bn_eval_aten + else: + match_pattern = bn_eval_aten + replacement_pattern = bn_train_aten + + from torch.fx.subgraph_rewriter import replace_pattern_with_filters + + replace_pattern_with_filters( + m, + match_pattern, + replacement_pattern, + match_filters=[], + ignore_literals=True, + ) + m.recompile() + + +# TODO: expose these under this namespace? +def _move_exported_model_to_eval(model: torch.fx.GraphModule): + """ + Move an exported GraphModule to eval mode. + + This is equivalent to model.eval() but only for certain special ops like dropout, batchnorm. + QAT users should call this before performing inference on the model. + + This call is idempotent; if the model is already in eval mode, nothing will happen. + """ + is_training = getattr(model, _EXPORTED_TRAINING_ATTR, True) + if not is_training: + return model + setattr(model, _EXPORTED_TRAINING_ATTR, False) + _replace_dropout(model, train_to_eval=True) + _replace_batchnorm(model, train_to_eval=True) + return model + + +def _move_exported_model_to_train(model: torch.fx.GraphModule): + """ + Move an exported GraphModule to train mode. + + This is equivalent to model.train() but only for certain special ops like dropout, batchnorm. + QAT users should call this before performing training on the model. + + This call is idempotent; if the model is already in train mode, nothing will happen. + """ + is_training = getattr(model, _EXPORTED_TRAINING_ATTR, False) + if is_training: + return model + setattr(model, _EXPORTED_TRAINING_ATTR, True) + _replace_dropout(model, train_to_eval=False) + _replace_batchnorm(model, train_to_eval=False) + return model + + +def _allow_exported_model_train_eval(model: torch.fx.GraphModule): + """ + Allow users to call `model.train()` and `model.eval()` on an exported model, + but with the effect of changing behavior between the two modes limited to special + ops only, which are currently dropout and batchnorm. + + Note: This does not achieve the same effect as what `model.train()` and `model.eval()` + does in eager models, but only provides an approximation. In particular, user code + branching on `training` flag will not function correctly in general because the branch + is already specialized at export time. Additionally, other ops beyond dropout and batchnorm + that have different train/eval behavior will also not be converted properly. + """ + + def _train(self, mode: bool = True): + if mode: + _move_exported_model_to_train(self) + else: + _move_exported_model_to_eval(self) + + def _eval(self): + _move_exported_model_to_eval(self) + + model.train = types.MethodType(_train, model) # type: ignore[method-assign] + model.eval = types.MethodType(_eval, model) # type: ignore[method-assign] + return model diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/pt2e/graph_utils.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/pt2e/graph_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..a6b46011d8c41d3dc0d980e33caecb22b615fd22 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/pt2e/graph_utils.py @@ -0,0 +1,191 @@ +# mypy: allow-untyped-defs +import itertools +import operator +from collections import OrderedDict +from collections.abc import Callable, Sequence +from typing import Any + +import torch +from torch.export import ExportedProgram +from torch.fx import Node +from torch.fx.passes.utils.source_matcher_utils import ( + check_subgraphs_connected, + get_source_partitions, + SourcePartition, +) + + +__all__ = [ + "find_sequential_partitions", + "get_equivalent_types", + "update_equivalent_types_dict", + "bfs_trace_with_node_process", +] + +_EQUIVALENT_TYPES: list[set] = [ + {torch.nn.Conv1d, torch.nn.functional.conv1d}, + {torch.nn.Conv2d, torch.nn.functional.conv2d}, + {torch.nn.AdaptiveAvgPool2d, torch.nn.functional.adaptive_avg_pool2d}, + {torch.nn.ReLU, torch.nn.functional.relu, torch.nn.functional.relu_}, + {torch.nn.BatchNorm2d, torch.nn.functional.batch_norm}, + {torch.nn.Hardtanh, torch.nn.functional.hardtanh, torch.nn.functional.hardtanh_}, + {torch.add, operator.add, operator.iadd, "add", "add_"}, + {torch.mul, operator.mul, operator.imul, "mul", "mul_"}, +] + + +def _create_equivalent_types_dict(): + _DICT = {} + for values in _EQUIVALENT_TYPES: + for v in values: + _DICT[v] = list(values) + return _DICT + + +_EQUIVALENT_TYPES_DICT = _create_equivalent_types_dict() + + +def get_equivalent_types() -> list[set]: + return _EQUIVALENT_TYPES + + +def update_equivalent_types_dict(customized_equivalent_types=None): + """Help function for user who wants to customize the _EQUIVALENT_TYPES and _EQUIVALENT_TYPES_DICT. + When customized_equivalent_types passes in, + re-generate _EQUIVALENT_TYPES and _EQUIVALENT_TYPES_DICT. + """ + if customized_equivalent_types is None: + raise ValueError("customized_equivalent_types should not be None") + global _EQUIVALENT_TYPES + global _EQUIVALENT_TYPES_DICT + _EQUIVALENT_TYPES = customized_equivalent_types + _EQUIVALENT_TYPES_DICT = _create_equivalent_types_dict() + + +def _partitions_sequential(partitions: Sequence[SourcePartition]): + prev_partition = None + for partition in partitions: + if prev_partition is not None and not check_subgraphs_connected( + prev_partition, partition + ): + return False + prev_partition = partition + return True + + +def _get_matching_types(partition_type): + matching_types = [partition_type] + if partition_type in _EQUIVALENT_TYPES_DICT: + matching_types.extend(_EQUIVALENT_TYPES_DICT[partition_type]) + return matching_types + + +def _valid_type_sequence(partition_types: list[Any]): + partition_types_set = set() # type: ignore[var-annotated] + for partition_type in partition_types: + matching_types = _get_matching_types(partition_type) + matching_types_set = set(matching_types) + if len(partition_types_set & matching_types_set) > 0: + return False + partition_types_set |= matching_types_set + return True + + +def find_sequential_partitions( + gm: torch.fx.GraphModule, + partition_types: list[Any], + include_functional_equivalent=True, + filter_fn: Callable[[Node], bool] | None = None, +): + if not _valid_type_sequence(partition_types): + raise ValueError( + f"Invalid partition types: {partition_types}. Each type in the sequence must be unique" + ) + + typed_partitions: OrderedDict[Any, list[SourcePartition]] = OrderedDict() + for partition_type in partition_types: + types_to_match = _get_matching_types(partition_type) + partitions = get_source_partitions(gm.graph, types_to_match, filter_fn) + typed_partitions[partition_type] = list( + itertools.chain.from_iterable(partitions.values()) + ) + + typed_partitions_list = list(typed_partitions.values()) + fusion_candidates = itertools.product(*typed_partitions_list) + fused_partitions = [ + candidate + for candidate in fusion_candidates + if _partitions_sequential(candidate) + ] + return fused_partitions + + +def _get_submodule( + graph_module: torch.fx.GraphModule, node: torch.fx.Node, arg_index: int +) -> tuple[str, torch.nn.Module, torch.fx.Node]: + submod_node = node.args[arg_index] + if not isinstance(submod_node, torch.fx.Node): + raise AssertionError( + f"Expected submod_node to be a torch.fx.Node, got {type(submod_node)}" + ) + if submod_node.op != "get_attr": + raise AssertionError( + f"Expected submod_node.op to be 'get_attr', got {submod_node.op}" + ) + if not isinstance(submod_node.target, str): + raise AssertionError( + f"Expected submod_node.target to be a string attribute name, got {type(submod_node.target)}" + ) + submodule = graph_module.get_submodule(submod_node.target) + # pyre-ignore + return submod_node.target, submodule, node + + +def _get_control_flow_submodules( + graph_module: torch.fx.GraphModule, +) -> list[tuple[str, torch.nn.Module, torch.fx.Node]]: + """ + Returns a list of submodules used for control flow operations + (torch.ops.higher_order.cond/map) that are in the given toplevel graph (does not look + into submodules). Specifically, the returned value is a list containing a + tuple of (name of the submodule that's stored in the graph module, the + submodule itself, and the fx node that uses this submodule). + """ + control_flow_submodules = [] + for node in graph_module.graph.nodes: + if node.op != "call_function": + continue + + if node.target is torch.ops.higher_order.cond: + control_flow_submodules.append(_get_submodule(graph_module, node, 1)) + control_flow_submodules.append(_get_submodule(graph_module, node, 2)) + if node.target is torch.ops.higher_order.map_impl: + control_flow_submodules.append(_get_submodule(graph_module, node, 0)) + + return control_flow_submodules + + +def bfs_trace_with_node_process( + model: ExportedProgram | torch.fx.GraphModule, node_op: Callable +) -> None: + """Traverse the graph module and apply node_op to each node.""" + + if not isinstance(model, (ExportedProgram, torch.fx.GraphModule)): + raise AssertionError( + f"Expected GraphModule or ExportedProgram, got {type(model)}" + ) + gm = model.graph_module if isinstance(model, ExportedProgram) else model + queue = [gm] + while queue: + current_graph_module = queue.pop(0) + for node in current_graph_module.graph.nodes: + if node.op in ["output", "placeholder"]: + continue + + node_op(node) + + control_flow_submodules = [ + submodule + for _, submodule, _ in _get_control_flow_submodules(current_graph_module) + ] + queue.extend(control_flow_submodules) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/pt2e/lowering.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/pt2e/lowering.py new file mode 100644 index 0000000000000000000000000000000000000000..c306b1745badaf575060a6a2fb4ed21f6977ab75 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/pt2e/lowering.py @@ -0,0 +1,60 @@ +import torch +from torch._inductor.constant_folding import constant_fold +from torch._inductor.fx_passes.freezing_patterns import freezing_passes + + +__all__ = [ + "lower_pt2e_quantized_to_x86", +] + + +def lower_pt2e_quantized_to_x86( + model: torch.fx.GraphModule, + example_inputs: tuple[torch.Tensor, ...], +) -> torch.fx.GraphModule: + """Lower a PT2E-quantized model to x86 backend. + + Args: + * `model` (torch.fx.GraphModule): a model quantized by PT2E quantization flow. + * `example_inputs` (tuple[torch.Tensor, ...]): example inputs for the model. + + Return: + A GraphModule lowered to x86 backend. + """ + + def _post_autograd_decomp_table(): # type: ignore[no-untyped-def] + decomp_table = torch.export.default_decompositions() + + # if we are post-autograd, we shouldn't + # decomp prim ops. + for k in list(decomp_table.keys()): + if not torch._export.utils._is_cia_op(k): + del decomp_table[k] + + return decomp_table + + def _node_replace(m): # type: ignore[no-untyped-def] + # Replace aten.t(x) with aten.permute(x, [1, 0]) + aten = torch.ops.aten + g = m.graph + for node in g.nodes: + if node.target is aten.t.default: + with g.inserting_before(node): + x = node.args[0] + dims = [1, 0] + perm_node = g.call_function(aten.permute.default, args=(x, dims)) + node.replace_all_uses_with(perm_node) + g.erase_node(node) + + g.lint() + m.recompile() + + lowered_model = ( + torch.export.export(model, example_inputs, strict=True) + .run_decompositions(_post_autograd_decomp_table()) + .module() + ) + _node_replace(lowered_model) + freezing_passes(lowered_model, example_inputs) + constant_fold(lowered_model) + return lowered_model diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/pt2e/port_metadata_pass.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/pt2e/port_metadata_pass.py new file mode 100644 index 0000000000000000000000000000000000000000..be5878042b046447e446c3f4ee1cb1d761f29f27 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/pt2e/port_metadata_pass.py @@ -0,0 +1,217 @@ +# mypy: allow-untyped-defs +import logging + +import torch +from torch._export.error import InternalError +from torch.ao.quantization.pt2e.utils import ( + _filter_sym_size_users, + _find_q_dq_node_for_user, + _is_valid_annotation, +) +from torch.ao.quantization.quantizer import QuantizationSpecBase +from torch.fx.passes.infra.pass_base import PassBase, PassResult + + +logger = logging.getLogger(__name__) +logger.setLevel(logging.ERROR) + +__all__ = ["PortNodeMetaForQDQ"] + +_METADATA_TO_PORT = [ + "stack_trace", + "quantization_tag", +] + +_QUANTIZE_OPS = [ + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.quantize_per_tensor.tensor, + torch.ops.quantized_decomposed.quantize_per_channel.default, + torch.ops.pt2e_quant.quantize_affine, +] + +_DEQUANTIZE_OPS = [ + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.tensor, + torch.ops.quantized_decomposed.dequantize_per_channel.default, + torch.ops.pt2e_quant.dequantize_affine, +] + +_CHOOSE_QPARAMS_OPS = [ + torch.ops.quantized_decomposed.choose_qparams.tensor, + torch.ops.quantized_decomposed.choose_qparams_symmetric.tensor, + torch.ops.pt2e_quant.choose_qparams_affine, +] + + +def _add_metadata(to_node: torch.fx.Node, from_node: torch.fx.Node) -> None: + from_meta = from_node.meta + for meta_name in _METADATA_TO_PORT: + if meta_name in from_meta: + to_node.meta[meta_name] = from_meta[meta_name] + + +def _has_quant_annotation(node: torch.fx.Node) -> bool: + return "quantization_annotation" in node.meta + + +def _find_choose_qparams_node(node: torch.fx.Node) -> torch.fx.Node | None: + # BFS to look for choose qparams + from collections import deque + + queue = deque(list(node.users.keys())) + while len(queue): + n = queue.popleft() + if n.op == "output": + continue + if n.op == "call_function" and n.target in _CHOOSE_QPARAMS_OPS: + return n + for k in n.users: + queue.append(k) + return None + + +def _port_metadata_for_input_quant_nodes( + input_node: torch.fx.Node, + node: torch.fx.Node, + qspec: QuantizationSpecBase | None, +): + if qspec is None: + return + + is_dynamic_quant = getattr(qspec, "is_dynamic", None) + if is_dynamic_quant is not None and is_dynamic_quant is True: + choose_qparams_node = _find_choose_qparams_node(input_node) + if choose_qparams_node is None: + raise ValueError(f"No chose qparams node found for {node}") + choose_qparam_users = _filter_sym_size_users(choose_qparams_node) + if len(choose_qparam_users) != 2: + raise InternalError(f"Expecting exactly two user for {choose_qparams_node}") + scale_node = choose_qparam_users.pop() + dynamic_q_node = next(iter(scale_node.users.keys())) + dynamic_q_node_users = _filter_sym_size_users(dynamic_q_node) + if len(dynamic_q_node_users) > 1: + raise InternalError(f"Expecting single user for {dynamic_q_node}") + dynamic_dq_node = dynamic_q_node_users.pop() + _add_metadata(choose_qparams_node, node) + _add_metadata(dynamic_q_node, node) + _add_metadata(dynamic_dq_node, node) + else: + q_node, dq_node = _find_q_dq_node_for_user(input_node, node) + if q_node is None or dq_node is None: + return + # add metadata for all the node between q_node and get_attr node + # if the q_node can be traced back to get_attr node + q_to_get_attr_nodes = [q_node] + q_node_input = q_node.args[0] + while ( + isinstance(q_node_input, torch.fx.Node) + and q_node_input.op == "call_function" + and q_node_input.target + in [ + torch.ops.aten.flatten.using_ints, + torch.ops.aten.permute.default, + torch.ops.aten.permute_copy.default, + torch.ops.aten.slice_copy.Tensor, + torch.ops.aten.squeeze.dim, + torch.ops.aten.squeeze_copy.dim, + torch.ops.aten.transpose.Dimname, + torch.ops.aten.transpose.int, + torch.ops.aten.transpose_, + torch.ops.aten.view_copy.default, + torch.ops.aten.view.default, + torch.ops.aten._mkldnn_transpose, + ] + ): + q_to_get_attr_nodes.append(q_node_input) + q_node_input = q_node_input.args[0] + if isinstance(q_node_input, torch.fx.Node) and q_node_input.op == "get_attr": + for n in q_to_get_attr_nodes: + _add_metadata(n, q_node_input) + _add_metadata(dq_node, node) + + +def _port_metadata_for_output_quant_nodes( + node: torch.fx.Node, qspec: QuantizationSpecBase | None +): + if qspec is None: + return + + node_users = _filter_sym_size_users(node) + if len(node.users) == 0: + return + if len(node_users) != 1: + logger.warning(f"Expecting {node} to have single user") # noqa: G004 + q_node = node_users.pop() + if q_node.op != "call_function" or q_node.target not in _QUANTIZE_OPS: + logger.warning( + f"Expecting {node} user to be a quantized op but got {q_node}" # noqa: G004 + ) # noqa: G004 + return + + _add_metadata(q_node, node) + + +class PortNodeMetaForQDQ(PassBase): + """ + Port metadata for nodes added by quantization flow. + For static quant these are: + - quantizer_per_tensor.default, dequantize_per_tensor.default + - quantizer_per_channel.default, dequantize_per_channel.default + For dynamic quant these are: + - choose_qparams.tensor + - quantizer_per_tensor.tensor, dequantize_per_tensor.tensor + - quantizer_per_channel.default, dequantize_per_channel.default + + Rules of porting metadata: + - Metadata to be ported: + - nn_module_stack + - stack_trace + - quantization_tag + - Metadata to NOT be ported: + - Everything else + - Rules: + - Statically quantized patterns: + - Dequantize nodes on the inputs to be quantized inherit metadata of the consumer node. + - Quantize nodes on the outputs inherit metadata of the producer node. + - Example 1: + - Original: [Conv -> AvgPool -> Linear] + - Quantized [Q-> DQ -> Conv -> Q -> DQ -> AvgPool -> Q -> DQ -> Linear -> Q -> DQ] + - Inner brackets specify which nodes Q/DQ inherit metadata from + - [Q-> [DQ -> Conv -> Q] -> [DQ -> AvgPool -> Q] -> [DQ -> Linear -> Q] -> DQ] + - Note first Q and last DQ do not inherit metadata from any nodes + - Example 2: + - Original: [Conv -> AvgPool -> Linear] + - AvgPool is not quantized + - Quantized [Q-> DQ -> Conv -> Q -> DQ -> AvgPool -> Q -> DQ -> Linear -> Q -> DQ] + - Inner brackets specify which nodes Q/DQ inherit metadata from + - [Q-> [DQ -> Conv -> Q] -> DQ -> [AvgPool] -> Q -> [DQ -> Linear -> Q] -> DQ] + - Note DQ and Q nodes around AvgPool do not inherit metadata from AvgPool because + AvgPool was not supposed to be quantized. Metadata porting relies on quantization_annotation + on the nodes (in this case AvgPool node) to conclude if the node or pattern was + supposed to be quantized. And subsequently decide if the preceding Q, if any, should + inherit metadata from AvgPool. + - Dynamically quantized patterns: + - Input that are dynamically quantized have choose_qparams, quantize and dequantize nodes + - For example, below linear is dynamically quantized while rest statically: + - Original: [Conv -> AvgPool -> Linear] + - Quantized [Q-> DQ -> Conv -> Q -> DQ -> AvgPool -> Q -> DQ -> choose_params -> Q -> DQ -> Linear] + - Quantized [Q-> [DQ -> Conv -> Q] -> [DQ -> AvgPool -> Q] -> DQ -> [choose_params -> Q -> DQ -> Linear]] + - Note first Q does not inherit metadata from any nodes + NB: + - The best place for porting metadata is during observer conversion to q/dq. This is because it precisely + knows which quantization spec is converted to q/dq and thus from where the metadata should be ported. + However, since FX and PT2E quant workflow are on a common code-base, this hurts readability quite a bit. + Doing it via a separate pass, helps readability of the code. Once we are able to refactor PT2E quant + code, this pass should like to be integrated in the refactored variant of "convert" step. + """ + + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: + for node in graph_module.graph.nodes: + annotation = node.meta.get("quantization_annotation", None) + if _is_valid_annotation(annotation): + input_qspec_map = node.meta["quantization_annotation"].input_qspec_map + output_qspec = node.meta["quantization_annotation"].output_qspec + for input_node, qspec in input_qspec_map.items(): + _port_metadata_for_input_quant_nodes(input_node, node, qspec) + _port_metadata_for_output_quant_nodes(node, output_qspec) + return PassResult(graph_module, True) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/pt2e/prepare.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/pt2e/prepare.py new file mode 100644 index 0000000000000000000000000000000000000000..7e3c8b4b33d881feda0864cc65698972a3226c7c --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/pt2e/prepare.py @@ -0,0 +1,610 @@ +# mypy: allow-untyped-defs +from typing import Any + +import torch +from torch._subclasses import FakeTensor +from torch.ao.quantization import ( + CUSTOM_KEY, + NUMERIC_DEBUG_HANDLE_KEY, + ObserverOrFakeQuantize, + QConfigMapping, +) +from torch.ao.quantization.fx.custom_config import PrepareCustomConfig +from torch.ao.quantization.fx.prepare import ( + _create_obs_or_fq_from_qspec, + _insert_obs_or_fq, + _is_activation_post_process_node, + _save_state, +) +from torch.ao.quantization.qconfig import QConfigAny +from torch.ao.quantization.quantizer import ( + EdgeOrNode, + QuantizationSpecBase, + SharedQuantizationSpec, +) +from torch.ao.quantization.utils import _assert_and_get_unique_device +from torch.fx import Graph, GraphModule, Node +from torch.fx.node import Argument + + +# TODO: make pt2e folder private? +__all__ = [ + "prepare", +] + + +def _find_root_edge_or_node( + edge_or_node: EdgeOrNode, shared_with_map: dict[EdgeOrNode, EdgeOrNode] +) -> EdgeOrNode: + """Find the root node for the sharing tree + Args: + edge_or_node: edge/node that we want to find the root + shared_with_map: each edge/node points to the parent, the root node will points to itself + + Returns: + root edge/node + """ + parent = shared_with_map[edge_or_node] + if parent == edge_or_node: + return edge_or_node + root = _find_root_edge_or_node(parent, shared_with_map) + # path compression + shared_with_map[edge_or_node] = root + return root + + +def _union( + parent: EdgeOrNode, + child: EdgeOrNode, + shared_with_map: dict[EdgeOrNode, EdgeOrNode], +) -> None: + """Merge the subtree for `child` with `parent`, the order is important here""" + root_parent = _find_root_edge_or_node(parent, shared_with_map) + root_child = _find_root_edge_or_node(child, shared_with_map) + # union the two trees by pointing the root of child to root of parent + shared_with_map[root_child] = root_parent + + +def _update_shared_with( + child: EdgeOrNode, + qspec: QuantizationSpecBase, + shared_with_map: dict[EdgeOrNode, EdgeOrNode], +): + """Update the `shared_with_map` based on the qspec, this applies the `SharedQuantizationSpec` + configuration and established the relationship between `edge_or_node` with the edge/node that it + is pointing to, we'll use this information in the end to get the group id + """ + if isinstance(qspec, SharedQuantizationSpec): + parent = qspec.edge_or_node + # we point from edge_or_node to the node that it is sharing_with, e.g. + # qspec for a = SharedQuantizationSpec(b) means `a` points to `b` + _union(parent, child, shared_with_map) + + +def _unwrap_shared_qspec( + qspec: QuantizationSpecBase, + edge_or_node_to_qspec: dict[EdgeOrNode, QuantizationSpecBase], + shared_with_map: dict[EdgeOrNode, EdgeOrNode], +) -> QuantizationSpecBase: + """Unwraps qspec to get the final root qspec (non SharedQuantizationSpec) + if qspec is SharedQuantizationSpec + (1). tries to find the root edge or node for the node that the qspec points to + (2). recursively find the root qspec based on the qspec for the root node + """ + if isinstance(qspec, SharedQuantizationSpec): + sharing_with = qspec.edge_or_node + root = _find_root_edge_or_node(sharing_with, shared_with_map) + qspec = edge_or_node_to_qspec[root] + return _unwrap_shared_qspec(qspec, edge_or_node_to_qspec, shared_with_map) + return qspec + + +def _has_same_attr( + qspec_a: QuantizationSpecBase, qspec_b: QuantizationSpecBase, attr_name: str +): + return ( + hasattr(qspec_a, attr_name) + and hasattr(qspec_b, attr_name) + and getattr(qspec_a, attr_name) == getattr(qspec_b, attr_name) + ) or (not hasattr(qspec_a, attr_name) and not hasattr(qspec_b, attr_name)) + + +def _get_edge_or_node_to_qspec( + model: torch.fx.GraphModule, +) -> dict[EdgeOrNode, QuantizationSpecBase]: + """Get a map from EdgeOrNode to quantization spec based on annotations on the nodes""" + edge_or_node_to_qspec: dict[EdgeOrNode, QuantizationSpecBase] = {} + for n in model.graph.nodes: + if hasattr(n, "meta") and "quantization_annotation" in n.meta: + qa = n.meta["quantization_annotation"] + for input_to_n, qspec in qa.input_qspec_map.items(): + input_edge = (input_to_n, n) + edge_or_node_to_qspec[input_edge] = qspec + if qa.output_qspec is not None: + output_node = n + qspec = qa.output_qspec + edge_or_node_to_qspec[output_node] = qspec + return edge_or_node_to_qspec + + +def _union_input_edge_with( + input_edge, + input_edge_root_qspec, + edge_or_node, + edge_or_node_to_qspec, + shared_with_map, +): + """Union input edge with another edge or node, used in implicit sharing to point the current input + edge to other user edges of the producer node, or the output of producer node since these are + referring to the same Tensor + """ + root_qspec = None + if edge_or_node in edge_or_node_to_qspec: + qspec = edge_or_node_to_qspec[edge_or_node] + root_qspec = _unwrap_shared_qspec(qspec, edge_or_node_to_qspec, shared_with_map) + # TODO: add assertions for types of root qspecs + if root_qspec is not None and all( + _has_same_attr(root_qspec, input_edge_root_qspec, attr) + for attr in [ + "dtype", + "is_dynamic", + "quant_min", + "quant_max", + "qscheme", + "ch_axis", + "scale", + "zero_point", + ] + ): + # the input arg to the node should reuse the existing output observer for arg + # since dtype is the same (we may want to extend this to be a more strict check + # in the future) + # so we point from `input_edge` to `arg` (output of the argument) + _union(edge_or_node, input_edge, shared_with_map) + + +def _get_edge_or_node_to_group_id( + edge_or_node_to_qspec: dict[EdgeOrNode, QuantizationSpecBase], +) -> dict[EdgeOrNode, int]: + """Map from edge/node to the group ID, generated from quantization annotations, + edge/node with the same group ID should use the same observer/fake_quant instance + + This is applying SharedQuantizationSpec configuration and map each edge/node to a group + There is another implicit sharing that's built in the quantization, when we have the following: + * op1 -> op2 + * output of op1: int8_qspec + * (op1 -> op2) input edge: int8_qspec + we'll assume sharing between the output of op1 and input of (op1 -> op2) since these are the same Tensor. + + Figuring out the correct group ID for all edge/node is a standard union find problem: + https://www.geeksforgeeks.org/introduction-to-disjoint-set-data-structure-or-union-find-algorithm/ + + Args: + edge_or_node_to_qspec: Dictionary from edge_or_node to the qspec, derived from annotations + Returns: + edge_or_node_to_group_id: Dictionary from edge_or_node to group_id (int), all edge or node that + belongs to the same group should have the same id + + Example: + op2 -> cat1 -> cat2 + op1 / / + op3 + edge_or_node_to_qspec: { + op1: int8_qspec, + op2: int8_qspec, + (op1, cat1): int8_qspc, + (op2, cat1): SharedQuantizationSpec((op1, cat1)), + cat1: SharedQuantizationSpec((op1, cat1)), + (op3, cat2): int8_qspec, + (cat1, cat2): SharedQuantizationSpec((op3, cat2)), + cat2: SharedQuantizationSpec((op3, cat2)), + } + + edge_or_node_to_group_id = _get_edge_or_node_to_group_id(edge_or_node_to_qspec) + edge_or_node_to_group_id: { + op1: 1, + op2: 1, + (op1, cat1): 1, + (op2, cat1): 1, + cat1: 1, + (op3, cat2): 1, + (cat1, cat2): 1, + cat2: 1, + } + # everything are in the same group because (cat1) and (cat1, cat2) are implicitly shared, which + # connects the two sharing group around cat1 and cat2 op due to transitive sharing + """ + # means the observer of key should be shared with observer with value, by default it will + # be shared with itself + shared_with_map: dict[EdgeOrNode, EdgeOrNode] = { + k: k for k in edge_or_node_to_qspec + } + for edge_or_node, qspec in edge_or_node_to_qspec.items(): + if isinstance(edge_or_node, torch.fx.Node): + output_node = edge_or_node + _update_shared_with(output_node, qspec, shared_with_map) + else: + input_edge = edge_or_node + input_edge_root_qspec = _unwrap_shared_qspec( + qspec, edge_or_node_to_qspec, shared_with_map + ) + + if not isinstance(input_edge, tuple): + raise AssertionError( + f"input_edge must be a tuple (arg, user), got {type(input_edge)}" + ) + arg, n = input_edge + if n.meta["quantization_annotation"].allow_implicit_sharing: + # NOTE: the order is important here, we first share with other users and then share with previous + # output because the reverse order could cause circular dependency + # e.g node1 -> node2 + # \ -> node3 + # when processing (node1, node2), if we first point (node1, node2) to node1 + # Step 1. shared_map = {(node1, node2): node1} + # Step 2. after that, we point the (node1, node2) to its other user (node1, node3) , + # which means shared_map = {(node1, node2): node1, node1: (node1, node3)} + # because we will point the root of (node1, node2) (in this case node1) to the root of (node1, node3) + # Step 3. and when we process (node1, node3), it can try to point to node1 as well, then we'll + # have a circular dependency + # the following order works around this issue, but this does not allow arbitrary configuration + # of sharing so it might break in a different case in the future, when it breaks + # quantizer writer can check the notes here to debug the issue + + # sharing with other users of the producer node + # (arg, user) + if not isinstance(arg, Node) or not isinstance(n, Node): + raise Exception( # noqa: TRY002 + f"Expected input_edge to have type Tuple[Node, Node], but got: {arg, n}" + ) + for user in arg.users: + if user is n: + continue + arg_to_user_edge = (arg, user) + _union_input_edge_with( + input_edge, + input_edge_root_qspec, + arg_to_user_edge, + edge_or_node_to_qspec, + shared_with_map, + ) + + # sharing with output of producer node + _union_input_edge_with( + input_edge, + input_edge_root_qspec, + arg, + edge_or_node_to_qspec, + shared_with_map, + ) + + _update_shared_with(input_edge, qspec, shared_with_map) + + # now that we get the sharing relations between all edges and nodes, we can assign group ids + cur_group_id = 0 + edge_or_node_to_group_id: dict[EdgeOrNode, int] = {} + for edge_or_node in shared_with_map: + root = _find_root_edge_or_node(edge_or_node, shared_with_map) + if root not in edge_or_node_to_group_id: + edge_or_node_to_group_id[root] = cur_group_id + cur_group_id += 1 + edge_or_node_to_group_id[edge_or_node] = edge_or_node_to_group_id[root] + + return edge_or_node_to_group_id + + +def _get_obs_or_fq_map( + edge_or_node_to_group_id: dict[EdgeOrNode, int], + edge_or_node_to_qspec: dict[EdgeOrNode, QuantizationSpecBase], + is_qat: bool, +) -> dict[EdgeOrNode, ObserverOrFakeQuantize]: + """Generates the EdgeOrNode to observer/fake_quant instances + Makes sure that for EdgeOrNode that has the same group_id should have the same observer or fake quant + instances + """ + obs_or_fq_map: dict[EdgeOrNode, ObserverOrFakeQuantize] = {} + group_id_to_obs_or_fq: dict[int, ObserverOrFakeQuantize] = {} + for edge_or_node, qspec in edge_or_node_to_qspec.items(): + group_id = edge_or_node_to_group_id[edge_or_node] + if group_id not in group_id_to_obs_or_fq: + # TODO: maybe edge_or_node_to_qspec should be edge_or_node_to_root_qspec, this will simplify + # the implementation for _create_obs_or_fq_from_qspec + group_id_to_obs_or_fq[group_id] = _create_obs_or_fq_from_qspec( + qspec, obs_or_fq_map, is_qat + ) + obs_or_fq_map[edge_or_node] = group_id_to_obs_or_fq[group_id] + return obs_or_fq_map + + +def _maybe_insert_input_observer_for_arg_or_kwarg( + node: Node | Any, + arg: Argument, + qconfig: QConfigAny, + model: torch.nn.Module, + named_modules: dict[str, torch.nn.Module], + obs_or_fq_map: dict[EdgeOrNode, ObserverOrFakeQuantize], + is_qat: bool, + model_device: torch.device | None = None, +) -> Argument: + """ + Given a `node` and an `arg`, inserts an input observer between + `node` and `arg` if necessary. + """ + # for ops such as torch.cat([x0, x1]), + # traverse through the list + if isinstance(arg, (list, tuple)): + new_arg_to_return = [] + for inner_arg in arg: + new_inner_arg = _maybe_insert_input_observer_for_arg_or_kwarg( + node, + inner_arg, + qconfig, + model, + named_modules, + obs_or_fq_map, + is_qat, + model_device, + ) + new_arg_to_return.append(new_inner_arg) + return type(arg)(new_arg_to_return) + + if not isinstance(arg, Node): + return arg + if not isinstance(arg, Node): + raise AssertionError( + f"expect original argument to be a Node, but got: {type(arg)}" + ) + # default (no observer) + new_arg = arg + + # find the original `arg` node to the current node, skipping inserted observer/fake_quant nodes + original_arg = arg + while _is_activation_post_process_node(original_arg, named_modules): + original_arg = original_arg.args[0] # type: ignore[assignment] + if not isinstance(original_arg, Node): + raise AssertionError( + f"expect original argument to be a Node, but got: {type(original_arg)}" + ) + + input_edge = (original_arg, node) + if input_edge not in obs_or_fq_map: + return new_arg + # input_edge needs to be observed + input_edge_obs_or_fq = obs_or_fq_map[input_edge] + if input_edge_obs_or_fq is None: + return new_arg + + arg_as_output_obs_or_fq = obs_or_fq_map.get(original_arg) + # the arg is observed as the output and is using the same instance as the input_edge + # we'll reuse the inserted observer/fake_quant + if arg_as_output_obs_or_fq is not None and id(arg_as_output_obs_or_fq) == id( + input_edge_obs_or_fq + ): + return new_arg + + # otherwise, we'll insert a new observer/fake_quant node + + # skip inserting new observers if the same observer instance is inserted before for another user + # Example: + # conv1 -> obs1 -> existing_obs -> conv2 + # \ -> conv3 + # + # instead of inserting new observers we will have: + # conv1 -> obs1 -> existing_obs -> conv2 + # \ -> conv3 + for maybe_obs_node in arg.users: + if not _is_activation_post_process_node(maybe_obs_node, named_modules): + continue + maybe_obs_mod = named_modules[maybe_obs_node.target] # type: ignore[index] + if id(maybe_obs_mod) == id(input_edge_obs_or_fq): + return maybe_obs_node + + if not isinstance(model.graph, Graph): + raise AssertionError( + f"Expected model.graph to be a torch.fx.Graph, got {type(model.graph)}" + ) + new_arg = _insert_obs_or_fq( + arg, + input_edge_obs_or_fq, + model, + named_modules, + model.graph, + model_device, + ) + return new_arg + + +def _maybe_insert_input_observers_for_node( + node: Node, + qconfig: QConfigAny, + model: torch.nn.Module, + named_modules: dict[str, torch.nn.Module], + obs_or_fq_map: dict[EdgeOrNode, ObserverOrFakeQuantize], + is_qat: bool, + model_device: torch.device | None = None, +) -> None: + """ + If needed, inserts observers to the input args and kwargs of `node`. + Note: modifies `node` inplace. + + For example, if cur_node needs an observer after prev_node, we change from + + prev_node -> cur_node + + To + + prev_node -> obs -> cur_node + + """ + # Look through every input arg. If that arg's target dtype does not + # match the current node's target dtype, insert an observer. + new_args = [] + for arg in node.args: + new_arg = _maybe_insert_input_observer_for_arg_or_kwarg( + node, + arg, + qconfig, + model, + named_modules, + obs_or_fq_map, + is_qat, + model_device, + ) + new_args.append(new_arg) + + # Clone has a memory_format kwarg, zeros_like has a pin_memory kwarg, and + # gelu has a has an approximate kwarg that persist in exported graph. + # This is just a work around for these. + if not ( + node.target is torch.ops.aten.clone.default + or node.target is torch.ops.aten.zeros_like.default + or node.target is torch.ops.aten.gelu.default + or len(node.kwargs) == 0 + ): + raise AssertionError(" expecting kwargs for aten op IR to be empty") + + # assign the new args to the node, inplace + node.args = tuple(new_args) + + +def _maybe_insert_output_observer_for_node( + node: Node, + model: torch.nn.Module, + named_modules: dict[str, torch.nn.Module], + graph: Graph, + obs_or_fq_map: dict[EdgeOrNode, ObserverOrFakeQuantize], + is_qat: bool, + model_device: torch.device | None = None, +) -> Node | None: + if node in obs_or_fq_map: + output_act_obs_or_fq = obs_or_fq_map[node] + new_output = _insert_obs_or_fq( + node, + output_act_obs_or_fq, + model, + named_modules, + graph, + model_device, + ) + # propagate numeric debug handle from original node to observer/fake_quant node + if ( + isinstance(node, Node) + and isinstance(new_output, Node) + and CUSTOM_KEY in node.meta + and NUMERIC_DEBUG_HANDLE_KEY in node.meta[CUSTOM_KEY] + ): + if CUSTOM_KEY not in new_output.meta: + new_output.meta[CUSTOM_KEY] = {} + new_output.meta[CUSTOM_KEY][NUMERIC_DEBUG_HANDLE_KEY] = node.meta[ + CUSTOM_KEY + ][NUMERIC_DEBUG_HANDLE_KEY] + return new_output + return None + + +def _maybe_insert_input_and_output_observers_for_node( + node: Node, + model: torch.fx.GraphModule, + obs_or_fq_map: dict[EdgeOrNode, ObserverOrFakeQuantize], + is_qat: bool, + model_device: torch.device | None = None, +): + this_node_quantization_annotation = node.meta.get("quantization_annotation", None) + if this_node_quantization_annotation is None: + return + + named_modules = dict(model.named_modules(remove_duplicate=False)) + _maybe_insert_input_observers_for_node( + node, + None, # qconfig + model, + named_modules, + obs_or_fq_map, + is_qat, + model_device, + ) + + output_is_a_tensor = "val" in node.meta and isinstance(node.meta["val"], FakeTensor) + if not output_is_a_tensor: + return + + # this returns the new observer node if it was needed + maybe_output_obs_node = _maybe_insert_output_observer_for_node( + node, + model, + named_modules, + model.graph, + obs_or_fq_map, + is_qat, + model_device, + ) + + if maybe_output_obs_node is None: + return + # Update users of original node to use the output observer + # instead. For example, change + # + # next_node + # / + # cur_node -> obs + # + # to + # + # next_node + # / + # cur_node -> obs + # + # We need to save orig users before updating uses because + # the list of users will change as we update uses + orig_users = list(node.users.keys()) + for user_node in orig_users: + if user_node is maybe_output_obs_node: + continue + user_node.replace_input_with(node, maybe_output_obs_node) + + +def prepare( + model: GraphModule, + node_name_to_scope: dict[str, tuple[str, type]], + is_qat: bool, + obs_or_fq_callback=None, +) -> GraphModule: + # Since we are mutating the graph as we go, we iterate over the original + # nodes before observer insertion, instead of model.graph.nodes. + nodes_before_observation = list(model.graph.nodes) + + # At the high level we construct a map from EdgeOrNode to a observer_or_fake_quant instance + # all edge/nodes that belongs to the same group will use the same instance + # and when we insert observers we'll just query this map to get the correct observer_or_fake_quant + # instance + edge_or_node_to_qspec = _get_edge_or_node_to_qspec(model) + edge_or_node_to_group_id = _get_edge_or_node_to_group_id(edge_or_node_to_qspec) + obs_or_fq_map = _get_obs_or_fq_map( + edge_or_node_to_group_id, edge_or_node_to_qspec, is_qat + ) + if obs_or_fq_callback: + obs_or_fq_callback(model, obs_or_fq_map) + model_device = _assert_and_get_unique_device(model) + + for node in nodes_before_observation: + # TODO: simplify logic for inserting observers + _maybe_insert_input_and_output_observers_for_node( + node, + model, + obs_or_fq_map, + is_qat, + model_device, + ) + + model = GraphModule(model, model.graph) + + _save_state( + model, + {}, # node_name_to_qconfig + node_name_to_scope, + PrepareCustomConfig(), + {}, # equalization_node_name_to_qconfig + QConfigMapping(), + is_qat, + set(), # observed_node_names + ) + return model diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/pt2e/qat_utils.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/pt2e/qat_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..9498a4f16f78f256baba85246dabeb458c9764c2 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/pt2e/qat_utils.py @@ -0,0 +1,1058 @@ +# mypy: allow-untyped-defs +import copy +import dataclasses +import itertools +import operator +from collections.abc import Callable +from typing import Any, TYPE_CHECKING + +import torch +import torch.nn.functional as F +from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa: F401 +from torch.ao.quantization.pt2e.export_utils import _WrapperModule +from torch.ao.quantization.quantizer import ( + DerivedQuantizationSpec, + EdgeOrNode, + QuantizationSpecBase, + SharedQuantizationSpec, +) +from torch.fx import Graph, GraphModule, Node +from torch.fx.subgraph_rewriter import replace_pattern_with_filters, ReplacedPatterns + +from .utils import ( + _get_aten_graph_module_for_pattern, + _is_bn_node, + _is_conv_or_conv_transpose_node, + _is_conv_transpose_fn, + fold_bn_weights_into_conv_node, +) + + +if TYPE_CHECKING: + from torch.fx.passes.utils.matcher_with_name_node_map_utils import InternalMatch + +__all__ = [] # type: ignore[var-annotated] + + +def _get_quantized_conv_bn_example_inputs_kwargs( + is_per_channel: bool, + has_bias: bool, + bias_is_quantized: bool, + is_cuda: bool, +) -> dict[str, Any]: + """ + Optional example inputs for quantized and folded conv-bn patterns + used in convert, expressed as kwargs. + """ + kwargs = {} + # Per tensor quantization uses literals to represent scale and zero + # point, so there is no need to include them here as kwargs + if is_per_channel: + kwargs["weight_scale"] = torch.tensor([1], dtype=torch.float) + kwargs["weight_zero_point"] = torch.tensor([0], dtype=torch.int) + if has_bias and bias_is_quantized: + kwargs["bias_scale"] = torch.tensor([1], dtype=torch.float) + kwargs["bias_zero_point"] = torch.tensor([0], dtype=torch.int) + if has_bias: + kwargs["conv_bias"] = torch.randn(1) + if is_cuda: + for k, v in kwargs.items(): + if isinstance(v, torch.Tensor): + kwargs[k] = v.cuda() + return kwargs + + +def _get_conv_bn_pattern(conv_fn: Callable) -> Callable: + def _conv_bn_pattern( + x: torch.Tensor, + conv_weight: torch.Tensor, + conv_bias: torch.Tensor, + bn_weight: torch.Tensor, + bn_bias: torch.Tensor, + bn_running_mean: torch.Tensor, + bn_running_var: torch.Tensor, + ) -> torch.Tensor: + x = conv_fn(x, conv_weight, conv_bias) + x = F.batch_norm( + x, bn_running_mean, bn_running_var, bn_weight, bn_bias, training=True + ) + return x + + return _WrapperModule(_conv_bn_pattern) + + +# TODO: merge this with the `no_conv_bias` case +def _get_qat_conv_bn_pattern(conv_fn: Callable) -> Callable: + def _qat_conv_bn_pattern( + x: torch.Tensor, + conv_weight: torch.Tensor, + conv_bias: torch.Tensor, + bn_weight: torch.Tensor, + bn_bias: torch.Tensor, + bn_running_mean: torch.Tensor, + bn_running_var: torch.Tensor, + ) -> torch.Tensor: + """ + Approximated method to fuse conv and bn. It requires only one forward pass. + conv_orig = conv / scale_factor where scale_factor = bn.weight / running_std. + This is based on `nniqat.ConvBn2d._forward_approximate`. + """ + # TODO: allow setting eps + bn_eps = 1e-5 + running_std = torch.sqrt(bn_running_var + bn_eps) + scale_factor = bn_weight / running_std + weight_shape = [1] * len(conv_weight.shape) + weight_in_channel_axis = 1 if _is_conv_transpose_fn(conv_fn) else 0 + weight_shape[weight_in_channel_axis] = -1 + bias_shape = [1] * len(conv_weight.shape) + bias_shape[1] = -1 + scaled_weight = conv_weight * scale_factor.reshape(weight_shape) + zero_bias = torch.zeros_like(conv_bias, dtype=x.dtype) + x = conv_fn(x, scaled_weight, zero_bias) + x = x / scale_factor.reshape(bias_shape) + x = x + conv_bias.reshape(bias_shape) + x = F.batch_norm( + x, + bn_running_mean, + bn_running_var, + bn_weight, + bn_bias, + training=True, + eps=bn_eps, + ) + return x + + return _WrapperModule(_qat_conv_bn_pattern) + + +def _get_qat_conv_bn_pattern_no_conv_bias(conv_fn: Callable) -> Callable: + def _qat_conv_bn_pattern_no_conv_bias( + x: torch.Tensor, + conv_weight: torch.Tensor, + # Not used, only for matching convenience + conv_bias: torch.Tensor, + bn_weight: torch.Tensor, + bn_bias: torch.Tensor, + bn_running_mean: torch.Tensor, + bn_running_var: torch.Tensor, + ) -> torch.Tensor: + """ + Same as `_get_qat_conv_bn_pattern`, but handles the case with no conv bias. + """ + # TODO: allow setting eps + bn_eps = 1e-5 + running_std = torch.sqrt(bn_running_var + bn_eps) + scale_factor = bn_weight / running_std + weight_shape = [1] * len(conv_weight.shape) + weight_in_channel_axis = 1 if _is_conv_transpose_fn(conv_fn) else 0 + weight_shape[weight_in_channel_axis] = -1 + bias_shape = [1] * len(conv_weight.shape) + bias_shape[1] = -1 + scaled_weight = conv_weight * scale_factor.reshape(weight_shape) + x = conv_fn(x, scaled_weight, None) + x = x / scale_factor.reshape(bias_shape) + x = F.batch_norm( + x, + bn_running_mean, + bn_running_var, + bn_weight, + bn_bias, + training=True, + eps=bn_eps, + ) + return x + + return _WrapperModule(_qat_conv_bn_pattern_no_conv_bias) + + +def _append_qdq(x, is_per_channel, is_bias, kwargs): + """ + Helper function to append q-dq ops after `x`, using dummy values for the qparams + and qmin/qmax. We use dummy values here because we match with `ignore_literals=True` + and will manually replace these values after subgraph rewriting. + + Return the dq node. + """ + # Dummy args to be passed into q-dq ops + per_channel_axis = 0 + scale_key = "bias_scale" if is_bias else "weight_scale" + zp_key = "bias_zero_point" if is_bias else "weight_zero_point" + scale = kwargs[scale_key] if is_per_channel else 1.0 + zp = kwargs[zp_key] if is_per_channel else 0 + qmin = -127 + qmax = 127 + dtype = torch.int8 + + qd = torch.ops.quantized_decomposed + if is_per_channel: + x = qd.quantize_per_channel(x, scale, zp, per_channel_axis, qmin, qmax, dtype) + x = qd.dequantize_per_channel(x, scale, zp, per_channel_axis, qmin, qmax, dtype) + else: + x = qd.quantize_per_tensor(x, scale, zp, qmin, qmax, dtype) + x = qd.dequantize_per_tensor(x, scale, zp, qmin, qmax, dtype) + return x + + +def _get_quantized_qat_conv_bn_pattern( + is_per_channel: bool, + has_bias: bool, + bias_is_quantized: bool, + conv_fn: Callable, + bn_is_training: bool, +) -> Callable: + """ + Return the quantized version of QAT conv + BN pattern. + This is based on `nniqat.ConvBn2d._forward_approximate`, + used in QAT convert. We first match this pattern and replace + it with the normal [conv - bn] pattern, then fold the BN + weights into conv. + """ + # TODO: allow setting eps + bn_eps = 1e-5 + + def _quantized_qat_conv_bn_pattern( + x: torch.Tensor, + conv_weight: torch.Tensor, + bn_weight: torch.Tensor, + bn_bias: torch.Tensor, + bn_running_mean: torch.Tensor, + bn_running_var: torch.Tensor, + **kwargs, + ) -> torch.Tensor: + running_std = torch.sqrt(bn_running_var + bn_eps) + scale_factor = bn_weight / running_std + weight_shape = [1] * len(conv_weight.shape) + weight_shape[0] = -1 + bias_shape = [1] * len(conv_weight.shape) + bias_shape[1] = -1 + scaled_weight = conv_weight * scale_factor.reshape(weight_shape) + scaled_weight = _append_qdq( + scaled_weight, + is_per_channel, + is_bias=False, + kwargs=kwargs, + ) + if has_bias: + zero_bias = torch.zeros_like(kwargs["conv_bias"], dtype=x.dtype) + if bias_is_quantized: + zero_bias = _append_qdq( + zero_bias, + is_per_channel, + is_bias=True, + kwargs=kwargs, + ) + x = conv_fn(x, scaled_weight, zero_bias) + else: + x = conv_fn(x, scaled_weight, None) + x = x / scale_factor.reshape(bias_shape) + if has_bias: + x = x + kwargs["conv_bias"].reshape(bias_shape) + x = F.batch_norm( + x, + bn_running_mean, + bn_running_var, + bn_weight, + bn_bias, + training=bn_is_training, + eps=bn_eps, + ) + return x + + return _WrapperModule(_quantized_qat_conv_bn_pattern) + + +def _get_folded_quantized_qat_conv_bn_pattern( + is_per_channel: bool, + has_bias: bool, + bias_is_quantized: bool, + conv_fn: Callable, + bn_is_training: bool, +) -> Callable: + """ + Quantized QAT conv - bn pattern with bn weights being folded into conv. + """ + # TODO: allow setting eps + bn_eps = 1e-5 + + def _folded_quantized_qat_conv_bn_pattern( + x: torch.Tensor, + conv_weight: torch.Tensor, + bn_weight: torch.Tensor, + bn_bias: torch.Tensor, + bn_running_mean: torch.Tensor, + bn_running_var: torch.Tensor, + **kwargs, + ) -> torch.Tensor: + conv_weight = _append_qdq( + conv_weight, + is_per_channel, + is_bias=False, + kwargs=kwargs, + ) + if has_bias: + bias = kwargs["conv_bias"] + if bias_is_quantized: + bias = _append_qdq( + bias, + is_per_channel, + is_bias=True, + kwargs=kwargs, + ) + else: + bias = None + x = conv_fn(x, conv_weight, bias) + x = F.batch_norm( + x, + bn_running_mean, + bn_running_var, + bn_weight, + bn_bias, + training=bn_is_training, + eps=bn_eps, + ) + return x + + return _WrapperModule(_folded_quantized_qat_conv_bn_pattern) + + +def _has_conv_bias_filter( + match: "InternalMatch", + original_graph: Graph, + pattern_graph: Graph, +) -> bool: + """ + Match filter for the subgraph rewriter that returns True if the conv node in + the original graph has bias. + """ + for n in match.nodes_map.values(): + if _is_conv_or_conv_transpose_node(n): + return len(n.args) > 2 and n.args[2] is not None + raise ValueError("Could not find conv node in matched conv + bn pattern") + + +def _no_conv_bias_filter( + match: "InternalMatch", + original_graph: Graph, + pattern_graph: Graph, +) -> bool: + """ + Match filter for the subgraph rewriter that returns True if the conv node in + the original graph does NOT have bias. + """ + return not _has_conv_bias_filter(match, original_graph, pattern_graph) + + +def _is_quantize(n: Node) -> bool: + return n.target in [ + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.quantize_per_tensor.tensor, + torch.ops.quantized_decomposed.quantize_per_channel.default, + ] + + +def _is_dequantize(n: Node) -> bool: + return n.target in [ + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.tensor, + torch.ops.quantized_decomposed.dequantize_per_channel.default, + ] + + +def _get_conv_bn_pattern_nodes(r: ReplacedPatterns) -> dict[str, tuple[Node, Node]]: + """ + Helper function to extract the nodes in the conv-bn fusion pattern after + subgraph rewriting, in the form of a map: + + {name: (original_node, replacement_node)} + + The following names must exist in the map: + + "conv", "conv_weight", "conv_input", "bn", "getitem" + + The following names may exist in the map: + + "conv_weight_q", "conv_weight_dq", "conv_bias", + "conv_bias_q", "conv_bias_dq" + """ + + def _get_nodes(nodes: list[Node]) -> tuple[Node, Node, Node | None]: + """ + Return a 3-tuple of (conv_node, bn_node, getitem_node). + This asserts that the match contains exactly one of each node. + """ + conv_node, bn_node, getitem_node = None, None, None + for n in nodes: + if n.op != "call_function": + continue + if _is_conv_or_conv_transpose_node(n): + if conv_node is not None: + raise AssertionError( + f"Found multiple conv nodes in match, previous: {conv_node}, new: {n}" + ) + conv_node = n + if _is_bn_node(n): + if bn_node is not None: + raise AssertionError( + f"Found multiple bn nodes in match, previous: {bn_node}, new: {n}" + ) + bn_node = n + if n.target is operator.getitem: + if getitem_node is not None: + raise AssertionError( + f"Found multiple getitem nodes in match, previous: {getitem_node}, new: {n}" + ) + getitem_node = n + if conv_node is None: + raise AssertionError( + "Expected exactly one conv node in the match, found none" + ) + if bn_node is None: + raise AssertionError( + "Expected exactly one bn node in the match, found none" + ) + return (conv_node, bn_node, getitem_node) + + def _get_q_dq_nodes(n: Node) -> tuple[Node, Node, Node]: + """ + Return a 3-tuple of (orig_node, q_node, dq_node). + """ + if not _is_dequantize(n): + raise AssertionError(f"Expected a dequantize node, got: {n}") + q_node = n.args[0] + if not isinstance(q_node, Node): + raise AssertionError( + f"Expected quantize node to be a torch.fx.Node, got {type(q_node)}" + ) + if not _is_quantize(q_node): + raise AssertionError( + f"Expected q_node to be a quantize node, got target={q_node.target}" + ) + orig_node = q_node.args[0] + if not isinstance(orig_node, Node): + raise AssertionError( + f"Expected original node to be a torch.fx.Node, got {type(orig_node)}" + ) + return (orig_node, q_node, n) + + original_nodes = list(_filter_nodes_map(r.nodes_map).values()) + o_conv, o_bn, o_getitem = _get_nodes(original_nodes) + r_conv, r_bn, r_getitem = _get_nodes(r.replacements) + + # Create the mapping from original node to replacement node + if o_getitem is not None: + raise AssertionError(f"Expected o_getitem to be None, got {o_getitem}") + if r_getitem is not None: + raise AssertionError(f"Expected r_getitem to be None, got {r_getitem}") + mapping = { + "conv": (o_conv, r_conv), + "bn": (o_bn, r_bn), + } + + # Extract conv input and weight + # Note: here we extract the original nodes indirectly through the pattern nodes + # because the args of the original nodes are no longer available after replacement + (p_conv, _, _) = _get_nodes(list(r.nodes_map.keys())) + (p_conv_input, p_conv_weight, *_) = p_conv.args + (r_conv_input, r_conv_weight, *_) = r_conv.args + if not isinstance(p_conv_input, Node): + raise AssertionError( + f"Expected p_conv_input to be a Node, got {type(p_conv_input)}" + ) + if not isinstance(p_conv_weight, Node): + raise AssertionError( + f"Expected p_conv_weight to be a Node, got {type(p_conv_weight)}" + ) + if not isinstance(r_conv_input, Node): + raise AssertionError( + f"Expected r_conv_input to be a Node, got {type(r_conv_input)}" + ) + if not isinstance(r_conv_weight, Node): + raise AssertionError( + f"Expected r_conv_weight to be a Node, got {type(r_conv_weight)}" + ) + o_conv_input = r.nodes_map[p_conv_input] + o_conv_weight = r.nodes_map[p_conv_weight] + + # If conv weight is quantized, extract the q - dq nodes + if _is_dequantize(p_conv_weight): + p_conv_weight, p_conv_weight_q, p_conv_weight_dq = _get_q_dq_nodes( + p_conv_weight + ) + r_conv_weight, r_conv_weight_q, r_conv_weight_dq = _get_q_dq_nodes( + r_conv_weight + ) + o_conv_weight = r.nodes_map[p_conv_weight] + o_conv_weight_q = r.nodes_map[p_conv_weight_q] + o_conv_weight_dq = r.nodes_map[p_conv_weight_dq] + mapping["conv_weight_q"] = (o_conv_weight_q, r_conv_weight_q) + mapping["conv_weight_dq"] = (o_conv_weight_dq, r_conv_weight_dq) + mapping["conv_input"] = (o_conv_input, r_conv_input) + mapping["conv_weight"] = (o_conv_weight, r_conv_weight) + + # Extract conv bias + if len(p_conv.args) > 2 and len(r_conv.args) > 2: + p_conv_bias = p_conv.args[2] + r_conv_bias = r_conv.args[2] + if not isinstance(p_conv_bias, Node): + raise AssertionError( + f"Expected p_conv_bias to be a Node, got {type(p_conv_bias)}" + ) + if not isinstance(r_conv_bias, Node): + raise AssertionError( + f"Expected r_conv_bias to be a Node, got {type(r_conv_bias)}" + ) + o_conv_bias = r.nodes_map[p_conv_bias] + + # If conv bias is quantized, extract the q - dq nodes + if _is_dequantize(p_conv_bias): + p_conv_bias, p_conv_bias_q, p_conv_bias_dq = _get_q_dq_nodes(p_conv_bias) + r_conv_bias, r_conv_bias_q, r_conv_bias_dq = _get_q_dq_nodes(r_conv_bias) + o_conv_bias = r.nodes_map[p_conv_bias] + o_conv_bias_q = r.nodes_map[p_conv_bias_q] + o_conv_bias_dq = r.nodes_map[p_conv_bias_dq] + mapping["conv_bias_q"] = (o_conv_bias_q, r_conv_bias_q) + mapping["conv_bias_dq"] = (o_conv_bias_dq, r_conv_bias_dq) + mapping["conv_bias"] = (o_conv_bias, r_conv_bias) + return mapping + + +def _filter_nodes_map(nodes_map: dict[Node, Node]) -> dict[Node, Node]: + """ + Return a filtered `nodes_map` returned from the subgraph rewriter. + The filtered `nodes_map` will contain only nodes that are actually + matched in the pattern, excluding None or placeholder nodes. + """ + new_nodes_map: dict[Node, Node] = {} + for pattern_node, graph_node in nodes_map.items(): + # bias can be None + if graph_node is None: + continue + # skip pattern placeholder nodes + if pattern_node.op == "placeholder": + continue + new_nodes_map[pattern_node] = graph_node + return new_nodes_map + + +# TODO: this is error prone, use the replace_literals_with_placeholders hack instead +def _copy_over_literal_conv_args(original_node: Node, new_node: Node): + """ + Copy over literal args in conv, such as stride and padding, from the matched node + in the original graph to its replacement in the new graph. + + This is needed due to the following limitation in the subgraph rewriter when used + with dynamo export: literal (non-tensor) args are not supported in the match and + replacement patterns. This is because dynamo export automatically inlines these + literal args, making them dead placeholder nodes. In the future, we should check + if dynamo export can optionally disable this inlining, or if subgraph rewriter + can do the copying for us. See https://github.com/pytorch/pytorch/issues/100419. + + Note: Unlike other tensor args like conv weights and biases, literal args are + preserved in the original nodes after replacement, so we can access them here. + """ + if not _is_conv_or_conv_transpose_node(original_node): + raise AssertionError( + f"Expected original_node to be a conv node, got {original_node}" + ) + if not _is_conv_or_conv_transpose_node(new_node): + raise AssertionError(f"Expected new_node to be a conv node, got {new_node}") + # x, weight, bias, [stride, padding, dilation, transposed, output_padding, groups] + new_args = list(new_node.args) + if len(new_args) < 3: + # bias is optional, when it is not present, it means it is None + new_args.append(None) + new_node.args = tuple(new_args[:3]) + original_node.args[3:] + + +def _update_conv_input_qspec_map_after_replacement( + original_node: Node, replacement_node: Node +): + """ + Update the `input_qspec_map` in the annotation after subgraph rewriting. + + The original annotation referred to the nodes in the original graph, + so the keys in the `input_qspec_map` will need to be updated to reflect + the corresponding nodes in the replacement graph. + """ + if not _is_conv_or_conv_transpose_node(original_node): + raise AssertionError( + f"Expected original_node to be a conv node, got {original_node}" + ) + if not _is_conv_or_conv_transpose_node(replacement_node): + raise AssertionError( + f"Expected replacement_node to be a conv node, got {replacement_node}" + ) + if "quantization_annotation" not in original_node.meta: + return + original_input_qspec_map = original_node.meta[ + "quantization_annotation" + ].input_qspec_map + input_qspec_map = {} + # get the list of configs, it should be ordered as input, weight, bias + # note: this is really hacky, we need a better solution, hopefully + # in subgraph_rewriter, issue tracking the problem: https://github.com/pytorch/pytorch/issues/101820 + all_configs = list(original_input_qspec_map.items()) + # input activation + input_qspec_map[replacement_node.args[0]] = all_configs[0][1] + # weight + input_qspec_map[replacement_node.args[1]] = all_configs[1][1] + # bias + if len(replacement_node.args) > 2 and len(all_configs) > 2: + input_qspec_map[replacement_node.args[2]] = all_configs[2][1] + replacement_node.meta["quantization_annotation"].input_qspec_map = input_qspec_map + + +def _update_special_qspecs_after_replacement( + node: Node, + original_to_replacement_node: dict[Node, Node], +): + """ + Update the `SharedQuantizationSpec`s and `DerivedQuantizationSpec`s + used in `node`'s quantization annotation after subgraph rewriting. + + The original annotation referred to the nodes in the original graph, + so the nodes used in these special quantization specs will need to + be updated to the corresponding nodes in the replacement graph. + """ + + def _get_new_edge_or_node(edge_or_node: EdgeOrNode): + if isinstance(edge_or_node, Node): + _node = edge_or_node + return original_to_replacement_node.get(_node, _node) + elif ( + isinstance(edge_or_node, tuple) + and len(edge_or_node) == 2 + and all(isinstance(x, Node) for x in edge_or_node) + ): + src, dest = edge_or_node + return ( + original_to_replacement_node.get(src, src), + original_to_replacement_node.get(dest, dest), + ) + else: + raise ValueError("unexpected type for edge_or_node: ", type(edge_or_node)) + + def _get_new_qspec(qspec: QuantizationSpecBase): + if isinstance(qspec, SharedQuantizationSpec): + new_edge_or_node = _get_new_edge_or_node(qspec.edge_or_node) + return SharedQuantizationSpec(new_edge_or_node) + elif isinstance(qspec, DerivedQuantizationSpec): + new_derived_from = [_get_new_edge_or_node(x) for x in qspec.derived_from] + return dataclasses.replace(qspec, derived_from=new_derived_from) + else: + return qspec + + if "quantization_annotation" not in node.meta: + return + annotation = node.meta["quantization_annotation"] + for input_node, qspec in annotation.input_qspec_map.items(): + annotation.input_qspec_map[input_node] = _get_new_qspec(qspec) + annotation.output_qspec = _get_new_qspec(annotation.output_qspec) + + +def _fuse_conv_bn_qat(m: GraphModule) -> GraphModule: + # Example inputs for conv-bn1d patterns + _conv1d_bn_example_inputs = ( + torch.randn(1, 1, 3), # x + torch.randn(1, 1, 1), # conv_weight + torch.randn(1), # conv_bias + torch.randn(1), # bn_weight + torch.randn(1), # bn_bias + torch.randn(1), # bn_running_mean + torch.randn(1), # bn_running_var + ) + + # Example inputs for conv-bn2d patterns + _conv2d_bn_example_inputs = ( + torch.randn(1, 1, 3, 3), # x + torch.randn(1, 1, 1, 1), # conv_weight + torch.randn(1), # conv_bias + torch.randn(1), # bn_weight + torch.randn(1), # bn_bias + torch.randn(1), # bn_running_mean + torch.randn(1), # bn_running_var + ) + + has_bn = any(_is_bn_node(n) for n in m.graph.nodes) + if not has_bn: + return m + is_cuda_options = [True, False] if torch.cuda.is_available() else [False] + for is_cuda in is_cuda_options: + m = _fuse_conv_bn_qat_helper( + m, F.conv1d, _conv1d_bn_example_inputs, is_cuda=is_cuda + ) + m = _fuse_conv_bn_qat_helper( + m, F.conv2d, _conv2d_bn_example_inputs, is_cuda=is_cuda + ) + m = _fuse_conv_bn_qat_helper( + m, F.conv_transpose1d, _conv1d_bn_example_inputs, is_cuda=is_cuda + ) + m = _fuse_conv_bn_qat_helper( + m, F.conv_transpose2d, _conv2d_bn_example_inputs, is_cuda=is_cuda + ) + return m + + +def _fuse_conv_bn_qat_helper( + m: GraphModule, + conv_fn: Callable, + example_inputs: tuple[Any, ...], + is_cuda: bool, +) -> GraphModule: + """ + Given a graph of decomposed aten ops, replace the (conv + bn) pattern with + the fused QAT subgraph equivalent. The input graph should already be annotated. + The annotations in the original nodes will be preserved in the corresponding + nodes in the new subgraph. + + Note: This also handles the (conv + bn + relu) pattern. + """ + m.graph.eliminate_dead_code() + m.recompile() + + conv_bn_pattern = _get_conv_bn_pattern(conv_fn) + match_pattern = _get_aten_graph_module_for_pattern( + conv_bn_pattern, + example_inputs, + is_cuda, + ) + + # Step (1): Replace patterns with conv bias + # + # Here we do replacement separately for cases with and without conv bias, since + # the replacement patterns for these two cases are substantially different. + # TODO: use the public replace_pattern API once it also returns replacement nodes + + qat_conv_bn_pattern = _get_qat_conv_bn_pattern(conv_fn) + replacement_pattern_with_conv_bias = _get_aten_graph_module_for_pattern( + qat_conv_bn_pattern, + example_inputs, + is_cuda, + ) + replacements_with_conv_bias = replace_pattern_with_filters( + m, + match_pattern, + replacement_pattern_with_conv_bias, + match_filters=[_has_conv_bias_filter], + ignore_literals=True, + ) + m.recompile() + + # Step (2): Replace patterns without conv bias + + qat_conv_bn_pattern_no_conv_bias = _get_qat_conv_bn_pattern_no_conv_bias(conv_fn) + replacement_pattern_no_conv_bias = _get_aten_graph_module_for_pattern( + qat_conv_bn_pattern_no_conv_bias, + example_inputs, + is_cuda, + ) + replacements_no_conv_bias = replace_pattern_with_filters( + m, + match_pattern, + replacement_pattern_no_conv_bias, + match_filters=[_no_conv_bias_filter], + ignore_literals=True, + ) + m.recompile() + + # Step (3): Post processing + # + # Due to limited functionality in the subgraph rewriter, here we manually + # update the replacement graph as follows: + # + # (a) Copy over metadata from original subgraph. This ensures the stack traces + # and annotations are preserved in the new subgraph + # + # (b) Copy over literal args for conv from the original subgraph + # TODO: do this for literal args for batchnorm as well + # + # (c) Update all references of the old nodes in the original subgraph to refer + # to the corresponding nodes in the new subgraph in the annotations + # + # In the future, we should try to push as much of this functionality into the + # subgraph rewriter as possible, so we don't have to manually copy anything over. + # For more detail, see https://github.com/pytorch/pytorch/issues/100419. + + all_original_to_replacement_nodes = {} + for r in replacements_with_conv_bias + replacements_no_conv_bias: + replacement_dict = _get_conv_bn_pattern_nodes(r) + # The original conv node's "nn_module_stack" + conv_nn_module = replacement_dict["conv"][0].meta.get("nn_module_stack", None) + for k, node_tuple in replacement_dict.items(): + original_node, replacement_node = node_tuple + # Step (3a): Copy over metadata for all nodes in [conv - bn - getitem] + replacement_node.meta = original_node.meta + # If original_node is a get_attr node, it doesn't have nn_module_stack. + # In this case, we copy nn_module_stack from the original conv node. + if ( + k in ["conv_input", "conv_weight"] + and conv_nn_module + and "nn_module_stack" not in replacement_node.meta + ): + replacement_node.meta["nn_module_stack"] = copy.deepcopy(conv_nn_module) + if _is_conv_or_conv_transpose_node(original_node): + # Step (3b): Copy over conv literal args + _copy_over_literal_conv_args(original_node, replacement_node) + # Step (3c): Update old references in the conv node's input_qspec_map + _update_conv_input_qspec_map_after_replacement( + original_node, replacement_node + ) + all_original_to_replacement_nodes[original_node] = replacement_node + + # Step (3c): Update old references in the special qspecs for all nodes in the graph + for n in m.graph.nodes: + _update_special_qspecs_after_replacement(n, all_original_to_replacement_nodes) + + return m + + +def _duplicate_dequantize_node(m: GraphModule): + """ + Helper function to duplicate all dequantize nodes in the graph if the + node has more than one user. For example: + + Before: + quantize -> dequantize -> a + \\--> b + \\--> c + + After: + quantize -> dequantize_1 -> a + \\--> dequantize_2 -> b + \\--> dequantize_3 -> c + + This is useful for subgraph rewriting. E.g. if we wish to match the + pattern [dequantize - a] above, subgraph matching would fail because + the dequantize node has users outside the matched portion of the graph. + Instead, we match [dequantize_1 - a], which is safe. + """ + dq_op = torch.ops.quantized_decomposed.dequantize_per_tensor + for n in m.graph.nodes: + if n.op != "call_function" or n.target != dq_op or len(n.users) == 1: + continue + for user in list(n.users): + with m.graph.inserting_before(n): + new_node = m.graph.create_node("call_function", dq_op, n.args, n.kwargs) + user.replace_input_with(n, new_node) + m.graph.erase_node(n) + m.recompile() + + +def _remove_extra_dequantize(m: GraphModule): + """ + Removes duplicate dequant nodes in the graph, for an operator that has + multiple dequant nodes as a user. Replace them with a single dequant node + that can be shared across all uses. This should be seen as the "reverse" + of `_duplicate_dequantize_node`. + """ + dq_op = torch.ops.quantized_decomposed.dequantize_per_tensor + for n in m.graph.nodes: + dq_users = [ + user + for user in n.users + if user.op == "call_function" and user.target == dq_op + ] + if len(dq_users) > 1: + with m.graph.inserting_after(dq_users[0]): + new_node = m.graph.create_node( + "call_function", dq_op, dq_users[0].args, {} + ) + for dq_user in dq_users: + dq_user.replace_all_uses_with(new_node) + m.graph.erase_node(dq_user) + m.recompile() + + +def _copy_over_q_dq_args(original_node: Node, replacement_node: Node): + """ + Given a pair of quantize or dequantize nodes, copy over all literal args + from the original node to the replacement node. + """ + # For quantize_per_tensor, scale and zp are literals and need to be copied + # For quantize_per_channel, scale and zp are get_attr nodes and should be skipped + if original_node.target != replacement_node.target: + raise AssertionError( + "Expected original and replacement nodes to have the same target, got " + f"{original_node.target} != {replacement_node.target}" + ) + if original_node.target in ( + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + ): + # Args: input, [scale, zp, qmin, qmax, dtype] + start_copy_arg_index = 1 + elif original_node.target in ( + torch.ops.quantized_decomposed.quantize_per_channel.default, + torch.ops.quantized_decomposed.dequantize_per_channel.default, + ): + # Args: input, scale, zp, [axis, qmin, qmax, dtype] + start_copy_arg_index = 3 + else: + raise ValueError( + f"Expected quantize/dequantize nodes, got '{original_node.target}'" + ) + replacement_node.args = ( + replacement_node.args[:start_copy_arg_index] + + original_node.args[start_copy_arg_index:] + ) + + +def _fold_conv_bn_qat(m: GraphModule) -> GraphModule: + # Example inputs for quantized and folded conv-bn1d patterns used in convert + _quantized_conv1d_bn_example_inputs = ( + torch.randn(1, 1, 3), # x + torch.randn(1, 1, 1), # conv_weight + torch.randn(1), # bn_weight + torch.randn(1), # bn_bias + torch.randn(1), # bn_running_mean + torch.randn(1), # bn_running_var + ) + + # Example inputs for quantized and folded conv-bn2d patterns used in convert + _quantized_conv2d_bn_example_inputs = ( + torch.randn(1, 1, 3, 3), # x + torch.randn(1, 1, 1, 1), # conv_weight + torch.randn(1), # bn_weight + torch.randn(1), # bn_bias + torch.randn(1), # bn_running_mean + torch.randn(1), # bn_running_var + ) + + has_bn = any(_is_bn_node(n) for n in m.graph.nodes) + if not has_bn: + return m + is_cuda_options = [True, False] if torch.cuda.is_available() else [False] + for is_cuda in is_cuda_options: + m = _fold_conv_bn_qat_helper( + m, F.conv1d, _quantized_conv1d_bn_example_inputs, is_cuda=is_cuda + ) + m = _fold_conv_bn_qat_helper( + m, F.conv2d, _quantized_conv2d_bn_example_inputs, is_cuda=is_cuda + ) + m = _fold_conv_bn_qat_helper( + m, F.conv_transpose1d, _quantized_conv1d_bn_example_inputs, is_cuda=is_cuda + ) + m = _fold_conv_bn_qat_helper( + m, F.conv_transpose2d, _quantized_conv2d_bn_example_inputs, is_cuda=is_cuda + ) + + # remove in place add from batchnorm tracking training stats + for node in m.graph.nodes: + if ( + node.target is torch.ops.aten.add_.Tensor + and node.args[0].op == "get_attr" + and node.args[1] == 1 + and ( + torch.nn.modules.batchnorm.BatchNorm2d + in [val[1] for val in node.meta["source_fn_stack"]] + or torch.nn.modules.batchnorm.BatchNorm1d + in [val[1] for val in node.meta["source_fn_stack"]] + ) + ): + m.graph.erase_node(node) + + m.graph.eliminate_dead_code() + m.recompile() + + return m + + +def _fold_conv_bn_qat_helper( + m: GraphModule, + conv_fn: Callable, + example_inputs: tuple[Any, ...], + is_cuda: bool, +) -> GraphModule: + """ + Replace the quantized (conv + bn) pattern with conv with bn weights folded into the weights of conv. + """ + + m.graph.eliminate_dead_code() + m.recompile() + _duplicate_dequantize_node(m) + + # Step (1): Replace QAT pattern with simple [conv - bn] pattern + replacements = [] + replacement_options = itertools.product( + [True, False], # is_per_channel + [True, False], # has_bias + [True, False], # bias_is_quantized + [True, False], # bn_is_training + ) + for ( + is_per_channel, + has_bias, + bias_is_quantized, + bn_is_training, + ) in replacement_options: + # For the cases without bias, `bias_is_quantized` is irrelevant, so here we arbitrarily + # filter out one of the values for this flag to avoid having duplicate patterns + if not has_bias and bias_is_quantized: + continue + kwargs = _get_quantized_conv_bn_example_inputs_kwargs( + is_per_channel, has_bias, bias_is_quantized, is_cuda + ) + match_pattern = _get_quantized_qat_conv_bn_pattern( + is_per_channel, has_bias, bias_is_quantized, conv_fn, bn_is_training + ) + match_pattern = _get_aten_graph_module_for_pattern( + match_pattern, + example_inputs, + is_cuda, + **kwargs, + ) + replacement_pattern = _get_folded_quantized_qat_conv_bn_pattern( + is_per_channel, has_bias, bias_is_quantized, conv_fn, bn_is_training + ) + replacement_pattern = _get_aten_graph_module_for_pattern( + replacement_pattern, + example_inputs, + is_cuda, + **kwargs, + ) + replacements.extend( + replace_pattern_with_filters( + m, + match_pattern, + replacement_pattern, + ignore_literals=True, + ) + ) + m.recompile() + _remove_extra_dequantize(m) + + for r in replacements: + node_map = _get_conv_bn_pattern_nodes(r) + + # Step (2): Copy over metadata from original subgraph + for original_node, replacement_node in node_map.values(): + replacement_node.meta = original_node.meta + + # Step (3): Copy over args for weight (and optionally bias) q - dq nodes + _copy_over_q_dq_args(*node_map["conv_weight_q"]) + _copy_over_q_dq_args(*node_map["conv_weight_dq"]) + if "conv_bias_q" in node_map: + if "conv_bias_dq" not in node_map: + raise AssertionError( + "Expected 'conv_bias_dq' to be present in node_map when 'conv_bias_q' is present" + ) + _copy_over_q_dq_args(*node_map["conv_bias_q"]) + _copy_over_q_dq_args(*node_map["conv_bias_dq"]) + + # Step (4): Fold BN weights into conv + conv_bias = None + (_, conv_node) = node_map["conv"] + (_, bn_node) = node_map["bn"] + (_, conv_weight) = node_map["conv_weight"] + if "conv_bias" in node_map: + (_, conv_bias) = node_map["conv_bias"] + fold_bn_weights_into_conv_node(conv_node, conv_weight, conv_bias, bn_node, m) + + # Copy over literal args for conv + for original_node in _filter_nodes_map(r.nodes_map).values(): + if _is_conv_or_conv_transpose_node(original_node): + _copy_over_literal_conv_args(original_node, conv_node) + + m.graph.eliminate_dead_code() + m.recompile() + return m diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/pt2e/utils.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/pt2e/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..69a74ea6a0dfaf541a7617a16419013cce597bdd --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/pt2e/utils.py @@ -0,0 +1,625 @@ +# mypy: allow-untyped-defs +import operator +import types +from collections.abc import Callable +from typing import Any + +import torch +import torch.ao.quantization.pt2e._affine_quantization # noqa: F401 +import torch.nn.functional as F +import torch.utils._pytree as pytree + +# Makes sure that quantized_decomposed ops are registered +from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa: F401 +from torch.ao.quantization.quantizer import QuantizationAnnotation +from torch.export.unflatten import _assign_attr, _AttrKind +from torch.fx import GraphModule, Node +from torch.nn.utils.fusion import fuse_conv_bn_weights + + +__all__ = [ + "fold_bn_weights_into_conv_node", + "remove_tensor_overload_for_qdq_ops", +] + +_QUANTIZE_OPS = [ + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.quantize_per_tensor.tensor, + torch.ops.quantized_decomposed.quantize_per_channel.default, +] + + +_DEQUANTIZE_OPS = [ + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.tensor, + torch.ops.quantized_decomposed.dequantize_per_channel.default, +] + + +def _is_connected(source: torch.fx.Node, dest: torch.fx.Node) -> bool: + """ + Assuming dest is one of the ops inserted by quant workflow, this function + finds if source and dest are connected. Assumption is that only quant workflow + inserted ops exist between source and dest + """ + quant_workflow_ops = _QUANTIZE_OPS + _DEQUANTIZE_OPS + quant_workflow_ops.append(torch.ops.quantized_decomposed.choose_qparams.tensor) + while dest.target in quant_workflow_ops: + if not isinstance(dest.args[0], torch.fx.Node): + raise ValueError( + f"expected arg[0] of quant workflow ops to be a node but found {dest.args[0]}" + ) + dest = dest.args[0] + return dest == source + + +def _find_q_dq_node_for_user( + produer: torch.fx.Node, user: torch.fx.Node +) -> tuple[Any, Any]: + """ + Find q, dq pair corresponding to [producer -> q -> dq -> user] + Utils works by finding dq arg of user and ensuring it is connected to + producer + """ + dq_node = None + for n in user.args: + if ( + isinstance(n, torch.fx.Node) + and n.op == "call_function" + and n.target in _DEQUANTIZE_OPS + ): + if _is_connected(produer, n): + dq_node = n + break + if dq_node is None: + for n in user.kwargs: + if ( + isinstance(n, torch.fx.Node) + and n.op == "call_function" + and n.target in _DEQUANTIZE_OPS + ): + if _is_connected(produer, n): + dq_node = n + break + if dq_node is None: + return (None, None) + + q_node = None + if ( + isinstance(arg := dq_node.args[0], torch.fx.Node) + and arg.op == "call_function" + and arg.target in _QUANTIZE_OPS + ): + q_node = arg + return (q_node, dq_node) + + +def _is_sym_size_node(node: Node): + return ( + node.op == "call_function" + and node.target is torch.ops.aten.sym_size.default + or node.target is torch.ops.aten.sym_numel.default + or node.target is torch.ops.aten.sym_numel + or node.target is torch.ops.aten.sym_size + ) + + +def _filter_sym_size_users(node: torch.fx.Node) -> list[torch.fx.Node]: + node_users = list(filter((lambda x: (_is_sym_size_node(x) is False)), node.users)) + return node_users + + +def _is_valid_annotation(annotation: QuantizationAnnotation) -> bool: + if annotation is None: + return False + input_qspec_map = annotation.input_qspec_map + output_qspec = annotation.output_qspec + if len(input_qspec_map) == 0 and output_qspec is None: + return False + return True + + +def _get_tensor_constant_from_node(node, m): + if node is None: + return None + if node.op != "get_attr": + raise AssertionError(f"Expected node.op to be 'get_attr', got {node.op}") + target_atoms = node.target.split(".") + attr_itr = m + for i, atom in enumerate(target_atoms): + if not hasattr(attr_itr, atom): + raise RuntimeError( + f"Node referenced nonexistent target {'.'.join(target_atoms[:i])}" + ) + attr_itr = getattr(attr_itr, atom) + return attr_itr + + +def _get_all_arguments(orig_args, orig_kwargs, args_schema): + all_args = [] + for i, schema in enumerate(args_schema): + if schema.name in orig_kwargs: + all_args.append(orig_kwargs[schema.name]) + elif not schema.kwarg_only and i < len(orig_args): + all_args.append(orig_args[i]) + else: + all_args.append(schema.default_value) + return all_args + + +def _is_supported_batch_norm_for_training(node: Node): + """ + Return True if the given node refers to an aten batch norm op QAT supports. + """ + supported_ops = [ + torch.ops.aten.batch_norm.default, + torch.ops.aten._native_batch_norm_legit.default, + # Note: we won't need this op anymore after batch norm consolidation + # For now, we need to continue to support it because it gives better + # training numerics than `_native_batch_norm_legit` + torch.ops.aten.cudnn_batch_norm.default, + torch.ops.aten.miopen_batch_norm.default, + ] + return node.target in supported_ops + + +# TODO: move this to torch/ao/quantization/utils.py +def _is_conv_node(n: Node): + """ + Return whether the node refers to an aten conv op. + """ + return n.op == "call_function" and n.target in [ + torch.ops.aten.conv1d.default, + torch.ops.aten.conv1d.padding, + torch.ops.aten.conv2d.default, + torch.ops.aten.conv2d.padding, + torch.ops.aten.conv3d.default, + torch.ops.aten.conv3d.padding, + ] + + +def _is_conv_transpose_node(n: Node): + """ + Return whether the node refers to an aten conv_transpose op. + """ + return n.op == "call_function" and n.target in [ + torch.ops.aten.conv_transpose1d, + torch.ops.aten.conv_transpose1d.default, + torch.ops.aten.conv_transpose2d, + torch.ops.aten.conv_transpose2d.input, + ] + + +def _is_conv_or_conv_transpose_node(n: Node): + """ + Return whether the node refers to an aten conv or conv transpose op. + """ + return _is_conv_node(n) or _is_conv_transpose_node(n) + + +def _is_conv_transpose_fn(conv_fn: Callable): + return conv_fn in [F.conv_transpose1d, F.conv_transpose2d] + + +def _is_bn_node(n: Node): + return ( + _is_supported_batch_norm_for_training(n) + or n.target is torch.ops.aten._native_batch_norm_legit_no_training.default + ) + + +def fold_bn_weights_into_conv_node( + conv_node: Node, + conv_weight_node: Node, + conv_bias_node: Node | None, + bn_node: Node, + m: GraphModule, +) -> None: + # conv args: input, weight, bias, stride, padding, dilation, ... + conv_w = _get_tensor_constant_from_node(conv_weight_node, m) + conv_b = _get_tensor_constant_from_node(conv_bias_node, m) + transpose = _is_conv_transpose_node(conv_node) + + # eval bn args: input, weight, bias, running mean, running var, momentum, eps + # train bn args: input, weight, bias, running mean, running var, training, momentum, eps + bn_args_schema = bn_node.target._schema.arguments # type: ignore[union-attr] + bn_args = _get_all_arguments(bn_node.args, bn_node.kwargs, bn_args_schema) + bn_w = _get_tensor_constant_from_node(bn_args[1], m) + bn_b = _get_tensor_constant_from_node(bn_args[2], m) + bn_rm = _get_tensor_constant_from_node(bn_args[3], m) + bn_rv = _get_tensor_constant_from_node(bn_args[4], m) + if bn_node.target is torch.ops.aten._native_batch_norm_legit_no_training.default: + eps_arg_index = 6 + elif _is_supported_batch_norm_for_training(bn_node): + eps_arg_index = 7 + else: + raise ValueError("BN node target is unexpected ", bn_node.target) + bn_eps = bn_args[eps_arg_index] + + fused_weight, fused_bias = fuse_conv_bn_weights( + conv_w, conv_b, bn_rm, bn_rv, bn_eps, bn_w, bn_b, transpose=transpose + ) + + # update the weight and bias for conv + conv_args = list(conv_node.args) + # filling in the default bias argument + if len(conv_args) == 2: + conv_args.append(None) + + # calling data since the fused_weight and fused_bias are nn.Parameter + weight_attr_name = conv_weight_node.target + if not isinstance(weight_attr_name, str): + raise AssertionError( + f"Expected conv_weight_node.target to be a string attribute name, got {type(weight_attr_name)}" + ) + _assign_attr(fused_weight, m, weight_attr_name, _AttrKind.PARAMETER) + if conv_bias_node is not None: + bias_attr_name = conv_bias_node.target + _assign_attr(fused_bias, m, str(bias_attr_name), _AttrKind.PARAMETER) + else: + bias_attr_name = weight_attr_name + "_bias" + _assign_attr(fused_bias, m, bias_attr_name, _AttrKind.PARAMETER) + with m.graph.inserting_before(conv_node): + get_bias_node = m.graph.get_attr(bias_attr_name) + # NOTE: here we assume the bias of conv is not quantized! + conv_args[2] = get_bias_node + conv_node.args = tuple(conv_args) + + # native_batch_norm has 3 outputs, we expect getitem calls on the output + # and we want to replace the uses of getitem 0 with the output of conv + # + if bn_node.target is torch.ops.aten.batch_norm.default: + # With the new training ir, instead of batch_norm + getitem, + # we only have the batch_norm node. + # + # Before: + # conv -> bn -> users + # After: + # conv -> users + # bn has no users now + bn_node.replace_all_uses_with(conv_node) + else: + # Before: + # conv -> bn - (first output) -> users1 + # \ - (second output) -> users2 + # \ - (third output) -> users3 + # After: + # conv -> (first output) -> users1 + # bn - + # \ - (second output) -> users2 + # \ - (third output) -> users3 + # if users2 and users3 are empty then bn will be removed through dead code elimination + for user in bn_node.users: + if ( + user.op != "call_function" + or user.target != operator.getitem + or user.args[1] != 0 + ): + continue + user.replace_all_uses_with(conv_node) + + # If the BN node does not have users, erase it from the graph + # Note: we need to do this manually because the model can still be in train + # mode at this point, in which case DCE won't erase the BN node automatically + # since the node refers to a mutating op. Here we still need to call DCE first + # to get rid of the unused getitem nodes that consume the BN node. + m.graph.eliminate_dead_code() + if len(bn_node.users) == 0: + m.graph.erase_node(bn_node) + + +# fuse conv bn weights, inplace modification of the graph_module and graph +def _fuse_conv_bn_(m: GraphModule) -> None: + has_bn = any(_is_bn_node(n) for n in m.graph.nodes) + if not has_bn: + return + for n in m.graph.nodes: + if n.op != "call_function" or n.target not in ( + torch.ops.aten._native_batch_norm_legit_no_training.default, + torch.ops.aten.batch_norm.default, + ): + continue + bn_node = n + n = bn_node.args[0] + if not _is_conv_or_conv_transpose_node(n): + continue + conv_node = n + conv_weight_node = conv_node.args[1] + conv_bias_node = conv_node.args[2] if len(conv_node.args) > 2 else None + fold_bn_weights_into_conv_node( + conv_node, conv_weight_node, conv_bias_node, bn_node, m + ) + + m.graph.eliminate_dead_code() + m.recompile() + + +def _get_node_name_to_scope(model: GraphModule) -> dict[str, tuple[str, type]]: + # TODO: move this information to fx node itself + node_name_to_scope: dict[str, tuple[str, type]] = {} + for n in model.graph.nodes: + nn_module_stack = n.meta.get("nn_module_stack", None) + current_scope = ("", type(None)) + if nn_module_stack: + bt = list(nn_module_stack.values())[-1] + current_scope = (bt[0].split(".")[-1], bt[1]) + node_name_to_scope[n.name] = current_scope + return node_name_to_scope + + +def _get_aten_graph_module_for_pattern( + pattern: Callable, + example_inputs: tuple[Any, ...], + is_cuda: bool = False, + **kwargs, +) -> GraphModule: + """ + Convert the pattern to an FX graph with decomposed aten ops. + """ + if is_cuda: + example_inputs = tuple( + x.cuda() if isinstance(x, torch.Tensor) else x for x in example_inputs + ) + + with torch._export.config.patch(use_new_tracer_experimental=True): + aten_pattern = torch.export.export( + pattern, # type: ignore[arg-type] + example_inputs, + kwargs, + strict=True, + ).module(check_guards=False) + + aten_pattern.graph.eliminate_dead_code() # type: ignore[operator, union-attr] + aten_pattern.recompile() # type: ignore[operator] + + # ep.module() adds copy_ nodes for the mutated inputs. + # For patterns, it doesn't matter + for node in aten_pattern.graph.nodes: # type: ignore[union-attr] + if ( + node.op == "call_function" + and node.target is torch.ops.aten.copy_.default + and len(node.users) == 0 + ): + aten_pattern.graph.erase_node(node) # type: ignore[operator, union-attr] + + aten_pattern.graph.eliminate_dead_code() # type: ignore[operator, union-attr] + aten_pattern.recompile() # type: ignore[operator] + + return aten_pattern # type: ignore[return-value] + + +def remove_tensor_overload_for_qdq_ops(match_pattern: GraphModule) -> None: + """Remove .tensor overload for quantize/dequantize ops so that we can + use the match_pattern that we get from torchdynamo export to match the output of convert_pt2e + """ + _MAP = { + torch.ops.quantized_decomposed.quantize_per_tensor.default: torch.ops.quantized_decomposed.quantize_per_tensor, + torch.ops.quantized_decomposed.dequantize_per_tensor.default: torch.ops.quantized_decomposed.dequantize_per_tensor, + torch.ops.quantized_decomposed.quantize_per_tensor.tensor: torch.ops.quantized_decomposed.quantize_per_tensor, + torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: torch.ops.quantized_decomposed.dequantize_per_tensor, + torch.ops.quantized_decomposed.quantize_per_tensor.tensor2: torch.ops.quantized_decomposed.quantize_per_tensor, + torch.ops.quantized_decomposed.dequantize_per_tensor.tensor2: torch.ops.quantized_decomposed.dequantize_per_tensor, + torch.ops.quantized_decomposed.quantize_per_channel.default: torch.ops.quantized_decomposed.quantize_per_channel, + torch.ops.quantized_decomposed.dequantize_per_channel.default: torch.ops.quantized_decomposed.dequantize_per_channel, + torch.ops.aten.clamp.Tensor: torch.ops.aten.clamp, + } + for n in match_pattern.graph.nodes: + if n.op != "call_function": + continue + if n.target in _MAP: + n.target = _MAP[n.target] + + +def _is_literal(arg): + if isinstance(arg, (int, float)): + return True + if isinstance(arg, (tuple, list)): + return all(map(_is_literal, arg)) + return False + + +def _replace_literals_with_new_placeholders( + gm: torch.fx.GraphModule, + merge_dup: bool = False, + exclude_literals: list[Any] | None = None, +): + """Replace the literals in the graph with placeholder nodes that's created on the fly while we + traverse the graph, so that the literal arguments in the graph can be matched and replaced + + To use this, the pattern and replacement graph should have the exact same number of literal args + and they should be used in the exact same order in the pattern and replacement graph. + + If the literal arguments are not used in the same order in pattern and replacement graph, please + use `_replace_literals_with_existing_placeholders` instead + + Args: + `gm`: input GraphModule that we'll transform + `merge_dup`: boolean flag to indicate that if the same literal appears multiple times in + the graph, whether they should correspond to the same placeholder or not + `exclude_literals`: a list of literals that will not be replaced with placeholders + + Example: + + # 1. Original Graph + def pattern(self, x): + return x + 3 + + def replacement(self, x): + return x - 3 + + example_inputs = (torch.randn(1, 3, 3, 3),) + pattern_gm = _get_aten_graph_module_for_pattern(pattern, example_inputs) + replacement_gm = _get_aten_graph_module_for_pattern(pattern, example_inptus) + + # 2. Before calling replace literals we'll see the following graph: + def pattern(self, x): + return x + 3 + + def replacement(self, x): + return x - 3 + + pattern_gm = _replace_literals_with_new_placeholders(pattern_gm) + replacement_gm = _replace_literals_with_new_placeholders(replacement_gm) + + # 3. After replacing literals with new placeholder nodes + + def pattern(self, x, new_ph): + return x + new_ph + + def pattern(self, x, new_ph): + return x - new_ph + + """ + last_ph = None + cnt = 0 + literal_to_ph: dict[float | bool | int | torch.dtype, Node] = {} + if exclude_literals is None: + exclude_literals = [] + + in_spec = gm._in_spec + assert in_spec.type is tuple + args_spec = in_spec.child(0) + assert args_spec.type is tuple + args_spec_children = args_spec.children() + for node in gm.graph.nodes: + if node.op == "placeholder": + last_ph = node + cnt += 1 + continue + with gm.graph.inserting_after(last_ph): + new_args = [] + for arg in node.args: + if _is_literal(arg) and arg not in exclude_literals: + if merge_dup and arg in literal_to_ph: + new_args.append(literal_to_ph[arg]) + else: + ph_node = gm.graph.placeholder("arg" + str(cnt)) + new_args.append(ph_node) + args_spec_children.append(pytree.treespec_leaf()) + cnt += 1 + if merge_dup: + literal_to_ph[arg] = ph_node + else: + new_args.append(arg) + new_args = tuple(new_args) + + node.args = new_args + + # Update `num_nodes`, `num_leaves`, `num_children`. + args_spec = pytree.treespec_tuple(args_spec_children) + gm._in_spec = in_spec = pytree.treespec_tuple([args_spec, *in_spec.children()[1:]]) + return gm + + +def _replace_literals_with_existing_placeholders( + gm: torch.fx.GraphModule, + exclude_literals: list[Any] | None = None, + literal_to_ph_idx: dict[float | int | bool | torch.dtype, int] | None = None, +): + """Replace the literals in the graph with **existing** placeholder nodes, so that the literal arguments + in the graph can be matched and replaced + + To use this, all literal args in the graph should be unique and each of them should correspond + to exactly one placeholder node + + # 1. Original Graph + def pattern(self, x_i8, scale, zero_point, quant_min, quant_max): + return torch.dequantize_per_tensor(x_i8, scale, zero_point, quant_min, quant_max) + + def replacement(x_i8, scale, zero_point, quant_min, quant_max): + x_i8 = torch.clamp(x_i8, quant_min, quant_max) + return ((x_i8.to(torch.float32) - zero_point) * scale).to(dtype=torch.float32) + + example_inputs = ( + torch.randn(1, 3, 3, 3), + 1.0, + 0, + -128, + 127, + ) + pattern_gm = _get_aten_graph_module_for_pattern(pattern, example_inputs) + replacement_gm = _get_aten_graph_module_for_pattern(pattern, example_inptus) + + # 2. Before calling replace literals we'll see the following graph: + def pattern(self, x_i8, scale, zero_point, quant_min, quant_max): + # scale/zero_point/quant_min/quant_max are burnt in since they are scalar values + return torch.dequantize_per_tensor(x_i8, 1.0, 0, -128, 127) + + def replacement(x_i8, scale, zero_point, quant_min, quant_max): + # scale/zero_point/quant_min/quant_max are burnt in since they are scalar values + x_i8 = torch.clamp(x_i8, -128, 127) + return ((x_i8.to(torch.float32) - 0) * 1.0).to(dtype=torch.float32) + + # Note that literal args appear in different order in pattern and replacement graph, so + # we can't use _replace_literals_with_new_placeholders + + literal_to_ph_idx = {1.0: 1, 0: 2, -128: 3, 127: 4} + pattern_gm = _replace_literals_with_existing_placeholders(pattern_gm, literal_to_ph_idx) + replacement_gm = _replace_literals_with_existing_placeholders(replacement_gm, literal_to_ph_idx) + + # 3. After replacing literals with existing placeholder nodes + + def pattern(self, x_i8, scale, zero_point, quant_min, quant_max): + # scale/zero_point/quant_min/quant_max are burnt in since they are scalar values + return torch.dequantize_per_tensor(x_i8, scale, zero_point, quant_min, quant_max) + + def replacement(x_i8, scale, zero_point, quant_min, quant_max): + # scale/zero_point/quant_min/quant_max are burnt in since they are scalar values + x_i8 = torch.clamp(x_i8, quant_min, quant_max) + return ((x_i8.to(torch.float32) - zero_point) * scale).to(dtype=torch.float32) + """ + if exclude_literals is None: + exclude_literals = [] + + if literal_to_ph_idx is None: + literal_to_ph_idx = {} + + phs = [node for node in gm.graph.nodes if node.op == "placeholder"] + + for node in gm.graph.nodes: + if node.op != "call_function": + continue + new_args = [] + for arg in node.args: + if ( + _is_literal(arg) + and arg not in exclude_literals + and arg in literal_to_ph_idx + ): + ph_idx = literal_to_ph_idx[arg] + ph_node = phs[ph_idx] + new_args.append(ph_node) + else: + new_args.append(arg) + new_args = tuple(new_args) + node.args = new_args + return gm + + +# TODO: Handle this in export itself and don't wrap the model in another GraphModule +# in prepare and convert +def _disallow_eval_train(model: GraphModule): + """ + Disallow calling `model.train()` or `model.eval()` on the given GraphModule. + This is useful for exported models, where these methods don't actually behave as expected. + """ + error_message = """ + Calling train() or eval() is not supported for exported models. + Please call `torch.ao.quantization.move_exported_model_to_train(model)` (or eval) instead. + + If you cannot replace the calls to `model.train()` and `model.eval()`, you may override + the behavior for these methods by calling `torch.ao.quantization.allow_exported_model_train_eval(model)`, + which does the above automatically for you. Note that this has limited effect on switching + behavior between train and eval modes, and should be used only for special ops such as dropout + and batchnorm. + """ + + def _train(self, mode: bool = True): + raise NotImplementedError(error_message) + + def _eval(self, mode: bool = True): + raise NotImplementedError(error_message) + + model.train = types.MethodType(_train, model) # type: ignore[method-assign] + model.eval = types.MethodType(_eval, model) # type: ignore[method-assign] + return model diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/compiler/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/compiler/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..523d0083cace61b7537ea13375d5b793b1e50f8c Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/compiler/__pycache__/__init__.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/compiler/__pycache__/_cache.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/compiler/__pycache__/_cache.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3536be9b6555516e7095bd5608c4a6f33c6cb9fe Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/compiler/__pycache__/_cache.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/compiler/__pycache__/config.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/compiler/__pycache__/config.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f7364f6a53f85e89028651152c6ad68629d2f034 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/compiler/__pycache__/config.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/cpu/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/cpu/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c823809354dc7afa65e2739e03db2c7636a0f7a1 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/cpu/__pycache__/__init__.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/cpu/amp/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/cpu/amp/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..dae673c7b2313480a940a9cc19517dba21d20d3a --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/cpu/amp/__init__.py @@ -0,0 +1,3 @@ +# pyrefly: ignore [deprecated] +from .autocast_mode import autocast +from .grad_scaler import GradScaler diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/cpu/amp/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/cpu/amp/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0bc9c95f33db433e2762e01f1f99398222ed2e87 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/cpu/amp/__pycache__/__init__.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/cpu/amp/__pycache__/autocast_mode.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/cpu/amp/__pycache__/autocast_mode.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ab7e1362acaa693427d98e605b3f933690cb16db Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/cpu/amp/__pycache__/autocast_mode.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/cpu/amp/__pycache__/grad_scaler.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/cpu/amp/__pycache__/grad_scaler.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8d45b4a7dbe9cec5b3a4a537713fbde9e49f9757 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/cpu/amp/__pycache__/grad_scaler.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/cpu/amp/autocast_mode.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/cpu/amp/autocast_mode.py new file mode 100644 index 0000000000000000000000000000000000000000..f0f81060d4a01fc6857138c49ec8276bee59b90d --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/cpu/amp/autocast_mode.py @@ -0,0 +1,71 @@ +# mypy: allow-untyped-defs +import sys +from typing import Any +from typing_extensions import deprecated + +import torch + + +__all__ = ["autocast"] + + +@deprecated( + "`torch.cpu.amp.autocast(args...)` is deprecated. " + "Please use `torch.amp.autocast('cpu', args...)` instead.", + category=FutureWarning, +) +class autocast(torch.amp.autocast_mode.autocast): + r""" + See :class:`torch.autocast`. + ``torch.cpu.amp.autocast(args...)`` is deprecated. Please use ``torch.amp.autocast("cpu", args...)`` instead. + """ + + # TODO: remove this conditional once we stop supporting Python < 3.13 + # Prior to Python 3.13, inspect.signature could not retrieve the correct + # signature information for classes decorated with @deprecated (unless + # the __new__ static method was explicitly defined); + # + # However, this issue has been fixed in Python 3.13 and later versions. + if sys.version_info < (3, 13): + + def __new__( + cls, + enabled: bool = True, + dtype: torch.dtype = torch.bfloat16, + cache_enabled: bool = True, + ): + return super().__new__(cls) + + def __init_subclass__(cls): + pass + + def __init__( + self, + enabled: bool = True, + dtype: torch.dtype = torch.bfloat16, + cache_enabled: bool = True, + ): + if torch._jit_internal.is_scripting(): + self._enabled = enabled + self.device = "cpu" + self.fast_dtype = dtype + return + super().__init__( + "cpu", enabled=enabled, dtype=dtype, cache_enabled=cache_enabled + ) + + def __enter__(self): + if torch._jit_internal.is_scripting(): + return self + return super().__enter__() + + # TODO: discuss a unified TorchScript-friendly API for autocast + def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any): # type: ignore[override] + if torch._jit_internal.is_scripting(): + return + return super().__exit__(exc_type, exc_val, exc_tb) + + def __call__(self, func): + if torch._jit_internal.is_scripting(): + return func + return super().__call__(func) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/cpu/amp/grad_scaler.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/cpu/amp/grad_scaler.py new file mode 100644 index 0000000000000000000000000000000000000000..aefaa1c323f5ff9089fc69c7a7aabbb380cc7233 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/cpu/amp/grad_scaler.py @@ -0,0 +1,35 @@ +from typing_extensions import deprecated + +import torch + + +__all__ = ["GradScaler"] + + +class GradScaler(torch.amp.GradScaler): + r""" + See :class:`torch.amp.GradScaler`. + ``torch.cpu.amp.GradScaler(args...)`` is deprecated. Please use ``torch.amp.GradScaler("cpu", args...)`` instead. + """ + + @deprecated( + "`torch.cpu.amp.GradScaler(args...)` is deprecated. " + "Please use `torch.amp.GradScaler('cpu', args...)` instead.", + category=FutureWarning, + ) + def __init__( + self, + init_scale: float = 2.0**16, + growth_factor: float = 2.0, + backoff_factor: float = 0.5, + growth_interval: int = 2000, + enabled: bool = True, + ) -> None: + super().__init__( + "cpu", + init_scale=init_scale, + growth_factor=growth_factor, + backoff_factor=backoff_factor, + growth_interval=growth_interval, + enabled=enabled, + ) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8310b32e135b6f38d8c807599cf4fb0860654d63 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/__pycache__/__init__.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/__pycache__/_checkpointable.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/__pycache__/_checkpointable.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..32f654bb5cdc1420deb023fac10bd55b9663d5b2 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/__pycache__/_checkpointable.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/__pycache__/_composable_state.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/__pycache__/_composable_state.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7796fb93a9a6f470658ac9578f379e325ae28d65 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/__pycache__/_composable_state.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/__pycache__/_dist2.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/__pycache__/_dist2.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c97d4b5db2ebeba02667757cf1c474ce246ac6ea Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/__pycache__/_dist2.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/__pycache__/_functional_collectives.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/__pycache__/_functional_collectives.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fd6d37dc0516b23cca56b1de76a61f1516b76468 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/__pycache__/_functional_collectives.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/__pycache__/_functional_collectives_impl.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/__pycache__/_functional_collectives_impl.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e23ce1f6b906a02f838b56aaacc3ad8d70621d6d Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/__pycache__/_functional_collectives_impl.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/__pycache__/_mesh_layout.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/__pycache__/_mesh_layout.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d042def43637566d94c88701c050b6b61a5051b1 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/__pycache__/_mesh_layout.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/__pycache__/_serialization.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/__pycache__/_serialization.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8de776d6941526625ea6500a4a02eaa931670519 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/__pycache__/_serialization.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/__pycache__/_state_dict_utils.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/__pycache__/_state_dict_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0d63e2d893ccb0cc1674953534b00258f97db938 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/__pycache__/_state_dict_utils.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/__pycache__/argparse_util.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/__pycache__/argparse_util.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..55f7f1a659afbe153cb2dd52789cf00b3ffedbf8 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/__pycache__/argparse_util.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/__pycache__/c10d_logger.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/__pycache__/c10d_logger.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1d7db912f331fe103823c37160fb7b688e7f59a5 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/__pycache__/c10d_logger.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/__pycache__/collective_utils.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/__pycache__/collective_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a7075a1fcb4fcd89f31932f71de353948056bca6 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/__pycache__/collective_utils.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/__pycache__/constants.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/__pycache__/constants.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cf529468d41ef886c5d6800ec77f9b211e0fb6d4 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/__pycache__/constants.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/__pycache__/device_mesh.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/__pycache__/device_mesh.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bf058e9eba2b566b71b6dad84ccda3378928ba28 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/__pycache__/device_mesh.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/__pycache__/launch.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/__pycache__/launch.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5c7e61c9442ceaefd9d50356470a225763fe2df4 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/__pycache__/launch.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/__pycache__/logging_handlers.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/__pycache__/logging_handlers.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..185e7b1ce7f67c50125a343c363d2a6d05cd376e Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/__pycache__/logging_handlers.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/__pycache__/remote_device.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/__pycache__/remote_device.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..025ad014b4e050fd02cf973d0c3a427e723f8076 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/__pycache__/remote_device.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/__pycache__/rendezvous.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/__pycache__/rendezvous.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a8fb328abe24f0dfcaecc9866ccf0475b13f55a0 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/__pycache__/rendezvous.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/__pycache__/run.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/__pycache__/run.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c0dd62e8c11f292f1a60eb42ec6d64e2e0fa2dd6 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/__pycache__/run.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/__pycache__/utils.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/__pycache__/utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2331f65dc3b6952765fb386424971f400dcd71d4 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/__pycache__/utils.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_composable/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_composable/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3e38281810696814a7eae148eff19b58c10e072b --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_composable/__init__.py @@ -0,0 +1,3 @@ +from .checkpoint_activation import checkpoint +from .contract import _get_registry, contract +from .replicate import replicate diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_composable/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_composable/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4b4d2a1d2d7c9bfa195f9fb0f6356fef878d9c59 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_composable/__pycache__/__init__.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_composable/__pycache__/checkpoint_activation.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_composable/__pycache__/checkpoint_activation.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fa7a577dd9467a9ba6e916cc565adbeda4a9b370 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_composable/__pycache__/checkpoint_activation.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_composable/__pycache__/contract.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_composable/__pycache__/contract.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..400eb47fc34a2baa8a0f96d1020d3e5a05486f6a Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_composable/__pycache__/contract.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_composable/__pycache__/replicate.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_composable/__pycache__/replicate.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..41af53f0e5fac9ad530b0b0e77a6c8f19284fde1 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_composable/__pycache__/replicate.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_composable/__pycache__/replicate_with_fsdp.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_composable/__pycache__/replicate_with_fsdp.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b73cbad12620e957469a3d479b52ae85232de817 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_composable/__pycache__/replicate_with_fsdp.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_composable/checkpoint_activation.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_composable/checkpoint_activation.py new file mode 100644 index 0000000000000000000000000000000000000000..93ae14110ef79a3b9b065c4ca1e8af613bd90ff5 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_composable/checkpoint_activation.py @@ -0,0 +1,134 @@ +# mypy: allow-untyped-defs +from collections.abc import Generator +from contextlib import AbstractContextManager, contextmanager, nullcontext +from typing import Any + +import torch +import torch.nn as nn +from torch.utils.checkpoint import ( + _checkpoint_without_reentrant_generator, + _DEFAULT_DETERMINISM_MODE, +) + +from .contract import _State, contract + + +@contextmanager +def _no_hook(module: nn.Module, user_ctx: AbstractContextManager | None = None): + r""" + Disable hooks installed by checkpoint to avoid unintentional recursion + during backward recomputation. + """ + + with user_ctx if user_ctx else nullcontext(): + orig_enable_hook = checkpoint.state(module).enable_hook + checkpoint.state(module).enable_hook = False + try: + yield + finally: + checkpoint.state(module).enable_hook = orig_enable_hook + + +class _CheckpointState(_State): + enable_hook: bool = False + _ac_generator: Generator[None, None, None] | None + + +@contract(_CheckpointState) +def checkpoint(module: nn.Module, **kwargs) -> nn.Module: + r""" + This is a composable activation checkpointing API. Unlike functional + activation checkpointing APIs, this one does not require changing model + source code. Unlike ``nn.Module`` wrapper activation checkpointing APIs, + this one does not modify model structure or fully-qualified names either. + Under the hood, it registers activation checkpointing logic as pre- and + post-forward hooks. Hence, this API can be easily applied to any model or + sub-modules in the model. + + Args: + module (nn.Module): the target model or sub-module to apply activation + checkpointing. + + Example:: + >>> # xdoctest: +SKIP + >>> import torch.nn as nn + >>> + >>> class MyModel(nn.Module): + >>> def __init__(self) -> None: + >>> super().__init__() + >>> self.l1 = nn.Linear(10, 10) + >>> self.l2 = nn.Linear(10, 10) + >>> + >>> def forward(self, x): + >>> return self.l2(self.l1(x)) + >>> + >>> model = MyModel() + >>> checkpoint(model.l1) # apply activation checkpointing only to l1 + >>> model(torch.zeros(2, 10)).sum().backward() + + """ + torch._C._log_api_usage_once("torch.distributed.checkpoint") + + use_reentrant = kwargs.pop("use_reentrant", False) + if use_reentrant: + raise NotImplementedError( + "use_reentrant=True is not supported in composable checkpoint. " + "Please use torch.utils.checkpoint.checkpoint instead." + ) + preserve_rng_state = kwargs.pop("preserve_rng_state", True) + user_context_fns = kwargs.pop("context_fn", None) + determinism_check = kwargs.pop("determinism_check", _DEFAULT_DETERMINISM_MODE) + debug = kwargs.pop("debug", False) + early_stop = kwargs.pop("early_stop", True) + + if kwargs: + raise ValueError( + "Unexpected keyword arguments: " + ",".join(arg for arg in kwargs) + ) + + def forward_pre_hook( + module: nn.Module, args: tuple[Any, ...], kwargs: dict[str, Any] + ) -> None: + if checkpoint.state(module).enable_hook: + + def context_fns(): + if user_context_fns is not None: + ctx1, ctx2 = user_context_fns() + return ctx1, _no_hook(module, ctx2) + else: + return nullcontext(), _no_hook(module) + + gen = _checkpoint_without_reentrant_generator( + module, + preserve_rng_state, + context_fns, + determinism_check, + debug, + early_stop, + *args, + **kwargs, + ) + checkpoint.state(module)._ac_generator = gen + next(gen) + + def forward_hook(module: nn.Module, inputs: tuple[Any, ...], output: Any) -> Any: + if checkpoint.state(module).enable_hook: + try: + gen = checkpoint.state(module)._ac_generator + assert gen is not None + next(gen) + except StopIteration: + pass + else: + raise RuntimeError( + "Expected non-reentrant activation checkpoint generator to be exhausted, but it was not!" + ) + + # Ensure that we no longer hold on to the generator. always_call=True helps ensure we + # clear this even in the case of exception in fwd pass. + checkpoint.state(module)._ac_generator = None + + checkpoint.state(module).enable_hook = True + module.register_forward_pre_hook(forward_pre_hook, with_kwargs=True) + module.register_forward_hook(forward_hook, prepend=True, always_call=True) + return module diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_composable/contract.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_composable/contract.py new file mode 100644 index 0000000000000000000000000000000000000000..c810da8cb583c1199cda7087f7feb45b8ab6c443 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_composable/contract.py @@ -0,0 +1,259 @@ +# mypy: allow-untyped-defs +import uuid +from collections import OrderedDict +from collections.abc import Callable +from functools import wraps +from typing import Concatenate, Generic, Protocol +from typing_extensions import ParamSpec, TypeVar + +import torch +import torch.nn as nn +from torch.distributed._composable_state import _State +from torch.distributed.utils import _get_root_modules + + +_T = TypeVar("_T", covariant=True) +_P = ParamSpec("_P") + + +def generate_state_key(string="__composable_api_state_key"): + return f"{string}_{str(uuid.uuid4())}" + + +STATE_KEY = generate_state_key() +REGISTRY_KEY = generate_state_key() + + +# TODO: we can add additional info to RegistryItem to share across APIs. E.g., +# we can add args and kwargs here, and then we can detect whether fully_shard +# is combined with reentrant activation checkpointing and error out with a clear +# message. +class RegistryItem: + pass + + +_TState = TypeVar("_TState", bound="_State", covariant=True) +_M = TypeVar("_M", nn.Module, list[nn.Module]) + + +class _ContractFn(Protocol, Generic[_P, _T, _TState]): + def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _T: ... + + def state(self, module: nn.Module) -> _TState: ... + + +def contract( + state_cls: type[_TState] = _State, # type: ignore[assignment] +) -> Callable[ + [Callable[Concatenate[_M, _P], _M]], + _ContractFn[Concatenate[_M, _P], _M, _TState], +]: + r""" + Decorate a function as a composable distributed API, where the first + argument of the function must be an :class:`nn.Module` instance or sequence + of :class:`nn.Module` instances. + + The decorator verifies that the decorated function does not modify + fully-qualified names (FQNs) for parameters, buffers, or modules. The + decorated function can return different module instances than the input + modules; the FQN invariant will be enforced following the input order. + + When a function ``func`` is decorated by ``@contract()``, a + ``.state(module: nn.Module)`` method will be installed to the decorated + function. Then you can retrieve and modify the state on a module by calling + ``func.state(module)``. + + Example:: + >>> # xdoctest: +SKIP + >>> import torch.nn as nn + >>> + >>> class MyModel(nn.Module): + >>> def __init__(self) -> None: + >>> super().__init__() + >>> self.l1 = nn.Linear(10, 10) + >>> self.l2 = nn.Linear(10, 10) + >>> + >>> def forward(self, x): + >>> return self.l2(self.l1(x)) + >>> + >>> @contract() + >>> def my_feature(module: nn.Module) -> nn.Module: + >>> my_feature.state(module).some_state = "any value" + >>> return module + >>> + >>> model = MyModel() + >>> my_feature(model.l1) + >>> assert my_feature.state(model.l1).some_state == "any value" + >>> my_feature(model.l2) + >>> model(torch.randn(2, 10)).sum().backward() + """ + + # wraps will make functions decorated with contract() pickleable - needed for integration with torch.package + @wraps(state_cls) # type: ignore[arg-type] + def inner( + func: Callable[Concatenate[_M, _P], _M], + ) -> _ContractFn[Concatenate[_M, _P], _M, _TState]: + @wraps(func) + def wrapper( + module: _M, + *args: _P.args, + **kwargs: _P.kwargs, + ) -> _M: + inp_module = module + modules: list[nn.Module] + if isinstance(module, nn.Module): + modules = [module] + else: + # If the user passes a sequence of modules, then we assume that + # we only need to insert the state object on the root modules + # (i.e. those without a parent) among the passed-in modules. + # pyrefly: ignore [no-matching-overload] + modules = _get_root_modules(list(module)) + state = state_cls() # shared across all modules + registry_item = RegistryItem() # shared across all modules + + # `func` is allowed to return different module instances than the + # input modules as long as FQNs are preserved following the input + # module order + all_orig_named_params: list[dict[str, nn.Parameter]] = [] + all_orig_named_buffers: list[dict[str, torch.Tensor]] = [] + all_orig_named_modules: list[dict[str, nn.Module]] = [] + + # pyrefly: ignore [bad-assignment] + for module in modules: + default_all_state: dict[Callable, _State] = OrderedDict() + default_registry: dict[str, RegistryItem] = OrderedDict() + all_state: dict[Callable, _State] = module.__dict__.setdefault( # type: ignore[call-overload] + STATE_KEY, default_all_state + ) + if not isinstance(all_state, dict): + raise AssertionError( + f"Distributed composable API states corrupted: {all_state}" + ) + registry: dict[str, RegistryItem] = module.__dict__.setdefault( # type: ignore[call-overload] + REGISTRY_KEY, default_registry + ) + if not isinstance(registry, dict): + raise AssertionError( + f"Distributed composable API registry corrupted: {registry}" + ) + if func in all_state or func.__name__ in registry: + raise AssertionError( + "Each distinct composable distributed API can only be applied to a " + f"module once. {func.__name__} has already been applied to the " + f"following module:\n{module}" + ) + all_state.setdefault(func, state) + registry.setdefault(func.__name__, registry_item) + + # pyrefly: ignore [missing-attribute] + all_orig_named_params.append(OrderedDict(module.named_parameters())) + # pyrefly: ignore [missing-attribute] + all_orig_named_buffers.append(OrderedDict(module.named_buffers())) + # pyrefly: ignore [missing-attribute] + all_orig_named_modules.append(OrderedDict(module.named_modules())) + + updated = func(inp_module, *args, **kwargs) + if updated is None: + updated = inp_module # type: ignore[assignment] + updated_modules: list[nn.Module] + if isinstance(updated, nn.Module): + updated_modules = [updated] + else: + updated_modules = _get_root_modules(list(inp_module)) # type: ignore[arg-type, call-overload] + + all_new_named_params: list[dict[str, nn.Parameter]] = [] + all_new_named_buffers: list[dict[str, torch.Tensor]] = [] + all_new_named_modules: list[dict[str, nn.Module]] = [] + # pyrefly: ignore [bad-assignment] + for module in updated_modules: + # pyrefly: ignore [missing-attribute] + all_new_named_params.append(OrderedDict(module.named_parameters())) + # pyrefly: ignore [missing-attribute] + all_new_named_buffers.append(OrderedDict(module.named_buffers())) + # pyrefly: ignore [missing-attribute] + all_new_named_modules.append(OrderedDict(module.named_modules())) + + num_orig_modules = len(all_orig_named_modules) + num_new_modules = len(all_new_named_modules) + if num_orig_modules != num_new_modules: + raise AssertionError( + f"{func.__name__} should return the same number of modules as input modules" + f"Inputs: {num_orig_modules} modules\n" + f"Outputs: {num_new_modules} modules" + ) + + def check_fqn(orig_fqns: list[str], new_fqns: list[str], check_key: str): + if orig_fqns == new_fqns: + return + + orig_fqn_set, new_fqn_set = set(orig_fqns), set(new_fqns) + orig_only = orig_fqn_set - new_fqn_set + new_only = new_fqn_set - orig_fqn_set + if len(orig_only) or len(new_only): + raise RuntimeError( + f"{check_key}" + "Composable distributed API implementations cannot modify FQNs.\n" + f"FQNs only in original: {orig_only}\n" + f"FQNs only in new: {new_only}" + ) + else: + raise RuntimeError( + f"{check_key}" + "Composable distributed API implementations cannot modify " + "the order of FQNs.\n" + f"Original FQNs: {orig_only}\n" + f"New FQNs: {new_only}" + ) + + for orig_named_params, new_named_params in zip( + all_orig_named_params, all_new_named_params + ): + check_fqn( + list(orig_named_params.keys()), + list(new_named_params.keys()), + "Checking parameters: ", + ) + for orig_named_buffers, new_named_buffers in zip( + all_orig_named_buffers, all_new_named_buffers + ): + check_fqn( + list(orig_named_buffers.keys()), + list(new_named_buffers.keys()), + "Checking buffers: ", + ) + for orig_named_modules, new_named_modules in zip( + all_orig_named_modules, all_new_named_modules + ): + check_fqn( + list(orig_named_modules.keys()), + list(new_named_modules.keys()), + "Checking modules: ", + ) + + # TODO: verify that installed distributed paradigms are compatible with + # each other. + + # pyrefly: ignore [bad-return] + return updated + + def get_state(module: nn.Module) -> _State: + return module.__dict__.setdefault( # type: ignore[call-overload] + STATE_KEY, + {}, # TODO(@yhcharles): this is a temporary fix, need a better way + ).get(func) # type: ignore[call-overload] + + wrapper.state = get_state # type: ignore[attr-defined] + + return wrapper # type: ignore[return-value] + + return inner # type: ignore[return-value] + + +def _get_registry(module: nn.Module) -> dict[str, RegistryItem] | None: + r""" + Get an ``OrderedDict`` of composable APIs that have been applied to the + ``module``, indexed by the API name. If no API has been applied, then this + returns ``None``. + """ + return getattr(module, REGISTRY_KEY, None) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_composable/fsdp/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_composable/fsdp/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..108c765ba4766bf7d9110aa67e09ac02cab00410 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_composable/fsdp/__init__.py @@ -0,0 +1,3 @@ +from torch.distributed.fsdp import CPUOffloadPolicy, MixedPrecisionPolicy, OffloadPolicy + +from .fully_shard import FSDPModule, fully_shard, register_fsdp_forward_method diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_composable/fsdp/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_composable/fsdp/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b9f31a658dc50036a64d23c0fa2ffdcd507268d5 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_composable/fsdp/__pycache__/__init__.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_composable/fsdp/__pycache__/fully_shard.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_composable/fsdp/__pycache__/fully_shard.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4b274ddbf8b58b068a526f98c56c778baab4c32d Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_composable/fsdp/__pycache__/fully_shard.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_composable/fsdp/fully_shard.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_composable/fsdp/fully_shard.py new file mode 100644 index 0000000000000000000000000000000000000000..9e36c7b430fc89dd58cc5742f299ac607eb4367b --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_composable/fsdp/fully_shard.py @@ -0,0 +1,8 @@ +# TODO: For backward compatibility, we are importing the public objects +# originally from this file. +from torch.distributed.fsdp import ( # noqa: F401 + FSDPModule, + fully_shard, + register_fsdp_forward_method, + UnshardHandle, +) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_composable/replicate.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_composable/replicate.py new file mode 100644 index 0000000000000000000000000000000000000000..8cdec49468703e53b0a125a0d3c71a92ec80d00c --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_composable/replicate.py @@ -0,0 +1,254 @@ +# mypy: allow-untyped-defs +import weakref +from collections.abc import Iterable +from typing import Any, NoReturn + +import torch +import torch.nn as nn +from torch.distributed._composable_state import _State +from torch.nn.parallel import DistributedDataParallel + +from .contract import _get_registry, contract + + +_ROOT_MODULE_PREFIX = "" + + +class _ReplicateState(_State): + _ddp_weakref: weakref.ref + + def __init__(self) -> None: + super().__init__() + self.module: nn.Module = nn.ParameterList() + self.has_initialized: bool = False + self._param_list: nn.ParameterList = nn.ParameterList() + # TODO(@fegin): this variable is originally create for testing, we + # should remove this if possible. + self._orig_module = self.module + self._param_names: list[str] = [] + self._no_sync: bool = False + self._init_args: tuple[Any, ...] | None = None + self._init_kwargs: dict[str, Any] = {} + self._comm_hook_args: list[Any] = [] + + def _collect_params( + self, + module: nn.Module, + ignored_modules: set[nn.Module], + ignored_params: set[nn.Parameter], + prefix: str = _ROOT_MODULE_PREFIX, + ) -> None: + # skip if managed by fully_sharded API + if _is_fully_sharded(module): + return + + # if a module is ignored, all descendants of the module are ignored. + if module in ignored_modules: + return + + recurse_prefix = ( + f"{prefix}." if prefix != _ROOT_MODULE_PREFIX else _ROOT_MODULE_PREFIX + ) + + for n, p in module.named_parameters(recurse=False): + if p not in ignored_params: + self._param_list.append(p) + self._param_names.append(f"{recurse_prefix}{n}") + + for name, child_module in module.named_children(): + self._collect_params( + child_module, + ignored_modules, + ignored_params, + prefix=f"{recurse_prefix}{name}", + ) + + def lazy_init(self) -> None: + @torch._disable_dynamo(recursive=True) + def _lazy_init(): + assert self._init_args is not None + self.init(*self._init_args, **self._init_kwargs) + self.register_comm_hook() + self._init_args = () + self._init_kwargs = {} + + _lazy_init() + + def init( + self, + module: nn.Module, + ignored_modules: set[nn.Module], + **kwargs, + ) -> None: + if self.has_initialized: + return + + self.has_initialized = True + self.module = module + ignored_params = {p for m in ignored_modules for p in m.parameters()} + for submodule in module.modules(): + if _is_fully_sharded(submodule): + ignored_params.update(submodule.parameters()) + from torch.distributed.tensor.parallel.ddp import _localize_dtensor + + _localize_dtensor(module, ignored_params=ignored_params) + self._collect_params(module, ignored_modules, ignored_params) + + if "device_id" in kwargs: + # replicate() supports a small usability enhancement where + # user can pass in device_id as a Union[int, torch.device] even for + # CPU devices so users don't have to change code for CPU/GPU runs. + # We derive the right device_ids to feed into DDP to support this. + if kwargs["device_id"] is not None: + device_id = kwargs["device_id"] + # Convert to device_ids that DDP expects. + if isinstance(device_id, torch.device) and device_id.type == "cpu": + # CPU modules receive device_ids None + kwargs["device_ids"] = None + else: + # GPU modules expect device_ids=[cuda_device] + kwargs["device_ids"] = [device_id] + else: + kwargs["device_ids"] = None + kwargs.pop("device_id") + + self._ddp = DistributedDataParallel(self._param_list, **kwargs) + # Weakref to the DDP instance is currently only used for testing. + replicate.state(self.module)._ddp_weakref = weakref.ref(self._ddp) + + def register_comm_hook(self) -> None: + for comm_args, comm_kwargs in self._comm_hook_args: + self._ddp.register_comm_hook(*comm_args, **comm_kwargs) + self._comm_hook_args.clear() + + def record_init_args(self, *args, **kwargs) -> None: + self._init_args = args + self._init_kwargs = kwargs + + def forward_pre_hook( + self, module: nn.Module, args: tuple[Any, ...], kwargs: dict[str, Any] + ) -> Any: + if self._init_args or self._init_kwargs: + self.lazy_init() + self._ddp.require_backward_grad_sync = not self._no_sync + return self._ddp._pre_forward(*args, **kwargs) + + def forward_post_hook( + self, + module: nn.Module, + input: tuple[torch.Tensor], + output: torch.Tensor, + ) -> torch.Tensor: + return self._ddp._post_forward(output) + + +def unimplemented_deepcopy(*args: Any, **kwargs: Any) -> NoReturn: + raise AssertionError( + "DDP does not support deepcopy. Please use state dict for serialization." + ) + + +# Follow the same pattern as FSDP/fully_shard +class DDP: + def __new__(cls, *args, **kwargs): + """ + Override ``__new__`` to remove the DDP class and directly construct + the original class for cases like indexing into a container module. + """ + # Use index 2 since 0 is the dynamically constructed `DDP<...>` class + # and index 1 is the `DDP` class itself + orig_cls = cls.__mro__[2] + return orig_cls.__new__(orig_cls, *args, **kwargs) + + def set_requires_gradient_sync(self, requires_gradient_sync: bool) -> None: + """ + Sets if the module should sync gradients. This can be used to implement + gradient accumulation without communication. + + Args: + requires_gradient_sync (bool): Whether to reduce gradients for the + module's parameters. + """ + replicate.state(self)._no_sync = not requires_gradient_sync # type: ignore[arg-type] + + def register_comm_hook(self, *args, **kwargs) -> None: + replicate.state(self)._comm_hook_args.append((args, kwargs)) # type: ignore[arg-type] + + +@contract(state_cls=_ReplicateState) +def replicate( + module: nn.Module, + ignored_modules: Iterable[torch.nn.Module] | None = None, + **kwargs, +) -> nn.Module: + r"""Replicates a module + + Args: + module (torch.nn.Module): module to replicate + + Example:: + >>> # xdoctest: +REQUIRES(module:torch._C._distributed_c10d) + >>> module = nn.Linear(3, 3) + >>> replicate(module) + """ + torch._C._log_api_usage_once("torch.distributed.replicate") + + # TODO(fegin): using kwargs is not a good idea if we would like to make + # replicate a formal API to replace DDP. + if "device_id" in kwargs: + if not isinstance(kwargs["device_id"], (int, torch.device)): + raise RuntimeError( + "Expected device_id to be int or torch.device, " + f"but got {type(kwargs['device_id'])}" + ) + + if _is_fully_sharded(module): + raise RuntimeError( + "Cannot apply `replicate()` on a Module already managed by `fully_shard`" + ) + + if ignored_modules is None: + ignored_modules = {} + else: + ignored_modules = set(ignored_modules) + + state = replicate.state(module) + module.register_forward_pre_hook(state.forward_pre_hook, with_kwargs=True) + device_mesh = kwargs.get("device_mesh") + if device_mesh is not None: + root_mesh = device_mesh._get_root_mesh() + # if a root mesh is not the same as device_mesh, + # meaning the device_mesh is sliced out from the root mesh. + if root_mesh != device_mesh: + # TODO: This is a temporary work around to enable DDP + TP. + # We should do the logic in DDP so that the 2D implementation is + # sound and the state_dict works out of the box. + # + # This won't conflict with what is done in DDP class as the module + # replicate is going to pass is NOT the original module. + from torch.distributed.tensor.parallel.ddp import ( + _localize_dtensor, + _reconstruct_dtensor, + ) + + module.register_forward_pre_hook(_reconstruct_dtensor) + module.register_forward_hook(_localize_dtensor) + + module.register_forward_hook(state.forward_post_hook) # type: ignore[arg-type] + + state.record_init_args(module, ignored_modules, **kwargs) + + # Place DDP leftmost for highest priority in the method resolution order + cls = module.__class__ + dct = {"__deepcopy__": unimplemented_deepcopy} + new_cls = type(f"DDP{cls.__name__}", (DDP, cls), dct) + module.__class__ = new_cls + return module + + +def _is_fully_sharded(module: nn.Module) -> bool: + r"""Check if module is marked with fully_shard.""" + registry = _get_registry(module) + if registry is None: + return False + return "fully_shard" in registry diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_composable/replicate_with_fsdp.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_composable/replicate_with_fsdp.py new file mode 100644 index 0000000000000000000000000000000000000000..6c242323bcac82a55198f6f768ff5bd60c01595f --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_composable/replicate_with_fsdp.py @@ -0,0 +1,408 @@ +# mypy: allow-untyped-defs +from __future__ import annotations + +import logging +from typing import overload + +import torch +import torch.distributed as dist +import torch.nn as nn +from torch.distributed._composable_state import _get_module_state, _insert_module_state +from torch.distributed.device_mesh import _get_device_handle +from torch.distributed.fsdp._fully_shard._fsdp_api import ( + MixedPrecisionPolicy, + OffloadPolicy, +) +from torch.distributed.fsdp._fully_shard._fsdp_common import ( + DDPMeshInfo, + detect_compiled_autograd, +) +from torch.distributed.fsdp._fully_shard._fsdp_init import ( + _get_device_from_mesh, + _get_managed_states, + _init_default_fully_shard_mesh, + _move_states_to_device, +) +from torch.distributed.fsdp._fully_shard._fsdp_param_group import FSDPParamGroup +from torch.distributed.fsdp._fully_shard._fsdp_state import ( + _register_group_forward_hooks, + FSDPState, +) +from torch.distributed.fsdp._fully_shard._fully_shard import ( + _unimplemented_deepcopy, + FSDPModule, +) +from torch.distributed.tensor import DeviceMesh, init_device_mesh +from torch.distributed.utils import _get_root_modules + +from .contract import _get_registry, contract + + +cls_to_replicate_cls: dict[type, type] = {} + +_ROOT_MODULE_PREFIX = "" + +logger = logging.getLogger("torch.distributed._composable.replicate_with_fsdp") + + +class _ReplicateStateContext: + """This has state shared across Replicate states.""" + + def __init__(self) -> None: + # All Replicate states in the root state's module tree + self.all_states: list[_ReplicateState] = [] + # Iteration's forward root runs the once-per-forward logic; this root + # may not be the overall root set by lazy initialization in cases where + # only a submodule runs forward (e.g. encoder-only for eval) + self.iter_forward_root: _ReplicateState | None = None + # Final callback should only be queued once per backward + self.post_backward_final_callback_queued: bool = False + # Whether to finalize backward in this backward's final callback + self.is_last_backward: bool = True + # Optional user-provided event recorded after optimizer for the + # all-gather streams to wait on in the root pre-forward + self.post_optim_event: torch.Event | None = None + + +def _get_module_replicate_state(module: nn.Module) -> _ReplicateState | None: + """Checks if module state is ReplicateState""" + state = _get_module_state(module) + if isinstance(state, _ReplicateState): + return state + return None + + +class _ReplicateState(FSDPState): + """ + Replicate state functionality is adapted from FSDP state. + In the future, could experiment with inheriting from it instead. + """ + + def __init__(self) -> None: + super().__init__() + self._state_ctx = _ReplicateStateContext() # type: ignore[assignment] + + # Define a separate init since `__init__` is called in the contract + def init( + self, + modules: tuple[nn.Module, ...], + device: torch.device, + mp_policy: MixedPrecisionPolicy, + auto_reshard_after_forward: bool = False, + ) -> None: + for module in modules: + _insert_module_state(module, self) + self._modules = modules + # pyrefly: ignore [read-only] + self._device = device + self._device_handle = _get_device_handle(device.type) + self._mp_policy = mp_policy + self._auto_reshard_after_forward = auto_reshard_after_forward + if len(modules) == 1: + self._pre_forward_hook_handle = modules[0].register_forward_pre_hook( + self._pre_forward, prepend=True, with_kwargs=True + ) + self._post_forward_hook_handle = modules[0].register_forward_hook( + self._post_forward, prepend=False + ) + else: + hook_handle = _register_group_forward_hooks( + modules, + self._pre_forward, + self._post_forward, + self._modules_to_run_forward, + ) + self._pre_forward_hook_handle = hook_handle + self._post_forward_hook_handle = hook_handle + + def _lazy_init(self) -> None: + """ + Lazy initialization represents when all modules' parallelisms have + finalized (e.g. Replicate has been applied to all desired modules). This + means that we can determine which state is the root, and we do so by + the 1st state to run forward. + """ + if self._is_root is not None: + return # no-op: already initialized + self._is_root = True + if len(self._modules) > 1: + raise RuntimeError( + f"Replicate requires a single root module but got {self._modules}" + ) + detect_compiled_autograd() + root_module = self._modules[0] + visited_states: set[_ReplicateState] = set() + for module_name, module in root_module.named_modules(): + if (state := _get_module_replicate_state(module)) is None: + continue + if module is not root_module: + if state not in visited_states and state._is_root is not None: + raise RuntimeError( + "Replicate state has already been lazily initialized for " + f"{module_name}\nReplicate requires running forward through " + "the root module first" + ) + state._is_root = False + self._state_ctx.all_states.append(state) + # pyrefly: ignore [bad-argument-type] + visited_states.add(state) + if self._fsdp_param_group and self._auto_reshard_after_forward: + # For the root, do not reshard after forward since for training, + # the parameters would be freed and all-gathered immediately + self._fsdp_param_group.post_forward_mesh_info = None + self._init_fqns() + self._init_shared_state() + # Run parameter group lazy inits after initializing FQNs for improved + # error messages + for state in self._state_ctx.all_states: # type: ignore[assignment] + if state._fsdp_param_group: # type: ignore[union-attr] + state._fsdp_param_group.lazy_init() # type: ignore[union-attr] + + +def replicate_impl( + module, + mesh: DeviceMesh, + *, + device_id: int | torch.device | None = None, + mp_policy: MixedPrecisionPolicy = MixedPrecisionPolicy(), + offload_policy: OffloadPolicy = OffloadPolicy(), + ignored_params: set[nn.Parameter] | None = None, +): + torch._C._log_api_usage_once("torch.distributed._composable.replicate_with_fsdp") + if isinstance(module, (nn.ModuleList, nn.ModuleDict)): + raise ValueError( + f"replicate does not support containers that do not implement forward: {module}" + ) + + mesh = mesh or _init_default_fully_shard_mesh() + if mesh.ndim != 1: + raise ValueError(f"replicate expects a 1D DeviceMesh but got {mesh}") + + else: + if mesh.mesh_dim_names is None: + raise AssertionError( + "Please init the 2D mesh for HSDP with mesh_dim_names specified" + ) + mesh_info = DDPMeshInfo(mesh, replicate_mesh_dim=0) + device = _get_device_from_mesh(mesh) + + post_forward_mesh_info = None + + arg_module = module + modules = ( + (module,) if isinstance(module, nn.Module) else tuple(_get_root_modules(module)) + ) + state = replicate.state(modules[0]) # type: ignore[attr-defined] # see [1] + state.init(modules, device, mp_policy) + + managed_modules = _get_managed_modules(modules, ignored_params) + params, buffers = _get_managed_states(managed_modules, ignored_params) + + _move_states_to_device(params, buffers, device) + if params: + state._fsdp_param_group = FSDPParamGroup( + params, + modules, + mesh_info, # type: ignore[arg-type] + post_forward_mesh_info, + device, + None, + mp_policy, + offload_policy, + ) + + # Place Replicate leftmost for highest priority in the method resolution order + for module in modules: + cls = module.__class__ + new_cls = cls_to_replicate_cls.get(cls) + if not new_cls: + dct = {"__deepcopy__": _unimplemented_deepcopy} + new_cls = type(f"Replicate{cls.__name__}", (ReplicateModule, cls), dct) + cls_to_replicate_cls[cls] = new_cls + module.__class__ = new_cls + return arg_module + + +@overload +# pyrefly: ignore [inconsistent-overload] +def replicate( + module: nn.Module, + *, + mesh: DeviceMesh | None = ..., + mp_policy: MixedPrecisionPolicy = ..., + offload_policy: OffloadPolicy = ..., + ignored_params: set[nn.Parameter] | None = ..., +) -> ReplicateModule: ... + + +@overload +# pyrefly: ignore [inconsistent-overload] +def replicate( + module: list[nn.Module], + *, + mesh: DeviceMesh | None = ..., + mp_policy: MixedPrecisionPolicy = ..., + offload_policy: OffloadPolicy = ..., + ignored_params: set[nn.Parameter] | None = ..., +) -> list[ReplicateModule]: ... + + +@contract(state_cls=_ReplicateState) # type: ignore[misc] +def replicate( + module: nn.Module, + *, + mesh: DeviceMesh | None = None, + mp_policy: MixedPrecisionPolicy = MixedPrecisionPolicy(), + offload_policy: OffloadPolicy = OffloadPolicy(), + ignored_params: set[nn.Parameter] | None = None, +): + r"""Replicates a module + + Args: + module (torch.nn.Module): module to replicate + + Example:: + >>> # xdoctest: +REQUIRES(module:torch._C._distributed_c10d) + >>> module = nn.Linear(3, 3) + >>> replicate(module) + """ + + if not is_composable_with_replicate(module): + raise RuntimeError( + "Cannot apply `replicate()` on a Module already managed by `fully_shard`" + ) + + if mesh is None: + mesh = replicate_mesh() + + return replicate_impl( + module, + mesh, + mp_policy=mp_policy, + offload_policy=offload_policy, + ignored_params=ignored_params, + ) + + +class ReplicateModule(FSDPModule): + def __new__(cls, *args, **kwargs): + """ + Override ``__new__`` to remove the FSDP class and directly construct + the original class for cases like indexing into a container module. + """ + # Use index 2 since 0 is the dynamically constructed `FSDP<...>` class + # and index 1 is the `FSDPModule` class itself + orig_cls = cls.__mro__[3] + self = orig_cls.__new__(orig_cls, *args, **kwargs) + self.__init__(*args, **kwargs) + return self + + +def _get_managed_modules( + root_modules: tuple[nn.Module, ...], + ignored_params: set[nn.Parameter] | None = None, +) -> list[nn.Module]: + modules: list[nn.Module] = [] + root_modules_set = set(root_modules) + # Track visisted modules to avoid visiting shared modules multiple times + visited_modules: set[nn.Module] = set() + + def dfs(module: nn.Module) -> None: + """ + Runs a DFS to collect managed modules, not recursing into modules with + a non-composable API or ``replicate`` already applied. + """ + if not is_composable_with_replicate(module): + return + elif ( + module not in root_modules_set + and _get_module_replicate_state(module) is not None + ): + return # nested `fully_shard` module + visited_modules.add(module) + for submodule in module.children(): + if submodule not in visited_modules: + dfs(submodule) + modules.append(module) + + for root_module in root_modules: + dfs(root_module) + + if ignored_params is None: + return modules + + adjusted_modules = _adjust_managed_modules(modules, ignored_params) + return adjusted_modules + + +def is_composable_with_replicate(module: nn.Module) -> bool: + """Checks if replicate can be applied with module""" + registry = _get_registry(module) + if registry is None: + return True + # Registry keys by function name + return "fully_shard" not in registry + + +def replicate_mesh(): + """Creates a device mesh for replicate if the user doesn't provide one""" + if not dist.distributed_c10d.is_initialized(): + dist.distributed_c10d.init_process_group() + default_pg = dist.distributed_c10d._get_default_group() + device = torch._C._get_accelerator() + mesh = init_device_mesh( + device.type, + mesh_shape=(default_pg.size(),), + mesh_dim_names=("replicate",), + ) + return mesh + + +def _adjust_managed_modules( + modules: list[nn.Module], ignored_params: set[nn.Parameter] +) -> list[nn.Module]: + """ + Adjust the given list of managed modules by removing those with all parameters ignored. + """ + ignore_decision: dict[nn.Module, bool] = {} + new_modules = [] + for module in modules: + ignored = _ignore_module(module, ignored_params, ignore_decision) + if not ignored: + new_modules.append(module) + return new_modules + + +def _ignore_module( + module: nn.Module, + ignored_params: set[nn.Parameter], + ignore_decision: dict[nn.Module, bool], +) -> bool: + """ + Decide if it is safe to ignore a module for applying replicate. + """ + if module in ignore_decision: + return ignore_decision[module] + + if len(list(module.buffers(recurse=False))) > 0: + # Cannot ignore a module with any buffer + ignore_decision[module] = False + return False + + for _, param in module.named_parameters(recurse=False): + if param not in ignored_params: + # at least one param is not ignored. So this module shouldn't be. + ignore_decision[module] = False + return False + + # Need to consider descendants of module + for child in list(module.children()): + ignore_child = _ignore_module(child, ignored_params, ignore_decision) + if not ignore_child: + # Cannot ignore module if one of its children is not ignored + ignore_decision[module] = False + return False + + # Safe to ignore module + ignore_decision[module] = True + return True diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_local_tensor/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_local_tensor/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..632f28224193de697b7eb96608a262f58dd6363d --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_local_tensor/__init__.py @@ -0,0 +1,1965 @@ +from ast import Call + +from torch._ops import OpOverload + + +""" +A LocalTensor is a tensor subclass which simulates a tensor that is +distributed across SPMD ranks. A LocalTensor might be size N, but in fact +there are world_size shards/replicas of it stored internally. When you do a +plain PyTorch operation on it, we apply the operation to each shard; when you +do a collective, we do the mathematically equivalent operation on the local +shards. A LocalTensor is associated with a list of ranks which specify +which ranks it holds local tensors for. + +NB, this is NOT a DataParallel like abstraction where you can run operations +on multiple different GPUs. It is intended purely for *debugging* purposes, +the overhead is almost certainly too high to keep eight GPUs (even the C++ +autograd needs multithreading to keep up!) (It might potentially be possible +to trace through this with torch.compile and then compile it with CUDA graphs +but this is currently a non-goal.) + +We do not directly handling MPMD. However in practice even in SPMD you may +encounter divergence in behavior per rank (for example, uneven sharding +across ranks). To support scenarios like this, we provide a helper decorator +that allows you to run a function with no side effects for each LocalTensor +shard and combine results back into LocalTensor or LocalIntNode. + +NB: This is a torch dispatch Tensor subclass, as we want to assume that autograd +is SPMD, so we run it once, and dispatch the inner autograd calls to the individual +local shards. + +NOTE ABOUT MESH: This subclass requires collectives that are issued to it to +respect a DeviceMesh like abstraction. The reason for this is that when +DTensor issues us a collective for a particular rank, you will be asked to do +this on a specific process group which involves some ranks. However, this +will only be for the LOCAL PG that this particular rank is participating in; +there will be a bunch of other PGs for other nodes that you don't get to see. +We need to be able to reverse engineer all of the collectives that don't +involve the current local rank here to actually issue them. This can be done +two ways: (1) looking at the participating local ranks in the PG and computing +the complement which specifies all the other collectives you have to run, or +(2) retrieving the device mesh axis corresponding to the PG for this rank, and +then running all the fibers for this. +""" + +import contextlib +import copy +import functools +import operator +import os +import sys +import threading +from collections import defaultdict +from collections.abc import Callable, Generator, Sequence +from types import TracebackType +from typing import Any, Optional, ParamSpec, TypeVar, Union + + +try: + import numpy as np + + HAS_NUMPY = True +except ModuleNotFoundError: + HAS_NUMPY = False + np = None # type: ignore[assignment] + +import torch +import torch.distributed as dist +from torch import Size, SymBool, SymInt, Tensor +from torch._C import DispatchKey, DispatchKeySet, ScriptObject +from torch._export.wrappers import mark_subclass_constructor_exportable_experimental +from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode +from torch.distributed import DeviceMesh, ProcessGroup +from torch.distributed._functional_collectives import AsyncCollectiveTensor +from torch.distributed.distributed_c10d import _get_default_group +from torch.fx.experimental._constant_symnode import ConstantIntNode +from torch.nested._internal.nested_int import NestedIntNode +from torch.utils import _pytree as pytree +from torch.utils._mode_utils import no_dispatch +from torch.utils._python_dispatch import ( + _get_current_dispatch_mode_stack, + return_and_correct_aliasing, + TorchDispatchMode, +) +from torch.utils.checkpoint import get_device_states, set_device_states + + +_R = TypeVar("_R") +_P = ParamSpec("_P") + +not_implemented_log = torch._logging.getArtifactLogger(__name__, "not_implemented") + + +from . import _c10d + + +def _is_in_fake_tensor_mode() -> bool: + return any( + isinstance(mode, FakeTensorMode) for mode in _get_current_dispatch_mode_stack() + ) + + +def _reduce_multidim_lists( + lists_to_reduce: list[Any], reduce_func: Callable[[list[Any]], Any] +) -> Any: + """ + Reduces a list of multi-dimensional lists, assuming they all have + the exact same shape. + + Args: + lists_to_reduce (list): A list where each item is a multi-dimensional + list (e.g., [md_list_1, md_list_2, ...]). + All inner md_lists must have the same shape. + reduce_func (callable): A function that takes an iterable (list) of + values and returns a single reduced value. + For example: sum, max, min, or + lambda x: sum(x) / len(x) for mean. + + Returns: + A single multi-dimensional list of the same shape as the inputs, + where each value is the result of the reduce_func. + + Raises: + ValueError: If the input list is empty or if shapes are inconsistent + (which may also raise IndexError or TypeError). + """ + if not lists_to_reduce: + raise ValueError("Input 'lists_to_reduce' cannot be empty.") + + # Get the first list to inspect its structure (shape) + first_list = lists_to_reduce[0] + + # Check if the first element of this list is *also* a list. + # This determines if we are at the base case or need to recurse. + if isinstance(first_list[0], list): + # --- RECURSIVE STEP --- + # The elements are lists, so we need to go one level deeper. + + # We find the number of sub-lists from the first list. + # (e.g., for [[1,2], [3,4]], this is 2) + num_sublists = len(first_list) + + result = [] + # Iterate by the index of the sub-lists (e.g., i = 0, then i = 1) + for i in range(num_sublists): + # Build a new list to pass to the recursive call. + # This list will contain the i-th sublist from *each* of the + # input lists. + # e.g., if lists_to_reduce = [ L1, L2 ] and i = 0, + # this creates [ L1[0], L2[0] ] + sublists_to_reduce = [l[i] for l in lists_to_reduce] + + # Recurse and append the result + result.append(_reduce_multidim_lists(sublists_to_reduce, reduce_func)) + return result + else: + # --- BASE CASE --- + # The elements are values (int, float, etc.), not lists. + # We are at the innermost dimension. + + # Find the number of values in the innermost list. + # (e.g., for [1, 2], this is 2) + num_values = len(first_list) + + result = [] + # Iterate by the index of the values (e.g., i = 0, then i = 1) + for i in range(num_values): + # Get the values at this specific position (i) from *all* + # input lists. + # e.g., if lists_to_reduce = [ [1,2], [10,20] ] and i = 0, + # this creates [ 1, 10 ] + values_at_pos = [l[i] for l in lists_to_reduce] + + # Apply the user-provided reduction function to this list of values + # and append the single result. + result.append(reduce_func(values_at_pos)) + return result + + +def _is_inplace_op(op: OpOverload | Callable[..., Any]) -> bool: + return ( + isinstance(op, OpOverload) + # Not precise heuristic to detect inplace operation + and op._schema.name[-1] == "_" + # Strengthen the heuristic to check that the first argument and return value are a write + and len(op._schema.arguments) > 0 + and op._schema.arguments[0].is_write + and len(op._schema.returns) > 0 + and op._schema.returns[0].is_write + ) + + +def _int_on_rank(i: "int | LocalIntNode | ConstantIntNode", r: int) -> int: + if isinstance(i, LocalIntNode): + return i._local_ints[r] + elif isinstance(i, ConstantIntNode): + return i.val + elif isinstance(i, int): + return i + else: + raise AssertionError(type(i)) + + +def _check_for_subclass(flat_args: Sequence[object]) -> bool: + return any(_check_for_subclass_arg(x) for x in flat_args) + + +def _check_for_subclass_arg(x: object) -> bool: + return ( + not isinstance(x, LocalTensor) + and isinstance(x, Tensor) + and type(x) + not in ( + Tensor, + FakeTensor, + torch.nn.Parameter, + torch.nn.Buffer, + ) + ) + + +def _map_to_rank_local_val(val: Any, rank: int) -> Any: + if isinstance(val, LocalTensor): + return val._local_tensors[rank] + if isinstance(val, SymInt): + if isinstance(val.node, LocalIntNode): + return val.node._local_ints[rank] + if isinstance(val.node, ConstantIntNode): + return val.node.val + return val + + +def _collect_accelerator_rng_states() -> dict[int, torch.Tensor]: + """ + Collects RNG state from all available acceleator devices. + + Returns: + List of RNG state tensors, one for each accelerator device. + Returns empty list if accelerator is not available. + """ + if not torch.accelerator.is_available(): + return {} + + if torch.accelerator.is_available(): + device_idx = torch.accelerator.current_device_index() + with torch.accelerator.device_index(device_idx): + return {device_idx: torch.get_device_module().get_rng_state()} + + return {} + + +def _set_accelerator_rng_states(rng_states: dict[int, torch.Tensor]) -> None: + """ + Sets RNG state for all accelerator devices from a list of states. + + Args: + rng_states: List of RNG state tensors to restore. + """ + if not torch.accelerator.is_available(): + return + + if torch.accelerator.is_available(): + for device_idx, device_rng_state in rng_states.items(): + with torch.accelerator.device_index(device_idx): + torch.get_device_module().set_rng_state(device_rng_state) + + +def _get_rng_state() -> tuple[torch.Tensor, dict[int, torch.Tensor]]: + """ + Gets CPU and accelerator (e.g., CUDA, XPU device) rng states from all devices. + """ + return (torch.get_rng_state(), _collect_accelerator_rng_states()) + + +def _set_rng_state( + cpu_state: torch.Tensor, accelerator_states: dict[int, torch.Tensor] +) -> None: + """ + Sets CPU and accelerator (e.g., CUDA, XPU device) rng states for all devices. If + the list of accelerator states is shorter than the number of devices only the + first len(accelerator_states) devices will get their rng state set. + """ + torch.set_rng_state(cpu_state) + _set_accelerator_rng_states(accelerator_states) + + +def _combine_int_rank_results(rank_results: dict[int, int]) -> int | torch.SymInt: + any_v = next(iter(rank_results.values())) + + if all(v == any_v for v in rank_results.values()): + return any_v + + return torch.SymInt(LocalIntNode(rank_results)) + + +def _combine_any_rank_results(rank_results: dict[int, Any]) -> Any: + any_v = next(iter(rank_results.values())) + + if isinstance(any_v, Tensor): + # pyrefly: ignore [bad-argument-type, bad-argument-count] + return LocalTensor(rank_results) + + if isinstance(any_v, int): + return _combine_int_rank_results(rank_results) + + if isinstance(any_v, torch.device): + assert all(v.type == any_v.type for v in rank_results.values()), ( + "device type should be the same" + ) + # Just use the first device - the device type is what matters, + # and LocalTensorMode runs on a single physical device anyway + return any_v + + assert all(v == any_v for v in rank_results.values()), ( + "Non Tensor or int rank results must be equal for all ranks" + ) + + return any_v + + +def _combine_rank_results(rank_results: dict[int, Any], default: Any | None) -> Any: + rank_ids = rank_results.keys() + rank_value = rank_results[next(iter(rank_ids))] + + if isinstance(rank_value, (list, tuple)): + max_rank_result_len = max(len(v) for v in rank_results.values()) + ret_list = [] + for i in range(max_rank_result_len): + rank_col_results = { + r: v[i] if i < len(v) else default for r, v in rank_results.items() + } + ret_list.append(_combine_any_rank_results(rank_col_results)) + return type(rank_value)(ret_list) + else: + return _combine_any_rank_results(rank_results) + + +def _zero_sized_like(tensor: torch.Tensor, dim: int) -> torch.Tensor: + tensor_size = list(tensor.size()) + tensor_size[dim] = 0 + empty_tensor = torch.empty(*tensor_size, dtype=tensor.dtype, device=tensor.device) + return empty_tensor + + +def _for_each_rank_run_func( + func: OpOverload | Callable[..., Any], + ranks: frozenset[int], + args: Sequence[Any], + kwargs: dict[str, Any], + *, + alias: bool = True, +) -> Any: + flat_args, args_spec = pytree.tree_flatten((args, kwargs)) + flat_args = [ + a.wait() if isinstance(a, AsyncCollectiveTensor) else a for a in flat_args + ] + + lm = enabled_local_tensor_mode() + use_per_rank_rng = lm is not None and len(lm._per_rank_rng_states) > 0 + + global_rng_state = None if use_per_rank_rng else _get_rng_state() + + flat_rank_rets = {} + + default_value: Tensor | None = None + for r in sorted(ranks): + if use_per_rank_rng: + assert lm is not None + if r in lm._per_rank_rng_states: + _set_rng_state(*lm._per_rank_rng_states[r]) + else: + assert global_rng_state is not None + _set_rng_state(*global_rng_state) + + rank_flat_args = [_map_to_rank_local_val(a, r) for a in flat_args] + rank_args, rank_kwargs = pytree.tree_unflatten(rank_flat_args, args_spec) + if func is torch.ops.aten.hash_tensor.default and rank_args[0].numel() == 0: + # Special case for empty tensors, hash_tensor returns an empty tensor + rank_ret = torch.empty(0, dtype=torch.uint64, device=rank_args[0].device) + else: + rank_ret = func(*rank_args, **rank_kwargs) + flat_rank_rets[r] = rank_ret + + if use_per_rank_rng: + assert lm is not None + lm._per_rank_rng_states[r] = _get_rng_state() + + if default_value is None and func is torch.ops.aten.split.Tensor: + # If split happens over the dimension smaller than the number of chunks + # it is possible that some ranks will produce shorter lists of chunks. + # In order to make the result across all ranks of the same length we + # append empty tensors (zero size on the split dimension). + tensor = rank_flat_args[0] + split_dim = 0 if len(rank_flat_args) < 3 else rank_flat_args[2] + default_value = _zero_sized_like(tensor, split_dim) + + if _is_inplace_op(func): + alias = False + # For the in-place ops return self + ret = args[0] + if isinstance(func, OpOverload) and torch.Tag.inplace_view in func.tags: + # Ensure that wrapper tensor size is synchronized with its local tensors + ret._sync_meta() + else: + ret = _combine_rank_results(flat_rank_rets, default_value) + + if alias: + return return_and_correct_aliasing(func, args, kwargs, ret) + else: + return ret + + +def _get_extra_dispatch_keys(t: torch.Tensor) -> DispatchKeySet: + extra_dispatch_keys = torch._C.DispatchKeySet.from_raw_repr(0) + if torch._C._dispatch_keys(t).has(torch._C.DispatchKey.Conjugate): + extra_dispatch_keys = extra_dispatch_keys.add(torch._C.DispatchKey.Conjugate) + if torch._C._dispatch_keys(t).has(torch._C.DispatchKey.Negative): + extra_dispatch_keys = extra_dispatch_keys.add(torch._C.DispatchKey.Negative) + return extra_dispatch_keys + + +class LocalIntNode: + """ + Like a LocalTensor, but for an int. We can't use a 0D tensor to represent this + because often only a SymInt is accepted where we wish to use this. + """ + + def __new__(cls, local_ints: dict[int, int]) -> "ConstantIntNode | LocalIntNode": # type: ignore[misc] + if len(set(local_ints.values())) == 1: + return ConstantIntNode(next(iter(local_ints.values()))) + return super().__new__(cls) + + def __init__(self, local_ints: dict[int, int]): + self._local_ints = local_ints + + def maybe_as_int(self) -> int | None: + return None + + def is_int(self) -> bool: + return True + + def is_float(self) -> bool: + return False + + def is_bool(self) -> bool: + return False + + def is_nested_int(self) -> bool: + return False + + def clone(self) -> "LocalIntNode": + return self + + def _str(self) -> str: + return f"LocalIntNode({self._local_ints})" + + def __str__(self) -> str: + return self._str() + + def __repr__(self) -> str: + return self._str() + + def _graph_repr(self) -> str: + return self._str() + + def is_symbolic(self) -> bool: + return False + + def is_constant(self) -> bool: + return False + + def sym_max( + self, other: "int | LocalIntNode | ConstantIntNode" + ) -> "LocalIntNode | ConstantIntNode": + return LocalIntNode( + { + r: max(self._local_ints[r], _int_on_rank(other, r)) + for r in self._local_ints + } + ) + + def sym_sum(self, other: Any) -> "LocalIntNode | ConstantIntNode": + t = LocalIntNode(dict.fromkeys(self._local_ints, 0)) + for o in other: + t = t.add(o) + return t + + def neg(self) -> "LocalIntNode | ConstantIntNode": + return LocalIntNode({r: -self._local_ints[r] for r in self._local_ints}) + + def add( + self, other: "int | LocalIntNode | ConstantIntNode" + ) -> "LocalIntNode | ConstantIntNode": + return LocalIntNode( + {r: self._local_ints[r] + _int_on_rank(other, r) for r in self._local_ints} + ) + + def sub( + self, other: "int | LocalIntNode | ConstantIntNode" + ) -> "LocalIntNode | ConstantIntNode": + return LocalIntNode( + {r: self._local_ints[r] - _int_on_rank(other, r) for r in self._local_ints} + ) + + def mul( + self, other: "int | LocalIntNode | ConstantIntNode" + ) -> "LocalIntNode | ConstantIntNode": + return LocalIntNode( + {r: self._local_ints[r] * _int_on_rank(other, r) for r in self._local_ints} + ) + + def floordiv( + self, other: "int | LocalIntNode | ConstantIntNode" + ) -> "LocalIntNode | ConstantIntNode": + return LocalIntNode( + {r: self._local_ints[r] // _int_on_rank(other, r) for r in self._local_ints} + ) + + def mod( + self, other: "int | LocalIntNode | ConstantIntNode" + ) -> "LocalIntNode | ConstantIntNode": + return LocalIntNode( + {r: self._local_ints[r] % _int_on_rank(other, r) for r in self._local_ints} + ) + + def int_floordiv( + self, other: "int | LocalIntNode | ConstantIntNode" + ) -> "LocalIntNode | ConstantIntNode": + return LocalIntNode( + {r: self._local_ints[r] // _int_on_rank(other, r) for r in self._local_ints} + ) + + def eq(self, other: "int | LocalIntNode | ConstantIntNode") -> bool | SymBool: + r = {self._local_ints[r] == _int_on_rank(other, r) for r in self._local_ints} + return torch._C._get_constant_bool_symnode(len(r) == 1 and next(iter(r))) + + def ne(self, other: "int | LocalIntNode | ConstantIntNode") -> bool | SymBool: + r = {self._local_ints[r] != _int_on_rank(other, r) for r in self._local_ints} + return torch._C._get_constant_bool_symnode(len(r) > 1 or next(iter(r))) + + def ge(self, other: "int | LocalIntNode | ConstantIntNode") -> bool | SymBool: + r = {self._local_ints[r] >= _int_on_rank(other, r) for r in self._local_ints} + assert len(r) == 1, (self, other) + return torch._C._get_constant_bool_symnode(next(iter(r))) + + def gt(self, other: "int | LocalIntNode | ConstantIntNode") -> bool | SymBool: + r = {self._local_ints[r] > _int_on_rank(other, r) for r in self._local_ints} + assert len(r) == 1, (self, other) + return torch._C._get_constant_bool_symnode(next(iter(r))) + + def lt(self, other: "int | LocalIntNode | ConstantIntNode") -> bool | SymBool: + r = {self._local_ints[r] < _int_on_rank(other, r) for r in self._local_ints} + assert len(r) == 1, (self, other) + return torch._C._get_constant_bool_symnode(next(iter(r))) + + def wrap_int(self, num: int) -> "LocalIntNode | ConstantIntNode": + return ConstantIntNode(num) + + +class _LocalDeviceHandle: + """ + Wrapper around device module (e.g., torch.cuda) with automatic LocalTensor semantics. + + This class wraps device modules and automatically handles per-rank operations in + LocalTensor mode: + - get_rng_state() returns a LocalTensor with per-rank states + - set_rng_state(LocalTensor) sets per-rank states + + When not in LocalTensor mode, it delegates directly to the underlying device handle. + """ + + def __init__(self, device_handle, device_type: str): + """ + Initialize the local device handle wrapper. + + Args: + device_handle: The underlying device module (e.g., torch.cuda) + device_type: Device type string (e.g., "cuda", "cpu") + """ + self._device_handle = device_handle + self._device_type = device_type + + def get_rng_state(self): + """ + Get RNG state, automatically returning LocalTensor in LocalTensor mode. + + Returns: + LocalTensor in LocalTensor mode, regular Tensor otherwise + """ + lm = enabled_local_tensor_mode() + if not lm: + return self._device_handle.get_rng_state() + + original_state = _get_rng_state() + per_rank_states = {} + + try: + for rank in lm.ranks: + # We need to set-then-get instead of directly copying lm._per_rank_rng_states[rank] + # because they have different structures: + # - lm._per_rank_rng_states[rank] is a tuple: (cpu_state, {device_idx: cuda_state}) + # - self._device_handle.get_rng_state() returns just the device-specific tensor + # So we temporarily restore the full RNG state (CPU + all CUDA devices) for this rank, + # then extract only the specific device's state tensor that we need. + if rank in lm._per_rank_rng_states: + _set_rng_state(*lm._per_rank_rng_states[rank]) + + per_rank_states[rank] = self._device_handle.get_rng_state() + finally: + _set_rng_state(*original_state) + + # pyrefly: ignore [bad-argument-type, bad-argument-count] + return LocalTensor(per_rank_states) + + def set_rng_state(self, state): + """ + Set RNG state, automatically handling LocalTensor input. + + Args: + state: Regular Tensor or LocalTensor with per-rank states + """ + if isinstance(state, LocalTensor): + lm = enabled_local_tensor_mode() + assert lm is not None + + # Similar to get_rng_state but in reverse: we need to convert from + # device-specific tensor format to full state tuple format. + # - state._local_tensors[rank] contains just the device-specific RNG state tensor + # - lm._per_rank_rng_states[rank] needs a tuple: (cpu_state, {device_idx: cuda_state}) + # So we set the device's state with the rank-specific tensor, then _get_rng_state() + # captures both CPU and CUDA states into the tuple format that _per_rank_rng_states expects. + for rank, rank_state in state._local_tensors.items(): + self._device_handle.set_rng_state(rank_state.to("cpu")) + lm._per_rank_rng_states[rank] = _get_rng_state() + else: + self._device_handle.set_rng_state(state.to("cpu")) + + def __getattr__(self, name): + """Delegate all other attributes to the underlying device module.""" + return getattr(self._device_handle, name) + + +class _LocalOffsetBasedRNGTracker: + """ + LocalTensor-specific RNG tracker for DTensor random operations. + + This class manages per-rank RNG states when running in LocalTensor mode, + using _LocalPhiloxState to track different offsets for each virtual rank. + It is instantiated and used by OffsetBasedRNGTracker when in LocalTensor mode. + + Much of this is derived from OffsetBasedRNGTracker: + https://github.com/pytorch/pytorch/blob/402c46503002f98ccfc023a733081fb0719223a1/torch/distributed/tensor/_random.py#L182 + """ + + def __init__(self, device_type: str = "cuda"): + """Initialize the LocalTensor RNG tracker.""" + from torch.distributed.device_mesh import _get_device_handle + + self._device_type = device_type + self._device_handle = _LocalDeviceHandle( + _get_device_handle(device_type), device_type + ) + self.distribute_region_enabled = True + self._device_mesh = None + + @property + def _device(self): + return torch.device(self._device_type, torch.cuda.current_device()) + + def _set_pre_op_offset(self, state, spec) -> None: + """Compute and set per-rank offsets before the random operation.""" + from torch.distributed.tensor._ops.utils import prod + from torch.distributed.tensor._utils import ( + _compute_local_shape_and_global_offset, + ) + from torch.distributed.tensor.placement_types import Shard + + lm = enabled_local_tensor_mode() + assert lm is not None + + state._per_rank_offsets = {} + + for rank in lm.ranks: + # compute this rank's coordinate in the mesh + mesh_coords = [] + for mesh_dim_idx in range(spec.mesh.ndim): + mesh_dim_size = spec.mesh.size(mesh_dim_idx) + # calculate rank's coordinate in this mesh dimension + num_chunks_after = 1 + for j in range(mesh_dim_idx + 1, spec.mesh.ndim): + num_chunks_after *= spec.mesh.size(j) + coord = (rank // num_chunks_after) % mesh_dim_size + mesh_coords.append(coord) + + # compute shard offset based on placements + from torch.distributed.tensor._random import ( + _calc_first_shard_size, + _calc_shard_info, + _calc_shard_linear_idx, + ) + + # Compute shard index and total number of shards on each tensor dim + shard_idx_by_dim, total_num_shards_by_dim = _calc_shard_info( + mesh_coords, spec + ) + + # compute shard linear index + shard_linear_idx = _calc_shard_linear_idx( + shard_idx_by_dim, total_num_shards_by_dim + ) + + # get current offset for this rank + current_offset = int( + state._per_rank_states[rank][8:].view(dtype=torch.int64).item() + ) + + local_shape = _calc_first_shard_size(spec) + # compute local size + local_size = prod(local_shape) + + # compute new offset (must be multiple of 4) + offset_incr = (shard_linear_idx * local_size + 3) // 4 * 4 + state._per_rank_offsets[rank] = current_offset + offset_incr + + def _set_post_op_offset(self, state, spec, old_offset) -> None: + """Set per-rank offsets after the random operation.""" + from torch.distributed.tensor._ops.utils import prod + + lm = enabled_local_tensor_mode() + assert lm is not None + + dtensor_shape = spec.shape + numel = prod(dtensor_shape) + # offset must be multiple of 4 + numel = (numel + 3) // 4 * 4 + + if not hasattr(state, "_per_rank_offsets"): + state._per_rank_offsets = {} + + # handle LocalIntNode old_offset (different values per rank) + if isinstance(old_offset, SymInt) and isinstance(old_offset.node, LocalIntNode): + for rank in lm.ranks: + rank_old_offset = old_offset.node._local_ints[rank] + state._per_rank_offsets[rank] = rank_old_offset + numel + else: + # same old_offset for all ranks + old_offset_int = ( + int(old_offset) if isinstance(old_offset, SymInt) else old_offset + ) + for rank in lm.ranks: + state._per_rank_offsets[rank] = old_offset_int + numel + + @contextlib.contextmanager + def _distribute_region(self, spec, generator=None): + """Context manager for LocalTensor mode distribute region.""" + lm = enabled_local_tensor_mode() + assert lm is not None + + # get base state + if generator is not None: + base_state_tensor = generator.get_state() + per_rank_states = {rank: base_state_tensor.clone() for rank in lm.ranks} + # pyrefly: ignore [bad-argument-type, bad-argument-count] + base_state_tensor = LocalTensor(per_rank_states) + else: + base_state_tensor = self._device_handle.get_rng_state() + + state = _LocalPhiloxState(base_state_tensor) + + if self.distribute_region_enabled: + # sync to rank 0's state if no explicit generator + if generator is None: + any_rank_state = lm._any_local_rng_state() + any_rank_cpu, any_rank_cuda = any_rank_state + + if self._device.type == "cuda": + assert self._device.index in any_rank_cuda + any_rank_device_state = any_rank_cuda[self._device.index] + else: + any_rank_device_state = any_rank_cpu + + from torch.distributed.tensor._random import _PhiloxState + + any_rank_philox = _PhiloxState(any_rank_device_state) + state.seed = any_rank_philox.seed + state.offset = any_rank_philox.offset + + old_offset = state.offset + self._set_pre_op_offset(state, spec) + state.apply_to_local_tensor_mode(self._device_handle) + + try: + yield + finally: + self._set_post_op_offset(state, spec, old_offset) + state.apply_to_local_tensor_mode(self._device_handle) + else: + yield + + # maybe reset generator to rank 0's state + if generator is not None: + rank_0_state = state._per_rank_states[0] + generator.set_state(rank_0_state) + + +_LOCAL_TENSOR_ATTR_PREFIX = "_local_tensor_" + + +def _is_local_tensor_attr(attr: str) -> bool: + return attr.startswith(_LOCAL_TENSOR_ATTR_PREFIX) + + +def _to_local_tensor_attr(rank: int) -> str: + return f"{_LOCAL_TENSOR_ATTR_PREFIX}{rank}" + + +def _from_local_tensor_attr(attr: str) -> int: + if not _is_local_tensor_attr(attr): + raise AssertionError(f"Invalid local tensor attr {attr}") + return int(attr[len(_LOCAL_TENSOR_ATTR_PREFIX) :]) + + +def _all_elements_same(values: list[Any]) -> bool: + if not values: + return True + first_value = values[0] + return all(value == first_value for value in values) + + +def _compute_local_tensor_meta( + local_tensors: dict[int, torch.Tensor], +) -> tuple[ + list[torch.SymInt | int], + list[torch.SymInt | int], + torch.device, + torch.dtype, + torch.layout, + DispatchKeySet, +]: + """ + Computes the meta information for a LocalTensor from its local tensors. + """ + it = iter(local_tensors.values()) + first_local_tensor = next(it) + + first_shape = first_local_tensor.shape + first_stride = first_local_tensor.stride() + dtype = first_local_tensor.dtype + device = first_local_tensor.device + layout = first_local_tensor.layout + + extra_dispatch_keys = _get_extra_dispatch_keys(first_local_tensor) + + # Assert that all tensors have the same dtype, layout and dispatch keys. Due + # to uneven sharding, it is possible that tensors will have different shapes. + for local_tensor in it: + assert dtype == local_tensor.dtype, ( + "Tensors representing LocalTensor shards must have the same dtype" + ) + assert layout == local_tensor.layout, ( + "Tensors representing LocalTensor shards must have the same layout" + ) + assert extra_dispatch_keys == _get_extra_dispatch_keys(local_tensor), ( + "Tensors representing LocalTensor shards must have the same set of extra dispatch keys" + ) + + # Compute shape/stride. We allow for non-SPMD'ness here + local_shapes: dict[int, dict[int, int]] = defaultdict(dict) # dim => rank => size + local_strides: dict[int, dict[int, int]] = defaultdict(dict) # dim => rank => size + for r, local_tensor in local_tensors.items(): + for d, size in enumerate(local_tensor.shape): + local_shapes[d][r] = size + local_strides[d][r] = local_tensor.stride(d) + shape = [ + ( + first_shape[d] + if _all_elements_same(list(local_shapes[d].values())) + else torch.SymInt(LocalIntNode(local_shapes[d])) + ) + for d in range(len(first_shape)) + ] + strides = [ + ( + first_stride[d] + if _all_elements_same(list(local_strides[d].values())) + else torch.SymInt(LocalIntNode(local_strides[d])) + ) + for d in range(len(first_shape)) + ] + return shape, strides, device, dtype, layout, extra_dispatch_keys + + +class LocalTensor(torch.Tensor): + """ + LocalTensor is a Tensor subclass that simulates a tensor distributed across multiple SPMD + (Single Program, Multiple Data) ranks. Each LocalTensor instance internally holds a mapping from + global rank ids to their corresponding local Tensor shards.Operations performed on a LocalTensor + are applied independently to each local shard, mimicking distributed computation. Collectives + and other distributed operations are handled by mapping them to the local shards as appropriate. + + Note: + This class is primarily intended for debugging and simulating distributed tensor computations + on a single process. + + """ + + # Map from global rank to the local tensor. + _local_tensors: dict[int, torch.Tensor] + # Precomputed for speed set of keys from the local tensor map. + _ranks: frozenset[int] + _size: list[torch.SymInt | int] + __slots__ = ["_local_tensors", "_ranks", "_size"] + + @staticmethod + @torch._disable_dynamo + def __new__( + cls, + local_tensors: dict[int, torch.Tensor], + requires_grad: bool = False, + ) -> "LocalTensor": + if any(t.requires_grad for t in local_tensors.values()): + raise AssertionError( + "Internal local_tensors require grad, but we will ignore those autograd graph. " + "Make a custom autograd function and make sure you detach the inner tensors." + ) + + (shape, strides, device, dtype, layout, extra_dispatch_keys) = ( + _compute_local_tensor_meta(local_tensors) + ) + + r = torch.Tensor._make_wrapper_subclass( + cls, + shape, + strides=strides, + dtype=dtype, + device=device, + layout=layout, + # In place ops potentially change local tensor sizes (e.g. resize_). While + # executing an in-place op the return value must be the same as "self" input + # otherwise we can introduce errors due to tensor identity changes. Hence we + # need to be able to update wrapper subclass sizes after in-place ops. This + # dispatch policy allows us to do that. + dispatch_sizes_strides_policy="sizes", + requires_grad=requires_grad, + _extra_dispatch_keys=extra_dispatch_keys, + ) + + local_tensors = { + r: v if not isinstance(v, AsyncCollectiveTensor) else v.wait() + for r, v in local_tensors.items() + } + r._local_tensors = local_tensors + r._ranks = frozenset(local_tensors.keys()) + r._size = shape + return r + + @torch._disable_dynamo + @mark_subclass_constructor_exportable_experimental # type: ignore[misc] + def __init__(self, *args: Any, **kwargs: Any): + super().__init__() + + def __deepcopy__(self, memo: dict[Any, Any] | None) -> "LocalTensor": + local_tensors_copy = { + r: copy.deepcopy(t, memo) for r, t in self._local_tensors.items() + } + # pyrefly: ignore [bad-argument-type, bad-argument-count] + return LocalTensor(local_tensors_copy, self.requires_grad) + + def __repr__(self) -> str: # type: ignore[override] + parts = [] + for k, v in self._local_tensors.items(): + # pyrefly: ignore [bad-argument-type] + parts.append(f" {k}: {v}") + tensors_str = ",\n".join(parts) + return f"LocalTensor(\n{tensors_str}\n)" + + def __getattr__(self, name: str) -> Any: + if _is_local_tensor_attr(name): + rank = _from_local_tensor_attr(name) + if rank not in self._ranks: + raise AttributeError(f"Local tensor has no knowledge of rank {rank}") + return self._local_tensors[rank] + return object.__getattribute__(self, name) + + def __tensor_flatten__(self) -> tuple[list[str], tuple[Any, ...]]: + """ + protocol to inform how to flatten a DTensor to local tensor + for PT2 tracing + """ + local_tensor_attrs = [_to_local_tensor_attr(r) for r in self._ranks] + return local_tensor_attrs, () + + @staticmethod + def __tensor_unflatten__( + inner_tensors: dict[str, Any], + flatten_spec: tuple[Any, ...], + outer_size: torch.Size, + outer_stride: tuple[int, ...], + ) -> "LocalTensor": + assert flatten_spec is not None, ( + "Expecting spec to be not None from `__tensor_flatten__` return value!" + ) + local_tensors = { + _from_local_tensor_attr(a): t for a, t in inner_tensors.items() + } + # pyrefly: ignore [bad-argument-type, bad-argument-count] + return LocalTensor(local_tensors) + + @classmethod + @torch._disable_dynamo + def __torch_dispatch__( # type: ignore[override] + cls, + func: Any, + types: tuple[Any, ...], + args: tuple[Any, ...] = (), + kwargs: dict[str, Any] | None = None, + ) -> Any: + if kwargs is None: + kwargs = {} + + # This is horribly inefficient + flat_args, args_spec = pytree.tree_flatten((args, kwargs)) + local_tensor = None + for arg in flat_args: + if isinstance(arg, LocalTensor): + local_tensor = arg + break + + assert local_tensor is not None, ( + "At least one of the arguments must be a LocalTensor" + ) + + # Check for unrecognized tensor subclasses (but allow regular tensors and scalars) + has_unrecognized_types = _check_for_subclass(flat_args) + if has_unrecognized_types: + unrecognized_types = [ + type(x) for x in flat_args if _check_for_subclass_arg(x) + ] + not_implemented_log.debug( + "LocalTensor unrecognized subclass(es): %s", unrecognized_types + ) + return NotImplemented + + with LocalTensorMode(local_tensor._ranks): + return func(*args, **kwargs) + + def numpy(self, *, force: bool = False) -> Any: + if HAS_NUMPY: + return self.reconcile().numpy(force=force) + else: + raise RuntimeError("Numpy is not available") + + def contiguous( + self, + memory_format: torch.memory_format = torch.contiguous_format, + ) -> torch.Tensor: + # pyrefly: ignore [bad-argument-type] + return LocalTensor( + # pyrefly: ignore [bad-argument-count] + { + r: t.contiguous(memory_format=memory_format) + for r, t in self._local_tensors.items() + } + ) + + def is_contiguous( + self, + memory_format: torch.memory_format = torch.contiguous_format, + ) -> bool: + return all( + t.is_contiguous(memory_format=memory_format) + for t in self._local_tensors.values() + ) + + def tolist(self) -> list[Any]: + """ + Try to reconcile, if successful convert to list, otherwise if dtype is integer, + convert to list of local integers. + """ + equal_obj = self._equal_local_tensors() + if isinstance(equal_obj, torch.Tensor): + return equal_obj.tolist() + if isinstance(equal_obj, torch.Size): + if not self.dtype.is_floating_point and not self.dtype.is_complex: + ranks = sorted(self._ranks) + local_lists = [self._local_tensors[r].tolist() for r in ranks] + return _reduce_multidim_lists( + local_lists, + lambda values: torch.SymInt( + LocalIntNode(dict(zip(ranks, values, strict=True))) + ), + ) + + raise RuntimeError("Cannot convert local tensor to list") + + def reconcile(self) -> torch.Tensor: + """ + Reconciles the LocalTensor into a single torch.Tensor by ensuring all local + shards are identical and returning a detached clone of one of them. + + Note: + This method is useful for extracting a representative tensor from a LocalTensor + when all shards are expected to be the same, such as after a collective operation + that synchronizes all ranks. + """ + + # Force all local tensor shards across ranks to be the same + equal_obj = self._equal_local_tensors() + assert isinstance(equal_obj, torch.Tensor), ( + "LocalTensor shards must be the same to reconcile" + ) + cl = equal_obj.clone().detach() + cl.requires_grad_(self.requires_grad) + return cl + + def _equal_local_tensors(self) -> torch.Tensor | torch.Size | None: + it = iter(self._local_tensors.values()) + t1 = next(it) + if all(t2.equal(t1) for t2 in it): + return t1 + if all(t2.shape == t1.shape for t2 in it): + return t1.shape + return None + + def _sync_meta(self) -> None: + with no_dispatch(): + (shape, strides, device, dtype, layout, extra_dispatch_keys) = ( + _compute_local_tensor_meta(self._local_tensors) + ) + self._size = shape + + +# If set to `True` the LocalTensorMode stack will be created for the whole process, +# otherwise it will be created for each thread. +_PROCESS_MODE: bool = True +_PROCESS_LOCAL_TENSOR_MODE: list["LocalTensorMode"] = [] +# When running under local runner each thread must create its own local tensor mode +# so that they do not interfere with each other. +_THREAD_LOCAL_TENSOR_MODE: threading.local = threading.local() + + +def get_local_tensor_mode_list() -> list["LocalTensorMode"]: + global _PROCESS_MODE + if _PROCESS_MODE: + global _PROCESS_LOCAL_TENSOR_MODE + return _PROCESS_LOCAL_TENSOR_MODE + global _THREAD_LOCAL_TENSOR_MODE + if not hasattr(_THREAD_LOCAL_TENSOR_MODE, "value"): + _THREAD_LOCAL_TENSOR_MODE.value = [] + return _THREAD_LOCAL_TENSOR_MODE.value + + +class LocalTensorMode(TorchDispatchMode): + """ + A TorchDispatchMode that simulates SPMD (Single Program, Multiple Data) execution + for LocalTensor objects across a set of ranks. + + LocalTensorMode enables PyTorch operations to be transparently applied to each + local shard of a LocalTensor, as if they were distributed across multiple ranks. + When active, this mode intercepts tensor operations and dispatches them to each + rank's local tensor, collecting and wrapping the results as LocalTensors. It also + handles collective operations by mapping them to local implementations. + + This mode is primarily intended for debugging and simulating distributed tensor + computations on a single process, rather than for high-performance distributed + training. It maintains a stack of active modes, patches DeviceMesh coordinate + resolution, and provides utilities for temporarily disabling the mode or mapping + functions over ranks. + """ + + # What ranks this local tensor mode is operating over + def __init__(self, ranks: int | frozenset[int]): + if isinstance(ranks, int): + # assume is world size + self.ranks = frozenset(range(ranks)) + else: + assert isinstance(ranks, frozenset) + self.ranks = ranks + self._disable = True + self._old_get_coordinate = None + self._old_get_rank = None + self._old_get_local_rank = None + self._old_torch_manual_seed: Any = None + self._old_torch_initial_seed: Any = None + self._per_rank_rng_states: dict[ + int, tuple[torch.Tensor, dict[int, torch.Tensor]] + ] = {} + + self.enable_() + + def __enter__(self) -> "LocalTensorMode": + self.enable_() + get_local_tensor_mode_list().append(self) + + # _distribute_region will compute correct per-shard offsets + # but we want all ranks to start with the same state + if not _is_in_fake_tensor_mode(): + cpu_state, cuda_states = _get_rng_state() + for rank in self.ranks: + self._per_rank_rng_states[rank] = ( + cpu_state.clone(), + {idx: state.clone() for idx, state in cuda_states.items()}, + ) + + return super().__enter__() + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + self.disable_() + get_local_tensor_mode_list().pop() + super().__exit__(exc_type, exc_val, exc_tb) + + def __torch_dispatch__( + self, + func: Any, + types: tuple[Any, ...], + args: tuple[Any, ...] = (), + kwargs: dict[str, Any] | None = None, + ) -> Any: + if kwargs is None: + kwargs = {} + + flat_args, args_spec = pytree.tree_flatten((args, kwargs)) + + # Find all LocalTensor arguments to determine ranks + local_tensors = [a for a in flat_args if isinstance(a, LocalTensor)] + + # Check for unrecognized tensor subclasses (but allow regular tensors and scalars) + has_unrecognized_types = _check_for_subclass(flat_args) + if has_unrecognized_types: + unrecognized_types = [ + type(x) for x in flat_args if _check_for_subclass_arg(x) + ] + not_implemented_log.debug( + "LocalTensorMode unrecognized subclass(es): %s", unrecognized_types + ) + return NotImplemented + + # Factory functions convert into LocalTensor, so we don't have to + # transmute a Tensor into a LocalTensor if mutation happens... + # But if you do an operation on a Tensor, do NOT wrap it into a + # LocalTensor. This helps prevent accidents when you're doing Tensor + # operations on the inner non-wrapped tensors. + if not local_tensors: + if self._disable or any(isinstance(a, Tensor) for a in flat_args): + return func(*args, **kwargs) + + # For LocalTensors, verify they have compatible ranks + for a in flat_args: + if isinstance(a, LocalTensor): + assert a._ranks <= self.ranks, ( + f"Input LocalTensor {a} must be configured for a subset of the LocalTensorMode ranks {self.ranks}" + ) + + if func.overloadpacket == torch.ops.aten.dim: + return len(args[0]._size) + if func.overloadpacket == torch.ops.aten.sym_size: + return tuple(args[0]._size) + + if func.namespace == "c10d": + if func is torch.ops.c10d.allreduce_.default: + return _c10d._local_all_reduce_(*args, **kwargs) + elif func is torch.ops.c10d.allreduce_coalesced_.default: + return _c10d._local_allreduce_coalesced_(*args, **kwargs) + elif func is torch.ops.c10d.reduce_scatter_tensor_coalesced_.default: + return _c10d._local_reduce_scatter_tensor_coalesced_(*args, **kwargs) + elif func is torch.ops.c10d.scatter_.default: + return _c10d._local_scatter_(*args, **kwargs) + elif func is torch.ops.c10d.broadcast_.default: + return _c10d._local_broadcast_(*args, **kwargs) + elif func is torch.ops.c10d.allgather_.default: + return _c10d._local_all_gather_(*args, **kwargs) + elif func is torch.ops.c10d.allgather_into_tensor_coalesced_.default: + return _c10d._local_allgather_into_tensor_coalesced_(*args, **kwargs) + elif func is torch.ops.c10d._allgather_base_.default: + return _c10d._local_allgather_base_(*args, **kwargs) + elif func is torch.ops.c10d._reduce_scatter_base_.default: + return _c10d._local_reduce_scatter_base_(*args, **kwargs) + elif func is torch.ops.c10d.gather_.default: + return _c10d._local_gather_(*args, **kwargs) + elif func is torch.ops.c10d.alltoall_.default: + return _c10d._local_alltoall_(*args, **kwargs) + elif func is torch.ops.c10d.alltoall_base_.default: + return _c10d._local_alltoall_base_(*args, **kwargs) + elif func is torch.ops.c10d.barrier.default: + return _c10d._local_barrier(*args, **kwargs) + elif func is torch.ops.c10d.monitored_barrier_.default: + return _c10d._local_monitored_barrier_(*args, **kwargs) + elif func is torch.ops.c10d.send.default: + return _c10d._local_send(*args, **kwargs) + elif func is torch.ops.c10d.recv_.default: + return _c10d._local_recv_(*args, **kwargs) + elif func is torch.ops.c10d.recv_any_source_.default: + return _c10d._local_recv_any_source_(*args, **kwargs) + raise NotImplementedError(f"{func} not implemented") + + if func.namespace == "_c10d_functional" or func.namespace == "_dtensor": + if func is torch.ops._dtensor.shard_dim_alltoall.default: + return _c10d._local_functional_shard_dim_alltoall(*args, **kwargs) + elif func is torch.ops._c10d_functional.all_gather_into_tensor.default: + return _c10d._local_functional_all_gather_into_tensor(*args, **kwargs) + elif func is torch.ops._c10d_functional.reduce_scatter_tensor.default: + return _c10d._local_functional_reduce_scatter_tensor(*args, **kwargs) + elif func is torch.ops._c10d_functional.all_to_all_single.default: + return _c10d._local_functional_all_to_all_single(*args, **kwargs) + else: + with LocalTensorMode(self.ranks): + return func._op_dk( + DispatchKey.CompositeExplicitAutograd, *args, **kwargs + ) + + if func.namespace == "profiler": + return func(*args, **kwargs) + + if func.namespace == "_c10d_functional_autograd": + raise NotImplementedError(f"{func} not implemented") + + if func.namespace == "symm_mem": + raise NotImplementedError(f"{func} not implemented") + + return _for_each_rank_run_func(func, self.ranks, args, kwargs, alias=True) + + def disable_(self): + if self._disable: + return + + self._unpatch_device_mesh() + self._unpatch_random_functions() + self._disable = True + + def enable_(self): + if not self._disable: + return + + self._patch_device_mesh() + self._patch_random_functions() + self._disable = False + + @contextlib.contextmanager + def disable(self) -> Generator[None, None, None]: + """ + Disables LocalTensorMode temporarily. Primarily is intended to be used to perform + rank specific computations and merge results back before enabling LocalTensorMode back. + """ + + # don't unpatch again if already disabled + if self._disable: + try: + yield + finally: + # re-disable if the yield messed + # with the state + self.disable_() + return + + self.disable_() + try: + yield + finally: + self.enable_() + + def rank_map(self, cb: Callable[[int], Tensor]) -> LocalTensor: + """ + Creates a LocalTensor instance by mapping rank id to ids local shard. + """ + + with self.disable(): + # pyrefly: ignore [bad-argument-type, bad-argument-count] + return LocalTensor({r: cb(r) for r in self.ranks}) + + def tensor_map( + self, tensor: LocalTensor, cb: Callable[[int, Tensor], Tensor | None] + ) -> LocalTensor: + """ + Creates a LocalTensor instance by mapping rank id to ids local shard. + """ + + with self.disable(): + results = {} + for r in self.ranks: + if r in tensor._local_tensors: + m = cb(r, tensor._local_tensors[r]) + if m is not None: + results[r] = m + # pyrefly: ignore [bad-argument-type, bad-argument-count] + return LocalTensor(results) + + def _any_local_rng_state(self) -> tuple[torch.Tensor, dict[int, torch.Tensor]]: + return self._per_rank_rng_states[next(iter(self.ranks))] + + def _patch_device_mesh(self) -> None: + assert self._old_get_coordinate is None + assert self._old_get_rank is None + assert self._old_get_local_rank is None + self._old_get_coordinate = DeviceMesh.get_coordinate # type: ignore[assignment] + self._old_get_rank = DeviceMesh.get_rank # type: ignore[assignment] + self._old_get_local_rank = DeviceMesh.get_local_rank # type: ignore[assignment] + DeviceMesh.get_coordinate = _LocalDeviceMesh.get_coordinate # type: ignore[method-assign] + DeviceMesh.get_rank = _LocalDeviceMesh.get_rank # type: ignore[method-assign] + DeviceMesh.get_local_rank = _LocalDeviceMesh.get_local_rank # type: ignore[method-assign] + + def _unpatch_device_mesh(self) -> None: + assert self._old_get_coordinate is not None + assert self._old_get_rank is not None + assert self._old_get_local_rank is not None + DeviceMesh.get_coordinate = self._old_get_coordinate + DeviceMesh.get_rank = self._old_get_rank + DeviceMesh.get_local_rank = self._old_get_local_rank + # pyrefly: ignore [bad-assignment] + self._old_get_coordinate = None + # pyrefly: ignore [bad-assignment] + self._old_get_rank = None + # pyrefly: ignore [bad-assignment] + self._old_get_local_rank = None + + def _patch_random_functions(self) -> None: + import torch.random + from torch.distributed.tensor import _random as dtensor_random + + if self._old_torch_manual_seed is None: + self._old_torch_manual_seed = torch.random.manual_seed + torch.random.manual_seed = _LocalRandom.torch_manual_seed + torch.manual_seed = _LocalRandom.torch_manual_seed + + if self._old_torch_initial_seed is None: + self._old_torch_initial_seed = torch.random.initial_seed + torch.random.initial_seed = _LocalRandom.torch_initial_seed + torch.initial_seed = _LocalRandom.torch_initial_seed + + def _unpatch_random_functions(self) -> None: + import torch.random + from torch.distributed.tensor import _random as dtensor_random + + if self._old_torch_manual_seed is not None: + torch.random.manual_seed = self._old_torch_manual_seed + torch.manual_seed = self._old_torch_manual_seed + self._old_torch_manual_seed = None + + if self._old_torch_initial_seed is not None: + torch.random.initial_seed = self._old_torch_initial_seed + torch.initial_seed = self._old_torch_initial_seed + self._old_torch_initial_seed = None + + +class _LocalRandom: + """ + Holds implementations of random functionality that must be patched while running + under LocalTensorMode. + """ + + @staticmethod + def torch_manual_seed(seed) -> torch._C.Generator: + """LocalTensor-aware version of torch.random.manual_seed.""" + if ( + (lm := enabled_local_tensor_mode()) + and isinstance(seed, torch.SymInt) + and isinstance(seed.node, LocalIntNode) + ): + from torch.random import _manual_seed_impl + + for rank in sorted(lm.ranks): + rank_seed = seed.node._local_ints[rank] + _manual_seed_impl(rank_seed) + lm._per_rank_rng_states[rank] = _get_rng_state() + return torch.random.default_generator + from torch.random import _manual_seed_impl + + result = _manual_seed_impl(seed) + + if lm is not None and len(lm._per_rank_rng_states) > 0: + cpu_state, cuda_states = _get_rng_state() + for rank in lm.ranks: + lm._per_rank_rng_states[rank] = ( + cpu_state.clone(), + {idx: state.clone() for idx, state in cuda_states.items()}, + ) + + return result + + @staticmethod + def torch_initial_seed(): + """LocalTensor-aware version of torch.random.initial_seed.""" + if lm := enabled_local_tensor_mode(): + if len(lm._per_rank_rng_states) == 0: + return torch.random.default_generator.initial_seed() + rank_seeds = {} + + for rank in sorted(lm.ranks): + _set_rng_state(*lm._per_rank_rng_states[rank]) + rank_seeds[rank] = torch.random.default_generator.initial_seed() + + local_int_node = LocalIntNode(rank_seeds) + return torch.SymInt(local_int_node) + + return torch.random.default_generator.initial_seed() + + +# Save the original get_coordinate method before any patching + + +class _LocalDeviceMesh: + """ + Holds implementations of DeviceMesh functionality that must be patched while running + under LocalTensorMode. + """ + + @staticmethod + def get_coordinate(self: DeviceMesh) -> list[int] | None: + # NB: In order to support submeshes the code below recreates for each + # rank submesh with the same mesh dimensions as current mesh. We are + # doing this because when submesh is created it is created for a particular + # rank (therefore below we are patching get_rank method). We are trying to + # limit the invasiveness of local tensor. + lm = enabled_local_tensor_mode() + assert lm is not None, "Unexpectedly not in LocalTensorMode" + + coords: list[dict[int, int]] = [{} for _ in range(self.ndim)] + for r in lm.ranks: + rank_tensor = self._layout.remap_to_tensor(self._rank_map) + rank_coords = (rank_tensor == r).nonzero().tolist() + assert len(rank_coords) == 1 + for d, c in enumerate(rank_coords[0][1:]): + coords[d][r] = c + + out = [torch.SymInt(LocalIntNode(c)) for c in coords] + # The output contains coordinates for each of the ranks with respect to + # their meshes formed from root mesh and selecting the same dimensions + # as the current mesh. + return out # type: ignore[return-value] + + @staticmethod + def get_rank(self) -> int | SymInt: + lm = enabled_local_tensor_mode() + assert lm is not None, "Unexpectedly not in LocalTensorMode" + return torch.SymInt(LocalIntNode(local_ints={r: r for r in lm.ranks})) + + @staticmethod + def get_local_rank(self, mesh_dim: int | str | None = None) -> int | SymInt: + lm = enabled_local_tensor_mode() + assert lm is not None, "Unexpectedly not in LocalTensorMode" + + if self.ndim > 1 and mesh_dim is None: + raise RuntimeError( + f"Found the DeviceMesh have {self.ndim} dimensions", + "Optional kwarg `mesh_dim` needs to be specified when device_mesh.ndim > 1.", + ) + elif mesh_dim is None: + mesh_dim = 0 + + if isinstance(mesh_dim, str): + mesh_dim = self._mesh_dim_names.index(mesh_dim) + + # Compute local rank for each global rank + # get_coordinate returns a list of SymInt, one per mesh dimension + # We need to extract the coordinate for the specified mesh_dim + coords = _LocalDeviceMesh.get_coordinate(self) + assert coords is not None + return coords[mesh_dim] + + +def reconcile_args(args: Any, kwargs: dict[str, Any] | None = None) -> Any: + """ + Reconciles arguments by converting any LocalTensor instances in the input + arguments to their underlying torch.Tensor representation. + + This function is typically used to prepare arguments for functions that + expect standard torch.Tensor objects, by flattening the input arguments, + replacing LocalTensor instances with their reconciled (standard tensor) + versions, and then reconstructing the original argument structure. + + Args: + args: Positional arguments, possibly containing LocalTensor instances. + kwargs: Keyword arguments, possibly containing LocalTensor instances. + + Returns: + Any: The arguments with all LocalTensor instances replaced by their reconciled torch.Tensor equivalents, + preserving the original structure. + """ + if kwargs is None: + kwargs = {} + flat_args, args_spec = pytree.tree_flatten((args, kwargs)) + reconciled_args = [ + a.reconcile() if isinstance(a, LocalTensor) else a for a in flat_args + ] + return pytree.tree_unflatten(reconciled_args, args_spec) + + +def local_tensor_mode() -> LocalTensorMode | None: + """ + Returns the current active LocalTensorMode if one exists. + + This function checks the global stack of LocalTensorMode instance. If there + is at least one LocalTensorMode active, it returns the most recently entered + (top of the stack) LocalTensorMode. If no LocalTensorMode is active, it returns None. + + Returns: + Optional[LocalTensorMode]: The current LocalTensorMode if active, else None. + """ + local_tensor_mode_list = get_local_tensor_mode_list() + if len(local_tensor_mode_list) > 0: + return local_tensor_mode_list[-1] + return None + + +def enabled_local_tensor_mode() -> LocalTensorMode | None: + """ + Returns the current active LocalTensorMode only if it's enabled. + + This is a convenience function that combines the common pattern of checking + if local_tensor_mode() is not None and not disabled. + + Returns: + Optional[LocalTensorMode]: The current LocalTensorMode if active and enabled, else None. + """ + lm = local_tensor_mode() + if lm is not None and not lm._disable: + return lm + return None + + +def maybe_run_for_local_tensor(func: Callable[_P, _R]) -> Callable[_P, _R]: + """ + Decorator that ensures a function is executed for each local tensor shard + when running under LocalTensorMode. If not in LocalTensorMode, the function + is executed normally. When in LocalTensorMode, the function is run for each + rank, and the results are collected appropriately. + + This decorator is useful for functions that exhibit non-SPMD behavior, such + as those requiring rank specific actions. For example, a function that computes + offset into input tensor based on rank. + + Note that the function being decorated must not have any side effects and + contain operations for a single rank only. For example, wrapping a function + that performs a collective operation will not work. + + Args: + func (Callable[..., Any]): The function to be decorated. + + Returns: + Callable[..., Any]: The wrapped function that handles LocalTensorMode logic. + """ + + @functools.wraps(func) + def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R: + if not (lm := enabled_local_tensor_mode()): + return func(*args, **kwargs) + ret = None + with lm.disable(): + ret = _for_each_rank_run_func(func, lm.ranks, args, kwargs, alias=False) + + return ret + + return wrapper + + +def maybe_disable_local_tensor_mode() -> contextlib.AbstractContextManager: + """ + Context manager that disables LocalTensorMode for the duration of the context. + """ + lm = local_tensor_mode() + return lm.disable() if lm is not None else contextlib.nullcontext() + + +def maybe_enable_local_tracker( + device_type: str, distribute_region_enabled: bool, spec, generator +): + """ + Returns a context manager for LocalTensor-mode RNG tracking if local tensor mode is enabled. + + Args: + device_type: The device type (e.g., "cuda", "cpu") + distribute_region_enabled: Whether distribute region is enabled + spec: The DTensorSpec + generator: Optional torch.Generator + + Returns: + Context manager from local_tracker._distribute_region if local tensor mode is enabled, + otherwise None. + """ + if enabled_local_tensor_mode(): + local_tracker = _LocalOffsetBasedRNGTracker(device_type) + local_tracker.distribute_region_enabled = distribute_region_enabled + return local_tracker._distribute_region(spec, generator) + + return None + + +def get_generator_seed_for_device_type(device_type: str): + """ + Gets the generator seed for a specific device type, handling LocalTensor mode appropriately. + + Args: + device_type: The device type (e.g., "cuda", "cpu") + + Returns: + If in LocalTensor mode with per-rank RNG states: + - Returns int if all ranks have the same seed + - Returns SymInt(LocalIntNode) if ranks have different seeds + Otherwise: + - Returns int seed from the device's RNG state + """ + if lm := enabled_local_tensor_mode(): + if len(lm._per_rank_rng_states) == 0: + device_module = torch.get_device_module(device_type) + return device_module.get_rng_state()[:8].view(torch.int64).item() + device_module = torch.get_device_module(device_type) + + original_state = _get_rng_state() + + rank_seeds = {} + try: + for rank in sorted(lm.ranks): + _set_rng_state(*lm._per_rank_rng_states[rank]) + rank_seeds[rank] = int( + device_module.get_rng_state()[:8].view(torch.int64).item() + ) + finally: + # restore original state + _set_rng_state(*original_state) + + unique_seeds = set(rank_seeds.values()) + if len(unique_seeds) == 1: + return next(iter(unique_seeds)) + local_int_node = LocalIntNode(rank_seeds) + return torch.SymInt(local_int_node) + else: + device_module = torch.get_device_module(device_type) + return device_module.get_rng_state()[:8].view(torch.int64).item() + + +import threading +from queue import Queue + + +_LOCAL_RUNNER_MODE: "LocalRunnerMode | None" = None + + +class LocalRunnerMode: + """ + A class for running multiple SPMD functions concurrently, however at any point + in time only one function can be running. The main use case for the local runner + mode is to enable SPMD functions to be able to use send and recv to communicate + with each other. Without local runner mode send and recv are not supported. + """ + + runner_context = threading.local() + + def __init__( + self, ranks: frozenset[int] | int, concurrency: int, fn: Callable[[int], None] + ): + if isinstance(ranks, int): + ranks = frozenset(range(ranks)) + self._ranks = ranks + self._fn = fn + self._run_lock = threading.Lock() + self._run_id = -1 + self._run_cond = threading.Condition(self._run_lock) + + self._recv_objects: dict[int, dict[int, Queue]] = { + dst: {src: Queue() for src in ranks} for dst in ranks + } + self._runners = [ + threading.Thread(target=self._run, args=(i,), name="LocalRunnerMode") + for i in range(concurrency) + ] + self._process_mode = True + + def __enter__(self) -> "LocalRunnerMode": + global _LOCAL_RUNNER_MODE + assert _LOCAL_RUNNER_MODE is None, "LocalRunnerMode is already running" + _LOCAL_RUNNER_MODE = self + + global _PROCESS_MODE + self._process_mode = _PROCESS_MODE + _PROCESS_MODE = False + for r in self._runners: + r.start() + return self + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + for r in self._runners: + r.join() + global _LOCAL_RUNNER_MODE + _LOCAL_RUNNER_MODE = None + + global _PROCESS_MODE + _PROCESS_MODE = self._process_mode + + def _run(self, id: int) -> None: + LocalRunnerMode.runner_context.id = id + # Only one thread can run at a time, hence must acquire the lock + try: + self._acquire_run_lock() + self._fn(id) + finally: + self._release_run_lock() + + def _acquire_run_lock(self) -> None: + self._run_lock.acquire() + self._run_id = LocalRunnerMode.runner_context.id + + def _release_run_lock(self) -> None: + self._run_id = -1 + self._run_lock.release() + + def _assert_holds_run_lock(self) -> None: + assert self._run_id == LocalRunnerMode.runner_context.id, ( + "Calling thread does not hold the run lock" + ) + + def _get_recv_object(self, src: int, dst: int) -> object | None: + peers = [src] if src != -1 else list(self._ranks) + recv_objects = self._recv_objects[dst] + + for p in peers: + if not recv_objects[p].empty(): + return recv_objects[p].get() + + return None + + def _signal_send(self, src: int, dst: int, obj: object) -> None: + assert obj is not None, "Cannot signal None" + # Only a single thread a time executes so it is safe to mutate + # read objects queue (executing thread is already holding the lock) + self._recv_objects[dst][src].put(obj) + # Signal directly condition variable since the calling thread is already + # holding the lock + self._run_cond.notify_all() + + def _wait_recv(self, src: int, dst: int, post: Callable[[object], None]) -> None: + # Wait for the object to be available + while True: + obj = self._get_recv_object(src, dst) + if obj is not None: + post(obj) + # Note that we are not releasing the lock here, since the thread + # will continue to run and therefore must hold the lock + return + self._run_cond.wait() + + @staticmethod + def current() -> "LocalRunnerMode": + global _LOCAL_RUNNER_MODE + assert _LOCAL_RUNNER_MODE is not None, "LocalRunnerMode is not enabled" + return _LOCAL_RUNNER_MODE + + +class _LocalPhiloxState: + """ + LocalTensor-aware version of _PhiloxState that manages per-rank RNG states. + This class handles the case where the generator state is a LocalTensor, allowing + different offsets and seeds for different virtual ranks. + + Note: This is designed to be used as a drop-in replacement for _PhiloxState + when working with LocalTensors in the DTensor random ops implementation. + """ + + def __init__(self, state: torch.Tensor): + assert isinstance(state, LocalTensor), ( + "_LocalPhiloxState requires a LocalTensor" + ) + self._local_tensor = state + self._per_rank_states = { + rank: local_state.to("cpu") + for rank, local_state in state._local_tensors.items() + } + + @property + def state(self): + return LocalTensor(self._per_rank_states) # type: ignore[name-defined] + + @property + def offset(self) -> int | SymInt: + from torch.distributed.tensor._random import _PhiloxState + + offsets = {} + for rank, state in self._per_rank_states.items(): + rank_philox = _PhiloxState(state) + offsets[rank] = rank_philox.offset + + if len(set(offsets.values())) == 1: + return next(iter(offsets.values())) + # pyrefly: ignore [bad-argument-type, bad-argument-count] + return SymInt(LocalIntNode(offsets)) + + @offset.setter + def offset(self, offset: int | SymInt) -> None: + from torch.distributed.tensor._random import _PhiloxState + + if isinstance(offset, SymInt) and isinstance(offset.node, LocalIntNode): + for rank, state in self._per_rank_states.items(): + rank_offset = offset.node._local_ints[rank] + rank_philox = _PhiloxState(state) + rank_philox.offset = rank_offset + else: + offset_int = int(offset) if isinstance(offset, SymInt) else offset + for state in self._per_rank_states.values(): + rank_philox = _PhiloxState(state) + rank_philox.offset = offset_int + + @property + def seed(self) -> int | SymInt: + from torch.distributed.tensor._random import _PhiloxState + + seeds = {} + for rank, state in self._per_rank_states.items(): + rank_philox = _PhiloxState(state) + seeds[rank] = rank_philox.seed + + if len(set(seeds.values())) == 1: + return next(iter(seeds.values())) + return SymInt(LocalIntNode(seeds)) + + @seed.setter + def seed(self, seed: int | SymInt) -> None: + from torch.distributed.tensor._random import _PhiloxState + + if isinstance(seed, SymInt) and isinstance(seed.node, LocalIntNode): + for rank, state in self._per_rank_states.items(): + rank_seed = seed.node._local_ints[rank] + rank_philox = _PhiloxState(state) + rank_philox.seed = rank_seed + else: + seed_int = int(seed) if isinstance(seed, SymInt) else seed + for state in self._per_rank_states.values(): + rank_philox = _PhiloxState(state) + rank_philox.seed = seed_int + + def apply_to_local_tensor_mode(self, device_handle) -> None: + """ + Apply per-rank RNG states to the LocalTensorMode's tracked states. + This updates both the device RNG state and the LocalTensorMode's _per_rank_rng_states. + + Args: + device_handle: The device handle to use for setting RNG state (_LocalDeviceHandle) + """ + if not enabled_local_tensor_mode(): + return + + assert hasattr(self, "_per_rank_offsets") + + for rank in sorted(self._per_rank_states.keys()): + offset_value = self._per_rank_offsets[rank] + if isinstance(offset_value, SymInt): + if isinstance(offset_value.node, LocalIntNode): + offset_value = offset_value.node._local_ints[rank] + else: + offset_value = int(offset_value) + + offset_tensor = torch.tensor( + [offset_value], dtype=torch.uint64, device="cpu" + ).view(torch.uint8) + self._per_rank_states[rank][8:] = offset_tensor + + # pyrefly: ignore [bad-argument-type, bad-argument-count] + device_handle.set_rng_state(LocalTensor(self._per_rank_states)) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_local_tensor/_c10d.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_local_tensor/_c10d.py new file mode 100644 index 0000000000000000000000000000000000000000..b3eca57402c56d8b5e9cdb216245ee2652f5250e --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_local_tensor/_c10d.py @@ -0,0 +1,1060 @@ +import functools +import math +import operator +from collections.abc import Callable, Sequence +from datetime import timedelta + +import torch +from torch._C import ScriptObject +from torch._C._distributed_c10d import FakeWork, PythonCallbackWork +from torch.distributed._mesh_layout import _MeshLayout +from torch.distributed.distributed_c10d import ( + _check_op, + _get_default_group, + _resolve_process_group, + GroupName, + ProcessGroup, + ReduceOp, + Work, +) + + +# NOTE: Most of the c10d collectives often take a Tensor[] (or Tensor[][]) +# when you would expect Tensor (or Tensor[]). In fact, there will only ever +# be one Tensor in this case; the old signature was to support dispatching a +# collective on multiple devices (ala DataParallel) but we don't support that +# API anymore. Note that we are not 100% consistent about this; some more +# modern collectives like _allgather_base_ got rid of the unnecessary list. +# When in doubt, consult the code that dispatches to the collective on the PG +# in distributed_c10d.py e.g., work = group.allgather([tensor_list], [tensor], +# opts) indicates its always a list. + + +def _gcd_list(numbers: Sequence[int]) -> int: + return 0 if not numbers else functools.reduce(math.gcd, numbers) + + +def _indices_to_layout(indices: list[int]) -> tuple[tuple[int, ...], tuple[int, ...]]: + # Base case: A single index represents a point, not a dimension. + if len(indices) <= 1: + return (), () + + # The smallest stride is likely the GCD of the differences between consecutive indices. + # For a sorted, unique list, all differences will be positive. + diffs = [indices[i] - indices[i - 1] for i in range(1, len(indices))] + last_stride = _gcd_list(diffs) + + assert last_stride != 0, ( + # This case should not be reached if indices are unique and sorted. + "Cannot determine stride; indices may not be unique." + ) + + # Identify the starting index of each "row" in the last dimension. + # An index starts a new row if the preceding index (index - stride) is not present. + indices_set = set(indices) + higher_dim_indices = [indices[0]] + for index in indices[1:]: + if (index - last_stride) not in indices_set: + higher_dim_indices.append(index) + + # From the number of rows, we can deduce the shape of the last dimension. + assert len(indices) % len(higher_dim_indices) == 0, ( + "Indices do not form a regular grid. " + f"Found {len(higher_dim_indices)} subgroups for {len(indices)} total elements." + ) + last_shape = len(indices) // len(higher_dim_indices) + + # Recurse on the higher-dimensional indices (the start of each row). + higher_shapes, higher_strides = _indices_to_layout(higher_dim_indices) + + # Combine the results from the recursion with the current dimension's results. + final_shapes = higher_shapes + (last_shape,) + final_strides = higher_strides + (last_stride,) + + return final_shapes, final_strides + + +def _prepare_collective_groups( + process_group_so: ScriptObject | ProcessGroup, +) -> tuple[list[int], list[int], int]: + process_group = ( + ProcessGroup.unbox(process_group_so) + if isinstance(process_group_so, ScriptObject) + else process_group_so + ) + + ranks = torch.distributed.get_process_group_ranks(process_group) + assert ranks + # TODO: We can handle permutations but the layout inference algorithm will + # lose the permutation so we will have to reapply it + assert ranks == sorted(ranks), ranks + offset = ranks[0] + ranks = [r - offset for r in ranks] + + shape, strides = _indices_to_layout(ranks) + layout = _MeshLayout(shape, strides) + + global_pg = _get_default_group() + group_offsets = layout.complement(global_pg.size()).all_ranks_from_zero() + + return ranks, group_offsets, offset + + +# NB: There are two flavors of the collectives: regular and functional. Regular collectives +# allocate outputs to write the result to, accept process group and support async ops (return +# work object). Functional collectives expect the implementation to allocate outputs, accept +# process group name that must be resolved and do not support async ops (return output). +def _local_functional_all_gather_into_tensor( + tensor: torch.Tensor, group_size: int, group_name: GroupName +) -> torch.Tensor: + # "all_gather_into_tensor(Tensor input, int group_size, str group_name) -> Tensor" + from . import LocalTensor + + ranks, group_offsets, offset = _prepare_collective_groups( + _resolve_process_group(group_name) + ) + + assert isinstance(tensor, LocalTensor), "Input tensor must be a LocalTensor" + output_local_tensors: dict[int, torch.Tensor] = {} + + for group_offset in group_offsets: + group_ranks = [group_offset + r for r in ranks] + + group_tensors = [] + if not all(rank in tensor._local_tensors for rank in group_ranks): + continue + + for rank in group_ranks: + group_tensors.append(tensor._local_tensors[rank]) + + gathered_tensor = torch.cat(group_tensors, dim=0) + + for rank in group_ranks: + output_local_tensors[rank] = gathered_tensor.clone() + + # pyrefly: ignore [bad-argument-type, bad-argument-count] + output = LocalTensor(output_local_tensors) + + return output + + +def _local_functional_reduce_scatter_tensor( + tensor: torch.Tensor, reduce_op: str, group_size: int, group_name: GroupName +) -> torch.Tensor: + # "reduce_scatter_tensor(Tensor input, str reduce_op, int group_size, str group_name) -> Tensor" + from . import _zero_sized_like, LocalTensor + + ranks, group_offsets, offset = _prepare_collective_groups( + _resolve_process_group(group_name) + ) + + assert isinstance(tensor, LocalTensor), "Input tensor must be a LocalTensor" + output_local_tensors: dict[int, torch.Tensor] = {} + + for group_offset in group_offsets: + group_ranks = [group_offset + r for r in ranks] + + group_tensors = [] + if not all(rank in tensor._local_tensors for rank in group_ranks): + continue + + for rank in group_ranks: + group_tensors.append(tensor._local_tensors[rank]) + + reduced_tensor = _local_reduce(reduce_op, group_tensors) + + scattered_tensor = torch.split( + reduced_tensor, + reduced_tensor.size(0) // len(group_ranks), + dim=0, + ) + + for i, rank in enumerate(group_ranks): + if i < len(scattered_tensor): + output_local_tensors[rank] = scattered_tensor[i].clone() + else: + output_local_tensors[rank] = _zero_sized_like(reduced_tensor, 0) + + # pyrefly: ignore [bad-argument-type, bad-argument-count] + output = LocalTensor(output_local_tensors) + + return output + + +def _local_functional_shard_dim_alltoall( + tensor: torch.Tensor, gather_dim: int, shard_dim: int, group_name: GroupName +) -> torch.Tensor: + # "shard_dim_alltoall(Tensor input, int gather_dim, int shard_dim, str group_name) -> Tensor" + from . import _zero_sized_like, LocalTensor + + ranks, group_offsets, offset = _prepare_collective_groups( + _resolve_process_group(group_name) + ) + + assert isinstance(tensor, LocalTensor), "Input tensor must be a LocalTensor" + output_local_tensors: dict[int, torch.Tensor] = {} + + for group_offset in group_offsets: + group_ranks = [group_offset + r for r in ranks] + + group_tensors = [] + if not all(rank in tensor._local_tensors for rank in group_ranks): + continue + + for rank in group_ranks: + group_tensors.append(tensor._local_tensors[rank]) + + gathered_tensor = torch.cat(group_tensors, dim=gather_dim) + + split_tensor = torch.split( + gathered_tensor, + gathered_tensor.size(shard_dim) // len(group_ranks), + dim=shard_dim, + ) + + for i, rank in enumerate(group_ranks): + if i < len(split_tensor): + output_local_tensors[rank] = split_tensor[i].clone() + else: + output_local_tensors[rank] = _zero_sized_like( + gathered_tensor, shard_dim + ) + + # pyrefly: ignore [bad-argument-type, bad-argument-count] + output = LocalTensor(output_local_tensors) + + return output + + +def _local_functional_all_to_all_single( + tensor: torch.Tensor, + output_split_sizes: list[torch.SymInt], + input_split_sizes: list[torch.SymInt], + group_name: GroupName, +) -> torch.Tensor: + # "all_to_all_single(Tensor input, SymInt[] output_split_sizes, SymInt[] input_split_sizes, str group_name) -> Tensor" + from . import LocalIntNode, LocalTensor + + ranks, group_offsets, offset = _prepare_collective_groups( + _resolve_process_group(group_name) + ) + + assert isinstance(tensor, LocalTensor), "Input tensor must be a LocalTensor" + + split_local_sizes: dict[int, list[int]] = {} + for input_split_size in input_split_sizes: + if isinstance(input_split_size, torch.SymInt) and isinstance( + input_split_size.node, LocalIntNode + ): + local_ints = dict(input_split_size.node._local_ints.items()) + else: + local_ints = {rank: int(input_split_size) for rank in tensor._local_tensors} + for rank, split_size in local_ints.items(): + if rank not in split_local_sizes: + split_local_sizes[rank] = [] + split_local_sizes[rank].append(split_size) + + split_local_tensors: dict[int, list[torch.Tensor]] = {} + + for rank, split_sizes in split_local_sizes.items(): + split_local_tensors[rank] = list( + torch.split(tensor._local_tensors[rank], split_sizes) + ) + + output_local_tensors: dict[int, torch.Tensor] = {} + + for group_offset in group_offsets: + group_ranks = [group_offset + r for r in ranks] + + if not all(rank in split_local_tensors for rank in group_ranks): + continue + + for i, dst in enumerate(group_ranks): + splits = [] + for j, src in enumerate(group_ranks): + splits.append(split_local_tensors[src][i]) + output_local_tensors[dst] = torch.cat(splits) + + # pyrefly: ignore [bad-argument-type, bad-argument-count] + output = LocalTensor(output_local_tensors) + + return output + + +def _local_broadcast_( + tensors: list[torch.Tensor], + process_group_so: ScriptObject, + root_rank: int, + root_tensor: int, + async_op: bool = True, + timeout: int = -1, +) -> tuple[list[torch.Tensor], ScriptObject]: + # "broadcast_(Tensor[] tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, " + # "int root_rank, int root_tensor, bool async_op=True, int timeout=-1) -> (Tensor[], __torch__.torch.classes.c10d.Work)" + from . import LocalTensor + + assert len(tensors) == 1 + assert root_tensor == 0 + tensor = tensors[0] + + ranks, group_offsets, offset = _prepare_collective_groups(process_group_so) + + # We're going to assume SPMD where for every rank group the root_rank is + # the same relative to others + relative_root_rank = root_rank - offset + + assert isinstance(tensor, LocalTensor), "Input tensor must be a LocalTensor" + + for group_offset in group_offsets: + # For the tensors in this group [group_offset + r for r in ranks] + # perform the broadcast on them + group_ranks = [group_offset + r for r in ranks] + if not all(rank in tensor._local_tensors for rank in group_ranks): + continue + + source_rank = group_offset + relative_root_rank + source_tensor = tensor._local_tensors[source_rank] + + # Broadcast the source tensor to all ranks in this group + for rank in group_ranks: + if source_rank != rank: + tensor._local_tensors[rank].copy_(source_tensor) + + work = FakeWork() + work_so = Work.boxed(work) + return (tensors, work_so) + + +def _local_reduce( + reduce_op: ReduceOp | str, + tensors: list[torch.Tensor], +) -> torch.Tensor: + if reduce_op == ReduceOp.SUM or reduce_op == "sum": + op = operator.add + elif reduce_op == ReduceOp.AVG or reduce_op == "avg": + op = None + elif reduce_op == ReduceOp.PRODUCT or reduce_op == "product": + op = operator.mul + elif reduce_op == ReduceOp.MIN or reduce_op == "min": + op = torch.minimum + elif reduce_op == ReduceOp.MAX or reduce_op == "max": + op = torch.maximum + elif reduce_op == ReduceOp.BAND or reduce_op == "band": + op = torch.bitwise_and + elif reduce_op == ReduceOp.BOR or reduce_op == "bor": + op = torch.bitwise_or + elif reduce_op == ReduceOp.BXOR or reduce_op == "bxor": + op = torch.bitwise_xor + elif reduce_op == ReduceOp.PREMUL_SUM or reduce_op == "premul_sum": + raise NotImplementedError("PREMUL_SUM: need to add binding for scaling factor") + else: + raise NotImplementedError(f"ReduceOp {reduce_op} not implemented") + + if reduce_op == ReduceOp.AVG or reduce_op == "avg": + return functools.reduce(operator.add, tensors) / len(tensors) + else: + assert op is not None + return functools.reduce(op, tensors) + + +def _local_all_reduce_( + tensors: list[torch.Tensor], + process_group_so: ScriptObject, + reduce_op_so: ScriptObject, + sparse_indices: torch.Tensor | None = None, + async_op: bool = True, + timeout: int = -1, +) -> tuple[list[torch.Tensor], ScriptObject]: + # "allreduce_(Tensor[] tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, " + # "__torch__.torch.classes.c10d.ReduceOp reduce_op, Tensor? sparse_indices, bool async_op=True, " + # "int timeout=-1) -> (Tensor[], __torch__.torch.classes.c10d.Work)"); + from . import LocalTensor + + assert len(tensors) == 1 + tensor = tensors[0] + reduce_op = reduce_op_so.op() # type: ignore[attr-defined] + + ranks, group_offsets, _offset = _prepare_collective_groups(process_group_so) + + assert isinstance(tensor, LocalTensor), "Input tensor must be a LocalTensor" + + for group_offset in group_offsets: + # For the tensors in this group [group_offset + r for r in ranks] + # perform the allreduce on them + group_ranks = [group_offset + r for r in ranks] + if not all(rank in tensor._local_tensors for rank in group_ranks): + continue + + # Collect tensors from the specified ranks in this group + group_tensors = [] + for rank in group_ranks: + group_tensors.append(tensor._local_tensors[rank]) + + # Perform the reduction operation + reduced_tensor = _local_reduce(reduce_op, group_tensors) + + # Update all tensors in the group with the reduced result + for rank in group_ranks: + tensor._local_tensors[rank].copy_(reduced_tensor) + + work = FakeWork() + work_so = Work.boxed(work) + return (tensors, work_so) + + +def _local_allreduce_coalesced_( + tensors: list[torch.Tensor], + process_group_so: ScriptObject, + reduce_op_so: ScriptObject, + async_op: bool = True, + timeout: int = -1, +) -> ScriptObject: + # "allreduce_coalesced_(Tensor[] tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, " + # "__torch__.torch.classes.c10d.ReduceOp reduce_op, bool async_op=True, int timeout=-1) -> __torch__.torch.classes.c10d.Work" + from . import LocalTensor + + reduce_op = reduce_op_so.op() # type: ignore[attr-defined] + ranks, group_offsets, _offset = _prepare_collective_groups(process_group_so) + + for group_offset in group_offsets: + # For the tensors in this group [group_offset + r for r in ranks] + # perform the allreduce on all tensors together + group_ranks = [group_offset + r for r in ranks] + + # For each tensor, perform the reduction operation + for tensor in tensors: + assert isinstance(tensor, LocalTensor), "Input tensor must be a LocalTensor" + if not all(rank in tensor._local_tensors for rank in group_ranks): + continue + # Collect tensors from the specified ranks in this group + group_tensors = [] + for rank in group_ranks: + group_tensors.append(tensor._local_tensors[rank]) + + # Perform the reduction operation + reduced_tensor = _local_reduce(reduce_op, group_tensors) + + # Update all tensors in the group with the reduced result + for rank in group_ranks: + tensor._local_tensors[rank].copy_(reduced_tensor) + + work = FakeWork() + work_so = Work.boxed(work) + return work_so + + +def _local_reduce_scatter_tensor_coalesced_( + output_tensors: list[torch.Tensor], + input_tensors: list[torch.Tensor], + process_group_so: ScriptObject, + reduce_op_so: ScriptObject, + async_op: bool = True, + timeout: int = -1, +) -> ScriptObject: + # "reduce_scatter_tensor_coalesced_(Tensor[] outputs, Tensor[] inputs, " + # "__torch__.torch.classes.c10d.ProcessGroup process_group, " + # "__torch__.torch.classes.c10d.ReduceOp reduce_op, bool async_op=True, " + # "int timeout=-1) -> __torch__.torch.classes.c10d.Work" + + from . import LocalTensor + + reduce_op = reduce_op_so.op() # type: ignore[attr-defined] + ranks, group_offsets, _offset = _prepare_collective_groups(process_group_so) + + for group_offset in group_offsets: + # For the tensors in this group [group_offset + r for r in ranks] + # perform the allreduce on all tensors together + group_ranks = [group_offset + r for r in ranks] + + # For each tensor, perform the reduction operation + for input_tensor, output_tensor in zip(input_tensors, output_tensors): + assert isinstance(input_tensor, LocalTensor), ( + "Input tensor must be a LocalTensor" + ) + assert isinstance(output_tensor, LocalTensor), ( + "Output tensor must be a LocalTensor" + ) + if not all(rank in input_tensor._local_tensors for rank in group_ranks): + continue + if not all(rank in output_tensor._local_tensors for rank in group_ranks): + continue + + # Collect tensors from the specified ranks in this group + group_inputs = [] + for rank in group_ranks: + group_inputs.append(input_tensor._local_tensors[rank]) + + # Perform the reduction operation + reduced_input = _local_reduce(reduce_op, group_inputs) + + reduced_input_splits = torch.split( + reduced_input, reduced_input.size(0) // len(group_ranks), dim=0 + ) + + # Update all tensors in the group with the reduced result + for i, rank in enumerate(group_ranks): + output_tensor._local_tensors[rank].copy_(reduced_input_splits[i]) + + work = FakeWork() + work_so = Work.boxed(work) + return work_so + + +def _local_allgather_base_( + output_tensor: torch.Tensor, + input_tensor: torch.Tensor, + process_group_so: ScriptObject, + async_op: bool = True, + timeout: int = -1, +) -> tuple[torch.Tensor, ScriptObject]: + # "_allgather_base_(Tensor output_tensor, Tensor input_tensor, __torch__.torch.classes.c10d.ProcessGroup + # process_group, bool async_op=True, int timeout=-1) -> (Tensor, __torch__.torch.classes.c10d.Work)"); + from . import LocalTensor + + ranks, group_offsets, _offset = _prepare_collective_groups(process_group_so) + + assert isinstance(output_tensor, LocalTensor), "Output tensor must be a LocalTensor" + assert isinstance(input_tensor, LocalTensor), "Input tensor must be a LocalTensor" + + for group_offset in group_offsets: + group_ranks = [group_offset + r for r in ranks] + + if not all(rank in input_tensor._local_tensors for rank in group_ranks): + continue + if not all(rank in output_tensor._local_tensors for rank in group_ranks): + continue + + gathered_tensors = [] + for rank_i in group_ranks: + gathered_tensors.append(input_tensor._local_tensors[rank_i]) + + gathered_tensor = torch.cat(gathered_tensors, dim=0) + + for rank_i in group_ranks: + output_tensor._local_tensors[rank_i].copy_(gathered_tensor) + + work = FakeWork() + work_so = Work.boxed(work) + return output_tensor, work_so + + +def _local_reduce_scatter_base_( # type: ignore[no-untyped-def] + output_tensor: torch.Tensor, + input_tensor: torch.Tensor, + process_group_so: ScriptObject, + reduce_op_so: ScriptObject, + async_op: bool = True, + timeout: int = -1, +) -> tuple[torch.Tensor, ScriptObject]: + # "_reduce_scatter_base_(Tensor output_tensor, Tensor input_tensor, + # __torch__.torch.classes.c10d.ProcessGroup process_group, __torch__.torch.classes.c10d.ReduceOp reduce_op, + # bool async_op=True, int timeout=-1) -> (Tensor, __torch__.torch.classes.c10d.Work)" + + from . import LocalTensor + + reduce_op = reduce_op_so.op() # type: ignore[attr-defined] + ranks, group_offsets, _offset = _prepare_collective_groups(process_group_so) + + assert isinstance(output_tensor, LocalTensor), "Output tensor must be a LocalTensor" + assert isinstance(input_tensor, LocalTensor), "Input tensor must be a LocalTensor" + + for group_offset in group_offsets: + group_ranks = [group_offset + r for r in ranks] + if not all(rank in input_tensor._local_tensors for rank in group_ranks): + continue + if not all(rank in output_tensor._local_tensors for rank in group_ranks): + continue + + gathered_tensors = [] + for rank_i in group_ranks: + gathered_tensors.append(input_tensor._local_tensors[rank_i]) + + reduced_tensor = _local_reduce(reduce_op, gathered_tensors) + + scattered_tensor = torch.split( + reduced_tensor, + reduced_tensor.size(0) // len(group_ranks), + dim=0, + ) + + for i, rank_i in enumerate(group_ranks): + output_tensor._local_tensors[rank_i].copy_(scattered_tensor[i].clone()) + + work = FakeWork() + work_so = Work.boxed(work) + return output_tensor, work_so + + +def _local_all_gather_( + output_tensors: list[list[torch.Tensor]], + input_tensors: list[torch.Tensor], + process_group_so: ScriptObject, + async_op: bool = True, + timeout: int = -1, +) -> tuple[list[list[torch.Tensor]], ScriptObject]: + # "allgather_(Tensor[][] output_tensors, Tensor[] input_tensors, " + # "__torch__.torch.classes.c10d.ProcessGroup process_group, bool async_op=True, " + # "int timeout=-1) -> (Tensor[][], __torch__.torch.classes.c10d.Work)"); + + from . import LocalTensor + + assert len(output_tensors) == 1 + assert len(input_tensors) == 1 + + input_tensor = input_tensors[0] + # pyrefly: ignore [bad-assignment] + output_tensors = output_tensors[0] + + ranks, group_offsets, _offset = _prepare_collective_groups(process_group_so) + + for i in range(len(output_tensors)): + assert isinstance(output_tensors[i], LocalTensor), ( + "Output tensor must be a LocalTensor" + ) + + for group_offset in group_offsets: + # For the tensors in this group [group_offset + r for r in ranks] + # perform the all_gather on them + group_ranks = [group_offset + r for r in ranks] + + # For each rank in the group, gather from their input tensor + for i, rank_i in enumerate(group_ranks): + # allgather object happens to create pure tensor, so we special case it here + source_tensor = input_tensor + if isinstance(input_tensor, LocalTensor): + source_tensor = input_tensor._local_tensors[rank_i] + # pyrefly: ignore [missing-attribute] + output_tensors[i].copy_(source_tensor) + + work = FakeWork() + work_so = Work.boxed(work) + # pyrefly: ignore [bad-return] + return ([output_tensors], work_so) + + +def _local_allgather_into_tensor_coalesced_( + output_tensors: list[torch.Tensor], + input_tensors: list[torch.Tensor], + process_group_so: ScriptObject, + async_op: bool = True, +) -> ScriptObject: + # "allgather_into_tensor_coalesced_(Tensor[] outputs, Tensor[] inputs, " + # "__torch__.torch.classes.c10d.ProcessGroup process_group, bool async_op=True) " + # "-> __torch__.torch.classes.c10d.Work" + from . import LocalTensor + + ranks, group_offsets, _offset = _prepare_collective_groups(process_group_so) + + # Each output tensor should be sized to hold all gathered inputs + # outputs[i] will contain all inputs[i] from all ranks + assert len(output_tensors) == len(input_tensors), ( + f"Number of outputs ({len(output_tensors)}) must match number of inputs ({len(input_tensors)})" + ) + + for group_offset in group_offsets: + # For the tensors in this group [group_offset + r for r in ranks] + # perform the allgather_into_tensor on them + group_ranks = [group_offset + r for r in ranks] + + # For each input/output pair + for input_tensor, output_tensor in zip(input_tensors, output_tensors): + assert isinstance(input_tensor, LocalTensor), ( + "Input tensor must be a LocalTensor" + ) + assert isinstance(output_tensor, LocalTensor), ( + "Output tensor must be a LocalTensor" + ) + + if not all(rank in input_tensor._local_tensors for rank in group_ranks): + continue + if not all(rank in output_tensor._local_tensors for rank in group_ranks): + continue + + # Gather input_tensor from all ranks into output_tensor + # The output should be a concatenation of all inputs along the first dimension + gathered_tensors = [] + for rank in group_ranks: + gathered_tensors.append(input_tensor._local_tensors[rank]) + + # Concatenate along first dimension and copy to output + if gathered_tensors: + concatenated = torch.cat(gathered_tensors, dim=0) + for rank in group_ranks: + output_tensor._local_tensors[rank].copy_(concatenated) + + work = FakeWork() + work_so = Work.boxed(work) + return work_so + + +def _local_gather_( + output_tensors: list[list[torch.Tensor]], + input_tensors: list[torch.Tensor], + process_group_so: ScriptObject, + root_rank: int, + async_op: bool = True, + timeout: int = -1, +) -> ScriptObject: + # "gather_(Tensor[][] output_tensors, Tensor[] input_tensors, " + # "__torch__.torch.classes.c10d.ProcessGroup process_group, int root_rank, " + # "bool async_op=True, int timeout=-1) -> __torch__.torch.classes.c10d.Work" + raise NotImplementedError( + "LocalTensor does not support MPMD operations like gather " + "(only root rank receives data). Use SPMD collective operations like allgather instead." + ) + + +def _local_scatter_( + output_tensors: list[torch.Tensor], + input_tensors: list[list[torch.Tensor]], + process_group_so: ScriptObject, + root_rank: int, + async_op: bool = True, + timeout: int = -1, +) -> tuple[list[torch.Tensor], ScriptObject]: + # "scatter_(Tensor[] output_tensors, Tensor[][] input_tensors, " + # "__torch__.torch.classes.c10d.ProcessGroup process_group, int root_rank, " + # "bool async_op=True, int timeout=-1) -> (Tensor[], __torch__.torch.classes.c10d.Work)"); + + from . import LocalTensor + + assert len(output_tensors) == 1 + assert len(input_tensors) == 1 + output_tensor = output_tensors[0] + # pyrefly: ignore [bad-assignment] + input_tensors = input_tensors[0] + + ranks, group_offsets, offset = _prepare_collective_groups(process_group_so) + + # We're going to assume SPMD where for every rank group the root_rank is + # the same relative to others + relative_root_rank = root_rank - offset + + assert isinstance(output_tensor, LocalTensor), "Output tensor must be a LocalTensor" + assert len(ranks) == len(input_tensors), (ranks, input_tensors) + + for group_offset in group_offsets: + # For the tensors in this group [group_offset + r for r in ranks] + # perform the scatter on them + group_ranks = [group_offset + r for r in ranks] + if not all(rank in output_tensor._local_tensors for rank in group_ranks): + continue + + # Root rank scatters its input tensors to all ranks in this group + for i, rank in enumerate(group_ranks): + input_tensor = input_tensors[i] + assert isinstance(input_tensor, LocalTensor) + # Each rank i gets the i-th input tensor from the root + source_tensor = input_tensor._local_tensors[ + group_offset + relative_root_rank + ] + output_tensor._local_tensors[rank].copy_(source_tensor) + + work = FakeWork() + work_so = Work.boxed(work) + return (output_tensors, work_so) + + +def _local_alltoall_( + output_tensors: list[torch.Tensor], + input_tensors: list[torch.Tensor], + process_group_so: ScriptObject, + async_op: bool = True, + timeout: int = -1, +) -> tuple[list[torch.Tensor], ScriptObject]: + # "alltoall_(Tensor[] output_tensors, Tensor[] input_tensors, " + # "__torch__.torch.classes.c10d.ProcessGroup process_group, bool async_op=True, " + # "int timeout=-1) -> (Tensor[], __torch__.torch.classes.c10d.Work)"; + + from . import LocalTensor + + ranks, group_offsets, _offset = _prepare_collective_groups(process_group_so) + + assert len(input_tensors) == len(output_tensors) == len(ranks), ( + f"Number of input tensors ({len(input_tensors)}), " + f"output tensors ({len(output_tensors)}), and ranks ({len(ranks)}) must match" + ) + + for group_offset in group_offsets: + # For the tensors in this group [group_offset + r for r in ranks] + # perform the alltoall on them + group_ranks = [group_offset + r for r in ranks] + + # In alltoall, rank i sends input_tensors[j] to rank j and receives into output_tensors[i] from rank j + for i, rank_i in enumerate(group_ranks): + output_tensor = output_tensors[i] + assert isinstance(output_tensor, LocalTensor), ( + "Output tensor must be a LocalTensor" + ) + + if not all(rank in output_tensor._local_tensors for rank in group_ranks): + continue + + for j, rank_j in enumerate(group_ranks): + input_tensor = input_tensors[j] + assert isinstance(input_tensor, LocalTensor), ( + "Input tensor must be a LocalTensor" + ) + + if not all(rank in input_tensor._local_tensors for rank in group_ranks): + continue + + # Rank i's j-th input tensor goes to rank j's i-th output tensor + source_tensor = input_tensor._local_tensors[rank_i] + output_tensor._local_tensors[rank_j].copy_(source_tensor) + + work = FakeWork() + work_so = Work.boxed(work) + return (output_tensors, work_so) + + +def _local_alltoall_base_( + output_tensor: torch.Tensor, + input_tensor: torch.Tensor, + process_group_so: ScriptObject, + output_split_sizes: list[int], + input_split_sizes: list[int], + async_op: bool = True, + timeout: int = -1, +) -> ScriptObject: + # "alltoall_base_(Tensor output, Tensor input, __torch__.torch.classes.c10d.ProcessGroup process_group, " + # "int[] output_split_sizes, int[] input_split_sizes, bool async_op=True, int timeout=-1) -> __torch__.torch.classes.c10d.Work"; + + from . import LocalTensor + + ranks, group_offsets, _offset = _prepare_collective_groups(process_group_so) + + assert isinstance(input_tensor, LocalTensor), "Input tensor must be a LocalTensor" + assert isinstance(output_tensor, LocalTensor), "Output tensor must be a LocalTensor" + # Convert split sizes to lists if they aren't already + if output_split_sizes is not None: + output_split_sizes = list(output_split_sizes) + if input_split_sizes is not None: + input_split_sizes = list(input_split_sizes) + + for group_offset in group_offsets: + # For the tensors in this group [group_offset + r for r in ranks] + # perform the alltoall_base on them + group_ranks = [group_offset + r for r in ranks] + + if not all(rank in input_tensor._local_tensors for rank in group_ranks): + continue + if not all(rank in output_tensor._local_tensors for rank in group_ranks): + continue + + for i, rank_i in enumerate(group_ranks): + # Split input tensor from rank_i according to input_split_sizes + rank_tensor = input_tensor._local_tensors[rank_i] + + if input_split_sizes is not None and len(input_split_sizes) > 0: + # Split the input tensor + input_splits = torch.split(rank_tensor, input_split_sizes, dim=0) + else: + # No split sizes specified, split evenly + split_size = rank_tensor.size(0) // len(group_ranks) + input_splits = torch.split(rank_tensor, split_size, dim=0) + + # Send each split to the corresponding rank + for j, rank_j in enumerate(group_ranks): + if j < len(input_splits): + split_tensor = input_splits[j] + + # Determine where to place this split in the output tensor + if output_split_sizes is not None and len(output_split_sizes) > 0: + # Calculate offset based on output split sizes + output_offset = sum(output_split_sizes[:i]) if i > 0 else 0 + end_offset = ( + output_offset + output_split_sizes[i] + if i < len(output_split_sizes) + else output_tensor._local_tensors[rank_j].size(0) + ) + else: + # No output split sizes, use even splits + split_size = output_tensor._local_tensors[rank_j].size( + 0 + ) // len(group_ranks) + output_offset = i * split_size + end_offset = min( + (i + 1) * split_size, + output_tensor._local_tensors[rank_j].size(0), + ) + + # Copy the split to the appropriate section of the output tensor + output_section = output_tensor._local_tensors[rank_j][ + output_offset:end_offset + ] + if output_section.numel() > 0: + # Reshape split_tensor to match output_section if necessary + if split_tensor.size() != output_section.size(): + split_tensor = split_tensor.view(output_section.size()) + output_section.copy_(split_tensor) + + work = FakeWork() + work_so = Work.boxed(work) + return work_so + + +def _local_barrier( + tensor: torch.Tensor, + process_group_so: ScriptObject, + device_ids: list[int], + async_op: bool = True, + timeout: int = -1, +) -> ScriptObject: + # "barrier(Tensor tensor, __torch__.torch.classes.c10d.ProcessGroup process_group, " + # "int[] device_ids, bool async_op=True, int timeout=-1) -> __torch__.torch.classes.c10d.Work"; + + from . import LocalTensor + + # Barrier is a synchronization primitive - in local simulation, + # we don't need to do any actual work since all "ranks" are in the same process + # Just validate that the tensor is a LocalTensor + assert isinstance(tensor, LocalTensor) + + # In a real distributed setting, barrier would synchronize all processes + # In local simulation, this is essentially a no-op since all ranks are local + work = FakeWork() + work_so = Work.boxed(work) + return work_so + + +def _local_monitored_barrier_( + tensor: torch.Tensor, + process_group_so: ScriptObject, + device_ids: list[int], + timeout: int, + wait_all_ranks: bool, +) -> None: + # "monitored_barrier_(Tensor tensor, __torch__.torch.classes.c10d.ProcessGroup process_group, " + # "int[] device_ids, int timeout, bool wait_all_ranks) -> ()"; + + from . import LocalTensor + + # Monitored barrier is a synchronization primitive with monitoring - in local simulation, + # we don't need to do any actual work since all "ranks" are in the same process + # Just validate that the tensor is a LocalTensor + assert isinstance(tensor, LocalTensor) + + # In a real distributed setting, monitored barrier would synchronize all processes + # and provide monitoring capabilities. In local simulation, this is essentially a no-op + # since all ranks are local and no actual synchronization is needed + return + + +def _local_send( + tensors: list[torch.Tensor], + process_group_so: ScriptObject, + dst: int, + tag: int, +) -> ScriptObject: + # "send(Tensor[] tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, " + # "int dst, int tag) -> __torch__.torch.classes.c10d.Work"; + + from . import LocalRunnerMode, LocalTensor + + assert len(tensors) == 1 + tensor = tensors[0] + + assert isinstance(tensor, LocalTensor), "Input tensor must be a Tensor" + src = int(tensor.__src_rank__) + + LocalRunnerMode.current()._signal_send(src, dst, tensor._local_tensors[src]) + + work = FakeWork() + work_so = Work.boxed(work) + return work_so + + +def _local_recv_( + tensors: list[torch.Tensor], + process_group_so: ScriptObject, + src: int, + tag: int, +) -> ScriptObject: + # "recv_(Tensor[] tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, " + # "int src, int tag) -> __torch__.torch.classes.c10d.Work"; + from . import LocalRunnerMode, LocalTensor + + assert len(tensors) == 1 + tensor = tensors[0] + + assert isinstance(tensor, LocalTensor), "Input tensor must be a Tensor" + dst = int(tensor.__src_rank__) + + def _recv_and_store(timeout: timedelta) -> bool: + def _wait_and_store(obj: object) -> None: + assert isinstance(obj, torch.Tensor), "Expected to receive a Tensor" + assert isinstance(tensor, LocalTensor), "Input tensor must be a Tensor" + tensor._local_tensors[dst] = obj + + LocalRunnerMode.current()._wait_recv(src, dst, _wait_and_store) + return True + + work = PythonCallbackWork(_recv_and_store) + work_so = Work.boxed(work) + return work_so + + +def _local_recv_any_source_( + tensors: list[torch.Tensor], process_group_so: ScriptObject, tag: int +) -> ScriptObject: + # "recv_any_source_(Tensor[] tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, " + # "int tag) -> __torch__.torch.classes.c10d.Work"; + + return _local_recv_(tensors, process_group_so, -1, tag) + + +def _attach_rank(tensor: torch.Tensor, rank: int) -> torch.Tensor: + """ + Attaches rank as an attribute to given tensor so that the send or recv implementation + knows which rank initiates the operation (note under local tensor mode ). + """ + from torch.distributed.tensor import DTensor + + if isinstance(tensor, DTensor): + tensor = tensor._local_tensor + + tensor.__src_rank__ = rank # type: ignore[attr-defined] + return tensor + + +def local_p2p_op( + dst: torch.SymInt, + tensor: torch.Tensor, + op: Callable[[torch.Tensor, int], Work | None], +) -> Work | None | list[Work | None]: + """ + Runs a point-to-point (P2P) operation for all combinations of source and destination ranks. + """ + _check_op(op) + + from . import LocalIntNode + + assert isinstance(dst.node, LocalIntNode), ( + "Expected 'dst' to be a LocalIntNode where the value is the destination rank and key is the source rank" + ) + + w = [] + for s, d in dst.node._local_ints.items(): + tensor = _attach_rank(tensor, s) + w.append(op(tensor, d)) + return w + + +def wait_all(work: Work | None | list[Work | None]) -> None: + """ + Waits for all work objects in the input to complete. + + A single Work object, None, or a list of Work objects (possibly containing None). + If None, does nothing. If a single Work, waits for it to complete. If a list, waits + for each non-None Work in the list to complete. + """ + + if work is None: + return + if isinstance(work, Work): + work = [work] + for w in work: + if w is None: + continue + w.wait() diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_pycute/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_pycute/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e13bcc86e5095a0762417cf0c6cfdaa20951ee5d --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_pycute/__init__.py @@ -0,0 +1,74 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +from .int_tuple import ( + as_tuple, + crd2crd, + crd2idx, + elem_scale, + flatten, + has_none, + idx2crd, + inner_product, + IntTuple, + is_int, + is_tuple, + match_structure, + product, + shape_div, + signum, + slice_, + suffix_product, + tuple_max, +) +from .layout import ( + coalesce, + complement, + composition, + cosize, + filter, + is_layout, + Layout, + LayoutBase, + left_inverse, + logical_divide, + logical_product, + make_layout, + right_inverse, + size, + slice_and_offset, + tiled_divide, + tiled_product, + zipped_divide, + zipped_product, +) +from .typing import Integer diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_pycute/int_tuple.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_pycute/int_tuple.py new file mode 100644 index 0000000000000000000000000000000000000000..bb3406a7399b1af1f892c17e2cf34755aeed244c --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_pycute/int_tuple.py @@ -0,0 +1,269 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Functions for manipulating IntTuples +""" + +from functools import reduce +from itertools import chain +from typing import TypeAlias +from typing_extensions import TypeIs + +from .typing import Integer + + +# Type aliases for better readability +IntTuple: TypeAlias = int | tuple["IntTuple", ...] + + +def is_int(x: object) -> TypeIs[int]: + return isinstance(x, Integer) + + +def is_tuple(x: object) -> TypeIs[tuple]: + return isinstance(x, tuple) + + +def as_tuple(x: IntTuple) -> tuple[IntTuple, ...]: + if is_int(x): + return (x,) + return x + + +def match_structure(a: IntTuple, b: IntTuple) -> bool: + if is_int(a) and is_int(b): + return True + if is_tuple(a) and is_tuple(b): + return len(a) == len(b) and all(match_structure(x, y) for x, y in zip(a, b)) + return False + + +def flatten(t: IntTuple) -> tuple[int, ...]: + if is_tuple(t): + if len(t) == 0: + return () + else: + return tuple(i for a in t for i in flatten(a)) + else: + return (t,) + + +def signum(a: int) -> int: + return bool(a > 0) - bool(a < 0) + + +def product(a: IntTuple) -> int: + if is_tuple(a): + return reduce(lambda val, elem: val * product(elem), a, 1) + else: + return a + + +def inner_product(a: IntTuple, b: IntTuple) -> int: + if is_tuple(a) and is_tuple(b): # tuple tuple + assert len(a) == len(b) + return sum(inner_product(x, y) for x, y in zip(a, b)) + else: # "int" "int" + assert not is_tuple(a) and not is_tuple(b) + return a * b + + +def tuple_max(a: IntTuple) -> int: + if is_tuple(a): + return max(tuple_max(x) for x in a) + else: + return a + + +def elem_scale(a: IntTuple, b: IntTuple) -> IntTuple: + if is_tuple(a): + if is_tuple(b): # tuple tuple + assert len(a) == len(b) + return tuple(elem_scale(x, y) for x, y in zip(a, b)) + else: # tuple "int" + raise AssertionError("Invalid combination: tuple with int") + else: + if is_tuple(b): # "int" tuple + return elem_scale(a, product(b)) + else: # "int" "int" + return a * b + + +# Inclusive prefix ceil div with output congruent to input a +def shape_div(a: IntTuple, b: IntTuple) -> IntTuple: + if is_tuple(a): + if is_tuple(b): # tuple tuple + assert len(a) == len(b) + return tuple(shape_div(x, y) for x, y in zip(a, b)) + else: # tuple "int" + # r = [shape_div(a[0],b)] + [shape_div(a[i],b := shape_div(b, product(a[i-1]))) for i in range(1,len(a))] + r = [] + for v in a: + r.append(shape_div(v, b)) + b = shape_div(b, product(v)) + return tuple(r) + else: + if is_tuple(b): # "int" tuple + return shape_div(a, product(b)) + else: # "int" "int" + assert a % b == 0 or b % a == 0 + return (a + b - 1) // b + + +# Exclusive suffix product with output congruent to input a (lexicographic) +def suffix_product(a: IntTuple, init: IntTuple = 1) -> IntTuple: + # TODO: With all these length asserts, may want to create a zip_strict wrapper. + if is_tuple(a): + if is_tuple(init): # tuple tuple + assert len(a) == len(init) + return tuple(suffix_product(x, i) for x, i in zip(a, init)) + else: # tuple "int" + # Process from right to left for lexicographic ordering + # r = [prefix_product(a[len(a)-1],init)] + + # [prefix_product(a[i],init := init * product(a[i+1])) for i in range(len(a)-1,0)].reverse() + r = [] + + # Calculate products from right to left, appending to list + for i in range(len(a) - 1, -1, -1): + r.append(suffix_product(a[i], init)) + init = init * product(a[i]) + + # Reverse to get correct lexicographic order + r.reverse() + return tuple(r) + else: + if is_tuple(init): # "int" tuple + raise AssertionError("Invalid combination: int with tuple init") + else: # "int" "int" + return init + + +def idx2crd(idx: IntTuple, shape: IntTuple, stride: IntTuple | None = None) -> IntTuple: + if stride is None: + stride = suffix_product(shape) + + if is_tuple(idx): + if is_tuple(shape) and is_tuple(stride): # tuple tuple tuple + assert len(idx) == len(shape) and len(stride) == len(shape) + return tuple(idx2crd(i, s, d) for i, s, d in zip(idx, shape, stride)) + else: # tuple "int" "int" + raise AssertionError("Invalid combination: tuple with int stride") + else: + if is_tuple(shape) and is_tuple(stride): # "int" tuple tuple + assert len(shape) == len(stride) + return tuple(idx2crd(idx, s, d) for s, d in zip(shape, stride)) + else: # "int" "int" "int" + assert not is_tuple(shape) and not is_tuple(stride) + return (idx // stride) % shape # all are ints after type checks + + +def crd2idx( + crd: IntTuple | None, shape: IntTuple, stride: IntTuple | None = None +) -> int: + if stride is None: + stride = suffix_product(shape) + + if is_tuple(crd): + if is_tuple(shape) and is_tuple(stride): # tuple tuple tuple + assert len(crd) == len(shape) and len(stride) == len(shape) + return sum(crd2idx(c, s, d) for c, s, d in zip(crd, shape, stride)) + else: # tuple "int" "int" + raise AssertionError(f"Invalid combination: crd={crd}, shape={shape}") + else: + if crd is None: + crd = 0 + + if is_tuple(shape) and is_tuple(stride): # "int" tuple tuple + assert len(shape) == len(stride) + result = 0 + # Process from right to left for lexicographic ordering + for i in range(len(shape) - 1, 0, -1): + result += crd2idx(crd % product(shape[i]), shape[i], stride[i]) + crd = crd // product(shape[i]) + if len(shape) > 0: + result += crd2idx(crd, shape[0], stride[0]) + return result + else: # "int" "int" "int" + assert not is_tuple(shape) and not is_tuple(stride) + return crd * stride # all are ints after type checks + + +# Transform crd into the dst_shape's iteration space +def crd2crd( + crd: IntTuple, dst_shape: IntTuple, src_shape: IntTuple | None = None +) -> IntTuple: + if is_tuple(crd): + if is_tuple(dst_shape): # tuple tuple + assert len(crd) == len(dst_shape) + return tuple(crd2crd(x, y) for x, y in zip(crd, dst_shape)) + else: # tuple "int" + # Ambiguous unless we have src_shape + assert src_shape is not None + return crd2idx(crd, src_shape) + else: + if is_tuple(dst_shape): # "int" tuple + return idx2crd(crd, dst_shape) + else: # "int" "int" + assert crd < dst_shape + return crd + + +# Filter trg according to crd: keep only elements of trg that are paired with None +def slice_(crd: None | tuple | int, trg: tuple | int) -> tuple | int: + if is_tuple(crd): + if is_tuple(trg): # tuple tuple + assert len(crd) == len(trg) + # match C++ behavior of `filter_tuple` using `tuple_cat(...)` + return tuple( + chain( + *filter( # type: ignore[arg-type] # filter returns Iterator which is compatible + lambda x: x != (), + [slice_(c, s) for c, s in zip(crd, trg)], + ) + ) + ) + else: + raise AssertionError("Invalid combination: tuple crd with int trg") + elif crd is None: + # match C++ behavior `return cute::tuple{b};` + return (trg,) + else: + return () + + +# Determine if None appears at any of an int_tuples' terminals +def has_none(a: None | tuple | int) -> bool: + if is_tuple(a): + return any(has_none(v) for v in a) + else: + return a is None diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_pycute/layout.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_pycute/layout.py new file mode 100644 index 0000000000000000000000000000000000000000..0adf94b5b142b925f7ab35dc82f46c4bf509001a --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_pycute/layout.py @@ -0,0 +1,470 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Definition of CuTe Layouts and functions to manipulate them which works with the order +of lexicographic instead of co-lexicographic as implemented in the original layout.py +""" + +from itertools import chain +from typing import TypeAlias +from typing_extensions import Self, TypeIs + +from .int_tuple import ( + crd2idx, + flatten, + has_none, + IntTuple, + is_int, + is_tuple, + product, + slice_, + suffix_product, +) + + +# Type aliases +CoordinateType: TypeAlias = ( + int | IntTuple | tuple[object, ...] | None +) # Input for slice_ and crd2idx functions + + +class LayoutBase: + pass + + +def is_layout(x: object) -> TypeIs["Layout"]: + return isinstance(x, LayoutBase) + + +class Layout(LayoutBase): + def __init__(self, _shape: IntTuple, _stride: IntTuple | None = None) -> None: + self.shape = _shape + if _stride is None: + self.stride = suffix_product(self.shape) + else: + self.stride = _stride + + # operator == + def __eq__(self, other: object) -> bool: + if not isinstance(other, Layout): + return False + return self.shape == other.shape and self.stride == other.stride + + # operator len(L) (len [rank] like tuples) + def __len__(self) -> int: + if is_tuple(self.shape): + return len(self.shape) + else: + return 1 + + # operator () (map coord to idx) + def __call__(self, *args: CoordinateType) -> Self | int: + """ + Map a logical coordinate to a linear index (Coord has no Underscore slice operators) + OR + Slice the layout and return the sublayout (Coord has an Underscore slice op) + + Follow the same behavior of `Layout::operator(Coord const&)` in cute C++ + """ + if has_none(args): + if len(args) == 1: + return Layout(slice_(args[0], self.shape), slice_(args[0], self.stride)) + else: + return Layout(slice_(args, self.shape), slice_(args, self.stride)) + else: + if len(args) == 1: + return crd2idx(args[0], self.shape, self.stride) # type: ignore[arg-type] + else: + return crd2idx(args, self.shape, self.stride) # type: ignore[arg-type] + + # operator [] (get-i like tuples) + def __getitem__(self, i: int) -> Self: + if is_tuple(self.shape): + return Layout(self.shape[i], self.stride[i]) # type: ignore[index] + else: + assert i == 0 + return Layout(self.shape, self.stride) + + # size(layout) Size of the domain + def size(self) -> int: + return product(self.shape) + + # cosize(layout) Size of the codomain + def cosize(self) -> int: + return self(self.size() - 1) + 1 # type: ignore[operator] + + # print and str + def __str__(self) -> str: + return f"{self.shape}:{self.stride}" + + # error msgs and representation + def __repr__(self) -> str: + return f"Layout({self.shape},{self.stride})" + + +# Type aliases +LayoutOrIntTuple: TypeAlias = Layout | IntTuple +LayoutProfile: TypeAlias = tuple[object, ...] | Layout | None +LayoutInput: TypeAlias = Layout | IntTuple | tuple[object, ...] | None + + +# Make Layout from a list of layouts (each layout it's own mode in the result) +def make_layout(*layouts: Layout | tuple[Layout, ...]) -> Layout: + if len(layouts) == 1 and not is_layout(layouts[0]): + layouts = layouts[0] + + shape, stride = zip(*((a.shape, a.stride) for a in layouts)) # type: ignore[union-attr] + return Layout(shape, stride) + + +# Size of the domain +def size(layout: LayoutOrIntTuple) -> int: + if is_layout(layout): + return layout.size() + return product(layout) + + +# Size of the codomain +def cosize(layout: Layout) -> int: + return layout.cosize() + + +# Layout coalesce -- flatten and combine as many modes as possible while preserving the int-to-int function +def coalesce(layout: Layout, profile: LayoutProfile = None) -> Layout: + if is_tuple(profile): + assert len(layout) >= len(profile) + return make_layout( + chain( + (coalesce(layout[i], profile[i]) for i in range(len(profile))), # type: ignore[arg-type] + (layout[i] for i in range(len(profile), len(layout))), + ) + ) + + result_shape = [1] + result_stride = [0] + # Since we now follow lexicographic order, we need to process from right to left. + # And to make implementation more efficient, we append to the end of list and reverse it in the end. + for shape, stride in zip( + reversed(flatten(layout.shape)), reversed(flatten(layout.stride)) + ): + # skip their shape-1s + if shape == 1: + continue + # replace our shape-1 with anything + elif result_shape[-1] == 1: + result_shape[-1] = shape + result_stride[-1] = stride + # merge modes if the shape*stride match + elif result_shape[-1] * result_stride[-1] == stride: + result_shape[-1] = result_shape[-1] * shape + # append a new mode + else: + result_shape.append(shape) + result_stride.append(stride) + + if len(result_shape) == 1: + return Layout(result_shape[0], result_stride[0]) + else: + result_shape.reverse() + result_stride.reverse() + return Layout(tuple(result_shape), tuple(result_stride)) + + +# Layout filter -- replace all stride-0 modes with size-1 and then coalesce to remove them +def filter(layout: Layout, profile: LayoutProfile = None) -> Layout: + if is_tuple(profile): + assert len(layout) >= len(profile) + return make_layout( + chain( + (filter(layout[i], profile[i]) for i in range(len(profile))), # type: ignore[arg-type] + (layout[i] for i in range(len(profile), len(layout))), + ) + ) + + result_shape = [] + result_stride = [] + for shape, stride in zip(flatten(layout.shape), flatten(layout.stride)): + # skip their shape-1s and stride-0s + if not (shape == 1 or stride == 0): + result_shape.append(shape) + result_stride.append(stride) + + if len(result_shape) == 0: + return Layout(1, 0) + else: + return coalesce(Layout(tuple(result_shape), tuple(result_stride))) + + +# Layout composition +# Use tuples-of-layouts to perform this operation by-mode and None as no-op +def composition(layoutA: Layout, layoutB: LayoutInput) -> Layout: + if layoutB is None: + return layoutA + elif is_int(layoutB): + return composition(layoutA, Layout(layoutB)) + elif is_tuple(layoutB): + assert len(layoutA) >= len(layoutB) + return make_layout( + chain( + (composition(layoutA[i], layoutB[i]) for i in range(len(layoutB))), # type: ignore[arg-type] + (layoutA[i] for i in range(len(layoutB), len(layoutA))), + ) + ) + elif is_tuple(layoutB.shape): + return make_layout(composition(layoutA, layoutB_i) for layoutB_i in layoutB) # type: ignore[arg-type, attr-defined] + + if layoutB.stride == 0: + return Layout(layoutB.shape, 0) + else: + result_shape = [] + result_stride = [] + rest_shape = layoutB.shape + rest_stride = layoutB.stride + flat_A = coalesce(layoutA) + # when left layout is multi-dimensional sublayout, aka, self = (a,b,...,c):(x,y,...,z), layout = s:d, + # for integral s and d means that we want: + # (1) “remove” the first d elements from left, starting from rightmost. (This will increase the stride.) + # (2) “keep” the first s of those strided elements. (This does not affect the stride.) + # For example, if self = (6,2):(2,1), layout = (3:2) + # Step 1: remove the first 2 elements from self with stride increase, i.e., (6,2):(2,1) -> (6,1):(2,2) + # Step 2: keep the first 3 of those strided elements, i.e., (6,1):(2,2) -> (3,1):(2,2) + # Because we are going lexicographically, we go through left layout from right to left. + for curr_shape, curr_stride in zip( + reversed(flatten(flat_A.shape)[1:]), reversed(flatten(flat_A.stride)[1:]) + ): + assert curr_shape % rest_stride == 0 or rest_stride % curr_shape == 0 # type: ignore[operator] + new_shape = min(max(1, curr_shape // rest_stride), rest_shape) # type: ignore[operator] + + if new_shape != 1: + result_shape.append(new_shape) # Append to end, will reverse later + result_stride.append(rest_stride * curr_stride) + + rest_shape = rest_shape // new_shape # type: ignore[operator] + rest_stride = -( + -rest_stride // curr_shape # type: ignore[operator] + ) # Python exclusive impl: "//" is always floor div so == ceil_div(abs(rest_stride), curr_shape) * signum(rest_stride) + + # When left has single-size sublayout or reach the last sublayout, aka, left = a:b, layout = s:d, + # the result is rather trivial: left o layout = a:b o s:d = s:(b*d). + # For example, if self = (6:2), layout = (3:2), the result is (3:(2*2)) = (3:4). + if rest_shape != 1 or len(result_shape) == 0: + result_shape.append(rest_shape) # Append to end, will reverse later + result_stride.append(rest_stride * flatten(flat_A.stride)[0]) + + # Reverse the lists because we build lists in reverse order (append to end), this way it is more efficient. + result_shape.reverse() + result_stride.reverse() + + if len(result_shape) == 1: + return Layout(result_shape[0], result_stride[0]) # type: ignore[arg-type] + else: + return Layout(tuple(result_shape), tuple(result_stride)) # type: ignore[arg-type] + + +# Layout complement +def complement(layout: LayoutOrIntTuple, max_idx: int = 1) -> Layout: + if is_int(layout): + return complement(Layout(layout)) + + result_shape = [] + result_stride = [] + current_idx = 1 + + sorted_DS = sorted(zip(flatten(layout.stride), flatten(layout.shape))) # type: ignore[union-attr] + for stride, shape in sorted_DS: + if stride == 0 or shape == 1: + continue + + in_bound = current_idx <= shape * stride + # To support symbolic value which can't be evaluated now + assert (type(in_bound) is not bool) or in_bound + + result_shape.append(stride // current_idx) + result_stride.append(current_idx) + current_idx = shape * stride + + result_shape.append((max_idx + current_idx - 1) // current_idx) # ceil_div + result_stride.append(current_idx) + # This is different from original pycute implementation, because we want to follow the lexicographic order here + # where the right-most dimension is the innermost dimension (smallest stride). + result_shape.reverse() + result_stride.reverse() + + return coalesce(Layout(tuple(result_shape), tuple(result_stride))) + + +# Layout right inverse +def right_inverse(layout: LayoutOrIntTuple | None) -> Layout | None: + if layout is None: + return None + elif is_int(layout): + return Layout(layout) + + result_shape = [] + result_stride = [] + current_idx = 1 + + flat_shape = flatten(layout.shape) # type: ignore[union-attr] + flat_stride = flatten(layout.stride) # type: ignore[union-attr] + sorted_DSA = sorted(zip(flat_stride, flat_shape, suffix_product(flat_shape))) # type: ignore[arg-type] + for stride, shape, rstride in sorted_DSA: + if shape == 1: + continue + if current_idx != stride: + break + + result_shape.append(shape) + result_stride.append(rstride) + current_idx = shape * stride + + result_shape.reverse() + result_stride.reverse() + return coalesce(Layout(tuple(result_shape), tuple(result_stride))) + + +# Layout left inverse +def left_inverse(layout: LayoutOrIntTuple | None) -> Layout | None: + if layout is None: + return None + elif is_int(layout): + return Layout(layout) + return right_inverse(make_layout(complement(layout), layout)) # type: ignore[arg-type] + + +# Split a layout by the composition of B and the "rest" +# Use tuples-of-layouts to perform this operation by-mode and None as no-op +def logical_divide(layoutA: Layout, layoutB: LayoutInput) -> Layout: + if layoutB is None: + return layoutA + elif is_int(layoutB): + return logical_divide(layoutA, Layout(layoutB)) + elif is_tuple(layoutB): + assert len(layoutA) >= len(layoutB) + return make_layout( + chain( + ( + logical_divide(layoutA[i], layoutB[i]) # type: ignore[arg-type] + for i in range(len(layoutB)) + ), + (layoutA[i] for i in range(len(layoutB), len(layoutA))), + ) + ) + + return composition( + layoutA, + make_layout(layoutB, complement(layoutB, size(layoutA))), + ) + + +# Reproduce a layoutA over a layoutB +# Use tuples-of-layouts to perform this operation by-mode and None as no-op +def logical_product(layoutA: Layout, layoutB: LayoutInput) -> Layout: + if layoutB is None: + return layoutA + elif is_int(layoutB): + return logical_divide(layoutA, Layout(layoutB)) + elif is_tuple(layoutB): + assert len(layoutA) >= len(layoutB) + return make_layout( + chain( + ( + logical_product(layoutA[i], layoutB[i]) # type: ignore[arg-type] + for i in range(len(layoutB)) + ), + (layoutA[i] for i in range(len(layoutB), len(layoutA))), + ) + ) + + return make_layout( + layoutA, + composition(complement(layoutA, size(layoutA) * cosize(layoutB)), layoutB), + ) + + +# Gather the modes from a hierarchical logical_divide or logical_product +def hier_unzip( + splitter: object, + layoutA: Layout, + layoutB: LayoutInput, +) -> Layout: + if layoutB is None: + return make_layout(Layout(1, 0), layoutA) + elif is_tuple(layoutB): + assert len(layoutA) >= len(layoutB) + # A layout with shape ((A,a),(B,b),(C,c)) + split = make_layout( + hier_unzip(splitter, layoutA[i], layoutB[i]) # type: ignore[arg-type] + for i in range(len(layoutB)) + ) + # Gather to shape ((A,B,C,...),(a,b,c,...,y,z)) + return make_layout( + make_layout(split[i][0] for i in range(len(layoutB))), # type: ignore[arg-type] + make_layout( + chain( # type: ignore[arg-type] + (split[i][1] for i in range(len(layoutB))), + (layoutA[i] for i in range(len(layoutB), len(layoutA))), + ) + ), + ) + + # splitter must return a rank-2 layout + return splitter(layoutA, layoutB) # type: ignore[operator] + + +# Apply logical divide hierarchically and gather the split modes into two modes +def zipped_divide(layoutA: Layout, layoutB: LayoutInput) -> Layout: + return hier_unzip(logical_divide, layoutA, layoutB) + + +# Perform logical divide hierarchically and gather tiles (B-layouts) into a new mode +def tiled_divide(layoutA: Layout, layoutB: LayoutInput) -> Layout: + result = zipped_divide(layoutA, layoutB) + return make_layout([result[0]] + [result[1][i] for i in range(len(result[1]))]) # type: ignore[arg-type] + + +# Apply logical product hierarchically and gather the split modes into two modes +def zipped_product(layoutA: Layout, layoutB: LayoutInput) -> Layout: + return hier_unzip(logical_product, layoutA, layoutB) + + +# Perform logical product hierarchically and gather tiles (B-layouts) into a new mode +def tiled_product(layoutA: Layout, layoutB: LayoutInput) -> Layout: + result = zipped_product(layoutA, layoutB) + return make_layout([result[0]] + [result[1][i] for i in range(len(result[1]))]) # type: ignore[arg-type] + + +def slice_and_offset(crd: tuple[object, ...], layout: Layout) -> tuple[Layout, int]: + return ( + Layout(slice_(crd, layout.shape), slice_(crd, layout.stride)), + crd2idx(crd, layout.shape, layout.stride), # type: ignore[arg-type] + ) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_pycute/typing.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_pycute/typing.py new file mode 100644 index 0000000000000000000000000000000000000000..5e6fe0a9c66e800186b4d84ef52cc12d6baeb1f6 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_pycute/typing.py @@ -0,0 +1,42 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +from abc import ABC + + +class Integer(ABC): # noqa: B024 # Uses __subclasshook__ instead of abstract methods + @classmethod + def __subclasshook__(cls, c: type) -> bool: + if c in [bool, float]: + return False + + return issubclass(c, int) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..85a313c779e7aa87726f425146048fcd37efd261 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/__init__.py @@ -0,0 +1 @@ +from .api import _shard_tensor, load_with_process_group, shard_module, shard_parameter diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..55dbbed944fc75f40e943ce131e16fb00036c952 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/__pycache__/__init__.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/__pycache__/_utils.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/__pycache__/_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..80f8deac9dbf6021f68a52ef028758e9d621ff0e Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/__pycache__/_utils.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/__pycache__/api.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/__pycache__/api.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..318ac42f7aaf75431a01c2fab3b5d44b26ee8006 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/__pycache__/api.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/__pycache__/common_op_utils.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/__pycache__/common_op_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b70f56f6286a5765a90ed9d66286b1601521fb27 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/__pycache__/common_op_utils.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/__pycache__/metadata.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/__pycache__/metadata.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..68b9e8144ed4e1b1b3a04db8ecb800bdabb1aabb Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/__pycache__/metadata.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/__pycache__/op_registry_utils.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/__pycache__/op_registry_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..605972f9201a691b4d1bfcf3198b823e2ad7b1d8 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/__pycache__/op_registry_utils.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/__pycache__/sharder.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/__pycache__/sharder.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fbb9a82fe0acfaacdc7c0f44fa70c33277d1c99d Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/__pycache__/sharder.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/_utils.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..6fd641b3f9443faa64b6b54c3ab209f8167a56f7 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/_utils.py @@ -0,0 +1,32 @@ +from collections.abc import Sequence + +import torch +from torch.distributed._shard.metadata import ShardMetadata + + +DEPRECATE_MSG = "Please use DTensor instead and we are deprecating ShardedTensor." + + +def narrow_tensor_by_index( + tensor: torch.Tensor, + offsets: Sequence[int], + sizes: Sequence[int], +) -> torch.Tensor: + """ + Narrow the tensor according to ``offsets`` and ``sizes``. + """ + narrowed_tensor = tensor + for idx, (offset, size) in enumerate(zip(offsets, sizes)): + if size < tensor.size(idx): + # Reshape to get shard for this rank and we don't want autograd + # recording here for the narrow op and 'local_shard' should be a + # leaf variable in the autograd graph. + narrowed_tensor = narrowed_tensor.narrow(idx, offset, size) + return narrowed_tensor + + +def narrow_tensor(tensor: torch.Tensor, metadata: ShardMetadata) -> torch.Tensor: + """ + Narrow the tensor according to the metadata + """ + return narrow_tensor_by_index(tensor, metadata.shard_offsets, metadata.shard_sizes) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/api.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/api.py new file mode 100644 index 0000000000000000000000000000000000000000..82589119d7afa6086b6b6289954d88676516b620 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/api.py @@ -0,0 +1,305 @@ +# mypy: allow-untyped-defs +from contextlib import contextmanager + +import torch +import torch.distributed as dist +import torch.nn as nn +from torch.distributed import distributed_c10d +from torch.distributed._shard.sharded_tensor import ShardedTensor + +from .sharder import Sharder +from .sharding_plan import ShardingPlan +from .sharding_spec import ChunkShardingSpec, ShardingSpec + + +def _shard_tensor( + tensor: torch.Tensor, sharding_spec: ShardingSpec, src_rank=0, process_group=None +) -> ShardedTensor: + """ + Given a :class:`torch.Tensor`, it shards that tensor according to the provided + ``sharding_spec``. ``src_rank`` denotes the source rank which would be + used as the ground truth of the data which would be scattered as shards + across the rest of the ranks. + + Args: + tensor (:class:`torch.Tensor`): Tensor needs to be sharded. + sharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`): The specification + describing how to shard the Tensor. + + Keyword args: + src_rank (int, optional): The source rank which is used as the ground truth of + the data for the parameter that would be sharded and scattered + across the rest of the ranks. + Default: 0. + process_group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. + + Returns: + A :class:`ShardedTensor` sharded from the given tensor. + + .. warning:: + Only :class:`torch.distributed._shard.sharding_spec.ChunkShardingSpec` is + currently supported as the ``sharding_spec``. + """ + if not tensor.is_contiguous(): + raise ValueError("input tensor is not a contiguous Tensor") + + pg = ( + process_group + if process_group is not None + else distributed_c10d._get_default_group() + ) + world_size = dist.get_world_size(pg) + current_rank = dist.get_rank(pg) + + # Validate src_rank and sharding_spec are same across all ranks. + gathered_list = [None] * world_size + dist.all_gather_object(gathered_list, (src_rank, sharding_spec), group=pg) + + for idx, entry in enumerate(gathered_list): + if src_rank != entry[0]: # type: ignore[index] + raise ValueError( + f"src_rank={src_rank} on rank: {current_rank} does not " # type: ignore[index] + f"match with src_rank={entry[0]} on rank: {idx}" # type: ignore[index] + ) + if sharding_spec != entry[1]: # type: ignore[index] + raise ValueError( + f"sharding_spec={sharding_spec} on rank: {current_rank} does not " # type: ignore[index] + f"match with sharding_spec={entry[1]} on rank: {idx}" # type: ignore[index] + ) + + st = sharding_spec.shard(tensor, src_rank=src_rank, process_group=pg) + + return st + + +def shard_parameter( + module: torch.nn.Module, + param_name: str, + sharding_spec: ShardingSpec, + src_rank=0, + process_group=None, +): + """ + Given a :class:`torch.nn.Module`, a ``param_name`` for a parameter in that + module, it shards that parameter according to the provided + ``sharding_spec``. ``src_rank`` denotes the source rank which would be + used as the ground truth of the data which would be scattered as shards + across the rest of the ranks. + + This method replaces ``module.param_name`` with a + :class:`torch.distributed._sharded_tensor.ShardedTensor` + + Args: + module (:class:`torch.nn.Module`): Module whose parameter needs to be sharded. + param_name (str): Name of the parameter of ``module`` that needs to be sharded. + sharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`): The specification + describing how to shard the Tensor. + + Keyword args: + src_rank (int, optional): The source rank which is used as the ground truth of + the data for the parameter that would be sharded and scattered + across the rest of the ranks. + Default: 0. + process_group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. + + .. warning:: + Only :class:`torch.distributed._shard.sharding_spec.ChunkShardingSpec` is + currently supported as the ``sharding_spec``. + """ + # Perform some validation first. + if not hasattr(module, param_name): + raise AttributeError(f"{module._get_name()} has no attribute `{param_name}`") + + tensor = getattr(module, param_name) + if not isinstance(tensor, torch.Tensor): + raise ValueError( + f"Expected {type(module).__name__}.{param_name} to be a Tensor, but found {type(tensor).__name__}" + ) + + if not tensor.is_contiguous(): + raise ValueError(f"param: {param_name} is not a contiguous Tensor") + + st = _shard_tensor(tensor, sharding_spec, src_rank, process_group) + + # Replace param with ShardedTensor. + module.register_parameter(param_name, nn.Parameter(st)) + + +# Tracks the current process group in the load context manager. +_CURRENT_PROCESS_GROUP: dist.ProcessGroup | None = None + + +@contextmanager +def load_with_process_group(process_group): + """ + Context manager to set the process group with which to load a ShardedTensor. + """ + global _CURRENT_PROCESS_GROUP + if _CURRENT_PROCESS_GROUP is not None: + raise RuntimeError( + 'ProcessGroup already set by previous "load_with_process_group" ' + "context manager" + ) + _CURRENT_PROCESS_GROUP = process_group + try: + yield process_group + finally: + _CURRENT_PROCESS_GROUP = None + + +def _get_current_process_group(): + """ + Retrieves the current process group set by ``load_with_process_group``. + If not set, it just returns the default group. + """ + global _CURRENT_PROCESS_GROUP + if _CURRENT_PROCESS_GROUP is None: + return distributed_c10d._get_default_group() + else: + return _CURRENT_PROCESS_GROUP + + +def _reshard_output( + module: torch.nn.Module, resharding_spec: ShardingSpec +) -> torch.nn.Module: + """ + Hook a module with output resharding in the forward pass according + to the given ``resharding_spec``. + + Args: + module (:class:`torch.nn.Module`): Module whose output needs to be resharded. + resharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`): + The specification describing how the output of the module will be resharded. + + Returns: + A :class:`torch.nn.Module` object with reshard API hooked. + """ + + def hook_func(_module, _input, output): + if isinstance(output, ShardedTensor): + return output.reshard(resharding_spec) + return output + + module.register_forward_hook(hook_func) + return module + + +def _collect_local_shard(module: torch.nn.Module) -> torch.nn.Module: + """ + Hook a module with local shards collection in the forward pass. + + This API is typically used to convert a sharded representation back to data parallel + representation. In particular, it returns the local tensor for this Shard. If the + size along the sharding dimension for the local tensor is 1, this dimension is removed + from the final result. For example a [4, 16] ShardedTensor across 4 ranks is typically + a local Tensor of size [16] across each rank and not [1, 16] across each rank. + + Args: + module (:class:`torch.nn.Module`): Module whose output is ShardedTensor and the + local tensor value needs to be returned. + + Returns: + A :class:`torch.nn.Module` object with collection API hooked. + """ + + def hook_func(_module, _input, output): + if isinstance(output, ShardedTensor): + local_tensor = output.local_tensor() + # Squeeze the # of dimensions manually, only applicable to ChunkShardingSpec + sharding_spec = output._sharding_spec + if ( + isinstance(sharding_spec, ChunkShardingSpec) + and local_tensor.size(sharding_spec.dim) == 1 # type: ignore[attr-defined, arg-type] + ): + local_tensor = local_tensor.squeeze( + output._sharding_spec.dim # type: ignore[attr-defined] + ) + return local_tensor + + module.register_forward_hook(hook_func) + return module + + +def shard_module(module: nn.Module, plan: ShardingPlan, src_rank=0, process_group=None): + """ + Shards a given module according to the provided sharding `plan`. This method + first shards all the parameters according to the given sharding `plan`. Then if + `output_plan` and `return_local_tensor` are specified in the sharding `plan`, it + will tag the output of modules according `output_plan`, convert the module's + output back to data parallel according to `return_local_tensor`. + + Needs to be called on all ranks in an SPMD fashion. + + Args: + module (:class:`torch.nn.Module`): The module to apply sharding to + plan (:class:`torch.distributed._shard.sharding_plan.ShardingPlan`): + The ShardingPlan which specified param name to ShardingSpec to apply to + each parameter. + + Keyword args: + src_rank (int, optional): The source rank which is used as the ground truth of + the data for the module that would be sharded and scattered across the rest + of the ranks. + Default: 0. + process_group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. + """ + # record Sharder paths for sanity check on the plan to ensure items in the plan + # does not conflict with the submodule tree that the Sharder is working with + sharder_paths = [] + for name, spec in plan.plan.items(): + if isinstance(spec, Sharder): + sharder_paths.append(name) + + # shard the parameter according to the ShardingPlan + for name, spec in plan.plan.items(): + if isinstance(spec, ShardingSpec): + # if found a sharding spec, try to shard the parameter + module_path, _, param_name = name.rpartition(".") + + for sharder_path in sharder_paths: + if module_path.startswith(sharder_path): + raise RuntimeError( + f"ShardingPlan is in-valid, trying to shard a parameter: {name}," + f" but there's already a Sharder entry for module {sharder_path}," + f" parameter sharding should not conflict with the submodule tree" + f" that a Sharder is working with!" + ) + + mod = module.get_submodule(module_path) + shard_parameter( + mod, param_name, spec, src_rank=src_rank, process_group=process_group + ) + elif isinstance(spec, Sharder): + parent_mod_path, _, _mod_name = name.rpartition(".") + if name == "": + raise KeyError("Module path must not be empty for custom sharder!") + mod = module.get_submodule(name) + parent_mod = module.get_submodule(parent_mod_path) + sharded_mod = spec.shard(mod) + # swap this submodule with the sharded module + parent_mod.mod_name = sharded_mod + else: + raise TypeError( + f"Only `ShardingSpec` and `Sharder` are supported to shard '{name}'" + ) + + # reshard output if there's an entry in `reshard_output` for this module + if plan.output_plan is not None: + for module_path, output_spec in plan.output_plan.items(): + if isinstance(output_spec, ShardingSpec): + mod = module.get_submodule(module_path) + _reshard_output(mod, output_spec) + else: + raise TypeError( + f"Only `ShardingSpec` is supported as output_plan for '{module_path}'" + ) + # convert the output back to data parallel for the modules appears in + # `return_local_tensor` of the plan, we will call `_collect_local_shard` + # to collect the local tensor for output of modules + if plan.return_local_tensor is not None: + for module_path in plan.return_local_tensor: + mod = module.get_submodule(module_path) + _collect_local_shard(mod) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/checkpoint/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/checkpoint/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..85915636a014640d8fff5a29db602c4a114f1b1d --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/checkpoint/__init__.py @@ -0,0 +1,19 @@ +# Keep old package for BC purposes, this file should be removed once +# everything moves to the `torch.distributed.checkpoint` package. +import sys +import warnings + +import torch +from torch.distributed.checkpoint import * # noqa: F403 + + +with warnings.catch_warnings(): + warnings.simplefilter("always") + warnings.warn( + "`torch.distributed._shard.checkpoint` will be deprecated, " + "use `torch.distributed.checkpoint` instead", + DeprecationWarning, + stacklevel=2, + ) + +sys.modules["torch.distributed._shard.checkpoint"] = torch.distributed.checkpoint diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/common_op_utils.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/common_op_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..c98b8c87ca2c7ceb1608a59673738a7e57333035 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/common_op_utils.py @@ -0,0 +1,64 @@ +# mypy: allow-untyped-defs + +import torch +from torch.utils import _pytree as pytree + + +def _basic_validation(op, args=(), kwargs=None): + """ + Common validation across all ops go in here. + """ + from torch.distributed._shard.sharded_tensor import ShardedTensor + + if len(args) == 0 and (kwargs is None or len(kwargs) == 0): + raise ValueError(f" No input for '{op.__name__}'!") + + # Validate types + has_distributed_tensor = False + + def is_distributed_tensor(e): + nonlocal has_distributed_tensor + if isinstance(e, ShardedTensor): + has_distributed_tensor = True + + pytree.tree_map_(is_distributed_tensor, args) + pytree.tree_map_(is_distributed_tensor, kwargs) + + if not has_distributed_tensor: + raise TypeError( + f"torch function '{op.__name__}', with args: {args} and " + f"kwargs: {kwargs} are called without any distributed tensor!" + ) + + # Validate all distributed tensors use the same PG. + cur_pg: torch.distributed.ProcessGroup | None = None + + def validate_pg(e): + nonlocal cur_pg + if isinstance(e, ShardedTensor): + if cur_pg is not None and e._process_group is not cur_pg: + raise RuntimeError( + "All distributed tensors should use the " + "same ProcessGroup if used together in an op." + ) + cur_pg = e._process_group + + pytree.tree_map_(validate_pg, args) + pytree.tree_map_(validate_pg, kwargs) + + +def _register_default_op(op, decorator): + @decorator(op) + def tensor_default_op(types, args=(), kwargs=None, pg=None): + """ + Handles ``__torch_function__`` dispatch for the default tensor ops that + behave the same as ``torch.Tensor`` such as ``torch.Tensor.shape`` or + ``torch.Tensor.dtype``. We simply lower to the real op call with + DisableTorchFunctionSubclass context like ``torch.Tensor.__torch_function__`` + to avoid recursions. + """ + if kwargs is None: + kwargs = {} + + with torch._C.DisableTorchFunctionSubclass(): + return op(*args, **kwargs) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/metadata.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/metadata.py new file mode 100644 index 0000000000000000000000000000000000000000..63ef073b1c494ab450bca79c83f3867548140fd8 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/metadata.py @@ -0,0 +1,63 @@ +# mypy: allow-untyped-defs +from dataclasses import dataclass +from functools import reduce + +from torch.distributed.remote_device import _remote_device + + +@dataclass +class ShardMetadata: + """ + Represents a shard of the overall Tensor including its + offsets, lengths and device placement. + + Args: + shard_offsets(List[int]): Offsets in the original tensor indicating + the start offsets for this shard. Should have the same rank as + the original tensor. + shard_sizes(List[int]): Integers indicating the size of each + dimension for this shard. Should have the same rank as the + original tensor. + placement(:class:`torch.distributed._remote_device`): + Specifies the placement of this shard. + """ + + __slots__ = ["shard_offsets", "shard_sizes", "placement"] + + shard_offsets: list[int] + shard_sizes: list[int] + placement: _remote_device | None + + def __init__( + self, + shard_offsets: list[int], + shard_sizes: list[int], + placement: str | _remote_device | None = None, + ): + self.shard_offsets = shard_offsets + self.shard_sizes = shard_sizes + if isinstance(placement, str): + self.placement = _remote_device(placement) + else: + self.placement = placement + if len(self.shard_offsets) != len(self.shard_sizes): + raise ValueError( + f"shard_offsets and shard_sizes should have " + f"the same number of elements, found {len(self.shard_offsets)} " + f"and {self.shard_sizes} respectively" + ) + + for i in range(len(self.shard_offsets)): + if self.shard_offsets[i] < 0: + raise ValueError("shard_offsets should be >=0") + if self.shard_sizes[i] < 0: + raise ValueError("shard_sizes should be >= 0") + + def __hash__(self): + def _hash_reduce(a, b): + return (a << 8) + hash(b) + + res = reduce(_hash_reduce, self.shard_offsets, 37) + res = reduce(_hash_reduce, self.shard_sizes, res) + res = _hash_reduce(res, self.placement) + return res diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/op_registry_utils.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/op_registry_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..12e0b1895e2f053e6c4a975cb6d3c0118baf50bb --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/op_registry_utils.py @@ -0,0 +1,41 @@ +# mypy: allow-untyped-defs +import functools +from inspect import signature + +from .common_op_utils import _basic_validation + + +""" +Common utilities to register ops on ShardedTensor +and PartialTensor. +""" + + +def _register_op(op, func, op_table): + """ + Performs basic validation and registers the provided op in the given + op_table. + """ + if len(signature(func).parameters) != 4: + raise TypeError( + f"Custom sharded op function expects signature: " + f"(types, args, kwargs, process_group), but received " + f"signature: {signature(func)}" + ) + + op_table[op] = func + + +def _decorator_func(wrapped_func, op, op_table): + """ + Decorator function to register the given ``op`` in the provided + ``op_table`` + """ + + @functools.wraps(wrapped_func) + def wrapper(types, args, kwargs, process_group): + _basic_validation(op, args, kwargs) + return wrapped_func(types, args, kwargs, process_group) + + _register_op(op, wrapper, op_table) + return wrapper diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/sharded_optim/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/sharded_optim/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..effae2e3cd1b89027cf06bf6e603e6ca84551520 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/sharded_optim/__init__.py @@ -0,0 +1,53 @@ +from collections.abc import Iterator +from typing import Union + +import torch.nn as nn +from torch.distributed._shard.sharded_tensor import ShardedTensor + +from .api import ShardedOptimizer + + +def named_params_with_sharded_tensor( + module: nn.Module, + prefix: str = "", + recurse: bool = True, +) -> Iterator[tuple[str, nn.Parameter | ShardedTensor]]: + r"""Returns an iterator over module parameters (together with the + ShardedTensor parameters), yielding both the name of the parameter + as well as the parameter itself. This is typically passed to a + :class:torch.distributed._shard.sharded_optim.ShardedOptimizer + + Args: + prefix (str): prefix to prepend to all parameter names. + recurse (bool): if True, then yields parameters of this module + and all submodules. Otherwise, yields only parameters that + are direct members of this module. + + Yields: + (str, Union[Tensor, ShardedTensor]): Tuple containing + the name and parameter (or ShardedTensor parameter) + + Example:: + + >>> # xdoctest: +SKIP + >>> model = torch.nn.Linear(*linear_size) + >>> shard_parameter(model, "weight", spec) + >>> for name, param in named_params_with_sharded_tensor(model): + >>> if name in ['weight']: + >>> print(param.size()) + + """ + modules = module.named_modules(prefix=prefix) if recurse else [(prefix, module)] + + memo = set() + for mod_prefix, mod in modules: + # find all sharded tensor params + for name, val in vars(mod).items(): + if isinstance(val, ShardedTensor) and val not in memo: + memo.add(val) + name = mod_prefix + ("." if mod_prefix else "") + name + yield name, val + + # find all nn.Parameters + for name, val in module.named_parameters(): + yield name, val diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/sharded_optim/api.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/sharded_optim/api.py new file mode 100644 index 0000000000000000000000000000000000000000..2989e85496090782fcdb39d0e6613b82155ea23c --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/sharded_optim/api.py @@ -0,0 +1,102 @@ +# mypy: allow-untyped-defs +from collections.abc import Mapping +from typing import Any + +import torch.optim as optim +from torch import Tensor +from torch.distributed._shard.sharded_tensor import ShardedTensor + + +class ShardedOptimizer(optim.Optimizer): + def __init__( + self, + named_params: Mapping[str, Tensor | ShardedTensor], + optimizer_class, + *optimizer_args, + **optimizer_kwargs, + ): + """ + ShardedOptimizer collects all tensors and local shard tensors of + ShardedTensor, then use these tensors as ``params`` for optimizers + + Args: + named_params (Dict[str, Union[Tensor, ShardedTensor]]) : a Dict + of parameters, where key is the parameter key, value is either + Tensor or ShardedTensor parameter. + optimizer_class (torch.optim.Optimizer): the Optimizer to use + locally, i.e. torch.optim.SGD, torch.optim.Adagrad, etc. + *optimizer_args: the arguments to initialize the optimizer. + **optimizer_kwargs: the key-word arguments to initialize the optimizer. + + """ + tensors: list[Tensor] = [] + for value in named_params.values(): + if isinstance(value, ShardedTensor): + tensors.extend( + local_shard.tensor for local_shard in value.local_shards() + ) + else: + tensors.append(value) + + self.named_params = named_params + self._optim = optimizer_class(tensors, *optimizer_args, **optimizer_kwargs) + self.param_groups = self._optim.param_groups + self.state = self._optim.state + + def zero_grad(self, set_to_none: bool = True): # type: ignore[override] + r"""Resets the gradients of all optimized :class:`torch.Tensor` s. + + Args: + set_to_none (bool): instead of setting to zero, set the grads to None. + This will in general have lower memory footprint, and can modestly improve performance. + However, it changes certain behaviors. For example: + 1. When the user tries to access a gradient and perform manual ops on it, + a None attribute or a Tensor full of 0s will behave differently. + 2. If the user requests ``zero_grad(set_to_none=True)`` followed by a backward pass, ``.grad``\ s + are guaranteed to be None for params that did not receive a gradient. + 3. ``torch.optim`` optimizers have a different behavior if the gradient is 0 or None + (in one case it does the step with a gradient of 0 and in the other it skips + the step altogether). + """ + self._optim.zero_grad(set_to_none) + + def step(self, closure=None): + r"""Performs a single optimization step (parameter update). + + Args: + closure (Callable): A closure that reevaluates the model and + returns the loss. Optional for most optimizers. + + .. note:: + Unless otherwise specified, this function should not modify the + ``.grad`` field of the parameters. + """ + self._optim.step(closure) + + def state_dict(self) -> dict[str, Any]: + """ + Returned state and param_groups will contain parameter keys + instead of parameter indices like torch.optim.Optimizer. + This allows for advanced functionality like optimizer re-sharding to be implemented. + """ + # TODO: implement state_dict + raise NotImplementedError("ShardedOptimizer state_dict not implemented yet!") + + def load_state_dict(self, state_dict: Mapping[str, Any]): + r"""Loads the ShardedOptimizer state. + + Args: + state_dict (dict): ShardedOptimizer state. Should be an object returned + from a call to :meth:`state_dict`. + """ + # TODO: implement load_state_dict + raise NotImplementedError( + "ShardedOptimizer load_state_dict not implemented yet!" + ) + + def add_param_group(self, param_group: Any): + r"""Add a new param group""" + # TODO: implement add_param_group + raise NotImplementedError( + "ShardedOptimizer add_param_group not implemented yet!" + ) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/sharded_tensor/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/sharded_tensor/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3d3af3ed3595378ca8522384f295ef6ea9930ebf --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/sharded_tensor/__init__.py @@ -0,0 +1,490 @@ +# mypy: allow-untyped-defs +import functools +from typing import TYPE_CHECKING + +import torch +from torch.distributed._shard.op_registry_utils import _decorator_func + +from .api import ( + _CUSTOM_SHARDED_OPS, + _SHARDED_OPS, + Shard, + ShardedTensor, + ShardedTensorBase, + ShardedTensorMetadata, + TensorProperties, +) +from .metadata import ShardMetadata # noqa: F401 + + +if TYPE_CHECKING: + from torch.distributed._shard.sharding_spec import ShardingSpec +else: + ShardingSpec = "ShardingSpec" + + +def empty( + sharding_spec: ShardingSpec, + *size, + dtype=None, + layout=torch.strided, + requires_grad=False, + pin_memory=False, + memory_format=torch.contiguous_format, + process_group=None, + init_rrefs=False, +) -> ShardedTensor: + """ + Returns a :class:`ShardedTensor` filled with uninitialized data. + Needs to be called on all ranks in an SPMD fashion. + + Args: + sharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`): The specification + describing how to shard the Tensor. + size (int...): a sequence of integers defining the shape of the output + tensor. Can be a variable number of arguments or a collection like a list or tuple. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. + memory_format (:class:`torch.memory_format`, optional): the desired memory format of + returned Tensor. Default: ``torch.contiguous_format``. + process_group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. + init_rrefs (bool, optional): Whether or not to initialize + :class:`torch.distributed.rpc.RRef`s pointing to remote shards. + Need to initialize the RPC Framework if specified as ``True``. + Default: ``False``. + + Returns: + A :class:`ShardedTensor` object on each rank + """ + return ShardedTensor( + sharding_spec, + *size, + dtype=dtype, + layout=layout, + requires_grad=requires_grad, + pin_memory=pin_memory, + memory_format=memory_format, + process_group=process_group, + init_rrefs=init_rrefs, + ) + + +def ones( + sharding_spec: ShardingSpec, + *size, + dtype=None, + layout=torch.strided, + requires_grad=False, + pin_memory=False, + memory_format=torch.contiguous_format, + process_group=None, + init_rrefs=False, +) -> ShardedTensor: + """ + Returns a :class:`ShardedTensor` with the scalar value 1. + Needs to be called on all ranks in an SPMD fashion. + + Args: + sharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`): The specification + describing how to shard the Tensor. + size (int...): a sequence of integers defining the shape of the output + tensor. Can be a variable number of arguments or a collection like a list or tuple. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. + process_group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. + init_rrefs (bool, optional): Whether or not to initialize + :class:`torch.distributed.rpc.RRef`s pointing to remote shards. + Need to initialize the RPC Framework if specified as ``True``. + Default: ``False``. + + Returns: + A :class:`ShardedTensor` object on each rank + """ + return full( + sharding_spec, + size, + fill_value=1, + dtype=dtype, + layout=layout, + requires_grad=requires_grad, + pin_memory=pin_memory, + memory_format=memory_format, + process_group=process_group, + init_rrefs=init_rrefs, + ) + + +def zeros( + sharding_spec: ShardingSpec, + *size, + dtype=None, + layout=torch.strided, + requires_grad=False, + pin_memory=False, + memory_format=torch.contiguous_format, + process_group=None, + init_rrefs=False, +) -> ShardedTensor: + """ + Returns a :class:`ShardedTensor` filled with the scalar value 0. + Needs to be called on all ranks in an SPMD fashion. + + Args: + sharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`): The specification + describing how to shard the Tensor. + size (int...): a sequence of integers defining the shape of the output + tensor. Can be a variable number of arguments or a collection like a list or tuple. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. + process_group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. + init_rrefs (bool, optional): Whether or not to initialize + :class:`torch.distributed.rpc.RRef`s pointing to remote shards. + Need to initialize the RPC Framework if specified as ``True``. + Default: ``False``. + + Returns: + A :class:`ShardedTensor` object on each rank + """ + return full( + sharding_spec, + size, + fill_value=0, + dtype=dtype, + layout=layout, + requires_grad=requires_grad, + pin_memory=pin_memory, + memory_format=memory_format, + process_group=process_group, + init_rrefs=init_rrefs, + ) + + +def full( + sharding_spec: ShardingSpec, + size, + fill_value, + *, + dtype=None, + layout=torch.strided, + requires_grad=False, + pin_memory=False, + memory_format=torch.contiguous_format, + process_group=None, + init_rrefs=False, +) -> ShardedTensor: + """ + Creates a :class:`ShardedTensor` filled with fill_value. The tensor's dtype + is inferred from fill_value. If dtype is specified, it will override the + inferred type from fill_value. Needs to be called on all ranks in an SPMD fashion. + Args: + sharding_spec (:class:`torch.distributed._sharding_spec.ShardingSpec`): The specification + describing how to shard the Tensor. + size (int...): a list, tuple, or `torch.Size` of integers defining the shape of the + output tensor. + fill_value (Scalar) - the value to fill the output tensor with. + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. + process_group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. + init_rrefs (bool, optional): Whether or not to initialize + :class:`torch.distributed.rpc.RRef`s pointing to remote shards. + Need to initialize the RPC Framework if specified as ``True``. + Default: ``False``. + Returns: + A :class:`ShardedTensor` object on each rank + """ + sharded_tensor = ShardedTensor( + sharding_spec, + *size, + dtype=dtype, + layout=layout, + requires_grad=requires_grad, + pin_memory=pin_memory, + memory_format=memory_format, + process_group=process_group, + init_rrefs=init_rrefs, + ) + torch.nn.init.constant_(sharded_tensor, fill_value) # type: ignore[arg-type] + return sharded_tensor + + +def rand( + sharding_spec: ShardingSpec, + *size, + dtype=None, + layout=torch.strided, + requires_grad=False, + pin_memory=False, + memory_format=torch.contiguous_format, + process_group=None, + init_rrefs=False, +) -> ShardedTensor: + """ + Creates a :class:`ShardedTensor` filled with random numbers from a uniform distribution + on the interval :math:`[0, 1)`. The shape of the tensor is defined by the + variable argument `size`. Needs to be called on all ranks in an SPMD fashion. + + Args: + sharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`): The specification + describing how to shard the Tensor. + size (int...): a list, tuple, or `torch.Size` of integers defining the shape of the + output tensor. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. + process_group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. + init_rrefs (bool, optional): Whether or not to initialize + :class:`torch.distributed.rpc.RRef`s pointing to remote shards. + Need to initialize the RPC Framework if specified as ``True``. + Default: ``False``. + + Returns: + A :class:`ShardedTensor` object on each rank + """ + sharded_tensor = ShardedTensor( + sharding_spec, + *size, + dtype=dtype, + layout=layout, + requires_grad=requires_grad, + pin_memory=pin_memory, + memory_format=memory_format, + process_group=process_group, + init_rrefs=init_rrefs, + ) + torch.nn.init.uniform_(sharded_tensor, 0, 1) # type: ignore[arg-type] + return sharded_tensor + + +def randn( + sharding_spec: ShardingSpec, + *size, + dtype=None, + layout=torch.strided, + requires_grad=False, + pin_memory=False, + memory_format=torch.contiguous_format, + process_group=None, + init_rrefs=False, +) -> ShardedTensor: + """ + Creates a :class:`ShardedTensor` filled with random numbers from a uniform distribution + with mean `0` and variance `1` (also called standard normal distribution). The shape + of the tensor is defined by the variable argument `size`. Needs to be called on all ranks + in an SPMD fashion. + + Args: + sharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`): The specification + describing how to shard the Tensor. + size (int...): a list, tuple, or `torch.Size` of integers defining the shape of the + output tensor. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. + process_group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. + init_rrefs (bool, optional): Whether or not to initialize + :class:`torch.distributed.rpc.RRef`s pointing to remote shards. + Need to initialize the RPC Framework if specified as ``True``. + Default: ``False``. + + Returns: + A :class:`ShardedTensor` object on each rank + """ + sharded_tensor = ShardedTensor( + sharding_spec, + *size, + dtype=dtype, + layout=layout, + requires_grad=requires_grad, + pin_memory=pin_memory, + memory_format=memory_format, + process_group=process_group, + init_rrefs=init_rrefs, + ) + torch.nn.init.normal_(sharded_tensor, 0, 1) # type: ignore[arg-type] + return sharded_tensor + + +def init_from_local_shards( + local_shards: list[Shard], *global_size, process_group=None, init_rrefs=False +) -> ShardedTensor: + """ + Creates an :class:`ShardedTensor` from local shards and the global metadata. + Needs to be called on all ranks in an SPMD fashion. + + Args: + local_shards (List[:class `torch.distributed._shard.sharded_tensor.Shard`]): A list + of shards that represent the local shards on this rank. + global_size (int...): a list, tuple, or `torch.Size` of integers defining the + shape of the overall sharded tensor. + + Keyword args: + process_group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. + init_rrefs (bool, optional): Whether or not to initialize + :class:`torch.distributed.rpc.RRef`s pointing to remote shards. + Need to initialize the RPC Framework if specified as ``True``. + Default: ``False``. + + Returns: + A :class:`ShardedTensor` object handle on this rank + + + Examples: + Suppose we want construct a sharded tensor on two ranks, global size = (10, 5), + each shard have a (5, 5) local tensor, we can do it like below: + + on rank 0: + >>> # xdoctest: +SKIP("not distributed") + >>> local_shard_metadata = ShardMetadata( + >>> shard_offsets=[0, 0], + >>> shard_lengths=[5, 5], + >>> placement="rank:0/cuda:0" + >>> ) + >>> local_shards = [Shard(torch.randn(5, 5), local_shard_metadata)] + >>> sharded_tensor = init_from_local_shards(local_shards, [10, 5]) + + on rank 1: + >>> # xdoctest: +SKIP("not distributed") + >>> local_shard_metadata = ShardMetadata( + >>> shard_offsets=[5, 0], + >>> shard_lengths=[5, 5], + >>> placement="rank:1/cuda:1" + >>> ) + >>> local_shards = [Shard(torch.randn(5, 5), local_shard_metadata)] + >>> sharded_tensor = init_from_local_shards(local_shards, [10, 5]) + """ + return ShardedTensor._init_from_local_shards( + local_shards, *global_size, process_group=process_group, init_rrefs=init_rrefs + ) + + +def state_dict_hook(module, destination, prefix, local_metadata): + """ + Hook to add ShardedTensor to Module's ``state_dict``. Needs to be + registered to the Module using + :meth:`torch.nn.Module._register_state_dict_hook`. + """ + for submodule_name, submodule in module.named_modules(): + for attr_name, attr in submodule.__dict__.items(): + if isinstance(attr, ShardedTensor): + mod_prefix = prefix + submodule_name + key = mod_prefix + ("." if mod_prefix else "") + attr_name + destination[key] = attr + + +def pre_load_state_dict_hook( + module, + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, +): + """ + Pre-load state dict hook to add ShardedTensor to the module. + """ + for submodule_name, submodule in module.named_modules(): + for attr_name in submodule.__dict__: + mod_prefix = prefix + submodule_name + key = mod_prefix + ("." if mod_prefix else "") + attr_name + if key in state_dict: + if isinstance(state_dict[key], ShardedTensor): + setattr(submodule, attr_name, state_dict[key]) + + +def custom_sharded_op_impl(func): + """ + Provides a way for users to write their own custom sharded operator. This + can be used to override existing ShardedTensor operators or write a new + one not supported by ShardedTensor. If the operator in question is covered + by ``__torch_function__`` dispatch and has a ShardedTensor as any of its + parameters, the function provided will be invoked for that operator. + + Example:: + >>> # xdoctest: +SKIP + >>> @custom_sharded_op_impl(torch.nn.functional.linear) + >>> def my_custom_sharded_linear(types, args, kwargs, process_group): + >>> ... + >>> # xdoctest: +SKIP("Undefined variables") + >>> input = torch.rand(10, 32) + >>> weight = sharded_tensor.rand(32, 16) + >>> bias = torch.rand(16) + >>> # This will call 'my_custom_sharded_linear' + >>> torch.nn.functional.linear(input, weight, bias) + + The types, args and kwargs parameters are the same parameters that are + passed to ``__torch_function__`` dispatch API + (https://pytorch.org/docs/stable/notes/extending.html#extending-torch). + There is an additional ``process_group`` parameter which is the + process_group used for the ShardedTensor and can be used by + implementations for communications within a sharded implementation. + + Args: + func(Callable): Torch function for which we want to provide a sharded + implementation (ex: torch.nn.functional.linear) + """ + return functools.partial(_decorator_func, op=func, op_table=_CUSTOM_SHARDED_OPS) + + +def _sharded_op_impl(func): + """ + Decorator to register a default sharded op. + """ + return functools.partial(_decorator_func, op=func, op_table=_SHARDED_OPS) + + +# Import all builtin sharded ops +from ._ops import * # noqa: F403 diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/sharded_tensor/_ops/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/sharded_tensor/_ops/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..be6d01fc8e54ee214fafa847c9261db375d8b87e --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/sharded_tensor/_ops/__init__.py @@ -0,0 +1,13 @@ +import torch.distributed._shard.sharded_tensor._ops.misc_ops +import torch.distributed._shard.sharded_tensor._ops.tensor_ops + +# Import all ChunkShardingSpec ops +from torch.distributed._shard.sharding_spec.chunk_sharding_spec_ops.embedding import ( + sharded_embedding, +) +from torch.distributed._shard.sharding_spec.chunk_sharding_spec_ops.embedding_bag import ( + sharded_embedding_bag, +) + +from .binary_cmp import allclose, equal +from .init import constant_, kaiming_uniform_, normal_, uniform_ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/sharded_tensor/_ops/misc_ops.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/sharded_tensor/_ops/misc_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..8b84c1684c32456989e3998b3d4c30c34cb5dbf4 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/sharded_tensor/_ops/misc_ops.py @@ -0,0 +1,12 @@ +# mypy: allow-untyped-defs +import torch +from torch.distributed._shard.sharded_tensor import _sharded_op_impl + + +# This is used by `_apply()` within module.py to set new +# parameters after apply a certain method, we should follow +# the future behavior of overwriting the existing tensor +# instead of doing in-place change using `.data = `. +@_sharded_op_impl(torch._has_compatible_shallow_copy_type) +def tensor_has_compatible_shallow_copy_type(types, args=(), kwargs=None, pg=None): + return False diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/sharded_tensor/api.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/sharded_tensor/api.py new file mode 100644 index 0000000000000000000000000000000000000000..a8e8677d6ae7c91cf8d871ff697e057b554b794c --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/sharded_tensor/api.py @@ -0,0 +1,1368 @@ +# mypy: allow-untyped-defs +from __future__ import annotations # type: ignore[attr-defined] + +import copy +import operator +import threading +import warnings +import weakref +from dataclasses import dataclass +from functools import reduce +from typing import cast, TYPE_CHECKING +from typing_extensions import deprecated + +import torch +import torch.distributed as dist +import torch.distributed._shard.sharding_spec as shard_spec +from torch._utils import _get_device_module +from torch.distributed import distributed_c10d, rpc +from torch.distributed._shard._utils import DEPRECATE_MSG +from torch.distributed._shard.sharding_spec._internals import ( + check_tensor, + validate_non_overlapping_shards_metadata, +) +from torch.distributed._shard.sharding_spec.api import ( + _dispatch_custom_op, + _has_custom_op, +) +from torch.distributed.remote_device import _remote_device +from torch.utils import _pytree as pytree + +from .metadata import ShardedTensorMetadata, TensorProperties +from .reshard import reshard_local_shard, reshuffle_local_shard +from .shard import Shard +from .utils import ( + _flatten_tensor_size, + _parse_and_validate_remote_device, + _validate_output_tensor_for_gather, + build_global_metadata, + build_metadata_from_local_shards, +) + + +if TYPE_CHECKING: + from collections.abc import Callable, Sequence + + from torch.distributed._shard.metadata import ShardMetadata + + +# Tracking for sharded tensor objects. +_sharded_tensor_lock = threading.Lock() +_sharded_tensor_current_id = 0 +_sharded_tensor_map: dict[int, weakref.ReferenceType[ShardedTensor]] = {} + +# Default sharded ops +_SHARDED_OPS: dict[Callable, Callable] = {} + +# Customized user ops +_CUSTOM_SHARDED_OPS: dict[Callable, Callable] = {} + + +def _register_remote_shards( + sharded_tensor_id: int, rrefs: list[rpc.RRef[Shard]], rpc_rank: int +): + with _sharded_tensor_lock: + if sharded_tensor_id not in _sharded_tensor_map: + raise RuntimeError( + f"Could not find sharded_tensor_id: {sharded_tensor_id} in map: {_sharded_tensor_map.keys()}" + ) + + sharded_tensor = _sharded_tensor_map[sharded_tensor_id]() + if sharded_tensor is None: + raise RuntimeError("ShardedTensor weakref has been deallocated") + else: + sharded_tensor._register_remote_shards(rrefs, rpc_rank) + + +class ShardedTensorBase(torch.Tensor): + _sharding_spec: shard_spec.ShardingSpec + _metadata: ShardedTensorMetadata + _local_shards: list[Shard] + + def __new__(cls, sharding_spec: shard_spec.ShardingSpec, *size, **kwargs): + # Use __new__ to construct a wrapper tensor, for recording tensor + # properties and logging purposes. + torch._C._log_api_usage_once("torch.distributed._shard.sharded_tensor") + + # check sharding spec and build sharded tensor metadata + if not isinstance(sharding_spec, shard_spec.ShardingSpec): + raise ValueError(f"Expecting ShardingSpec but got: {type(sharding_spec)}") + + sizes = _flatten_tensor_size(size) + dtype = kwargs["dtype"] + layout = kwargs["layout"] + pin_memory = kwargs["pin_memory"] + requires_grad = kwargs["requires_grad"] + + if dtype is None: + dtype = torch.get_default_dtype() + + tensor_properties = TensorProperties( + dtype, layout, requires_grad, pin_memory=pin_memory + ) + sharded_tensor_metadata = sharding_spec.build_metadata( + sizes, tensor_properties=tensor_properties + ) + + r = torch.Tensor._make_wrapper_subclass( + cls, + sizes, + dtype=dtype, + layout=layout, + pin_memory=pin_memory, + requires_grad=requires_grad, + ) + # set sharding spec + r._sharding_spec = sharding_spec + # set metadata + r._metadata = sharded_tensor_metadata + # set local shards + r._local_shards = [] + return r + + def metadata(self) -> ShardedTensorMetadata: + """ + Returns a :class:`ShardedTensorMetadata` object corresponding to the + metadata for the entire tensor. + """ + return self._metadata + + def local_shards(self) -> list[Shard]: + """ + Returns a list of :class:`Shard' corresponding to the + local shards for this rank. Returns an empty list if the current rank + does not host any shards for this Tensor. + """ + return self._local_shards + + @classmethod + def _init_from_local_shards_and_global_metadata( + cls, + local_shards: list[Shard], + sharded_tensor_metadata: ShardedTensorMetadata, + sharding_spec=None, + ) -> ShardedTensorBase: + """ + Initialize a ShardedTensorBase with local shards and a global + ShardedTensorMetadata built on each rank. + Warning: This API is experimental and subject to change. It does + not do cross rank validations, and fully rely on the user + for the correctness of sharded_tensor_metadata on each rank + """ + shards_metadata = sharded_tensor_metadata.shards_metadata + tensor_properties = sharded_tensor_metadata.tensor_properties + + if len(shards_metadata) == 0: + raise ValueError("shards_metadata must not be empty!") + + if tensor_properties.layout != torch.strided: + raise ValueError("Only torch.strided layout is currently supported") + + if sharding_spec is None: + spec = shard_spec._infer_sharding_spec_from_shards_metadata(shards_metadata) + else: + spec = sharding_spec + + sharded_tensor_base = ShardedTensorBase.__new__( + ShardedTensor, + spec, + sharded_tensor_metadata.size, + dtype=tensor_properties.dtype, + layout=tensor_properties.layout, + pin_memory=tensor_properties.pin_memory, + requires_grad=tensor_properties.requires_grad, + ) + + # check if shards_metadata have overlap shards + validate_non_overlapping_shards_metadata(shards_metadata) + + # check if the shards_metadata is compatible with overall size of the sharded tensor. + check_tensor(shards_metadata, list(sharded_tensor_metadata.size)) + + # done validation, add local_shards + sharded_tensor_base._local_shards = local_shards + return sharded_tensor_base + + @classmethod + def __torch_dispatch__(cls, func, types, args=(), kwargs=None): # type: ignore[override] + raise RuntimeError( + f"A {cls.__name__} object is being used from c++ while calling {func.__module__}.{func.__name__} " + "but the there is no custom __torch_dispatch__ implementation for it." + ) + + +class ShardedTensor(ShardedTensorBase): + """ + ShardedTensor is an torch.Tensor subclass to represent Tensors that are sharded + across multiple devices and multiple processes. + + ShardedTensor is initialized in an SPMD like fashion where each rank + initializes the ShardedTensor. The ShardedTensor object on each rank + then only stores the local shard for the Tensor and provides global + metadata for all the shards. + + ShardedTensor doesn't provide any Tensor like operations but is a wrapper + providing the Tensor representing the local shard and the global metadata. + Using these, users can build their custom distributed._sharded computations + on top of this primitive. The local shards are all initialized using the + create_op specified by tensor_init_params.create_op, e.g., torch.ones, or + torch.empty + + Args: + sharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`): The specification + describing how to shard the Tensor. + size (int...): a sequence of integers defining the shape of the output + tensor. Can be a variable number of arguments or a collection like a list or tuple. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. + memory_format (:class:`torch.memory_format`, optional): the desired memory format of + returned Tensor. Default: ``torch.contiguous_format``. + init_rrefs (bool, optional): Whether or not to initialize + :class:`torch.distributed.rpc.RRef`s pointing to remote shards. + Need to initialize the RPC Framework if specified as ``True``. + Default: ``False``. + + .. note:: ShardedTensor uses collectives to do various operations, i.e. it + uses all_gather to do cross rank validations. For NCCL-based process + groups, internal tensor representations of objects must be moved to the + GPU device before communication takes place. In this case, the device + used is given by ``torch.cuda.current_device()`` and it is the user's + responsibility to ensure that this is set so that each rank has an + individual GPU, via ``torch.cuda.set_device()`` + + """ + + def __new__(cls, sharding_spec: shard_spec.ShardingSpec, *size, **kwargs): + self = super().__new__(cls, sharding_spec, *size, **kwargs) + return self + + def __init__( + self, + sharding_spec: shard_spec.ShardingSpec, + *size, + dtype=None, + layout=torch.strided, + requires_grad=False, + pin_memory=False, + memory_format=torch.contiguous_format, + process_group=None, + init_rrefs=False, + ): + # prepare initialization, initialize fields like + # _process_group, _local_shards, etc. + self._prepare_init(process_group=process_group, init_rrefs=init_rrefs) + + if layout != torch.strided: + raise ValueError("Only torch.strided layout is currently supported") + + if memory_format != torch.contiguous_format: + raise ValueError( + "Only torch.contiguous_format memory_format is currently supported" + ) + + self._metadata.tensor_properties.memory_format = memory_format + + current_rank = dist.get_rank() # global rank + + for shard_metadata in self._metadata.shards_metadata: + rank, device = _parse_and_validate_remote_device( + self._process_group, shard_metadata.placement + ) + if rank == current_rank: + local_tensor = _create_tensor_from_params( + shard_metadata.shard_sizes, + local_device=device, + tensor_properties=self._metadata.tensor_properties, + ) + self._local_shards.append(Shard(local_tensor, shard_metadata)) + + # do post initialization (i.e. register sharded_tensor_id, initialize_rpc) + self._post_init() + + def _prepare_init(self, process_group=None, init_rrefs=False): + self._init_rrefs = init_rrefs + self._sharded_tensor_id = None + + self._process_group = self._normalize_pg(process_group) + self._remote_shards: dict[int, list[rpc.RRef[Shard]]] = {} + + def _post_init(self): + # Initialize RPC if available. + if self._init_rrefs: + with _sharded_tensor_lock: + global _sharded_tensor_current_id, _sharded_tensor_map + # pyrefly: ignore [bad-assignment] + self._sharded_tensor_id = _sharded_tensor_current_id + # pyrefly: ignore [unsupported-operation] + _sharded_tensor_map[self._sharded_tensor_id] = weakref.ref(self) + _sharded_tensor_current_id += 1 + + if not rpc._is_current_rpc_agent_set(): + raise RuntimeError( + "RPC Framework needs to be initialized using" + " torch.distributed.rpc.init_rpc if init_rrefs is set to True" + ) + self._init_rpc() + + def __del__(self): + # Clean up the global map. + with _sharded_tensor_lock: + global _sharded_tensor_current_id, _sharded_tensor_map + if ( + hasattr(self, "_sharded_tensor_id") + and self._sharded_tensor_id in _sharded_tensor_map + ): + _sharded_tensor_map.pop(self._sharded_tensor_id) # type: ignore[call-overload] + + def _init_rpc(self): + # Validate PG and RPC ranks match. + pg_rank = dist.get_rank() + rpc_rank = rpc.get_worker_info().id + if pg_rank != rpc_rank: + raise ValueError( + f"Default ProcessGroup and RPC ranks must be " + f"the same for ShardedTensor, found process group rank: " + f"{pg_rank} and RPC rank: {rpc_rank}" + ) + + self._remote_shards = {} + + # Gather all the sharded tensor ids. + worker_infos = rpc._get_current_rpc_agent().get_worker_infos() + rank_to_name = {} + name_to_rank = {} + + for worker_info in worker_infos: + rank_to_name[worker_info.id] = worker_info.name + name_to_rank[worker_info.name] = worker_info.id + + all_tensor_ids = rpc.api._all_gather(self._sharded_tensor_id) + + # Share the local shards to the entire world. + futs = [] + rpc_rank = rpc.get_worker_info().id + for rank in range(dist.get_world_size()): + # Skip self. + if rank == dist.get_rank(): + continue + + if len(self.local_shards()) != 0: + rrefs: list[rpc.RRef[Shard]] = [ + rpc.RRef(shard) for shard in self.local_shards() + ] + fut = rpc.rpc_async( + rank, + _register_remote_shards, + args=(all_tensor_ids[rank_to_name[rank]], rrefs, rpc_rank), + ) + futs.append(fut) + + torch.futures.wait_all(futs) + + # Barrier for all RPCs to finish on all ranks. + rpc.api._all_gather(None) + + def _get_preferred_device(self) -> torch.device: + """ + Return the preferred device to be used when creating tensors for collectives. + This method takes into account the associated process group + """ + backend = dist.get_backend(self._process_group) + if backend == dist.Backend.NCCL: + return torch.device(torch.cuda.current_device()) + elif backend == dist.Backend.GLOO: + return torch.device("cpu") + else: + backend_config = dist.BackendConfig(backend) + for device, backend_str in backend_config.get_device_backend_map().items(): + if backend_str == backend and device != "cpu": + return torch.device( + device, _get_device_module(device).current_device() + ) + return torch.device("cpu") + + def gather( # type: ignore[override] + self, + dst: int = 0, + out: torch.Tensor | None = None, + enforce_dtype: bool = False, + dtype: torch.dtype | None = None, + ) -> None: + """ + Creates a full :class:`Tensor` on rank ``dst`` by gathering all shards of the + sharded tensor. + + The API needs to be called on all ranks in SPMD fashion. All ranks should have + the same ``dst``. ``out`` should be a tensor of the same size as the overall + size of the sharded tensor on ``dst`` and ``None`` on all other ranks. + + Args: + dst(int): The rank where full tensor is constructed. + Default: 0 + out (:class `torch.Tensor`, optional): The output full tensor. + Must to be provided ONLY on ``dst`` rank. + Default: ``None`` + enforce_dtype (bool): Deprecated, please use dtype instead. Force the + gathered tensors to be the same type as input and output. + dtype (torch.dtype): Force the gathered tensors to be this dtype. + Default: ``None`` + """ + + def shard_size(shard_md): + return reduce(operator.mul, shard_md.shard_sizes) # type: ignore[attr-defined] + + if enforce_dtype: + warnings.warn( + "`enforce_dtype` is deprecated. Please use `dtype` instead.", + FutureWarning, + stacklevel=2, + ) + + rank = dist.get_rank(self._process_group) + full_size = self.metadata().size + _validate_output_tensor_for_gather(rank, dst, full_size, out) + + local_shards = self.local_shards() + world_size = dist.get_world_size(self._process_group) + rank_sizes = [0 for _ in range(world_size)] + max_rank_size = 0 + shard_placement: dict[ShardMetadata, tuple[int, int]] = {} + # collect sizes + for shard_md in self.metadata().shards_metadata: + shard_rank = cast(_remote_device, shard_md.placement).rank() + assert shard_rank is not None + + shard_placement[shard_md] = (shard_rank, rank_sizes[shard_rank]) + rank_sizes[shard_rank] += shard_size(shard_md) + max_rank_size = max(max_rank_size, rank_sizes[shard_rank]) + + gather_list: list[torch.Tensor] | None + if rank == dst: + assert out is not None + if enforce_dtype: + # enforce_dtype is deprecated. Do it for backward compatibility. + dtype = out.dtype + # TODO make it as a view of out tensor + gather_list = [ + torch.empty((max_rank_size,), device=out.device, dtype=dtype) + for _ in range(world_size) + ] + else: + gather_list = None + + with torch.no_grad(): + if enforce_dtype and len(local_shards) > 0: + # enforce_dtype is deprecated. Do it for backward compatibility. + dtype = local_shards[0].tensor.dtype + data = torch.empty( + max_rank_size, device=self._get_preferred_device(), dtype=dtype + ) + + for shard in local_shards: + src = shard.tensor.flatten() + if src.nelement() == 0: + warnings.warn( + "Gathering a tensor with zero elements on rank " + str(rank), + stacklevel=2, + ) + continue + shard_offset = shard_placement[shard.metadata][1] + data[shard_offset : shard_offset + src.numel()].copy_(src) + + dist.gather( + tensor=data, + gather_list=gather_list, + dst=dst, + group=self._process_group, + ) + if rank != dst: + return + # In _validate_output_tensor_for_gather, we raise if out == None and rank == dst + out = cast(torch.Tensor, out) + assert gather_list is not None + + full_size = self.metadata().size + dims = len(full_size) + for shard_md in self.metadata().shards_metadata: + rank, rank_offset = shard_placement[shard_md] + tensor = gather_list[rank] + tensor = tensor[rank_offset : rank_offset + shard_size(shard_md)] + tensor = tensor.view(shard_md.shard_sizes) + + out_narrow_view = out + for dim in range(dims): + out_narrow_view = out_narrow_view.narrow( + dim, + shard_md.shard_offsets[dim], + shard_md.shard_sizes[dim], + ) + + out_narrow_view.copy_(tensor) + + def cpu( + self, memory_format=torch.preserve_format, process_group=None + ) -> ShardedTensor: + """ + Returns a copy of this object in CPU memory. + + If this ShardedTensor is already on CPU memory, then no copy is + performed and original object is returned. + + .. note:: When moving a ShardedTensor from GPU to CPU, the ShardedTensor might + need to be managed by a different type of ProcessGroup(i.e. ProcessGroupGloo), + it is the user's responsibility to explicitly pass in a new process_group that + is compatible with CPU. + """ + # TODO: make this a __torch_function__ op once ShardedTensor becomes a + # torch.Tensor subclass, see https://github.com/pytorch/pytorch/issues/75402 + if ( + memory_format != torch.preserve_format + and memory_format != torch.contiguous_format + ): + raise RuntimeError( + "Only `torch.contiguous_format` or " + "`torch.preserve_format` is supported!" + ) + all_on_cpu = True + for meta in self.metadata().shards_metadata: + all_on_cpu &= meta.placement.device().type == "cpu" # type: ignore[union-attr] + + # if every shard is already on CPU, return the original object + if all_on_cpu: + return self + + # if not, returns a copy of this object on CPU + list_shards: list[Shard] = [] + # move all local shards to cpu, and change metadata + for shard in self._local_shards: + cpu_tensor = shard.tensor.cpu(memory_format=memory_format) # type: ignore[call-arg] + metadata = copy.deepcopy(shard.metadata) + metadata.placement._device = torch.device("cpu") # type: ignore[union-attr] + list_shards.append(Shard(cpu_tensor, metadata)) + + st_meta = copy.deepcopy(self.metadata()) + for meta in st_meta.shards_metadata: + if meta.placement.device().type != "cpu": # type: ignore[union-attr] + meta.placement._device = torch.device("cpu") # type: ignore[union-attr] + + pg = self._process_group if process_group is None else process_group + st_cpu = ShardedTensor._init_from_local_shards_and_global_metadata( + list_shards, + sharded_tensor_metadata=st_meta, + process_group=pg, + init_rrefs=self._init_rrefs, + ) + return st_cpu + + def cuda( + self, + device=None, + non_blocking=False, + memory_format=torch.preserve_format, + process_group=None, + ) -> ShardedTensor: + """ + Returns a copy of this object in CUDA memory, if the original ShardedTensor + is on CPU, we will move the local shard to the current GPU device of each + process in a SPMD fashion. + If this ShardedTensor is already on CUDA memory and local shards on each rank are + already on current device, we still returns a new ShardedTensor object with new + metadata, but no underlying data movements are performed. + .. note:: When moving a ShardedTensor from CPU to GPU, the ShardedTensor might + need to be managed by a different type of ProcessGroup(i.e. ProcessGroupNCCL), + it is the user's responsibility to explicitly pass in a new process_group that + is compatible with GPU. + """ + if ( + memory_format != torch.preserve_format + and memory_format != torch.contiguous_format + ): + raise RuntimeError( + "Only `torch.contiguous_format` or " + "`torch.preserve_format` is supported!" + ) + + if device is not None: + device = torch.device(device) if isinstance(device, str) else device + assert ( + isinstance(device, torch.device) + and device.index == torch.cuda.current_device() + ), ( + """Only device without device id (e.g. "cpu" or "cuda") is expected for ShardedTensor!""" + ) + + current_device = torch.device(torch.cuda.current_device()) + # returns a copy of ShardedTensor on CUDA current device + list_shards: list[Shard] = [] + # move all local shards to current device, and change metadata + # if local shards already on the current device, there's no + # real data movement, only the metadata are copied. + for shard in self._local_shards: + cuda_tensor = shard.tensor.cuda( + device=current_device, + non_blocking=non_blocking, + memory_format=memory_format, + ) # type: ignore[call-arg] + metadata = copy.deepcopy(shard.metadata) + metadata.placement._device = current_device # type: ignore[union-attr] + + list_shards.append(Shard(cuda_tensor, metadata)) + + st_meta = copy.deepcopy(self.metadata()) + for meta in st_meta.shards_metadata: + if meta.placement.device().type != "cuda": # type: ignore[union-attr] + meta.placement._device = current_device # type: ignore[union-attr] + + pg = self._process_group if process_group is None else process_group + # we need to use `init_from_local_shards` to communicate between ranks + # and update the sharding spec/shards metadata. + st_cuda = ShardedTensor._init_from_local_shards_and_global_metadata( + list_shards, + sharded_tensor_metadata=st_meta, + process_group=pg, + init_rrefs=self._init_rrefs, + ) + return st_cuda + + def to(self, *args, **kwargs) -> ShardedTensor: + current_device: torch.device + if self._local_shards: + current_device = self._local_shards[0].tensor.device + elif self._process_group._get_backend_name() == "gloo": + current_device = torch.device("cpu") + else: + current_device = torch.device(torch.cuda.current_device()) + current_dtype = self.dtype + device_to = current_device + dtype_to = current_dtype + if len(args) == 1: + if isinstance(args[0], torch.dtype): + dtype_to = args[0] + elif isinstance(args[0], torch.device): + device_to = args[0] + elif isinstance(args[0], (str, int)): + device_to = torch.device(args[0]) + elif isinstance(args[0], torch.Tensor): + dtype_to = args[0].dtype + device_to = args[0].device + else: + raise RuntimeError(f"ShardedTensor.to() have wrong arguments: {args}") + elif len(args) == 2: + device_to, dtype_to = args + else: + dtype_to = kwargs.get("dtype", current_dtype) + device_to = kwargs.get("device", current_device) + + device_to = ( + torch.device(device_to) if isinstance(device_to, (str, int)) else device_to + ) + + if device_to.type == "cuda": + # if device_to set to cuda, set to current device even + # if user specify the device index. + current_idx = torch.cuda.current_device() + if device_to.index != current_idx: + warnings.warn( + "ShardedTensor.to only move tensor to its current device" + "If you want to put to different device, use `reshard` instead.", + stacklevel=2, + ) + device_to = torch.device(current_idx) + + copy_tensor = kwargs.get("copy", False) + non_blocking = kwargs.get("non_blocking", False) + memory_format = kwargs.get("memory_format", torch.preserve_format) + process_group = kwargs.get("process_group") + + if ( + not copy_tensor + and dtype_to == current_dtype + and device_to == current_device + ): + # already have correct dtype and device, return itself + return self + + # returns a copy of ShardedTensor on CUDA current device + list_shards: list[Shard] = [] + + for shard in self._local_shards: + new_tensor = shard.tensor.to( # type: ignore[call-overload] + device=device_to, + dtype=dtype_to, + non_blocking=non_blocking, + copy=copy_tensor, + memory_format=memory_format, + ) + metadata = copy.deepcopy(shard.metadata) + if metadata.placement is not None: + metadata.placement._device = device_to + list_shards.append(Shard(new_tensor, metadata)) + + # update metadata + st_meta = copy.deepcopy(self.metadata()) + st_meta.tensor_properties.dtype = dtype_to + for meta in st_meta.shards_metadata: + meta.placement._device = device_to # type: ignore[union-attr] + + pg = self._process_group if process_group is None else process_group + # we need to use `init_from_local_shards` to communicate between ranks + # and update the sharding spec/shards metadata. + st_to = ShardedTensor._init_from_local_shards_and_global_metadata( + list_shards, + sharded_tensor_metadata=st_meta, + process_group=pg, + init_rrefs=self._init_rrefs, + ) + return st_to + + @classmethod + def _normalize_pg( + cls, process_group: dist.ProcessGroup | None + ) -> dist.ProcessGroup: + if process_group is not None: + return process_group + return distributed_c10d._get_default_group() + + @classmethod + def _init_from_local_shards( + cls, + local_shards: list[Shard], + *global_size, + process_group=None, + init_rrefs=False, + ): + # recalc metadata handles special ST creation cases like each rank only has tensor available + # caller need to provide None on the unknown dimension of the global size + # We will change None into zeros and go through the same amount of checks as before to create ST + # and use all_gather to calculate the offsets and global size for metadata + # It is compatible with the current use case since, conventionally we don't pass None as global size + # Therefore the old path won't trigger the new feature + recalc_metadata = False + for dim in global_size: + if dim is None: + recalc_metadata = True + if recalc_metadata: + global_size = tuple( + 0 if dim_size is None else dim_size for dim_size in global_size + ) + # STEP 1: Validate the Shardmetadatas locally + process_group = cls._normalize_pg(process_group) + current_rank = dist.get_rank() # intentional to get global rank + world_size = dist.get_world_size(process_group) + + local_sharded_tensor_metadata: ShardedTensorMetadata | None = None + global_tensor_size = _flatten_tensor_size(global_size) + + if len(local_shards) > 0: + local_sharded_tensor_metadata = build_metadata_from_local_shards( + local_shards, global_tensor_size, current_rank, process_group + ) + + # STEP 2. Validate metadata across ranks, and build a global sharded tensor + # metadata by gathering local ShardedTensorMetadata + gathered_metadatas: list[ShardedTensorMetadata | None] = [] + if world_size > 1: + gathered_metadatas = [None for _ in range(world_size)] + + dist.all_gather_object( + gathered_metadatas, local_sharded_tensor_metadata, group=process_group + ) + else: + gathered_metadatas = [local_sharded_tensor_metadata] + + global_sharded_tensor_metadata = build_global_metadata( + gathered_metadatas, recalc_metadata=recalc_metadata + ) + if recalc_metadata: + # for recalc use cases, we only support rw for now, limit the blast radius + # will modify here once we support more sharding type + assert ( + len(local_shards) > 0 + and len(global_sharded_tensor_metadata.shards_metadata) > current_rank + ), ( + f"# for metadata recalculation, local_shards must be larger than 0 " + f"actual:{len(local_shards)}, # glb metadata must be greater than any rank id, " + f"# metadata:{len(global_sharded_tensor_metadata.shards_metadata)}, rank id:{current_rank}" + ) + local_md = [ + shard_md + for shard_md in global_sharded_tensor_metadata.shards_metadata + if shard_md.placement.rank() == current_rank + ] + assert len(local_md) == 1, ( + f"should has and only has one metadata for local rank, actual:{local_md}" + ) + local_shards[0].metadata = local_md[0] + tensor_properties = global_sharded_tensor_metadata.tensor_properties + + # STEP 3: Validation done, create the actual ShardedTensor and populate fields + # prepare initialization + spec = shard_spec._infer_sharding_spec_from_shards_metadata( + global_sharded_tensor_metadata.shards_metadata + ) + sharded_tensor = cls.__new__( + cls, + spec, + global_sharded_tensor_metadata.size, + dtype=tensor_properties.dtype, + layout=tensor_properties.layout, + pin_memory=tensor_properties.pin_memory, + requires_grad=tensor_properties.requires_grad, + ) + sharded_tensor._prepare_init(process_group=process_group, init_rrefs=init_rrefs) + + # attach local_shards to the ShardedTensor created + sharded_tensor._local_shards = local_shards + + # run post initialization, i.e. map registration, rpc initialization + sharded_tensor._post_init() + return sharded_tensor + + @classmethod + @deprecated(DEPRECATE_MSG, category=FutureWarning) + def _init_from_local_tensor( + cls, + local_tensor: torch.Tensor, + sharding_spec: shard_spec.ShardingSpec, + *global_size: Sequence[int], + process_group: dist.ProcessGroup | None = None, + init_rrefs=False, + ) -> ShardedTensor: + """ + Initialize a ShardedTensor given only one local tensor, global sharded tensor + size and sharding spec on each rank. + + Args: + local_tensor (Tensor): Single tensor of local shard stored in each rank. + sharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`): + The specification describing how to shard the Tensor. + global_size (Sequence[int]): Size of the sharded tensor. + process_group (ProcessGroup, optional): The process group to aggregate on. + Default: None + init_rrefs (bool, optional): Whether or not to initialize + :class:`torch.distributed.rpc.RRef`s pointing to remote shards. + Need to initialize the RPC Framework if specified as ``True``. + Default: ``False``. + + Returns: + A :class:`ShardedTensor` sharded based on the given sharding_spec with local + tensor stored in the current rank. + + Examples: + >>> # xdoctest: +SKIP + >>> # All tensors below are of torch.int64 type. + >>> # We have 2 process groups, 2 ranks. + >>> tensor = torch.arange(2, dtype=torch.int64) + 1 + 2 * rank + >>> local_tensor = torch.unsqueeze(torch.cat([tensor, tensor + 2])) + >>> local_tensor + tensor([[1, 2, 3, 4]]) # Rank 0 + tensor([[3, 4, 5, 6]]) # Rank 1 + >>> sharding_dim = 0 + >>> sharding_spec = ChunkShardingSpec( + dim=sharding_dim, + placements=[ + "rank:0/cuda:0", + "rank:1/cuda:1", + ], + ) + >>> st = ShardedTensor._init_from_local_tensor( + ... local_tensor, sharding_spec, [2, 4] + ... ) + >>> st + ShardedTensor( + ShardedTensorMetadata( + shards_metadata=[ + ShardMetadata(shard_offsets=[0, 0], shard_sizes=[1, 4], placement=rank:0/cuda:0), + ShardMetadata(shard_offsets=[1, 0], shard_sizes=[1, 4], placement=rank:1/cuda:1), + ], + size=torch.Size([2, 4]) + ) + >>> st.local_tensor() + tensor([1, 2, 3, 4]) # Rank 0 + tensor([3, 4, 5, 6]) # Rank 1 + + Warning: This API is experimental and subject to change. It lacks of a fully across + rank validations, and we only validate the local shard on the current rank. + We fully rely on the user to ensure local tensor is sharded based on the + sharding spec. + """ + if not local_tensor.is_contiguous(): + raise ValueError("local_tensor is not a contiguous Tensor.") + + global_tensor_size = _flatten_tensor_size(global_size) + tensor_properties = TensorProperties( + dtype=local_tensor.dtype, + layout=local_tensor.layout, + requires_grad=local_tensor.requires_grad, + memory_format=torch.contiguous_format, + pin_memory=local_tensor.is_pinned(), + ) + sharded_tensor_metadata = sharding_spec.build_metadata( + global_tensor_size, tensor_properties + ) + + process_group = cls._normalize_pg(process_group) + current_rank = dist.get_rank() # intentional to get global rank + + local_shards: list[Shard] = [] + for shard_metadata in sharded_tensor_metadata.shards_metadata: + rank, _device = _parse_and_validate_remote_device( + process_group, shard_metadata.placement + ) + if rank == current_rank: + local_shards.append(Shard(local_tensor, shard_metadata)) + + # TODO: figure out what the API should behave when some rank have no shard + # see https://github.com/pytorch/pytorch/issues/7313 + return ShardedTensor._init_from_local_shards_and_global_metadata( + local_shards, + sharded_tensor_metadata, + process_group=process_group, + init_rrefs=init_rrefs, + sharding_spec=sharding_spec, + ) + + @classmethod + def _init_from_local_shards_and_global_metadata( # type: ignore[override] + cls, + local_shards: list[Shard], + sharded_tensor_metadata: ShardedTensorMetadata, + process_group=None, + init_rrefs=False, + sharding_spec=None, + ) -> ShardedTensor: + """ + Initialize a ShardedTensor with local shards and a global + ShardedTensorMetadata built on each rank. + + Warning: This API is experimental and subject to change. It does + not do cross rank validations, and fully rely on the user + for the correctness of sharded_tensor_metadata on each rank + """ + process_group = cls._normalize_pg(process_group) + current_rank = dist.get_rank() # intentional to get global rank + + shards_metadata = sharded_tensor_metadata.shards_metadata + + local_shard_metadatas = [] + + # collect local shard metadatas from the global sharded_tensor_metadata + for shard_metadata in shards_metadata: # type: ignore[attr-defined] + rank, local_device = _parse_and_validate_remote_device( + process_group, shard_metadata.placement + ) + + if current_rank == rank: + local_shard_metadatas.append(shard_metadata) + + if len(local_shards) != len(local_shard_metadatas): + raise RuntimeError( + f"Number of local shards ({len(local_shards)}) does not match number of local " + f"shards metadata in sharded_tensor_metadata ({len(local_shard_metadatas)}) " + f"on rank ({current_rank}) " + ) + + shards_metadata = sharded_tensor_metadata.shards_metadata + tensor_properties = sharded_tensor_metadata.tensor_properties + + if len(shards_metadata) == 0: + raise ValueError("shards_metadata must not be empty!") + + if tensor_properties.layout != torch.strided: + raise ValueError("Only torch.strided layout is currently supported") + + if sharding_spec is None: + spec = shard_spec._infer_sharding_spec_from_shards_metadata(shards_metadata) + else: + spec = sharding_spec + + sharded_tensor = ShardedTensor.__new__( + ShardedTensor, + spec, + sharded_tensor_metadata.size, + dtype=tensor_properties.dtype, + layout=tensor_properties.layout, + pin_memory=tensor_properties.pin_memory, + requires_grad=tensor_properties.requires_grad, + ) + + def _raise_if_mismatch(expected, actual, prop_name, rank, is_property=False): + tensor_property_or_metadata = ( + "tensor property" if is_property else "local ShardMetadata" + ) + if expected != actual: + raise ValueError( + f"Local shards' tensor {prop_name} property is incompatible with " + f"{tensor_property_or_metadata} on rank {rank}: " + f"{tensor_property_or_metadata} {prop_name}={expected}, " + f"local shard tensor {prop_name}={actual}." + ) + + for shard in local_shards: + shard_meta = shard.metadata + local_shard_tensor = shard.tensor + placement = shard_meta.placement + assert placement is not None, "Must specify placement for `Shard`!" + rank = placement.rank() + local_device = placement.device() + + _raise_if_mismatch( + tensor_properties.layout, + local_shard_tensor.layout, + "layout", + rank, + True, + ) + if not local_shard_tensor.is_contiguous(): + raise ValueError( + "Only torch.contiguous_format memory_format is currently supported" + ) + + _raise_if_mismatch( + shard_meta.shard_sizes, + list(local_shard_tensor.size()), + "size", + rank, + ) + _raise_if_mismatch( + tensor_properties.pin_memory, + local_shard_tensor.is_pinned(), + "pin_memory", + rank, + True, + ) + _raise_if_mismatch(local_device, local_shard_tensor.device, "device", rank) + _raise_if_mismatch( + tensor_properties.dtype, + local_shard_tensor.dtype, + "dtype", + rank, + True, + ) + _raise_if_mismatch( + tensor_properties.requires_grad, + local_shard_tensor.requires_grad, + "requires_grad", + rank, + True, + ) + + # check if shards_metadata have overlap shards + validate_non_overlapping_shards_metadata(shards_metadata) + + # check if the shards_metadata is compatible with overall size of the sharded tensor. + check_tensor(shards_metadata, list(sharded_tensor_metadata.size)) + + # done validation, add local_shards + sharded_tensor._local_shards = local_shards + sharded_tensor._prepare_init(process_group=process_group, init_rrefs=init_rrefs) + + # run post initialization, i.e. map registration, rpc initialization + sharded_tensor._post_init() + return sharded_tensor + + def sharding_spec(self) -> shard_spec.ShardingSpec: + """ + Returns the ShardingSpec for the tensor. + """ + return self._sharding_spec + + @deprecated(DEPRECATE_MSG, category=FutureWarning) + def reshard(self, resharding_spec: shard_spec.ShardingSpec) -> ShardedTensor: + """ + Reshard a sharded tensor given the ``resharding_spec``. For now, we only support + single local shard. + + If ``resharding_spec`` is same as the original one, this becomes a no-op. + If only ``resharding_spec`` shares the same sharding dim with the original one, + we swap local shards directly. + For more generic cases, we merge different shards across different ranks and split + the local shards based on the ``resharding_spec`` via `all_to_all` collective API. + + Args: + resharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`): The + specification describing how the tensor is sharded. + + Returns: + A :class:`ShardedTensor` object whose local shards are resharded. + + Examples: + >>> # xdoctest: +SKIP + >>> # We have 2 process groups, 2 ranks. + >>> tensor = torch.arange(4, dtype=torch.int64) + 1 + 2 * rank + >>> tensor = torch.stack([tensor, tensor]) + >>> tensor + tensor([[1, 2, 3, 4], [1, 2, 3, 4]]) # Rank 0 + tensor([[3, 4, 5, 6], [3, 4, 5, 6]]) # Rank 1 + tensor([[5, 6, 7, 8], [5, 6, 7, 8]]) # Rank 2 + tensor([[7, 8, 9, 10], [7, 8, 9, 10]]) # Rank 3 + >>> sharding_dim = 0 + >>> spec = ChunkShardingSpec( + dim=sharding_dim, + placements=[ + "rank:0/cuda:0", + "rank:1/cuda:1", + "rank:2/cuda:2", + "rank:3/cuda:3", + ], + ) + >>> current_offsets = [0] * 2 + >>> current_offsets[0] = rank * 2 + >>> shard_metadata = ShardMetadata( + shard_offsets=copy.deepcopy(current_offsets), + shard_sizes=tensor.size(), + placement=spec.placements[rank], + ) + >>> local_shards = [ + Shard( + tensor=tensor, + metadata=shard_metadata, + ) + ] + >>> st = ShardedTensor._init_from_local_shards(local_shards, tensor.size()) + >>> sharding_dim = 1 + >>> resharding_spec = ChunkShardingSpec( + dim=sharding_dim, + placements=[ + "rank:0/cuda:0", + "rank:1/cuda:1", + "rank:2/cuda:2", + "rank:3/cuda:3", + ], + ) + >>> st.reshard(resharding_spec) + >>> tensor = st.local_shards()[0].tensor + >>> tensor + tensor([[1], [1], [3], [3], [5], [5], [7], [7]]) # Rank 0 + tensor([[2], [2], [4], [4], [6], [6], [8], [8]]) # Rank 1 + tensor([[3], [3], [5], [5], [7], [7], [9], [9]]) # Rank 2 + tensor([[4], [4], [6], [6], [8], [8], [10], [10]]) # Rank 3 + """ + if not isinstance( + resharding_spec, shard_spec.ChunkShardingSpec + ) or not isinstance(self._sharding_spec, shard_spec.ChunkShardingSpec): + raise NotImplementedError("Only ChunkShardingSpec supported for reshard.") + + num_local_shards = len(self.local_shards()) + if num_local_shards != 1: + raise NotImplementedError( + f"Only single local shard supported for reshard. Number of shards: {num_local_shards}" + ) + + if self._sharding_spec.dim == resharding_spec.dim: # type: ignore[attr-defined] + if self._sharding_spec.placements == resharding_spec.placements: # type: ignore[attr-defined] + return self + else: + local_shards, shards_metadata = reshuffle_local_shard( + self.local_tensor(), + self.size(), # type: ignore[arg-type] + self._sharding_spec, + resharding_spec, + self._process_group, + ) + else: + local_shards, shards_metadata = reshard_local_shard( + self.local_tensor(), + self.size(), # type: ignore[arg-type] + self._sharding_spec, + resharding_spec, + self._process_group, + ) + self._local_shards = local_shards + self._metadata.shards_metadata = shards_metadata + self._sharding_spec = resharding_spec + return self + + def local_tensor(self) -> torch.Tensor: + """ + Return local tensor for a sharded_tensor. For now we only support single local shard. + + Returns: + A :class:`torch.Tensor` of the local shard. + """ + num_local_shards = len(self.local_shards()) + if num_local_shards != 1: + raise NotImplementedError( + f"Only single local shard is supported. Number of shards: {num_local_shards}" + ) + return self.local_shards()[0].tensor + + @classmethod + @deprecated(DEPRECATE_MSG, category=FutureWarning) + def __torch_function__(cls, func, types, args=(), kwargs=None): + def dispatch(st: ShardedTensor, func: Callable): + # Dispatch to custom user provided op first if it exists. + if func in _CUSTOM_SHARDED_OPS: + return _CUSTOM_SHARDED_OPS[func](types, args, kwargs, st._process_group) + + # Dispatch to custom sharding spec op if it has one. + if _has_custom_op(st._sharding_spec, func): + return _dispatch_custom_op( + st._sharding_spec, func, types, args, kwargs, st._process_group + ) + + if func in _SHARDED_OPS: + return _SHARDED_OPS[func](types, args, kwargs, st._process_group) + + raise RuntimeError( + f"torch function '{func.__name__}', with args: {args} and " + f"kwargs: {kwargs} not supported for ShardedTensor!" + ) + + # Find ShardedTensor instance to get process_group and sharding_spec. + st_instance = None + + def find_sharded_tensor(e): + nonlocal st_instance + if st_instance is None and isinstance(e, ShardedTensor): + st_instance = e + + pytree.tree_map_(find_sharded_tensor, args) + pytree.tree_map_(find_sharded_tensor, kwargs) + + if st_instance is not None: + return dispatch(st_instance, func) + + raise RuntimeError( + f"torch function '{func.__name__}', with args: {args} and " + f"kwargs: {kwargs} not supported for ShardedTensor!" + ) + + def is_pinned(self) -> bool: # type: ignore[override] + """ + Returns True if the sharded tensor (each local shard) resides in pinned memory. + """ + return self._metadata.tensor_properties.pin_memory + + def _register_remote_shards( + self, remote_shards: list[rpc.RRef[Shard]], rpc_rank: int + ): + self._remote_shards[rpc_rank] = remote_shards + + def remote_shards(self) -> dict[int, list[rpc.RRef[Shard]]]: + """ + Returns a Dict[int, RRef] with keys being the RPC rank and values + being RRefs to shards on that rank. Need to initialize the + RPC framework for this functionality. + + Raises an exception if ShardedTensor was created with ``init_rrefs=False`` + """ + if not self._init_rrefs: + raise RuntimeError( + "ShardedTensor created with init_rrefs=False, no RRefs to remote shards available" + ) + return self._remote_shards + + def __hash__(self): + return id(self) + + def __repr__(self) -> str: # type: ignore[override] + return f"ShardedTensor({self._metadata})" + + @dataclass + class ProcessGroupState: + """ + State for ser-de of process group + """ + + local_rank: int + global_rank: int + local_world_size: int + global_world_size: int + + def __getstate__(self): + pg_state = ShardedTensor.ProcessGroupState( + distributed_c10d.get_rank(self._process_group), + distributed_c10d.get_rank(), + distributed_c10d.get_world_size(self._process_group), + distributed_c10d.get_world_size(), + ) + + return ( + self._local_shards, + self._metadata, + pg_state, + self._sharding_spec, + self._init_rrefs, + ) + + def __setstate__(self, state): + self._sharded_tensor_id = None + if not distributed_c10d.is_initialized(): + raise RuntimeError( + "Need to initialize default process group using " + '"init_process_group" before loading ShardedTensor' + ) + + ( + self._local_shards, + self._metadata, + pg_state, + self._sharding_spec, + self._init_rrefs, + ) = state + + # Setup process group + from torch.distributed._shard.api import _get_current_process_group + + self._process_group = _get_current_process_group() + + # Validate process group. + local_rank = distributed_c10d.get_rank(self._process_group) + if pg_state.local_rank != local_rank: + raise RuntimeError( + f"Local rank at save time was {pg_state.local_rank}, but at " + f"load time was {local_rank}" + ) + + global_rank = distributed_c10d.get_rank() + if pg_state.global_rank != global_rank: + raise RuntimeError( + f"Global rank at save time was {pg_state.global_rank}, but at " + f"load time was {global_rank}" + ) + + local_world_size = distributed_c10d.get_world_size(self._process_group) + if pg_state.local_world_size != local_world_size: + raise RuntimeError( + f"Local world size at save time was {pg_state.local_world_size}, " + f"but at load time was {local_world_size}" + ) + + global_world_size = distributed_c10d.get_world_size() + if pg_state.global_world_size != global_world_size: + raise RuntimeError( + f"Global world size at save time was {pg_state.global_world_size}, " + f"but at load time was {global_world_size}" + ) + + self._post_init() + + +def _create_tensor_from_params( + *size, local_device, tensor_properties: TensorProperties +): + """Helper to construct tensor from size, device and common params.""" + dtype = tensor_properties.dtype + layout = tensor_properties.layout + requires_grad = tensor_properties.requires_grad + memory_format = tensor_properties.memory_format + pin_memory = tensor_properties.pin_memory + + return torch.empty( + *size, + dtype=dtype, + layout=layout, + device=local_device, + requires_grad=requires_grad, + memory_format=memory_format, + pin_memory=pin_memory, + ) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/sharded_tensor/logger.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/sharded_tensor/logger.py new file mode 100644 index 0000000000000000000000000000000000000000..ff8cb4d18fb180ea620dd8daad60b5771a9688be --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/sharded_tensor/logger.py @@ -0,0 +1,35 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import logging + +from torch.distributed._shard.sharded_tensor.logging_handlers import _log_handlers + + +__all__: list[str] = [] + + +def _get_or_create_logger() -> logging.Logger: + logging_handler, log_handler_name = _get_logging_handler() + logger = logging.getLogger(f"sharding-spec-{log_handler_name}") + logger.setLevel(logging.DEBUG) + formatter = logging.Formatter( + "%(asctime)s %(filename)s:%(lineno)s %(levelname)s p:%(processName)s t:%(threadName)s: %(message)s" + ) + logging_handler.setFormatter(formatter) + logger.propagate = False + logger.addHandler(logging_handler) + return logger + + +def _get_logging_handler( + destination: str = "default", +) -> tuple[logging.Handler, str]: + log_handler = _log_handlers[destination] + log_handler_name = type(log_handler).__name__ + return (log_handler, log_handler_name) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/sharded_tensor/logging_handlers.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/sharded_tensor/logging_handlers.py new file mode 100644 index 0000000000000000000000000000000000000000..ed6832fd1ae834b6365a6b005b07bbbfffe90726 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/sharded_tensor/logging_handlers.py @@ -0,0 +1,16 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import logging + + +__all__: list[str] = [] + +_log_handlers: dict[str, logging.Handler] = { + "default": logging.NullHandler(), +} diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/sharded_tensor/metadata.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/sharded_tensor/metadata.py new file mode 100644 index 0000000000000000000000000000000000000000..466ca1a0c519ce4cc4ee24fae98ff4ddfbee300a --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/sharded_tensor/metadata.py @@ -0,0 +1,94 @@ +# mypy: allow-untyped-defs +from dataclasses import dataclass, field +from enum import Enum + +import torch +from torch.distributed._shard.metadata import ShardMetadata + + +class MEM_FORMAT_ENCODING(Enum): + TORCH_CONTIGUOUS_FORMAT = 0 + TORCH_CHANNELS_LAST = 1 + TORCH_PRESERVE_FORMAT = 2 + + +@dataclass +class TensorProperties: + """Properties used to create :class:`Tensor`""" + + # Regular tensor fields + dtype: torch.dtype = field(default=torch.get_default_dtype()) + layout: torch.layout = field(default=torch.strided) + requires_grad: bool = False + memory_format: torch.memory_format = field(default=torch.contiguous_format) + pin_memory: bool = False + + def __getstate__(self): + # Since torch.memory_format cannot be pickled! + memory_format = self.memory_format + if memory_format == torch.contiguous_format: + mem_format_encoding = MEM_FORMAT_ENCODING.TORCH_CONTIGUOUS_FORMAT + elif memory_format == torch.channels_last: + mem_format_encoding = MEM_FORMAT_ENCODING.TORCH_CHANNELS_LAST + elif memory_format == torch.preserve_format: + mem_format_encoding = MEM_FORMAT_ENCODING.TORCH_PRESERVE_FORMAT + else: + raise RuntimeError(f"Invalid torch.memory_format: {memory_format}") + + return ( + self.dtype, + self.layout, + self.requires_grad, + mem_format_encoding, + self.pin_memory, + ) + + def __setstate__( + self, + state, + ): + ( + self.dtype, + self.layout, + self.requires_grad, + mem_format_encoding, + self.pin_memory, + ) = state + + if mem_format_encoding == MEM_FORMAT_ENCODING.TORCH_CONTIGUOUS_FORMAT: + memory_format = torch.contiguous_format + elif mem_format_encoding == MEM_FORMAT_ENCODING.TORCH_CHANNELS_LAST: + memory_format = torch.channels_last + elif mem_format_encoding == MEM_FORMAT_ENCODING.TORCH_PRESERVE_FORMAT: + memory_format = torch.preserve_format + else: + raise RuntimeError( + f"Invalid torch.memory_format encoding: {mem_format_encoding}" + ) + + self.memory_format = memory_format + + @staticmethod + def create_from_tensor(tensor: torch.Tensor) -> "TensorProperties": + return TensorProperties( + dtype=tensor.dtype, + layout=tensor.layout, + requires_grad=tensor.requires_grad, + memory_format=torch.contiguous_format, + pin_memory=tensor.is_pinned(), + ) + + +@dataclass +class ShardedTensorMetadata: + """ + Represents metadata for :class:`ShardedTensor` + """ + + # Metadata about each shard of the Tensor + shards_metadata: list[ShardMetadata] = field(default_factory=list) + + # Size of each dim of the overall Tensor. + size: torch.Size = field(default=torch.Size([])) + + tensor_properties: TensorProperties = field(default_factory=TensorProperties) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/sharded_tensor/reshard.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/sharded_tensor/reshard.py new file mode 100644 index 0000000000000000000000000000000000000000..daef9c3586184e4e62b4a141ec2e43f5025bf454 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/sharded_tensor/reshard.py @@ -0,0 +1,243 @@ +# mypy: allow-untyped-defs +import copy + +import torch +import torch.distributed as dist +import torch.distributed._shard.sharding_spec as shard_spec +from torch._C._distributed_c10d import ProcessGroup +from torch.distributed._shard.metadata import ShardMetadata +from torch.distributed._shard.sharding_spec._internals import ( + get_chunked_dim_size, + get_split_size, +) +from torch.distributed.nn.functional import all_to_all, all_to_all_single + +from .shard import Shard + + +def get_idx_from_placements(placements, current_rank) -> int: + """ + Return the position of the current rank in the given placements. + + Args: + placements(List[Union[_remote_device, str]]): + Specifies the placement of each shard of the Tensor. The size of + the list represents the number of shards to be created. This could + be a list of + :class:`torch.distributed._remote_device`'s. This list + could also contain a string which represents remote + device as accepted by + :class:`torch.distributed._remote_device` + current_rank (int): number of current device. + + Returns: + A int which contains the position of current device in the placement list. + """ + for idx, placement in enumerate(placements): # type: ignore[attr-defined] + if current_rank == placement.rank(): # type: ignore[union-attr] + return idx + raise RuntimeError("current_rank not in the placement.") + + +def build_reshard_metadata( + st_size: torch.Size, + sharding_spec: shard_spec.ShardingSpec, + world_size: int, +) -> tuple[list[ShardMetadata], list[int]]: + """ + Based the given sharding spec, we calculate the offset and local shard size. + We then build a ShardMetadata on top of the calculation result. + + Args: + st_size (torch.Size): The size of the sharded tensor. + sharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`): The + specification describing how the tensor is sharded. + world_size (int): number of ranks. + + Returns: + A Tuple of the followings: + A List[`ShardMetadata`] which contains the metadata for the shard, including + offsets, lengths and device placement. + A List[int] which contains the ranks in the order of placement. + """ + shard_dim = int(sharding_spec.dim) # type: ignore[attr-defined] + shards_metadata = [None] * world_size + ranks = [] + offsets = [0] * len(st_size) + split_size = get_split_size(st_size[shard_dim], world_size) + for idx, placement in enumerate(sharding_spec.placements): # type: ignore[attr-defined] + ranks.append(placement.rank()) + sharded_dim_size = get_chunked_dim_size(st_size[shard_dim], split_size, idx) + local_tensor_size = list(st_size) + local_tensor_size[shard_dim] = sharded_dim_size + shards_metadata[placement.rank()] = ShardMetadata( # type: ignore[call-overload] + shard_offsets=copy.deepcopy(offsets), + shard_sizes=local_tensor_size, + placement=placement, + ) + offsets[shard_dim] += sharded_dim_size + return shards_metadata, ranks # type: ignore[return-value] + + +def reshuffle_local_shard( + local_shard: torch.Tensor, + st_size: torch.Size, + sharding_spec: shard_spec.ShardingSpec, + resharding_spec: shard_spec.ShardingSpec, + pg: ProcessGroup, +) -> tuple[list[Shard], list[ShardMetadata]]: + """ + Reshuffle the local shard directly when the reshard dim is same as the original + sharding dim. Logically we do this in two step: + 1. To collect all shards based on original sharding spec. + 2. Reshard the tensor based on the given resharding spec. + + In reality, we consolidate the two steps into one by sending the local tensor to + the new shard directly based on the resharding spec. + + Args: + local_shard (Tensor): Local tensor stored in the current rank. + st_size (torch.Size): The size of the sharded tensor. + sharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`): The + specification describing how the tensor is sharded originally. + resharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`): The + specification describing how the tensor will be resharded. + pg (ProcessGroup): The process group to aggregate on. + + Returns: + A Tuple of the followings: + A List[`Shard`] which contains the local tensor and its metadata. + A List[`ShardMetadata`] which contains the metadata for the shard, including + offsets, lengths and device placement. + """ + current_rank = dist.get_rank(pg) + world_size = dist.get_world_size(pg) + # Build shards_metadata first. + shards_metadata, ranks = build_reshard_metadata( + st_size, resharding_spec, world_size + ) + # Get input split size for all2all. + reshard_dim = int(resharding_spec.dim) # type: ignore[attr-defined] + split_size = get_split_size(st_size[reshard_dim], world_size) + input_split_sizes = [0] * world_size + idx = get_idx_from_placements(sharding_spec.placements, current_rank) # type: ignore[attr-defined] + new_rank = resharding_spec.placements[idx].rank() # type: ignore[union-attr, attr-defined] + input_split_sizes[new_rank] = local_shard.size(reshard_dim) + # Get output split size for all2all. + output_split_sizes = [0] * world_size + new_idx = ranks.index(current_rank) + sharded_dim_size = get_chunked_dim_size(st_size[reshard_dim], split_size, new_idx) + output_split_sizes[new_rank] = sharded_dim_size + # Get gathered_input for all2all. + local_shard = local_shard.transpose(0, reshard_dim).contiguous() + gathered_input_size = list(local_shard.size()) + gathered_input_size[0] = sharded_dim_size + gathered_input = torch.empty( + gathered_input_size, device=local_shard.device, dtype=local_shard.dtype + ) + # all2all. + local_shard = all_to_all_single( + gathered_input, + local_shard, + input_split_sizes=input_split_sizes, + output_split_sizes=output_split_sizes, + group=pg, + ) + local_tensor = local_shard.transpose(0, reshard_dim).contiguous() + local_shards = [Shard(local_tensor, shards_metadata[current_rank])] + return local_shards, shards_metadata + + +def reshard_local_shard( + local_tensor: torch.Tensor, + st_size: torch.Size, + sharding_spec: shard_spec.ShardingSpec, + resharding_spec: shard_spec.ShardingSpec, + pg: ProcessGroup, +) -> tuple[list[Shard], list[ShardMetadata]]: + """ + Reshard a sharded tensor given the ``resharding_spec``. When the reshard dim is + different from the original sharding dim, we need to do two steps logically: + 1. To collect all shards based on original sharding spec. + 2. Reshard the tensor based on the given resharding spec. + + In reality, we consolidate the two steps into one by sending each rank the new + shard based on the resharding spec. + + Args: + local_tensor (Tensor): Local tensor stored in the current rank. + st_size (torch.Size): The size of the sharded tensor. + sharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`): The + specification describing how the tensor is sharded originally. + resharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`): The + specification describing how the tensor will be resharded. + pg (ProcessGroup): The process group to aggregate on. + + Returns: + A Tuple of the followings: + A List[`Shard`] which contains the local tensor and its metadata. + A List[`ShardMetadata`] which contains the metadata for the shard, including + offsets, lengths and device placement. + """ + current_rank = dist.get_rank(pg) + world_size = dist.get_world_size(pg) + current_sharding_dim = int(sharding_spec.dim) # type: ignore[attr-defined] + reshard_dim = int(resharding_spec.dim) # type: ignore[attr-defined] + + # Build shards_metadata first. + shards_metadata, ranks = build_reshard_metadata( + st_size, resharding_spec, world_size + ) + + # Compute expected size + input_split_sizes = [ + metadata.shard_sizes[reshard_dim] for metadata in shards_metadata + ] + rearrange_input = any(ranks[i] > ranks[i + 1] for i in range(len(ranks) - 1)) + + if rearrange_input: + # Need to re-arrange reshard_dim of local_tensor before all2all. + indices: list[int] = [] + for metadata in shards_metadata: + offset_start_idx = metadata.shard_offsets[reshard_dim] + split_size = metadata.shard_sizes[reshard_dim] + indices += range(offset_start_idx, offset_start_idx + split_size) + local_tensor = local_tensor.index_select( + reshard_dim, torch.tensor(indices, device=local_tensor.device) + ) + + # Because reshard_dim != original shard_dim. We need to compute the + # size of tensor from each rank. + output_tensor_list = [torch.tensor(1)] * world_size + split_size = get_split_size(st_size[current_sharding_dim], world_size) + rearrange_output_list = False + indices = [] + for idx, placement in enumerate(sharding_spec.placements): # type: ignore[attr-defined] + sharded_dim_size = get_chunked_dim_size( + st_size[current_sharding_dim], split_size, idx + ) + output_tensor_size = list(st_size) + output_tensor_size[current_sharding_dim] = sharded_dim_size + output_tensor_size[reshard_dim] = input_split_sizes[current_rank] + output_tensor_list[placement.rank()] = torch.empty( # type: ignore[union-attr, index] + output_tensor_size, device=local_tensor.device, dtype=local_tensor.dtype + ) + indices.append(placement.rank()) # type: ignore[union-attr, index, arg-type] + if idx != placement.rank(): # type: ignore[union-attr] + rearrange_output_list = True + + # Perform autograd enabled all2all. + input_tensor_tuple = torch.split(local_tensor, input_split_sizes, dim=reshard_dim) + input_tensor_list = [tensor.contiguous() for tensor in input_tensor_tuple] + output_tensor_list = all_to_all( + output_tensor_list, + input_tensor_list, + group=pg, + ) + + if rearrange_output_list: + # Need to re-arrange original shard_dim of output_tensor_list. + output_tensor_list = [output_tensor_list[idx] for idx in indices] # type: ignore[call-overload] + local_tensor = torch.cat(output_tensor_list, dim=current_sharding_dim) + local_shards = [Shard(local_tensor, shards_metadata[current_rank])] + return local_shards, shards_metadata diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/sharded_tensor/shard.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/sharded_tensor/shard.py new file mode 100644 index 0000000000000000000000000000000000000000..2d9d4357436a6c15f590a4db486d9d54b6d6ca57 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/sharded_tensor/shard.py @@ -0,0 +1,61 @@ +from dataclasses import dataclass + +import torch +from torch.distributed._shard.metadata import ShardMetadata +from torch.distributed.remote_device import _remote_device + + +@dataclass +class Shard: + """ + Container which holds the data for a shard as a Tensor and also + the associated metadata for that shard. + + Args: + tensor(torch.Tensor): Local tensor for the shard. + metadata(:class `torch.distributed._shard.sharded_tensor.ShardMetadata`): + The metadata for the shard, including offsets, lengths and device placement. + """ + + __slots__ = ["tensor", "metadata"] + tensor: torch.Tensor + metadata: ShardMetadata + + def __post_init__(self) -> None: + # verification between local tensor and metadata + if list(self.tensor.size()) != self.metadata.shard_sizes: + raise ValueError( + "Shard tensor size does not match with metadata.shard_lengths! " + f"Found shard tensor size: {list(self.tensor.size())}, " + f"metadata.shard_lengths: {self.metadata.shard_sizes}, " + ) + placement_device = self.metadata.placement + if ( + placement_device is not None + and placement_device.device() != self.tensor.device + ): + raise ValueError( + f"Local shard tensor device does not match with local Shard's placement! " + f"Found local shard tensor device: {self.tensor.device}, " + f"local shard metadata placement device: {placement_device.device()}" + ) + + @classmethod + def from_tensor_and_offsets( + cls, tensor: torch.Tensor, shard_offsets: list[int], rank: int + ) -> "Shard": + """ + Creates a Shard of a ShardedTensor from a local torch.Tensor, shard_offsets and rank. + + Args: + tensor(torch.Tensor): Local tensor for the shard. + shard_offsets(List[int]): List of integers specify the offset + of the shard on each dimension. + rank(int): Specify the rank for the shard. + """ + shard_sizes = list(tensor.size()) + placement = _remote_device(f"rank:{rank}/{str(tensor.device)}") + shard_meta = ShardMetadata( + shard_offsets=shard_offsets, shard_sizes=shard_sizes, placement=placement + ) + return Shard(tensor, shard_meta) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/sharded_tensor/utils.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/sharded_tensor/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..b323da4ecbfa3adcea51367dc42a6e54d2cd1624 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/sharded_tensor/utils.py @@ -0,0 +1,325 @@ +# mypy: allow-untyped-defs +import collections.abc +import copy +import itertools +from collections.abc import Sequence +from typing import TYPE_CHECKING + +import torch +from torch.distributed import distributed_c10d as c10d, rpc +from torch.distributed._shard.sharding_spec._internals import ( + check_tensor, + validate_non_overlapping_shards_metadata, +) + +from .metadata import ShardedTensorMetadata, TensorProperties +from .shard import Shard + + +if TYPE_CHECKING: + from torch.distributed._shard.metadata import ShardMetadata + + +def _parse_and_validate_remote_device(pg, remote_device): + if remote_device is None: + raise ValueError("remote device is None") + + worker_name = remote_device.worker_name() + rank = remote_device.rank() + device = remote_device.device() + + # Validate rank, skip validation if rank is not part of process group. + if rank is not None and not c10d._rank_not_in_group(pg): + pg_global_ranks = c10d.get_process_group_ranks(pg) + if rank not in pg_global_ranks: + raise ValueError( + f"Global rank {rank} does not exist in input process group: {pg_global_ranks}" + ) + + if worker_name is not None: + if not rpc._is_current_rpc_agent_set(): + raise RuntimeError( + f"RPC framework needs to be initialized for using worker names: {worker_name}" + ) + + workers = rpc._get_current_rpc_agent().get_worker_infos() + for worker in workers: + if worker.name == worker_name: + return worker.id, device + + raise ValueError(f"Invalid worker name: {worker_name}") + + return rank, device + + +def _validate_output_tensor_for_gather( + my_rank: int, + dst_rank: int, + size: torch.Size, + dst_tensor: torch.Tensor | None, +) -> None: + if dst_rank == my_rank: + if dst_tensor is None: + raise ValueError( + f"Argument ``dst_tensor`` must be specified on destination rank {dst_rank}" + ) + if tuple(size) != (dst_tensor.size()): + raise ValueError( + f"Argument ``dst_tensor`` have size {tuple(dst_tensor.size())}," + f"but should be {tuple(size)}" + ) + elif dst_tensor: + raise ValueError( + "Argument ``dst_tensor`` must NOT be specified on non-destination ranks." + ) + + +def _flatten_tensor_size(size) -> torch.Size: + """ + Checks if tensor size is valid, then flatten/return a torch.Size object. + """ + if len(size) == 1 and isinstance(size[0], collections.abc.Sequence): + # pyrefly: ignore [not-iterable] + dims = list(*size) + else: + dims = list(size) + + for dim in dims: + if not isinstance(dim, int): + raise TypeError(f"size has to be a sequence of ints, found: {dims}") + + return torch.Size(dims) + + +def _raise_if_mismatch(expected, actual, prop_name, ranks, is_local=True): + if is_local: + assert isinstance(ranks, int) + if expected != actual: + raise ValueError( + f"Local shards' tensor {prop_name} property need to be the same on rank:{ranks}! " + f"Found one local shard tensor {prop_name}={expected}, " + f"the other local shard tensor {prop_name}={actual}." + ) + else: + # compare failure check across ranks, ranks list should have two rank + assert len(ranks) == 2 + if expected != actual: + raise ValueError( + f"ShardedTensor {prop_name} property does not match from different ranks! " + f"Found {prop_name}={expected} on rank:{ranks[0]}, " + f"and {prop_name}={actual} on rank:{ranks[1]}." + ) + + +def build_metadata_from_local_shards( + local_shards: list[Shard], + global_size: torch.Size, + current_rank: int, + pg: c10d.ProcessGroup, +) -> ShardedTensorMetadata: + assert len(local_shards) > 0, "must have local shards!" + local_shard_metadatas: list[ShardMetadata] = [] + + first_shard_dtype = local_shards[0].tensor.dtype + first_shard_layout = local_shards[0].tensor.layout + first_shard_requires_grad = local_shards[0].tensor.requires_grad + first_shard_is_pinned = local_shards[0].tensor.is_pinned() + + # 1). Validate local tensors and associated metadatas + for local_shard in local_shards: + local_shard_tensor = local_shard.tensor + local_shard_meta = local_shard.metadata + local_shard_metadatas.append(local_shard_meta) + rank, local_device = _parse_and_validate_remote_device( + pg, local_shard_meta.placement + ) + + if ( + local_shard_tensor.layout != torch.strided + or local_shard_tensor.layout != first_shard_layout + ): + raise ValueError( + f"Only torch.strided layout is currently supported, but found " + f"{local_shard_tensor.layout} on rank:{current_rank}!" + ) + + if not local_shard_tensor.is_contiguous(): + raise ValueError( + "Only torch.contiguous_format memory_format is currently supported!" + ) + + if rank != current_rank: + raise ValueError( + f"Local shard metadata's rank does not match with the rank in its process group! " + f"Found current rank in the process group: {current_rank}, " + f"local ShardMetadata placement's rank: {rank}" + ) + if local_shard_tensor.device != local_device: + raise ValueError( + f"Local shard tensor device does not match with local Shard's placement! " + f"Found local shard tensor device: {local_shard_tensor.device}, " + f"local shard metadata placement device: {local_device}" + ) + + _raise_if_mismatch( + local_shard_meta.shard_sizes, + list(local_shard_tensor.size()), + "size", + current_rank, + ) + _raise_if_mismatch( + local_shard_tensor.is_pinned(), + first_shard_is_pinned, + "pin_memory", + current_rank, + ) + _raise_if_mismatch( + local_shard_tensor.dtype, first_shard_dtype, "dtype", current_rank + ) + _raise_if_mismatch( + local_shard_tensor.requires_grad, + first_shard_requires_grad, + "requires_grad", + current_rank, + ) + + # 2). Build a "local" ShardedTensorMetadata with all local shards on this rank, then + # do all_gather to collect local_sharded_tensor_metadata from all ranks + local_tensor_properties = TensorProperties( + dtype=first_shard_dtype, + layout=first_shard_layout, + requires_grad=first_shard_requires_grad, + memory_format=torch.contiguous_format, + pin_memory=first_shard_is_pinned, + ) + + local_sharded_tensor_metadata = ShardedTensorMetadata( + shards_metadata=local_shard_metadatas, + size=global_size, + tensor_properties=local_tensor_properties, + ) + + return local_sharded_tensor_metadata + + +def build_global_metadata( + gathered_metadatas: Sequence[ShardedTensorMetadata | None], + recalc_metadata: bool = False, +): + global_sharded_tensor_metadata = None + global_metadata_rank = 0 + + # pyrefly: ignore [bad-assignment] + for rank, rank_metadata in enumerate(gathered_metadatas): + if rank_metadata is None: + continue + + if global_sharded_tensor_metadata is None: + global_sharded_tensor_metadata = copy.deepcopy(rank_metadata) + global_metadata_rank = rank + else: + _raise_if_mismatch( + global_sharded_tensor_metadata.size, + rank_metadata.size, + "global_size", + [global_metadata_rank, rank], + is_local=False, + ) + + # don't need to check layout and memory format as we already checked in local shards validation stage + _raise_if_mismatch( + global_sharded_tensor_metadata.tensor_properties.dtype, + rank_metadata.tensor_properties.dtype, + "dtype", + [global_metadata_rank, rank], + is_local=False, + ) + + _raise_if_mismatch( + global_sharded_tensor_metadata.tensor_properties.requires_grad, + rank_metadata.tensor_properties.requires_grad, + "requires_grad", + [global_metadata_rank, rank], + is_local=False, + ) + + _raise_if_mismatch( + global_sharded_tensor_metadata.tensor_properties.pin_memory, + rank_metadata.tensor_properties.pin_memory, + "pin_memory", + [global_metadata_rank, rank], + is_local=False, + ) + # pass all validations, extend shards metadata + global_sharded_tensor_metadata.shards_metadata.extend( + rank_metadata.shards_metadata + ) + + if global_sharded_tensor_metadata is not None: + if recalc_metadata: + recalc_global_sharded_tensor_metadata( + global_sharded_tensor_metadata, + 0, # sharded on 0th dim + ) + + # check if shards_metadata have overlap shards + validate_non_overlapping_shards_metadata( + global_sharded_tensor_metadata.shards_metadata + ) + + # check if the shards_metadata is compatible with global size of the sharded tensor. + check_tensor( + global_sharded_tensor_metadata.shards_metadata, + global_sharded_tensor_metadata.size, + ) + else: + raise ValueError("ShardedTensor have no local shards on all ranks!") + + return global_sharded_tensor_metadata + + +def recalc_global_sharded_tensor_metadata( + global_sharded_tensor_metadata: ShardedTensorMetadata, sharded_dim: int +) -> None: + # recalculate global ShardedTensorMetadata + + # reorder here in case shard metadata is not sorted on sharded_dim + placement_idx_pairs = [] + for i, shard_metadata in enumerate(global_sharded_tensor_metadata.shards_metadata): + if shard_metadata.placement: + placement_idx_pairs.append((shard_metadata.placement.rank(), i)) + else: + raise AssertionError( + "currently only support rw, it should always have valid rank info" + ) + sorted_idx = sorted(placement_idx_pairs) + shard_sizes = [ + global_sharded_tensor_metadata.shards_metadata[idx].shard_sizes[sharded_dim] + for _, idx in sorted_idx + ] + cum_sum = [0] + list(itertools.accumulate(shard_sizes)) + + for shard_id, shard_metadata in enumerate( + global_sharded_tensor_metadata.shards_metadata + ): + # update shard offset for each shard on the sharded dimension + shard_metadata.shard_offsets[sharded_dim] = cum_sum[shard_id] + for other_dim in range( + len(global_sharded_tensor_metadata.shards_metadata[0].shard_sizes) + ): + if other_dim != sharded_dim: + # shard offset for each shard on the unsharded dimension + shard_metadata.shard_offsets[other_dim] = 0 + + # update global size for ShardedTensorMetadata + global_size_list = [] + for other_dim in range( + len(global_sharded_tensor_metadata.shards_metadata[0].shard_sizes) + ): + if other_dim != sharded_dim: + global_size_list.append( + global_sharded_tensor_metadata.shards_metadata[0].shard_sizes[other_dim] + ) + else: + global_size_list.append(cum_sum[-1]) + global_sharded_tensor_metadata.size = torch.Size(global_size_list) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/sharder.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/sharder.py new file mode 100644 index 0000000000000000000000000000000000000000..5d91ec15775bea870b81c4b10fb1443a3fba0977 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/sharder.py @@ -0,0 +1,29 @@ +import abc + +import torch.nn as nn + + +class Sharder(abc.ABC): + """ + This is an interface which allows user to create more advanced + sharding strategies that are not easily be composed by the + `ShardingSpec`. + + :class:`torch.distributed._shard.sharding_plan.ShardingPlan` could + take an object of the `Sharder` and call `shard` to shard the module, + then replace the original module with sharded module returned. + """ + + @abc.abstractmethod + def shard(self, module: nn.Module) -> nn.Module: + """ + Shard a module base on the implementation of this method, and + return the sharded version of the module. + + Args: + module (:class:`torch.nn.Module`): + The module to apply sharding to. + Returns: + A :class:`torch.nn.Module` object that represents a module + that's already been sharded. + """ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/sharding_plan/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/sharding_plan/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..325f7d7eb47b96a79fdc10cc2d1f072cdec9b4ce --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/sharding_plan/__init__.py @@ -0,0 +1 @@ +from .api import ShardingPlan, ShardingPlanner diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/sharding_plan/api.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/sharding_plan/api.py new file mode 100644 index 0000000000000000000000000000000000000000..a94f4b54edf2b6c29fd9331ec5e662a793510102 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/sharding_plan/api.py @@ -0,0 +1,86 @@ +import abc +from dataclasses import dataclass + +import torch.nn as nn +from torch.distributed._shard.sharder import Sharder +from torch.distributed._shard.sharding_spec import ShardingSpec + + +@dataclass +class ShardingPlan: + """ + Representation of a sharding plan, describes how to shard a module + across hosts. `plan` is used to shard module parameters according to the spec provided, + `output_plan` and `return_local_tensor` are optional, they are used to specify the output + layout of a module with a spec, and when to convert back to data parallel fashion. + + Args: + plan (Dict[str, Union[:class:`torch.distributed._shard.sharding_spec.ShardingSpec`, + :class:`torch.distributed._shard.sharder.Sharder`]): + a dict describes how to shard a module, there're currently two ways to shard a module: + 1. directly shard a module parameter by a `ShardingSpec`, keyed by the name of + a parameter to a `ShardingSpec`. + 2. shard a submodule by applying a `Sharder` on it, keyed by the name of a module + to a `Sharder` object. + output_plan (Dict[str, :class:`torch.distributed._shard.sharding_spec.ShardingSpec`), optional): + a dict specifies the layout of a module's output which produces a ShardedTensor, + keyed by the name of module to ShardingSpec("" in key means the root module). + Default: `None` + return_local_tensor (List[str], optional): a list of string, each element enables + a module's sharded output to be returned as a Tensor from its local shards to + ensure further processing in a data parallel fashion. ("" in list means the + root module). + Default: None + Example: + Suppose we want to shard a module with two linear layers and then run it with DDP, we also + want to convert the output of the second linear layer back to DDP, we can do it as follows: + + >>> # xdoctest: +REQUIRES(module:torch._C._distributed_c10d) + >>> class MyModule(nn.Module): + >>> def __init__(self) -> None: + >>> super().__init__() + >>> self.fc1 = nn.Linear() + >>> self.gelu = nn.GELU() + >>> self.fc2 = nn.Linear() + >>> self.relu = nn.Linear() + >>> + >>> def forward(self, input): + >>> return self.relu(self.fc2(self.gelu(self.fc1(input)))) + + + >>> # xdoctest: +SKIP("Undefined spec1, spec2) + >>> sharding_plan = ShardingPlan( + >>> plan={ + >>> "fc1.weight": spec1, + >>> "fc2.weight": spec2 + >>> }, + >>> output_plan={ + >>> "fc2": output_spec + >>> }, + >>> return_local_tensor=["fc2"] + >>> ) + """ + + plan: dict[str, ShardingSpec | Sharder] + output_plan: dict[str, ShardingSpec] | None = None + return_local_tensor: list[str] | None = None + + +class ShardingPlanner(abc.ABC): + """ + Default ShardingPlanner interface, can be extended and + implement advanced sharding strategies. + """ + + @abc.abstractmethod + def build_plan(self, module: nn.Module) -> ShardingPlan: + """ + Given a nn.Module, define how to shard the module across + ranks, return a ShardingPlan + Args: + module (:class:`torch.nn.Module`): + The module to apply sharding to. + Returns: + A :class:`torch.distributed._shard.sharding_plan.ShardingPlan` object that + represents how to shard the module. + """ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/sharding_spec/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/sharding_spec/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..bfd3f0a7581e8c4352eba843af6d3751bee7f387 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/sharding_spec/__init__.py @@ -0,0 +1,10 @@ +from torch.distributed._shard.metadata import ShardMetadata + +from .api import ( + _infer_sharding_spec_from_shards_metadata, + DevicePlacementSpec, + EnumerableShardingSpec, + PlacementSpec, + ShardingSpec, +) +from .chunk_sharding_spec import ChunkShardingSpec as ChunkShardingSpec diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/sharding_spec/_internals.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/sharding_spec/_internals.py new file mode 100644 index 0000000000000000000000000000000000000000..486c62a18cd7b91e30ad21891fb0c735e28d443f --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/sharding_spec/_internals.py @@ -0,0 +1,244 @@ +# mypy: allow-untyped-defs +import math +import sys +from bisect import bisect_right, insort + +from torch.distributed._shard.metadata import ShardMetadata + + +def _check_shard_metadata_pair_overlap(shard1: ShardMetadata, shard2: ShardMetadata): + """ + Checks if two shards overlap. + """ + + # For each dim of each shard, check if one shard resides on the other + # end of second shard with respect to that dim. As an example for a 2D + # shard, we would check if one shard is above or on the left of the + # other shard. + ndims = len(shard1.shard_offsets) + for i in range(ndims): + if shard1.shard_offsets[i] >= shard2.shard_offsets[i] + shard2.shard_sizes[i]: + return False + if shard2.shard_offsets[i] >= shard1.shard_offsets[i] + shard1.shard_sizes[i]: + return False + + return True + + +def _find_nd_overlapping_shards( + shards: list[ShardMetadata], sharded_dims: list[int] +) -> tuple[int, int] | None: + """Find overlapping shards using sweep-line algorithm.""" + if len(shards) <= 1: + return None + + dims = len(sharded_dims) + if dims == 0: + return None + + sweep_dim_idx = 0 + if dims > 1: + max_size = 0 + for i, dim in enumerate(sharded_dims): + dim_size = shards[0].shard_offsets[dim] + shards[0].shard_sizes[dim] + if dim_size > max_size: + max_size = dim_size + sweep_dim_idx = i + sweep_dim = sharded_dims[sweep_dim_idx] + + sorted_indices = sorted( + range(len(shards)), + key=lambda idx: ( + shards[idx].shard_offsets[sweep_dim], + *(shards[idx].shard_offsets[d] for d in sharded_dims if d != sweep_dim), + ), + ) + active: list[tuple[int, int]] = [] + + for idx in sorted_indices: + current = shards[idx] + start = current.shard_offsets[sweep_dim] + end = start + current.shard_sizes[sweep_dim] + + cutoff = bisect_right(active, (start, sys.maxsize)) + if cutoff: + del active[:cutoff] + + for _, other_idx in active: + other = shards[other_idx] + + if _check_shard_metadata_pair_overlap(current, other): + return (other_idx, idx) + insort(active, (end, idx)) + return None + + +def _find_1d_overlapping_shards( + shards: list[ShardMetadata], dim: int +) -> tuple[int, int] | None: + # (begin, end, index_in_shards). Begin and end are inclusive. + intervals = [ + (s.shard_offsets[dim], s.shard_offsets[dim] + s.shard_sizes[dim] - 1, i) + for i, s in enumerate(shards) + ] + intervals.sort() + for i in range(len(shards) - 1): + if intervals[i][1] >= intervals[i + 1][0]: + return (intervals[i][2], intervals[i + 1][2]) + return None + + +def validate_non_overlapping_shards_metadata(shards: list[ShardMetadata]): + """ + Ensures none of the shards overlap with each other. + + Args: + shards(List[ShardMetadata]): List of :class:`ShardMetadata` objects representing + each shard. + Raises: + ``ValueError`` if there's overlap in any two shards. + """ + if not shards or len(shards) == 1: + return + + sharded_dims: list[int] = [] + for dim in range(len(shards[0].shard_offsets)): + for i in range(1, len(shards)): + if ( + shards[i].shard_offsets[dim] != shards[0].shard_offsets[dim] + or shards[i].shard_sizes[dim] != shards[0].shard_sizes[dim] + ): + sharded_dims.append(dim) + break + + pair: tuple[int, int] | None = None + if len(sharded_dims) == 0: + # if shard is all zeros, we should consider as pass + all_zeros: bool = all( + # strictly limited all offsets to be 0 to pass + # could loose it later on + shard.shard_offsets == [0] * len(shards[0].shard_offsets) + and math.prod(shard.shard_sizes) == 0 # one dimension is 0 + for shard in shards + ) + if all_zeros: + return + # All shards are the same, all dims are not partitioned. Choose any 2. + pair = (0, 1) + elif len(sharded_dims) == 1: + # Shards are partitioned over only one dimension. Overlap can be found + # using a O(nlogn) overlapping interval algorithm. + pair = _find_1d_overlapping_shards(shards, sharded_dims[0]) + else: + # Shards are partitioned over more than one dimension. + # Use sweep-line algorithm for O(n log n) complexity. + pair = _find_nd_overlapping_shards(shards, sharded_dims) + + if pair: + raise ValueError(f"Shards {shards[pair[0]]} and {shards[pair[1]]} overlap") + + +def check_tensor(shards_metadata, tensor_dims) -> None: + """ + Checks if the shards_metadata is compatible with the provided tensor dims. + + Args: + shards_metadata(List[ShardMetadata]): List of :class:`ShardMetadata` + objects representing each shard of the tensor. + tensor_dims(Sequence of int): Dimensions of tensor to verify + Raises: + ``ValueError`` if not compatible. + """ + + # If the tensor's volume matches the total volume of all shards and + # all shard boundaries are within tensor dims, we have a compatible + # sharding spec for this tensor. Note that we have already verified + # we don't have overlapping shards. + tensor_rank = len(tensor_dims) + shards_rank = len(shards_metadata[0].shard_offsets) + if tensor_rank != shards_rank: + raise ValueError( + f"Rank of tensor is {tensor_rank}, but shards rank is {shards_rank}" + ) + + total_shard_volume = 0 + for shard in shards_metadata: + shard_volume = 1 + for i, shard_length in enumerate(shard.shard_sizes): + shard_volume *= shard_length + if shard.shard_offsets[i] + shard.shard_sizes[i] > tensor_dims[i]: + raise ValueError( + f"Shard offset {shard.shard_offsets[i]} and length " + f"{shard.shard_sizes[i]} exceeds tensor dim: {tensor_dims[i]} for shard {shard}" + ) + total_shard_volume += shard_volume + + tensor_volume = 1 + for size in tensor_dims: + tensor_volume *= size + + if total_shard_volume != tensor_volume: + # TODO: Can we improve this error message to point out the gaps? + raise ValueError( + f"Total volume of shards: {total_shard_volume} " + f"does not match tensor volume: {tensor_volume}, in other words " + f"all the individual shards do not cover the entire tensor" + ) + + +def get_split_size(dim_size, chunks): + """ + Computes the split size inline with ``torch.chunk`` + + Args: + dim_size(int): Size of the dimension being chunked. + chunks(int): Number of chunks to create for ``dim_size``. + + Returns: + An int indicating the split size to use. + """ + return (dim_size + chunks - 1) // chunks + + +def get_chunked_dim_size(dim_size, split_size, idx): + """ + Computes the dim size of the chunk for provided ``idx`` given ``dim_size`` + and ``split_size``. + + Args: + dim_size(int): Size of the dimension being chunked. + split_size(int): The chunk size for each chunk of ``dim_size``. + idx(int): The index of chunk whose dim size is being requested. + + Returns: + An int indicating the dim size of the chunk. + """ + return max(min(dim_size, split_size * (idx + 1)) - split_size * idx, 0) + + +def get_chunk_sharding_params(sharding_dim_size, world_size, spec, rank): + """ + Generate the start pos and offset length for the current rank for + chunk sharding. + + Args: + sharding_dim_size(int): The dimension length which we shard on. + world_size(int): number of ranks. + spec (:class:`torch.distributed._shard.sharding_spec.ChunkShardingSpec`): + sharding spec. + rank(int): # of cuda process. + + Returns: + start_pos(int): start position of sharded tensor on the given rank. + chunk_size(int): chunk size of sharded tensor on the given rank. + """ + split_size = get_split_size(sharding_dim_size, world_size) + current_offsets = 0 + start_pos = current_offsets + for idx, placement in enumerate(spec.placements): + chunk_size = get_chunked_dim_size(sharding_dim_size, split_size, idx) + if rank == placement.rank(): + start_pos = current_offsets + break + current_offsets += chunk_size + return start_pos, chunk_size # type: ignore[possibly-undefined] diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/sharding_spec/api.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/sharding_spec/api.py new file mode 100644 index 0000000000000000000000000000000000000000..87a49abdb5c05dcfe3db1fdf734dc9f3bef3b4bf --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/sharding_spec/api.py @@ -0,0 +1,264 @@ +# mypy: allow-untyped-defs +import functools +import operator +from abc import ABC, abstractmethod +from collections.abc import Callable +from dataclasses import dataclass +from typing import TYPE_CHECKING + +import torch +import torch.distributed._shard.sharded_tensor.metadata as sharded_tensor_meta +from torch.distributed._shard.metadata import ShardMetadata +from torch.distributed._shard.op_registry_utils import _decorator_func + +from ._internals import ( + check_tensor, + get_chunked_dim_size, + get_split_size, + validate_non_overlapping_shards_metadata, +) + + +if TYPE_CHECKING: + # Only include ShardedTensor when do type checking, exclude it + # from run-time to resolve circular dependency. + from torch.distributed._shard.sharded_tensor import ShardedTensor + + +class PlacementSpec(ABC): # noqa: B024 + """ + Base class representing the placement of an entity. Subclasses of this + class can be used to specify customized placements which might not be + covered by existing APIs. + """ + + +@dataclass +class DevicePlacementSpec(PlacementSpec): + """ + Associates placement of an entity with a single device. + + Args: + device(:class:`torch.distributed._remote_device`): The device to place the entity on. + """ + + device: torch.distributed._remote_device + + def __post_init__(self): + if not isinstance(self.device, torch.distributed._remote_device): + self.device = torch.distributed._remote_device(self.device) + + +class ShardingSpec(ABC): + """ + Base class representing sharding specifications. + """ + + @abstractmethod + def build_metadata( + self, + tensor_sizes: torch.Size, + tensor_properties: sharded_tensor_meta.TensorProperties, + ) -> sharded_tensor_meta.ShardedTensorMetadata: + """ + Given a global tensor size, define how to shard a tensor like this shape + across ranks, return ShardedTensorMetadata + Args: + tensor_sizes (:class:`torch.Size`): + The tensor shape to shard on, a `torch.Size` object that represents the + tensor shape to be sharded according to the ShardingSpec. + tensor_properties(:class:`torch.distributed._shard.sharded_tensor.TensorProperties): + Tensor properties used to create a ShardedTensor. + Returns: + A :class:`ShardedTensorMetadata` object that encodes the information about + the layout of the ShardedTensor and its properties. + """ + + @abstractmethod + def shard( + self, tensor: torch.Tensor, src_rank: int = 0, process_group=None + ) -> "ShardedTensor": + """ + Given a global tensor on src_rank, shard this tensor + across ranks within the process group, return a ShardedTensor. + Args: + tensor (:class:`torch.Tensor`): Tensor needs to be sharded. + Keyword args: + src_rank (int, optional): The source rank which is used as the ground truth of + the data for the parameter that would be sharded and scattered + across the rest of the ranks. + Default: 0. + process_group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. + Returns: + A :class:`ShardedTensor` sharded from the given tensor. + """ + + +# Ops customized for a particular ShardingSpec. +_CUSTOM_SHARDING_SPEC_OPS: dict[str, dict[Callable, Callable]] = {} + + +def _has_custom_op(sharding_spec, op): + """ + Returns whether or not the ShardingSpec has a custom op implementation. + """ + class_name = type(sharding_spec).__qualname__ + return ( + class_name in _CUSTOM_SHARDING_SPEC_OPS + and op in _CUSTOM_SHARDING_SPEC_OPS[class_name] + ) + + +def _dispatch_custom_op( + sharding_spec, op: Callable, types, args, kwargs, process_group +): + """ + Calls the custom op for this ShardingSpec if it exists. + """ + class_name = type(sharding_spec).__qualname__ + if not _has_custom_op(sharding_spec, op): + raise RuntimeError(f"Custom op: {op} not registered for {class_name}") + func = _CUSTOM_SHARDING_SPEC_OPS[class_name][op] + return func(types, args, kwargs, process_group) + + +def custom_sharding_spec_op(sharding_spec_class, func): + """ + Decorator to allow custom registration of ops. + Args: + sharding_spec_class(type): The ShardingSpec for which we need to add this custom op. + func(Callable): The op to override (ex: torch.bmm) + """ + class_name = sharding_spec_class.__qualname__ + if class_name not in _CUSTOM_SHARDING_SPEC_OPS: + _CUSTOM_SHARDING_SPEC_OPS[class_name] = {} + return functools.partial( + _decorator_func, op=func, op_table=_CUSTOM_SHARDING_SPEC_OPS[class_name] + ) + + +@dataclass +class EnumerableShardingSpec(ShardingSpec): + """ + This is a type of PlacementSpec that allows users to specify a generic + sharding scheme by enumerating exactly how each shard is laid out. + + Args: + shards(List[ShardMetadata]): List of :class:`ShardMetadata` objects representing + each shard. Note that none of the shards should overlap. + """ + + shards: list[ShardMetadata] + + def __post_init__(self): + if len(self.shards) == 0: + raise ValueError(f"Empty shard list provided: {self.shards}") + + # Validate each shard has same rank. + rank = -1 + for shard in self.shards: + if rank != -1 and rank != len(shard.shard_offsets): + raise ValueError( + f"Found inconsistent ranks for shards: {rank} and {len(shard.shard_offsets)}" + ) + rank = len(shard.shard_offsets) + + validate_non_overlapping_shards_metadata(self.shards) + + def build_metadata( + self, + tensor_sizes: torch.Size, + tensor_properties: sharded_tensor_meta.TensorProperties, + ) -> sharded_tensor_meta.ShardedTensorMetadata: + # check if shards form a valid tensor + check_tensor(self.shards, tensor_sizes) + return sharded_tensor_meta.ShardedTensorMetadata( + self.shards, tensor_sizes, tensor_properties + ) + + def shard( + self, tensor: torch.Tensor, src_rank: int = 0, process_group=None + ) -> "ShardedTensor": + # TODO: figure out a generic and efficient way to scatter the shards for EnumerableShardingSpec + raise NotImplementedError("EnumerableShardingSpec.shard not implemented yet!") + + +def _infer_sharding_spec_from_shards_metadata(shards_metadata): + """ + Infer the sharding spec from the metadata of each shard of a ShardedTensor. + If the tensor is sharded only on one dimension, we can then verify whether it's + a ChunkShardingSpec or not. The way to verify it is to first get the total length + and perform a chunk sharding with the given placements to see if we can have the + same chunk size as the given shards_metadata. If not, we assume it's enum sharded. + + Args: + shards_metadata (List[ShardMetadata]): List of Metadata of local shards. + + Returns: + A :class:`torch.distributed._shard.sharding_spec.ShardingSpec` object of sharding + spec for one sharded tensor. + """ + placements = [] + chunk_sharding_dim = None + chunk_offset_list = [] + shard_size_list = [] + shard_offset_list = [] + # collect local shard metadatas from the global sharded_tensor_metadata + for shard_metadata in shards_metadata: # type: ignore[attr-defined] + placements.append(shard_metadata.placement) + local_offsets = shard_metadata.shard_offsets + chunk_offset_list.append(sum(local_offsets)) + shard_size_list.append(shard_metadata.shard_sizes) + shard_offset_list.append(shard_metadata.shard_offsets) + shard_dims = [idx for idx, e in enumerate(local_offsets) if e != 0] + # If the offset is [0, 0, ..., 0] (all zeros), + # we cannot decide whether how the tensor is sharded. + if len(shard_dims) == 0: + continue + # If the offset is [0, N, .,0, M, 0, .., 0], + # we are sure it's sharded by more than one dimension. + if len(shard_dims) != 1: + chunk_sharding_dim = None + break + # If the offset is [0, 0, .,0, M, 0, .., 0], aka, it's sharded by just + # one dimension, we need to make sure all ranks share the same dimension. + if not chunk_sharding_dim: + chunk_sharding_dim = shard_dims[0] + elif chunk_sharding_dim != shard_dims[0]: + chunk_sharding_dim = None + break + + if chunk_sharding_dim is not None: + # Ensure we infer the correct placement order from offsets + placements = [ + x + for _, x in sorted( + zip(chunk_offset_list, placements), key=operator.itemgetter(0) + ) + ] + + from .chunk_sharding_spec import ChunkShardingSpec + + chunk_spec = ChunkShardingSpec( + dim=chunk_sharding_dim, + placements=placements, + ) + + shard_sizes = sorted([x[chunk_sharding_dim] for x in shard_size_list]) + shard_total_length = sum(shard_sizes) + shard_offsets = sorted([x[chunk_sharding_dim] for x in shard_offset_list]) + + chunks = len(placements) + split_size = get_split_size(shard_total_length, chunks) + chunk_shard_sizes = sorted( + [ + get_chunked_dim_size(shard_total_length, split_size, idx) + for idx in range(chunks) + ] + ) + # Should match ChunkShardingSpec offsets calculation + chunk_shard_offsets = [split_size * idx for idx in range(chunks)] + if shard_sizes == chunk_shard_sizes and shard_offsets == chunk_shard_offsets: + return chunk_spec + return EnumerableShardingSpec(shards_metadata) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/sharding_spec/chunk_sharding_spec.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/sharding_spec/chunk_sharding_spec.py new file mode 100644 index 0000000000000000000000000000000000000000..4d7b11b7c16c567b0b71fc6a0858dc58b7977ebf --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_shard/sharding_spec/chunk_sharding_spec.py @@ -0,0 +1,229 @@ +# mypy: allow-untyped-defs +from dataclasses import dataclass +from typing import cast, TYPE_CHECKING + +import torch +import torch.distributed as dist +import torch.distributed._shard.sharded_tensor.metadata as sharded_tensor_meta +import torch.distributed.distributed_c10d as distributed_c10d +from torch.distributed._shard._utils import narrow_tensor +from torch.distributed._shard.metadata import ShardMetadata +from torch.distributed._shard.sharded_tensor.shard import Shard +from torch.distributed._shard.sharded_tensor.utils import ( + _parse_and_validate_remote_device, +) + +from ._internals import get_chunked_dim_size, get_split_size +from .api import ShardingSpec + + +if TYPE_CHECKING: + # Only include ShardedTensor when do type checking, exclude it + # from run-time to resolve circular dependency. + from torch.distributed._shard.sharded_tensor import ShardedTensor + + +@dataclass +class ChunkShardingSpec(ShardingSpec): + """ + This is a type of PlacementSpec that defines the placement as being sharded + across multiple devices. In particular, it represents sharding a Tensor + along a single dimension into equal chunks (similar to :meth:`torch.chunk`). + + The semantics of how a tensor is partitioned is inline with + :meth:`torch.chunk`, where ``dim`` in torch.chunk corresponds to the + specified ``dim`` and ``chunks`` in torch.chunk is the number of elements + in the placement specified. + + Args: + dim (int or str): + The dimension to shard on, could be an integer representing the + dimension or a string in case of named tensors where dimensions are + named. Note that named tensor support is not added yet. + placement(List[Union[_remote_device, str]]): + Specifies the placement of each shard of the Tensor. The size of + the list represents the number of shards to be created. This could + be a list of + :class:`torch.distributed._remote_device`'s. This list + could also contain a string which represents remote + device as accepted by + :class:`torch.distributed._remote_device` + """ + + ShardingDim = int | str + + dim: ShardingDim + placements: list[torch.distributed._remote_device | str] + + def __post_init__(self): + self._verify_dim(self.dim) + for i, remote_device in enumerate(self.placements): + if not isinstance(remote_device, torch.distributed._remote_device): + self.placements[i] = torch.distributed._remote_device(remote_device) + + @staticmethod + def _verify_dim(dim): + # Validate the sharding spec. + # TODO: support named dimension + if isinstance(dim, str): + raise NotImplementedError( + "ChunkShardingSpec does not support named dimension yet!" + ) + + if not isinstance(dim, int): + raise ValueError(f"Sharding dim needs to be an integer, found: {dim}") + + def build_metadata( + self, + tensor_sizes: torch.Size, + tensor_properties: sharded_tensor_meta.TensorProperties, + ) -> sharded_tensor_meta.ShardedTensorMetadata: + tensor_num_dim = len(tensor_sizes) + + self._verify_dim(self.dim) + if self.dim >= tensor_num_dim or self.dim < -tensor_num_dim: # type: ignore[operator] + raise ValueError(f"Invalid sharding dim: {self.dim}") + + shards_metadata = [] + sharding_dim_size = tensor_sizes[self.dim] # type: ignore[index] + chunks = len(self.placements) + split_size = get_split_size(sharding_dim_size, chunks) + for idx, placement in enumerate(self.placements): + # generate ShardMetadata for each placement device + chunked_dim_size = get_chunked_dim_size(sharding_dim_size, split_size, idx) + shard_size = list(tensor_sizes) + current_offsets = [0] * tensor_num_dim + current_offsets[self.dim] = split_size * idx # type: ignore[index] + shard_size[self.dim] = chunked_dim_size # type: ignore[index] + + shard_metadata = ShardMetadata( + shard_offsets=current_offsets, + shard_sizes=shard_size, + placement=placement, + ) + shards_metadata.append(shard_metadata) + + return sharded_tensor_meta.ShardedTensorMetadata( + shards_metadata, tensor_sizes, tensor_properties + ) + + def shard( + self, tensor: torch.Tensor, src_rank: int = 0, process_group=None + ) -> "ShardedTensor": + """ + Args: + src_rank: group rank relative to ``process_group`` + + N.B. If ``process_group`` is None, ``src_rank`` is a global rank. + """ + # relative imports to avoid circular dependency + from torch.distributed._shard.sharded_tensor import ShardedTensor + + tensor_properties = sharded_tensor_meta.TensorProperties( + dtype=tensor.dtype, + layout=tensor.layout, + requires_grad=tensor.requires_grad, + memory_format=torch.contiguous_format, + pin_memory=tensor.is_pinned(), + ) + current_rank = dist.get_rank(process_group) + current_global_rank = dist.get_rank() + tensor_meta = self.build_metadata(tensor.size(), tensor_properties) + local_shards = [] + local_tensor = None + local_metadata = None + + tensors_to_scatter = cast( + list[torch.Tensor | None], + [None] * dist.get_world_size(process_group), + ) + + sharding_dim_size = tensor.size()[self.dim] # type: ignore[index] + chunks = len(self.placements) + split_size = get_split_size(sharding_dim_size, chunks) + scatter_shape = list(tensor.size()) + scatter_shape[self.dim] = split_size # type: ignore[index] + + for shard_meta in tensor_meta.shards_metadata: + remote_global_rank, device = _parse_and_validate_remote_device( + process_group, shard_meta.placement + ) + if current_rank == src_rank: + # Reshape to get shard for this rank and we don't want autograd + # recording here for the narrow op and 'local_shard' should be a + # leaf variable in the autograd graph. + narrowed_tensor = narrow_tensor(tensor, shard_meta) + if shard_meta.shard_sizes[self.dim] < split_size: # type: ignore[index] + # for the last shard that might be smaller to other shards + # resize the narrowed tensor to the same size and use it for + # the scatter collective as dist.scatter requires same size + # inputs on every rank + tensor_to_scatter = ( + narrowed_tensor.detach().clone().resize_(scatter_shape) + ) + else: + tensor_to_scatter = narrowed_tensor.detach().clone( + memory_format=torch.contiguous_format + ) + + tensors_to_scatter[ + # pyrefly: ignore [bad-argument-type] + dist.get_group_rank(process_group, remote_global_rank) + ] = tensor_to_scatter + + if current_global_rank == remote_global_rank: + local_tensor = torch.empty( + scatter_shape, + dtype=tensor.dtype, + layout=tensor.layout, + device=device, + ) + local_metadata = shard_meta + + # each rank should have local_tensor and local_metadata initialized if we build + # the metadata list in a correct way. + assert local_tensor is not None + assert local_metadata is not None + + # Scatter the shards to all ranks in the pg + # scatter takes the global rank as ``src`` + src_for_scatter = src_rank + if ( + process_group is not None + and process_group is not distributed_c10d._get_default_group() + ): + src_for_scatter = distributed_c10d.get_global_rank( + process_group, src_for_scatter + ) + + tensors_to_scatter_: list[torch.Tensor] | None = None + if current_rank == src_rank: + tensors_to_scatter_ = [] + for t in tensors_to_scatter: + assert isinstance(t, torch.Tensor) + tensors_to_scatter_.append(t) + + dist.scatter( + local_tensor, + scatter_list=tensors_to_scatter_, + src=src_for_scatter, + group=process_group, + ) + + if list(local_tensor.size()) != local_metadata.shard_sizes: + # detach again after receiving to ensure local shards remain a leaf node + local_tensor = local_tensor.resize_(local_metadata.shard_sizes).detach() + + # Sync requires_grad to local_shard. + local_tensor.requires_grad = tensor.requires_grad + + local_shards.append(Shard(tensor=local_tensor, metadata=local_metadata)) + + st = ShardedTensor._init_from_local_shards_and_global_metadata( + local_shards, tensor_meta, process_group=process_group + ) + + # Manually set sharding_spec + st._sharding_spec = self + + return st diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_sharded_tensor/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_sharded_tensor/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..24de2628c0ab9ceb89fa28b52753a421b58b56c2 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_sharded_tensor/__init__.py @@ -0,0 +1,21 @@ +# Keep old package for BC purposes, this file should be removed once +# everything moves to the `torch.distributed._shard` package. +import sys +import warnings + +import torch +from torch.distributed._shard.sharded_tensor import * # noqa: F403 + + +with warnings.catch_warnings(): + warnings.simplefilter("always") + warnings.warn( + "`torch.distributed._sharded_tensor` will be deprecated, " + "use `torch.distributed._shard.sharded_tensor` instead", + DeprecationWarning, + stacklevel=2, + ) + +sys.modules["torch.distributed._sharded_tensor"] = ( + torch.distributed._shard.sharded_tensor +) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_sharding_spec/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_sharding_spec/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c74dd3633e0f5e8436b844fd2d14f3bdb00635b7 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_sharding_spec/__init__.py @@ -0,0 +1,22 @@ +# Keep old package for BC purposes, this file should be removed once +# everything moves to the `torch.distributed._shard` package. +import sys +import warnings + +import torch +from torch.distributed._shard.sharding_spec import * # noqa: F403 + + +with warnings.catch_warnings(): + warnings.simplefilter("always") + warnings.warn( + "`torch.distributed._sharding_spec` will be deprecated, " + "use `torch.distributed._shard.sharding_spec` instead", + DeprecationWarning, + stacklevel=2, + ) + +import torch.distributed._shard.sharding_spec as _sharding_spec + + +sys.modules["torch.distributed._sharding_spec"] = _sharding_spec diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_sharding_spec/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_sharding_spec/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5d20967fbf452710405e9b6d3d48dd3d4519e0c2 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_sharding_spec/__pycache__/__init__.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_symmetric_memory/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_symmetric_memory/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0ee29ea452143fce950421ade8803bf53776d397 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_symmetric_memory/__init__.py @@ -0,0 +1,2121 @@ +from __future__ import annotations + +import math +import os +import socket +import uuid +from collections.abc import Callable, Generator +from contextlib import contextmanager +from datetime import timedelta +from enum import Enum +from functools import partial +from typing import Any, Literal + +import torch +import torch.distributed._functional_collectives as funcol +import torch.distributed.distributed_c10d as c10d +from torch._C._autograd import DeviceType +from torch._C._distributed_c10d import _SymmetricMemory, Work as _Work + + +_group_name_to_store: dict[str, c10d.Store] = {} + + +def enable_symm_mem_for_group(group_name: c10d.GroupName) -> None: + """ + Enables symmetric memory for a process group. + + Args: + group_name (str): the name of the process group. + """ + if group_name in _group_name_to_store: + return + + group = c10d._resolve_process_group(group_name) + global_ranks = sorted(c10d._world.pg_group_ranks[group].keys()) + # Different subgroups with the same name should use different stores + global_ranks_str = "_".join(map(str, global_ranks)) + store = c10d.PrefixStore( + f"symmetric_memory-{global_ranks_str}", + c10d._get_process_group_store(group), + ) + _group_name_to_store[group_name] = store + _SymmetricMemory.set_group_info( + group_name, + group.rank(), + group.size(), + store, + ) + + +_is_test_mode: bool = False +_mocked_group_names: set[str] | None = None + + +@contextmanager +def _test_mode(group_names: set[str] | None = None) -> Generator[None, None, None]: + """ + Forces ``is_symm_mem_enabled_for_group()`` to return ``True`` and the ops + defined in the ``symm_mem`` namespace to use fallback implementations. + + The context manager is not thread safe. + """ + global _is_test_mode + global _mocked_group_names + prev = _is_test_mode + prev_group_names = _mocked_group_names + try: + _is_test_mode = True + _mocked_group_names = group_names + yield + finally: + _is_test_mode = prev + _mocked_group_names = prev_group_names + + +def is_symm_mem_enabled_for_group(group_name: c10d.GroupName) -> bool: + """ + Check if symmetric memory is enabled for a process group. + + Args: + group_name (str): the name of the process group. + """ + if _is_test_mode: + return _mocked_group_names is None or group_name in _mocked_group_names + return group_name in _group_name_to_store + + +_group_name_to_workspace_tensor: dict[str, torch.Tensor | None] = {} + + +def get_symm_mem_workspace( + group_name: c10d.GroupName, min_size: int +) -> _SymmetricMemory: + """ + Get the symmetric memory workspace associated with the process group. If + ``min_size`` is greater than the workspace associated with ``group_name``, + the workspace will be re-allocated and re-rendezvous'd. + + Args: + group_name (str): the name of the process group. + min_size (int): the size requirement for the workspace in bytes. + + Returns: + _SymmetricMemory: the symmetric memory workspace associated with the + group. + """ + enable_symm_mem_for_group(group_name) + + tensor = _group_name_to_workspace_tensor.get(group_name) + size = tensor.numel() * tensor.element_size() if tensor is not None else 0 + if tensor is None or size < min_size: + if torch.cuda.is_current_stream_capturing(): + curr_size = 0 if tensor is None else tensor.numel() * tensor.element_size() + raise RuntimeError( + f"get_symm_mem_workspace(): the requested size ({min_size} bytes) " + "is greater than the size of the currently allocated workspace " + f"({curr_size} bytes). It's currently not possible to expand the " + "workspace size during graph capture. Please invoke " + f'`get_symm_mem_workspace(group_name="{group_name}", ' + f'min_size="{min_size}")` before initiating the graph capture ' + "and try again." + ) + tensor = _SymmetricMemory.empty_strided_p2p( + (max(size, min_size),), + [1], + torch.uint8, + torch.device(f"cuda:{torch.cuda.current_device()}"), + group_name, + ) + _group_name_to_workspace_tensor[group_name] = tensor + return _SymmetricMemory.rendezvous(tensor) + + +_backend_streams: dict[int, torch.cuda.Stream] = {} + + +def _get_backend_stream(priority: int = 0) -> torch.cuda.Stream: + if priority not in _backend_streams: + _backend_streams[priority] = torch.cuda.Stream(priority=priority) + return _backend_streams[priority] + + +def _pipelined_multi_all_gather_and_consume( + shard: list[torch.Tensor], + shard_consumer: Callable[[list[torch.Tensor], int], None], + ag_out: list[torch.Tensor], + group_name: c10d.GroupName, + ag_out_needed: bool = True, +) -> None: + """ + Perform the following logic with micro-pipelined computation and + communication: + + gathered = [ + all_gather_tensor(x, gather_dim=0, group=group) + for x in shard + ] + + shards = [[] for _ in range(group_size)] + for x in ag_out: + for i, y in enumerate(x.chunk(group_size)): + shards[i].append(y) + + for src_rank, shard in enumerate(shards): + shard_consumer(shard, src_rank) + """ + p2p_workspace_size_req = 0 + for x in shard: + p2p_workspace_size_req += x.numel() * x.element_size() + symm_mem = get_symm_mem_workspace(group_name, min_size=p2p_workspace_size_req) + group_size = symm_mem.world_size + rank = symm_mem.rank + + symm_mem.barrier(channel=0) + backend_stream = _get_backend_stream() + backend_stream.wait_stream(torch.cuda.current_stream()) + + for x, y in zip(shard, ag_out): + assert x.is_contiguous(), ( + "_pipelined_all_gather_and_consume: all tensors " + "in `shard` must be contiguous" + ) + assert y.is_contiguous(), ( + "_pipelined_all_gather_and_consume: all tensors " + "in `ag_out` must be contiguous" + ) + assert x.shape[0] * group_size == y.shape[0] + assert x.shape[1:] == y.shape[1:] + + def copy_shard(dst: list[torch.Tensor], src: list[torch.Tensor]) -> None: + for d, s in zip(dst, src): + d.copy_(s) + + def get_p2p_bufs(remote_rank: int) -> list[torch.Tensor]: + offset_bytes = 0 + bufs = [] + for x in shard: + buf = symm_mem.get_buffer( + remote_rank, + x.shape, + x.dtype, + storage_offset=offset_bytes // x.element_size(), + ) + bufs.append(buf) + offset_bytes += buf.numel() * buf.element_size() + return bufs + + local_p2p_bufs = get_p2p_bufs(rank) + + # shards[i] => shard from rank i + shards: list[list[torch.Tensor]] = [[] for _ in range(group_size)] + for x in ag_out: + for i, y in enumerate(x.chunk(group_size)): + shards[i].append(y) + + # Parallelization strategy: after each rank copies its shard into its local + # p2p buffer, every rank issues independent p2p copy -> shard_consumer + # sequences to two streams. In addition to computation/communication + # overlapping, the strategy allows for computation/computation overlapping, + # greatly reducing quantization inefficiency. + # + # Notation: + # - "mv" for the copy to local buffer + # - "cp" for p2p copies + # - "b" for barriers + # + # Constraints: + # - The GPU scheduler may or may not overlap "mv" with the first shard_consumer. + # - "cp" from different streams cannot overlap. + # + # Ideal scenario 0 - "mv" overlaps with the first shard_consumer: + # + # stream 0: [ shard_consumer ][ cp ][ shard_consumer ] + # stream 1: [ mv ][b][ cp ][ shard_consumer ] + # + # Ideal scenario 1 - "mv" is scheduled before the first shard_consumer: + # + # stream 0: [ shard_consumer ][ cp ][ shard_consumer ] + # stream 1: [ mv ][b][ cp ][ shard_consumer ] + # + # Suboptimal scenario 0 - "mv" is scheduled after the first shard_consumer: + # + # stream 0: [ shard_consumer ] [ cp ][ shard_consumer ] + # stream 1: [ mv ][b][ cp ][ shard_consumer ] + # + # Suboptimal scenario 0 - "b" is scheduled after the first shard_consumer: + # + # stream 0: [ shard_consumer ] [ cp ][ shard_consumer ] + # stream 1: [ mv ] [b][ cp ][ shard_consumer ] + # + # We haven't yet figured out a way to ensure "mv" and "b" are either + # overlapped with or scheduled before the first shard_consumer. Thus, to + # prevent suboptimal scenarios, we are giving up the chance to overlap "mv" + # and "b" with the first shard_consumer for now. + copy_shard(dst=local_p2p_bufs, src=shard) + symm_mem.barrier(channel=1) + backend_stream.wait_stream(torch.cuda.current_stream()) + + # At this point, all ranks have copied their local shard to + # their local p2p buffer. Each rank can now copy and consume + # remote shards. + shard_consumer(shard, rank) + + for step in range(1, group_size): + if step % 2 == 0: + stream = torch.cuda.current_stream() + else: + stream = backend_stream + remote_rank = (step + rank) % group_size + remote_p2p_bufs = get_p2p_bufs(remote_rank) + with stream: + copy_shard(dst=shards[remote_rank], src=remote_p2p_bufs) + shard_consumer(shards[remote_rank], remote_rank) + + if ag_out_needed: + # Copy from input to the all-gather output. Opportunistically overlap + # it with the last shard_consumer. + if group_size % 2 == 0: + stream = torch.cuda.current_stream() + else: + stream = backend_stream + with stream: + copy_shard(dst=shards[rank], src=shard) + + torch.cuda.current_stream().wait_stream(backend_stream) + symm_mem.barrier(channel=0) + + +def _pipelined_all_gather_and_consume( + shard: torch.Tensor, + shard_consumer: Callable[[torch.Tensor, int], None], + ag_out: torch.Tensor, + group_name: c10d.GroupName, + ag_out_needed: bool = True, +) -> None: + """ + Perform the following logic with micro-pipelined computation and + communication: + + ag_out = all_gather_tensor(shard, gather_dim=0, group=group) + shards = ag_out.chunk(group.size()) + for src_rank, shard in enumerate(shards): + shard_consumer(shard, src_rank) + """ + + def adapter(shard: list[torch.Tensor], rank: int) -> None: + shard_consumer(shard[0], rank) + + _pipelined_multi_all_gather_and_consume( + [shard], + adapter, + [ag_out], + group_name, + ag_out_needed, + ) + + +def _pipelined_produce_and_all2all( + chunk_producer: Callable[[int, torch.Tensor], None], + output: torch.Tensor, + group_name: c10d.GroupName, + out_chunk_dim: int = 0, +) -> None: + """ + Perform the following logic with micro-pipelined computation and + communication: + + chunks = [ + chunk_producer(dst_rank, chunks[dst_rank]) + for dst_rank in range(group_size): + ] + dist.all_to_all_single(output=output, input=torch.cat(chunks)) + """ + out_chunks = output.chunk( + c10d._get_group_size_by_name(group_name), dim=out_chunk_dim + ) + p2p_workspace_size_req = out_chunks[0].numel() * out_chunks[0].element_size() * 2 + symm_mem = get_symm_mem_workspace(group_name, min_size=p2p_workspace_size_req) + group_size = symm_mem.world_size + rank = symm_mem.rank + + symm_mem.barrier(channel=0) + backend_stream = _get_backend_stream() + backend_stream.wait_stream(torch.cuda.current_stream()) + + def get_p2p_buf(rank: int, idx: int) -> torch.Tensor: + assert idx in (0, 1) + offset = 0 if idx == 0 else out_chunks[0].numel() + return symm_mem.get_buffer( + rank, out_chunks[0].shape, out_chunks[0].dtype, offset + ) + + # Prepare two local p2p buffers, so that a remote rank can pull the result + # of step [i] in one p2p buffer while the local rank can compute the + # result of step [i+1] and write it directly the other p2p buffer. + local_p2p_buf_0 = get_p2p_buf(rank, 0) + local_p2p_buf_1 = get_p2p_buf(rank, 1) + + for step in range(1, group_size): + remote_rank = (rank - step) % group_size + if step % 2 == 0: + stream = torch.cuda.current_stream() + p2p_buf = local_p2p_buf_1 + remote_p2p_buf = get_p2p_buf(remote_rank, 1) + else: + stream = backend_stream + p2p_buf = local_p2p_buf_0 + remote_p2p_buf = get_p2p_buf(remote_rank, 0) + with stream: + # Parallelization strategy: every rank issues independent compute + # -> barrier -> p2p copy sequences on two streams. In addition to + # computation/communication overlapping, the strategy allows for + # computation/computation overlapping, greatly reducing + # quantization inefficiency. + # + # Ideally, stream activities would look like this ("b" for + # barriers, "cp" for p2p copies): + # + # [rank 0] + # stream 0: [ chunk_producer ][b][ cp ][ chunk_producer ][b][ cp ] + # stream 1: [ chunk_producer ][b][ cp ][ chunk_producer ][b][ cp ] + # + # [rank 1] + # stream 0: [ chunk_producer ][b][ cp ][ chunk_producer ][b][ cp ] + # stream 1: [ chunk_producer ][b][ cp ][ chunk_producer ][b][ cp ] + # + # Note that the barriers synchronize streams with the same ID + # across ranks. They don't synchronize streams on the same rank. + # + # Since the work on both streams is independent, there's no + # guarantee that the chunk_producer from stream 0 or stream 1 will + # be scheduled first. If there is a scheduling mismatch across + # ranks, the barrier forces all ranks to wait for the slowest. + # + # When scheduling mismatches occur among ranks, the stream + # activities might look like this (note that p2p copies from + # different streams cannot overlap with each other): + # + # [rank 0] + # stream 0: [ chunk_producer ][b ][ cp ][ chunk_producer ][b ][ cp ] + # stream 1: [ chunk_producer ][b] [ cp ][ chunk_producer ][b] [ cp ] + # + # [rank 1] + # stream 0: [ chunk_producer ][b] [ cp ][ chunk_producer ][b] [ cp ] + # stream 1: [ chunk_producer ][b ][ cp ][ chunk_producer ][b ][ cp ] + # + # To prevent this, we need to ensure that the chunk_producer on + # stream 1 gets scheduled first on every rank. Without access to + # the underlying kernels, CUDA offers no API to control the + # scheduling order of two independent, overlapping kernels. Our + # solution is to issue a small sleep kernel in stream 0. The sleep + # duration is insignificant, but having an extra task in stream 0 + # will almost guarantee that the chunk_producer on stream 1 gets + # scheduled first. Once the first chunk_producer is scheduled in + # the correct order, there's very little room for the scheduling + # order of subsequent kernels to be inconsistent across ranks. + if step == 2: + torch.cuda._sleep(100) + chunk_producer((rank + step) % group_size, p2p_buf) + symm_mem.barrier(channel=step % 2) + out_chunks[remote_rank].copy_(remote_p2p_buf) + # The local P2P buffer can only be overwritten by the next + # chunk_producer after all peers have finished reading from it. + symm_mem.barrier(channel=step % 2) + + # If the sleep wasn't issued in the above loop, do it now. + if group_size == 2: + torch.cuda._sleep(100) + + chunk_producer(rank, out_chunks[rank]) + torch.cuda.current_stream().wait_stream(backend_stream) + symm_mem.barrier(channel=0) + + +lib = torch.library.Library("symm_mem", "DEF") # noqa: TOR901 +lib.define( + "fused_all_gather_matmul(" + "Tensor A, Tensor[] Bs, int gather_dim, str group_name, *, bool return_A = True) -> (Tensor?, Tensor[])", + tags=[torch._C.Tag.needs_fixed_stride_order], +) +lib.define( + "fused_all_gather_scaled_matmul(" + "Tensor A, Tensor[] Bs, Tensor A_scale, Tensor[] B_scales, " + "int gather_dim, str group_name, " + "Tensor?[] biases, " + "Tensor?[] result_scales, " + "ScalarType?[] out_dtypes, " + "bool[] use_fast_accum) -> (Tensor, Tensor[])", + tags=[torch._C.Tag.needs_fixed_stride_order], +) +lib.define( + "fused_matmul_reduce_scatter(Tensor A, Tensor B, str reduce_op, int scatter_dim, str group_name) -> Tensor", + tags=[torch._C.Tag.needs_fixed_stride_order], +) +lib.define( + "fused_scaled_matmul_reduce_scatter(" + "Tensor A, Tensor B, Tensor A_scale, Tensor B_scale, " + "str reduce_op, int orig_scatter_dim, int scatter_dim_after_maybe_reshape, str group_name, SymInt[]? output_shape, " + "Tensor? bias = None, " + "Tensor? result_scale = None, " + "ScalarType? out_dtype = None, " + "bool use_fast_accum = False) -> Tensor", + tags=[torch._C.Tag.needs_fixed_stride_order], +) +lib.define("_low_contention_all_gather(Tensor tensor, str group_name) -> Tensor") +lib.define( + "_low_contention_reduce_scatter(Tensor tensor, str reduce_op, str group_name) -> Tensor" +) + +lib.define("get_remote_tensors(Tensor x, str group_name) -> Tensor[]") +""" +Given a local tensor and a group name, return a tuple of tensors that are +symmetric on other devices. The returned tensors are ordered by rank IDs. The +length of the tuple equals to the size of the group. + +Note: this API works only when `world_within_direct_access()` returns True, i.e. +only when the group is within NVLink domain or similar. It does not work across +network interfaces. +""" + + +@torch.library.impl(lib, "get_remote_tensors", "CUDA") +def _get_remote_tensors_default( + local: torch.Tensor, group_name: c10d.GroupName +) -> tuple[torch.Tensor, ...]: + hdl = rendezvous(local, group_name) + if hdl is None: + raise ValueError("Tensor is not allocated from Symmetric Memory") + + return tuple( + hdl.get_remote_tensor(peer, local.size(), local.dtype) + for peer in range(hdl.world_size) + ) + + +@torch.library.impl(lib, "get_remote_tensors", "Meta") +def _get_remote_tensors_meta( + local: torch.Tensor, group_name: c10d.GroupName +) -> tuple[torch.Tensor, ...]: + group = c10d._resolve_process_group(group_name) + return tuple(torch.empty_like(local) for _ in range(group.size())) + + +class _ScaleMode(Enum): + UNSCALED = "unscaled" + TENSOR_WISE = "tensor-wise" + ROW_WISE_SHARDED = "row-wise-sharded" + ROW_WISE_REPLICATED = "row-wise-replicated" + + +def _check_and_verify_fp8_all_gather_scale_mode( + shard: torch.Tensor, scale: torch.Tensor | None, gather_dim: int, group_size: int +) -> _ScaleMode: + full_shape = list(shard.shape) + full_shape[gather_dim] *= group_size + + if scale is None: + return _ScaleMode.UNSCALED + elif scale.shape[:-1] == shard.shape[:-1] and scale.shape[-1] == 1: + # Row-wise scaling + # + # NOTE: when the last dim of both A_shard and A_scale is one, we can't + # tell if A_scale is replicated tensor-wise scale or sharded row-wise + # scale. Treating it as row-wise scaling for safety. + return _ScaleMode.ROW_WISE_SHARDED + elif scale.numel() == 1: + return _ScaleMode.TENSOR_WISE + elif list(scale.shape[:-1]) == full_shape[:-1]: + return _ScaleMode.ROW_WISE_REPLICATED + else: + raise ValueError( + "Invalid scale shape for fp8 all-gather " + f"(shard shape: {shard.shape}, scale shape: {scale.shape})" + ) + + +def _fused_all_gather_matmul_impl( + mm_out_op: torch._ops.OpOverload, + A_shard: torch.Tensor, + Bs: list[torch.Tensor], + A_scale: torch.Tensor | None, + kwargs_list: list[dict[str, Any]], + out_dtypes: list[torch.dtype | None], + gather_dim: int, + group_name: c10d.GroupName, + return_A: bool, +) -> tuple[torch.Tensor | None, list[torch.Tensor]]: + if A_shard.dim() < 2: + raise ValueError("A_shard must be a matrix") + for B in Bs: + if B.dim() != 2: + raise ValueError("B must be a matrix") + if len(out_dtypes) != len(Bs): + raise ValueError("len(out_types) must be the same as len(Bs)") + if len(kwargs_list) != len(Bs): + raise ValueError("len(kwargs_list) must be the same as len(Bs)") + if gather_dim < 0 or gather_dim >= A_shard.dim(): + raise ValueError("Invalid gather_dim") + + group = c10d._resolve_process_group(group_name) + + if gather_dim == A_shard.ndim - 1 or gather_dim == -1: + return _fused_all_gather_matmul_last_gather_dim_impl( + mm_out_op, + A_shard, + Bs, + A_scale, + kwargs_list, + out_dtypes, + gather_dim, + group_name, + return_A, + ) + + # Move the gather_dim to the front and flatten the tensor into a 2D matrix. + # The flattened tensor doesn't need to be contiguous (for computation + # efficiency), as _pipelined_all_gather_and_consume guarantees that shards + # passed to shard_consumer are contiguous. + A_shard_flat = A_shard.movedim(gather_dim, 0) + leading_dims = [group.size()] + list(A_shard_flat.shape[:-1]) + A_shard_flat = A_shard_flat.flatten(0, -2) + + # Helper function for reverting the above transformation + def unflatten(t: torch.Tensor) -> torch.Tensor: + return t.view(*leading_dims, -1).flatten(0, 1).movedim(0, gather_dim) + + A_flat = A_shard_flat.new_empty( + A_shard_flat.shape[0] * group.size(), + A_shard_flat.shape[1], + ) + + outputs = [ + A_flat.new_empty(A_flat.shape[0], B.shape[1], dtype=out_dtype or B.dtype) + for B, out_dtype in zip(Bs, out_dtypes) + ] + output_shards = [output.chunk(group.size()) for output in outputs] + + scale_mode = _check_and_verify_fp8_all_gather_scale_mode( + shard=A_shard, scale=A_scale, gather_dim=gather_dim, group_size=group.size() + ) + + # Computing block-wise matmul along the first dim of A + if scale_mode == _ScaleMode.ROW_WISE_SHARDED: + assert A_scale is not None + A_scale_shard = A_scale.movedim(gather_dim, 0).flatten(0, -2) + A_scale_flat = A_scale_shard.new_empty( + A_scale_shard.shape[0] * group.size(), + A_scale_shard.shape[1], + ) + + def row_wise_sharded_consumer(shard: list[torch.Tensor], rank: int) -> None: + for idx, (B, kwargs) in enumerate(zip(Bs, kwargs_list)): + mm_out_op( + shard[0], + B, + scale_a=shard[1], + **kwargs, + out=output_shards[idx][rank], + ) + + _pipelined_multi_all_gather_and_consume( + [A_shard_flat, A_scale_shard], + row_wise_sharded_consumer, + [A_flat, A_scale_flat], + group_name, + return_A, + ) + elif scale_mode == _ScaleMode.ROW_WISE_REPLICATED: + assert A_scale is not None + A_scale_shards = ( + A_scale.movedim(gather_dim, 0).flatten(0, -2).chunk(group.size()) + ) + + def row_wise_replicated_consumer(shard: torch.Tensor, rank: int) -> None: + for idx, (B, kwargs) in enumerate(zip(Bs, kwargs_list)): + mm_out_op( + shard, + B, + scale_a=A_scale_shards[rank], + **kwargs, + out=output_shards[idx][rank], + ) + + _pipelined_all_gather_and_consume( + A_shard_flat, + row_wise_replicated_consumer, + A_flat, + group_name, + return_A, + ) + else: + if scale_mode == _ScaleMode.TENSOR_WISE: + assert A_scale is not None + for kwargs in kwargs_list: + kwargs["scale_a"] = A_scale + else: + assert scale_mode == _ScaleMode.UNSCALED + + def default_consumer(shard: torch.Tensor, rank: int) -> None: + for idx, (B, kwargs) in enumerate(zip(Bs, kwargs_list)): + mm_out_op(shard, B, **kwargs, out=output_shards[idx][rank]) + + _pipelined_all_gather_and_consume( + A_shard_flat, + default_consumer, + A_flat, + group_name, + return_A, + ) + + A = unflatten(A_flat) if return_A else None + return A, [unflatten(output) for output in outputs] + + +def _pipelined_all_gather_and_consume_last_dim( + shard: torch.Tensor, + shard_consumer: Callable[[torch.Tensor, int], None], + ag_out: torch.Tensor, + group_name: c10d.GroupName, + ag_out_needed: bool = True, +) -> None: + p2p_workspace_size_req = 0 + p2p_workspace_size_req = shard.numel() * shard.element_size() + symm_mem = get_symm_mem_workspace(group_name, min_size=p2p_workspace_size_req) + group_size = symm_mem.world_size + rank = symm_mem.rank + + symm_mem.barrier(channel=0) + backend_stream = _get_backend_stream() + backend_stream.wait_stream(torch.cuda.current_stream()) + + def copy_shard(dst: torch.Tensor, src: torch.Tensor) -> None: + dst.copy_(src) + + def get_p2p_buf(remote_rank: int) -> torch.Tensor: + buf = symm_mem.get_buffer( + remote_rank, + shard.shape, + shard.dtype, + ) + return buf + + local_p2p_buf = get_p2p_buf(rank) + + shards = ag_out.chunk(group_size) + + copy_shard(dst=local_p2p_buf, src=shard) + symm_mem.barrier(channel=1) + backend_stream.wait_stream(torch.cuda.current_stream()) + + # At this point, all ranks have copied their local shard to + # their local p2p buffer. Each rank can now copy and consume + # remote shards. + shard_consumer(shard, rank) + + for step in range(1, group_size): + if step % 2 == 0: + stream = torch.cuda.current_stream() + else: + stream = backend_stream + remote_rank = (step + rank) % group_size + remote_p2p_buf = get_p2p_buf(remote_rank) + with stream: + copy_shard(dst=shards[remote_rank], src=remote_p2p_buf) + shard_consumer(shards[remote_rank], remote_rank) + + if ag_out_needed: + # Copy from input to the all-gather output. Opportunistically overlap + # it with the last shard_consumer. + if group_size % 2 == 0: + stream = torch.cuda.current_stream() + else: + stream = backend_stream + with stream: + copy_shard(dst=shards[rank], src=shard) + + torch.cuda.current_stream().wait_stream(backend_stream) + symm_mem.barrier(channel=0) + + +def _fused_all_gather_matmul_last_gather_dim_impl( + mm_out_op: torch._ops.OpOverload, + A_shard: torch.Tensor, + Bs: list[torch.Tensor], + A_scale: torch.Tensor | None, + kwargs_list: list[dict[str, Any]], + out_dtypes: list[torch.dtype | None], + gather_dim: int, + group_name: c10d.GroupName, + return_A: bool, +) -> tuple[torch.Tensor | None, list[torch.Tensor]]: + group = c10d._resolve_process_group(group_name) + group_size = group.size() + + B_shards = [B.chunk(group.size()) for B in Bs] + + leading_dims = list(A_shard.shape[:-1]) + A_shard_flat = A_shard.flatten(0, -2) + + def unflatten(t: torch.Tensor) -> torch.Tensor: + return t.view(*leading_dims, -1) + + A_flat_out = A_shard_flat.new_empty( + A_shard_flat.shape[0] * group.size(), + A_shard_flat.shape[1], + ) + + outputs = [ + torch.empty( + (A_shard_flat.shape[0], B.shape[1]), + dtype=out_dtype or B.dtype, + device=A_shard.device, + ) + for B, out_dtype in zip(Bs, out_dtypes) + ] + + first = True + events = [torch.cuda.Event() for _ in outputs] + + def default_consumer(shard: torch.Tensor, rank: int) -> None: + nonlocal first + for out, event, B_shard, kwargs in zip(outputs, events, B_shards, kwargs_list): + event.wait() + if first: + torch.ops.aten.mm.out(shard, B_shard[rank], **kwargs, out=out) + else: + out.addmm_(shard, B_shard[rank]) + event.record() + + first = False + + _pipelined_all_gather_and_consume_last_dim( + A_shard_flat, + default_consumer, + A_flat_out, + group_name, + return_A, + ) + ret_A = None + if return_A: + # This path is inefficient and will be filtered out at passes stage + # Added only for completeness. + A_split_cat_out_flat = torch.cat(A_flat_out.chunk(group_size), dim=-1) + ret_A = unflatten(A_split_cat_out_flat) + + return ret_A, [unflatten(output) for output in outputs] + + +@torch.library.impl(lib, "fused_all_gather_matmul", "Meta") +def _fused_all_gather_matmul_fallback( + A_shard: torch.Tensor, + Bs: list[torch.Tensor], + gather_dim: int, + group_name: c10d.GroupName, + *, + return_A: bool = True, +) -> tuple[torch.Tensor | None, list[torch.Tensor]]: + group_size = c10d._get_group_size_by_name(group_name) + A = torch.ops._c10d_functional.all_gather_into_tensor( + A_shard.contiguous(), group_size, group_name + ) + A = torch.ops._c10d_functional.wait_tensor(A) + if gather_dim == A.ndim - 1 or gather_dim == -1: + A_splits = A.chunk(group_size) + A_mm = torch.cat(A_splits, dim=-1) + res = [torch.matmul(A_mm, B) for B in Bs] + if return_A: + return A_mm, res + else: + return None, res + + A = A.view(group_size, *A_shard.shape).movedim(gather_dim + 1, 1).flatten(0, 1) + res = [torch.matmul(A, B).movedim(0, gather_dim) for B in Bs] + if return_A: + return A.movedim(0, gather_dim), res + else: + return None, res + + +@torch.library.impl(lib, "fused_all_gather_matmul", "CUDA") +def _fused_all_gather_matmul( + A_shard: torch.Tensor, + Bs: list[torch.Tensor], + gather_dim: int, + group_name: c10d.GroupName, + *, + return_A: bool = True, +) -> tuple[torch.Tensor | None, list[torch.Tensor]]: + """ + Perform the following logic with micro-pipelined computation and + communication: + + all_gather_tensor(A_shard, gather_dim, group_name) @ B + + Optimal stride order for A_shard - if A_shard.movedim(gather_dim, 0) is + contiguous, no extra copy is required for input layout transformation. + Otherwise A_shard needs to be copied once. + """ + if _is_test_mode: + return _fused_all_gather_matmul_fallback( + A_shard, Bs, gather_dim, group_name, return_A=return_A + ) + + if _should_use_fused_all_gather_matmul_native(A_shard, Bs, gather_dim, group_name): + group = c10d._resolve_process_group(group_name) + leading_dims = list(A_shard.shape[:-1]) + leading_dims[0] *= group.size() + A, out = _fused_all_gather_matmul_native( + A_shard.flatten(0, -2), Bs[0], group_name + ) + return A.view(*leading_dims, -1), [out.view(*leading_dims, -1)] + + if _should_use_multimem_all_gather_matmul( + A_shard, gather_dim, group_name, return_A + ): + return None, _multimem_all_gather_matmul(A_shard, Bs, group_name) + + with torch.profiler.record_function("fused_all_gather_matmul"): + return _fused_all_gather_matmul_impl( + torch.ops.aten.mm.out, + A_shard, + Bs, + None, + [{} for B in Bs], + [B.dtype for B in Bs], + gather_dim, + group_name, + return_A, + ) + + +def _should_use_fused_all_gather_matmul_native( + A_shard: torch.Tensor, + Bs: list[torch.Tensor], + gather_dim: int, + group_name: c10d.GroupName, +) -> bool: + group = c10d._resolve_process_group(group_name) + local_M = math.prod(A_shard.shape[:-1]) + + return ( + "TORCH_SYMM_MEM_ENABLE_NATIVE_ASYNC_TP" in os.environ + and A_shard.is_contiguous() + and gather_dim == 0 + # _async_input_mm requires local_M to be divisible by world_size. + and local_M % group.size() == 0 + # _async_input_mm outperforms the decomposition-based approach when the + # global M is small. + and 2048 < local_M * group.size() <= 4096 + # _async_input_mm only supports a single B. + and len(Bs) == 1 + ) + + +def _fused_all_gather_matmul_native( + A_shard: torch.Tensor, + B: torch.Tensor, + group_name: c10d.GroupName, +) -> tuple[torch.Tensor, torch.Tensor]: + symm_mem = rendezvous(A_shard, group_name) + if symm_mem is None: + symm_mem = get_symm_mem_workspace( + group_name, A_shard.numel() * A_shard.element_size() + ) + symm_mem.barrier() + buf = symm_mem.get_buffer(symm_mem.rank, A_shard.shape, A_shard.dtype) + buf.copy_(A_shard) + A_shard = buf + + rank = symm_mem.rank + world_size = symm_mem.world_size + + current_stream = torch.cuda.current_stream() + backend_stream = _get_backend_stream(priority=-1) + + symm_mem.barrier() + backend_stream.wait_stream(current_stream) + current_stream.wait_stream(backend_stream) + + A = A_shard.new_empty(A_shard.shape[0] * world_size, A_shard.shape[1]) + A_signals = torch.zeros(world_size, dtype=torch.uint32, device=A_shard.device) + A_shards = A.chunk(world_size) + + A_shards[rank].copy_(A_shard) + if not torch.cuda.is_current_stream_capturing(): + _SymmetricMemory.stream_write_value32(A_signals, rank, 1) + else: + _SymmetricMemory.memset32(A_signals, offset=rank, val=1, count=1) + + out = torch.ops.symm_mem._async_input_mm(A, B, A_signals, rank) + for step in range(1, world_size): + src_rank = (rank + step) % world_size + src_buf = symm_mem.get_buffer(src_rank, A_shard.shape, A_shard.dtype) + with backend_stream: + A_shards[src_rank].copy_(src_buf) + if not torch.cuda.is_current_stream_capturing(): + # cuStreamWriteValue32 issues a system level fence before the write + _SymmetricMemory.stream_write_value32(A_signals, src_rank, 1) + else: + _SymmetricMemory.memset32(A_signals, offset=src_rank, val=1, count=1) + + current_stream.wait_stream(backend_stream) + backend_stream.wait_stream(current_stream) + + symm_mem.barrier() + return A, out + + +def _should_use_multimem_all_gather_matmul( + A_shard: torch.Tensor, + gather_dim: int, + group_name: c10d.GroupName, + return_A: bool, +) -> bool: + group = c10d._resolve_process_group(group_name) + local_M = math.prod(A_shard.shape[:-1]) + has_multicast_support = ( + A_shard.device.type == "cuda" + and _SymmetricMemory.has_multicast_support( + DeviceType.CUDA, A_shard.device.index + ) + ) + + return ( + has_multicast_support + and not return_A + and A_shard.is_contiguous() + and gather_dim == 0 + # The heuristic is empirical. We could refine it with a more + # sophisticated perf model. + and local_M * group.size() <= 2048 + ) + + +def _multimem_all_gather_matmul( + A_shard: torch.Tensor, + Bs: list[torch.Tensor], + group_name: c10d.GroupName, +) -> list[torch.Tensor]: + group = c10d._resolve_process_group(group_name) + A_shape = torch.Size((A_shard.shape[0] * group.size(), *A_shard.shape[1:])) + symm_mem = get_symm_mem_workspace( + group_name, A_shape.numel() * A_shard.element_size() + ) + A = symm_mem.get_buffer(symm_mem.rank, A_shape, A_shard.dtype) + torch.ops.symm_mem.multimem_all_gather_out(A_shard, group_name, A) + return [torch.matmul(A, B) for B in Bs] + + +@torch.library.impl(lib, "fused_all_gather_scaled_matmul", "Meta") +def _fused_all_gather_scaled_matmul_fallback( + A_shard: torch.Tensor, + Bs: list[torch.Tensor], + A_scale: torch.Tensor, + B_scales: list[torch.Tensor], + gather_dim: int, + group_name: c10d.GroupName, + biases: list[torch.Tensor | None], + result_scales: list[torch.Tensor | None], + out_dtypes: list[torch.dtype | None], + use_fast_accum: list[bool], +) -> tuple[torch.Tensor, list[torch.Tensor]]: + out_dtypes = _maybe_convert_scalar_types_to_dtypes(out_dtypes) + + group_size = c10d._get_group_size_by_name(group_name) + A = torch.ops._c10d_functional.all_gather_into_tensor( + A_shard.contiguous(), group_size, group_name + ) + A = torch.ops._c10d_functional.wait_tensor(A) + A = A.view(group_size, *A_shard.shape).movedim(gather_dim + 1, 1).flatten(0, 1) + + scale_mode = _check_and_verify_fp8_all_gather_scale_mode( + shard=A_shard, scale=A_scale, gather_dim=gather_dim, group_size=group_size + ) + if scale_mode == _ScaleMode.ROW_WISE_SHARDED: + A_scale_shard = A_scale + A_scale = torch.ops._c10d_functional.all_gather_into_tensor( + A_scale.contiguous(), group_size, group_name + ) + A_scale = torch.ops._c10d_functional.wait_tensor(A_scale) + A_scale = ( + A_scale.view(group_size, *A_scale_shard.shape) + .movedim(gather_dim + 1, 1) + .flatten(0, -2) + ) + elif scale_mode == _ScaleMode.ROW_WISE_REPLICATED: + A_scale = A_scale.movedim(gather_dim, 0).flatten(0, -2) + else: + assert scale_mode == _ScaleMode.TENSOR_WISE + + def scaled_matmul( + A: torch.Tensor, + B: torch.Tensor, + A_scale: torch.Tensor, + B_scale: torch.Tensor, + bias: torch.Tensor | None, + result_scale: torch.Tensor | None, + out_dtype: torch.dtype | None, + use_fast_accum: bool, + ) -> torch.Tensor: + leading_dims = A.shape[:-1] + res = torch.ops.aten._scaled_mm( + A.flatten(0, -2), + B, + A_scale, + B_scale, + bias, + result_scale, + out_dtype=out_dtype, + use_fast_accum=use_fast_accum, + ) + return res.unflatten(0, leading_dims) + + return A.movedim(0, gather_dim), [ + scaled_matmul( + A, B, A_scale, B_scale, bias, result_scale, out_dtype, fast_accum + ).movedim(0, gather_dim) + for B, B_scale, bias, result_scale, out_dtype, fast_accum in zip( + Bs, B_scales, biases, result_scales, out_dtypes, use_fast_accum + ) + ] + + +@torch.library.impl(lib, "fused_all_gather_scaled_matmul", "CUDA") +def _fused_all_gather_scaled_matmul( + A_shard: torch.Tensor, + Bs: list[torch.Tensor], + A_scale: torch.Tensor, + B_scales: list[torch.Tensor], + gather_dim: int, + group_name: c10d.GroupName, + biases: list[torch.Tensor | None], + result_scales: list[torch.Tensor | None], + out_dtypes: list[torch.dtype | None], + use_fast_accum: list[bool], +) -> tuple[torch.Tensor, list[torch.Tensor]]: + """ + Perform the following logic with micro-pipelined computation and + communication: + + A = all_gather_tensor(A_shard, gather_dim, group_name) + leading_dims = A.shape[:-1] + res = torch.ops.aten._scaled_mm(A.flatten(0, -2), B, A_scale, B_scale) + res = res.unflatten(0, leading_dims) + + The input `A_scale` can be tensor-wise, row-wise-sharded or + row-wise-replicated. + + Optimal stride order for `A_shard` - if `A_shard.movedim(gather_dim, 0)` is + contiguous, no extra copy is required for input layout transformation. + Otherwise A_shard needs to be copied once. + """ + out_dtypes = _maybe_convert_scalar_types_to_dtypes(out_dtypes) + + if len(biases) != len(Bs): + raise ValueError("len(biases) must be the same as len(Bs)") + if len(result_scales) != len(Bs): + raise ValueError("len(result_scales) must be the same as len(Bs)") + if len(out_dtypes) != len(Bs): + raise ValueError("len(out_dtypes) must be the same as len(Bs)") + if len(use_fast_accum) != len(Bs): + raise ValueError("len(use_gast_accum_list) must be the same as len(Bs)") + + if _is_test_mode: + return _fused_all_gather_scaled_matmul_fallback( + A_shard, + Bs, + A_scale, + B_scales, + gather_dim, + group_name, + biases, + result_scales, + out_dtypes, + use_fast_accum, + ) + + with torch.profiler.record_function("fused_all_gather_scaled_matmul"): + A, res = _fused_all_gather_matmul_impl( + torch.ops.aten._scaled_mm.out, + A_shard, + Bs, + A_scale, + [ + { + "scale_b": B_scale, + "bias": bias, + "scale_result": result_scale, + "out_dtype": out_dtype, + "use_fast_accum": fast_accum, + } + for B_scale, bias, result_scale, out_dtype, fast_accum in zip( + B_scales, biases, result_scales, out_dtypes, use_fast_accum + ) + ], + out_dtypes, + gather_dim, + group_name, + True, + ) + assert A is not None + return A, res + + +def make_contiguous_for_perm( + t: torch.Tensor, + perm: list[int], +) -> torch.Tensor: + """ + Restride `t` such that `t.permute(perm)` is contiguous. + """ + inv_perm = [0] * len(perm) + for i, p in enumerate(perm): + inv_perm[p] = i + return t.permute(perm).contiguous().permute(inv_perm) + + +def restride_A_shard_for_fused_all_gather_matmul( + t: torch.Tensor, + gather_dim: int, +) -> torch.Tensor: + """ + Restride the `A_shard` arg of `fused_all_gather_matmul` for optimal perf. + See the doc for `fused_all_gather_matmul` for detail. + """ + perm = list(range(len(t.shape))) + perm.insert(0, perm.pop(gather_dim)) + return make_contiguous_for_perm(t, perm) + + +@torch.library.impl(lib, "fused_matmul_reduce_scatter", "CUDA") +def _fused_matmul_reduce_scatter( + A: torch.Tensor, + B: torch.Tensor, + reduce_op: str, + scatter_dim: int, + group_name: c10d.GroupName, +) -> torch.Tensor: + """ + Perform the following logic with micro-pipelined computation and + communication: + + reduce_scatter_tensor(A @ B, reduce_op, scatter_dim, group_name) + + Optimal stride order for A - if A.movedim(scatter_dim, 0) is contiguous, no + extra copy is required for input layout transformation. Otherwise A needs + to be copied once. + """ + if _is_test_mode: + return _fused_matmul_reduce_scatter_fallback( + A, B, reduce_op, scatter_dim, group_name + ) + + with torch.profiler.record_function("fused_matmul_reduce_scatter"): + return _fused_matmul_reduce_scatter_impl( + mm_out_op=torch.ops.aten.mm.out, + A=A, + B=B, + kwargs={}, + out_dtype=A.dtype, + reduce_op=reduce_op, + scatter_dim=scatter_dim, + group_name=group_name, + ) + + +@torch.library.impl(lib, "fused_matmul_reduce_scatter", "Meta") +def _fused_matmul_reduce_scatter_fallback( + A: torch.Tensor, + B: torch.Tensor, + reduce_op: str, + scatter_dim: int, + group_name: c10d.GroupName, +) -> torch.Tensor: + res = funcol.reduce_scatter_tensor(A @ B, reduce_op, scatter_dim, group_name) + res = funcol.wait_tensor(res) + return res + + +def _fused_matmul_reduce_scatter_impl( + mm_out_op: torch._ops.OpOverload, + A: torch.Tensor, + B: torch.Tensor, + kwargs: dict[str, Any], + out_dtype: torch.dtype | None, + reduce_op: str, + scatter_dim: int, + group_name: c10d.GroupName, +) -> torch.Tensor: + if A.dim() < 2: + raise ValueError("A_shard must be a matrix") + if scatter_dim < 0 or scatter_dim >= A.dim(): + raise ValueError("Invalid gather_dim") + if B.dim() != 2: + raise ValueError("B must be a matrix") + if reduce_op == "sum": + reduce_fn = partial(torch.sum, dim=0) + elif reduce_op == "avg": + reduce_fn = partial(torch.mean, dim=0) + else: + raise ValueError("reduce_op must be sum or avg") + group = c10d._resolve_process_group(group_name) + out_shape = [*A.shape[:-1], B.shape[1]] + out_shape[scatter_dim] //= group.size() + + if scatter_dim == A.ndim - 1: + B_shards = B.chunk(group.size(), dim=B.ndim - 1) + A_flat = A.flatten(0, -2) + + def _chunk_producer(rank: int, out: torch.Tensor) -> None: + mm_out_op(A_flat, B_shards[rank], **kwargs, out=out) + + leading_dims = list(A.shape[:-1]) + + stacked_partials = torch.empty( + (A_flat.shape[0], B.shape[1]), + dtype=out_dtype or A.dtype, + device=A.device, + ) + + _pipelined_produce_and_all2all( + _chunk_producer, + stacked_partials, + group_name, + out_chunk_dim=1, + ) + + stacked_partials_view = stacked_partials.reshape( + *leading_dims, group.size(), -1 + ) + return reduce_fn( + stacked_partials_view, + dim=-2, + ) + + # Move the scatter_dim to the front and flatten the tensor into a 2D matrix + x = A.movedim(scatter_dim, 0) + leading_dims = [group.size()] + list(x.shape[:-1]) + leading_dims[1] //= group.size() + x = x.flatten(0, -2) + A_shards = x.chunk(group.size()) + + # Computing block-wise matmul along the first dim of A + def chunk_producer(rank: int, out: torch.Tensor) -> None: + mm_out_op(A_shards[rank], B, **kwargs, out=out) + + stacked_partials = x.new_empty(x.shape[0], B.shape[1], dtype=out_dtype or A.dtype) + + _pipelined_produce_and_all2all( + chunk_producer, + stacked_partials, + group_name, + ) + + # Ensures that the transpose and reduction produce contiguous result + # in a single reduction kernel. + return reduce_fn( + stacked_partials.view(*leading_dims, -1) + .movedim(1, scatter_dim + 1) + .movedim(0, scatter_dim), + dim=scatter_dim, + ) + + +@torch.library.impl(lib, "fused_scaled_matmul_reduce_scatter", "CUDA") +def _fused_scaled_matmul_reduce_scatter( + A: torch.Tensor, + B: torch.Tensor, + A_scale: torch.Tensor, + B_scale: torch.Tensor, + reduce_op: str, + orig_scatter_dim: int, + scatter_dim_after_maybe_reshape: int, + group_name: c10d.GroupName, + output_shape: list[int], + bias: torch.Tensor | None = None, + result_scale: torch.Tensor | None = None, + out_dtype: torch.dtype | None = None, + use_fast_accum: bool = False, +) -> torch.Tensor: + if _is_test_mode: + return _fused_scaled_matmul_reduce_scatter_fallback( + A, + B, + A_scale, + B_scale, + reduce_op, + orig_scatter_dim, + scatter_dim_after_maybe_reshape, + group_name, + output_shape, + bias, + result_scale, + out_dtype, + use_fast_accum, + ) + with torch.profiler.record_function("fused_scaled_matmul_reduce_scatter"): + return _fused_scaled_matmul_reduce_scatter_impl( + mm_out_op=torch.ops.aten._scaled_mm.out, + A=A, + B=B, + A_scale=A_scale, + kwargs={ + "scale_b": B_scale, + "bias": bias, + "scale_result": result_scale, + "out_dtype": out_dtype, + "use_fast_accum": use_fast_accum, + }, + out_dtype=out_dtype, + reduce_op=reduce_op, + orig_scatter_dim=orig_scatter_dim, + scatter_dim_after_maybe_reshape=scatter_dim_after_maybe_reshape, + group_name=group_name, + output_shape=output_shape, + ) + + +@torch.library.impl(lib, "fused_scaled_matmul_reduce_scatter", "Meta") +def _fused_scaled_matmul_reduce_scatter_fallback( + A: torch.Tensor, + B: torch.Tensor, + A_scale: torch.Tensor, + B_scale: torch.Tensor, + reduce_op: str, + orig_scatter_dim: int, + scatter_dim_after_maybe_reshape: int, + group_name: c10d.GroupName, + output_shape: list[int], + bias: torch.Tensor | None = None, + result_scale: torch.Tensor | None = None, + out_dtype: torch.dtype | None = None, + use_fast_accum: bool = False, +) -> torch.Tensor: + if A_scale.numel() > 1: + if A_scale.shape[:-1] != A.shape[:-1]: + raise ValueError( + "For row-wise scaling, the leading dims of A_scale " + "must match the leading dims of A " + f"(A shape: {A.shape}, A_scale shape: {A_scale.shape})" + ) + A_scale = A_scale.flatten(0, -2).contiguous() + elif A_scale.numel() != 1: + raise ValueError( + "Invalid A_scale shape " + f"(A shape: {A.shape}, A_scale shape: {A_scale.shape})" + ) + + C = torch._scaled_mm( + A.flatten(0, -2).contiguous(), + B, + A_scale, + B_scale, + bias, + result_scale, + out_dtype, + use_fast_accum, + ) + C = C.view(*output_shape[:-1], B.shape[1]) + res = funcol.reduce_scatter_tensor( + C, + reduce_op, + orig_scatter_dim, # need original scatter dim for 3D+ output tensor here + group_name, + ) + res = funcol.wait_tensor(res) + return res + + +def _fused_scaled_matmul_reduce_scatter_impl( + mm_out_op: torch._ops.OpOverload, + A: torch.Tensor, + B: torch.Tensor, + A_scale: torch.Tensor, + kwargs: dict[str, Any], + out_dtype: torch.dtype | None, + reduce_op: str, + orig_scatter_dim: int, + scatter_dim_after_maybe_reshape: int, + group_name: c10d.GroupName, + output_shape: list[int], +) -> torch.Tensor: + if A.dim() < 2: + raise ValueError("A_shard must be a matrix") + if ( + scatter_dim_after_maybe_reshape < 0 + or scatter_dim_after_maybe_reshape >= A.dim() + ): + raise ValueError("Invalid scatter dim for 2D tensor input to scaled_mm") + if orig_scatter_dim < 0 or orig_scatter_dim >= len(output_shape): + raise ValueError("Invalid scatter dim for 3D+ output tensor") + if B.dim() != 2: + raise ValueError("B must be a matrix") + if reduce_op == "sum": + reduce_fn = partial(torch.sum, dim=0) + elif reduce_op == "avg": + reduce_fn = partial(torch.mean, dim=0) + else: + raise ValueError("reduce_op must be sum or avg") + + group = c10d._resolve_process_group(group_name) + + # Move scatter to first dim, then shard the tensor along the first dim, so the chunk producer + # can perform matmuls along the first dim. + A_with_scatter_dim_0 = A.movedim(scatter_dim_after_maybe_reshape, 0) + + # To handle case where A is 3D+, reshape to 2D to prepare for mm which requires 2D inputs. + A_2D_with_scatter_dim_0 = A_with_scatter_dim_0.flatten(0, -2) + + # Partition A along the first dim to prepare for sharding across TP process group. + A_shards = A_2D_with_scatter_dim_0.chunk(group.size()) + + # Now that 'A' is sharded along the first dim, we need to update its scale(s) accordingly. + # How we do this depends on if we are using tensorwise scaling, rowwise scaling, or no scaling. + tensorwise_scaling = A_scale is not None and A_scale.numel() == 1 + rowwise_scaling = A_scale is not None and A_scale.numel() > 1 + + # For tensorwise scaling, the scale should be replicated so each shard has a copy. + if tensorwise_scaling: + A_scale_shards = [A_scale] * group.size() + + # For rowwise scaling, we need to move the scatter dim to the first dim to match the + # dim swap of the 'A' tensor. Then we can shard the scales along the first dim, just like + # the 'A' tensor. + elif rowwise_scaling: + if A_scale.shape[:-1] != A.shape[:-1]: + raise ValueError( + "For row-wise scaling, the leading dims of A_scale " + "must match the leading dims of A " + f"(A shape: {A.shape}, A_scale shape: {A_scale.shape})" + ) + A_scale = ( + A_scale.movedim(scatter_dim_after_maybe_reshape, 0) + .contiguous() + .flatten(0, -2) + ) + A_scale_shards = list(A_scale.chunk(group.size())) + # cuBLAS's row-wise kernel requires scales to be aligned to 16 bytes. + # When we slice them we might break this and need to reallocate them. + A_scale_shards = [ + t if t.data_ptr() % 16 == 0 else t.clone() for t in A_scale_shards + ] + else: + raise ValueError("A_scale cannot be none for scaled_mm") + + # Computing block-wise matmul along the first dim of A + def chunk_producer(rank: int, out: torch.Tensor) -> None: + mm_out_op(A_shards[rank], B, scale_a=A_scale_shards[rank], **kwargs, out=out) + + # Stacked partials will be the 2D outputs of the pipelined scaled mm, and will + # have the shape (A_with_scatter_dim_0_tensor.shape[0], B.shape[1]) to align with the formula: + # (a*b,c) @ (c,d) = (a*b,d) + stacked_partials = A_with_scatter_dim_0.new_empty( + A_2D_with_scatter_dim_0.shape[0], B.shape[1], dtype=out_dtype or A.dtype + ) + + # Execute the pipelined mm/scaled_mm. + _pipelined_produce_and_all2all( + chunk_producer, + stacked_partials, + group_name, + ) + + # We now need to transform the *unreduced* stacked 2D partial mm outputs to an *unreduced* 3D+ output, + # then reduce-scatter. To do this, we first need to determine the shape of the unreduced 3D+ output, + # to reshape our stacked partials so we can apply the reduce-scatter. + # + # The *unreduced* 3D+ tensor will have dim 0 = `group_size`, as we have `group_size` instances of + # stacked partial outputs. The next dims will be A's leading dims (sharded along the original scatter dim), + # as it was the left operand of the mm op. We can use -1 as the final dim of the view to populate the rest. + stacked_partials_3D_leading_dims = [group.size()] + list( + # We use A from after the dim swap 0<=>scatter_dim, but before the flatten, + # to get the leading dims of the 3D+ view of stacked partials. + A_with_scatter_dim_0.shape[:-1] + ) + + # The `group_size` leading dim has been prepended to `stacked_partials_3D_leading_dims`, + # to capture the partial output from each rank. We need to divide the sharding/scatter dim + # by the group size. If the original scatter dim was 0, then it is now dim 1 in this + # tensor, since this new `group_size` dim was prepended. + stacked_partial_scatter_dim = orig_scatter_dim if orig_scatter_dim > 0 else 1 + stacked_partials_3D_leading_dims[stacked_partial_scatter_dim] //= group.size() + + # Ensures that the transpose and reduction produce contiguous result + # in a single reduction kernel. + reduced_out = reduce_fn( + # View 2D stacked partials as 3D+ tensor of shape (`group_size`, ...) + stacked_partials.view(*stacked_partials_3D_leading_dims, -1) + # We originally swapped 0<=>scatter_dim_after_maybe_reshape. Now after + # prepending the `group_size` dim, to undo this original swap, we + # must swap 1<=>scatter_dim_after_maybe_reshape+1. + .movedim(1, scatter_dim_after_maybe_reshape + 1), + # Reduce along the `group_size` dim (0). + dim=0, + ) + + # Output shape must be scattered along original scatter dim as well. + output_shape[orig_scatter_dim] //= group.size() + out = reduced_out.view(*output_shape) + return out + + +def restride_A_for_fused_matmul_reduce_scatter( + t: torch.Tensor, + scatter_dim: int, +) -> torch.Tensor: + """ + Restride the `A_shard` arg of `fused_matmul_reduce_scatter` for optimal + perf. See the doc for `fused_matmul_reduce_scatter` for detail. + """ + perm = list(range(len(t.shape))) + perm.insert(0, perm.pop(scatter_dim)) + return make_contiguous_for_perm(t, perm) + + +def _maybe_convert_scalar_types_to_dtypes( + scalar_types: list[Any], +) -> list[torch.dtype | None]: + """ + When a list of `torch.dtype`s is passed through the dispatcher as + `ScalarType[]`, it is converted to a list of scalar type enum values. This + function converts it back to a list of `torch.dtype`s. + """ + # Order defined in https://github.com/pytorch/pytorch/blob/344defc9733a45fee8d0c4d3f5530f631e823196/c10/core/ScalarType.h + _SCALAR_TYPE_TO_DTYPE = { + 0: torch.uint8, + 1: torch.int8, + 2: torch.short, + 3: torch.int, + 4: torch.int64, + 5: torch.half, + 6: torch.float, + 7: torch.double, + 8: torch.complex32, + 9: torch.complex64, + 10: torch.complex128, + 11: torch.bool, + 12: torch.qint8, + 13: torch.quint8, + 14: torch.qint32, + 15: torch.bfloat16, + 16: torch.float8_e5m2, + 17: torch.float8_e4m3fn, + 18: torch.float8_e5m2fnuz, + 19: torch.float8_e4m3fnuz, + } + if any(not isinstance(x, (type(None), int)) for x in scalar_types): + return scalar_types + + dtypes: list[torch.dtype | None] = [] + for scalar_type in scalar_types: + if scalar_type is None: + dtypes.append(scalar_type) + elif scalar_type not in _SCALAR_TYPE_TO_DTYPE: + raise ValueError(f"Unrecognized scalar type {scalar_type}") + else: + dtypes.append(_SCALAR_TYPE_TO_DTYPE[scalar_type]) + return dtypes + + +class Work(_Work): + def __init__(self) -> None: + super().__init__() + self.event = torch.cuda.Event() + self.event.record() + + def wait(self, timeout: timedelta = timedelta(seconds=0)) -> bool: + self.event.wait() + return True + + +""" +NOTE [low-contention collectives] +When a collective is overlapped with abundant compute, it makes sense to +prioritize reducing the contention between the collective and the overlapped +compute, even at the cost of a slightly slower collective. + +Common collective implementations (e.g., NCCL without user buffer +registration) optimize for throughput with no ambient compute. However, such +implementations may not be optimal when they are overlapped with compute: +- These implementations typically fuse the entire collective into a single +kernel and reserve SM resources based on the most demanding portion of the +collective, even when a large portion of the collective does not require this +much resource. +- These implementations often use SM-based P2P copy as opposed to copy +engine-based P2P copy. Copy engine-based P2P copy may not have a significant +advantage when there's no ambient compute. However, it may significantly +improve overall resource utilization in the presence of ambient compute. + +When overlapped with intensive compute (e.g., persistent matmul kernels), the +SM-usage of a collective can lead to inefficient overlapping. + +Low-contention collectives achieve their goals with the following strategies: +- Use copy engine-based copy whenever possible. +- Break down portions of a collective with different resource requirements +into multiple kernels. This improves the overlapping efficiency at the cost +of additional launching overhead. +""" + + +@torch.library.impl(lib, "_low_contention_all_gather", "Meta") +def _low_contention_all_gather_meta( + tensor: torch.Tensor, + group_name: c10d.GroupName, +) -> torch.Tensor: + group_size = c10d._get_group_size_by_name(group_name) + return tensor.new_empty(tensor.shape[0] * group_size, *tensor.shape[1:]) + + +@torch.library.impl(lib, "_low_contention_all_gather", "CUDA") +def _low_contention_all_gather( + tensor: torch.Tensor, + group_name: c10d.GroupName, +) -> torch.Tensor: + """ + Performs all-gather with symmetric memory in a low-contention fashion. + + When `tensor` is already in symmetric memory: + - The collective is carried out without using SMs. + - No symmetric memory workspace is required. + + When `tensor` is not in symmetric memory: + - An extra SM-based copy is performed to copy the input data into the + symmetric memory workspace. + - Symmetric memory workspace size requirement: the size of `tensor`. + """ + symm_mem = rendezvous(tensor, group_name) + if symm_mem is not None: + input_is_symm_mem = True + else: + symm_mem = get_symm_mem_workspace( + group_name, tensor.numel() * tensor.element_size() + ) + input_is_symm_mem = False + + rank = symm_mem.rank + world_size = symm_mem.world_size + + output = tensor.new_empty(tensor.shape[0] * world_size, *tensor.shape[1:]) + chunks = output.chunk(world_size) + + _get_backend_stream().wait_stream(torch.cuda.current_stream()) + with _get_backend_stream(): + if not input_is_symm_mem: + local_buf = symm_mem.get_buffer(rank, tensor.shape, tensor.dtype) + local_buf.copy_(tensor) + # pull + symm_mem.barrier() + for step in range(world_size): + remote_rank = (rank - step) % world_size + src_buf = symm_mem.get_buffer(remote_rank, tensor.shape, tensor.dtype) + chunks[remote_rank].copy_(src_buf) + symm_mem.barrier() + torch._C._distributed_c10d._register_work(output, Work()) + return output + + +@torch.library.impl(lib, "_low_contention_reduce_scatter", "Meta") +def _low_contention_reduce_scatter_meta( + tensor: torch.Tensor, + reduce_op: str, + group_name: c10d.GroupName, +) -> torch.Tensor: + group_size = c10d._get_group_size_by_name(group_name) + return tensor.unflatten(0, (group_size, -1)).mean(dim=0) + + +def _low_contention_reduce_scatter_with_symm_mem_input( + tensor: torch.Tensor, + reduce_op: str, + symm_mem: _SymmetricMemory, +) -> torch.Tensor: + rank = symm_mem.rank + world_size = symm_mem.world_size + + assert tensor.shape[0] % world_size == 0 + a2a_res = torch.empty_like(tensor) + chunks = a2a_res.chunk(world_size) + + _get_backend_stream().wait_stream(torch.cuda.current_stream()) + with _get_backend_stream(): + # pull + offline reduction + symm_mem.barrier() + for step in range(world_size): + remote_rank = (rank - step) % world_size + src_buf = symm_mem.get_buffer( + remote_rank, + chunks[0].shape, + chunks[0].dtype, + chunks[0].numel() * rank, + ) + chunks[remote_rank].copy_(src_buf) + symm_mem.barrier() + + ret = a2a_res.unflatten(0, (world_size, -1)) + if reduce_op == "sum": + ret = ret.sum(dim=0) + elif reduce_op == "avg": + ret = ret.mean(dim=0) + else: + raise ValueError(f"reduce_op ({reduce_op}) is not supported") + torch._C._distributed_c10d._register_work(ret, Work()) + return ret + + +def _low_contention_reduce_scatter_with_workspace( + tensor: torch.Tensor, + reduce_op: str, + workspace: _SymmetricMemory, +) -> torch.Tensor: + rank = workspace.rank + world_size = workspace.world_size + + assert tensor.shape[0] % world_size == 0 + chunks = tensor.chunk(world_size) + + _get_backend_stream().wait_stream(torch.cuda.current_stream()) + with _get_backend_stream(): + # push + offline reduction + workspace.barrier() + for step in range(world_size): + remote_rank = (rank - step) % world_size + dst_buf = workspace.get_buffer( + remote_rank, chunks[0].shape, chunks[0].dtype, chunks[0].numel() * rank + ) + dst_buf.copy_(chunks[remote_rank]) + workspace.barrier() + + buf = workspace.get_buffer(rank, tensor.shape, tensor.dtype) + ret = buf.unflatten(0, (world_size, -1)) + if reduce_op == "sum": + ret = ret.sum(dim=0) + elif reduce_op == "avg": + ret = ret.mean(dim=0) + else: + raise ValueError(f"reduce_op ({reduce_op}) is not supported") + torch._C._distributed_c10d._register_work(ret, Work()) + return ret + + +@torch.library.impl(lib, "_low_contention_reduce_scatter", "CUDA") +def _low_contention_reduce_scatter( + tensor: torch.Tensor, + reduce_op: str, + group_name: c10d.GroupName, +) -> torch.Tensor: + """ + Performs reduce-scatter with symmetric memory in a low-contention fashion. + + This implementation performs a P2P-based all-to-all followed by an offline + reduction. + + When `tensor` is already in symmetric memory: + - Pull-based all-to-all is used. + - No symmetric memory workspace is required. + + When `tensor` is not in symmetric memory: + - Push-based all-to-all is used. + - Symmetric memory workspace size requirement: the size of `tensor`. + + SM-usage: + - SM-based copy of the rank's own chunk for the all-to-all. + - Reduction on the all-to-all result. + + TODO(yifu): the SM-based copy can be avoided with a list-based reduction + kernel. + """ + symm_mem = rendezvous(tensor, group_name) + if symm_mem is not None: + return _low_contention_reduce_scatter_with_symm_mem_input( + tensor, reduce_op, symm_mem + ) + else: + workspace = get_symm_mem_workspace( + group_name, tensor.numel() * tensor.element_size() + ) + return _low_contention_reduce_scatter_with_workspace( + tensor, reduce_op, workspace + ) + + +@torch.library.impl(lib, "all_to_all_vdev_2d", "Meta") +def _all_to_all_vdev_2d_meta( + input: torch.Tensor, + out: torch.Tensor, + in_splits: torch.Tensor, + out_splits_offsets: torch.Tensor, + group_name: c10d.GroupName, + major_align: int | None = None, +) -> None: + return None + + +@torch.library.impl(lib, "all_to_all_vdev_2d_offset", "Meta") +def _all_to_all_vdev_2d_offset_meta( + input: torch.Tensor, + out: torch.Tensor, + in_splits_offsets: torch.Tensor, + out_splits_offsets: torch.Tensor, + group_name: c10d.GroupName, +) -> None: + return None + + +# ============================================================================= +# User-facing APIs +# ============================================================================= + + +from collections.abc import Sequence +from typing import overload, TYPE_CHECKING, Union + + +if TYPE_CHECKING: + from torch._C._distributed_c10d import ProcessGroup + from torch.types import _device, _dtype, _int + + +@overload +def empty( + *size: _int, dtype: _dtype | None = None, device: _device | None = None +) -> torch.Tensor: ... + + +@overload +# pyrefly: ignore [inconsistent-overload] +def empty( + size: Sequence[_int], + *, + dtype: _dtype | None = None, + device: _device | None = None, +) -> torch.Tensor: ... + + +def empty( # type: ignore[misc] + *size: Any, + dtype: _dtype | None = None, + device: _device | None = None, +) -> torch.Tensor: + r""" + Similar to :func:`torch.empty()`. The returned tensor can be used by + :func:`torch._distributed._symmetric_memory.rendezvous()` to establish a + symmetric memory tensor among participating processes. + + Args: + size (int...): a sequence of integers defining the shape of the output tensor. + Can be a variable number of arguments or a collection like a list or tuple. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + """ + if len(size) == 1 and isinstance(size[0], Sequence): + size = tuple(size[0]) + else: + size = tuple(size) + + if dtype is None: + dtype = torch.get_default_dtype() + + if device is None: + device = torch.get_default_device() + + return _SymmetricMemory.empty_strided_p2p( + size=size, + stride=torch._prims_common.make_contiguous_strides_for(size), + dtype=dtype, + device=torch.device(device), + ) + + +def rendezvous( + tensor: torch.Tensor, group: Union[c10d.GroupName, ProcessGroup] +) -> _SymmetricMemory: + r""" + rendezvous(tensor, group) -> _SymmetricMemory + + Establish a symmetric memory tensor among participating processes. This is + a collective operation. + + Args: + tensor (:class:`torch.Tensor`): the local tensor used to establish the symmetric memory tensor. + It must be allocated via :func:`torch._distributed._symmetric_memory.empty()`. The shape, + dtype, and device type must be identical across all participating processes. + group (Union[str, :class:`torch.distributed.ProcessGroup`]): The group identifying the + participating processes. This can be either a group name or a process group object. + """ + from torch._C._distributed_c10d import ProcessGroup + + if isinstance(group, str): + group_name = c10d.GroupName(group) + elif isinstance(group, ProcessGroup): + group_name = group.group_name + else: + raise TypeError(f"rendezvous: unsupported group type: {type(group)}") + + enable_symm_mem_for_group(group_name) + return _SymmetricMemory.rendezvous(tensor, group_name) + + +def is_nvshmem_available() -> bool: + r""" + is_nvshmem_available() -> bool + + Check if NVSHMEM is available in current build and on current system. + """ + try: + from torch._C._distributed_c10d import _is_nvshmem_available + except ImportError: + # Not all builds have NVSHMEM support. + return False + + # Check if NVSHMEM is available on current system. + return _is_nvshmem_available() + + +def set_backend(name: Literal["NVSHMEM", "CUDA", "NCCL"]) -> None: + r""" + Set the backend for symmetric memory allocation. This is a global setting + and affects all subsequent calls to + :func:`torch._distributed._symmetric_memory.empty()`. Note that the backend + cannot be changed once a symmetric memory tensor has been allocated. + + Args: + backend (str): the backend for symmetric memory allocation. Currently, + only `"NVSHMEM"`, `"CUDA"`, `"NCCL"` are supported. + """ + _SymmetricMemory.set_backend(name) + + +def get_backend(device: _device) -> str | None: + r""" + Get the backend for symmetric memory allocation for a given device. If not + found, return None. + + Args: + device (`torch.device` or str): the device for which to get the backend. + """ + return _SymmetricMemory.get_backend(torch.device(device)) + + +def get_mempool_allocator(device: _device): # type: ignore[no-untyped-def] + r""" + Get the MemPool allocator for symmetric memory for a given device. + + Args: + device (`torch.device` or str): the device for which to get the MemPool + allocator. + """ + return _SymmetricMemory.get_mempool_allocator(torch.device(device)) + + +def set_signal_pad_size(size: int) -> None: + r""" + Set the signal pad size for future symmetric memory allocations. + + Signal pads are P2P-accessible memory regions used for synchronization in + symmetric memory. This function allows users to configure + the signal pad size to be proportional to their workload requirements. + + .. warning:: + This must be called before any symmetric memory allocations are made. + The size cannot be changed after allocations have been performed. + + Args: + size (int): the signal pad size in bytes. The size should be + proportional to the number of blocks launched and the world size. + + Example:: + + >>> # doctest: +SKIP + >>> # Set a larger signal pad size before any allocations + >>> torch.distributed._symmetric_memory.set_signal_pad_size(1024 * 1024) # 1MB + """ + _SymmetricMemory.signal_pad_size = size + + +def get_signal_pad_size() -> int: + r""" + Get the current signal pad size for symmetric memory allocations. + + Returns the user-configured size if set via :func:`set_signal_pad_size`, + otherwise returns the default size. + + Returns: + int: the signal pad size in bytes. + + Example:: + + >>> # doctest: +SKIP + >>> size = torch.distributed._symmetric_memory.get_signal_pad_size() + >>> print(f"Signal pad size: {size} bytes") + """ + return _SymmetricMemory.signal_pad_size + + +# An internal map from device to the symmetric memory pool for that device. +_symm_mem_pools: dict[_device, torch.cuda.MemPool] = {} + + +def get_mem_pool(device: _device) -> torch.cuda.MemPool: + """ + Get the symmetric memory pool for a given device. If not found, create a new + pool. + + The tensor allocations with this pool must be symmetric across ranks. The + allocated tensors can be used with symmetric operations, for example, + operations defined under `torch.ops.symm_mem`. + + Args: + device (`torch.device` or str): the device for which to get the symmetric memory pool. + + Returns: + `torch.cuda.MemPool`: the symmetric memory pool for the given device. + + Example:: + + >>> # doctest: +SKIP + >>> pool = torch.distributed._symmetric_memory.get_mem_pool("cuda:0") + >>> with torch.cuda.use_mem_pool(pool): + >>> tensor = torch.randn(1000, device="cuda:0") + >>> tensor = torch.ops.symm_mem.one_shot_all_reduce(tensor, "sum", group_name) + + """ + # This function is a wrapper around the `torch.cuda.MemPool` constructor. + # Due to special requirements of SymmetricMemory, we preset certain options for the pool. + # - use_on_oom=False: we don't want to lend the space of the pool for + # non-symmetric allocations because this could desync the allocation state + # across ranks. + # - no_split=True: we don't want to split segments, because today a segment + # is associated with a signal pad, if two allocated tensors share a segment + # and their kernels concurrently use (the same) signal pad, this could cause + # undefined behaviors. We could consider relaxing this in the future if we + # establish stream tracking and implicit synchronization around an + # allocation. + if device not in _symm_mem_pools: + allocator = get_mempool_allocator(device) + # Create a new pool with the given allocator and the preset options. + _symm_mem_pools[device] = torch.cuda.MemPool( + allocator, + use_on_oom=False, + no_split=True, + ) + + return _symm_mem_pools[device] + + +__all__ = [ + "empty", + "rendezvous", + "is_nvshmem_available", + "set_backend", + "get_backend", + "set_signal_pad_size", + "get_signal_pad_size", + "get_mem_pool", +] diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_symmetric_memory/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_symmetric_memory/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ba24a1ac17814c24e15db9f3ede69d15f1268c8c Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_symmetric_memory/__pycache__/__init__.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_symmetric_memory/__pycache__/_nvshmem_triton.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_symmetric_memory/__pycache__/_nvshmem_triton.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..30a17696b52aae443a77c22fb01e59b569c69f36 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_symmetric_memory/__pycache__/_nvshmem_triton.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_symmetric_memory/_nvshmem_triton.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_symmetric_memory/_nvshmem_triton.py new file mode 100644 index 0000000000000000000000000000000000000000..3ca8bc95eae39ebefdd97f8805b97099a82e9a92 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_symmetric_memory/_nvshmem_triton.py @@ -0,0 +1,1220 @@ +import logging +import os +import subprocess +import sysconfig +from typing import Any + +import torch.distributed as dist +from torch.utils._triton import has_triton + + +logger = logging.getLogger(__name__) + + +class NvshmemLibFinder: + """ + A class to find path to the NVSHMEM device library. + + Environment variable: + + `NVSHMEM_LIB_DIR` (Optional[str]): The directory where the NVSHMEM device + library is located. If not provided, it will use the default path where + NVSHMEM wheel is installed, or search for the library in common system + paths. + """ + + # Class variable to store the found library path for reuse + found_device_lib_path: str | None = None + + @classmethod + def find_device_library(cls) -> str: + """ + Find the path to the NVSHMEM device library. + + Returns: + str: The path to libnvshmem_device.bc (included). + """ + if cls.found_device_lib_path is not None: + # Return the cached path if it exists + return cls.found_device_lib_path + + # First, check if the user has specified a custom library path + user_lib_dir = os.environ.get("NVSHMEM_LIB_DIR", None) + if user_lib_dir is not None: + lib_path = os.path.join(user_lib_dir, "libnvshmem_device.bc") + if not os.path.exists(lib_path): + raise RuntimeError( + f"NVSHMEM device library not found at specified path: {user_lib_dir}" + ) + cls.found_device_lib_path = lib_path + return lib_path + + # Otherwise, search for the library in the default installation paths + paths = [ + os.path.join(sysconfig.get_path("purelib"), "nvidia", "nvshmem", "lib") + ] + + # Add common system installation paths + common_paths = [ + "/usr/local/lib", + "/usr/lib", + "/opt/nvidia/nvshmem/lib", + ] + paths.extend(common_paths) + + try: + import torch + + torch_lib = os.path.join(os.path.dirname(torch.__file__), "lib") + so_path = os.path.join(torch_lib, "libtorch_nvshmem.so") + + if os.path.exists(so_path): + try: + result = subprocess.run( + ["readelf", "-d", so_path], + capture_output=True, + text=True, + check=True, + ) + + for line in result.stdout.splitlines(): + if ("RPATH" in line or "RUNPATH" in line) and "[" in line: + rpath = line.split("[", 1)[1].split("]", 1)[0] + for p in rpath.split(":"): + p = p.strip().replace("$ORIGIN", torch_lib) + if p and p not in paths: + paths.append(p) + except subprocess.CalledProcessError: + pass + + except ImportError: + pass + + for path in paths: + device_lib = os.path.join(path, "libnvshmem_device.bc") + if os.path.exists(device_lib): + cls.found_device_lib_path = device_lib + return device_lib + + raise RuntimeError(f"NVSHMEM device library not found. Searched: {paths}") + + +def enable_triton(lib_dir: str | None = None) -> dict[str, str]: + raise NotImplementedError( + "`enable_triton` is deprecated. " + "If you need NVSHMEM device function support for Triton, " + "please use `@requires_nvshmem` to decorate your Triton kernel. ", + ) + + +class NvshmemKernelRegistry: + """ + A class to register kernel functions that ** require NVSHMEM initialization ** + """ + + # Class variable to store the functions to be initialized + _to_init: dict[str, Any] = {} + + @classmethod + def register(cls, name: str) -> None: + """ + Register a kernel function with the given name. + + Args: + name (str): The name of the kernel function. + """ + cls._to_init.setdefault(name) + + @classmethod + def deregister(cls, name: str) -> None: + """ + Deregister a kernel function with the given name. + + Args: + name (str): The name of the kernel function. + """ + cls._to_init.pop(name, None) + + @classmethod + def has(cls, name: str) -> bool: + """ + Check if a kernel function with the given name is registered. + + Args: + name (str): The name of the kernel function. + + Returns: + bool: True if the kernel function is registered, False otherwise. + """ + return name in cls._to_init + + +def _nvshmem_init_hook(*args, **kwargs) -> None: # type: ignore[no-untyped-def] + """ + A hook function to initialize the CUModule created by `triton.jit` with + NVSHMEM device context + """ + from torch._C._distributed_c10d import _nvshmemx_cumodule_init + + jit_function = kwargs["fn"].jit_function + fn_name = jit_function.fn.__name__ + + # Only initialize NVSHMEM module for kernels registered via @requires_nvshmem + if NvshmemKernelRegistry.has(fn_name): + key = kwargs["key"] + device = kwargs["compile"]["device"] + jit_function = kwargs["fn"].jit_function + kernel_cache = jit_function.device_caches[device][0] + kernel = kernel_cache.get(key, None) + if kernel is not None: + kernel.run + # Initialize NVSHMEM for the CU module + _nvshmemx_cumodule_init(kernel.module) + else: + logger.warning( + f"It seems Triton hasn't created a kernel for function {fn_name}. " # noqa: G004 + "Please report this issue to Triton." + ) + + +if has_triton(): + from triton.runtime.jit import JITFunction, KernelInterface + + # Create a new Callable class that follows the KernelInterface protocol so + # that the Callable works with the subscript operator, e.g. `foo[(1, 1)]` + class GridCallableWithExtern(KernelInterface): + """ + `KernelInterface` invokes `self.run` in `__getitem__`, i.e. []. We + implement a `run` method by directing the call to `JITFunction.run`, + with added extern_libs kwarg, so that users don't have to pass it + """ + + def __init__(self, jit_func: JITFunction, extern_libs: dict[str, str]) -> None: + self.jit_func = jit_func + self.extern_libs = extern_libs + + def run(self, *args, **kwargs): # type: ignore[no-untyped-def] + # Call the JITFunction.run with added extern_libs kwarg + return self.jit_func.run(*args, **kwargs, extern_libs=self.extern_libs) + + +def requires_nvshmem( # type: ignore[no-untyped-def] + jit_func, # JITFunction created by triton.jit +): + """ + A decorator to register a Triton kernel function that requires NVSHMEM initialization. + + Example usage: + ``` + @requires_nvshmem + @triton.jit + def foo(...): + ... + ``` + + If you would like to specify a path to the NVSHMEM device library other + than standard search locations, you can use the following environment + variable: + ``` + export NVSHMEM_LIB_DIR=/path/to/nvshmem/lib + ``` + """ + + import triton + from triton.runtime.jit import JITFunction + + if not isinstance(jit_func, JITFunction): + raise TypeError(f"Expected a JITFunction, but got {type(jit_func)}") + + # Find the NVSHMEM device library + lib_path = NvshmemLibFinder.find_device_library() + extern_libs = {"libnvshmem_device": lib_path} + + # Register the JITFunction with the kernel registry as "to be initialized" + NvshmemKernelRegistry.register(jit_func.fn.__name__) + + # Register the NVSHMEM init function as a post-compile hook. + # [Note] This is a global setting (due to lack of Triton API exposure). To + # avoid initializing Triton kernels that do not require NVSHMEM, filtering + # is performed in the hook function itself by checking against + # NvshmemKernelRegistry. + triton.knobs.runtime.jit_post_compile_hook = _nvshmem_init_hook + + return GridCallableWithExtern(jit_func, extern_libs) + + +if has_triton(): + import triton + import triton.language as tl + from triton.language import core + + @triton.jit # type: ignore[misc] + def put(dest, source, nelems, pe): # type: ignore[no-untyped-def] + """ + Put tensor data from local PE to a remote PE. + + This high-level function provides a tensor-aware interface for NVSHMEM put + operations. It automatically handles type checking and size calculations, making + the API more ergonomic and type-safe. + + Args: + dest: Destination tensor on the remote PE. Type must match source. + source: Source tensor on the local PE containing data to be copied. + nelems: Number of elements to transfer. + pe: PE number of the remote PE (0 ≤ pe < nvshmem_n_pes()). + + Notes: + - Performs compile-time type checking between dest and source tensors. + - Automatically calculates byte size from tensor type and element count. + - This is a blocking operation that returns after data has been copied out + of the source array on the local PE. + - The operation does not guarantee delivery to the destination PE. + Use nvshmem_fence() for ordering or nvshmem_quiet() for completion. + + Example: + ``` + # Transfer 100 elements to PE 1 + nvshmem.put(dest_tensor, src_tensor, 100, 1) + ``` + """ + tl.static_assert(dest.type == source.type) + nbytes = nelems * dest.type.element_ty.itemsize + return putmem_block_extern_wrapper( + dest.to(tl.int64), source.to(tl.int64), nbytes.to(tl.int64), pe + ) + + @core.extern + def putmem_block_extern_wrapper(dest, source, size_bytes, pe, _semantic=None): # type: ignore[no-untyped-def] + """Low-level extern wrapper for NVSHMEM put""" + return core.extern_elementwise( + "", + "", + [dest, source, size_bytes, pe], + { + ( + core.dtype("int64"), # dest ptr + core.dtype("int64"), # source ptr + core.dtype("int64"), # size in bytes + core.dtype("int32"), # pe number + ): ("nvshmemx_putmem_block", core.dtype("int32")) + }, + is_pure=False, + _semantic=_semantic, + ) + + @triton.jit # type: ignore[misc] + def get(dest, source, nelems, pe): # type: ignore[no-untyped-def] + """ + Get tensor data from a remote PE to local PE. + + This high-level function provides a tensor-aware interface for NVSHMEM get + operations. It automatically handles type checking and size calculations, making + the API more ergonomic and type-safe. + + Args: + dest: Destination tensor on the local PE. Type must match source. + source: Source tensor on the remote PE containing data to be copied. + nelems: Number of elements to transfer. + pe: PE number of the remote PE (0 ≤ pe < nvshmem_n_pes()). + + Notes: + - Performs compile-time type checking between dest and source tensors. + - Automatically calculates byte size from tensor type and element count. + - This is a blocking operation that returns after data has been delivered + to the destination array on the local PE. + - The destination data is guaranteed to be available for use after the call returns. + + Example: + ``` + # Get 100 elements from PE 0 + nvshmem.get(dest_tensor, src_tensor, 100, 0) + ``` + """ + tl.static_assert(dest.type == source.type) + nbytes = nelems * dest.type.element_ty.itemsize + return getmem_block_extern_wrapper( + dest.to(tl.int64), source.to(tl.int64), nbytes.to(tl.int64), pe + ) + + @core.extern + def getmem_block_extern_wrapper(dest, source, size_bytes, pe, _semantic=None): # type: ignore[no-untyped-def] + """Low-level extern wrapper for NVSHMEM get""" + return core.extern_elementwise( + "", + "", + [dest, source, size_bytes, pe], + { + ( + core.dtype("int64"), # dest ptr + core.dtype("int64"), # source ptr + core.dtype("int64"), # size in bytes + core.dtype("int32"), # pe number + ): ("nvshmemx_getmem_block", core.dtype("int32")) + }, + is_pure=False, + _semantic=_semantic, + ) + + @triton.jit # type: ignore[misc] + def get_nbi(dest, source, nelems, pe): # type: ignore[no-untyped-def] + """ + Get tensor data from a remote PE to local PE, non-blocking. + + Different from the `get` function, this function returns after + initiating the operation. The operation is considered complete after a + subsequent call to `quiet`. + + Args: + dest: Destination tensor on the local PE. Type must match source. + source: Source tensor on the remote PE containing data to be copied. + nelems: Number of elements to transfer. + pe: PE number of the remote PE (0 ≤ pe < nvshmem_n_pes()). + + Notes: + - Performs compile-time type checking between dest and source tensors. + - Automatically calculates byte size from tensor type and element count. + + Example: + ``` + # Get 100 elements from PE 0 + nvshmem.get_nbi(dest, src, 100, 0) + # Some independent computation which overlaps with the get operation + ... + # Wait for completion of the get operation + nvshmem.quiet() + ``` + """ + tl.static_assert(dest.type == source.type) + nbytes = nelems * dest.type.element_ty.itemsize + return getmem_block_extern_wrapper( + dest.to(tl.int64), source.to(tl.int64), nbytes.to(tl.int64), pe + ) + + @core.extern + def getmem_nbi_block_extern_wrapper(dest, source, size_bytes, pe, _semantic=None): # type: ignore[no-untyped-def] + """Low-level extern wrapper for NVSHMEM get""" + return core.extern_elementwise( + "", + "", + [dest, source, size_bytes, pe], + { + ( + core.dtype("int64"), # dest ptr + core.dtype("int64"), # source ptr + core.dtype("int64"), # size in bytes + core.dtype("int32"), # pe number + ): ("nvshmemx_getmem_nbi_block", core.dtype("int32")) + }, + is_pure=False, + _semantic=_semantic, + ) + + @triton.jit # type: ignore[misc] + def putmem_signal_block( # type: ignore[no-untyped-def] + dst, + src, + size_bytes, + signal, + sig_val, + sig_op, + pe, + ): # type: ignore[no-untyped-def] + """ + Put data to remote PE with atomic signal operation using block-scoped operation. + + This function copies data from the local PE to the remote PE and then + atomically updates a signal variable on the remote PE to indicate completion. + This enables efficient point-to-point synchronization between PEs. + + Args: + dst (tensor): A tensor on calling PE symmetric to the destination tensor on remote PE. + src (tensor): Local tensor containing the source data. + size_bytes (int64): Number of bytes to transfer. Must be positive. + signal (tensor): Symmetric signal pad with remote PE. + Must be 8-byte aligned symmetric memory. + signal (int64): Value to be used in the signal operation. + sig_op (int32): Signal operation type. Common values: + - NVSHMEM_SIGNAL_SET (0): Atomic set operation + - NVSHMEM_SIGNAL_ADD (5): Atomic add operation + pe (int32): PE number of the remote PE (0 ≤ pe < nvshmem_n_pes()). + + Returns: + int32: Status code (0 for success). + + Notes: + - This is a blocking operation that returns after data has been copied out + of the source array and the signal has been updated on the remote PE. + - The signal update is performed atomically with respect to other signal + operations and synchronization routines. + - The signal variable must be of type uint64_t in symmetric memory. + - Use with nvshmem_signal_wait_until() for synchronization. + + Example: + ``` + # Transfer data and set completion flag to 1 + NVSHMEM_SIGNAL_SET = 0 + nvshmem.putmem_signal_block( + dst_ptr, src_ptr, 1024, sig_ptr, 1, NVSHMEM_SIGNAL_SET, target_pe + ) + ``` + """ + # Ensure sig_val is 64 bits + sig_val = 0 << 32 | sig_val + return putmem_signal_block_extern_wrapper( + dst.to(tl.int64), + src.to(tl.int64), + size_bytes.to(tl.int64), + signal.to(tl.int64), + sig_val.to(tl.uint64), + sig_op, + pe, + ) + + @core.extern + def putmem_signal_block_extern_wrapper( # type: ignore[no-untyped-def] + dst, + src, + size_bytes, + signal, + sig_val, + sig_op, + pe, + _semantic=None, + ): # type: ignore[no-untyped-def] + return core.extern_elementwise( + "", + "", + [dst, src, size_bytes, signal, sig_val, sig_op, pe], + { + ( + core.dtype("int64"), + core.dtype("int64"), + core.dtype("int64"), + core.dtype("int64"), + core.dtype("uint64"), + core.dtype("int32"), + core.dtype("int32"), + ): ("nvshmemx_putmem_signal_block", core.dtype("int32")) + }, + is_pure=False, + _semantic=_semantic, + ) + + # Wait and Signal Operations + + @triton.jit # type: ignore[misc] + def wait_until(ivar, cmp_op, cmp_val): # type: ignore[no-untyped-def] + """ + Wait until a tensor variable meets a specified condition. + + This high-level function provides a tensor-aware interface for NVSHMEM wait_until + operations. It automatically handles tensor address extraction, making + the API more ergonomic and type-safe. + + Args: + ivar_tensor: Tensor to monitor (typically int64/uint64) in symmetric memory. + cmp: Comparison operator. Common values: + - NVSHMEM_CMP_EQ (0): Wait until ivar == cmp_val + - NVSHMEM_CMP_NE (1): Wait until ivar != cmp_val + - NVSHMEM_CMP_GT (2): Wait until ivar > cmp_val + - NVSHMEM_CMP_GE (3): Wait until ivar >= cmp_val + - NVSHMEM_CMP_LT (4): Wait until ivar < cmp_val + - NVSHMEM_CMP_LE (5): Wait until ivar <= cmp_val + cmp_val: Value to compare against. + + Notes: + - This is a blocking operation that will wait indefinitely until the + condition is satisfied. + - The tensor must be in symmetric memory and accessible from other PEs. + + Example: + ``` + # Wait until flag tensor becomes 1 (set by another PE) + NVSHMEM_CMP_EQ = 0 + nvshmem.wait_until_tensor(flag_tensor, NVSHMEM_CMP_EQ, 1) + ``` + """ + tl.static_assert( + ivar.type.element_ty.itemsize == 4, + "wait_until expects a 32-bit type for the synchronization variable", + ) + return wait_until_extern_wrapper(ivar.to(tl.int64), cmp_op, cmp_val) + + @core.extern + def wait_until_extern_wrapper(ivar, cmp, cmp_val, _semantic=None): # type: ignore[no-untyped-def] + return core.extern_elementwise( + "", + "", + [ivar, cmp, cmp_val], + { + ( + core.dtype("int64"), + core.dtype("int32"), + core.dtype("int32"), + ): ("nvshmem_int_wait_until", core.dtype("int32")) + }, + is_pure=False, + _semantic=_semantic, + ) + + @triton.jit # type: ignore[misc] + def signal_wait_until(signal, cmp, cmp_val): # type: ignore[no-untyped-def] + """ + Wait until a signal variable meets a specified condition. + + This function blocks the calling thread until the value at the specified + signal variable satisfies the given comparison condition. Signal variables + are special uint64_t symmetric objects used for efficient synchronization + with signal operations. + + Args: + signal (tensor): Symmetric signal tensor with remote PE. + Must be 8-byte aligned symmetric memory. + cmp (int32): Comparison operator. Common values: + - NVSHMEM_CMP_EQ (0): Wait until signal == cmp_val + - NVSHMEM_CMP_NE (1): Wait until signal != cmp_val + - NVSHMEM_CMP_GT (2): Wait until signal > cmp_val + - NVSHMEM_CMP_GE (3): Wait until signal >= cmp_val + - NVSHMEM_CMP_LT (4): Wait until signal < cmp_val + - NVSHMEM_CMP_LE (5): Wait until signal <= cmp_val + cmp_val (int64): Value to compare against. + + Returns: + int32: Status code (0 for success). + + Notes: + - This is a blocking operation designed specifically for signal variables. + - Signal variables are updated atomically by putmem_signal operations. + - More efficient than wait_until for signal-based synchronization patterns. + - Ensures the signal update is fully complete before returning. + - Commonly used with putmem_signal_block for producer-consumer patterns. + + Example: + ``` + # Wait for signal to be set to completion value + NVSHMEM_CMP_EQ = 0 + nvshmem.signal_wait_until(signal_ptr, NVSHMEM_CMP_EQ, 42) + ``` + """ + cmp_val = 0 << 32 | cmp_val + return signal_wait_until_extern_wrapper( + signal.to(tl.int64), cmp, cmp_val.to(tl.uint64) + ) + + @core.extern + def signal_wait_until_extern_wrapper(signal, cmp, cmp_val, _semantic=None): # type: ignore[no-untyped-def] + return core.extern_elementwise( + "", + "", + [signal, cmp, cmp_val], + { + ( + core.dtype("int64"), + core.dtype("int32"), + core.dtype("uint64"), + ): ("nvshmem_signal_wait_until", core.dtype("int32")) + }, + is_pure=False, + _semantic=_semantic, + ) + + @core.extern + def signal_op(sig_addr, signal, sig_op, pe, _semantic=None): # type: ignore[no-untyped-def] + """ + Perform an atomic signal operation on a remote PE. + + This function atomically updates a signal variable on the specified remote PE + using the given operation and value. This enables efficient point-to-point + synchronization and notification between PEs. + + Args: + sig_addr (int64): Symmetric address of the signal variable (uint64_t) on the remote PE. + Must be 8-byte aligned symmetric memory. + signal (int64): Value to be used in the signal operation. + sig_op (int32): Signal operation type. Common values: + - NVSHMEM_SIGNAL_SET (0): Atomically set sig_addr = signal + - NVSHMEM_SIGNAL_ADD (5): Atomically set sig_addr += signal + pe (int32): PE number of the remote PE (0 ≤ pe < nvshmem_n_pes()). + _semantic: Optional semantic information for Triton compilation. + + Returns: + int32: Status code (0 for success). + + Notes: + - This is a one-sided operation - the remote PE does not need to participate. + - The signal operation is performed atomically on the remote PE. + - Can be used with signal_wait_until() on the remote PE for synchronization. + - Provides low-overhead notification mechanism between PEs. + - The signal variable must be of type uint64_t in symmetric memory. + + Example: + ```python + # Atomically set remote signal to 1 to notify completion + NVSHMEM_SIGNAL_SET = 0 + nvshmem.signal_op(remote_signal_ptr, 1, NVSHMEM_SIGNAL_SET, target_pe) + ``` + """ + return core.extern_elementwise( + "", + "", + [sig_addr, signal, sig_op, pe], + { + ( + core.dtype("int64"), + core.dtype("int64"), + core.dtype("int32"), + core.dtype("int32"), + ): ("nvshmemx_signal_op", core.dtype("int32")) + }, + is_pure=False, + _semantic=_semantic, + ) + + # Memory Ordering Operations + @core.extern + def fence(_semantic=None): # type: ignore[no-untyped-def] + """ + Ensure ordering of put operations to each remote PE. + + This function provides a memory fence that ensures point-to-point ordering + of remote memory operations. Put operations issued before the fence are + guaranteed to be ordered before put operations issued after the fence, + when targeting the same remote PE. + + Args: + _semantic: Optional semantic information for Triton compilation. + + Returns: + int32: Status code (0 for success). + + Notes: + - This provides weaker ordering guarantees than quiet(). + - Operations to each PE are ordered, but operations to different PEs + may still be reordered relative to each other. + - Does not guarantee completion of operations, only ordering. + - Non-blocking operations are not ordered by fence - use quiet() instead. + - Essential for ensuring correct ordering in communication patterns. + + Memory Ordering Guarantees: + - Put operations before fence() → ordered before → Put operations after fence() + - Ordering is maintained per-destination-PE basis + - Remote PEs can observe the enforced ordering + + Example: + ``` + # Ensure first put completes before second put to same PE + nvshmem.put(dst, src, nelems, target_pe) + nvshmem.fence() # Enforce ordering + nvshmem.put(dst2, src2, nelems, target_pe) + ``` + """ + return core.extern_elementwise( + "", + "", + [], + { + (): ("nvshmem_fence", core.dtype("int32")), + }, + is_pure=False, + _semantic=_semantic, + ) + + @core.extern + def quiet(_semantic=None): # type: ignore[no-untyped-def] + """ + Wait for completion of all outstanding put operations. + + This function blocks until all outstanding remote memory operations issued + by the calling PE have completed. It provides stronger guarantees than + fence() by ensuring both ordering and completion of all operations. + + Args: + _semantic: Optional semantic information for Triton compilation. + + Returns: + int32: Status code (0 for success). + + Notes: + - This is a blocking operation that waits for completion. + - Ensures all previous put operations have been delivered to their destinations. + - Provides global ordering - operations to ALL PEs are ordered. + - Required to complete non-blocking operations. + - More expensive than fence() but provides stronger guarantees. + + Memory Ordering Guarantees: + - All put operations before quiet() are completed before any operations after quiet() + - Operations are visible to all PEs as having occurred before subsequent operations + - Both blocking and non-blocking operations are completed + + Example: + ``` + # Ensure all data transfers complete before setting completion flag + nvshmem.putmem_block(data_ptr, src_ptr, data_size, target_pe) + nvshmem.quiet() # Wait for data transfer completion + nvshmem.putmem_block( + flag_ptr, flag_src_ptr, 8, target_pe + ) # Signal completion + ``` + """ + return core.extern_elementwise( + "", + "", + [], + { + (): ("nvshmem_quiet", core.dtype("int32")), + }, + is_pure=False, + _semantic=_semantic, + ) + + # PE Information Operations + @core.extern + def my_pe(_semantic=None): # type: ignore[no-untyped-def] + """ + Get the PE number of the calling PE. + + This function returns the unique identifier (PE number) of the current + processing element within the NVSHMEM job. PE numbers range from 0 to + nvshmem_n_pes() - 1. + + Args: + _semantic: Optional semantic information for Triton compilation. + + Returns: + int32: PE number of the calling PE (0 ≤ pe < nvshmem_n_pes()). + + Notes: + - This is a pure function that returns the same value throughout execution. + - PE numbering starts from 0 and is contiguous. + - Each PE has a unique identifier within the NVSHMEM job. + - Can be called from both host and device code. + - Essential for implementing PE-specific logic and communication patterns. + + Example: + ``` + # Get current PE number for conditional logic + pe = nvshmem.my_pe() + if pe == 0: + # Root PE logic + pass + else: + # Non-root PE logic + pass + ``` + """ + return core.extern_elementwise( + "", + "", + [], + {(): ("nvshmem_my_pe", core.dtype("int32"))}, + is_pure=True, + _semantic=_semantic, + ) + + @core.extern + def n_pes(_semantic=None): # type: ignore[no-untyped-def] + """ + Get the total number of PEs in the NVSHMEM job. + + This function returns the total count of processing elements (PEs) + participating in the current NVSHMEM job. This value remains constant + throughout the execution of the program. + + Args: + _semantic: Optional semantic information for Triton compilation. + + Returns: + int32: Total number of PEs in the job (always ≥ 1). + + Notes: + - This is a pure function that returns the same value throughout execution. + - The value is determined at NVSHMEM initialization and never changes. + - Valid PE numbers range from 0 to n_pes() - 1. + - Can be called from both host and device code. + - Essential for implementing collective operations and communication patterns. + + Example: + ``` + # Broadcast from root to all other PEs + total_pes = nvshmem.n_pes() + my_rank = nvshmem.my_pe() + + if my_rank == 0: + # Send to all other PEs + for peer in range(1, total_pes): + nvshmem.putmem_block(dst_ptr, src_ptr, size, peer) + ``` + """ + return core.extern_elementwise( + "", + "", + [], + {(): ("nvshmem_n_pes", core.dtype("int32"))}, + is_pure=True, + _semantic=_semantic, + ) + + # Synchronization Operations + @core.extern + def barrier_all(_semantic=None): # type: ignore[no-untyped-def] + """ + Synchronize all PEs with completion guarantee. + + This function creates a barrier across all PEs in the NVSHMEM job. It ensures + that all local and remote memory updates issued before the barrier by any PE + are completed before any PE exits the barrier. This provides both + synchronization and memory consistency. + + Args: + _semantic: Optional semantic information for Triton compilation. + + Returns: + int32: Status code (0 for success). + + Notes: + - This is a collective operation - all PEs must participate. + - Stronger guarantee than sync_all() - ensures completion of remote operations. + - Blocks until all PEs reach the barrier AND all memory operations complete. + - Must be called from kernels launched with cooperative launch. + - Provides full memory consistency across all PEs. + - More expensive than sync_all() due to completion guarantees. + + Memory Consistency Guarantees: + - All memory updates before barrier_all() are visible to all PEs + - All remote memory operations are completed before any PE continues + - Provides a global synchronization point with memory ordering + + Example: + ``` + # Ensure all PEs complete their work before proceeding + # All PEs execute this - it's a collective operation + nvshmem.barrier_all() + # At this point, all previous operations are complete on all PEs + ``` + """ + return core.extern_elementwise( + "", + "", + [], + {(): ("nvshmem_barrier_all", core.dtype("int32"))}, + is_pure=False, + _semantic=_semantic, + ) + + @core.extern + def sync_all(_semantic=None): # type: ignore[no-untyped-def] + """ + Synchronize all PEs with local completion guarantee. + + This function creates a lightweight synchronization barrier across all PEs. + It ensures that all local store operations issued before the sync are + visible to other PEs, but does not guarantee completion of remote memory + operations initiated by the calling PE. + + Args: + _semantic: Optional semantic information for Triton compilation. + + Returns: + int32: Status code (0 for success). + + Notes: + - This is a collective operation - all PEs must participate. + - Lighter weight than barrier_all() - only ensures local store visibility. + - Does not guarantee completion of remote memory updates initiated locally. + - Must be called from kernels launched with cooperative launch. + - Suitable when only synchronization (not completion) is needed. + - More efficient than barrier_all() for synchronization-only patterns. + + Memory Consistency Guarantees: + - Local store operations are visible to other PEs + - Does NOT ensure completion of outgoing remote operations + - Provides synchronization point without full completion overhead + + Example: + ``` + # Lightweight synchronization between PEs + # All PEs execute this - it's a collective operation + nvshmem.sync_all() + # Local stores are visible, but remote ops may still be in flight + ``` + """ + return core.extern_elementwise( + "", + "", + [], + {(): ("nvshmem_sync_all", core.dtype("int32"))}, + is_pure=False, + _semantic=_semantic, + ) + + # Collective Operations (mem-based APIs - sizes in bytes) + @triton.jit # type: ignore[misc] + def alltoall(team, dest, source, nelems_per_pe): # type: ignore[no-untyped-def] + """ + All-to-all tensor exchange between PEs in a team. + + This high-level function provides a tensor-aware interface for NVSHMEM alltoall + operations. Each PE sends nelems_per_pe elements to every other PE and receives + the same amount from every other PE. + + Args: + team: Team handle for the collective operation. Use 0 for NVSHMEM_TEAM_WORLD. + dest: Destination tensor. Must be large enough for nelems_per_pe * n_pes elements. + source: Source tensor containing data for all PEs. Must contain nelems_per_pe * n_pes elements. + nelems_per_pe: Number of elements to exchange with each PE. + + Notes: + - Performs compile-time type checking between dest and source tensors. + - Automatically calculates byte size from tensor type and element count. + - This is a collective operation - all PEs in the team must participate. + - Data layout: source=[data_for_pe0, data_for_pe1, ...], dest=[data_from_pe0, data_from_pe1, ...] + + Example: + ``` + # Each PE exchanges 10 elements with every other PE + nvshmem.alltoall(0, dest_tensor, src_tensor, 10) + ``` + """ + tl.static_assert(dest.type == source.type) + size_bytes_per_pe = nelems_per_pe * dest.type.element_ty.itemsize + return alltoallmem_block_extern_wrapper( + team, dest.to(tl.int64), source.to(tl.int64), size_bytes_per_pe.to(tl.int64) + ) + + @core.extern # type: ignore[misc] + def alltoallmem_block_extern_wrapper( + team: Any, dest: Any, source: Any, size_bytes: Any, _semantic: Any = None + ) -> None: + """Low-level extern wrapper for NVSHMEM alltoall""" + return core.extern_elementwise( + "", + "", + [team, dest, source, size_bytes], + { + ( + core.dtype("int32"), # team handle + core.dtype("int64"), # dest ptr + core.dtype("int64"), # source ptr + core.dtype("int64"), # size in bytes + ): ("nvshmemx_alltoallmem_block", core.dtype("int32")) + }, + is_pure=False, + _semantic=_semantic, + ) + + @triton.jit # type: ignore[misc] + def broadcast(team, dest, source, nelems, pe_root): # type: ignore[no-untyped-def] + """ + Broadcast tensor data from a root PE to all other PEs in a team. + + This high-level function provides a tensor-aware interface for NVSHMEM broadcast + operations. It automatically handles type checking and size calculations, making + the API more ergonomic and type-safe. + + Args: + team: Team handle for the collective operation. Use 0 for NVSHMEM_TEAM_WORLD. + dest: Destination tensor with type information. All PEs receive data here. + source: Source tensor on the root PE. Type must match dest. + nelems: Number of elements to broadcast. + pe_root: PE number of the root PE that provides the source data. + + Notes: + - Performs compile-time type checking between dest and source tensors. + - Automatically calculates byte size from tensor type and element count. + - This is a collective operation - all PEs in the team must participate. + - Must be called from kernels launched with cooperative launch. + + Example: + ``` + # Broadcast 100 elements from PE 0 to all PEs + nvshmem.broadcast(0, dest_tensor, src_tensor, 100, 0) + ``` + """ + tl.static_assert(dest.type == source.type) + nbytes = nelems * dest.type.element_ty.itemsize + return broadcastmem_block_extern_wrapper( + team, dest.to(tl.int64), source.to(tl.int64), nbytes.to(tl.int64), pe_root + ) + + @core.extern # type: ignore[misc] + def broadcastmem_block_extern_wrapper( + team: Any, + dest: Any, + source: Any, + size_bytes: Any, + pe_root: Any, + _semantic: Any = None, + ) -> None: + """Low-level extern wrapper for NVSHMEM broadcast""" + return core.extern_elementwise( + "", + "", + [team, dest, source, size_bytes, pe_root], + { + ( + core.dtype("int32"), # team handle + core.dtype("int64"), # dest ptr + core.dtype("int64"), # source ptr + core.dtype("int64"), # size in bytes + core.dtype("int32"), # pe_root + ): ("nvshmemx_broadcastmem_block", core.dtype("int32")) + }, + is_pure=False, + _semantic=_semantic, + ) + + # Reduction Operation + @triton.jit # type: ignore[misc] + def reduce(team, dest, source, nreduce, operation: tl.constexpr): # type: ignore[no-untyped-def] + """ + Performs a collective reduction on tensors across a team of PEs. + + This high-level function provides a tensor-aware interface for NVSHMEM + reduction operations. It automatically infers the data type from the + input tensors and calls the appropriate underlying NVSHMEM function. + + Args: + team: The team handle for the collective (0 for NVSHMEM_TEAM_WORLD). + dest: Destination tensor for the reduction results. + source: Source tensor containing data to be reduced. Must be the same type as dest. + nreduce: The number of elements in the source tensor to reduce. + operation: The reduction operation to perform ("sum", "max", "min", "prod"). + + Notes: + - Performs compile-time type checking between dest and source tensors. + - This is a collective operation that must be called by all PEs in the team. + - Requires a cooperative grid launch. + + Example: + ``` + # Perform a sum reduction on two tensors + nvshmem.reduce(0, dest_tensor, src_tensor, 100, "sum") + ``` + """ + tl.static_assert(dest.type == source.type) + dtype = dest.type.element_ty + return reduce_extern_wrapper( + team, + dest.to(tl.int64), + source.to(tl.int64), + nreduce.to(tl.int64), + operation, + dtype, + ) + + @core.extern # type: ignore[misc] + def reduce_extern_wrapper( + team: Any, + dest: Any, + source: Any, + nreduce: Any, + operation: str, + dtype: Any, + _semantic: Any = None, + ) -> None: + """ + Low-level extern wrapper for NVSHMEM reduction operations. + + This function provides a generic interface to NVSHMEM reduction operations, + automatically selecting the appropriate NVSHMEM function based on the data type + and operation specified. + Args: + team (int64): The team handle (0 for NVSHMEM_TEAM_WORLD). + dest (pointer): Destination pointer where reduction results are stored. + source (pointer): Source pointer containing data to be reduced. + nreduce (int64): Number of elements to reduce. + operation (str): Reduction operation ("sum", "max", "min", "prod"). + dtype: Data type specification - accepts torch.dtype, tl.dtype, str, or constexpr. + _semantic: Optional semantic information for Triton compilation. + + Raises: + ValueError: If the operation is not supported. + TypeError: If the data type is not supported. + + Example: + nvshmem.reduce(0, dest_ptr, src_ptr, 100, "sum", torch.float32) + """ + # Mapping from Triton dtype names to NVSHMEM typenames + DTYPE_TO_NVSHMEM_MAP = { + "int8": "int8", + "int16": "int16", + "int32": "int32", + "int64": "int64", + "uint8": "uint8", + "uint16": "uint16", + "uint32": "uint32", + "uint64": "uint64", + "fp16": "half", + "bf16": "bfloat16", + "fp32": "float", + "fp64": "double", + } + + # Triton dtype names are standardized as fp16, bf16, fp32, etc. + dtype_name = str(dtype).replace("tl.", "") + + if dtype_name not in DTYPE_TO_NVSHMEM_MAP: + raise TypeError( + f"Unsupported reduction dtype: {dtype_name}. Supported dtypes: {list(DTYPE_TO_NVSHMEM_MAP.keys())}" + ) + + # Extract operation name from constexpr if needed + op_name = operation.value if hasattr(operation, "value") else operation + + # Validate operation is supported + supported_ops = {"sum", "max", "min", "prod"} + if op_name not in supported_ops: + raise ValueError( + f"Unsupported reduction operation: '{op_name}'. Supported ops are {supported_ops}" + ) + + # Map to NVSHMEM typename and validate dtype is supported + nvshmem_typename = DTYPE_TO_NVSHMEM_MAP.get(dtype_name) + if nvshmem_typename is None: + raise TypeError( + f"Unsupported reduction dtype: {dtype_name}. Supported dtypes are {list(DTYPE_TO_NVSHMEM_MAP.keys())}" + ) + + # Generate NVSHMEM function name + nvshmem_func = f"nvshmem_{nvshmem_typename}_{op_name}_reduce" + + # Define function signature - all parameters are int64 in Triton (they are just ptrs) + signature = ( + core.dtype("int32"), # team handle + core.dtype("int64"), # destination pointer + core.dtype("int64"), # source pointer + core.dtype("int64"), # number of elements + ) + + return core.extern_elementwise( + "", + "", + [team, dest, source, nreduce], + {signature: (nvshmem_func, core.dtype("int32"))}, + is_pure=False, + _semantic=_semantic, + ) + + # Utility for inspecting Triton kernels + + triton_kernels: dict = {} + + def _log_triton_kernel(kernel) -> None: # type: ignore[no-untyped-def] + import atexit + import tempfile + + if dist.is_initialized() and dist.get_rank() != 0: + return + + def on_exit() -> None: + logger.info("PTX files:") + for kernel in triton_kernels: + with tempfile.NamedTemporaryFile(dir="/tmp", delete=False) as f: + f.write(kernel.asm["ptx"].encode("utf-8")) + logger.info(f"+- {kernel.name}: {f.name}") # noqa: G004 + + if len(triton_kernels) == 0: + atexit.register(on_exit) + + if kernel not in triton_kernels: + triton_kernels[kernel] = None diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_tensor/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_tensor/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c5559cc10fabdc1172c9a3ac95ee48ca72b2d65f --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_tensor/__init__.py @@ -0,0 +1,45 @@ +""" +NOTICE: DTensor has moved to torch.distributed.tensor + +This file is a shim to redirect to the new location, and +we keep the old import path starts with `_tensor` for +backward compatibility. We will remove this folder once +we resolve all the BC issues. +""" + +import sys +from importlib import import_module + + +submodules = [ + # TODO: _shards_wrapper/_utils here mainly for checkpoint BC, remove them + "_shards_wrapper", + "_utils", + "experimental", + "device_mesh", +] + +# Redirect imports +for submodule in submodules: + full_module_name = f"torch.distributed.tensor.{submodule}" + sys.modules[f"torch.distributed._tensor.{submodule}"] = import_module( + full_module_name + ) + +from torch.distributed.tensor import ( # noqa: F401 + DeviceMesh, + distribute_module, + distribute_tensor, + DTensor, + empty, + full, + init_device_mesh, + ones, + Partial, + Placement, + rand, + randn, + Replicate, + Shard, + zeros, +) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_tensor/api.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_tensor/api.py new file mode 100644 index 0000000000000000000000000000000000000000..9e5742156a86ca511619360038a9028b0efeeaef --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_tensor/api.py @@ -0,0 +1,9 @@ +""" +NOTE: torch.distributed._tensor has been moved to torch.distributed.tensor. +The imports here are purely for backward compatibility. We will remove these +imports in a few releases + +TODO: throw warnings when this module imported +""" + +from torch.distributed.tensor._api import * # noqa: F401, F403 diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_tensor/placement_types.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_tensor/placement_types.py new file mode 100644 index 0000000000000000000000000000000000000000..6a4e70dbba455471feef2326cae8ba28b32d0304 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_tensor/placement_types.py @@ -0,0 +1,10 @@ +""" +NOTE: torch.distributed._tensor has been moved to torch.distributed.tensor. +The imports here are purely for backward compatibility. We will remove these +imports in a few releases + +TODO: throw warnings when this module imported +""" + +from torch.distributed.tensor._dtensor_spec import * # noqa: F401, F403 +from torch.distributed.tensor.placement_types import * # noqa: F401, F403 diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_tools/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_tools/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..22e974cdd64f1082e7a89e441eb8c90163f56d3b --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_tools/__init__.py @@ -0,0 +1,12 @@ +from .fsdp2_mem_tracker import FSDPMemTracker +from .mem_tracker import MemTracker +from .memory_tracker import MemoryTracker +from .mod_tracker import ModTracker +from .runtime_estimator import RuntimeEstimator +from .sac_estimator import ( + MSPS, + SACEstimator, + SACGreedyOrderMeta, + SACStats, + SACTradeOffStats, +) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_tools/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_tools/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..103e3bbab07f7f6989987b009e70d23332e34057 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_tools/__pycache__/__init__.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_tools/__pycache__/common_utils.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_tools/__pycache__/common_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..628f1e2a7fb221241836ca4f96171f43dae12d61 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_tools/__pycache__/common_utils.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_tools/__pycache__/fake_collectives.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_tools/__pycache__/fake_collectives.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dd1f64d369adcae6ed17f69012a57d09bd638d23 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_tools/__pycache__/fake_collectives.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_tools/__pycache__/fsdp2_mem_tracker.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_tools/__pycache__/fsdp2_mem_tracker.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f7534d9e892ced56e848979452ac447fc93cb4fe Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_tools/__pycache__/fsdp2_mem_tracker.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_tools/__pycache__/ilp_utils.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_tools/__pycache__/ilp_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7d5a17734122842ded02424ee04177c6ebba375e Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_tools/__pycache__/ilp_utils.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_tools/__pycache__/mem_tracker.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_tools/__pycache__/mem_tracker.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d3113405c938c78c72ffbf1f81d75cbe06278dc5 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_tools/__pycache__/mem_tracker.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_tools/__pycache__/memory_tracker.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_tools/__pycache__/memory_tracker.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..93bfded11dc58c8ff0245cfcbe09c2b0a022427c Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_tools/__pycache__/memory_tracker.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_tools/__pycache__/mod_tracker.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_tools/__pycache__/mod_tracker.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..297e4b6eab32be02bb0f827c050382656012bf0a Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_tools/__pycache__/mod_tracker.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_tools/__pycache__/runtime_estimator.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_tools/__pycache__/runtime_estimator.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d87f1495c1e2b2e9fe61aba8a1381a63381a0cad Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_tools/__pycache__/runtime_estimator.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_tools/__pycache__/sac_estimator.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_tools/__pycache__/sac_estimator.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cd28e4811e2433cee4ebe6b9eab449b47f8ff8e7 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_tools/__pycache__/sac_estimator.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_tools/__pycache__/sac_ilp.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_tools/__pycache__/sac_ilp.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0fbb9c586b440c1b4882111cf5bf16bc24f9e69a Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_tools/__pycache__/sac_ilp.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_tools/common_utils.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_tools/common_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..0188a4aa08440e05bcdbbff8c9d14c05540a7909 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_tools/common_utils.py @@ -0,0 +1,33 @@ +import warnings + +import torch +from torch.utils._python_dispatch import is_traceable_wrapper_subclass + + +def get_untyped_storages(t: torch.Tensor) -> set[torch.UntypedStorage]: + """ + Recursively extracts untyped storages from a tensor or its subclasses. + + Args: + t (torch.Tensor): The tensor to extract storages from. + + Returns: + Set[torch.UntypedStorage]: A set of untyped storages. + """ + unflattened_tensors = [t] + flattened_tensor_storages = set() + while len(unflattened_tensors) > 0: + obj = unflattened_tensors.pop() + if is_traceable_wrapper_subclass(obj): + attrs, _ = obj.__tensor_flatten__() # type: ignore[attr-defined] + unflattened_tensors.extend([getattr(obj, attr) for attr in attrs]) + else: + if not hasattr(obj, "untyped_storage"): + warnings.warn( + f"Expected a tensor or a traceable wrapper-subclass of tensor, but got {type(obj)}", + category=UserWarning, + stacklevel=2, + ) + else: + flattened_tensor_storages.add(obj.untyped_storage()) + return flattened_tensor_storages diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_tools/fake_collectives.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_tools/fake_collectives.py new file mode 100644 index 0000000000000000000000000000000000000000..0ac0f8a764d3eca836de98bd82d5495817eadf5b --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_tools/fake_collectives.py @@ -0,0 +1,307 @@ +import random +from typing import Any + +import torch +from torch._C._distributed_c10d import ( + _resolve_process_group, + FakeWork, + ProcessGroup, + Work, +) +from torch.utils._pytree import tree_map_only + + +torch.distributed.batch_isend_irecv + +c10d = torch.ops.c10d +_c10d_functional = torch.ops._c10d_functional +_c10d_functional_autograd = torch.ops._c10d_functional_autograd +_dtensor = torch.ops._dtensor +used_ids: set[int] = set() + + +def generate_unique_id() -> int: + while True: + new_id = random.randint(1, 10**9) + if new_id not in used_ids: + used_ids.add(new_id) + return new_id + + +# Function to create and return FakeWork object +def create_fakework(args, return_first_arg=True): # type: ignore[no-untyped-def] + work = FakeWork() + work.seq_id = generate_unique_id() + fakework_script_obj = work.boxed() + return (args[0], fakework_script_obj) if return_first_arg else fakework_script_obj + + +# Dictionary mapping collective operations to their meta functions +# All 20 ops from torch.csrc.distributed.c10d.Ops.cpp are included +# _DEPRECATED_META_FUNCTIONS = { +# "allreduce_coalesced_": lambda *args: create_fakework(args, return_first_arg=False), +# "allgather_coalesced_": lambda *args: create_fakework(args, return_first_arg=False), +# "allgather_into_tensor_coalesced_": lambda *args: create_fakework(args, return_first_arg=False), +# "reduce_scatter_tensor_coalesced_": lambda *args: create_fakework(args, return_first_arg=False), +# } +_META_FUNCTIONS = { + "broadcast_": lambda *args: create_fakework(args), + "allreduce_": lambda *args: create_fakework(args), + "allgather_": lambda *args: create_fakework(args), + "_allgather_base_": lambda *args: create_fakework(args), + "reduce_scatter_": lambda *args: create_fakework(args), + "_reduce_scatter_base_": lambda *args: create_fakework(args), + "reduce_": lambda *args: create_fakework(args, return_first_arg=False), + "gather_": lambda *args: create_fakework(args, return_first_arg=False), + "scatter_": lambda *args: create_fakework(args), + "alltoall_": lambda *args: create_fakework(args), + "alltoall_base_": lambda *args: create_fakework(args, return_first_arg=False), + "barrier": lambda *args: create_fakework(args, return_first_arg=False), + "monitored_barrier_": lambda *args: None, + "send": lambda *args: create_fakework(args, return_first_arg=False), + "recv_": lambda *args: create_fakework(args, return_first_arg=False), + "recv_any_source_": lambda *args: create_fakework(args, return_first_arg=False), +} + +lib_impl = torch.library.Library("c10d", "IMPL") # noqa: TOR901 +for op, meta_func in _META_FUNCTIONS.items(): + lib_impl.impl(op, meta_func, "Meta") + +# List of collective operation functions including functional collectives +# Note: The following collectives might be deprecated soon hence not adding them +# depcreated_non_functional_collectives = [ +# c10d.allreduce_coalesced_.default, +# c10d.reduce_scatter_tensor_coalesced_.default, +# c10d.allgather_into_tensor_coalesced_.default, +# c10d.allgather_coalesced_.default, +# ] +non_functional_collectives: set[torch._ops.OpOverload] = { + c10d.broadcast_.default, + c10d.allreduce_.default, + c10d.reduce_.default, + c10d.send.default, + c10d.recv_.default, + c10d.recv_any_source_.default, + c10d.allgather_.default, + c10d.reduce_scatter_.default, + c10d._reduce_scatter_base_.default, + c10d._allgather_base_.default, + c10d.gather_.default, + c10d.scatter_.default, + c10d.alltoall_.default, + c10d.alltoall_base_.default, + c10d.barrier.default, + c10d.monitored_barrier_.default, +} +functional_collectives: set[torch._ops.OpOverload] = { + _c10d_functional.broadcast.default, + _c10d_functional.all_reduce.default, + _c10d_functional.all_gather_into_tensor.default, + _c10d_functional.reduce_scatter_tensor.default, + _c10d_functional.reduce_scatter_tensor_out.default, + _c10d_functional.all_to_all_single.default, + _c10d_functional_autograd.all_to_all_single.default, + _c10d_functional.wait_tensor.default, + _c10d_functional.all_reduce_.default, + _c10d_functional.all_reduce_coalesced.default, + _c10d_functional.all_reduce_coalesced_.default, + _c10d_functional.all_gather_into_tensor_out.default, + _c10d_functional.all_gather_into_tensor_coalesced.default, + _c10d_functional_autograd.all_gather_into_tensor.default, + _c10d_functional.reduce_scatter_tensor_coalesced.default, + _c10d_functional_autograd.reduce_scatter_tensor.default, + _c10d_functional.broadcast_.default, + _dtensor.shard_dim_alltoall.default, +} + +sync_ops: set[torch._ops.OpOverload] = { + c10d.barrier.default, + c10d.monitored_barrier_.default, + _c10d_functional.wait_tensor.default, +} + +collective_ops = set.union(functional_collectives, non_functional_collectives) + + +class CollectiveOp: + # Static sets for performance optimization + PG_ARG_1 = { + c10d.broadcast_.default, + c10d.allreduce_.default, + c10d.reduce_.default, + c10d.send.default, + c10d.recv_.default, + c10d.recv_any_source_.default, + c10d.barrier.default, + # c10d.allreduce_coalesced_.default + } + + PG_ARG_2 = { + c10d.allgather_.default, + c10d._allgather_base_.default, + c10d.reduce_scatter_.default, + c10d._reduce_scatter_base_.default, + c10d.gather_.default, + c10d.scatter_.default, + c10d.alltoall_.default, + c10d.alltoall_base_.default, + # c10d.allgather_coalesced_.default, + # c10d.allgather_into_tensor_coalesced_.default + # c10d.reduce_scatter_tensor_coalesced_.default + } + + PG_ARG_3 = { + _c10d_functional.broadcast.default, + _c10d_functional.broadcast_.default, + _c10d_functional.all_reduce.default, + _c10d_functional.all_reduce_.default, + _c10d_functional.all_reduce_coalesced.default, + _c10d_functional.all_reduce_coalesced_.default, + _c10d_functional.all_gather_into_tensor.default, + _c10d_functional.all_gather_into_tensor_out.default, + _c10d_functional_autograd.all_gather_into_tensor.default, + _c10d_functional.all_gather_into_tensor_coalesced.default, + } + + PG_ARG_4 = { + _c10d_functional.reduce_scatter_tensor.default, + _c10d_functional.reduce_scatter_tensor_coalesced.default, + _c10d_functional_autograd.reduce_scatter_tensor.default, + _c10d_functional.all_to_all_single.default, + _c10d_functional_autograd.all_to_all_single.default, + _dtensor.shard_dim_alltoall.default, + } + + WK_ARG_1 = { + c10d.broadcast_.default, + c10d.allreduce_.default, + c10d.allgather_.default, + c10d.reduce_scatter_.default, + c10d._reduce_scatter_base_.default, + c10d._allgather_base_.default, + c10d.scatter_.default, + c10d.alltoall_.default, + } + + WK = { + c10d.send.default, + c10d.recv_.default, + c10d.recv_any_source_.default, + c10d.reduce_.default, + c10d.gather_.default, + c10d.alltoall_base_.default, + c10d.barrier.default, + } + + COMM_TENSOR_ARG_0 = { + c10d.allreduce_.default, + c10d.send.default, + c10d.recv_.default, + c10d.recv_any_source_.default, + c10d.allgather_.default, + c10d.gather_.default, + c10d.reduce_.default, + c10d.broadcast_.default, + _c10d_functional.all_reduce_coalesced.default, + _c10d_functional.all_reduce_coalesced_.default, + # c10d.allreduce_coalesced_.default + # c10d.allgather_coalesced_.default + # c10d.allgather_into_tensor_coalesced_.default, + } + + COMM_TENSOR_ARG_1 = { + c10d.reduce_scatter_.default, + c10d.scatter_.default, + # c10d.reduce_scatter_tensor_coalesced_.default, + } + + COMM_TENSOR_ARG_RES = { + _c10d_functional.all_gather_into_tensor.default, + _c10d_functional_autograd.all_gather_into_tensor.default, + } + + COMM_TENSOR_SINGLE_UNTYPED_STORAGE = { + c10d._allgather_base_.default, + _c10d_functional.broadcast.default, + _c10d_functional.broadcast_.default, + _c10d_functional.all_reduce.default, + _c10d_functional.all_reduce_.default, + _c10d_functional.reduce_scatter_tensor.default, + _c10d_functional_autograd.reduce_scatter_tensor.default, + } + + COMM_TENSOR_ARG_0_AND_RES = { + _c10d_functional.all_to_all_single.default, + _c10d_functional_autograd.all_to_all_single.default, + _dtensor.shard_dim_alltoall.default, + } + + COMM_TENSOR_RES_SUM = { + _c10d_functional.all_gather_into_tensor_coalesced.default, + _c10d_functional.reduce_scatter_tensor_coalesced.default, + } + + @staticmethod + def sum_tensors(arg: Any) -> int: + """Calculate total memory consumed by the tensors in the argument.""" + total_memory = 0 + + def sum_bytes(t: torch.Tensor) -> None: + nonlocal total_memory + total_memory += t.untyped_storage().nbytes() + + tree_map_only(torch.Tensor, sum_bytes, arg) + return total_memory + + @staticmethod + def get_process_group(func, args) -> ProcessGroup: # type: ignore[no-untyped-def] + """Retrieve the process group for collective operations, except `wait_tensor`.""" + if func in CollectiveOp.PG_ARG_1: + return ProcessGroup.unbox(args[1]) + if func in CollectiveOp.PG_ARG_2: + return ProcessGroup.unbox(args[2]) + if func in CollectiveOp.PG_ARG_3: + return _resolve_process_group(args[2]) + if func in CollectiveOp.PG_ARG_4: + return _resolve_process_group(args[3]) + raise TypeError(f"Func {func} not found in {collective_ops}") + + @staticmethod + def get_comm_tensor_size(func, res, args, kwargs) -> int: # type: ignore[no-untyped-def] + """Compute the communication tensor size, except for `wait_tensor`, `barrier`, and `monitored_barrier`.""" + if func in CollectiveOp.COMM_TENSOR_ARG_0: + return CollectiveOp.sum_tensors(args[0]) + if func in CollectiveOp.COMM_TENSOR_ARG_1: + return CollectiveOp.sum_tensors(args[1]) + if func in CollectiveOp.COMM_TENSOR_ARG_RES: + return res.untyped_storage().nbytes() + if func in CollectiveOp.COMM_TENSOR_SINGLE_UNTYPED_STORAGE: + return args[0].untyped_storage().nbytes() + if func is c10d._reduce_scatter_base_.default: + return args[1].untyped_storage().nbytes() + if func is c10d.alltoall_.default: + # TODO(@sanketpurandare) - Confirm size computation + return max( + CollectiveOp.sum_tensors(args[0]), CollectiveOp.sum_tensors(args[1]) + ) + if func is c10d.alltoall_base_.default: + # TODO(@sanketpurandare) - Confirm size computation + return max( + args[0].untyped_storage().nbytes(), args[1].untyped_storage().nbytes() + ) + if func == _c10d_functional.all_gather_into_tensor_out.default: + return args[-1].untyped_storage().nbytes() + if func in CollectiveOp.COMM_TENSOR_RES_SUM: + return CollectiveOp.sum_tensors(res) + if func in CollectiveOp.COMM_TENSOR_ARG_0_AND_RES: + # TODO(@sanketpurandare) - Confirm size computation + return args[0].untyped_storage().nbytes() + res.untyped_storage().nbytes() + raise TypeError(f"Unknown function: {func} in {collective_ops}") + + @staticmethod + def get_work(func, res) -> Work: # type: ignore[no-untyped-def] + if func in CollectiveOp.WK: + return FakeWork.unbox(res) + elif func in CollectiveOp.WK_ARG_1: + return FakeWork.unbox(res[1]) + raise TypeError(f"Func {func} not found in {collective_ops}") diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_tools/fsdp2_mem_tracker.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_tools/fsdp2_mem_tracker.py new file mode 100644 index 0000000000000000000000000000000000000000..7db24cad45b1a69a525efab736437fc48899a6d1 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_tools/fsdp2_mem_tracker.py @@ -0,0 +1,578 @@ +from collections.abc import Callable +from copy import deepcopy +from enum import auto, Enum +from functools import partial, wraps +from typing import Any, NamedTuple, TYPE_CHECKING, TypeVar +from typing_extensions import ParamSpec, TypeVarTuple, Unpack + +import torch +import torch.distributed._tools.fake_collectives +from torch import nn, optim +from torch._guards import active_fake_mode +from torch.distributed._tools.mem_tracker import _RefType, _State, MemTracker +from torch.distributed.fsdp import FSDPModule +from torch.distributed.fsdp._fully_shard._fsdp_param_group import FSDPParamGroup +from torch.distributed.tensor import DTensor +from torch.utils._python_dispatch import TorchDispatchMode +from torch.utils._pytree import tree_map_only +from torch.utils.weak import WeakIdKeyDictionary, weakref + + +if TYPE_CHECKING: + from torch.utils.hooks import RemovableHandle + +_TOTAL_KEY = "Total" + +__all__ = ["FSDPMemTracker"] + +_P = ParamSpec("_P") +_R = TypeVar("_R") +_Ts = TypeVarTuple("_Ts") + +c10d = torch.ops.c10d + + +class _FSDPRefType(_RefType): + """ + Enumerates categories of memory usage in FSDP modules, including parameters, gradients, activations, + and optimizer states. + + Attributes: + SHARDED_PARAM (str): Memory usage of sharded parameters. + UNSHARDED_PARAM (str): Memory usage of unsharded parameters. + SHARDED_GRAD (str): Memory usage of sharded gradients corresponding to the sharded parameters. + UNSHARDED_GRAD (str): Memory usage of unsharded gradients corresponding to the unsharded parameters. + ACT (str): Memory usage of activations and tensors from forward and AC recomputation. + TEMP (str): Memory usage of temporary tensors during the backward pass including gradients of activations. + ALL_GATHER (str): Memory usage of all_gather output tensor. + REDUCE_SCATTER (str): Memory usage of reduce_scatter input tensor. + OPT (str): Memory usage of tensors storing optimizer states. + INP (str): Memory usage of input tensors. + """ + + SHARDED_PARAM = "Sharded Param" + UNSHARDED_PARAM = "Unsharded Param" + BUFFER = "Buffer" + SHARDED_GRAD = "Sharded Grad" + UNSHARDED_GRAD = "Unsharded Grad" + ACT = "Activation" + TEMP = "Temp" + ALL_GATHER = "All Gather" + REDUCE_SCATTER = "Reduce Scatter" + OPT = "OptState" + INP = "Inputs" + + +class _SavedFSDPMethods(NamedTuple): + pre_backward: Callable + post_backward: Callable + + +class _FSDPModState(_State): + """ + Enumerates the states of FSDP modules during the forward and backward passes. + """ + + BEF_PRE_FW = "Before Pre-Forward" + AFT_PRE_FW = "After Pre-Forward" + BEF_POST_FW = "Before Post-Forward" + AFT_POST_FW = "After Post-Forward" + BEF_PRE_BW = "Before Pre-Backward" + AFT_PRE_BW = "After Pre-Backward" + BEF_POST_BW = "Before Post-Backward" + AFT_POST_BW = "After Post-Backward" + PRE_FW_AC = "Pre-Forward AC" + POST_FW_AC = "Post-Forward AC" + PEAK_FW = "Peak Forward" + PEAK_BW = "Peak Backward" + + +class _FSDPModMemStats: + """ + A class to store the memory statistics of an FSDP module. + + Args: + mod_fqn (str): The fully qualified name of the FSDP module. + + Attributes: + snapshots (Dict[_FSDPModState, Dict[torch.device, Dict[str, int]]]): A dictionary of memory snapshots + of the module at different states as defined by ``_FSDPModState``. Each key is a device, and + each value is another dictionary with keys as memory reference types defined by ``_FSDPRefType`` and + values as the memory consumed in bytes. + + """ + + def __init__(self, mod_fqn: str) -> None: + self.mod_fqn = mod_fqn + self.local_peak: dict[torch.device, int] = {} + self.snapshots: dict[ + _FSDPModState, list[dict[torch.device, dict[str, int]]] + ] = {} + + +class _FSDPState(Enum): + PRE_FW = auto() + FW = auto() + POST_FW = auto() + PRE_BW = auto() + BW = auto() + POST_BW = auto() + + +class FSDPMemTracker(MemTracker): + """ + A ``TorchDispatchMode`` based context manager that extends ``torch.distributed._tools.mem_tracker.MemTracker`` to track + and categorize the peak memory and module-wise memory usage of FSDP modules. + + It tracks the peak memory usage across all the devices of all the FSDP modules in the module tree and categorizes + the tensor memory usage as defined by ``_FSDPRefType``. Further, it captures memory `snapshots` at different stages of + the module execution defined by ``_FSDPModState``. + + Attributes: + memory_tracking: A weakref key dictionary to store the memory statistics of each module. Each key is a reference + to a module, and each value is a ``_FSDPModMemStats`` object that stores the memory statistics of the module. + + Args: + mod (torch.nn.Module): The root FSDP module to be tracked. + optm (torch.optim.Optimizer, optional): The optimizer to be tracked. + + Note: Please refer to ``torch.distributed._tools.mem_tracker.MemTracker`` to learn about the limitations. + + Example usage + + .. code-block:: python + + module = ... + optimizer = ... + inp = ... + fmt = FSDPMemTracker(module, optimizer) + fmt.track_inputs((inp,)) + with fmt: + optimizer.zero_grad() + loss = module(inp) + print("After Forward:") + fmt.display_snapshot("current") + loss.backward() + optimizer.step() + fmt.display_snapshot("peak") + fmt.display_modulewise_snapshots(depth=3, units="MB") + + """ + + def __init__( + self, + mod: torch.nn.Module, + optm: torch.optim.Optimizer | None = None, + ) -> None: + super().__init__() + assert isinstance(mod, FSDPModule), "FSDPMemTracker only supports FSDP modules" + self._root_mod = mod + self._optm = optm + self._fsdp_mod_to_saved_methods: WeakIdKeyDictionary = WeakIdKeyDictionary() + self._fsdp_state: _FSDPState = _FSDPState.PRE_FW + self._ref_class: type[_RefType] = _FSDPRefType + + def _instrument_fsdp_sharded_params_grads( + self, fsdp_param_group: FSDPParamGroup + ) -> None: + # Track sharded params and grads after initialization + for fsdp_param in fsdp_param_group.fsdp_params: + self._update_and_maybe_create_winfos( + fsdp_param.sharded_param, + _FSDPRefType.SHARDED_PARAM, + ) + sharded_grad = fsdp_param.sharded_param.grad + if sharded_grad is not None: + self._update_and_maybe_create_winfos( + sharded_grad, + _FSDPRefType.SHARDED_GRAD, + ) + + def _fsdp_state_pre_forward( + self, + fsdp_mod: FSDPModule, + orig_fsdp_state_pre_fw: Callable[_P, tuple[tuple[Unpack[_Ts]], dict[str, Any]]], + ) -> Callable[_P, tuple[tuple[Unpack[_Ts]], dict[str, Any]]]: + # We capture memory snapshots before and after ``FSDPState._pre_forward`` to attribute the `unsharded` params + # and `all_gather` buffers. There are three cases: + # Case 1: If the module is not in the ``memory_tracking`` dictionary, create a new ``_FSDPModMemStats`` + # instance for the module and add it to the ``memory_tracking`` dictionary. + # Case 2: If the module is already in the ``memory_tracking`` dictionary and we are in backward, this means + # we are in the AC region. We check if this is the top most module in the AC region. If it is, + # we store a weak reference and set the flag ``_in_ac`` to True. + # Case 3: If the module is already in the ``memory_tracking`` dictionary and we are in forward, this means + # this module is called for the second time. If it is a root module, that means we are in the next + # iteration and we error out. If it is not a root module, that means it's a submodule that is being + # used multiple times in the same iteration, which we allow and track. + # For Case 1 and 3, we also initialize the ``local_peak`` and ``PEAK_FW`` snapshot for the module. + # For Case 2 we only capture 1 snapshot after ``FSDPState._pre_forward`` runs because it is a no-op. + @wraps(orig_fsdp_state_pre_fw) + def inner( + *args: _P.args, **kwargs: _P.kwargs + ) -> tuple[tuple[Unpack[_Ts]], dict[str, Any]]: + self._fsdp_state = _FSDPState.PRE_FW + mod_fqn = self._mod_tracker.get_known_fqn(fsdp_mod) + assert mod_fqn is not None + if fsdp_mod not in self.memory_tracking: + mod_stat = _FSDPModMemStats(mod_fqn) + self.memory_tracking[fsdp_mod] = mod_stat + snapshot = self.get_tracker_snapshot() + mod_stat.local_peak = { + dev: dev_snap[_TOTAL_KEY] for dev, dev_snap in snapshot.items() + } + mod_stat.snapshots.setdefault(_FSDPModState.PEAK_FW, []).append( + snapshot + ) + mod_stat.snapshots.setdefault(_FSDPModState.BEF_PRE_FW, []).append( + deepcopy(snapshot) + ) + elif not self._mod_tracker.is_bw: + parents = self._mod_tracker.parents - {mod_fqn} + if len(parents) == 1 and "Global" in parents: + raise NotImplementedError( + "FSDPMemTracker does not support memory tracking for multiple iterative calls." + " Either use ``reset_mod_stats`` to clear module memory stats for the previous iteration" + " or file a github issue if you need this feature." + ) + + # pyrefly: ignore [bad-assignment] + args, kwargs = orig_fsdp_state_pre_fw(*args, **kwargs) + + fsdp_state = fsdp_mod._get_fsdp_state() + if fsdp_param_group := fsdp_state._fsdp_param_group: + for fsdp_param in fsdp_param_group.fsdp_params: + self._update_and_maybe_create_winfos( + fsdp_param.unsharded_param, + _FSDPRefType.UNSHARDED_PARAM, + ) + mod_stat = self.memory_tracking[fsdp_mod] + if self._mod_tracker.is_bw: + state = _FSDPModState.PRE_FW_AC + if self._ac_mod is None: + self._ac_mod = weakref.ref(fsdp_mod) + self._in_ac = True + else: + state = _FSDPModState.AFT_PRE_FW + mod_stat.snapshots.setdefault(state, []).append(self.get_tracker_snapshot()) + self._fsdp_state = _FSDPState.FW + return args, kwargs + + return inner + + def _fsdp_state_post_forward( + self, + fsdp_mod: FSDPModule, + orig_fsdp_state_post_fw: Callable[_P, _R], + ) -> Callable[_P, _R]: + # We capture memory snapshots before and after ``FSDPState._post_forward`` to capture the resharded state + # if ``reshard_after_forward`` is not ``False``. There are two cases: + # Case 1: This is called in backward, which means we are in the AC region. If this is the top most module + # in the AC region, we set the flag ``_in_ac`` to False. + # Case 2: This is called in forward. + @wraps(orig_fsdp_state_post_fw) + def inner(*args: _P.args, **kwargs: _P.kwargs) -> _R: + mod_stat = self.memory_tracking[fsdp_mod] + if self._mod_tracker.is_bw: + state = _FSDPModState.POST_FW_AC + if self._ac_mod is not None and self._ac_mod() is fsdp_mod: + self._ac_mod = None + self._in_ac = False + else: + state = _FSDPModState.BEF_POST_FW + mod_stat.snapshots.setdefault(state, []).append(self.get_tracker_snapshot()) + self._fsdp_state = _FSDPState.POST_FW + + output = orig_fsdp_state_post_fw(*args, **kwargs) + + if not self._mod_tracker.is_bw: + mod_stat.snapshots.setdefault(_FSDPModState.AFT_POST_FW, []).append( + self.get_tracker_snapshot() + ) + return output + + return inner + + def _fsdp_param_group_pre_backward( + self, + fsdp_mod: FSDPModule, + orig_fsdp_param_group_pre_backward: Callable[_P, Any], + ) -> Callable[_P, None]: + # We capture memory snapshots before and after ``FSDPParamGroup.pre_backward`` to capture the pre-fetching + # and unsharding of params. We also initialize ``local_peak`` and ``PEAK_BW`` snapshot for the module. + @wraps(orig_fsdp_param_group_pre_backward) + def inner(*args: _P.args, **kwargs: _P.kwargs) -> None: + self._fsdp_state = _FSDPState.PRE_BW + mod_stat = self.memory_tracking[fsdp_mod] + snapshot = self.get_tracker_snapshot() + mod_stat.local_peak = { + dev: dev_snap[_TOTAL_KEY] for dev, dev_snap in snapshot.items() + } + mod_stat.snapshots.setdefault(_FSDPModState.PEAK_BW, []).append(snapshot) + mod_stat.snapshots.setdefault(_FSDPModState.BEF_PRE_BW, []).append( + deepcopy(snapshot) + ) + orig_fsdp_param_group_pre_backward(*args, **kwargs) + + mod_stat.snapshots.setdefault(_FSDPModState.AFT_PRE_BW, []).append( + self.get_tracker_snapshot() + ) + self._fsdp_state = _FSDPState.BW + + return inner + + def _fsdp_param_group_post_backward( + self, + fsdp_mod: FSDPModule, + orig_fsdp_param_group_post_backward: Callable[_P, Any], + ) -> Callable[_P, None]: + # We capture the memory snapshots before and after ``FSDPParamGroup.post_backward`` to track and attribute + # the `unsharded` grads before the post backward and then `sharded` grads and `reduce_scatter` buffers + # after the post backward. + @wraps(orig_fsdp_param_group_post_backward) + def inner(*args: _P.args, **kwargs: _P.kwargs) -> None: + fsdp_state = fsdp_mod._get_fsdp_state() + if fsdp_param_group := fsdp_state._fsdp_param_group: + for fsdp_param in fsdp_param_group.fsdp_params: + unsharded_grad = fsdp_param._unsharded_param.grad + if unsharded_grad is not None: + self._update_and_maybe_create_winfos( + unsharded_grad, + _FSDPRefType.UNSHARDED_GRAD, + update_existing=True, + ) + + mod_stat = self.memory_tracking[fsdp_mod] + mod_stat.snapshots.setdefault(_FSDPModState.BEF_POST_BW, []).append( + self.get_tracker_snapshot() + ) + self._fsdp_state = _FSDPState.POST_BW + orig_fsdp_param_group_post_backward(*args, **kwargs) + + if fsdp_param_group := fsdp_state._fsdp_param_group: + for fsdp_param in fsdp_param_group.fsdp_params: + sharded_grad = fsdp_param.sharded_param.grad + if sharded_grad is not None: + self._update_and_maybe_create_winfos( + sharded_grad, + _FSDPRefType.SHARDED_GRAD, + ) + + mod_stat.snapshots.setdefault(_FSDPModState.AFT_POST_BW, []).append( + self.get_tracker_snapshot() + ) + + return inner + + def _instrument_fsdp_module(self) -> None: + # We uninstall the existing `FSDPState._pre_forward` and `FSDPState._post_forward` hooks and install + # our own hooks that wrap them. We choose this over monkey-patching `FSDPParamGroup.pre_forward` and + # `FSDPParamGroup.post_forward` because during AC these won't be called. + # TODO(@sanketpurandare): This will need to be modified after this PR (https://github.com/pytorch/pytorch/pull/127786) + # lands. For backward we monkey-patch the `FSDPParamGroup.pre_backward` and `FSDPParamGroup.post_backward`. + + # get the unique _MultiHandlers/RemoveHandlers and store in dictionary + # the _MultiHandlers object will only need to be grabbed once. + unique_handlers: dict[RemovableHandle, bool] = {} + # pyrefly: ignore # missing-attribute + for module in self._root_mod.modules(): + if isinstance(module, FSDPModule): + fsdp_state = module._get_fsdp_state() + if fsdp_param_group := fsdp_state._fsdp_param_group: + if not unique_handlers.get(fsdp_state._pre_forward_hook_handle): + unique_handlers[fsdp_state._pre_forward_hook_handle] = True + if not unique_handlers.get(fsdp_state._post_forward_hook_handle): + unique_handlers[fsdp_state._post_forward_hook_handle] = True + # call remove on the handles once + for f_hook_handle in unique_handlers: + f_hook_handle.remove() + # pyrefly: ignore # missing-attribute + for module in self._root_mod.modules(): + if isinstance(module, FSDPModule): + fsdp_state = module._get_fsdp_state() + if fsdp_param_group := fsdp_state._fsdp_param_group: + self._instrument_fsdp_sharded_params_grads(fsdp_param_group) + fsdp_state._pre_forward_hook_handle = ( + # pyrefly: ignore [missing-attribute] + module.register_forward_pre_hook( + self._fsdp_state_pre_forward( + module, fsdp_state._pre_forward + ), + prepend=True, + with_kwargs=True, + ) + ) + # pyrefly: ignore [missing-attribute] + fsdp_state._post_forward_hook_handle = module.register_forward_hook( + self._fsdp_state_post_forward(module, fsdp_state._post_forward), + prepend=False, + always_call=True, + ) + self._fsdp_mod_to_saved_methods[module] = _SavedFSDPMethods( + fsdp_param_group.pre_backward, + fsdp_param_group.post_backward, + ) + fsdp_param_group.pre_backward = self._fsdp_param_group_pre_backward( # type: ignore[assignment] + module, fsdp_param_group.pre_backward + ) + fsdp_param_group.post_backward = ( # type: ignore[assignment] + self._fsdp_param_group_post_backward( + module, fsdp_param_group.post_backward + ) + ) + + # pyrefly: ignore [missing-attribute] + for buffer in self._root_mod.buffers(): + self._update_and_maybe_create_winfos( + buffer, + _FSDPRefType.BUFFER, + ) + + def _instrument_optimizer(self) -> None: + # Register a hook on the optimizer step to track the optimizer states. + # The pre-hook is to set the flag ``_in_opt`` to True. The post-hook unsets the flag, + # and also tracks any optimizer states that are created during the optimizer step. + if self._optm is not None: + self._track_optimizer_states(_FSDPRefType.OPT, self._optm) + + def _opt_step_pre_hook( + optimizer: optim.Optimizer, args: Any, kwargs: Any + ) -> None: + self._in_opt = True + + def _opt_step_post_hook( + optimizer: optim.Optimizer, args: Any, kwargs: Any + ) -> None: + self._track_optimizer_states(_FSDPRefType.OPT, optimizer) + self._in_opt = False + + self._optimizer_hook_handles = ( + self._optm.register_step_pre_hook(_opt_step_pre_hook), + self._optm.register_step_post_hook(_opt_step_post_hook), + ) + + def _register_module_and_optimizer_hooks(self) -> None: + self._instrument_fsdp_module() + self._instrument_optimizer() + + def _deregister_module_and_optimizer_hooks(self) -> None: + for ( + fsdp_mod, + saved_methods, + ) in self._fsdp_mod_to_saved_methods.items(): + fsdp_state = fsdp_mod._get_fsdp_state() + fsdp_state._pre_forward_hook_handle.remove() + fsdp_state._post_forward_hook_handle.remove() + fsdp_state._pre_forward_hook_handle = fsdp_mod.register_forward_pre_hook( + fsdp_state._pre_forward, prepend=True, with_kwargs=True + ) + fsdp_state._post_forward_hook_handle = fsdp_mod.register_forward_hook( + fsdp_state._post_forward, prepend=False + ) + if fsdp_param_group := fsdp_state._fsdp_param_group: + fsdp_param_group.pre_backward = saved_methods.pre_backward + fsdp_param_group.post_backward = saved_methods.post_backward + self._fsdp_mod_to_saved_methods.clear() + + if self._optimizer_hook_handles is not None: + for handle in self._optimizer_hook_handles: + handle.remove() + self._optimizer_hook_handles = None + + def track_inputs(self, inputs: tuple[Any, ...]) -> None: + """ + This is used to track the input tensors to the model and annotate them as ``Inputs``. + Args: + inputs (Tuple[Any]): A tuple containing the input data. This can include tensors + as well as other data types. Only tensors will be tracked. + """ + + def _track_inputs(t: torch.Tensor) -> None: + self._update_and_maybe_create_winfos( + t, + _FSDPRefType.INP, + ) + + tree_map_only(torch.Tensor, _track_inputs, inputs) + + def track_external( + self, *external: nn.Module | optim.Optimizer | torch.Tensor + ) -> None: + """This is no-op for ``FSDPMemTracker``""" + + def __enter__(self) -> "FSDPMemTracker": + if self._depth == 0: + self._register_module_and_optimizer_hooks() + self._track_resize() + self._peak_mem_snap = self.get_tracker_snapshot() + self._peak_mem = { + dev: dev_snap[_TOTAL_KEY] + for dev, dev_snap in self._peak_mem_snap.items() + } + self._mod_tracker.__enter__() + TorchDispatchMode.__enter__(self) + self._depth += 1 + return self + + def __exit__(self, *args: Any) -> None: + self._depth -= 1 + if self._depth == 0: + self._deregister_module_and_optimizer_hooks() + self._restore_resize() + self._mod_tracker.__exit__(*args) + TorchDispatchMode.__exit__(self, *args) + + def __torch_dispatch__(self, func, types, args=..., kwargs=None): # type: ignore[no-untyped-def] + # When running this mode with DTensor, ordinarily all modes will + # run **before** subclasses get a chance to run. + # Returning NotImplemented here gives us a chance to let DTensor + # run and desugar into local tensor ops, before `MemTracker` sees them. + if any(t == DTensor for t in types): + return NotImplemented + if ( + func is torch.ops._c10d_functional.wait_tensor.default + and active_fake_mode() + ): + # N.B: This is a hacky way to override the Meta IMPL of wait_tensor. The original impl returns + # a new tensor which does not happen in eager mode, when a wait_tensor is called. + # pyrefly: ignore [unsupported-operation] + res = args[0] + else: + res = func(*args, **kwargs or {}) + # If we are tracking an optimizer state, we use the optimizer reference type. + # If we are in backward region and not in AC region, we use the backward reference type. + # Else we use the forward reference type. + if self._in_opt: + reftype = _FSDPRefType.OPT + elif self._mod_tracker.is_bw and not self._in_ac: + reftype = _FSDPRefType.TEMP + else: + reftype = _FSDPRefType.ACT + if func is c10d._allgather_base_.default and self._fsdp_state in [ + _FSDPState.PRE_FW, + _FSDPState.PRE_BW, + ]: + # pyrefly: ignore [unsupported-operation] + output_tensor = args[0] + self._update_and_maybe_create_winfos( + output_tensor, + _FSDPRefType.ALL_GATHER, + update_existing=True, + ) + if ( + func is c10d._reduce_scatter_base_.default + and self._fsdp_state == _FSDPState.POST_BW + ): + # pyrefly: ignore [unsupported-operation] + input_tensor = args[1] + self._update_and_maybe_create_winfos( + input_tensor, + _FSDPRefType.REDUCE_SCATTER, + update_existing=True, + ) + + tree_map_only(torch.Tensor, partial(self._track, reftype), res) + peak_state = ( + _FSDPModState.PEAK_BW if self._mod_tracker.is_bw else _FSDPModState.PEAK_FW + ) + self._update_peak_stats(peak_state) + return res diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_tools/ilp_utils.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_tools/ilp_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..0e8ba4195ffd20323d419642159fe199549e3de1 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_tools/ilp_utils.py @@ -0,0 +1,292 @@ +import copy +from collections import OrderedDict +from typing import cast, TypedDict + +import numpy as np + +import torch +from torch.distributed._tools.mem_tracker import ( + _MemRefType, + _ModMemStats, + _ModState, + MemTracker, +) +from torch.distributed._tools.runtime_estimator import RuntimeEstimator +from torch.distributed._tools.sac_estimator import SACEstimator, SACTradeOffStats + + +class ModOrder(TypedDict): + fw_pre_order: list[str] + bw_pre_order: list[str] + fw_post_order: list[str] + bw_post_order: list[str] + + +class ModRuntime(TypedDict): + fw: float + bw: float + + +class ModStats(TypedDict): + fqn: str + # per-module params + param_per_module: int + # per-module grads + grad_per_module: int + # total accumulated gradients up to and including this module + grad_total: int + # per module fw activation size (excluding input and output) + act_fw_per_module: int + # per module bw activation size during peak_bw + act_bw_per_module: int + # per module activation grad size during peak_bw + act_grad_per_module: int + # total activation size up to but excluding the current module + # includes input of the current module (i.e., output of previous module) + act_total: int + # Inputs to the module + input_per_module: int + # Outputs of the module + output_per_module: int + # Total fw run-time of the module + fw_runtime_per_module: float + # Total bw run-time of the module + bw_runtime_per_module: float + # Is this module a leaf module + is_leaf: bool + # Total ac run-time of the module + sac_runtime: float + # Total ac_memory for the module + sac_memory: int + # Number of piecewise-linear functions used for approximating ac tradeoff curve + n_segments: int + # Slopes of the of piecewise-linear functions + slopes: list[float] + # Intercepts of the of piecewise-linear functions + intercepts: list[float] + # X breakpoints of the of piecewise-linear functions + breakpoints: list[float] + # Original trade-off curves + tradeoff_curve: OrderedDict[float, float] + + +class ModuleInfo(TypedDict): + mod_order: ModOrder + mod_stats: list[ModStats] + + +def aggregate_stats( + model: torch.nn.Module, + mem_tracker: MemTracker, + runtime_estimator: RuntimeEstimator, + sac_estimator: SACEstimator, + dev: torch.device, +) -> ModuleInfo: + """ + Collect modulewise stats for a given model, including memory, runtime, and AC tradeoff stats. + + Args: + model: nn.Module object + runtime_estimator: RuntimeEstimator object with runtime stats + mem_tracker: MemTracker object with memory stats + sac_estimator: SACEstimator object with AC tradeoff stats + dev: device the model was run on (used to extract memory stats from MemTracker) + + Returns: + ModuleInfo: A dictionary with module order and module stats. + """ + + # Memory stats + mod_mem_stats: dict[torch.nn.Module, _ModMemStats] = dict( + copy.deepcopy(mem_tracker.memory_tracking) + ) + + # Runtime stats + mod_runtime_stats: dict[str, ModRuntime] = { + fqn: {"fw": v["fw"], "bw": v["bw"]} + for fqn, v in runtime_estimator.mod_runtimes.items() + } + + # Module order + mod_order: ModOrder = { + "fw_pre_order": list(runtime_estimator.mod_fw_pre_order), + "bw_pre_order": list(runtime_estimator.mod_bw_pre_order), + "fw_post_order": list(runtime_estimator.mod_fw_post_order), + "bw_post_order": list(runtime_estimator.mod_bw_post_order), + } + + # Selective Activation Checkpointing stats + sac_estimator.pwlf_sac_tradeoff_curve() + mod_sac_tradeoff_stats: dict[str, SACTradeOffStats] = copy.deepcopy( + sac_estimator.sac_mod_tradeoff_stats + ) + + module_info: ModuleInfo = { + "mod_order": mod_order, + "mod_stats": [], + } + + for mod in model.modules(): + if mod_mem_stat := mod_mem_stats.get(mod): + if tradeoff_stats := mod_sac_tradeoff_stats.get(mod_mem_stat.mod_fqn, None): + sac_runtime = tradeoff_stats.sac_runtime + sac_memory = tradeoff_stats.sac_memory + n_segments = tradeoff_stats.n_segments + slopes = tradeoff_stats.slopes + intercepts = tradeoff_stats.intercepts + breakpoints = tradeoff_stats.fit_breaks + tradeoff_curve = tradeoff_stats.tradeoff_curve + is_leaf = False + else: + sac_runtime = sac_memory = n_segments = 0 + slopes = intercepts = breakpoints = [] + tradeoff_curve: OrderedDict[float, float] = OrderedDict() # type: ignore[no-redef] + is_leaf = True + mod_stat: ModStats = { + "fqn": mod_mem_stat.mod_fqn, + "param_per_module": mod_mem_stat.parameter_mem, + "grad_per_module": mod_mem_stat.parameter_mem, + "grad_total": mod_mem_stat.snapshots[_ModState.PRE_BW][-1][dev][ + _MemRefType.GRAD + ], + "act_fw_per_module": max( + 0, + mod_mem_stat.snapshots[_ModState.POST_FW][-1][dev][_MemRefType.ACT] + - mod_mem_stat.snapshots[_ModState.PRE_FW][-1][dev][_MemRefType.ACT] + - mod_mem_stat.output_mem, + ), + "act_bw_per_module": max( + 0, + mod_mem_stat.snapshots[_ModState.PEAK_BW][-1][dev][_MemRefType.ACT], + ), + "act_grad_per_module": ( + mod_mem_stat.snapshots[_ModState.PEAK_BW][-1][dev][_MemRefType.TEMP] + - mod_mem_stat.snapshots[_ModState.PRE_BW][-1][dev][ + _MemRefType.TEMP + ] + ), + "act_total": mod_mem_stat.snapshots[_ModState.POST_FW][-1][dev][ + _MemRefType.ACT + ], + "input_per_module": mod_mem_stat.input_mem, + "output_per_module": mod_mem_stat.output_mem, + "fw_runtime_per_module": mod_runtime_stats[mod_mem_stat.mod_fqn]["fw"], + "bw_runtime_per_module": mod_runtime_stats[mod_mem_stat.mod_fqn]["bw"], + "is_leaf": is_leaf, + "sac_runtime": sac_runtime, + "sac_memory": sac_memory, + "n_segments": n_segments, + "slopes": slopes, + "intercepts": intercepts, + "breakpoints": breakpoints, + "tradeoff_curve": tradeoff_curve, + } + module_info["mod_stats"].append(mod_stat) + + return module_info + + +class Node(ModStats): + index: int # index according to forward pre-order + pos_fw_post_order: int # index according to forward post-order + + +class Graph: + def __init__(self, n: int) -> None: + self.nodes: list[Node] = [] + self.name2node: dict[str, Node] = {} + self.ad_matrix = np.zeros((n, n)) + self.fw_post_order: list[str] = [] + + def add_node(self, node: Node) -> None: + self.nodes.append(node) + self.name2node[node["fqn"]] = node + + +def parse_module_info(module_info: ModuleInfo) -> Graph: + """ + Parse module info and create a graph (tree) of modules. The graph will be + used by MILP solver to find optimal SAC and/or FSDP configurations. + """ + mod_stats = module_info["mod_stats"] + fw_pre_order = module_info["mod_order"]["fw_pre_order"] + # assertion and number of nodes + assert len(mod_stats) == len(fw_pre_order) + n_nodes = len(mod_stats) + + # create graph + g = Graph(n_nodes) + g.fw_post_order = module_info["mod_order"]["fw_post_order"] + + # sort the modules by pre-order and add them to the graph + module_info["mod_stats"] = sorted( + mod_stats, key=lambda x: fw_pre_order.index(x["fqn"]) + ) + for i, one_mod_stats in enumerate(mod_stats): + node: Node = cast(Node, one_mod_stats) + node["index"] = i + node["pos_fw_post_order"] = g.fw_post_order.index(node["fqn"]) + g.add_node(node) + + # set up ancestor-descendant matrix + for i in range(n_nodes): + for j in range(i, n_nodes): + if is_self_or_submodule(g.nodes[j]["fqn"], g.nodes[i]["fqn"]): + g.ad_matrix[i][j] = 1 + else: + break + + return g + + +def is_self_or_submodule(name_descendant: str, name_ancestor: str) -> bool: + """ + check if name_descendant is a submodule of name_ancestor, or if they are the same + """ + return name_descendant == name_ancestor or name_ancestor + "." in name_descendant + + +def is_submodule(name_descendant: str, name_ancestor: str) -> bool: + """ + if name_descendant is a submodule of name_ancestor, but not the same + """ + return name_ancestor + "." in name_descendant + + +def display_bytes(b: int, unit: str = "MiB") -> str: + """ + return a string that represent the number of bytes in a desired unit + """ + if unit == "KiB": + return f"{b / 2**10:.2f} KiB" + if unit == "MiB": + return f"{b / 2**20:.2f} MiB" + if unit == "GiB": + return f"{b / 2**30:.2f} GiB" + return f"{b:.2f} bytes" + + +def get_peak_memory_runtime_baseline(graph: Graph) -> tuple[int, float]: + """ + Get the baseline peak memory and runtime. + Baseline here means there is no FSDP or AC. + Memory includes the parameters, gradients, activations, and activation gradients. + Memory does not include e.g., optimizer states, embedding tables, etc. + + Returns: + int: peak memory in bytes + float: compute time in ms + """ + P_1 = graph.nodes[0]["param_per_module"] + num_nodes = len(graph.nodes) + peak_mem = 0 + for i in range(num_nodes): + TG_i = graph.nodes[i]["grad_total"] + AG_i = graph.nodes[i]["act_grad_per_module"] + TA_i = graph.nodes[i]["act_total"] + peak_mem = max(peak_mem, P_1 + TG_i + AG_i + TA_i) + compute_time = ( + graph.nodes[0]["fw_runtime_per_module"] + + graph.nodes[0]["bw_runtime_per_module"] + ) + return (peak_mem, compute_time) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_tools/mem_tracker.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_tools/mem_tracker.py new file mode 100644 index 0000000000000000000000000000000000000000..bcf03f132b1a7fbd7c03ed0b1a0b03e40ebebdb2 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_tools/mem_tracker.py @@ -0,0 +1,938 @@ +import math +import os +import re +import warnings +from collections.abc import Callable +from copy import deepcopy +from enum import auto, Enum +from functools import partial, wraps +from typing import Any, TYPE_CHECKING +from typing_extensions import Self + +import torch +import torch.distributed._tools.fake_collectives +from torch import nn, optim +from torch._guards import active_fake_mode +from torch.distributed._tools.common_utils import get_untyped_storages +from torch.distributed._tools.mod_tracker import ModTracker +from torch.distributed.tensor import DTensor +from torch.optim.optimizer import ( + register_optimizer_step_post_hook, + register_optimizer_step_pre_hook, +) +from torch.utils._python_dispatch import TorchDispatchMode +from torch.utils._pytree import tree_flatten, tree_map_only +from torch.utils.weak import WeakIdKeyDictionary, weakref + + +if TYPE_CHECKING: + from torch.utils.hooks import RemovableHandle + +# This value is hard-coded here: +# https://github.com/pytorch/pytorch/blob/5fba5d83f0703ff8077ab65448a998e9ad6598fd/c10/cuda/CUDACachingAllocator.cpp#L117 +_PYTORCH_MIN_ALLOCATE = ( + 2**9 if int(os.environ.get("PYTORCH_NO_CUDA_MEMORY_CACHING", 0)) == 0 else 1 +) +_TOTAL_KEY = "Total" + +__all__ = ["MemTracker"] + + +class _RefType(str, Enum): + """Base Class for defining memory reference types, categorizing tensors based on their usage within a model.""" + + +class _State(str, Enum): + """Base Class for defining module state to capture snapshots .""" + + +class _MemRefType(_RefType): + """ + An enum to define memory reference types, categorizing tensors based on their usage within a model. + + - PARAM: Tensors registered as nn.Parameter within modules. + - BUFFER: Tensors registered as nn.Buffer within modules. + - GRAD: Gradients associated with parameters. + - ACT: Tensors produced during the forward pass and recomputation in activation checkpointing. + - TMP: Temporary memory used during the backward pass, including gradients of activations. + - OPT: Tensors holding optimizer states. + - OTH: Tensors registered via `track_external` that do not fit the above categories. + """ + + PARAM = "Parameter" + BUFFER = "Buffer" + GRAD = "Gradient" + ACT = "Activation" + TEMP = "Temp" + OPT = "Optstate" + OTH = "Other" + + +class _ModState(_State): + """ + An enum to define the state of a module. + + - PRE_FW: The module is about to run the forward pass. + - POST_FW: The module has finished running the forward pass. + - PEAK_FW: The module has reached the peak memory usage during the forward pass. + - PRE_BW: The module is about to run the backward pass. + - PRE_FW_AC: The module is about to run the forward pass with activation checkpointing. + - POST_FW_AC: The module has finished running the forward pass with activation checkpointing. + - POST_BW: The module has finished running the backward pass. + - PEAK_BW: The module has reached the peak memory usage during the backward pass. + """ + + PRE_FW = "Pre-Forward" + POST_FW = "Post-Forward" + PEAK_FW = "Peak-Forward" + PRE_BW = "Pre-Backward" + PRE_FW_AC = "Pre-Forward-AC" + POST_FW_AC = "Post-Forward-AC" + POST_BW = "Post-Backward" + PEAK_BW = "Peak-Backward" + + +class _ModMemStats: + """ + A class to store the memory statistics of a module. + + Args: + mod_fqn (str): The fully qualified name of the module. + Attributes: + mod_fqn (str): The fully qualified name of the module. + parameter_mem (int): The memory usage of the parameters of the module. + buffer_mem (int): The memory usage of the buffers of the module. + input_mem (int): The memory usage of the inputs to the module. + output_mem (int): The memory usage of the outputs from the module. + snapshots (Dict[_ModState, Dict[torch.device, Dict[str, int]]]): A dictionary of memory snapshots + of the module at different states defined by ``_ModState``. + Note: + The memory snapshot is stored as a dictionary - Dict[torch.device, Dict[str, int]], where each key is a device, + and each value is another dictionary with keys as memory reference types defined by `_MemRefType` and + values as the memory consumed in bytes. + """ + + def __init__(self, mod_fqn: str): + self.mod_fqn = mod_fqn + self.parameter_mem: int + self.buffer_mem: int + self.input_mem: int + self.output_mem: int + self.local_peak: dict[torch.device, int] = {} + self.snapshots: dict[_ModState, list[dict[torch.device, dict[str, int]]]] = {} + + +class _WeakRefInfo: + """ + Manages memory statistics and device attributes for tensor storages. + """ + + def __init__( + self, size: int, element_size: int, device: torch.device, reftype: _RefType + ) -> None: + """ + Initializes the ``_WeakRefInfo`` object with tensor storage properties. + + Args: + size (int): The number of elements in the tensor storage. + element_size (int): The size of each element in the tensor storage. + device (torch.device): The device on which the tensor is allocated. + reftype (_RefType): The reference type of the tensor. + """ + self.size = size + self.element_size = element_size + self.reftype = reftype + # pyrefly: ignore [read-only] + self.device = device + self.mem_consumed = self._calculate_mem_consumed() + + def _calculate_mem_consumed(self) -> int: + """ + Calculates the memory consumed by the tensor storage, considering device-specific allocation rules. + + Returns: + int: The memory consumed in bytes. + """ + mem = self.size * self.element_size + if self.device.type == "cuda": + return math.ceil((mem) / _PYTORCH_MIN_ALLOCATE) * _PYTORCH_MIN_ALLOCATE + return mem + + def update_mem_consumed(self, st: torch.UntypedStorage) -> int: + """ + Updates and returns the memory consumed if the storage size has changed. + + Args: + st (torch.UntypedStorage): The tensor storage to check for size updates. + + Returns: + int: The updated memory consumed in bytes. + """ + if st.size() != self.size: + self.size = st.size() + self.mem_consumed = self._calculate_mem_consumed() + return self.mem_consumed + + @classmethod + def create_winfo( + cls, + st: torch.UntypedStorage, + device: torch.device, + reftype: _RefType, + callback: Callable[[Self, weakref.ref], Any] | None = None, + ) -> tuple[Self, weakref.ref]: + """ + Creates a new ``_WeakRefInfo`` instance and a weak reference to a ``torch.UntypedStorage`` object, + optionally attaching a callback to the weak reference. + + Args: + st (torch.UntypedStorage): The storage object for which to create the weak reference info. + device (torch.device): The device associated with the storage object. + reftype (_RefType): The type of reference, used to categorize the storage. + callback (Optional[Callable[[Self, weakref.ref]]]): A callback function that is called when + the storage object is about to be finalized (garbage collected). The callback function + should accept two arguments: the ``_WeakRefInfo`` instance and the weak reference to the storage. + Returns: + Tuple[Self, weakref.ref]: A tuple containing the newly created ``_WeakRefInfo`` instance and the + weak reference to the storage object. The weak reference may have an attached callback if provided. + """ + + winfo = cls(st.size(), st.element_size(), device, reftype) + w_st = weakref.ref(st, partial(callback, winfo) if callback else None) + return winfo, w_st + + +def _get_mem_divisor(units: str) -> int: + unit_dict = {"B": 1, "KiB": 2**10, "MiB": 2**20, "GiB": 2**30} + if units in unit_dict: + return unit_dict[units] + else: + raise ValueError( + f"Unsupported unit: {units}. Supported units are: {', '.join(unit_dict.keys())}" + ) + + +def _rounding_fn(value: int, divisor: int, precision: int) -> float | int: + return value if divisor == 1 else round(value / divisor, precision) + + +def _print_snapshot(snapshot: dict[torch.device, dict[str, int]], units: str) -> None: + if len(snapshot) == 0: + print("No memory tracked.") + return + divisor = _get_mem_divisor(units) + for dev, dev_snap in snapshot.items(): + if _rounding_fn(dev_snap[_TOTAL_KEY], divisor, 2) <= 0: + continue + print( + f"Device: {dev}", + *( + f"\t{k.value}: {_rounding_fn(v, divisor, 2)} {units}" + if isinstance(k, _RefType) + else f"\t{k}: {_rounding_fn(v, divisor, 2)} {units}" + for k, v in dev_snap.items() + ), + sep="\n", + ) + + +def _print_snapshot_tabular( + snapshot: dict[torch.device, dict[str, int]], units: str +) -> None: + if len(snapshot) == 0: + print("No memory tracked.") + return + try: + from tabulate import tabulate + except ImportError as err: + raise ImportError( + "Please install tabulate to use the tabulate option." + ) from err + divisor = _get_mem_divisor(units) + table_data = [] + key_list = list(next(iter(snapshot.values())).keys()) + headers = ["Device"] + [ + f"{key.value}" if isinstance(key, _RefType) else f"{key}" for key in key_list + ] + + for dev, dev_snap in snapshot.items(): + if _rounding_fn(dev_snap[_TOTAL_KEY], divisor, 2) <= 0: + continue + row = [str(dev)] + row.extend(f"{_rounding_fn(v, divisor, 2)} {units}" for v in dev_snap.values()) + table_data.append(row) + print(tabulate(table_data, headers=headers, tablefmt="rst")) + + +def _print_state_snapshots( + snapshots: dict[_State, list[dict[torch.device, dict[str, int]]]], units: str +) -> None: + for state, snapshot_list in snapshots.items(): + print(f"{state.value}") + for i, snapshot in enumerate(snapshot_list): + print(f"# {i + 1}:") + _print_snapshot(snapshot, units) + print() + + +def _print_state_snapshots_tabular( + snapshots: dict[_State, list[dict[torch.device, dict[str, int]]]], units: str +) -> None: + try: + from tabulate import tabulate + except ImportError as err: + raise ImportError( + "Please install tabulate to use the tabulate option." + ) from err + + table_data = [] + last_state_call = None + divisor = _get_mem_divisor(units) + for state, snapshot_list in snapshots.items(): + for i, snapshot in enumerate(snapshot_list): + state_call = f"{state.value} # {i + 1}" + for dev, dev_snap in snapshot.items(): + if _rounding_fn(dev_snap[_TOTAL_KEY], divisor, 2) <= 0: + continue + row = { + "State & Call": ( + state_call if state_call != last_state_call else "" + ), + "Device": str(dev), + } + last_state_call = state_call + for k, v in dev_snap.items(): + row[f"{k.value}" if isinstance(k, _RefType) else f"{k}"] = ( + f"{_rounding_fn(v, divisor, 2)} {units}" + ) + table_data.append(row) + print(tabulate(table_data, headers="keys", tablefmt="rst")) + + +class _UpdateType(Enum): + # These are used for tracking updates to the continuouly maintained memory snapshot. + # ADD - When a new tensor storage is tracked + # DEL - When a tensor storage is about to be finalized (garbage collected). + # REF - When a tensor reference is updated, for instance, the gradients are marked as + # generic backward reference types until the grad_hook categorizes them as gradients. + # SIZE - When a tensor's storage is resized. + ADD = auto() + DEL = auto() + REF = auto() + SIZE = auto() + + +class MemTracker(TorchDispatchMode): + """ + A TorchDispatchMode to track, categorize and attribute the tensor memory created or accessed within its context. + + It categorizes the tracked tensors as parameters, buffers, activations, gradients, temporary memory and optimizer states + as defined by ``_MemRefType`` within its context. It captures memory `snapshots` for the modules, called within its context, + at various states defined by ``_ModState``. + + Attributes: + memory_tracking: A weakref key dictionary to store the memory statistics of each module. Each key + is a reference to a module, and each value is a ``_ModMemStats`` object that stores the memory + statistics of the module. + + Note: + The MemTracker should be used as a context manager. The modules, optimizers, and any other tensors created within + the context of MemTracker will be tracked by default. Any tensors or stateful objects such as modules, optimizers etc. + that need to be tracked but are created outside the MemTracker should be registered using the `track_external` method. + The `track_external` method should be called before the MemTracker is used. Any tensors created outside the ``MemTracker`` + and not supplied to the `track_external` method will not be tracked by the ``MemTracker``. + + Example usage: + + .. code-block:: python + + module = ... + optimizer = ... + inp = ... + mem_tracker = MemTracker() + mem_tracker.track_external(module, optimizer, inp) + with mem_tracker as mt: + loss = module(inp) + print("After Forward:") + mt.display_snapshot("current") + loss.backward() + optimizer.step() + optimizer.zero_grad() + mt.display_snapshot("peak") + mt.display_modulewise_snapshots(depth=3, units="MiB") + + Known Limitations: + - The ``MemTracker`` does not track memory for tensors that bypass the ``TorchDispatchMode`` ex. under ``no_dispatch``. + - Resizing tensor storages directly by using non-Tensor methods other than using ``torch.Untyped_Storage.resize_`` + is not tracked. File a Github issue if you have use-cases for this. + - If the tensors are not traceable or wrappable subclasses of ``torch.Tensor``, then the tracker does not know how to + track their storages. File a Github issue if you have use-cases for this. + - During AC in the backward pass there might be misattribution between activation and temp memory, but the peak memory + will be tracked accurately. This will be fixed in the next update by hooking intricately with ``torch.uitls.checkpoint``. + """ + + def __init__(self) -> None: + self.memory_tracking = WeakIdKeyDictionary() + self._curr_mem_snap: dict[torch.device, dict[str, int]] = {} + self._peak_mem: dict[torch.device, int] = {} + self._peak_mem_snap: dict[torch.device, dict[str, int]] = {} + self._param_to_grad_hook_handles = WeakIdKeyDictionary() + self._optimizer_hook_handles: tuple[RemovableHandle, RemovableHandle] | None = ( + None + ) + # Dictionary to store the ``_WeakRefInfo`` instances corresponding to each tensor's storage. + self._WINFO = WeakIdKeyDictionary() + self._mod_tracker = ModTracker() + # This is a general memory tracker which can be used with any ``_RefType`` subclass + self._ref_class: type[_RefType] = _MemRefType + # Flags to track if we are in the AC region or optimizer step region + self._in_opt: bool = False + self._in_ac: bool = False + # Weak references to the topmost AC module currently active + self._ac_mod: weakref.ref | None = None + self._orig_resize = torch.UntypedStorage.resize_ + self._depth = 0 + + def _update_snap( + self, + u_type: _UpdateType, + winfo: _WeakRefInfo, + old_mem_consumed: int | None = None, + old_reftype: _RefType | None = None, + ) -> None: + # Initialize a flag to track if the total memory might drop to zero after updates. + maybe_zero = False + # Ensure the device entry exists in the current memory snapshot, initializing if necessary. + # pyrefly: ignore [no-matching-overload] + dev_snap = self._curr_mem_snap.setdefault( + winfo.device, dict.fromkeys(self._ref_class, 0) + ) + dev_snap.setdefault(_TOTAL_KEY, 0) + # Handle different types of updates based on the update type (`u_type`). + if u_type == _UpdateType.ADD: + # Increase the memory consumed for the specific reference type and update the total. + dev_snap[winfo.reftype] += winfo.mem_consumed + dev_snap[_TOTAL_KEY] += winfo.mem_consumed + elif u_type == _UpdateType.DEL: + # Decrease the memory consumed for the specific reference type and reduce the total. + dev_snap[winfo.reftype] -= winfo.mem_consumed + dev_snap[_TOTAL_KEY] -= winfo.mem_consumed + maybe_zero = True + elif u_type == _UpdateType.REF: + assert old_reftype is not None + # Adjust memory consumption between two reference types within the same device. + dev_snap[old_reftype] -= winfo.mem_consumed + dev_snap[winfo.reftype] += winfo.mem_consumed + elif u_type == _UpdateType.SIZE: + assert old_mem_consumed is not None + # Adjust the memory consumed for a reference type due to a change in size. + change = winfo.mem_consumed - old_mem_consumed + dev_snap[winfo.reftype] += change + dev_snap[_TOTAL_KEY] += change + maybe_zero = True + else: + raise ValueError(f"Invalid update type: {u_type}") + # Check if the total memory for the device has dropped to zero. + if maybe_zero: + if self._curr_mem_snap[winfo.device][_TOTAL_KEY] == 0: + # Remove the device entry from the memory snapshot if the total memory is zero. + del self._curr_mem_snap[winfo.device] + + def _update_and_maybe_create_winfos( + self, + t: torch.Tensor, + reftype: _RefType, + update_existing: bool = False, + ) -> set[_WeakRefInfo]: + sts = get_untyped_storages(t) + winfos = set() + for st in sts: + # Attempt to retrieve existing ``_WeakRefInfo`` and its weak reference from the tracking dictionary. + winfo, _ = self._WINFO.get(st, (None, None)) + if winfo is not None: + # If ``_WeakRefInfo`` exists, check if the reference type needs to be updated. + old_reftype = winfo.reftype + if old_reftype != reftype: + # Update the reference type and apply changes via ``_update_snap``. + winfo.reftype = reftype + self._update_snap(_UpdateType.REF, winfo, old_reftype=old_reftype) + winfos.add(winfo) + elif update_existing: + # If no existing ``_WeakRefInfo`` is found and update_existing is True, raise an error. + raise KeyError("No existing winfo found") + else: + # If no existing _WeakRefInfo is found and update_existing is False, create a new ``_WeakRefInfo``. + winfo, w_st = _WeakRefInfo.create_winfo( + st, t.device, reftype, self._delete_callback + ) + # Store the new ``_WeakRefInfo`` and its weak reference in the tracking dictionary. + self._WINFO[st] = (winfo, w_st) + # Update the snapshot for the newly added ``_WeakRefInfo``. + if winfo.mem_consumed > 0: + self._update_snap(_UpdateType.ADD, winfo) + winfos.add(winfo) + return winfos + + def _delete_callback(self, winfo: _WeakRefInfo, w_st: weakref.ref) -> None: + # Callback to be called when the storage object corresponding to the ``_WeakRefInfo`` + # instance is about to be finalized. + if winfo.mem_consumed > 0: + self._update_snap(_UpdateType.DEL, winfo) + + def _track_resize(self) -> None: + # Need to monkey-patch this because ``torch.UntypedStorage.resize_`` is not captured + # by ``TorchDispatchMode``. + @wraps(self._orig_resize) + def resize_(st: torch.UntypedStorage, size: int) -> None: + self._orig_resize(st, size) + winfo, _ = self._WINFO.get(st, (None, None)) + if winfo is not None and winfo.size != st.size(): + old_mem_consumed = winfo.mem_consumed + winfo.update_mem_consumed(st) + self._update_snap( + _UpdateType.SIZE, winfo, old_mem_consumed=old_mem_consumed + ) + + torch.UntypedStorage.resize_ = resize_ # type: ignore[method-assign, assignment] + + def _restore_resize(self) -> None: + torch.UntypedStorage.resize_ = self._orig_resize # type: ignore[method-assign] + + def _update_peak_stats(self, peak_state: _State) -> None: + # We first capture the current memory snapshot of the current tracker state then, + # We step through each of the modules we have tracked so far in ``memory_tracking`` + # and check if it is currently active by querying ``_mod_tracker.parents`` + # If it is active, we update the per device peak memory usage for the module + # corresponding to the ``_State`` which can be ``PEAK_FW`` or ``PEAK_BW``. + curr_snap = self._curr_mem_snap + + for mod_stats in self.memory_tracking.values(): + if mod_stats.mod_fqn in self._mod_tracker.parents: + if peak_state in mod_stats.snapshots: + for dev, dev_snap in curr_snap.items(): + if mod_stats.local_peak.get(dev, 0) < dev_snap[_TOTAL_KEY]: + mod_stats.local_peak[dev] = dev_snap[_TOTAL_KEY] + mod_stats.snapshots[peak_state][-1][dev] = deepcopy( + dev_snap + ) + + for dev, dev_snap in curr_snap.items(): + if self._peak_mem.get(dev, 0) < dev_snap[_TOTAL_KEY]: + self._peak_mem[dev] = dev_snap[_TOTAL_KEY] + self._peak_mem_snap[dev] = deepcopy(dev_snap) + + def _track(self, reftype: _RefType, t: torch.Tensor) -> None: + # Get the storages of the tensor and check if we have already tracked them. + # If yes, then check if the storage size has changed and update the current snapshot. + # Else create a new ``_WeakRefInfo`` instance and add it to the dictionary. + sts = get_untyped_storages(t) + for st in sts: + winfo, _ = self._WINFO.get(st, (None, None)) + if winfo is not None: + if winfo.size != st.size(): + old_mem_consumed = winfo.mem_consumed + winfo.update_mem_consumed(st) + self._update_snap( + _UpdateType.SIZE, winfo, old_mem_consumed=old_mem_consumed + ) + return + else: + winfo, w_st = _WeakRefInfo.create_winfo( + st, t.device, reftype, self._delete_callback + ) + self._WINFO[st] = (winfo, w_st) + # Update the current snapshot for the newly added ``_WeakRefInfo``. + if winfo.mem_consumed > 0: + self._update_snap(_UpdateType.ADD, winfo) + + def get_tracker_snapshot( + self, type: str = "current" + ) -> dict[torch.device, dict[str, int]]: + """ + Capture a snapshot of the memory usage breakdown per device, based on the specified type. + + Args: + type (str): The type of snapshot to capture. Can be "current" for the current memory usage or "peak" for the + peak memory usage. Defaults to "current". + Returns: + Dict[torch.device, Dict[str, int]]: A dictionary where each key is a torch.device, and each value is another + dictionary. This inner dictionary has keys representing memory reference + types as defined in ``_MemRefType`` and values representing the amount of + memory consumed in bytes. + Raises: + ValueError: If an invalid type is specified. + """ + if type == "current": + return deepcopy(self._curr_mem_snap) + elif type == "peak": + return deepcopy(self._peak_mem_snap) + else: + raise ValueError(f"Invalid type {type}") + + def _track_module_params_and_buffers( + self, module: nn.Module, install_grad_hooks: bool = True + ) -> tuple[int, int]: + # Track the parameters and buffers of the module if not already tracked. + # If the parameters have gradients, track the gradients as well. + # If install_grad_hooks is True, install a gradient hook on the parameters + # to track the gradients, if it has not already been installed. + # Return the total memory consumed by the parameters and buffers. + def _grad_hook(grad: torch.Tensor) -> None: + self._update_and_maybe_create_winfos( + grad, + _MemRefType.GRAD, + ) + + param_memory = 0 + for param in module.parameters(): + winfos = self._update_and_maybe_create_winfos( + param, + _MemRefType.PARAM, + ) + param_memory += sum(winfo.mem_consumed for winfo in winfos) + if param.grad is not None: + self._update_and_maybe_create_winfos( + param.grad, + _MemRefType.GRAD, + ) + if ( + self._param_to_grad_hook_handles.get(param, None) is None + and install_grad_hooks + ): + grad_hook_handle = param.register_hook(_grad_hook) + post_acc_grad_hook_handle = param.register_post_accumulate_grad_hook( + lambda p: (_grad_hook(p.grad)) + ) + self._param_to_grad_hook_handles[param] = ( + grad_hook_handle, + post_acc_grad_hook_handle, + ) + buffer_memory = 0 + for buffer in module.buffers(): + winfos = self._update_and_maybe_create_winfos( + buffer, + _MemRefType.BUFFER, + ) + buffer_memory += sum(winfo.mem_consumed for winfo in winfos) + return (param_memory, buffer_memory) + + def _track_inputs_or_outputs(self, args: Any) -> int: + # Calculate the memory consumed by the inputs or outputs of the module. + input_or_output_memory = 0 + + def add_inps_or_outs(t: torch.Tensor) -> None: + nonlocal input_or_output_memory + sts = get_untyped_storages(t) + for st in sts: + winfo, _ = self._WINFO.get(st, (None, None)) + if winfo is not None: + input_or_output_memory += winfo.mem_consumed + + tree_map_only(torch.Tensor, add_inps_or_outs, args) + return input_or_output_memory + + def _pre_fw_hook(self, module: nn.Module, inputs: Any) -> None: + # This is installed as a pre-fwd user hook with ``ModTracker.`` Based on the following cases we + # set the state and capture the memory snapshot for the module. + # Case 1: If the module is not in the ``memory_tracking`` dictionary, we track the parameters, buffers, + # input and output memory of the module. Create a new ``_ModMemStats`` instance for the module + # and add it to the ``memory_tracking`` dictionary. + # Case 2: If the module is already in the ``memory_tracking`` dictionary and we are in backward, this means + # we are in the AC region. We check if this is the top most module in the AC region. If it is, + # we store a weak reference and set the flag ``_in_ac`` to True. + # Case 3: If the module is already in the ``memory_tracking`` dictionary and we are in forward, this means + # this module is called for the second time. If it is a root module, that means we are in the next + # iteration and we error out. If it is not a root module, that means it's a submodule that is being + # used multiple times in the same iteration, which we allow and track. + # For Case 1 and 3, we also initialize the ``local_peak`` and ``PEAK_FW`` snapshot for the module. + mod_name = self._mod_tracker.get_known_fqn(module) + assert mod_name is not None + if module not in self.memory_tracking: + mod_stats = _ModMemStats(mod_name) + param_mem, buffer_mem = self._track_module_params_and_buffers( + module, install_grad_hooks=True + ) + input_mem = self._track_inputs_or_outputs(inputs) + mod_stats.parameter_mem = param_mem + mod_stats.buffer_mem = buffer_mem + mod_stats.input_mem = input_mem + self.memory_tracking[module] = mod_stats + state = _ModState.PRE_FW + + elif self._mod_tracker.is_bw: + mod_stats = self.memory_tracking[module] + state = _ModState.PRE_FW_AC + if self._ac_mod is None: + self._ac_mod = weakref.ref(module) + self._in_ac = True + else: + parents = set(self._mod_tracker.parents) - {mod_name} + if len(parents) == 1 and "Global" in parents: + raise NotImplementedError( + "MemTracker does not support memory tracking for multiple iterative calls." + " Either use ``reset_mod_stats`` to clear module memory stats for the previous iteration" + " or file a github issue if you need this feature." + ) + mod_stats = self.memory_tracking[module] + state = _ModState.PRE_FW + input_mem = self._track_inputs_or_outputs(inputs) + mod_stats.mod_fqn = mod_name + mod_stats.input_mem = input_mem + + mem_snapshot = self.get_tracker_snapshot() + if state == _ModState.PRE_FW: + mod_stats.local_peak = { + dev: dev_snap[_TOTAL_KEY] for dev, dev_snap in mem_snapshot.items() + } + mod_stats.snapshots.setdefault(_ModState.PEAK_FW, []).append(mem_snapshot) + mod_stats.snapshots.setdefault(state, []).append(deepcopy(mem_snapshot)) + + def _post_fw_hook(self, module: nn.Module, inputs: Any, outputs: Any) -> None: + # This is installed as a post-fwd user hook with ``ModTracker``. Based on the following cases we + # set the state and capture the memory snapshot for the module. + # Case 1: This is called in backward, which means we are in the AC region. If this is the top most module + # in the AC region, we set the flag ``_in_ac`` to False. + # Case 2: This is called in forward so we calculate the output memory + # of the module and update its mod_stats. + mod_stats = self.memory_tracking[module] + if self._mod_tracker.is_bw: + state = _ModState.POST_FW_AC + if self._ac_mod is not None and self._ac_mod() is module: + self._ac_mod = None + self._in_ac = False + else: + state = _ModState.POST_FW + output_mem = self._track_inputs_or_outputs(outputs) + mod_stats.output_mem = output_mem + mod_stats.snapshots.setdefault(state, []).append(self.get_tracker_snapshot()) + + def _pre_bw_hook(self, module: nn.Module, args: Any) -> None: + # This is installed as a pre-bwd user hook with ``ModTracker``. We set the state and capture the + # snapshot for the module. We also initialize the ``local_peak`` and ``PEAK_BW`` snapshot for it. + # If the module is None, we skip the hook. + # This can happen since this installed inside a multi-grad hook on the module's output tensors + # and the module itself may not be alive during backward. + if module is None: + warnings.warn("Module is None. Skipping PRE_BW hook.", stacklevel=2) + return + mod_stats = self.memory_tracking[module] + mem_snapshot = self.get_tracker_snapshot() + mod_stats.local_peak = { + dev: dev_snap[_TOTAL_KEY] for dev, dev_snap in mem_snapshot.items() + } + mod_stats.snapshots.setdefault(_ModState.PEAK_BW, []).append(mem_snapshot) + mod_stats.snapshots.setdefault(_ModState.PRE_BW, []).append( + deepcopy(mem_snapshot) + ) + + def _post_bw_hook(self, module: nn.Module, args: Any) -> None: + # This is installed as a post-bwd user hook with ``ModTracker``. We set the state and capture the + # snapshot for the module if it is not None. + # This can happen since this installed inside a multi-grad hook on the module's input tensors + # and the module itself may not be alive during backward. + if module is None: + warnings.warn("Module is None. Skipping POST_BW hook.", stacklevel=2) + return + mod_stats = self.memory_tracking[module] + mod_stats.snapshots.setdefault(_ModState.POST_BW, []).append( + self.get_tracker_snapshot() + ) + + def _track_optimizer_states( + self, reftype: _RefType, optimizer: optim.Optimizer + ) -> None: + for states in optimizer.state.values(): + for val in states.values(): + if isinstance(val, torch.Tensor): + self._update_and_maybe_create_winfos( + val, + reftype, + ) + + def _register_global_optimizer_hook(self) -> None: + # Register a hook on the optimizer step to track the optimizer states. + # The pre-hook is to set the flag ``_in_opt`` to True. The post-hook unsets the flag, + # and also tracks any optimizer states that are created during the optimizer step. + def _opt_step_pre_hook( + optimizer: optim.Optimizer, args: Any, kwargs: Any + ) -> None: + self._in_opt = True + + def _opt_step_post_hook( + optimizer: optim.Optimizer, args: Any, kwargs: Any + ) -> None: + self._track_optimizer_states(_MemRefType.OPT, optimizer) + self._in_opt = False + + self._optimizer_hook_handles = ( + register_optimizer_step_pre_hook(_opt_step_pre_hook), + register_optimizer_step_post_hook(_opt_step_post_hook), + ) + + def _deregister_param_and_optimizer_hooks(self) -> None: + for ( + grad_hook_handle, + post_acc_grad_hook_handle, + ) in self._param_to_grad_hook_handles.values(): + grad_hook_handle.remove() + post_acc_grad_hook_handle.remove() + self._param_to_grad_hook_handles.clear() + + if self._optimizer_hook_handles is not None: + for handle in self._optimizer_hook_handles: + handle.remove() + self._optimizer_hook_handles = None + + def track_external( + self, *external: nn.Module | optim.Optimizer | torch.Tensor + ) -> None: + """ + Track tensors and stateful objects like modules, optimizers etc. that are created outside the MemTracker. + + This method should be called before the ``MemTracker`` is used. Any tensors that are not module parameters, buffers, + gradients activations, or optimizer states will be categorized as ``Other``. If you want them categorized with a + custom name, please file a GitHub issue. Any tensors created outside the MemTracker and not supplied to this + method will not be be tracked by ``MemTracker``. + + Args: + *external (Union[nn.Module, optim.Optimizer, torch.Tensor]): The external modules, optimizers, and + tensors to be tracked. + """ + flat_external, _ = tree_flatten(external) + for obj in flat_external: + if isinstance(obj, torch.Tensor): + self._update_and_maybe_create_winfos( + obj, + _MemRefType.OTH, + ) + elif isinstance(obj, torch.nn.Module): + self._track_module_params_and_buffers(obj, install_grad_hooks=False) + elif isinstance(obj, optim.Optimizer): + self._track_optimizer_states(_MemRefType.OPT, obj) + elif obj is None: + continue + else: + raise TypeError( + f"Object of type {type(obj)} is not supported for tracking. " + f"Only stateful objects like modules, optimizers, and tensors are supported." + ) + + def display_snapshot( + self, type: str = "current", units: str = "B", tabulate: bool = False + ) -> None: + """ + Display the memory usage breakdown snapshot of the tracker based on the specified type and units. + + Keyword args: + type (str): The type of snapshot to display. Can be "current" for the current memory usage or "peak" for the + peak memory usage. Defaults to "current". + units (str): The units to use for displaying memory usage. Defaults to "B". Supports ["B", "KiB", "MiB", "GiB"]. + tabulate (bool): Whether to display the snapshot in a tabular format. Defaults to False. + """ + snapshot = self.get_tracker_snapshot(type) + if tabulate: + _print_snapshot_tabular(snapshot, units) + else: + _print_snapshot(snapshot, units) + + def display_modulewise_snapshots( + self, depth: int = 2, units: str = "B", tabulate: bool = False + ) -> None: + """ + Print per device memory breakdown snapshot for each module called within MemTracker. + + Snapshots are displayed for the states defined by ``_ModState``. + The module hierarchy is displayed up to the specified depth. + + Keyword Args: + depth (int, optional): The depth of the module hierarchy to display. Defaults to 2. + units (str, optional): The units to use for memory tracking. Defaults to "B". Supports ["B", "KiB", "MiB", "GiB"]. + tabulate (bool, optional): Whether to display the snapshot in a tabular format. Defaults to False. + """ + + def natural_sort_key(s: str) -> list[int | str]: + return [ + int(text) if text.isdigit() else text.lower() + for text in re.split("([0-9]+)", s) + ] + + for mod_stats in sorted( + self.memory_tracking.values(), + key=lambda m_stats: natural_sort_key(m_stats.mod_fqn), + ): + mod_fqn = mod_stats.mod_fqn + mod_depth = mod_fqn.count(".") + 1 + if mod_depth > depth: + continue + print(f"Module: {mod_fqn}") + if tabulate: + _print_state_snapshots_tabular(mod_stats.snapshots, units) + else: + _print_state_snapshots(mod_stats.snapshots, units) + + def reset_mod_stats(self) -> None: + """ + Reset all the module memory stats. Clears ``memory_tracking`` dictionary. + """ + self.memory_tracking.clear() + + def __enter__(self) -> "MemTracker": + if self._depth == 0: + self._register_global_optimizer_hook() + self._mod_tracker.register_user_hooks( + self._pre_fw_hook, + self._post_fw_hook, + self._pre_bw_hook, + self._post_bw_hook, + ) + self._track_resize() + self._peak_mem_snap = self.get_tracker_snapshot() + self._peak_mem = { + dev: dev_snap[_TOTAL_KEY] + for dev, dev_snap in self._peak_mem_snap.items() + } + self._mod_tracker.__enter__() + super().__enter__() + self._depth += 1 + return self + + # pyrefly: ignore [bad-override] + def __exit__(self, *args: Any) -> None: + self._depth -= 1 + if self._depth == 0: + self._deregister_param_and_optimizer_hooks() + self._mod_tracker.clear_user_hooks() + self._restore_resize() + self._mod_tracker.__exit__(*args) + super().__exit__(*args) + + def __torch_dispatch__(self, func, types, args=(), kwargs=None): # type: ignore[no-untyped-def] + # When running this mode with DTensor, ordinarily all modes will + # run **before** subclasses get a chance to run. + # Returning NotImplemented here gives us a chance to let DTensor + # run and desugar into local tensor ops, before `MemTracker` sees them. + if any(t == DTensor for t in types): + return NotImplemented + if ( + func is torch.ops._c10d_functional.wait_tensor.default + and active_fake_mode() + ): + # N.B: This is a hacky way to override the Meta IMPL of wait_tensor. The original impl returns + # a new tensor which does not happen in eager mode, when a wait_tensor is called. + # pyrefly: ignore [index-error] + res = args[0] + else: + res = func(*args, **kwargs or {}) + # If we are tracking an optimizer state, we use the optimizer reference type. + # If we are in backward region and not in AC region, we use the backward reference type. + # Else we use the forward reference type. + if self._in_opt: + reftype = _MemRefType.OPT + elif self._mod_tracker.is_bw and not self._in_ac: + reftype = _MemRefType.TEMP + else: + reftype = _MemRefType.ACT + tree_map_only(torch.Tensor, partial(self._track, reftype), res) + peak_state = _ModState.PEAK_BW if self._mod_tracker.is_bw else _ModState.PEAK_FW + self._update_peak_stats(peak_state) + return res diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_tools/memory_tracker.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_tools/memory_tracker.py new file mode 100644 index 0000000000000000000000000000000000000000..890d2be2794a4e570085a91da1440842473c9f49 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_tools/memory_tracker.py @@ -0,0 +1,304 @@ +# mypy: allow-untyped-defs +import operator +import pickle +from collections import defaultdict +from collections.abc import Callable, Sequence +from itertools import chain +from typing import Any, no_type_check, TYPE_CHECKING + +import torch +import torch.nn as nn +from torch.utils._python_dispatch import TorchDispatchMode + + +if TYPE_CHECKING: + from torch.utils.hooks import RemovableHandle + + +BYTES_PER_MB = 1024 * 1024.0 + + +class MemoryProfileDispatchMode(TorchDispatchMode): + """Run in ``TorchDispatchMode`` to get memory stats at operator level.""" + + def __init__(self, memory_tracker) -> None: + self.memory_tracker = memory_tracker + + def __torch_dispatch__(self, func, types, args=..., kwargs=None): + rs = func(*args, **kwargs) + if func is torch.ops.aten.detach.default: + return rs + func_name: str = ( + self.memory_tracker._cur_module_name + + "." + + func.__name__ + + "_" + + str(self.memory_tracker._operator_names[func.__name__]) + ) + self.memory_tracker._operator_names[func.__name__] = ( + self.memory_tracker._operator_names[func.__name__] + 1 + ) + self.memory_tracker._record_memory_stats(func_name) + + return rs + + +class MemoryTracker: + """ + Collect and plot the memory stats at operator level. + + Includes ``memories_allocated``, ``memories_active`` and ``memories_reserved``. + It also prints a summary for the top 20 operators that generate the most memories. + + Example usage: + + >>> # xdoctest: +SKIP(failing) + >>> net.cuda() + >>> input = input.cuda() + + >>> mem_tracker = MemoryTracker() + >>> mem_tracker.start_monitor(net) + + >>> net.zero_grad(True) + >>> loss = net(input) + >>> if isinstance(loss, dict): + >>> loss = loss['out'] + >>> loss.sum().backward() + >>> net.zero_grad(set_to_none=True) + + >>> mem_tracker.stop() + >>> mem_tracker.summary() + >>> mem_tracker.show_traces() + """ + + def __init__(self) -> None: + torch._C._log_api_usage_once("torch.distributed.memory_tracker") + self._hooks: list[RemovableHandle] = [] + self._operator_names: dict[str, int] = defaultdict(int) + self.memories_allocated: dict[int, dict[str, float]] = defaultdict() + self.memories_active: dict[int, dict[str, float]] = defaultdict() + self.memories_reserved: dict[int, dict[str, float]] = defaultdict() + self._markers: dict[str, int] = defaultdict(int) + self._cur_module_name: str = "" + self._op_index: int = 0 + self._num_alloc_retries: int = 0 + self._device_module = torch.get_device_module() + + @no_type_check + def start_monitor(self, root_module: nn.Module) -> None: + """ + Register module hooks and entering ``MemoryProfileDispatchMode``. + + This enables operator level memory stats can be tracked during module runtime. + """ + self._clear_state() + root_module.__setattr__("_memory_tracker_is_root", True) + for name, m in root_module.named_modules(): + if m is not root_module: + m.__setattr__("_memory_tracker_is_root", False) + # fused_proxy_group does not support hooks + if ".fused_proxy_grouped_embedding_bag" in name: + continue + # hook ordering with other hooks added by users is not managed, so + # the memory stats tracked here may not completely accurate. + h1 = m.register_forward_pre_hook(self._create_pre_forward_hook(name)) + h2 = m.register_forward_hook(self._create_post_forward_hook(name)) + # it does not work well with jagged tensor somehow, the root cause is not + # clear and remove it for now as it does not really capture important info. + # h3 = m.register_backward_hook(self._create_backward_hook(name)) + self._hooks.extend([h1, h2]) + self._device_module.empty_cache() + assert getattr(self, "profile_mode", None) is None + self.profile_mode = MemoryProfileDispatchMode(self) + self.profile_mode.__enter__() + + @no_type_check + def stop(self) -> None: + """ + Remove module hooks and exit ``MemoryProfileDispatchMode`` to stop tracking memory stats at operator level. + + Get some aggregated stats when the memory_tracker() is enabled, like ``num_alloc_retries``. + """ + self._num_alloc_retries = self._device_module.memory_stats().get( + "num_alloc_retries", 0 + ) + + for h in self._hooks: + h.remove() + self._hooks.clear() + assert getattr(self, "profile_mode", None) is not None + self.profile_mode.__exit__(None, None, None) + self.profile_mode = None + + @no_type_check + def summary(self, top: int = 20) -> None: + """ + Print out the top operators that generate the most memories. + + The number of the top operators can be configured. + """ + op_diff: dict[str, float] = defaultdict(float) + op_name, previous_allocated_memory = self.memories_allocated[0] + for i in range(1, self._op_index): + op_name, current_allocated_memory = self.memories_allocated[i] + op_diff[op_name] = current_allocated_memory - previous_allocated_memory + previous_allocated_memory = current_allocated_memory + + print("------------------------------------------------") + print(f"The number of alloc retries are: {self._num_alloc_retries}") + print(f"Top {top} ops that generates memory are:") + for k, v in sorted(op_diff.items(), key=operator.itemgetter(1), reverse=True)[ + :top + ]: + print(f"{k}: {v}MB") + print("------------------------------------------------") + + @no_type_check + def show_traces(self, path: str = "") -> None: + import matplotlib.pyplot as plt + + def _plot_figure(x, y_values, labels): + min_val = min(chain.from_iterable(y_values)) * 0.999 + max_val = max(chain.from_iterable(y_values)) * 1.001 + plt.figure() + for y, label in zip(y_values, labels): + plt.plot(x, y, label=label) + plt.xlabel("# Operator Calls") + plt.ylabel("Memory (MB)") + plt.legend() + for marker_name, marker in self._markers.items(): + if marker_name == "fw_bw_boundary": + plt.plot( + [marker, marker], + [min_val, max_val], + "r", + lw=2, + label=marker_name, + ) + else: + plt.plot( + [marker, marker], + [min_val, max_val], + "k-", + lw=2, + label=marker_name, + ) + + if path != "": + self.load(path) + + y_1 = [gb for (name, gb) in self.memories_allocated.values()] + y_2 = [gb for (name, gb) in self.memories_active.values()] + y_3 = [gb for (name, gb) in self.memories_reserved.values()] + x = list(range(len(y_1))) + # Split figures when there is big difference between + # "reserved_memory" and "allocated_memory" or "active_memory". + _plot_figure( + x, + [list(y_1), list(y_2), list(y_3)], + ["allocated_memory", "active_memory", "reserved_memory"], + ) + _plot_figure(x, [list(y_1)], ["allocated_memory"]) + _plot_figure(x, [list(y_2)], ["active_memory"]) + _plot_figure(x, [list(y_3)], ["reserved_memory"]) + + def save_stats(self, path: str) -> None: + """Save the stats using pickle during runtime if users want to plot the traces in other places like notebook.""" + stats = { + "memories_allocated": self.memories_allocated, + "memories_active": self.memories_active, + "memories_reserved": self.memories_reserved, + "markers": self._markers, + "num_alloc_retries": self._num_alloc_retries, + } + + with open(path, "wb") as f: + pickle.dump(stats, f, pickle.HIGHEST_PROTOCOL) + + def load(self, path: str) -> None: + """Load the pickled memory stats to plot the traces or print the summary.""" + with open(path, "rb") as f: + stats = pickle.load(f) + + self.memories_allocated = stats["memories_allocated"] + self.memories_active = stats["memories_active"] + self.memories_reserved = stats["memories_reserved"] + self._markers = stats["markers"] + self._num_alloc_retries = stats["num_alloc_retries"] + + def _create_pre_forward_hook(self, name: str) -> Callable: + """Prefix operator name with current module and 'forward', and insert 'fw_start' marker at forward pass start.""" + + def _pre_forward_hook(module: nn.Module, inputs: Any) -> None: + self._cur_module_name = f"{name}.forward" + if ( + # pyrefly: ignore [invalid-argument] + hasattr(module, "_memory_tracker_is_root") + # pyrefly: ignore [not-callable] + and module._memory_tracker_is_root + ): + self._add_marker("fw_start") + + return _pre_forward_hook + + def _create_post_forward_hook(self, name: str) -> Callable: + """Insert the marker 'fw_bw_boundary' at the boundary of forward and backward pass.""" + + def _post_forward_hook( + module: nn.Module, + inputs: Sequence[torch.Tensor], + outputs: Sequence[torch.Tensor], + ) -> None: + if ( + # pyrefly: ignore [invalid-argument] + hasattr(module, "_memory_tracker_is_root") + # pyrefly: ignore [not-callable] + and module._memory_tracker_is_root + ): + self._add_marker("fw_bw_boundary") + + return _post_forward_hook + + def _create_backward_hook(self, name: str) -> Callable: + """Insert the current module name with backward prefix for the operator name.""" + + def _backward_hook( + module: nn.Module, grad_input: torch.Tensor, grad_output: torch.Tensor + ) -> None: + self._cur_module_name = f"{name}.backward" + + return _backward_hook + + @no_type_check + def _record_memory_stats(self, fn_name: str) -> None: + """ + Record current memory allocated, current memory active and current memory reserved. + + The memory stats dict is indexed with ``self._op_index``. + """ + memory_allocated: float = self._device_module.memory_allocated() / BYTES_PER_MB + memory_reserved: float = self._device_module.memory_reserved() / BYTES_PER_MB + memory_active: float = ( + self._device_module.memory_stats().get("active_bytes.all.current", 0) + / BYTES_PER_MB + ) + self.memories_allocated[self._op_index] = (fn_name, memory_allocated) + self.memories_reserved[self._op_index] = (fn_name, memory_reserved) + self.memories_active[self._op_index] = (fn_name, memory_active) + self._op_index += 1 + + def _add_marker(self, marker_name: str) -> None: + """Set the marker's x-axis value.""" + marker_val = len(self.memories_allocated.values()) + self._markers[marker_name] = marker_val + + def _clear_state(self) -> None: + """Clear states when start_monitor() is called.""" + self._operator_names.clear() + self.memories_allocated.clear() + self.memories_active.clear() + self.memories_reserved.clear() + self._markers.clear() + self._cur_module_name = "" + self._op_index = 0 + self._num_alloc_retries = 0 diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_tools/mod_tracker.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_tools/mod_tracker.py new file mode 100644 index 0000000000000000000000000000000000000000..bae745bcc58040dd14a1dfbd0c2f116554870689 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_tools/mod_tracker.py @@ -0,0 +1,259 @@ +# mypy: allow-untyped-defs +import warnings +import weakref +from collections.abc import Callable + +import torch +from torch.autograd.graph import register_multi_grad_hook +from torch.nn.modules.module import ( + register_module_forward_hook, + register_module_forward_pre_hook, +) +from torch.utils._pytree import tree_flatten + + +__all__ = ["ModTracker"] + + +class ModTracker: + """ + ``ModTracker`` is a context manager that tracks the nn.Module hierarchy during execution + so that other system can query which Module is currently being executed (or its backward is being + executed). + + You can access the ``parents`` attribute on this context manager to get the set of all the + Modules currently being executed via their fqn (fully qualified name, also used as the key within + the state_dict). + You can access the ``is_bw`` attribute to know if you are currently running in backward or not. + + Note that ``parents`` is never empty and always contains the "Global" key. The ``is_bw`` flag + will remain ``True`` after the forward until another Module is executed. If you need it to be + more accurate, please submit an issue requesting this. Adding a map from fqn to the module instance + is possible but not done yet, please submit an issue requesting this if you need it. + + Example usage + + .. code-block:: python + + mod = torch.nn.Linear(2, 2) + + with ModTracker() as tracker: + # Access anything during the forward pass + def my_linear(m1, m2, bias): + print(f"Current modules: {tracker.parents}") + return torch.mm(m1, m2.t()) + bias + + torch.nn.functional.linear = my_linear + + mod(torch.rand(2, 2)) + + """ + + parents: set[str] + """ + A Set containing the fqn for each module currently running their forward + """ + + def __init__(self): + self.parents = {"Global"} + self._active_module_cnt = {} + self._known_modules: weakref.WeakKeyDictionary = weakref.WeakKeyDictionary() + self._seen_modules: weakref.WeakSet = weakref.WeakSet() + self._has_callback = False + self._post_bw_callbacks_to_enqueue: list[Callable] = [] + self._user_pre_fw_hook = None + self._user_post_fw_hook = None + self._user_pre_bw_hook = None + self._user_post_bw_hook = None + + def _maybe_set_engine_callback(self): + # This assumes no concurrent calls to backward + if self._has_callback: + return + + for post_bw_callback in reversed(self._post_bw_callbacks_to_enqueue): + torch.autograd.Variable._execution_engine.queue_callback(post_bw_callback) + self._post_bw_callbacks_to_enqueue.clear() + + def callback(): + self.parents = {"Global"} + self._has_callback = False + + torch.autograd.Variable._execution_engine.queue_callback(callback) + self._has_callback = True + + @property + def is_bw(self): + """ + A boolean marking if this is currently running during the backward pass or not + """ + return torch._C._current_graph_task_id() != -1 + + def get_known_fqn(self, mod): + """ + Return the fqn for the given module if it is known to the ``ModTracker``, otherwise ``None``. + """ + return self._known_modules.get(mod, None) + + def register_user_hooks( + self, + pre_fw_hook: Callable | None = None, + post_fw_hook: Callable | None = None, + pre_bw_hook: Callable | None = None, + post_bw_hook: Callable | None = None, + ): + """ + Registers user-specified hooks to be called before/after the forward/backward pass for each + module tracked by the ``ModTracker``. One or more can be ``None``. + Args: + pre_fw_hook (Callable, optional): A hook to be called before the forward pass for the + module. It should have the following signature: + pre_fw_hook (module, input) -> None + post_fw_hook (Callable, optional): A hook to be called after the forward pass for the + module. It should have the following signature: + post_fw_hook (module, input, output) -> None + pre_bw_hook (Callable, optional): A multi-grad hook to be called on all the outputs of + the module that require gradients. It should have the following signature: + pre_bw_hook (module, grad_output) -> None + post_bw_hook (Callable, optional): A multi-grad hook to be called on all the inputs of + the module that require gradients. It should have the following signature: + post_bw_hook (module, grad_input) -> None + Raises: + AssertionError: If a new hook is provided when one is already registered. + Note: + If the module is not alive during the backward pass, the pre_bw_hook and post_bw_hook will + will receive None as the module argument. + The module fqn will be present in the ``parents`` attribute when each of the hooks is called. + Hooks are intended to be used as markers only not to modify the inputs/outputs. + """ + + def set_hook(hook, user_hook, hook_name): + if hook is not None and user_hook is not None: + raise AssertionError( + f"Only one {hook_name} can be registered at a time" + f" Clear the existing hook by calling ``clear_user_hooks`` before registering a new one" + ) + return hook + + self._user_pre_fw_hook = set_hook( + pre_fw_hook, self._user_pre_fw_hook, "pre_fw_hook" + ) + self._user_post_fw_hook = set_hook( + post_fw_hook, self._user_post_fw_hook, "post_fw_hook" + ) + self._user_pre_bw_hook = set_hook( + pre_bw_hook, self._user_pre_bw_hook, "pre_bw_hook" + ) + self._user_post_bw_hook = set_hook( + post_bw_hook, self._user_post_bw_hook, "post_bw_hook" + ) + + def clear_user_hooks(self): + """ + Clears the user specified hooks registered with ``register_user_hooks`` + """ + self._user_pre_fw_hook = None + self._user_post_fw_hook = None + self._user_pre_bw_hook = None + self._user_post_bw_hook = None + + def _get_mod_name(self, mod): + if mod not in self._known_modules: + self._known_modules[mod] = type(mod).__name__ + mod_name = self._known_modules[mod] + if mod not in self._seen_modules: + for name, submod in mod.named_children(): + self._known_modules[submod] = f"{mod_name}.{name}" + self._get_mod_name(submod) + self._seen_modules.add(mod) + return mod_name + + def _get_append_fn(self, w_mod, name, is_bw): + def fn(*args): + if is_bw: + self._maybe_set_engine_callback() + if name in self.parents and not self.is_bw: + + def custom_formatwarning(msg, category, filename, lineno, line=None): + return f"{filename}:{lineno}: {category.__name__}: {msg} \n" + + # pyrefly: ignore [bad-assignment] + warnings.formatwarning = custom_formatwarning + warnings.warn( + "The module hierarchy tracking maybe be messed up." + " Please file a bug to PyTorch, if it is the case.", + stacklevel=2, + ) + if name not in self.parents: + self._active_module_cnt[name] = 1 + self.parents.add(name) + else: + self._active_module_cnt[name] += 1 + + if self._user_pre_bw_hook is not None and is_bw: + self._user_pre_bw_hook(w_mod(), args) + + return fn + + def _get_pop_fn(self, w_mod, name, is_bw): + def fn(*args): + if self._user_post_bw_hook is not None and is_bw: + self._user_post_bw_hook(w_mod(), args) + if name in self.parents: + self._active_module_cnt[name] -= 1 + if self._active_module_cnt[name] == 0: + self.parents.remove(name) + elif not self.is_bw: + # Due to some input/output not requiring gradients, we cannot enforce + # proper nesting in backward + raise RuntimeError( + "The Module hierarchy tracking is wrong. Report a bug to PyTorch" + ) + + return fn + + def _fw_pre_hook(self, mod, input): + if torch._dynamo.eval_frame._is_in_optimized_module(): + return + + name = self._get_mod_name(mod) + w_mod = weakref.ref(mod) + self._get_append_fn(w_mod, name, False)() + if self._user_pre_fw_hook is not None: + self._user_pre_fw_hook(mod, input) + args, _ = tree_flatten(input) + tensors = [a for a in args if isinstance(a, torch.Tensor) and a.requires_grad] + if not self.is_bw: + if tensors: + register_multi_grad_hook(tensors, self._get_pop_fn(w_mod, name, True)) + else: + self._post_bw_callbacks_to_enqueue.append( + self._get_pop_fn(w_mod, name, True) + ) + + def _fw_post_hook(self, mod, input, output): + if torch._dynamo.eval_frame._is_in_optimized_module(): + return + + name = self._get_mod_name(mod) + w_mod = weakref.ref(mod) + if self._user_post_fw_hook is not None: + self._user_post_fw_hook(mod, input, output) + self._get_pop_fn(w_mod, name, False)() + args, _ = tree_flatten(output) + tensors = [a for a in args if isinstance(a, torch.Tensor) and a.requires_grad] + if not self.is_bw and tensors: + register_multi_grad_hook( + tensors, self._get_append_fn(w_mod, name, True), mode="any" + ) + + def __enter__(self): + self._fw_pre_handle = register_module_forward_pre_hook(self._fw_pre_hook) + self._fw_post_handle = register_module_forward_hook( + self._fw_post_hook, always_call=True + ) + return self + + def __exit__(self, *args): + self._fw_pre_handle.remove() + self._fw_post_handle.remove() diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_tools/runtime_estimator.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_tools/runtime_estimator.py new file mode 100644 index 0000000000000000000000000000000000000000..caf399cf6a802677f754084d2d867a743036520f --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_tools/runtime_estimator.py @@ -0,0 +1,398 @@ +# Owner(s): ["module: unknown"] +from collections import defaultdict +from typing import Any, TYPE_CHECKING +from typing_extensions import Self + +import torch +import torch.utils._pytree as pytree +from torch._guards import active_fake_mode +from torch._subclasses.fake_tensor import FakeTensorMode +from torch.distributed._tools.mod_tracker import ModTracker +from torch.utils._mode_utils import no_dispatch +from torch.utils._python_dispatch import TorchDispatchMode +from torch.utils._runtime_estimation import ( + _FLOAT_TYPES, + _IGNORE_OPS, + _VIEW_OPS, + get_compute_time, + get_transfer_time, +) + + +if TYPE_CHECKING: + from collections.abc import Callable + +__all__ = ["RuntimeEstimator"] + + +class RuntimeEstimator(TorchDispatchMode): + """ + Estimates the GPU runtime in milliseconds using various estimation methods under the ``FakeTensorMode``. + + This class provides a ``TorchDispatchMode`` based context manager that can be used to estimate the eager + runtime of PyTorch functions. It supports two estimation modes, benchmarking (`operator-level-benchmark`) and + roofline cost modeling (`operator-level-cost-model`). + For modules executed under this context manager, it aggregates the forward and backward operation runtimes + and also records their execution orders. + + Attributes: + mod_runtimes (Dict[str, Dict[str, float]]): A dictionary of module runtimes. The key to the outer dictionary + is the fully qualified name (FQN) of the module. For each module the forward and backward runtimes of the + operations are aggregated in the inner dictionary keyed by 'fw' and 'bw'. + mod_fw_pre_order (List[str]): List of module FQNs in pre-forward execution order. + mod_bw_pre_order (List[str]): List of module FQNs in pre-backward execution order. + mod_fw_post_order (List[str]): List of module FQNs in post-forward execution order. + mod_bw_post_order (List[str]): List of module FQNs in post-backward execution order. + total_runtime (float): The total estimated runtime in milliseconds. + + Note: + 1) The benchmarking estimate mode will execute kernels on GPU and assumes that every operation can run in + isolation without causing an OOM error. It is also designed to be used only under ``FakeTensorMode``. + 2) Currently wrapper tensor sub-classes such as ``DTensor`` won't produce correct estimates. We plan to support + them in future PRs. + 3) We only estimate the compute time, if your code has communication, it will not be considered. Again, we will + support this in future PRs. + + Example usage: + + .. code-block:: python + + runtime_estimator = RuntimeEstimator() + with FakeTensorMode(): + module = ... + optimizer = ... + inp = ... + with runtime_estimator(estimate_mode_type="operator-level-cost-model"): + loss = module(inp) + loss.backward() + optimizer.step() + optimizer.zero_grad() + runtime_estimator.display_modulewise_stats() + """ + + _no_fallback_kernel: set[torch._ops._OpNamespace] = set() + fake_mode: FakeTensorMode + + def __init__(self) -> None: + super().__init__() + self._estimate: Callable + self._estimate_mode_type: str + self._mod_tracker = ModTracker() + self.mod_runtimes: dict[str, dict[str, float]] = defaultdict( + lambda: defaultdict(lambda: 0.0) + ) + self.mod_fw_pre_order: list[str] = [] + self.mod_bw_pre_order: list[str] = [] + self.mod_fw_post_order: list[str] = [] + self.mod_bw_post_order: list[str] = [] + self.total_runtime: float = 0.0 + + # Adapted from: https://github.com/pytorch/pytorch/blob/9b902b3ee3bd608a19543362b66bf06c373dd374/torch/_subclasses/fake_tensor.py#L1969 # noqa: PGH004,B950 + # NB: returns fake tensors + @classmethod + def _maybe_run_and_benchmark_fallback_kernel( # type: ignore[no-untyped-def] + cls, + func, + args, + kwargs, + orig_not_implemented_exception, + ): + """ + Runs and benchmarks a fallback kernel for a given function. + + Args: + func (Callable): The function to benchmark. + args (Tuple): The arguments to pass to the function. + kwargs (Dict[str, Any]): The keyword arguments to pass to the function. + orig_not_implemented_exception (Exception): The original exception to raise if the fallback kernel + is not implemented. + + Returns: + Tuple[Any, float]: A tuple containing the result of the function and + the mean operation time in milliseconds. + """ + # these should all be supported, just to be safe + # avoid fallback for operators which inplace modify metadata + # because the input fake tensors would be umodified + if torch.Tag.inplace_view in func.tags: # type: ignore[attr-defined] + raise orig_not_implemented_exception + + inp_impls = {} + flat_args, args_spec = pytree.tree_flatten((args, kwargs)) + # Don't use in_kernel_invocation_manager(fake_mode) as we want to do + # REAL compute (not with meta device) + with no_dispatch(): + + def to_real_tensor(e): # type: ignore[no-untyped-def] + if cls.fake_mode.is_our_fake(e): + if e.dtype in _FLOAT_TYPES: + out = torch.rand_like(e, device=e.fake_device) + else: + out = torch.ones_like(e, device=e.fake_device) + if e.is_sparse: + out._coalesced_(e.is_coalesced()) + inp_impls[id(out)] = e + return out + return e + + flat_args = [to_real_tensor(a) for a in flat_args] + args, kwargs = pytree.tree_unflatten(flat_args, args_spec) + r = func(*args, **kwargs) + warmup_iters, actual_iters = 2, 3 + for _ in range(warmup_iters): + func(*args, **kwargs) + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + start_event.record(torch.cuda.current_stream()) + for _ in range(actual_iters): + func(*args, **kwargs) + end_event.record(torch.cuda.current_stream()) + torch.cuda.synchronize() + cuda_time = start_event.elapsed_time(end_event) + mean_op_time = cuda_time / actual_iters + + storages = set() + + for e in flat_args: + if isinstance(e, torch.Tensor): + if not e.is_sparse: + storages.add(e._typed_storage()._cdata) + + # TODO: also check metadata change on inputs + # proper aliasing/metadata relationship between outputs and inputs will + # not be set up, bc of conversion to device, unless we can reuse an + # input impl + + def map_out(e): # type: ignore[no-untyped-def] + if id(e) not in inp_impls and ( + isinstance(e, torch.Tensor) + and not e.is_sparse + and e._typed_storage()._cdata in storages + ): + raise orig_not_implemented_exception + + if isinstance(e, torch.Tensor): + if id(e) in inp_impls: + return inp_impls[id(e)] + else: + return cls.fake_mode.fake_tensor_converter.from_real_tensor( + cls.fake_mode, e + ) + else: + return e + + return (pytree.tree_map(map_out, r), mean_op_time) + + @classmethod + def _benchmark_estimate(cls, func, args, kwargs) -> tuple[Any, float]: # type: ignore[no-untyped-def] + """ + Estimates the runtime of a function using benchmarking. + + Args: + func: The function to estimate. + args: The arguments to pass to the function. + kwargs: The keyword arguments to pass to the function. + res: The result of the function. + + Returns: + Tuple[Any, float]: A tuple containing the result of the function and + the mean operation time in milliseconds. + """ + assert isinstance(cls.fake_mode, FakeTensorMode), ( + "Initialize/Assign FakeTensorMode before using this function" + ) + mean_op_time = 0.0 + if func._overloadpacket not in _VIEW_OPS: + try: + res, mean_op_time = cls._maybe_run_and_benchmark_fallback_kernel( + func, + args, + kwargs, + NotImplementedError, + ) + return (res, mean_op_time) + except NotImplementedError: + cls._no_fallback_kernel.add(func._overloadpacket) + res = func(*args, **kwargs or {}) + return (res, mean_op_time) + + # Adapted from: https://github.com/pytorch/pytorch/blob/9b902b3ee3bd608a19543362b66bf06c373dd374/torch/_inductor/scheduler.py#L589 # noqa: PGH004,B950 + @classmethod + def _roofline_estimate(cls, func, args, kwargs) -> tuple[Any, float]: # type: ignore[no-untyped-def] + """ + Estimates the runtime of a function using a roofline cost model. + + Args: + func: The function to estimate. + args: The arguments to pass to the function. + kwargs: The keyword arguments to pass to the function. + out: The output of the function. + + Returns: + Tuple[Any, float]: A tuple containing the result of the function and + the mean operation time in milliseconds. + """ + assert torch.cuda.is_available(), ( + "Roofline estimation needs to access CUDA capabilities to make estimations" + ) + + # Roofline Cost Model Explanation + + # The roofline cost model estimates the execution time of an operator based on + # the device's empirical maximum FLOPs/sec (pi) and device DRAM bandwidth (beta). + + # Variables: + # - pi: Maximum empirical FLOPs/sec of the device + # - beta: Maximum empirical device DRAM bandwidth (bytes/sec) of the device + # - I: Arithmetic intensity of the operator (FLOPs/bytes) + # - op_flops: FLOPs required by the operator + # - op_bytes: Bytes transferred to and from DRAM for the operator + + # Calculation Steps: + # 1. Calculate arithmetic intensity: I = op_flops / op_bytes + # 2. Calculate estimated FLOPs/sec: est_flops_sec = min(pi, beta * I) + # 3. Calculate estimated operator time: estimated_op_time = op_flops / est_flops_sec + # This simplifies to: estimated_op_time = max(op_flops / pi, op_flops / (beta * I)) + # Further simplifying: estimated_op_time = max(op_flops / pi, op_bytes / beta) + + # Simplified Formulas: + # - compute_time = op_flops / pi + # - transfer_time = op_bytes / beta + # - estimated_op_time = max(compute_time, transfer_time) + + kwargs = kwargs if kwargs else {} + out = func(*args, **kwargs) + op_time = 0.0 + func_packet = func._overloadpacket + if func_packet not in _IGNORE_OPS: + flat_args_kwargs, args_spec = pytree.tree_flatten((args, kwargs)) + flat_outs, out_spec = pytree.tree_flatten(out) + transfer_time = get_transfer_time(flat_args_kwargs, flat_outs) + + out_dtypes = { + t.dtype + for t in flat_outs + if isinstance(t, torch.Tensor) and t.dtype in _FLOAT_TYPES + } + + args, kwargs = pytree.tree_unflatten(flat_args_kwargs, args_spec) + out = pytree.tree_unflatten(flat_outs, out_spec) + + compute_time = get_compute_time(func_packet, args, kwargs, out, out_dtypes) + # We get the estimated time as the max of the transfer time and + # compute time. We divide by 1e6 to get the time in ms + op_time = max(transfer_time, compute_time) / 1e6 + + return (out, op_time) + + def display_modulewise_stats(self, depth: int = 2) -> None: + """ + Displays module-wise statistics collected by ``RuntimeEstimator``. + + Prints the pre-forward and pre-backward execution orders. + Displays the module-wise forward and backward runtimes in milliseconds. + + Args: + depth (int): The maximum depth of module hierarchy to display (default to 2). + """ + print("Pre-Forward Execution Order: ") + for mod_fqn in self.mod_fw_pre_order: + mod_depth = mod_fqn.count(".") + 1 + if mod_depth > depth: + continue + print(mod_fqn) + print("Pre-Backward Execution Order: ") + for mod_fqn in self.mod_bw_pre_order: + mod_depth = mod_fqn.count(".") + 1 + if mod_depth > depth: + continue + print(mod_fqn) + for mod_fqn, runtimes in self.mod_runtimes.items(): + mod_depth = mod_fqn.count(".") + 1 + if mod_depth > depth: + continue + print( + f"{mod_fqn} fw: {runtimes.get('fw', 0.0):.3f}ms bw: {runtimes.get('bw', 0.0):.3f}ms" + ) + + def __torch_dispatch__(self, func, types, args=..., kwargs=None): # type: ignore[no-untyped-def] + # TODO: @sanketpurandare: Flatten tensors by desugaring the tensor subclasses + # TODO: @sanketpurandare: Add logic for incorporating communication time + res, op_time = self._estimate(func, args, kwargs) + for par in self._mod_tracker.parents: + if self._mod_tracker.is_bw: + self.mod_runtimes[par]["bw"] += op_time + else: + self.mod_runtimes[par]["fw"] += op_time + self.total_runtime += op_time + return res + + def __call__(self, estimate_mode_type: str) -> Self: + """ + Sets the estimate mode type. + + Currently supported modes: + - "operator-level-benchmark": Estimates runtime using operator benchmarking. + - "operator-level-cost-model": Estimates runtime using roofline cost model. + + Args: + estimate_mode_type (str): The type of estimate mode to use. + + Returns: + RuntimeEstimator: The runtime estimator instance. + + Raises: + NotImplementedError: If the estimate mode type is not supported. + """ + if estimate_mode_type == "operator-level-benchmark": + self._estimate = RuntimeEstimator._benchmark_estimate + elif estimate_mode_type == "operator-level-cost-model": + self._estimate = RuntimeEstimator._roofline_estimate + else: + raise NotImplementedError( + f"estimate_mode_type {estimate_mode_type} not supported" + ) + self._estimate_mode_type = estimate_mode_type + return self + + def __enter__(self) -> Self: + fake_mode = active_fake_mode() + assert isinstance(fake_mode, FakeTensorMode), ( + "No FakeTensorMode found, designed to used under FakeTensorMode" + ) + RuntimeEstimator.fake_mode = fake_mode + self.total_runtime = 0.0 + self.mod_runtimes = defaultdict(lambda: defaultdict(lambda: 0.0)) + self.mod_fw_pre_order.clear() + self.mod_bw_pre_order.clear() + self.mod_fw_post_order.clear() + self.mod_bw_post_order.clear() + self._mod_tracker.register_user_hooks( + pre_fw_hook=lambda mod, inp: self.mod_fw_pre_order.append( + self._mod_tracker.get_known_fqn(mod) + ), + pre_bw_hook=lambda mod, g_out: self.mod_bw_pre_order.append( + self._mod_tracker.get_known_fqn(mod) + ), + post_fw_hook=lambda mod, inp, out: self.mod_fw_post_order.append( + self._mod_tracker.get_known_fqn(mod) + ), + post_bw_hook=lambda mod, g_inp: self.mod_bw_post_order.append( + self._mod_tracker.get_known_fqn(mod) + ), + ) + self._mod_tracker.__enter__() + super().__enter__() + return self + + # pyrefly: ignore [bad-override] + def __exit__(self, *args: Any) -> None: + print( + f"Estimated ({self._estimate_mode_type})" + f"total_time: {self.total_runtime:.3f} ms" + ) + if len(self._no_fallback_kernel) > 0: + print("no_fallback_kernel: ", list(self._no_fallback_kernel)) + super().__exit__(*args) + self._mod_tracker.clear_user_hooks() + self._mod_tracker.__exit__() diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_tools/sac_estimator.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_tools/sac_estimator.py new file mode 100644 index 0000000000000000000000000000000000000000..c43de8c2b916742cf131b1761e801b41fc689ba6 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_tools/sac_estimator.py @@ -0,0 +1,961 @@ +import math +import os +import sys +from collections import OrderedDict +from dataclasses import astuple, dataclass +from typing import Any, NamedTuple +from typing_extensions import Self + +import torch +from torch import nan, nn, UntypedStorage +from torch._guards import active_fake_mode +from torch._subclasses.fake_tensor import FakeTensorMode +from torch.distributed._tools.common_utils import get_untyped_storages +from torch.distributed._tools.mod_tracker import ModTracker +from torch.distributed._tools.runtime_estimator import RuntimeEstimator +from torch.testing._internal.composite_compliance import ( + is_inplace, + is_inplace_view_fn, + is_view_fn, +) +from torch.utils._python_dispatch import TorchDispatchMode +from torch.utils._pytree import tree_flatten +from torch.utils.checkpoint import SAC_IGNORED_OPS + + +__all__ = ["SACEstimator", "SACStats", "MSPS", "SACTradeOffStats", "SACGreedyOrderMeta"] +aten = torch.ops.aten + +_ADDITIONAL_IGNORED_OPS = { + aten.lift_fresh.default, # type: ignore[attr-defined] + torch.ops.profiler._record_function_exit._RecordFunction, # type: ignore[attr-defined] + aten.clone.default, # type: ignore[attr-defined] # seems needed for torch.compile +} +OPS_TO_ALWAYS_SKIP = SAC_IGNORED_OPS | _ADDITIONAL_IGNORED_OPS +# This value is hard-coded here: +# https://github.com/pytorch/pytorch/blob/5fba5d83f0703ff8077ab65448a998e9ad6598fd/c10/cuda/CUDACachingAllocator.cpp#L117 +_PYTORCH_MIN_ALLOCATE = ( + 2**9 if int(os.environ.get("PYTORCH_NO_CUDA_MEMORY_CACHING", 0)) == 0 else 1 +) + + +def _display_stats_tabular(headers: list[str], table_data: list[list[Any]]) -> None: + try: + from tabulate import tabulate + except ImportError as err: + raise ImportError("Please install tabulate.") from err + + # Use tabulate to print the table + print(tabulate(table_data, headers=headers, tablefmt="rst")) + + +# Based on: +# https://github.com/facebookresearch/xformers/blob/main/xformers/checkpoint.py#L71 +@dataclass +class _SACMetadata: + """ + Stores metadata for a single operator for SAC. + + Attributes: + func (Any): The operator function. + time_taken (float): The time taken by the operator. + memory_used (float): The memory used by the operator. + curr_idx (int): The current operator index. + output_ids (Tuple[int, ...]): The storage IDs of the operator's outputs. + inplace_info (Tuple[int, ...]): Tuple of self and parent operator for in-place operator. + is_view_like (bool): Whether the operator is view-like. + is_rand_op (bool): Whether the operator is a random operator. + """ + + func: Any + time_taken: float + memory_used: float + curr_idx: int + output_ids: tuple[int, ...] + inplace_info: tuple[int, ...] + is_view_like: bool + is_rand_op: bool + + +@dataclass +class _SACModMetadata: + """ + Stores metadata for a module for SAC. + + Attributes: + start_idx (int): The starting index of the module's operators. + force_store_random (bool): Whether to force store random operators in the module. + sac_metadata (List[_SACMetadata]): List of metadata for each operator in the module. + """ + + start_idx: int + force_store_random: bool + sac_metadata: list[_SACMetadata] + + +@dataclass +class SACStats: + """ + A class for storing Activation Checkpointing statistics corresponding to a module. + + Attributes: + func_names (List[str]): List of operator names. + runtimes (List[float]): List of operator runtimes in millliseconds. + memory (List[int]): List of operator memory usage in bytes. + view_like_ops (List[int]): Indices of view-like operators. + rand_ops (List[int]): Indices of random operators. + saved_autograd_ops (List[int]): Indices of operator results saved by autograd engine. + inplace_ops (List[Tuple[int, int]]): Tuple of indices of op and its first parent for Inplace operators. + force_store_random (bool): Whether to force store random operator results. + """ + + func_names: list[str] + runtimes: list[float] + memory: list[int] + view_like_ops: list[int] + rand_ops: list[int] + saved_autograd_ops: list[int] + inplace_ops: list[tuple[int, int]] + force_store_random: bool + + +class MSPS(NamedTuple): + """ + Represents Memory and Runtime Statistics for an operator/operator group. + + Attributes: + func_names (set[str]): Set of operator/operator group names. + op_idx (int): Operator index (group head index in case of operator groups). + memory (int): Memory usage in bytes. + runtime (float): Runtime in milliseconds. + msps (float): Memory per second calculated as memory/runtime. + """ + + func_names: set[str] + op_idx: int + memory: int + runtime: float + msps: float + + +@dataclass +class SACTradeOffStats: + """ + Stores statistics for activation-checkpointing trade-off. + + Attributes: + n_segments (int): Number of piecewise linear segments fitted to the trade-off curve. + slopes (List[float]): Slopes of the pieces of linear segments fitted to the trade-off curve. + intercepts (List[float]): Intercepts of the of the pieces of linear segments fitted to the trade-off curve. + fit_breaks (List[float]): Breakpoints of the of the pieces of linear segments fitted to the trade-off curve. + tradeoff_curve (OrderedDict[float, float]): Trade-off curve data of memory discarded vs recomputation time. + sac_memory (int): Total memory of operations available for activation checkpointing in bytes. + sac_runtime (float): Total runtime of operations available for activation checkpointing in milliseconds. + """ + + n_segments: int + slopes: list[float] + intercepts: list[float] + fit_breaks: list[float] + tradeoff_curve: OrderedDict[float, float] + sac_memory: int + sac_runtime: float + + +@dataclass +class SACGreedyOrderMeta: + """ + Stores metadata for Greedy-order SAC. + + Attributes: + recomputed_ops (set[int]): Set of operator indices to be recomputed. + stored_ops (set[int]): Set of operator indices to be stored. + inplace_op_groups (dict[int, set[int]]): Dictionary of inplace operator groups from group-head to operators. + random_ops_group (dict[int, set[int]]): Dictionary of random op group head to random ops. + msps_meta (list[MSPS]): List of Memory and Runtime Statistics for operators. + """ + + recomputed_ops: set[int] + stored_ops: set[int] + inplace_op_groups: dict[int, set[int]] + random_ops_group: dict[int, set[int]] + msps_meta: list[MSPS] + + +class SACEstimator(TorchDispatchMode): + """ + Estimates the memory and recomputation time trade-offs for applying Selective Activation Checkpointing (SAC). + + This class provides a ``TorchDispatchMode`` based context manager that can be used to estimate the memory and + runtime trade-offs of functions or ``torch.nn.Module``s for Selective Activation Checkpointing (SAC). It provides + detailed statistics and metadata information for operators of each module and provides a greedy order for selecting + the operators to be recomputed/checkpointed. It also constructs the per-module trade-off graph of discarded memory + vs recomputation time for the obtained greedy order. Using ``RuntimeEstimator`` under the hood, it supports two + estimation modes, `operator-level-benchmark` and (`operator-level-cost-model` (roofline model). + + Attributes: + sac_mod_stats (Dict[str, SACStats]): Dictionary from module FQN (fully qualified name) to ``SACStats``. + sac_mod_tradeoff_stats (Dict[str, SACTradeOffStats]): Dictionary from module FQN to ``SACTradeOffStats``. + sac_mod_greedy_order_meta (Dict[str, SACGreedyOrderMeta]): Dictionary from module FQN to ``SACGreedyOrderMeta``. + + Note: + 1) This class is designed to be used under ``FakeTensorMode``. + 2) Currently, it only supports estimation of compute time and memory usage, and does not consider communication. + + Example usage: + + .. code-block:: python + + sac_estimator = SACEstimator() + with FakeTensorMode(): + module = ... + inp = ... + with sac_estimator("operator-level-cost-model"): + output = module(inp) + sac_estimator.display_modulewise_sac_stats(depth=4, print_tabular=True) + """ + + def __init__(self) -> None: + self.sac_mod_stats: dict[str, SACStats] = {} + self.sac_mod_tradeoff_stats: dict[str, SACTradeOffStats] = {} + self.sac_mod_greedy_order_meta: dict[str, SACGreedyOrderMeta] = {} + self._mod_tracker = ModTracker() + self._sac_metadata: list[_SACMetadata] = [] + self._sac_mod_metadata: dict[str, _SACModMetadata] = {} + self._leaf_modules: set[str] = set() + self._saved_tensor_hook_ctx = torch.autograd.graph.saved_tensors_hooks( + self._pack_hook, lambda x: x + ) + self._saved_tensor_ids: set[int] = set() + self._estimate_runtime = RuntimeEstimator._roofline_estimate + + def _pack_hook(self, x: torch.Tensor) -> torch.Tensor: + # Hook function to track underlying storage IDs of tensors + # Updates the _saved_tensor_ids set with the IDs of the tensor's storages + # Used in conjunction with torch.autograd.graph.saved_tensors_hooks + untyped_storages = get_untyped_storages(x) + storage_ids = (hash(st) for st in untyped_storages) + self._saved_tensor_ids.update(storage_ids) + return x + + def _pre_fw_hook(self, mod: nn.Module, inputs: Any) -> None: + # Pre-forward hook function to prepare module metadata + # Tracks module FQN, force store random flag, and ``SACModMetadata`` + # Initializes metadata for non-leaf modules, marks leaf modules + mod_fqn = self._mod_tracker.get_known_fqn(mod) + assert mod_fqn is not None + num_children = sum(1 for _ in mod.children()) + if num_children > 0: + force_store_random = self._get_force_store_random(inputs) + self._sac_mod_metadata[mod_fqn] = _SACModMetadata( + start_idx=len(self._sac_metadata), + force_store_random=force_store_random, + sac_metadata=[], + ) + else: + self._leaf_modules.add(mod_fqn) + + def _post_fw_hook(self, mod: nn.Module, inputs: Any, outputs: Any) -> None: + # 1. Retrieves the module's FQN and checks if it's a leaf module + # 2. If not a leaf module, computes: + # - ``SACStats`` using the module's metadata and force store random flag + # - ``SACGreedyOrderMeta`` using the computed SAC statistics + mod_fqn = self._mod_tracker.get_known_fqn(mod) + assert mod_fqn is not None + if mod_fqn in self._leaf_modules: + return + else: + self.sac_mod_stats[mod_fqn] = self._get_sac_stats( + data=self._sac_mod_metadata[mod_fqn].sac_metadata, + force_store_random=self._sac_mod_metadata[mod_fqn].force_store_random, + ) + self.sac_mod_greedy_order_meta[mod_fqn] = self._get_greedy_order_meta( + self.sac_mod_stats[mod_fqn] + ) + + def _get_force_store_random(self, inputs: Any) -> bool: + flat_inputs, _ = tree_flatten(inputs) + return all(not isinstance(x, torch.Tensor) for x in flat_inputs) + + def _get_sac_stats( + self, data: list[_SACMetadata], force_store_random: bool + ) -> SACStats: + # 1. Ignore the operations that should be skipped by SAC such as aten.detach.default because autograd + # inserts those during backward and it breaks the fwd-bwd alignment + filtered_data = [x for x in data if x.func not in OPS_TO_ALWAYS_SKIP] + + ( + ops, + runtimes_, + memory_, + new_ids, + output_ids, + inplace_ops_, + view_like_ops_, + rand_ops_, + ) = zip(*[astuple(x) for x in filtered_data], strict=True) + + # 2. Extract the metadata information + runtimes = list(runtimes_) + memory = list(memory_) + func_names = [op._overloadpacket.__name__ for op in ops] + view_like_ops = [i for i, x in enumerate(view_like_ops_) if x] + rand_ops = [i for i, x in enumerate(rand_ops_) if x] + saved_autograd_ops = [ + i + for i, out_ids in enumerate(output_ids) + if set(out_ids).issubset(self._saved_tensor_ids) + ] + + # 3. Remap the inplace indices as we have removed OPS_TO_ALWAYS_SKIP + # FIXME @sanketpurandare: Fix this by changing the parent of the inplace-op + # to itself if the original parent is in OPS_TO_ALWAYS_SKIP. + try: + inplace_ops = [tuple(map(new_ids.index, x)) for x in inplace_ops_ if x] + except ValueError as err: + raise ValueError( + f"The remapping of inplace ops failed since one of the inplace op parents" + f" must have been present in {OPS_TO_ALWAYS_SKIP}" + ) from err + + # 4. The last operation is always stored as the output of the checkpoint + # block, so we can avoid recomputing it. We set the memory to zero + # instead of adding a new constraint because we want both the 0 and 1 + # endpoints for memory_budget to be valid + # FIXME @sanketpurandare: this heuristic for finding the last non-view non-inplace op + # might not always be correct, which would yield suboptimal policies + last_op = len(ops) - 1 + skip_ops_ = set(view_like_ops) | set({x[0] for x in inplace_ops}) + reversed_skip_ops = sorted(skip_ops_, reverse=True) + for op in reversed_skip_ops: + if op == last_op: + last_op -= 1 + + memory[last_op] = 0 + + # 5. Create a single ``SACStats`` object for the entire block of ``_SACMetadata``. + return SACStats( + func_names=func_names, + runtimes=runtimes, + memory=memory, + view_like_ops=view_like_ops, + rand_ops=rand_ops, + saved_autograd_ops=saved_autograd_ops, + inplace_ops=inplace_ops, # type: ignore[arg-type] + force_store_random=force_store_random, + ) + + def _get_inplace_metadata( + self, func: Any, out_storages: set[UntypedStorage] + ) -> tuple[int, tuple[int, ...], dict[str, tuple[int, ...]]]: + # 1. Get the current index of the metadata obtained so far + curr_idx = len(self._sac_metadata) + # 2. Get the set of active modules that are not leaf + active_mod_fqns: set[str] = { + par for par in self._mod_tracker.parents if par not in self._leaf_modules + } + # 3. Output ids are the identifies of the storage objects corresponding to the tensors + output_ids = tuple(hash(st) for st in out_storages) + # 4. If the function is not inplace, return + if not is_inplace(func): + return curr_idx, output_ids, dict.fromkeys(active_mod_fqns, ()) + + op_idx = curr_idx + # 5. Initialize the parent op ids of the inplace op for each of the active modules + mod_op_parent_idxs: dict[str, int] = dict.fromkeys(active_mod_fqns, -1) + for i, d in enumerate(self._sac_metadata): + # 6. Find the first occurrence of a tensor corresponding to each module that + # shares the same storage as the current tensor + past_output_ids = d.output_ids + if set(output_ids).issubset(set(past_output_ids)): + for mod_fqn, op_parent_idx in mod_op_parent_idxs.items(): + if op_parent_idx == -1: + if acm_stats := self._sac_mod_metadata.get(mod_fqn, None): + if i >= acm_stats.start_idx: + mod_op_parent_idxs[mod_fqn] = i + else: + assert mod_fqn == "Global" + mod_op_parent_idxs[mod_fqn] = i + # 7. If no parent tensor is found, then it's probably an inplace op on the arguments + # so one can just store the current-op idx as parent idx + for mod_fqn, op_parent_idx in mod_op_parent_idxs.items(): + if op_parent_idx < 0: + mod_op_parent_idxs[mod_fqn] = op_idx + mod_inplace_info = { + mod_fqn: (op_idx, mod_op_parent_idxs[mod_fqn]) + for mod_fqn in active_mod_fqns + } + return curr_idx, output_ids, mod_inplace_info # type: ignore[return-value] + + def __torch_dispatch__( # type: ignore[no-untyped-def] + self, func, types, args=..., kwargs=None + ): + # 1. Get the runtime estimate + out, op_time = self._estimate_runtime(func, args, kwargs) + flat_outs, _ = tree_flatten(out) + out_storages_cuda: set[UntypedStorage] = set() + out_storages_cpu: set[UntypedStorage] = set() + cuda_devices: set[torch.device] = set() + for o in flat_outs: + if isinstance(o, torch.Tensor): + if o.device.type == "cuda": + out_storages_cuda.update(get_untyped_storages(o)) + cuda_devices.add(o.device) + else: + out_storages_cpu.update(get_untyped_storages(o)) + + # Check if there's more than 1 CUDA device + assert len(cuda_devices) <= 1, ( + f"{func.__name__}'s output has more than 1 CUDA devices {cuda_devices}" + ) + + # 2. Get the memory consumed by output + nbytes_cuda = sum( + math.ceil(st.nbytes() / _PYTORCH_MIN_ALLOCATE) * _PYTORCH_MIN_ALLOCATE + for st in out_storages_cuda + ) + nbytes_cpu = sum(st.nbytes() for st in out_storages_cpu) + nbytes = nbytes_cuda + nbytes_cpu + # 3. Get the current operator index, output storage identifiers and inplace metadata + out_storages = out_storages_cuda | out_storages_cpu + curr_idx, output_ids, mod_inplace_info = self._get_inplace_metadata( + func, out_storages + ) + # 4. Determine if the function is in-place, random-op or a view-like + is_view_like = is_view_fn(func) or is_inplace_view_fn(func) + is_rand_op = torch.Tag.nondeterministic_seeded in func.tags + if is_view_like: + nbytes = 0 + # sdpa has non-deterministic seed, but might be deterministic + # if no dropout is applied + if func.overloadpacket.__name__ == "_scaled_dot_product_flash_attention": + # pyrefly: ignore [missing-attribute] + is_rand_op = kwargs.get("dropout_p", 0) != 0 + # 5. Create metadata information per active non-leaf module + for mod_fqn in self._mod_tracker.parents: + if mod_fqn in self._leaf_modules: + continue + acm = _SACMetadata( + func=func, + time_taken=op_time, + memory_used=nbytes, + curr_idx=curr_idx, + output_ids=output_ids, + inplace_info=mod_inplace_info[mod_fqn], + is_view_like=is_view_like, + is_rand_op=is_rand_op, + ) + if acm_stats := self._sac_mod_metadata.get(mod_fqn, None): + acm_stats.sac_metadata.append(acm) + else: + assert mod_fqn == "Global", ( + f"Module {mod_fqn} not found in AC Mod Stats" + ) + self._sac_metadata.append(acm) + + return out + + def _get_greedy_order_meta(self, sac_stats: SACStats) -> SACGreedyOrderMeta: + # An inplace-op group is a set of inplace-ops that operate on the same underlying tensor storage. + # 1. inplace_op_groups: A dictionary from the top-most parent of inplace-ops to the inplace-ops in the group + # The top-most op can itself be an inplace-op or can be a non-inplace op. + # 2. inplace_op_to_group_head: A dictionary that maps all the inplace-ops to their respective group heads. + inplace_op_groups: dict[int, set[int]] = {} + inplace_op_to_group_head: dict[int, int] = dict(sac_stats.inplace_ops) + + # Initialize inplace_op_groups using inplace_op_to_group_head + for op_idx, group_head_idx in inplace_op_to_group_head.items(): + op_group = inplace_op_groups.setdefault(group_head_idx, {group_head_idx}) + op_group.add(op_idx) + + # Like inplace ops, all of the random ops in the function/module should all be either recomputed or saved + # as a group. This is because, they affect the ranom seed generator. If force_store_random is set True, + # all of the random ops will be stored by default. For easy of manageability, we store the top-most random op + # as the leader of the random_ops_group. + random_ops_group: dict[int, set[int]] = {} + random_group_head_idx = min(sac_stats.rand_ops, default=-1) + has_rand_ops = bool(sac_stats.rand_ops) + if has_rand_ops: + random_ops_group[random_group_head_idx] = set(sac_stats.rand_ops) + + # 1. Random ops are stored if force_store_random is set + # 2. View-like ops are recomputed by default + # 3. For inplace_op_groups: + # a) If the head of this group is an inplace op, then we have to store the entire group. + # b) If any op in the group is random and force_store_random is set, then entire group will be stored. + # c) If none of ops in the group are random and the head of the group is not an in-place op, then + # this group can be considered for recomputation in its entirety + stored_ops: set[int] = set() + recomputed_ops: set[int] = set() + # Case 1: + if has_rand_ops and sac_stats.force_store_random: + stored_ops.add(random_group_head_idx) + # Case 2: + recomputed_ops.update(set(sac_stats.view_like_ops)) + + for group_head_idx, op_group in inplace_op_groups.items(): + # Case 3a: + if group_head_idx in inplace_op_to_group_head: + stored_ops.add(group_head_idx) + # Case 3b: + if ( + sac_stats.force_store_random & len(op_group & set(sac_stats.rand_ops)) + > 0 + ): + stored_ops.add(group_head_idx) + + # The potential recompute candidates are populated as: + recompute_candidates: set[int] = set() + # 1) The random group head if it is not stored + if has_rand_ops and random_group_head_idx not in stored_ops: + recompute_candidates.add(random_group_head_idx) + # 2) The in-place op group heads that are not stored + recompute_candidates.update(set(inplace_op_groups.keys()) - stored_ops) + # 3) The non-inplace and non-random ops that are neither stored nor recomputed by default + recompute_candidates.update( + set(range(len(sac_stats.memory))) + - recomputed_ops + - stored_ops + - set(inplace_op_to_group_head.keys()) + - set(sac_stats.rand_ops) + ) + + # We define msps for a recomp candidate as the ratio of memory/runtime aka memory savings per second + msps_meta: list[MSPS] = [] + for cand_idx in recompute_candidates: + op_indices = {cand_idx} + if cand_idx in inplace_op_groups: + op_indices.update(inplace_op_groups[cand_idx]) + if has_rand_ops and cand_idx == random_group_head_idx: + op_indices.update(sac_stats.rand_ops) + + mem = sum(sac_stats.memory[op_idx] for op_idx in op_indices) + runtime = sum(sac_stats.runtimes[op_idx] for op_idx in op_indices) + func_names = {sac_stats.func_names[op_idx] for op_idx in op_indices} + msps = (mem / runtime) if runtime > 0 else sys.float_info.max + msps_meta.append(MSPS(func_names, cand_idx, mem, runtime, msps)) + # We choose candidates to be recomputed based on increasing msps + msps_meta.sort(key=lambda x: x.msps, reverse=True) + return SACGreedyOrderMeta( + recomputed_ops, stored_ops, inplace_op_groups, random_ops_group, msps_meta + ) + + def _get_sac_tradeoff_pwlf_stats( + self, + sac_stats: SACStats, + greedy_order_meta: SACGreedyOrderMeta, + n_segments: int = 2, + save_tradeoff_graph: bool = False, + filename: str = "ac_tradeoff", + ) -> SACTradeOffStats: + try: + import numpy as np # type: ignore[import-not-found] + import pwlf # type: ignore[import-untyped, import-not-found] + except ImportError as err: + raise ImportError("Please install pwlf and numpy package.") from err + + stored_ops, recomputed_ops, inplace_op_groups, random_ops_group, msps_meta = ( + greedy_order_meta.stored_ops, + greedy_order_meta.recomputed_ops, + greedy_order_meta.inplace_op_groups, + greedy_order_meta.random_ops_group, + greedy_order_meta.msps_meta, + ) + # 1. Initialize the discarded memory and recomputation runtime to sum of already chosen recomputed_ops + recomp_indices: set[int] = set() + for r_idx in recomputed_ops: + recomp_indices.add(r_idx) + if r_idx in inplace_op_groups: + recomp_indices.update(inplace_op_groups[r_idx]) + if r_idx in random_ops_group: + recomp_indices.update(random_ops_group[r_idx]) + + discarded_mem = sum(sac_stats.memory[op_idx] for op_idx in recomp_indices) + recomp_runtime = sum(sac_stats.runtimes[op_idx] for op_idx in recomp_indices) + # 2. Initialize the max recomputation time and total recomputation memory + sac_runtime = sum(sac_stats.runtimes) + sac_memory = sum(sac_stats.memory) + # 3. Tradeoff curve stores the KV pair of the discarded memory to total memory and, + # recomputation time to total runtime incurred. + delta = 1e-2 + tradeoff_curve = OrderedDict() + # 4. Initialize the trade-off curve with the stats of of already chosen recomputed_ops + tradeoff_curve[(discarded_mem / sac_memory) + delta] = ( + recomp_runtime / sac_runtime + ) + # 5. Update the trade-off curve with memory and runtime stats of SAC candidates in the + # greedy order of their ``MSPS``. + for cand in msps_meta: + discarded_mem += cand.memory + recomp_runtime += cand.runtime + tradeoff_curve[(discarded_mem / sac_memory) + delta] = ( + recomp_runtime / sac_runtime + ) + # 6. Finally, we add the memory and recomputation time of the always stored ops. + stored_indices: set[int] = set() + for s_idx in stored_ops: + stored_indices.add(s_idx) + if s_idx in inplace_op_groups: + stored_indices.update(inplace_op_groups[s_idx]) + if s_idx in random_ops_group: + stored_indices.update(random_ops_group[s_idx]) + discarded_mem += sum(sac_stats.memory[op_idx] for op_idx in stored_indices) + recomp_runtime += sum(sac_stats.runtimes[op_idx] for op_idx in stored_indices) + tradeoff_curve[(discarded_mem / sac_memory) + delta] = ( + recomp_runtime / sac_runtime + ) + x_ = list(tradeoff_curve.keys()) + y_ = list(tradeoff_curve.values()) + # 7. We shift the y values to left and x values to right to upperbound the trade-off function + # TODO: Write a better explanation why this needs to be done + x = x_[: len(x_) - 1] + y = y_[1:] + tradeoff_pwlf = pwlf.PiecewiseLinFit(x, y) + # 8. Fit a piecewise linear function with the specified number of segments to the trade-off curve. + n_segments = max(min(len(x) - 2, n_segments), 1) + tradeoff_pwlf.fit(n_segments=n_segments) + + # save prediction graph + def save_prediction_graph( + pwlf_: pwlf.PiecewiseLinFit, x: list[float], y: list[float], filename: str + ) -> None: + try: + import matplotlib.pyplot as plt # type: ignore[import-not-found] + import numpy as np # type: ignore[import-not-found] + except ImportError as err: + raise ImportError( + "Install matplotlib and numpy using pip: pip install matplotlib numpy" + ) from err + # predict for the determined points + xHat = np.linspace(min(x), max(x), num=10000) + yHat = pwlf_.predict(xHat) + + # plot the results + plt.figure() + plt.plot(x, y, "o", label="Shifted") + plt.plot(xHat, yHat, "-", label="Predicted") + plt.plot(x_, y_, "x", label="Original") + plt.ylabel("Recomp time / Total recomp time") + plt.xlabel("Memory discarded / Total memory") + plt.legend() + plt.title(f"{filename}") + plt.suptitle( + f"Total Memory = {sac_memory} B Total Runtime = {sac_runtime:.4f} ms", + fontsize=10, + ) + folder_name = "tradeoff_graphs" + if not os.path.exists(folder_name): + os.makedirs(folder_name) + # Save the plots in the folder + plt.savefig(os.path.join(folder_name, f"{filename}.png")) + + if save_tradeoff_graph: + save_prediction_graph(tradeoff_pwlf, x, y, filename) + # 9. Obtain the slopes, intercepts and breakpoints of the fitted piecewise linear functions + slopes = tradeoff_pwlf.calc_slopes().tolist() + assert isinstance(tradeoff_pwlf.intercepts, np.ndarray) and isinstance( + tradeoff_pwlf.fit_breaks, np.ndarray + ) + intercepts = tradeoff_pwlf.intercepts.tolist() + fit_breaks = tradeoff_pwlf.fit_breaks.tolist() + return SACTradeOffStats( + n_segments=n_segments, + slopes=slopes, + intercepts=intercepts, # type: ignore[arg-type] + fit_breaks=fit_breaks, # type: ignore[arg-type] + tradeoff_curve=tradeoff_curve, + sac_memory=sac_memory, + sac_runtime=sac_runtime, + ) + + def display_sac_stats( + self, sac_stats: SACStats, print_tabular: bool = False + ) -> None: + """ + Displays the SAC statistics. + + Args: + sac_stats (SACStats): The SAC statistics to display. + print_tabular (bool, optional): Whether to print the statistics in a tabular format. Defaults to False. + + Prints: + 1. Total Memory: The total memory usage in bytes. + 2. Total Runtime: The total runtime in milliseconds. + 3. Store Random: A flag indicating whether to force store random operator results. + + Followed by a table with the following columns: + 1. Op Idx: The operator index. + 2. Op Name: The operator name. + 3. Runtimes (ms): The operator runtime in milliseconds. + 4. Memory (B): The operator memory usage in bytes. + 5. View-like: A flag indicating whether the operator is view-like. + 6. Random: A flag indicating whether the operator is random. + 7. Saved Autograd: A flag indicating whether the operator's result is saved by autograd engine. + 8. In-place: The index of the operator's first parent, or None if not in-place. + + If print_tabular is True, the table is printed in a tabular format. + Otherwise, the table is printed in a plain text format. + """ + print( + f"Total Memory: {sum(sac_stats.memory)} B Total Runtime: {sum(sac_stats.runtimes)} ms" + f" Store Random: {sac_stats.force_store_random}" + ) + table_data = [] + op_parent = dict(sac_stats.inplace_ops) + for i, fn_name in enumerate(sac_stats.func_names): + row = [ + str(i), + fn_name, + f"{sac_stats.runtimes[i]:.4f}", + str(sac_stats.memory[i]), + str(i in sac_stats.view_like_ops), + str(i in sac_stats.rand_ops), + str(i in sac_stats.saved_autograd_ops), + str(op_parent.get(i)), + ] + table_data.append(row) + # Define headers + headers = [ + "Op Idx", + "Op Name", + "Runtimes(ms)", + "Memory (B)", + "View-like", + "Random", + "Saved Autograd", + "In-place", + ] + if print_tabular: + _display_stats_tabular(headers, table_data) + else: + max_widths = [0 for _ in range(len(headers))] + table_data.insert(0, headers) + for row in table_data: + for i, elem in enumerate(row): + max_widths[i] = max(max_widths[i], len(elem)) + for row in table_data: + print( + "\t".join( + [f"{elem:<{max_widths[i]}}" for i, elem in enumerate(row)] + ) + ) + + def display_sac_tradeoff_stats( + self, + greedy_order_meta: SACGreedyOrderMeta, + sac_stats: SACStats, + print_tabular: bool = False, + ) -> None: + """ + Displays the SAC trade-off statistics. + + Args: + greedy_order_meta (SACGreedyOrderMeta): The SAC greedy order metadata. + sac_stats (SACStats): The SAC statistics. + print_tabular (bool, optional): Whether to print the statistics in a tabular format. Defaults to False. + + Prints: + A table with the following columns: + 1. Op Id(s): The operator index(es). + 2. Op Name(s): The operator name(s). + 3. Discarded Mem (%): The percentage of discarded memory. + 4. Discarded Mem (B): The discarded memory in bytes. + 5. Recomp time (%): The percentage of recomputed time. + 6. Recomp time (ms): The recomputed time in milliseconds. + 7. MSPS: The memory per second. + 8. Always Stored: A flag indicating whether the operator is always stored. + 9. Always Recomputed: A flag indicating whether the operator is always recomputed. + + If print_tabular is True, the table is printed in a tabular format. + Otherwise, the table is printed in a plain text format. + """ + table_data = [] + total_memory, total_runtime = sum(sac_stats.memory), sum(sac_stats.runtimes) + discarded_mem: int = 0 + recomp_runtime: float = 0.0 + + def append_row( + op_indices: set[int], + func_names: set[str], + msps: float | None = None, + stored: bool | None = False, + recomputed: bool | None = False, + ) -> None: + row = [ + str(op_indices), + str(func_names), + f"{discarded_mem / total_memory:.4f}", + str(discarded_mem), + f"{recomp_runtime / total_runtime:.4f}", + str(recomp_runtime), + f"{msps:.2e}" if msps is not None else str(nan), + str(stored), + str(recomputed), + ] + table_data.append(row) + + stored_ops, recomputed_ops, inplace_op_groups, random_ops_group, msps_meta = ( + greedy_order_meta.stored_ops, + greedy_order_meta.recomputed_ops, + greedy_order_meta.inplace_op_groups, + greedy_order_meta.random_ops_group, + greedy_order_meta.msps_meta, + ) + + for op_idx in recomputed_ops: + op_indices: set[int] = {op_idx} + if op_idx in inplace_op_groups: + op_indices.update(inplace_op_groups[op_idx]) + if op_idx in random_ops_group: + op_indices.update(random_ops_group[op_idx]) + discarded_mem += sum(sac_stats.memory[i] for i in op_indices) + recomp_runtime += sum(sac_stats.runtimes[i] for i in op_indices) + func_names = {sac_stats.func_names[i] for i in op_indices} + append_row(op_indices, func_names, recomputed=True) + + for cand in msps_meta: + discarded_mem += cand.memory + recomp_runtime += cand.runtime + op_indices = {cand.op_idx} + if cand.op_idx in inplace_op_groups: + op_indices.update(inplace_op_groups[cand.op_idx]) + if cand.op_idx in random_ops_group: + op_indices.update(random_ops_group[cand.op_idx]) + append_row(op_indices, cand.func_names, msps=cand.msps) + + for op_idx in stored_ops: + op_indices = {op_idx} + if op_idx in inplace_op_groups: + op_indices.update(inplace_op_groups[op_idx]) + if op_idx in random_ops_group: + op_indices.update(random_ops_group[op_idx]) + discarded_mem += sum(sac_stats.memory[i] for i in op_indices) + recomp_runtime += sum(sac_stats.runtimes[i] for i in op_indices) + func_names = {sac_stats.func_names[i] for i in op_indices} + append_row(op_indices, func_names, stored=True) + + headers = [ + "Op Id(s)", + "Op Name(s)", + "Discarded Mem (%)", + "Discarded Mem (B)", + "Recomp time (%)", + "Recomp time (ms)", + "MSPS", + "Always Stored", + "Always Recomputed", + ] + if print_tabular: + _display_stats_tabular(headers, table_data) + else: + max_widths = [0 for _ in range(len(headers))] + table_data.insert(0, headers) + for row in table_data: + for i, elem in enumerate(row): + max_widths[i] = max(max_widths[i], len(elem)) + for row in table_data: + print( + "\t".join( + [f"{elem:<{max_widths[i]}}" for i, elem in enumerate(row)] + ) + ) + + def pwlf_sac_tradeoff_curve( + self, + n_segments: int = 2, + save_tradeoff_graphs: bool = False, + ) -> None: + """ + Fits a piecewise linear function with the specified sumber of segments to the SAC trade-off curve of + discarded memory vs recomputation time. + + Args: + n_segments (int, optional): The number of segments to be used for fitting the piecewise linear function to + the trade-off curve. Defaults to 2. + save_tradeoff_graphs (bool, optional): Whether to save the trade-off graphs to file. Defaults to False. + + If save_tradeoff_graphs is True, the trade-off graphs are saved to file using the module FQN as the filename. + """ + for mod_fqn, sac_stats in self.sac_mod_stats.items(): + self.sac_mod_tradeoff_stats[mod_fqn] = self._get_sac_tradeoff_pwlf_stats( + sac_stats=sac_stats, + greedy_order_meta=self.sac_mod_greedy_order_meta[mod_fqn], + n_segments=n_segments, + save_tradeoff_graph=save_tradeoff_graphs, + filename=mod_fqn, + ) + + def display_modulewise_sac_stats( + self, depth: int = 2, print_tabular: bool = False + ) -> None: + """ + Displays the SAC and trade-off statistics for each module. + + Args: + depth (int, optional): The maximum depth of modules to display. Defaults to 2. + print_tabular (bool, optional): Whether to print the statistics in a tabular format. Defaults to False. + + Prints: + For each module with depth less than or equal to the specified depth: + 1. The SAC statistics for the module (using display_sac_stats). + 2. The SAC trade-off statistics for the module (using display_sac_tradeoff_stats). + + If print_tabular is True, the statistics are printed in a tabular format. + Otherwise, the statistics are printed in a plain text format. + """ + for mod_fqn, sac_stats in self.sac_mod_stats.items(): + mod_depth = mod_fqn.count(".") + 1 + if mod_depth > depth: + continue + print(f"Module: {mod_fqn}") + self.display_sac_stats(sac_stats, print_tabular) + print(f"AC Trade-off for Module: {mod_fqn} MSPS = Memory/Runtime") + self.display_sac_tradeoff_stats( + self.sac_mod_greedy_order_meta[mod_fqn], sac_stats, print_tabular + ) + + def __call__(self, estimate_mode_type: str) -> Self: + """ + Sets the estimate mode type. + + Currently supported modes: + - "operator-level-benchmark": Estimates runtime using operator benchmarking. + - "operator-level-cost-model": Estimates runtime using roofline cost model. + + Args: + estimate_mode_type (str): The type of estimate mode to use. + + Returns: + SACEstimator: The SAC estimator instance. + + Raises: + NotImplementedError: If the estimate mode type is not supported. + """ + if estimate_mode_type == "operator-level-benchmark": + self._estimate_runtime = RuntimeEstimator._benchmark_estimate + elif estimate_mode_type == "operator-level-cost-model": + self._estimate_runtime = RuntimeEstimator._roofline_estimate + else: + raise NotImplementedError( + f"estimate_mode_type {estimate_mode_type} not supported" + ) + return self + + def __enter__(self) -> Self: # type: ignore[no-untyped-def] + fake_mode = active_fake_mode() + assert isinstance(fake_mode, FakeTensorMode), ( + "SAC Estimator should be called in FakeTensorMode" + ) + RuntimeEstimator.fake_mode = fake_mode + self._mod_tracker.register_user_hooks( + pre_fw_hook=self._pre_fw_hook, + post_fw_hook=self._post_fw_hook, + ) + self._mod_tracker.__enter__() + self._saved_tensor_hook_ctx.__enter__() + return super().__enter__() + + def __exit__(self, *args: Any) -> None: # type: ignore[no-untyped-def] + self._saved_tensor_hook_ctx.__exit__() + self._mod_tracker.__exit__(*args) + super().__exit__(*args) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_tools/sac_ilp.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_tools/sac_ilp.py new file mode 100644 index 0000000000000000000000000000000000000000..8799493f260a5967c8086aa3d24e64132cc4102d --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_tools/sac_ilp.py @@ -0,0 +1,294 @@ +import logging +import math +from enum import IntEnum + +from torch.distributed._tools.ilp_utils import Graph, is_submodule +from torch.distributed._tools.sac_estimator import SACStats + + +try: + from pulp import ( # type: ignore[import-untyped,import-not-found] + lpDot, + LpInteger, + LpMaximize, + LpMinimize, + LpProblem, + LpStatus, + lpSum, + LpVariable, + PULP_CBC_CMD, + value, + ) +except ImportError as err: + raise ImportError( + "Please install pulp package. See: https://github.com/coin-or/pulp." + ) from err + +# Create a logger object +logger = logging.getLogger(__name__) + +# Set the logging level to INFO +logger.setLevel(logging.INFO) + + +def sac_milp( + graph: Graph, + memory_budget: float, + world_size: int = 1, + ac_units: list[str] | None = None, + fsdp_units: list[str] | None = None, +) -> tuple[dict[str, float], float, int]: + """ + MILP to decide which modules to AC and how much memory to discard. + The objective is to minimize recomputation time. + The constraint is to ensure peak memory is under budget. + + Args: + graph: graph representation of the model as a module submodule tree + where each node is a submodule with memory & runtime stats + memory_budget: memory budget in GiB + world_size: number of GPUs. In the case of FSDP, world_size will be + used to compute the amount of parameter and gradient memory on each rank + ac_units: a list of user-specified AC units. + fsdp_units: a list of FSDP units. AC units cannot be supermodules of FSDP units. + + Returns: + Dict[str, float]: the optimal SAC solution, mapping from module fqn to + the percentage of activation memory to **discard** + float: the recomputation time of the optimal SAC solution + int: upper bound on the peak memory of the optimal SAC solution. + note that value of -1 means that the ILP solver failed to find a solution. + + """ + num_nodes = len(graph.nodes) + M = 10**2 # note: numerical issue may occur if M is too big + MEM_MULTIPLIER = 2**30 + + # Create a MILP problem + prob = LpProblem("SAC", LpMinimize) + + # Create decision variables + # y_i: indicator for if module i is AC'ed + y = LpVariable.matrix("y", list(range(num_nodes)), 0, 1, LpInteger) + # r_i: percentage of discarded activation memory + r = LpVariable.matrix("r", list(range(num_nodes)), 0, 1) + # d_i: discarded activation memory for module i + d = LpVariable.matrix("d", list(range(num_nodes)), 0) + # a_i: total activation memory at module i + a = LpVariable.matrix("a", list(range(num_nodes)), 0) + # m_i: memory at module i, combining parameters, gradients, and activations + m = LpVariable.matrix("m", list(range(num_nodes)), 0) + # rcp_i: percentage of recomputation time + rcp = LpVariable.matrix("rcp", list(range(num_nodes)), 0) + # rct_i: recomputation time for module i (in ms) + rct = LpVariable.matrix("rct", list(range(num_nodes)), 0) + # max_m: peak memory + max_m = LpVariable("max_m", 0) + + # Add constraints + # [Constraint] User specified AC units + if ac_units: + ac_units_set = set(ac_units) + for i in range(num_nodes): + if graph.nodes[i]["fqn"] not in ac_units_set: + prob += y[i] == 0 + + # [Constraint] AC units cannot be supmodules of user specified FSDP units + if fsdp_units: + for i in range(num_nodes): + if any( + is_submodule(fsdp_unit, graph.nodes[i]["fqn"]) + for fsdp_unit in fsdp_units + ): + prob += y[i] == 0 + + # [Constraint] No nested AC units + for i in range(num_nodes): + for j in range(i + 1, num_nodes): + if graph.ad_matrix[i][j] == 1: + prob += y[i] + y[j] <= 1 + + # [Constraint] Do not AC leaf modules + for i in range(num_nodes): + if graph.nodes[i]["is_leaf"]: + prob += y[i] == 0 + + # [Constraint] Express amount of discarded activation memory + for i in range(num_nodes): + # There are two measures for activation memory: ACM and IA + # 1. IA is the activation memory saved when not using AC + # 2. ACM is the total activation memory, including those + # that are not typically saved when not using AC + # Note: ACM >= IA + if (not graph.nodes[i]["is_leaf"]) and graph.nodes[i][ + "sac_memory" + ] < graph.nodes[i]["act_fw_per_module"]: + logger.warning("For module {%s}: ", graph.nodes[i]["fqn"]) + logger.warning( + "activation memory from memory tracker is {%d},", + graph.nodes[i]["act_fw_per_module"], + ) + logger.warning( + "activation memory from SAC estimator is {%d}.", + graph.nodes[i]["sac_memory"], + ) + logger.warning("Something is wrong. Please check!") + logger.warning("Overriding the latter with the former.") + graph.nodes[i]["sac_memory"] = graph.nodes[i]["act_fw_per_module"] + ACM_i = graph.nodes[i]["sac_memory"] / MEM_MULTIPLIER + IA_i = graph.nodes[i]["act_fw_per_module"] / MEM_MULTIPLIER + prob += d[i] == ACM_i * r[i] - (ACM_i - IA_i) * y[i] + + # [Constraint] Ensure correctness of r_i + # There are two parts to its correctness + # 1. r_i > 0 only if y_i == 1 (discard only if it is an AC unit) + # 2. r_i needs to be large enough to cover the difference between + # ACM and IA. Otherwise, we are not saving any memory + for i in range(num_nodes): + prob += y[i] >= r[i] + if graph.nodes[i]["is_leaf"]: + continue + ACM_i = graph.nodes[i]["sac_memory"] / MEM_MULTIPLIER + IA_i = graph.nodes[i]["act_fw_per_module"] / MEM_MULTIPLIER + prob += r[i] >= (ACM_i - IA_i) / ACM_i * y[i] + + # [Constraint] Express total activation memory in the backward pass + for i in range(num_nodes): + AG_i = graph.nodes[i]["act_grad_per_module"] / MEM_MULTIPLIER + TA_i = graph.nodes[i]["act_total"] / MEM_MULTIPLIER + # related to discarded amount of memory + pos = graph.nodes[i]["pos_fw_post_order"] + coeff = [0] * num_nodes + for p in range(pos): + j = graph.name2node[graph.fw_post_order[p]]["index"] + coeff[j] = 1 + prob += a[i] == TA_i + AG_i - lpDot(coeff, d) + + # [Constraint] Express the total amount of memory at each module + # Note that unsharded parameters and gradients are not included here + P_1 = graph.nodes[0]["param_per_module"] / MEM_MULTIPLIER + for i in range(num_nodes): + TG_i = graph.nodes[i]["grad_total"] / MEM_MULTIPLIER + prob += m[i] == a[i] + (P_1 + TG_i) / world_size + + # [Constraint] Express peak memory + for i in range(num_nodes): + prob += max_m >= m[i] + + # [Constraint] Express percentage of recomputation time + for i in range(num_nodes): + for s in range(graph.nodes[i]["n_segments"]): + slope = graph.nodes[i]["slopes"][s] + intercept = graph.nodes[i]["intercepts"][s] + prob += rcp[i] >= slope * r[i] + intercept + + # [Constraint] Express recomputation time + # rct_i = (rcp_i * ACT_i) if y_i == 1 else 0 + for i in range(num_nodes): + ACT_i = graph.nodes[i]["sac_runtime"] + prob += rct[i] <= M * y[i] + prob += rct[i] <= ACT_i * rcp[i] + prob += rct[i] >= ACT_i * rcp[i] - M * (1 - y[i]) + + # [Constraint] Peak memory should be below budget + prob += max_m <= memory_budget + + # Set Objeictive + prob += lpSum(rct) + + # Solve + solver = PULP_CBC_CMD(gapRel=0.05, timeLimit=180, msg=0) + status = prob.solve(solver) + + # If solver fails, print status and return empty solution + if status != 1: + logger.error("Solver failed to find a solution: %s", LpStatus[status]) + return {}, 0, -1 + + # Gather and return solution if optimal solution is found + ac_decisions = {} + for i in range(num_nodes): + if round(y[i].varValue) == 1: + ac_decisions[graph.nodes[i]["fqn"]] = round(r[i].varValue, 4) + recomputation_time = round(value(prob.objective), 2) + peak_mem = round(max_m.varValue * MEM_MULTIPLIER) + + return ac_decisions, recomputation_time, peak_mem + + +class SACDecision(IntEnum): + RECOMPUTE = 0 + SAVE = 1 + + +def get_optimal_checkpointing_policy_per_module( + sac_stats: SACStats, memory_budget: float +) -> list[int]: + """ + This is adapted from -- + https://github.com/facebookresearch/xformers/blob/c6c0ac31f1b08542a0bc27278c6ed10f825f6963/xformers/checkpoint.py#L375 + + Given the SACStats of a module, including list of operators, their memory, runtimes, and metadata, + decide via MILP an optimal set of operators to checkpoint under a given ``memory_budget``. + + Args: + sac_stats: the SACStats object of the module + memory_budget: a float between zero and one + + Returns: + List[int]: the decision whether each operator should be saved (1) or recomptued (0). + """ + if not (0 <= memory_budget <= 1): + raise ValueError( + f"`memory_budget` must be a float between 0 and 1. Got {memory_budget}." + ) + num_ops = len(sac_stats.func_names) + + # Create a MILP problem + prob = LpProblem("SAC-per-module", LpMaximize) + + # Create decision variables + # x[i] = 1 means the i-th operator should be saved, otherwise it should be recomputed + x = LpVariable.matrix("x", list(range(num_ops)), 0, 1, LpInteger) + + # Add constraints + # [Constraint] random ops should be saved if ``force_store_random`` is True + # otherwise, random ops should either be all recomputed or all saved + if sac_stats.force_store_random: + for i in sac_stats.rand_ops: + prob += x[i] == SACDecision.SAVE.value + else: + for i1, i2 in zip(sac_stats.rand_ops[:-1], sac_stats.rand_ops[1:]): + prob += x[i1] == x[i2] + + # [Constraint] view-like ops should always be recomputed + for i in sac_stats.view_like_ops: + prob += x[i] == SACDecision.RECOMPUTE.value + + # [Constraint] inplace ops should always be done in conjunction with its parent op + for op, op_parent in sac_stats.inplace_ops: + if op != op_parent: + prob += x[op] == x[op_parent] + else: + prob += x[op] == SACDecision.SAVE.value + + # [Constraint] saved memory should be under the ``memory_budget`` + max_memory = math.ceil(memory_budget * sum(sac_stats.memory)) + prob += lpDot(x, sac_stats.memory) <= max_memory + + # [Objective] minimize recomputation time, note the ILP is a maximization problem + # because x[i] == 1 means the op is saved (not recomputed), and thus recomputation + # time is sum(sac_stats.runtimes) - lpDot(x, sac_stats.runtimes) + prob += lpDot(x, sac_stats.runtimes) + + # Solve + solver = PULP_CBC_CMD(gapRel=0.05, timeLimit=10, msg=0) + status = prob.solve(solver) + + # If solver fails, print status and return empty solution + if status != 1: + logger.error("Solver failed to find a solution: %s", LpStatus[status]) + return [] + + # Gather and return solution if optimal solution is found + return [round(x[i].varValue) for i in range(num_ops)] diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/algorithms/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/algorithms/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..06c814295699405de9a8f8cf7f6a861b07b63a05 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/algorithms/__init__.py @@ -0,0 +1 @@ +from .join import Join, Joinable, JoinHook diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/algorithms/join.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/algorithms/join.py new file mode 100644 index 0000000000000000000000000000000000000000..52d0c52fbfb59d3c906bd282db51a76886206c96 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/algorithms/join.py @@ -0,0 +1,350 @@ +# mypy: allow-untyped-defs +import warnings +from abc import ABC, abstractmethod +from types import TracebackType +from typing import Any, NamedTuple + +import torch +import torch.distributed as dist + + +__all__ = ["JoinHook", "Joinable", "Join"] + + +class JoinHook: + r""" + This defines a join hook, which provides two entry points in the join context manager. + + Entry points : a main hook, which is called repeatedly while there exists a non-joined + process, and a post-hook, which is called once all processes have joined. + + To implement a join hook for the generic join context manager, define a + class that inherits from :class:`JoinHook` and override ``main_hook()`` and + ``post_hook()`` as appropriate. + """ + + def main_hook(self) -> None: + r"""Call this hook while there exists a non-joined process to shadow collective communications in a training iteration. + + Training iteration i.e., in one forward pass, backward pass, and optimizer step. + """ + + def post_hook(self, is_last_joiner: bool) -> None: + r""" + Call hook after all processes have joined. + + It is passed an additional ``bool`` argument ``is_last_joiner``, which indicates if the rank is one of the last to join. + + Arguments: + is_last_joiner (bool): ``True`` if the rank is one of the last to + join; ``False`` otherwise. + """ + + +class Joinable(ABC): + r""" + This defines an abstract base class for joinable classes. + + A joinable class + (inheriting from :class:`Joinable`) should implement :meth:`join_hook`, + which returns a :class:`JoinHook` instance, in addition to + :meth:`join_device` and :meth:`join_process_group` that return device and + process group information, respectively. + """ + + @abstractmethod + def __init__(self) -> None: + super().__init__() + self._join_config = _JoinConfig.construct_disabled_join_config() + + @abstractmethod + def join_hook(self, **kwargs) -> JoinHook: + r""" + Return a :class:`JoinHook` instance for the given :class:`Joinable`. + + Arguments: + kwargs (dict): a :class:`dict` containing any keyword arguments + to modify the behavior of the join hook at run time; all + :class:`Joinable` instances sharing the same join context + manager are forwarded the same value for ``kwargs``. + """ + ... + + @property + @abstractmethod + def join_device(self) -> torch.device: + r"""Return the device from which to perform collective communications needed by the join context manager.""" + ... + + @property + @abstractmethod + def join_process_group(self) -> Any: + r"""Returns the process group for the collective communications needed by the join context manager itself.""" + ... + + +class _JoinConfig(NamedTuple): + r"""This includes all fields needed from a :class:`Joinable` instance for the join context manager side.""" + + enable: bool + throw_on_early_termination: bool + is_first_joinable: bool + + @staticmethod + def construct_disabled_join_config(): + r"""Return a :class:`_JoinConfig` instance indicating that join-related logic should be disabled. + + e.g. if the caller is not in a join context manager. + """ + return _JoinConfig( + enable=False, throw_on_early_termination=False, is_first_joinable=False + ) + + +class Join: + r""" + This class defines the generic join context manager, which allows custom hooks to be called after a process joins. + + These hooks should shadow the + collective communications of non-joined processes to prevent hanging and + erroring and to ensure algorithmic correctness. Refer to :class:`JoinHook` + for details about the hook definition. + + .. warning:: + The context manager requires each participating :class:`Joinable` to + call the method :meth:`notify_join_context()` before its own per- + iteration collective communications to ensure correctness. + + .. warning:: + The context manager requires that all ``process_group`` attributes in + the :class:`JoinHook` objects are the same. If there are multiple + :class:`JoinHook` objects, then the ``device`` of the first is used. + The process group and device information is used for checking for non- + joined processes and for notifying processes to throw an exception if + ``throw_on_early_termination`` is enabled, both of which using an all- + reduce. + + Arguments: + joinables (List[Joinable]): a list of the participating + :class:`Joinable` s; their hooks are iterated over in the given + order. + + enable (bool): a flag enabling uneven input detection; setting to + ``False`` disables the context manager's functionality and should + only be set when the user knows the inputs will not be uneven + (default: ``True``). + + throw_on_early_termination (bool): a flag controlling whether to throw an + exception upon detecting uneven inputs (default: ``False``). + + Example:: + + >>> import os + >>> import torch + >>> import torch.distributed as dist + >>> import torch.multiprocessing as mp + >>> # xdoctest: +SKIP + >>> import torch.nn.parallel.DistributedDataParallel as DDP + >>> import torch.distributed.optim.ZeroRedundancyOptimizer as ZeRO + >>> from torch.distributed.algorithms.join import Join + >>> + >>> # On each spawned worker + >>> def worker(rank): + >>> dist.init_process_group("nccl", rank=rank, world_size=2) + >>> model = DDP(torch.nn.Linear(1, 1).to(rank), device_ids=[rank]) + >>> optim = ZeRO(model.parameters(), torch.optim.Adam, lr=0.01) + >>> # Rank 1 gets one more input than rank 0 + >>> inputs = [torch.tensor([1.]).to(rank) for _ in range(10 + rank)] + >>> with Join([model, optim]): + >>> for input in inputs: + >>> loss = model(input).sum() + >>> loss.backward() + >>> optim.step() + >>> # All ranks reach here without hanging/erroring + """ + + def __init__( + self, + joinables: list[Joinable], + enable: bool = True, + throw_on_early_termination: bool = False, + **kwargs, + ): + if len(joinables) == 0: + raise ValueError("The join context manager requires at least one joinable") + self._joinables = joinables + self._join_hooks = [ + joinable.join_hook(**kwargs) for joinable in self._joinables + ] + self._enable = enable + self._throw_on_early_termination = throw_on_early_termination + self._set_joinable_configs() + self._extract_dist_info() + + def _set_joinable_configs(self) -> None: + r"""Set the :class:`_JoinConfig` of each participating :class:`Joinable`.""" + assert len(self._joinables) > 0 + is_first_joinable = True + for joinable in self._joinables: + joinable._join_config = _JoinConfig( + enable=self._enable, + throw_on_early_termination=self._throw_on_early_termination, + is_first_joinable=is_first_joinable, + ) + is_first_joinable = False + + def _extract_dist_info(self) -> None: + r""" + Extract the process group and device information from the joinables. + + If there are multiple joinables, then the context manager uses the + first specified device. + + Preconditions: + ``self._joinables`` is not ``None`` and is non-empty. + + Raises: + ValueError + If there are multiple conflicting ``process_group`` attributes + among the ``Joinable`` objects. + """ + process_group = None + device = None + # pyrefly: ignore [bad-assignment] + for joinable in self._joinables: + if process_group is None: + process_group = joinable.join_process_group + elif process_group != joinable.join_process_group: + raise ValueError( + "Using join context manager with multiple process groups" + ) + if device is None: + device = joinable.join_device + self._process_group = process_group + self._rank = dist.get_rank(self._process_group) + self._device = device + + def __enter__(self): ... + + def __exit__( + self, + type: type[BaseException] | None, + value: BaseException | None, + traceback: TracebackType | None, + ): + r""" + Repeatedly runs the main hooks until all processes join; then, runs the post-hooks. + + Raises: + RuntimeError + If ``throw_on_early_termination=True``. + """ + if not self._enable or type: + return # propagate the exception directly if one was raised + + all_procs_joined = False + is_last_joiner = True + + i = 0 + WARN_THRESHOLD = 1000 + warnings.simplefilter("once") + + while not all_procs_joined: + if i > WARN_THRESHOLD: + warnings.warn( + "Detected uneven input skew of greater than " + f"{WARN_THRESHOLD}. This means that rank " + f"{self._rank} has at least {WARN_THRESHOLD} " + f"fewer inputs than other currently-active ranks. " + "This level of skew could lead to performance " + "degradation during training.", + stacklevel=2, + ) + # Shadow the all-reduce in non-joined processes + num_nonjoined_procs = self._get_num_nonjoined_procs() + if num_nonjoined_procs == 0: + all_procs_joined = True + else: + if self._throw_on_early_termination: + self._notify_procs_to_terminate() + + # Run main hooks + for join_hook in self._join_hooks: + join_hook.main_hook() + + is_last_joiner = False + i += 1 + + # Run post-hooks + for join_hook in self._join_hooks: + join_hook.post_hook(is_last_joiner) + + def _get_num_nonjoined_procs(self): + r"""Return the number of non-joined processes by shadowing an all-reduce in the non-joined processes.""" + num_nonjoined_procs = torch.zeros(1, device=self._device) + dist.all_reduce(num_nonjoined_procs, group=self._process_group) + return num_nonjoined_procs.item() + + def _notify_procs_to_terminate(self): + r"""Schedule an all-reduce to notify non-joined processes to terminate. + + Also raise a ``RuntimeError`` indicating that the current process has exhausted its inputs. + """ + ones = torch.ones(1, device=self._device) + dist.all_reduce(ones, group=self._process_group) + raise RuntimeError(f"Rank {self._rank} exhausted all inputs.") + + @staticmethod + def notify_join_context(joinable: Joinable): + r""" + Notifies the join context manager that the calling process has not yet joined. + + Then, if ``throw_on_early_termination=True``, checks if uneven inputs have been detected + (i.e. if one process has already joined) and throws an exception if so. + + This method should be called from a :class:`Joinable` object before + its per-iteration collective communications. For example, this should + be called at the beginning of the forward pass in + :class:`DistributedDataParallel`. + + Only the first :class:`Joinable` object passed into the context + manager performs the collective communications in this method, and + for the others, this method is vacuous. + + Arguments: + joinable (Joinable): the :class:`Joinable` object calling this + method. + + Returns: + An async work handle for the all-reduce meant to notify the context + manager that the process has not yet joined if ``joinable`` is the + first one passed into the context manager; ``None`` otherwise. + """ + assert hasattr(joinable, "_join_config"), ( + f"Check that the {type(joinable)} constructor calls the " + "``Joinable`` constructor" + ) + + join_config = joinable._join_config + # First joinable is responsible for the collective communications + if not join_config.is_first_joinable or not join_config.enable: + return None + + device = joinable.join_device + process_group = joinable.join_process_group + + # Schedule an all-reduce to indicate that the caller has not yet joined + ones = torch.ones(1, device=device) + work = dist.all_reduce(ones, group=process_group, async_op=True) + + if join_config.throw_on_early_termination: + # Check if uneven inputs have been detected + zeros = torch.zeros(1, device=device) + dist.all_reduce(zeros, group=process_group) + should_throw = zeros.item() + if should_throw: + raise RuntimeError( + "Detected at least one rank that exhausted inputs. " + "Throwing across all ranks." + ) + return work diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/autograd/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/autograd/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6a52c36942e48e389a7e344abeb929febdb62c6c --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/autograd/__init__.py @@ -0,0 +1,66 @@ +from __future__ import annotations + +from typing import Any, TYPE_CHECKING + +import torch + + +if TYPE_CHECKING: + from types import TracebackType + + +def is_available() -> bool: + return hasattr(torch._C, "_dist_autograd_init") + + +if is_available() and not torch._C._dist_autograd_init(): + raise RuntimeError("Failed to initialize torch.distributed.autograd") + +if is_available(): + from torch._C._distributed_autograd import ( + _current_context, + _get_debug_info, + _get_max_id, + _init, + _is_valid_context, + _new_context, + _release_context, + _retrieve_context, + backward, + DistAutogradContext, + get_gradients, + ) + +__all__ = ["context", "is_available"] + + +class context: + """ + Context object to wrap forward and backward passes when using + distributed autograd. The ``context_id`` generated in the ``with`` + statement is required to uniquely identify a distributed backward pass + on all workers. Each worker stores metadata associated with this + ``context_id``, which is required to correctly execute a distributed + autograd pass. + + Example:: + >>> # xdoctest: +SKIP + >>> import torch.distributed.autograd as dist_autograd + >>> with dist_autograd.context() as context_id: + >>> t1 = torch.rand((3, 3), requires_grad=True) + >>> t2 = torch.rand((3, 3), requires_grad=True) + >>> loss = rpc.rpc_sync("worker1", torch.add, args=(t1, t2)).sum() + >>> dist_autograd.backward(context_id, [loss]) + """ + + def __enter__(self) -> int: + self.autograd_context = _new_context() + return self.autograd_context._context_id() + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ) -> None: + _release_context(self.autograd_context._context_id()) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8104a8df99f0b5c4a4f1db57ac98602a61666626 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/__init__.py @@ -0,0 +1,21 @@ +from . import _extension +from .api import CheckpointException +from .default_planner import DefaultLoadPlanner, DefaultSavePlanner +from .filesystem import FileSystemReader, FileSystemWriter +from .hf_storage import HuggingFaceStorageReader, HuggingFaceStorageWriter +from .metadata import ( + BytesStorageMetadata, + ChunkStorageMetadata, + Metadata, + TensorStorageMetadata, +) +from .optimizer import load_sharded_optimizer_state_dict +from .planner import LoadPlan, LoadPlanner, ReadItem, SavePlan, SavePlanner, WriteItem +from .quantized_hf_storage import QuantizedHuggingFaceStorageReader + +# pyrefly: ignore [deprecated] +from .state_dict_loader import load, load_state_dict + +# pyrefly: ignore [deprecated] +from .state_dict_saver import async_save, save, save_state_dict +from .storage import StorageReader, StorageWriter diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4a358e995fa408265af56b71a218478a6114a5d7 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/__pycache__/__init__.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/__pycache__/_async_executor.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/__pycache__/_async_executor.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eaec9cf9c4bf36522eec1ff77aa5d18e121c33c3 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/__pycache__/_async_executor.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/__pycache__/_checkpointer.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/__pycache__/_checkpointer.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b0593ffffe1d74f6073ab89f9593b47a3f84f97b Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/__pycache__/_checkpointer.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/__pycache__/_consolidate_hf_safetensors.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/__pycache__/_consolidate_hf_safetensors.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1aec7b8c1a2241e0a52df543c50eb020798d934e Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/__pycache__/_consolidate_hf_safetensors.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/__pycache__/_extension.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/__pycache__/_extension.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..aa416a594a18bebb233591115092d857fba8580c Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/__pycache__/_extension.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/__pycache__/_hf_utils.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/__pycache__/_hf_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a25f788ce9825a1eb6f00ba59b975d6f36b3d707 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/__pycache__/_hf_utils.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/__pycache__/_pg_transport.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/__pycache__/_pg_transport.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..41aebc35cf725c67f4290946aa750d19414ce184 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/__pycache__/_pg_transport.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/__pycache__/_state_dict_stager.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/__pycache__/_state_dict_stager.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..81b746ad0d4cd311d5b7ca921b31f2050b48f7b3 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/__pycache__/_state_dict_stager.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/__pycache__/_storage_utils.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/__pycache__/_storage_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c76d14bf26edc14e584cab9b2f1abe4a5fecf623 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/__pycache__/_storage_utils.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/__pycache__/_traverse.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/__pycache__/_traverse.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c8610602531494d7a2d6502292d109388deba65b Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/__pycache__/_traverse.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/__pycache__/api.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/__pycache__/api.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..69d5cb2121fe3d37f580f2c38889368bc9f0d048 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/__pycache__/api.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/__pycache__/default_planner.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/__pycache__/default_planner.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6c50fed3463480e5a8f7ec979dbf6a7c598f64d4 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/__pycache__/default_planner.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/__pycache__/hf_storage.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/__pycache__/hf_storage.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ee3f3c41f4b19578dcbeaba2ca5d900cc5638445 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/__pycache__/hf_storage.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/__pycache__/logger.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/__pycache__/logger.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..69fb8d3e36e8dc19c227e51848f9b6ca184ad47b Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/__pycache__/logger.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/__pycache__/logging_handlers.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/__pycache__/logging_handlers.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2d1953ffef7a447a62f92a0fb2c600b2e09e20b8 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/__pycache__/logging_handlers.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/__pycache__/metadata.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/__pycache__/metadata.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..197f269eb9974e01271e94fe2c9844a5464e3fde Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/__pycache__/metadata.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/__pycache__/optimizer.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/__pycache__/optimizer.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b5d2e5219f53722a8cf35b298409bab16a15e469 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/__pycache__/optimizer.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/__pycache__/planner.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/__pycache__/planner.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5f5332eb9d915b9e88eeb31c91a94b69ba23e06c Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/__pycache__/planner.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/__pycache__/resharding.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/__pycache__/resharding.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..05243e993e60aee60cf34e9199d8baebb822ae39 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/__pycache__/resharding.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/__pycache__/staging.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/__pycache__/staging.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a5e9e6adcd310259f9949167484421d4d317af3a Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/__pycache__/staging.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/__pycache__/state_dict.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/__pycache__/state_dict.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e5703c3a07ed8d4014abb78a46d72778263e9bce Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/__pycache__/state_dict.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/__pycache__/stateful.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/__pycache__/stateful.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..74bb9017f57715974e9842581b0e19d9cab73e15 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/__pycache__/stateful.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/__pycache__/storage.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/__pycache__/storage.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f544bbd123e681df48b3d053af60b130125f1a07 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/__pycache__/storage.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/__pycache__/utils.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/__pycache__/utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7e486f40d796f6bf2f8f3bc4d6e6f65d99fc10eb Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/__pycache__/utils.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/_async_executor.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/_async_executor.py new file mode 100644 index 0000000000000000000000000000000000000000..428c697b91e9b567e99d52714a8248d322798073 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/_async_executor.py @@ -0,0 +1,34 @@ +# pyre-strict +# mypy: allow-untyped-defs +import abc +import os +from concurrent.futures import Future +from typing import Optional, Union + +import torch.distributed as dist +from torch.distributed.checkpoint.metadata import STATE_DICT_TYPE +from torch.distributed.checkpoint.planner import SavePlanner +from torch.distributed.checkpoint.storage import StorageWriter + + +class _AsyncCheckpointExecutor(abc.ABC): + @abc.abstractmethod + def execute_save( + self, + staging_future_or_state_dict: Union[STATE_DICT_TYPE, Future[STATE_DICT_TYPE]], + *, + checkpoint_id: Union[str, os.PathLike, None] = None, + storage_writer: Optional[StorageWriter] = None, + planner: Optional[SavePlanner] = None, + process_group: Optional[dist.ProcessGroup] = None, + no_dist: bool = False, + use_collectives: bool = True, + ) -> Future: + """ + Execute the checkpoint save request asynchronously. + + This method is intended to be used as an abstraction for + implementing async checkpointing. The actual checkpoint save + operation is executed in a separate thread or process depending + on the implementation of this interface. + """ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/_async_process_executor.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/_async_process_executor.py new file mode 100644 index 0000000000000000000000000000000000000000..48390253c302a5acc9806ecac587a24022262565 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/_async_process_executor.py @@ -0,0 +1,455 @@ +# pyre-strict +# mypy: allow-untyped-defs +import gc +import logging +import os +from concurrent.futures import Future, ThreadPoolExecutor +from dataclasses import dataclass +from enum import Enum +from typing import Any, Optional, Union +from uuid import uuid4 + +import torch.distributed as dist +import torch.multiprocessing as mp +from torch.distributed import PrefixStore, TCPStore +from torch.distributed.checkpoint._async_executor import _AsyncCheckpointExecutor +from torch.distributed.checkpoint.logger import _dcp_method_logger, _init_logger +from torch.distributed.checkpoint.metadata import Metadata, STATE_DICT_TYPE +from torch.distributed.checkpoint.planner import SavePlanner +from torch.distributed.checkpoint.storage import StorageWriter +from torch.distributed.checkpoint.utils import _DistWrapper +from torch.distributed.elastic.agent.server.api import _get_fq_hostname +from torch.distributed.elastic.utils.distributed import get_free_port + + +logger = logging.getLogger() + + +class _CheckpointSaveProcessControlOpts(Enum): + INIT_COMPLETE = "init_complete" + TERMINATE = "terminate" + + +@dataclass(init=False, unsafe_hash=True) +class _CheckpointRequestIdentifier: + checkpoint_id: Union[str, os.PathLike, None] + uuid: str + + def __init__(self, checkpoint_id: Union[str, os.PathLike, None]): + self.checkpoint_id = checkpoint_id + self.uuid = str(uuid4()) + + +@dataclass +class _AsyncCheckpointRequest: + staged_state_dict: STATE_DICT_TYPE + checkpoint_request_id: _CheckpointRequestIdentifier + storage_writer: Optional[StorageWriter] = None + planner: Optional[SavePlanner] = None + no_dist: bool = False + use_collectives: bool = True + + +@dataclass(init=False) +class _ProcessGroupInitInfo: + local_rank: int + global_rank: int + world_size: int + tcp_store_master_addr: str + tcp_store_master_port: int + use_prefix_store: bool + disable_automatic_gc: bool + disable_manual_gc: bool + + def __init__(self, process_group: Optional[dist.ProcessGroup] = None): + self.local_rank = dist.get_node_local_rank(fallback_rank=0) + self.global_rank = dist.get_rank(process_group) + self.world_size = dist.get_world_size(process_group) + self.use_prefix_store = os.environ.get("DCP_USE_PREFIX_STORE", "0") == "1" + self.disable_automatic_gc = ( + os.environ.get("DCP_DISABLE_AUTOMATIC_GC", "0") == "1" + ) + self.disable_manual_gc = os.environ.get("DCP_DISABLE_MANUAL_GC", "0") == "1" + + # Let coordinator rank find a port on the localhost. + # Broadcast the (master_addr, port) to all ranks; each rank in the + # checkpoint daemon process will use TCPStore (master_addr, port) + # for collective communication. + dist_wrapper: _DistWrapper = _DistWrapper( + group=process_group, + use_dist=True, + coordinator_rank=0, + ) + + def get_master_addr_and_port() -> tuple[str, int]: + if self.use_prefix_store: + master_addr = os.environ.get("MASTER_ADDR") + master_port = os.environ.get("MASTER_PORT") + assert master_addr is not None, ( + "DCP needs MASTER_ADDR to use prefix store" + ) + assert master_port is not None, ( + "DCP needs MASTER_PORT to use prefix store" + ) + master_port = int(master_port) + else: + master_addr = os.environ.get("MASTER_ADDR") + if master_addr is None: + master_addr = _get_fq_hostname() + master_port = get_free_port() + + return master_addr, master_port + + self.tcp_store_master_addr, self.tcp_store_master_port = dist_wrapper.broadcast( + step="get_master_addr_and_port", + map_fun=get_master_addr_and_port, + ) + + +class _AsyncCheckpointProcess: + def __init__( + self, + pg_init_info: _ProcessGroupInitInfo, + ): + self.ctx = mp.get_context("spawn") + self._process_pipe, child_end = self.ctx.Pipe() + + self._save_process = self.ctx.Process( + target=self._checkpointing_subprocess, + args=( + pg_init_info, + child_end, + ), + daemon=True, + ) + + self._save_process.start() + + # Close the parent's copy of child end after we pass it into the child, + # so the recv()s on it will fail-fast if the child process dies. + child_end.close() + + # Wait for the checkpoint background process to initialize. + # Using default GLOO init timeout. + response = self._wait_for_response(timeout=1800) + if not response == _CheckpointSaveProcessControlOpts.INIT_COMPLETE: + raise AssertionError(f"Expected INIT_COMPLETE response, got {response}") + + def __del__(self) -> None: + if self._save_process.is_alive(): + try: + logger.info("Terminating the checkpoint background process.") + self._send(_CheckpointSaveProcessControlOpts.TERMINATE) + self._save_process.join(timeout=5) + finally: + if self._save_process.is_alive(): + logger.warning( + "Checkpoint background process is still alive after termination request. Sending SIGTERM." + ) + self._save_process.terminate() + + def _send(self, data: Any) -> None: + self._process_pipe.send(data) + + def _wait_for_response(self, timeout: Optional[float] = None) -> Any: + if not self._save_process.is_alive(): + logger.info("Checkpoint background process is dead calling join()...") + self._save_process.join() + raise RuntimeError( + f"Checkpoint background process is dead. Exit code: {self._save_process.exitcode}" + ) + + if timeout is not None and not self._process_pipe.poll(timeout=timeout): + raise RuntimeError( + f"Timed out after {timeout}s while waiting for response from checkpointer process pid: {self._save_process.pid}" + ) + + try: + response = self._process_pipe.recv() + except EOFError: + raise RuntimeError( # noqa: B904 + f"Checkpoint background process is dead. Exit code: {self._save_process.exitcode}" + ) + + if isinstance(response, BaseException): + raise response + + return response + + def save( + self, + staged_state_dict: STATE_DICT_TYPE, + *, + checkpoint_id: Union[str, os.PathLike, None] = None, + storage_writer: Optional[StorageWriter] = None, + planner: Optional[SavePlanner] = None, + no_dist: bool = False, + use_collectives: bool = True, + ) -> Metadata: + # Create a unique identifier to locate requests/responses + # from the checkpoint daemon process. + checkpoint_request_id = _CheckpointRequestIdentifier(checkpoint_id) + async_cp_request = _AsyncCheckpointRequest( + staged_state_dict=staged_state_dict, + checkpoint_request_id=checkpoint_request_id, + storage_writer=storage_writer, + planner=planner, + no_dist=no_dist, + use_collectives=use_collectives, + ) + self._send(async_cp_request) + result = self._wait_for_response() + if not isinstance(result, Metadata): + raise AssertionError(f"Expected Metadata response, got {type(result)}") + return result + + @staticmethod + def _execute_save( + state_dict: STATE_DICT_TYPE, + *, + checkpoint_request_id: _CheckpointRequestIdentifier, + storage_writer: Optional[StorageWriter] = None, + planner: Optional[SavePlanner] = None, + no_dist: bool = False, + use_collectives: bool = True, + ) -> Metadata: + from torch.distributed.checkpoint.state_dict_saver import save + + metadata = save( + state_dict, + checkpoint_id=checkpoint_request_id.checkpoint_id, + storage_writer=storage_writer, + planner=planner, + no_dist=no_dist, + use_collectives=use_collectives, + ) + return metadata + + @staticmethod + def _checkpointing_subprocess( + pg_init_info: _ProcessGroupInitInfo, + parent_conn, + ) -> None: + # Phase 1: Process Group Initialization + # Only needs to execute once during the lifetime of the checkpoint background process. + try: + _init_logger(pg_init_info.global_rank) + + # Setup environment variables for process group initialization. + os.environ["TORCHELASTIC_USE_AGENT_STORE"] = "False" + os.environ["MASTER_ADDR"] = pg_init_info.tcp_store_master_addr + os.environ["MASTER_PORT"] = str(pg_init_info.tcp_store_master_port) + os.environ["LOCAL_RANK"] = str(pg_init_info.local_rank) + os.environ["RANK"] = str(pg_init_info.global_rank) + os.environ["WORLD_SIZE"] = str(pg_init_info.world_size) + + logger.info( + "Initializing dist.ProcessGroup in checkpoint background process on port %s", + pg_init_info.tcp_store_master_port, + ) + # NOTE: GLOO backend is enforced here. + if pg_init_info.use_prefix_store: + logger.info( + "Initializing dist.ProcessGroup in checkpoint background process with prefix store" + ) + store = PrefixStore( + "AsyncCheckpointProcess/", + TCPStore( + pg_init_info.tcp_store_master_addr, + pg_init_info.tcp_store_master_port, + ), + ) + dist.init_process_group( + backend=dist.Backend.GLOO, + store=store, + world_size=pg_init_info.world_size, + rank=pg_init_info.global_rank, + ) + else: + dist.init_process_group(backend=dist.Backend.GLOO) + dist.barrier() + + logger.info("Checkpoint background process is running...") + parent_conn.send(_CheckpointSaveProcessControlOpts.INIT_COMPLETE) + + if pg_init_info.disable_automatic_gc: + # Disable automatic garbage collection + # GC can optionally be called manually after each checkpoint + gc.disable() + logger.info("Disabled automatic garbage collection") + except BaseException as e: # noqa: B036 + logger.error( + f"Checkpoint background process failed during initialization: {e}" # noqa: G004 + ) + parent_conn.send(e) + return + + # Phase 2: Serving Loop + try: + first_request = True + while True: + logger.info("Waiting for checkpoint save request...") + obj = parent_conn.recv() + if ( + isinstance(obj, _CheckpointSaveProcessControlOpts) + and obj == _CheckpointSaveProcessControlOpts.TERMINATE + ): + logger.info("Terminating the checkpoint background process.") + return + if not isinstance(obj, _AsyncCheckpointRequest): + raise AssertionError( + f"Expected _AsyncCheckpointRequest, got {type(obj)}" + ) + logger.info( + f"Received async checkpoint request with id={obj.checkpoint_request_id.checkpoint_id}" # noqa: G004 + ) + + try: + response = _AsyncCheckpointProcess._execute_save( + obj.staged_state_dict, + checkpoint_request_id=obj.checkpoint_request_id, + storage_writer=obj.storage_writer, + planner=obj.planner, + no_dist=obj.no_dist, + use_collectives=obj.use_collectives, + ) + parent_conn.send(response) + logger.info( + f"Completed checkpoint save request for checkpoint_id={obj.checkpoint_request_id}" # noqa: G004 + ) + + # in theory this manual gc should not be needed as we shouldn't be leaking anything from checkpointing process + if ( + pg_init_info.disable_automatic_gc + and not pg_init_info.disable_manual_gc + ): + del obj + + collected_objects = gc.collect() + + logger.info( + f"Manual garbage collection completed - collected {collected_objects} objects." # noqa: G004 + ) + if first_request: + # Freeze GC to not check GC for large checkpoint save plans + # After freezing, subsequent gc.collect() calls will only scan + # NEW objects created after this point, not the frozen save plan + logger.info( + "First checkpoint request completed - freezing gc" + ) + gc.freeze() + first_request = False + except BaseException as e: # noqa: B036 + logger.error( + f"Checkpoint save failed for checkpoint_id={obj.checkpoint_request_id.checkpoint_id}: {e}" # noqa: G004 + ) + parent_conn.send(e) + # Continue serving loop - don't exit process + finally: + logger.info("Checkpoint background process is shutting down...") + dist.destroy_process_group() + parent_conn.close() + + +_CHECKPOINT_PROCESS: Optional[_AsyncCheckpointProcess] = None + + +class _ProcessBasedAsyncCheckpointExecutor(_AsyncCheckpointExecutor): + def __init__(self) -> None: + self._executor = ThreadPoolExecutor(max_workers=1) + + @staticmethod + def _execute_save_impl( + *, + pg_init_info: Optional[_ProcessGroupInitInfo], + staging_future_or_state_dict: Union[Future[STATE_DICT_TYPE], STATE_DICT_TYPE], + checkpoint_id: Union[str, os.PathLike, None] = None, + storage_writer: Optional[StorageWriter] = None, + planner: Optional[SavePlanner] = None, + process_group: Optional[dist.ProcessGroup] = None, + no_dist: bool = False, + use_collectives: bool = True, + ) -> Metadata: + global _CHECKPOINT_PROCESS + if _CHECKPOINT_PROCESS is None: + if pg_init_info is None: + raise AssertionError( + "pg_init_info must not be None when _CHECKPOINT_PROCESS is None" + ) + ckpt_kwargs = {} + if (ckpt_id := getattr(storage_writer, "checkpoint_id", None)) is not None: + ckpt_kwargs["checkpoint_id"] = ckpt_id + ckpt_kwargs["process_group"] = process_group + + @_dcp_method_logger(**ckpt_kwargs) + def create_checkpoint_daemon_process() -> None: + global _CHECKPOINT_PROCESS + # pyrefly: ignore [bad-argument-type] + _CHECKPOINT_PROCESS = _AsyncCheckpointProcess(pg_init_info=pg_init_info) + + create_checkpoint_daemon_process() + + if _CHECKPOINT_PROCESS is None: + raise AssertionError( + "_CHECKPOINT_PROCESS must not be None after initialization" + ) + staged_state_dict = ( + staging_future_or_state_dict.result() + if isinstance(staging_future_or_state_dict, Future) + else staging_future_or_state_dict + ) + return _CHECKPOINT_PROCESS.save( + staged_state_dict=staged_state_dict, + checkpoint_id=checkpoint_id, + storage_writer=storage_writer, + planner=planner, + no_dist=no_dist, + use_collectives=use_collectives, + ) + + def execute_save( + self, + staging_future_or_state_dict: Union[Future[STATE_DICT_TYPE], STATE_DICT_TYPE], + *, + checkpoint_id: Union[str, os.PathLike, None] = None, + storage_writer: Optional[StorageWriter] = None, + planner: Optional[SavePlanner] = None, + process_group: Optional[dist.ProcessGroup] = None, + no_dist: bool = False, + use_collectives: bool = True, + ) -> Future: + """ + NOTE: + + - Checkpoint process is implemented as a daemon process. + The AsyncCheckpointProcess' lifetime is tied to the lifetime of the + main process (e.g. trainer process). + + - The first call to execute_save_in_process() will initialize the checkpoint + daemon process. Subsequent async checkpoint requests will not need process + initialization. Therefore, the first async checkpoint request will take longer to complete. + + - Process initialization can have significant overhead, dominated by latency for all ranks to spawn + a background process + process group initialization in the background process. + """ + + global _CHECKPOINT_PROCESS + pg_init_info: Optional[_ProcessGroupInitInfo] = None + if _CHECKPOINT_PROCESS is None: + # Find a port on coordinator rank and broadcast + # to all ranks. + pg_init_info = _ProcessGroupInitInfo(process_group) + + f: Future = self._executor.submit( + self._execute_save_impl, + pg_init_info=pg_init_info, + staging_future_or_state_dict=staging_future_or_state_dict, + checkpoint_id=checkpoint_id, + storage_writer=storage_writer, + planner=planner, + no_dist=no_dist, + use_collectives=use_collectives, + ) + f.add_done_callback(lambda f: self._executor.shutdown(wait=False)) + + return f diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/_async_thread_executor.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/_async_thread_executor.py new file mode 100644 index 0000000000000000000000000000000000000000..8dfe63413d433c75a012916f65628f2bd4e57f20 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/_async_thread_executor.py @@ -0,0 +1,71 @@ +# pyre-strict +# mypy: allow-untyped-defs +import os +from concurrent.futures import Future, ThreadPoolExecutor +from typing import Optional, Union + +import torch.distributed as dist +from torch.distributed.checkpoint._async_executor import _AsyncCheckpointExecutor +from torch.distributed.checkpoint.metadata import STATE_DICT_TYPE +from torch.distributed.checkpoint.planner import SavePlanner +from torch.distributed.checkpoint.storage import StorageWriter + + +def save_wrapper( + staging_future_or_state_dict: Union[Future[STATE_DICT_TYPE], STATE_DICT_TYPE], + *, + checkpoint_id: Union[str, os.PathLike, None] = None, + storage_writer: Optional[StorageWriter] = None, + planner: Optional[SavePlanner] = None, + process_group: Optional[dist.ProcessGroup] = None, + no_dist: bool = False, + use_collectives: bool = True, +) -> Future: + from torch.distributed.checkpoint.state_dict_saver import save + + staged_dict = ( + staging_future_or_state_dict.result() + if isinstance(staging_future_or_state_dict, Future) + else staging_future_or_state_dict + ) + return save( + staged_dict, + checkpoint_id=checkpoint_id, + storage_writer=storage_writer, + planner=planner, + process_group=process_group, + no_dist=no_dist, + use_collectives=use_collectives, + ) + + +class _ThreadBasedAsyncCheckpointExecutor(_AsyncCheckpointExecutor): + def __init__(self) -> None: + self._executor = ThreadPoolExecutor( + max_workers=1, thread_name_prefix="AsyncCheckpointExecutor" + ) + + def execute_save( + self, + staging_future_or_state_dict: Union[Future[STATE_DICT_TYPE], STATE_DICT_TYPE], + *, + checkpoint_id: Union[str, os.PathLike, None] = None, + storage_writer: Optional[StorageWriter] = None, + planner: Optional[SavePlanner] = None, + process_group: Optional[dist.ProcessGroup] = None, + no_dist: bool = False, + use_collectives: bool = True, + ) -> Future: + f: Future = self._executor.submit( + save_wrapper, + staging_future_or_state_dict=staging_future_or_state_dict, + checkpoint_id=checkpoint_id, + storage_writer=storage_writer, + planner=planner, + process_group=process_group, + no_dist=no_dist, + use_collectives=use_collectives, + ) + f.add_done_callback(lambda f: self._executor.shutdown(wait=False)) + + return f diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/_checkpointer.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/_checkpointer.py new file mode 100644 index 0000000000000000000000000000000000000000..13b0d627a36cc0fedc75695932260ecec718bcde --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/_checkpointer.py @@ -0,0 +1,103 @@ +from concurrent.futures import Future +from typing import Any, Optional + +import torch.distributed as dist +import torch.distributed.checkpoint.state_dict_loader as loader +import torch.distributed.checkpoint.state_dict_saver as saver +from torch.distributed.checkpoint.metadata import Metadata, STATE_DICT_TYPE +from torch.distributed.checkpoint.storage import ( + LoadPlanner, + SavePlanner, + StorageReader, + StorageWriter, +) + + +__all__: list[str] = [] + + +class _Checkpointer: + """This base class specifies a high level API for saving and loading + distributed `state_dict` 's. It provides an abstraction over the low-level APIs + provided by :py:mod:`torch.distributed.checkpoint.storage`, essentially calling + :py:meth: `torch.distributed.state_dict_saver.save` and + :py:meth: `torch.distributed.state_dict_loader.load` with the provided storage + readers and writers. + + .. warning:: + This feature is experimental and subject to removal/change. + + """ + + def __init__( + self, + storage_writer: StorageWriter, + storage_reader: StorageReader, + *, + process_group: Optional[dist.ProcessGroup] = None, + coordinator_rank: int = 0, + no_dist: bool = False, + load_planner: Optional[LoadPlanner] = None, + save_planner: Optional[SavePlanner] = None, + ): + """Initializes the Checkpointer instance. + + Args: + storage_writer: Instance of StorageWrite use to perform writes. + storage_reader: StorageReader used to load data from. + process_group: ProcessGroup to be used for cross-rank synchronization. + coordinator_rank: Rank to use to coordinate the checkpoint. rank0 is used by default. + no_dist: If ``True``, distributed checkpoint will not load in SPMD style. (Default: ``False``) + loader_planner: Instance of LoadPlanner to use when loading. + save_planner: Instance of SavePlanner to use when saving. + """ + self.storage_writer = storage_writer + self.storage_reader = storage_reader + self.process_group = process_group + self.coordinator_rank = coordinator_rank + self.no_dist = no_dist + self.load_planner = load_planner + self.save_planner = save_planner + + def save( + self, + state_dict: STATE_DICT_TYPE, + ) -> Metadata: + """Calls :py:meth: `torch.distributed.state_dict_saver.save`. Utilizing values passed during initialization.""" + return saver.save( + state_dict, + self.storage_writer, + process_group=self.process_group, + coordinator_rank=self.coordinator_rank, + no_dist=self.no_dist, + planner=self.save_planner, + ) + + def async_save( + self, + state_dict: STATE_DICT_TYPE, + ) -> Future: + """ + Calls :py:meth: `torch.distributed.state_dict_saver._async_save`. Utilizing values passed during initialization. + + Returns: + Future: A future holding the resultant Metadata object from `save`. + """ + response = saver.async_save( + state_dict, + storage_writer=self.storage_writer, + process_group=self.process_group, + planner=self.save_planner, + ) + if not isinstance(response, Future): + raise AssertionError("response should be a Future instance") + return response + + def load(self, state_dict: dict[str, Any]) -> None: + """Calls :py:meth: `torch.distributed.state_dict_loader.load`. Utilizing values passed during initialization.""" + loader.load( + state_dict, + storage_reader=self.storage_reader, + process_group=self.process_group, + planner=self.load_planner, + ) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/_consolidate_hf_safetensors.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/_consolidate_hf_safetensors.py new file mode 100644 index 0000000000000000000000000000000000000000..32d81fb1ea7213e7672a9e7fe23b030962a354f0 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/_consolidate_hf_safetensors.py @@ -0,0 +1,716 @@ +# pyre-strict + +import concurrent.futures +import glob +import json +import logging +import math +import mmap +import os +import struct +import time +from dataclasses import dataclass, field +from typing import Any, Optional + +import torch +from torch import distributed as dist +from torch.distributed.checkpoint._hf_utils import ( + _gen_file_name, + _get_dcp_custom_metadata, + _get_safetensors_file_metadata, + _metadata_fn, + DATA_OFFSETS_KEY, + DEFAULT_EXTRA_METADATA_KEY, + DTYPE_KEY, + SAVED_OFFSETS_KEY, + SHAPE_KEY, + SUFFIX, +) + + +logger: logging.Logger = logging.getLogger(__name__) + + +@dataclass +class _FqnData: + """ + Dataclass to store information about a tensor (identified by its fully qualified name). + + Attributes: + offset_in_file: Byte offset where this tensor's data begins in the output file + shape_in_file: Shape of the tensor in the output file + dtype_size: Size of the tensor's data type in bytes + dtype_str: String representation of the tensor's data type + """ + + offset_in_file: int = 0 + shape_in_file: list[int] = field(default_factory=list) + dtype_size: int = 0 + dtype_str: str = "" + + +@dataclass +class _OutputFileData: + """ + Dataclass to store information about an output safetensors file. + + Attributes: + metadata_size: Size of the metadata section in bytes + fqn_data: Dictionary mapping tensor names to their metadata + """ + + metadata_size: int = 0 + fqn_data: dict[str, _FqnData] = field(default_factory=dict) + + +@dataclass +class _InputFileData: + """ + Dataclass to store information about an input safetensors file. + + Attributes: + metadata_size: Size of the metadata section in bytes + metadata: Json metadata from the safetensors file + """ + + metadata_size: int = 0 + metadata: Any = None + + +def _parse_input_metadata( + input_files_data: dict[str, _InputFileData], + output_files_data: dict[str, _OutputFileData], +) -> None: + """ + Parse metadata from input safetensors files to determine the full tensor shapes and types. + + This function analyzes the metadata from all input files to determine the complete shape + of each tensor after consolidation. It updates the output_files_data with this information. + + Args: + input_files_data: dict of metadata from input safetensors files + output_files_data: Dictionary mapping output file paths to their metadata + + Raises: + ValueError: If no DCP custom metadata is found in a safetensors file + """ + + from safetensors.torch import _getdtype # type: ignore[import] + + # Dictionary to track the full size of each tensor across all shards + fqn_to_size_mapping: dict[str, tuple[list[int], str]] = {} + + for file_data in input_files_data.values(): + safetensors_metadata = file_data.metadata + dcp_sharding_info = _get_dcp_custom_metadata(safetensors_metadata) + if not dcp_sharding_info: + raise ValueError( + "No DCP custom metadata found in safetensors file. The file must be saved with DCP to be consolidated." + ) + + for key, val in safetensors_metadata.items(): + if key == DEFAULT_EXTRA_METADATA_KEY: + continue + + # Get the shape of this tensor shard and its offset in the full tensor + sizes = val[SHAPE_KEY] + offsets = dcp_sharding_info[key][SAVED_OFFSETS_KEY] + + if key not in fqn_to_size_mapping: + # First time seeing this tensor - calculate its full size by adding offsets to dimensions + cur_size = [size + offset for size, offset in zip(sizes, offsets)] + fqn_to_size_mapping[key] = (cur_size, val[DTYPE_KEY]) + else: + # We've seen this tensor before - update its size if this shard extends beyond current known dimensions + cur_size = fqn_to_size_mapping[key][0] + for i in range(len(sizes)): + cur_size[i] = max(cur_size[i], sizes[i] + offsets[i]) + + # Now that we know the full size of each tensor, populate the output file data + for fqn, tensor_info in fqn_to_size_mapping.items(): + tensor_size = tensor_info[0] + dtype_str = tensor_info[1] + for output_data in output_files_data.values(): + # Add this tensor to the output file if it's already assigned there + if fqn in output_data.fqn_data: + output_data.fqn_data[fqn] = _FqnData( + shape_in_file=tensor_size, + dtype_size=torch.finfo(_getdtype(dtype_str)).bits + // 8, # Convert bits to bytes + dtype_str=dtype_str, + ) + + +def _write_metadata( + output_files_data: dict[str, _OutputFileData], +) -> None: + """ + Write metadata to the beginning of each output safetensors file. + + This function writes the metadata section to each output file, including information + about tensor shapes, data types, and offsets. It also updates the offset_in_file + field for each tensor in the output_files_data. + + Args: + output_files_data: Dictionary mapping output file paths to their metadata + """ + # Process each output file + for file_path, output_data in output_files_data.items(): + with open(file_path, "wb") as f: + metadata = {} + curr_offset = 0 + + # Calculate offsets for each tensor in the file + for fqn, fqn_data in output_data.fqn_data.items(): + # Calculate the end offset by multiplying all dimensions and the data type size + end_offset = ( + curr_offset + + math.prod(fqn_data.shape_in_file) * fqn_data.dtype_size + ) + + # Store metadata for this tensor + metadata[fqn] = { + SHAPE_KEY: fqn_data.shape_in_file, + DTYPE_KEY: fqn_data.dtype_str, + DATA_OFFSETS_KEY: [ + curr_offset, + end_offset, + ], # Start and end byte offsets + } + # Store the offset for later use when writing the actual tensor data + fqn_data.offset_in_file = curr_offset + + # Update current offset for the next tensor + curr_offset = end_offset + + # Convert metadata to JSON and encode as bytes + json_metadata = json.dumps(metadata) + json_bytes = json_metadata.encode("utf-8") + + # Write the metadata size as an 8-byte unsigned integer (little-endian) + size_in_bytes = len(json_bytes) + header_len = struct.pack(" bytes: + """ + Read tensor data from a safetensors file using memory mapping for efficiency. + + Args: + file_path: Path to the safetensors file + start_offset: Start offset of tensor data within the data section + end_offset: End offset of tensor data within the data section + metadata_size: Size of the metadata header + + Returns: + Raw tensor data as bytes + """ + # Use mmap for efficient access + with open(file_path, "rb") as f: + with mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_READ) as mm: + absolute_start = metadata_size + start_offset + absolute_end = metadata_size + end_offset + return bytes(mm[absolute_start:absolute_end]) + + +def _process_output_file( + output_file: str, + output_data: _OutputFileData, + input_files_data: dict[str, _InputFileData], +) -> None: + """ + Process a single output file by writing tensor data from input files using memory mapping. + + This function is designed to be run in parallel for different output files. + + Args: + output_file: Path to the output file + output_data: Metadata for the output file + input_files_data: Dictionary mapping input file paths to their metadata + """ + + sorted_tensors = sorted( + output_data.fqn_data.items(), key=lambda x: x[1].offset_in_file + ) + + with open(output_file, "r+b") as output_stream: + output_stream.seek(0, os.SEEK_END) + # Process each tensor in sequential output order + for tensor_fqn, tensor_fqn_data in sorted_tensors: + full_tensor_mv = memoryview( + bytearray( + math.prod(tensor_fqn_data.shape_in_file) + * tensor_fqn_data.dtype_size + ) + ) + + # Process each input safetensors file + for safetensors_file in input_files_data: + file_metadata = input_files_data[safetensors_file].metadata + input_metadata_size = input_files_data[safetensors_file].metadata_size + + if tensor_fqn not in file_metadata: + continue + + metadata = file_metadata[tensor_fqn] + + data_offsets = metadata[DATA_OFFSETS_KEY] + + # Use memory mapping to read tensor data efficiently + data_to_write = _read_tensor_data_mmap( + safetensors_file, + data_offsets[0], + data_offsets[1], + input_metadata_size, + ) + + # Get the offsets of this tensor shard within the full tensor + fqn_custom_metadata = _get_dcp_custom_metadata(file_metadata)[ + tensor_fqn + ] # type: ignore[index] + offsets_of_tensor_being_read = fqn_custom_metadata[SAVED_OFFSETS_KEY] # type: ignore[index] + + # Write this tensor shard to the appropriate position in the output file + _write_sub_tensor_to_file_optimized( + full_tensor_mv, + data_to_write, + tensor_fqn_data.dtype_size, # Size of each element in bytes + tensor_fqn_data.shape_in_file, # Full tensor shape + offsets_of_tensor_being_read, # Where this shard belongs in the full tensor + metadata[SHAPE_KEY], # Shape of this shard + ) + + output_stream.write(full_tensor_mv) + + +def _write_data( + input_files_data: dict[str, _InputFileData], + output_files_data: dict[str, _OutputFileData], + num_threads: int = 1, +) -> None: + """ + Write tensor data from input files to the output files using memory mapping. + + This function reads tensor data from each input file and writes it to the appropriate + position in the output files based on the tensor's offsets. When num_threads > 1, + the work is split across threads with each thread handling a different output file. + + Args: + input_files_data: Dictionary mapping input file paths to their metadata + output_files_data: Dictionary mapping output file paths to their metadata + num_threads: Number of threads to use for parallel processing + """ + if num_threads <= 1 or len(output_files_data) <= 1: + # Sequential processing + for output_file, output_data in output_files_data.items(): + _process_output_file(output_file, output_data, input_files_data) + else: + # Parallel processing with ThreadPoolExecutor + with concurrent.futures.ThreadPoolExecutor( + max_workers=min(num_threads, len(output_files_data)) + ) as executor: + futures = [] + for output_file, output_data in output_files_data.items(): + futures.append( + executor.submit( + _process_output_file, + output_file, + output_data, + input_files_data, + ) + ) + + # Wait for all futures to complete + for future in concurrent.futures.as_completed(futures): + # Handle any exceptions that might have occurred + try: + future.result() + except Exception as e: + print(f"Error processing output file: {e}") + raise + + +def _write_sub_tensor_to_file_optimized( + full_tensor_mv: memoryview, + sub_tensor_bytes: bytes, + element_size: int, + tensor_shape: list[int], + sub_tensor_offsets: list[int], + sub_tensor_shape: list[int], +) -> None: + """ + Optimized version that writes the maximum number of contiguous bytes possible. + + Uses a unified algorithm that calculates the maximum contiguous bytes that can be + written in each iteration and continues until the entire subtensor is written. + Handles all sharding patterns efficiently: + - Full sub-tensor at once for row-wise sharding + - Row-by-row for column-wise sharding + - Optimized chunks for other patterns + + Args: + full_tensor_mv: Buffer to write the full tensor to + sub_tensor_bytes: Raw tensor data as bytes + element_size: Size of each element in bytes + tensor_shape: Shape of the full tensor + sub_tensor_offsets: Starting offsets of the sub-tensor within the full tensor + sub_tensor_shape: Shape of the sub-tensor + """ + # Handle empty tensors + if not tensor_shape or not sub_tensor_shape: + return + + # Calculate tensor strides for efficient indexing + tensor_strides = [1] + for i in range(len(tensor_shape) - 1, 0, -1): + tensor_strides.insert(0, tensor_strides[0] * tensor_shape[i]) + + sub_tensor_strides = [1] + for i in range(len(sub_tensor_shape) - 1, 0, -1): + sub_tensor_strides.insert(0, sub_tensor_strides[0] * sub_tensor_shape[i]) + + total_elements = math.prod(sub_tensor_shape) + + elements_written = 0 + while elements_written < total_elements: + # Convert linear index to multi-dimensional indices + temp_idx = elements_written + indices = [] + for dim_size in reversed(sub_tensor_shape): + indices.append(temp_idx % dim_size) + temp_idx //= dim_size + indices.reverse() + + # Calculate maximum contiguous elements we can write from this position + max_contiguous = _calculate_max_contiguous_elements( + indices, sub_tensor_shape, tensor_shape + ) + + # Calculate source position in bytes + src_pos = sum(idx * stride for idx, stride in zip(indices, sub_tensor_strides)) + src_byte_offset = src_pos * element_size + + # Calculate destination position in bytes + dest_indices = [ + idx + offset for idx, offset in zip(indices, sub_tensor_offsets) + ] + dest_pos = sum( + idx * stride for idx, stride in zip(dest_indices, tensor_strides) + ) + dest_byte_offset = dest_pos * element_size + + # Write the contiguous chunk + bytes_to_write = max_contiguous * element_size + chunk_data = sub_tensor_bytes[ + src_byte_offset : src_byte_offset + bytes_to_write + ] + full_tensor_mv[dest_byte_offset : dest_byte_offset + bytes_to_write] = ( + chunk_data + ) + + elements_written += max_contiguous + + +def _calculate_max_contiguous_elements( + indices: list[int], + sub_tensor_shape: list[int], + tensor_shape: list[int], +) -> int: + """ + Calculate the maximum number of contiguous elements that can be written from current position. + + This determines the largest chunk by checking how elements are laid out in memory + and finding natural boundaries where contiguity breaks. + + Args: + indices: Current position indices in the sub-tensor + sub_tensor_shape: Shape of the sub-tensor being written + tensor_shape: Shape of the full tensor + + Raises: + ValueError: If input lists are empty, have mismatched lengths, or contain invalid values + """ + # Validate input lists are not empty + if not indices or not sub_tensor_shape or not tensor_shape: + raise ValueError("Input lists cannot be empty") + + # Validate all lists have the same length (same number of dimensions) + if not (len(indices) == len(sub_tensor_shape) == len(tensor_shape)): + raise ValueError( + f"All input lists must have the same length. Got indices: {len(indices)}, " + f"sub_tensor_shape: {len(sub_tensor_shape)}, tensor_shape: {len(tensor_shape)}" + ) + + # Validate indices are within bounds of sub_tensor_shape + for i, (idx, sub_dim) in enumerate(zip(indices, sub_tensor_shape)): + if idx >= sub_dim: + raise ValueError( + f"Index {idx} at dimension {i} is out of bounds for sub-tensor shape {sub_tensor_shape}" + ) + + # Validate sub_tensor dimensions don't exceed tensor dimensions + for i, (sub_dim, tensor_dim) in enumerate(zip(sub_tensor_shape, tensor_shape)): + if sub_dim > tensor_dim: + raise ValueError( + f"Sub-tensor dimension {sub_dim} at position {i} exceeds tensor dimension {tensor_dim}" + ) + + # Start with elements remaining in the last dimension + max_contiguous = sub_tensor_shape[-1] - indices[-1] + + # Check if we can extend across multiple dimensions + # We can write across dimension boundaries if we're writing complete "rows" + # and the layout in destination tensor maintains contiguity + + # For 2D case: check if we can write multiple complete rows + if len(sub_tensor_shape) >= 2: + # If we're at the start of a row and can write complete rows + if indices[-1] == 0: # At start of last dimension (column) + rows_remaining = sub_tensor_shape[-2] - indices[-2] # Rows left to write + + # Check if writing complete rows maintains contiguity in destination + # This is true for row-wise sharding or when sub-tensor spans full width + if sub_tensor_shape[-1] == tensor_shape[-1]: # Full width + max_contiguous = rows_remaining * sub_tensor_shape[-1] + + # For higher dimensions, check if we can extend further + if len(sub_tensor_shape) >= 3 and indices[-2] == 0: + # Check if we can write complete 2D slices + remaining_in_dim = sub_tensor_shape[-3] - indices[-3] + if ( + sub_tensor_shape[-1] == tensor_shape[-1] + and sub_tensor_shape[-2] == tensor_shape[-2] + ): + max_contiguous = ( + remaining_in_dim * sub_tensor_shape[-2] * sub_tensor_shape[-1] + ) + + return max_contiguous + + +def _write_overall_metadata_file( + output_dir: str, + output_files_data: dict[str, _OutputFileData], +) -> None: + """ + Write the overall metadata file that maps tensor names to their file locations. + + This creates a model.safetensors.index.json file that HuggingFace models use + to locate tensors across multiple files. + + Args: + output_dir: Directory where the metadata file will be written + output_files_data: Dictionary mapping output file paths to their metadata + """ + total_size = 0 + weight_map = {} + for output_path, value in output_files_data.items(): + for fqn, fqn_data in value.fqn_data.items(): + total_size += math.prod(fqn_data.shape_in_file) * fqn_data.dtype_size + weight_map[fqn] = os.path.basename(output_path) + + metadata_to_write: dict[str, Any] = {} + metadata_to_write["metadata"] = {"total_size": total_size} + metadata_to_write["weight_map"] = weight_map + + metadata_path = os.path.join(output_dir, f"{_metadata_fn}") + with open(metadata_path, "w") as metadata_file: + json.dump(metadata_to_write, metadata_file, indent=2) + + +def _consolidate_safetensors_files( + input_dir: str, + output_dir: str, + fqn_to_file_mapping: dict[str, str], + num_threads: int, +) -> dict[str, _OutputFileData]: + output_files_data: dict[str, _OutputFileData] = {} + # Create multiple output files based on the provided mapping + for fqn, filename in fqn_to_file_mapping.items(): + output_path = os.path.join(output_dir, filename) + + if output_path not in output_files_data: + output_files_data[output_path] = _OutputFileData(fqn_data={fqn: _FqnData()}) + else: + output_files_data[output_path].fqn_data[fqn] = _FqnData() + + # Find all safetensors files in the input directory + safetensors_files = glob.glob(os.path.join(input_dir, f"*{SUFFIX}")) + + # Read metadata from all input files + input_files_data: dict[str, _InputFileData] = {} + for safetensor_file in safetensors_files: + with open(safetensor_file, "rb") as f: + metadata, size = _get_safetensors_file_metadata(f) + input_files_data[safetensor_file] = _InputFileData( + metadata_size=size, metadata=metadata + ) + # Step 1: Parse metadata to determine tensor shapes and types + _parse_input_metadata(input_files_data, output_files_data) + + # Step 2: Write metadata headers to output files + _write_metadata(output_files_data) + # Step 3: Write actual tensor data from input files to output files + _write_data(input_files_data, output_files_data, num_threads) + + return output_files_data + + +def consolidate_safetensors_files( + input_dir: str, + output_dir: str, + fqn_to_index_mapping: dict[str, int], + num_threads: int = 1, +) -> None: + """ + Main function to consolidate sharded safetensors files into one or more output files. + + This function orchestrates the entire consolidation process: + 1. Sets up the output file structure based on the fqn_to_index_mapping + 2. Finds all safetensors files in the input directory + 3. Parses metadata from all input files + 4. Writes metadata to the output files + 5. Writes tensor data from input files to output files + 6. Writes overall model.index.safetensors.json file with weight map + + Args: + input_dir: Directory containing sharded safetensors files + output_dir: Directory where consolidated files will be written + fqn_to_index_mapping: Optional mapping of tensor names to output file indices. + If None, all tensors will be consolidated into a single file. + num_threads: Number of threads to use for parallel processing of saving data to output files. + """ + start_time = time.time() + logger.info( + "Consolidating safetensors files from %s to %s. Beginning at time %f", + input_dir, + output_dir, + start_time, + ) + + max_index = max(fqn_to_index_mapping.values()) + fqn_to_file_mapping = { + fqn: _gen_file_name(idx, max_index) for fqn, idx in fqn_to_index_mapping.items() + } + + output_files_data = _consolidate_safetensors_files( + input_dir, output_dir, fqn_to_file_mapping, num_threads + ) + + # Step 4: Write overall model.index.safetensors.json file with weight map + _write_overall_metadata_file(output_dir, output_files_data) + + logger.info("Done consolidating. Took %.2f secs.", time.time() - start_time) + + +def consolidate_safetensors_files_on_every_rank( + input_dir: str, + output_dir: str, + fqn_to_index_mapping: dict[str, int], + num_threads: int = 1, + process_group: Optional[dist.ProcessGroup] = None, +) -> None: + """ + Consolidate sharded safetensors files across multiple ranks, with each rank handling a subset of output files. + + This function distributes the consolidation work by assigning output files to different ranks. + All tensors with the same index in fqn_to_index_mapping are processed by the same rank, + as they belong to the same output file. + + If process_group is provided, rank and world_size will be derived from it. Otherwise, + they will be automatically detected from the distributed environment if available. + + Args: + input_dir: Directory containing sharded safetensors files + output_dir: Directory where consolidated files will be written + fqn_to_index_mapping: Mapping of tensor names to output file indices + num_threads: Number of threads to use for parallel processing on each rank + process_group: PyTorch distributed process group (default: None, will use default group) + """ + + start_time = time.time() + # Derive rank and world_size from process_group or default distributed environment + if dist.is_available() and dist.is_initialized(): + rank = dist.get_rank(group=process_group) + world_size = dist.get_world_size(group=process_group) + else: + # Default to single process mode if distributed is not initialized + rank = 0 + world_size = 1 + logger.warning( + "Distributed environment not initialized. Running in single process mode." + ) + logger.info( + "Rank %d/%d: Consolidating safetensors files from %s to %s", + rank, + world_size, + input_dir, + output_dir, + ) + + # Find all unique indices in the mapping + unique_indices = set(fqn_to_index_mapping.values()) + + # Distribute indices across ranks + indices_for_this_rank = [] + for idx in unique_indices: + # Simple distribution: index % world_size == rank + if idx % world_size == rank: + indices_for_this_rank.append(idx) + + logger.info( + "Rank %d: Assigned %d output files out of %d total files", + rank, + len(indices_for_this_rank), + len(unique_indices), + ) + + # Filter the fqn_to_index_mapping to only include tensors for this rank + filtered_mapping = { + fqn: idx + for fqn, idx in fqn_to_index_mapping.items() + if idx in indices_for_this_rank + } + + if filtered_mapping: + # Convert index mapping to filename mapping + max_index = max(unique_indices) + filtered_filename_mapping = {} + for fqn, idx in filtered_mapping.items(): + filename = _gen_file_name(idx, max_index) + filtered_filename_mapping[fqn] = filename + + # Call the existing consolidation function with the filtered mapping + _consolidate_safetensors_files( + input_dir=input_dir, + output_dir=output_dir, + fqn_to_file_mapping=filtered_filename_mapping, + num_threads=num_threads, + ) + + logger.info( + "Rank %d: Done consolidating. Processed %d unique indices in %.2f secs.", + rank, + len(indices_for_this_rank), + time.time() - start_time, + ) + + # Wait for all ranks to complete + if dist.is_available() and dist.is_initialized(): + logger.info("Rank %d: Waiting for all ranks to complete...", rank) + dist.barrier() + logger.info("Rank %d: All ranks have completed.", rank) + if rank == 0: + logger.info("Total time taken: %.2f secs.", time.time() - start_time) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/_dedup_save_plans.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/_dedup_save_plans.py new file mode 100644 index 0000000000000000000000000000000000000000..acb81c41862852320cdc1d412ddaffdd48e73841 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/_dedup_save_plans.py @@ -0,0 +1,65 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates +import dataclasses +from collections import defaultdict +from typing import TYPE_CHECKING + +from torch.distributed.checkpoint.planner import SavePlan, WriteItem + + +if TYPE_CHECKING: + from torch.distributed.checkpoint.metadata import MetadataIndex + +__all__ = ["dedup_save_plans"] + + +def dedup_save_plans( + all_plans: list[SavePlan], + save_to_lowest_rank: bool = False, +) -> list[SavePlan]: + """ + Removes duplicate entries from appearing on multiple SavePlans. For each duplicate across + a set of SavePlans, only the smallest SavePlan in terms of planned storage keeps the entry. + + Please note that this function does not modify the original SavePlans, but rather returns + """ + + # Map to query the plan indices that a write item is duplicated in + write_item_to_plan_indices: dict[MetadataIndex, set[int]] = defaultdict(set) + # Map to query the write item from its index + write_item_idx_to_write_item: dict[MetadataIndex, WriteItem] = {} + # Set of write item indices that are present in each plan + # After deduplication, this will be the set of write item indices that are present in the final plans + plan_to_item_indices: list[set[MetadataIndex]] = [ + {item.index for item in plan.items} for plan in all_plans + ] + + for plan_idx, plan in enumerate(all_plans): + for write_item in plan.items: + # map each write item to its plan + write_item_to_plan_indices[write_item.index].add(plan_idx) + write_item_idx_to_write_item[write_item.index] = write_item + plan_to_size = [0] * len(all_plans) + for write_item_idx, plan_indices in write_item_to_plan_indices.items(): + if save_to_lowest_rank: + select_plan_idx = min(plan_indices) + else: + select_plan_idx = min( + plan_indices, key=lambda plan_idx: plan_to_size[plan_idx] + ) + + write_item = write_item_idx_to_write_item[write_item_idx] + # Ignore the storage size of anything that is not a tensor, since + # we don't know how much storage they represent + plan_to_size[select_plan_idx] += write_item.tensor_storage_size() or 1 + for plan_idx in plan_indices - {select_plan_idx}: + plan_to_item_indices[plan_idx].discard(write_item_idx) + # Sanity check + if len(all_plans) != len(plan_to_item_indices): + raise AssertionError("len(all_plans) != len(plan_to_item_indices)") + # Create new plans with the updated write items post deduplication + return [ + dataclasses.replace( + plan, items=[item for item in plan.items if item.index in item_indexes] + ) + for plan, item_indexes in zip(all_plans, plan_to_item_indices) + ] diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/_dedup_tensors.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/_dedup_tensors.py new file mode 100644 index 0000000000000000000000000000000000000000..c57b2e149106abbac66522aa571d1a462db4157d --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/_dedup_tensors.py @@ -0,0 +1,62 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates +import dataclasses +import logging +from typing import TYPE_CHECKING + +from torch.distributed.checkpoint.planner import SavePlan + + +if TYPE_CHECKING: + from torch.distributed.checkpoint.metadata import MetadataIndex + +__all__ = ["dedup_tensors"] + + +def init_logger() -> logging.Logger: + logger = logging.getLogger(__name__) + level = logging.INFO + logger.setLevel(level) + console = logging.StreamHandler() + formatter = logging.Formatter( + "%(asctime)s %(filename)s:%(lineno)s %(levelname)s p:%(processName)s t:%(threadName)s: %(message)s" + ) + console.setFormatter(formatter) + console.setLevel(level) + logger.addHandler(console) + logger.propagate = False + return logger + + +logger = init_logger() + + +# TODO add docstring for dedup_tensors +def dedup_tensors(all_plans: list[SavePlan]) -> list[SavePlan]: + all_plans = list(all_plans) + key_to_plan: dict[MetadataIndex, list[int]] = {} + for plan_idx, plan in enumerate(all_plans): + for write_item in plan.items: + key_to_plan.setdefault(write_item.index, []).append(plan_idx) + + replicated_items = {k: v for k, v in key_to_plan.items() if len(v) > 1} + + # Remove duplicates by always keeping the first entry. + # Compute the per-rank remove set. + plan_to_keys: dict[int, list[MetadataIndex]] = {} + for key, plans in replicated_items.items(): + for plan_idx in plans[1:]: + plan_to_keys.setdefault(plan_idx, []).append(key) + if len(plan_to_keys) > 0: + logger.info("Duplicate keys to remove: %s", plan_to_keys) + + for plan_idx, keys in plan_to_keys.items(): + key_set = set(keys) + # rewrite items and remove elements + new_items = [ + write_item + for write_item in all_plans[plan_idx].items + if write_item.index not in key_set + ] + all_plans[plan_idx] = dataclasses.replace(all_plans[plan_idx], items=new_items) + + return all_plans diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/_experimental/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/_experimental/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d28f25bab8bb9e1a5246bfe7f22940ae49636a10 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/_experimental/__pycache__/__init__.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/_experimental/__pycache__/barriers.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/_experimental/__pycache__/barriers.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1c3cfeccf21bffc138620ff93ea360738cfd6763 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/_experimental/__pycache__/barriers.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/_experimental/__pycache__/builder.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/_experimental/__pycache__/builder.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..af7154d93cd57629cfdac4732fefcafdacf67ca8 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/_experimental/__pycache__/builder.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/_experimental/__pycache__/checkpoint_process.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/_experimental/__pycache__/checkpoint_process.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3aa3c8874bf1ed2e6bfa0517bf031edd5ba7c6ee Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/_experimental/__pycache__/checkpoint_process.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/_experimental/__pycache__/checkpoint_reader.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/_experimental/__pycache__/checkpoint_reader.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c07751c14105ff905371eacb43be795d5840625f Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/_experimental/__pycache__/checkpoint_reader.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/_experimental/__pycache__/checkpoint_writer.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/_experimental/__pycache__/checkpoint_writer.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..70596a675918e54d9b74c283c851d5befe0d7199 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/_experimental/__pycache__/checkpoint_writer.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/_experimental/__pycache__/checkpointer.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/_experimental/__pycache__/checkpointer.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eae17cdba05a0cdae265dfc8d20ed0e7b7b5943b Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/_experimental/__pycache__/checkpointer.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/_experimental/__pycache__/config.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/_experimental/__pycache__/config.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b2584512173d356c1f5033c5fa3ddab2d9127828 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/_experimental/__pycache__/config.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/_experimental/__pycache__/staging.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/_experimental/__pycache__/staging.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..90ad294e985fa847a5d35fd34b8869b36b4079ec Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/_experimental/__pycache__/staging.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/_experimental/__pycache__/types.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/_experimental/__pycache__/types.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fe9fd9053bfa8b95f04d322e0cb30be989a7fa65 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/_experimental/__pycache__/types.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/_experimental/__pycache__/utils.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/_experimental/__pycache__/utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5a8110dc76f286fd89bf3df39243186121bfd232 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/_experimental/__pycache__/utils.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/_extension.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/_extension.py new file mode 100644 index 0000000000000000000000000000000000000000..663caa8a857263e3fc2924e3e1ec80d13a9ae6b0 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/_extension.py @@ -0,0 +1,223 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates + +import abc +import io +from collections.abc import Sequence +from typing import cast, IO, Optional + +# introduced as collections.abc.Buffer in Python 3.12 +from typing_extensions import Buffer + +from torch._utils import try_import + + +# NOTE: everything in this file is experimental, and subject to +# change. Feedback and bug fixes are always welcome. + +pyzstd_module_name = "pyzstd" +pyzstd = try_import(pyzstd_module_name) +zstandard_module_name = "zstandard" +zstandard = try_import(zstandard_module_name) + + +__all__ = [ + "Extension", + "StreamTransformExtension", + "ZStandard", + "ExtensionRegistry", +] + + +class Extension(abc.ABC): + """ + Extensions provide modular additions to functionality within distributed checkpointing, + which affect the layout or format of the written artifacts. Extensions may be + built into pytorch, or provided externally. + + When writing, the caller provides a list of extension instances of the appropriate + type. Each extension can output a descriptor which is used to reconstitute the + extension at read-time. + """ + + @staticmethod + @abc.abstractmethod + def registry_name() -> str: + """ + See ExtensionRegistry.from_descriptor_list + """ + + @staticmethod + @abc.abstractmethod + def from_descriptor(version: str) -> "Extension": + """ + See ExtensionRegistry.from_descriptor_list + """ + + @abc.abstractmethod + def get_descriptor(self) -> str: + """ + Return descriptor name to be included in metadata. The form should be + "extension_name[@local-domain][/version]". + """ + + +class StreamTransformExtension(Extension): + """ + An extension which performs transformation on a byte stream, such as compression + or encryption. + + Implementations should try to be memory friendly and performant. For example, don't + read the whole input, then transform it, and write it back. If at all possible, do it in + chunks. But, don't read/transform/write one byte at a time, either. + """ + + @abc.abstractmethod + def transform_to(self, output: IO[bytes]) -> IO[bytes]: + """ + Takes a writeable output stream, and generates a new stream which implements the + output transform. Input data written to the returned stream will be transformed + and written to the `output` argument stream. + """ + + @abc.abstractmethod + def transform_from(self, input: IO[bytes]) -> IO[bytes]: + """ + Takes a readable input stream, and generates a new stream which implements the + input transform. When the returned stream is read, data will be read from the + 'input' stream, transformed, and returned. + """ + + +class ZStandard(StreamTransformExtension): + @staticmethod + def is_available() -> bool: + return zstandard is not None or pyzstd is not None + + @staticmethod + # pyrefly: ignore [bad-override] + def from_descriptor(version: str) -> "ZStandard": + if version.partition(".")[0] != "1": + raise ValueError(f"Unknown extension {version=}") + if not ZStandard.is_available(): + raise ValueError( + f"Stream with ZStandard compression cannot be processed because " + f"no module named '{zstandard_module_name}' or '{pyzstd_module_name}'" + ) + return ZStandard() + + @staticmethod + def registry_name() -> str: + return "stream.zstd" + + def __init__(self) -> None: + super().__init__() + if not ZStandard.is_available(): + raise ValueError( + f"ZStandard extension is unavailable because no module named '{zstandard_module_name}' or '{pyzstd_module_name}'" + ) + + def get_descriptor(self) -> str: + return f"{self.registry_name()}/1" + + def transform_to(self, output: IO[bytes]) -> IO[bytes]: + if zstandard is not None: + compressor = zstandard.ZstdCompressor() # type: ignore[union-attr] + return compressor.stream_writer(output) + + class Writer(io.RawIOBase): + def __init__(self, output: IO[bytes]) -> None: + self.output = output + self.compressor = pyzstd.ZstdCompressor() # type: ignore[union-attr] + + def writeable(self) -> bool: + return True + + def write(self, b: Buffer) -> Optional[int]: + outdata = self.compressor.compress(b) + if outdata: + self.output.write(outdata) + return len(memoryview(b)) + + def flush(self) -> None: + outdata = self.compressor.flush() + if outdata: + self.output.write(outdata) + self.output.flush() + + return cast(IO[bytes], Writer(output)) + + def transform_from(self, input: IO[bytes]) -> IO[bytes]: + if zstandard is not None: + decompressor = zstandard.ZstdDecompressor() # type: ignore[union-attr] + return decompressor.stream_reader(input) + + class Reader(io.RawIOBase): + def __init__(self, input: IO[bytes]) -> None: + self.input = input + self.decompressor = pyzstd.EndlessZstdDecompressor() # type: ignore[union-attr] + + def readable(self) -> bool: + return True + + def readinto(self, b: Buffer) -> Optional[int]: + # This needs to read enough so it can decompress + # something so the output doesn't look like EOF. This + # means reading at least one block. The max block + # size is 128KB, so we read that plus some + # overhead to be sure. + + if self.decompressor.needs_input: + indata = self.input.read((128 + 6) * 1024) + else: + indata = b"" + + bview = memoryview(b) + blen = len(bview) + outdata = self.decompressor.decompress(indata, blen) + if outdata is None: + return None + + count = len(outdata) + bview[:count] = outdata + return count + + def seekable(self) -> bool: + return False + + return cast(IO[bytes], Reader(input)) + + +class ExtensionRegistry: + def __init__(self) -> None: + # Populate default registry contents + self.extensions: dict[str, type[Extension]] = { + cls.registry_name(): cls for cls in (ZStandard,) + } + + def register(self, cls: type[Extension]) -> None: + self.extensions[cls.registry_name()] = cls + + def from_descriptor_list(self, descriptors: Sequence[str]) -> Sequence[Extension]: + """ + Given a seuquence of descriptor strings as returned by + Extension.get_descriptor at save time, creates a sequence of + Extension instances. The name[@local-domain] preceding the + version number is used to look up an implementation class in + the registry, and the version is passed to the class's + from_descriptor static method. If the registry contains no + match, this will throw ValueError. If the from_descriptor + method raises an exception, that will pass through to the + caller. + """ + + def from_descriptor(desc: str) -> Extension: + name, _, version = desc.partition("/") + if version is None: + version = 0 + ext = self.extensions.get(name) + if not ext: + raise ValueError(f"Unknown extension {name=}") + # pyrefly: ignore [bad-argument-type] + return ext.from_descriptor(version) + + return [from_descriptor(desc) for desc in descriptors] diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/_fsspec_filesystem.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/_fsspec_filesystem.py new file mode 100644 index 0000000000000000000000000000000000000000..e239bbe891fb95b374479fdafaab9a0d16604147 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/_fsspec_filesystem.py @@ -0,0 +1,168 @@ +# Mypy will not try inferring the types of any 3rd party libraries installed. +# mypy: ignore-errors + +import io +import os +from collections.abc import Generator, Sequence +from contextlib import contextmanager +from pathlib import Path +from typing import Optional, TYPE_CHECKING, Union + +from fsspec.core import url_to_fs + +from torch.distributed.checkpoint._extension import StreamTransformExtension +from torch.distributed.checkpoint.filesystem import ( + FileSystemBase, + FileSystemReader, + FileSystemWriter, + SerializationFormat, +) + + +if TYPE_CHECKING: + from fsspec import AbstractFileSystem + + +__all__ = [ + "FsspecWriter", + "FsspecReader", +] + + +class FileSystem(FileSystemBase): + def __init__(self) -> None: + self.fs: Optional[AbstractFileSystem] = None + + @contextmanager + def create_stream( + self, path: Union[str, os.PathLike], mode: str + ) -> Generator[io.IOBase, None, None]: + if self.fs is None: + raise AssertionError("fs should not be None") + path = os.fspath(path) + + # fsspec does not support concurrent transactions, and not all + # AbstractFileSystem have working rollback implementations, so + # just manually delete the file if necessary on errors. + with self.fs.open(path, mode) as stream: + try: + yield stream + except: # noqa: B001,E722 + if any(ch in mode for ch in "w+a"): # cleanup file if not read-only + try: + self.rm_file(path) + except: # noqa: B001,E722 + pass + raise + + def concat_path( + self, path: Union[str, os.PathLike], suffix: str + ) -> Union[str, os.PathLike]: + return os.path.join(path, suffix) + + def init_path( + self, path: Union[str, os.PathLike], **kwargs + ) -> Union[str, os.PathLike]: + self.fs, _ = url_to_fs(path, **kwargs) + return path + + def rename( + self, path: Union[str, os.PathLike], new_path: Union[str, os.PathLike] + ) -> None: + self.fs.rename(path, new_path) + + def mkdir(self, path: Union[str, os.PathLike]) -> None: + self.fs.makedirs(path, exist_ok=True) + + @classmethod + def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool: + if isinstance(checkpoint_id, Path): + return False + + try: + url_to_fs(checkpoint_id) + except ValueError: + return False + + return True + + def exists(self, path: Union[str, os.PathLike]) -> bool: + return self.fs.exists(path) + + def rm_file(self, path: Union[str, os.PathLike]) -> None: + self.fs.rm(path) + + def ls(self, path: Union[str, os.PathLike]) -> list[str]: + # setting detail to False explicitly to keep the list[str] return type, + # instead of the list[Dict] return type when detail=True + return self.fs.ls(path, detail=False) + + +# TODO: add the dcp.async_save mixin +class FsspecWriter(FileSystemWriter): + """ + Basic implementation of StorageWriter using FFspec. + + This implementation makes the following assumptions and simplifications: + + * The checkpoint path is an empty or non-existing directory. + * File creation is atomic + + The checkpoint consist of one file per write request plus + a `.metadata` file with the serialized metadata. + + """ + + def __init__( + self, + path: Union[str, os.PathLike], + single_file_per_rank: bool = True, + sync_files: bool = True, + thread_count: int = 1, + per_thread_copy_ahead: int = 10_000_000, + overwrite: bool = True, + _extensions: Optional[Sequence[StreamTransformExtension]] = None, + serialization_format: SerializationFormat = SerializationFormat.TORCH_SAVE, + **kwargs, + ) -> None: + """ + Initialize the writer pointing to `path`. + + Args: + path: directory where the checkpoint will be written to. + single_file_per_rank: Produce one file per rank instead of one file per tensor/blob. Default to True. + sync_files : force files to be synced to permanent storage. Default to True. + thread_count: Number of IO threads to use to write. Default to 1. + per_thread_copy_ahead: How many bytes to copy from the GPU ahead of saving then. Default 10Mb. + overwrite: Whether to allow overwriting existing checkpoints. Defaults to True. + _extensions: Extensions to apply to output streams (EXPERIMENTAL) + + N. B. If sync_files is disabled, there's no guarantee that the checkpoint will be consistent in the case of a failure. + """ + super().__init__( + path, + single_file_per_rank, + sync_files, + thread_count, + per_thread_copy_ahead, + overwrite=overwrite, + _extensions=_extensions, + serialization_format=serialization_format, + ) + self.fs = FileSystem() + self.path = self.fs.init_path(path, **kwargs) + + @classmethod + def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool: + return FileSystem.validate_checkpoint_id(checkpoint_id) + + +class FsspecReader(FileSystemReader): + def __init__(self, path: Union[str, os.PathLike], **kwargs) -> None: + super().__init__(path) + self.fs = FileSystem() + self.path = self.fs.init_path(path, **kwargs) + + @classmethod + def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool: + return FileSystem.validate_checkpoint_id(checkpoint_id) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/_hf_utils.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/_hf_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..0d14229b7f8ccfe5a51211d0fb6a4c332af6066b --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/_hf_utils.py @@ -0,0 +1,106 @@ +import io +import json +import struct +from dataclasses import dataclass +from typing import Any, Optional + +import torch + + +_metadata_fn: str = "model.safetensors.index.json" + +FILE_NAME = "model-{cpt_idx}-of-{num_files}" +SHARDED_FILE_NAME = "shard-{shard_idx}-model-{cpt_idx}-of-{num_files}" +SUFFIX = ".safetensors" + +# metadata keys +CUSTOM_METADATA_KEY = "DCP_SHARDING_INFO" +DEFAULT_EXTRA_METADATA_KEY = "__metadata__" +SAVED_OFFSETS_KEY = "saved_offsets" +SHAPE_KEY = "shape" +DATA_KEY = "data" +DTYPE_KEY = "dtype" +DATA_OFFSETS_KEY = "data_offsets" + +DTYPE_MAP = { + "F16": torch.float16, + "F32": torch.float32, + "F64": torch.float64, + "I8": torch.int8, + "U8": torch.uint8, + "I16": torch.int16, + "I32": torch.int32, + "I64": torch.int64, + "BF16": torch.bfloat16, +} + +HF_DCP_VERSION: float = 1.0 +DCP_VERSION_KEY = "DCP_VERSION" +DCP_SHARDING_INFO_KEY = "DCP_SHARDING_INFO" + +FORMAT_KEY = "format" +FORMAT_VALUE = "pt" + +NUM_BYTES_FOR_HEADER_LEN = 8 + +SHARDED_DIR_NAME = "sharded" + + +@dataclass +class _HFStorageInfo: + """This is the per entry storage info.""" + + relative_path: str + shape: torch.Size + dtype: torch.dtype + + +def _gen_file_name( + index: int, largest_index: int, shard_index: Optional[int] = None +) -> str: + if shard_index is not None: + return ( + SHARDED_FILE_NAME.format( + shard_idx=f"{shard_index}".zfill(5), + cpt_idx=f"{index}".zfill(5), + num_files=f"{largest_index}".zfill(5), + ) + + SUFFIX + ) + else: + return ( + FILE_NAME.format( + cpt_idx=f"{index}".zfill(5), num_files=f"{largest_index}".zfill(5) + ) + + SUFFIX + ) + + +def _get_safetensors_file_metadata(file_bytes: io.IOBase) -> tuple[Any, int]: + # this uses the same logic that's done in HF code base + # https://github.com/2404589803/huggingface_hub/blob/main/src/huggingface_hub/hf_api.py#L5308 + # and follows their documentation on how their files are serialized + # https://huggingface.co/docs/safetensors/index#format + + header_len_bytes = file_bytes.read(NUM_BYTES_FOR_HEADER_LEN) + header_len = struct.unpack(" torch.dtype: + try: + dtype = DTYPE_MAP[dtype_str] + except KeyError: + dtype = torch.get_default_dtype() + + return dtype + + +def _get_dcp_custom_metadata(metadata: Any) -> Optional[Any]: + if DEFAULT_EXTRA_METADATA_KEY in metadata: + custom_metadata = metadata[DEFAULT_EXTRA_METADATA_KEY] + if CUSTOM_METADATA_KEY in custom_metadata: + return json.loads(custom_metadata[CUSTOM_METADATA_KEY]) + return None diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/_nested_dict.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/_nested_dict.py new file mode 100644 index 0000000000000000000000000000000000000000..eb26058370f766fbb96e4a5f1530577234eed62a --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/_nested_dict.py @@ -0,0 +1,69 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates + +from torch.distributed.checkpoint.metadata import STATE_DICT_TYPE + +from . import _version +from ._traverse import ( + OBJ_PATH, + set_element, + STATE_DICT_ITEM, + traverse_state_dict, + traverse_state_dict_v_2_3, +) + + +""" +TODO: +Need to add ability to handle tuple, OrderedDict, NamedTuple. +Update mappings from dict to a class. +Change set_element to recreate the right type for tuple, OrderedDict, and NamedTuple. +""" + + +FLATTEN_MAPPING = dict[str, OBJ_PATH] + + +# TODO: Update Docstring for nested_dict.py +def flatten_state_dict( + state_dict: STATE_DICT_TYPE, +) -> tuple[STATE_DICT_TYPE, FLATTEN_MAPPING]: + """ + Flatten ``state_dict`` made of nested dicts and lists into a top level dictionary. + + Use ``unflatten_state_dict`` to revert this process. + Returns: + A tuple with the flatten state_dict and a mapping from original to new state_dict. + N.B. The new keys are derived from the object paths, joined by dot. + For example: ``{ 'a': {'b':...}}`` results in the key `a.b`. + """ + flattened: STATE_DICT_TYPE = {} + mappings: FLATTEN_MAPPING = {} + + def flat_copy(path: OBJ_PATH, value: STATE_DICT_ITEM) -> None: + new_fqn = ".".join(map(str, path)) + if new_fqn in flattened: + raise ValueError(f"duplicated flatten key {new_fqn}") + flattened[new_fqn] = value + mappings[new_fqn] = path + + # We started to flatten dictionary since v2.4. But in order to not break + # the checkpoints that were saved before v2.4, we need to keep the old + # traversal so that we can reconstruct those checkpoints. + use_v_2_3 = ( + _version._derived_version is not None and _version._derived_version == "2_3" + ) + if use_v_2_3: + traverse_state_dict_v_2_3(state_dict, flat_copy) + else: + traverse_state_dict(state_dict, flat_copy) + return flattened, mappings + + +def unflatten_state_dict( + state_dict: STATE_DICT_TYPE, mapping: FLATTEN_MAPPING +) -> STATE_DICT_TYPE: + """Restore the original nested state_dict according to ``mapping`` and the flattened ``state_dict``.""" + nested: STATE_DICT_TYPE = {} + for key, value in state_dict.items(): + set_element(nested, mapping[key], value) + return nested diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/_pg_transport.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/_pg_transport.py new file mode 100644 index 0000000000000000000000000000000000000000..b258517bdcebaa553c4acf7f5511b29432304ed9 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/_pg_transport.py @@ -0,0 +1,387 @@ +import logging +import pickle +import time +from collections.abc import Callable, Generator +from contextlib import contextmanager +from dataclasses import dataclass +from datetime import timedelta +from typing import cast, Optional, TypeVar, Union + +import torch +from torch.distributed import ProcessGroup, Work +from torch.distributed._shard.sharded_tensor import ( + Shard as ShardedTensorShard, + ShardedTensor, + ShardMetadata, +) +from torch.distributed._shard.sharded_tensor.metadata import ShardedTensorMetadata +from torch.distributed.tensor import _DTensorSpec, DTensor +from torch.utils._pytree import ( + KeyPath, + tree_flatten_with_path, + tree_unflatten, + TreeSpec, +) + + +logger: logging.Logger = logging.getLogger(__name__) + +T = TypeVar("T") + + +@dataclass +class _TensorMeta: + """ + This is the metadata for a tensor that is used to transfer checkpoints. + It contains the shape, the dtype, the storage offset and the stride of the + tensor. + + This must be pickleable so that it can be sent over the wire. + """ + + shape: torch.Size + dtype: torch.dtype + storage_offset: int + stride: tuple[int, ...] + nbytes: int + + +@dataclass +class _DTensorMeta: + """ + This is the metadata for a DTensor that is used to transfer checkpoints. + It contains the metadata for the local tensor and the spec of the DTensor. + + This must be pickleable so that it can be sent over the wire. + """ + + local: _TensorMeta + spec: _DTensorSpec + + +@dataclass +class _ShardedTensorMeta: + """ + This is the metadata for a ShardedTensor that is used to transfer checkpoints. + It contains the metadata for all local shards and the global tensor metadata. + + This must be pickleable so that it can be sent over the wire. + """ + + local_shards_meta: list[_TensorMeta] + local_shards_shard_metadata: list[ + ShardMetadata + ] # Original shard metadata for each local shard + sharded_tensor_metadata: ShardedTensorMetadata + + +@dataclass +class _StateDictMeta: + """ + This is the metadata for a state dict that is used to transfer checkpoints. + It contains the step, the pytree spec of the state dict and the metadata for + each tensor in the state dict. + + This must be pickleable so that it can be sent over the wire. + + Args: + step: the step of the checkpoint to verify consistency + treespec: the pytree spec of the state dict + paths: the path of each leaf in the state dict + non_tensor_leaves: the metadata for each tensor in the state dict and any + non-tensor leaves in the state dict + """ + + treespec: TreeSpec + paths: list[KeyPath] + non_tensor_leaves: list[ + Union[object, _TensorMeta, _DTensorMeta, _ShardedTensorMeta] + ] + + +@contextmanager +def _timeit(name: str) -> Generator[None, None, None]: + start = time.perf_counter() + yield + dur = time.perf_counter() - start + logger.info("%s took %ss", name, dur) + + +def _prepare_tensor(tensor: torch.Tensor) -> tuple[torch.Tensor, _TensorMeta]: + return ( + _cast_tensor(tensor, torch.uint8), + _TensorMeta( + shape=tensor.shape, + dtype=tensor.dtype, + storage_offset=cast(int, tensor.storage_offset()), + stride=tensor.stride(), + nbytes=tensor.untyped_storage().nbytes(), + ), + ) + + +def _prepare_state_dict( + state_dict: object, + device: torch.device, +) -> tuple[_StateDictMeta, list[torch.Tensor]]: + leaves: list[tuple[KeyPath, object]] + leaves, treespec = tree_flatten_with_path(state_dict) + + paths: list[KeyPath] = [] + non_tensor_leaves: list[ + Union[object, _TensorMeta, _DTensorMeta, _ShardedTensorMeta] + ] = [] + tensors: list[torch.Tensor] = [] + for key_path, v in leaves: + paths.append(key_path) + + if isinstance(v, DTensor): + tensor, tensor_meta = _prepare_tensor(v._local_tensor) + + tensors.append(tensor) + + non_tensor_leaves.append( + _DTensorMeta( + local=tensor_meta, + spec=v._spec, + ) + ) + elif isinstance(v, ShardedTensor): + # Handle ShardedTensor by extracting all local shards + local_shards = v.local_shards() + + # Prepare metadata for all local shards + local_shards_meta = [] + local_shards_shard_metadata = [] + for shard in local_shards: + tensor, tensor_meta = _prepare_tensor(shard.tensor) + tensors.append(tensor) + local_shards_meta.append(tensor_meta) + local_shards_shard_metadata.append(shard.metadata) + + non_tensor_leaves.append( + _ShardedTensorMeta( + local_shards_meta=local_shards_meta, + local_shards_shard_metadata=local_shards_shard_metadata, + sharded_tensor_metadata=v.metadata(), # Complete metadata + ) + ) + elif isinstance(v, torch.Tensor): + tensor, tensor_meta = _prepare_tensor(v) + tensors.append(tensor) + non_tensor_leaves.append(tensor_meta) + else: + non_tensor_leaves.append(v) + + return ( + _StateDictMeta( + treespec=treespec, + paths=paths, + non_tensor_leaves=non_tensor_leaves, + ), + tensors, + ) + + +def _cast_tensor(tensor: torch.Tensor, dtype: torch.dtype) -> torch.Tensor: + """ + Casts the underlying storage to a tensor of the given dtype. + + The returned tensor will be of size ``storage.nbytes``. + + This works for all datatypes and supports strided/offset tensors with the + caveat that the cast tensor may be larger than the original tensor due to + the differences in striding. + """ + if type(tensor) is not torch.Tensor: + raise AssertionError(f"can only cast standard tensors not {type(tensor)}") + storage = tensor.untyped_storage() + ret = torch.tensor(storage, dtype=dtype, device=tensor.device) + if ret.untyped_storage() is not storage: + raise AssertionError("storage should be the same") + return ret + + +class PGTransport: + """ + This is a checkpoint transport that uses the process group to transfer checkpoints. + This allows for fast recovery of workers by fetching the current weights + from an existing worker. + + Args: + pg: the process group to use for communication + timeout: the timeout for communication + device: the device to use for tensors + state_dict: if specified this function will be called to do an inplace + receive into the returned state_dict. This is much faster than + having to allocate new tensors and transferring them to the CPU. + """ + + def __init__( + self, + pg: ProcessGroup, + timeout: timedelta, + device: torch.device, + state_dict: Optional[Callable[[], object]] = None, + ) -> None: + self._work: list[Work] = [] + self._pg = pg + self._timeout = timeout + # pyrefly: ignore [read-only] + self._device = device + self._state_dict = state_dict + + def send_checkpoint(self, dst_ranks: list[int], state_dict: object) -> None: + """ + Send a checkpoint to multiple destination ranks. + + The process: + 1. Prepares the state dict by converting tensors to a serializable format + 2. Sends metadata as pickled data + 3. Sends each tensor sequentially to all destination ranks + + Args: + dst_ranks: List of destination ranks to send the checkpoint to + state_dict: The state dictionary containing model parameters + """ + with _timeit("preparing state_dict"): + meta, tensors = _prepare_state_dict(state_dict, device=self._device) + + work = [] + + with _timeit("send meta"): + buf = pickle.dumps(meta) + len_t = torch.tensor([len(buf)], dtype=torch.int64, device=self._device) + buf_t = torch.frombuffer(buf, dtype=torch.uint8).to(self._device) + for dst_rank in dst_ranks: + work.append(self._pg.send([len_t], dst_rank, tag=1)) + work.append(self._pg.send([buf_t], dst_rank, tag=2)) + + with _timeit("send tensors"): + for i, t in enumerate(tensors): + original_device = t.device + t = t.to(self._device) + for dst_rank in dst_ranks: + work.append(self._pg.send([t], dst_rank, tag=3 + i)) + + # if we did a copy we should wait for the work to complete so we + # can free the memory to avoid OOMs + if original_device == torch.device("cpu"): + for w in work: + w.wait() + work = [] + + for w in work: + w.wait() + + def recv_checkpoint(self, src_rank: int) -> object: + """ + Receive a checkpoint from a source rank. + + The process: + 1. Receives metadata about the checkpoint structure + 2. Receives each tensor, potentially reusing existing tensors for in-place updates + 3. Reconstructs the original state dict structure + + Args: + src_rank: The source rank to receive the checkpoint from + + Returns: + The reconstructed state dictionary with model parameters + """ + state_dict = self._state_dict() if self._state_dict else {} + state_dict_leaves, _ = tree_flatten_with_path(state_dict) + + dst_tensors: dict[KeyPath, object] = dict(state_dict_leaves) + + len_t = torch.zeros(1, dtype=torch.int64, device=self._device) + self._pg.recv([len_t], src_rank, tag=1).wait() + length = cast(int, len_t.item()) + + buf = torch.empty(length, dtype=torch.uint8, device=self._device) + self._pg.recv([buf], src_rank, tag=2).wait() + + meta: _StateDictMeta = pickle.loads(buf.cpu().numpy().tobytes()) + + i: int = 0 + works: list[Work] = [] + + def recv(path: KeyPath, v: _TensorMeta) -> torch.Tensor: + nonlocal i + + inplace = dst_tensors.get(path) + if ( + isinstance(inplace, torch.Tensor) + and inplace.device.type == self._device.type + ): + if isinstance(inplace, DTensor): + inplace = inplace._local_tensor + t = _cast_tensor(inplace, torch.uint8) + if t.nbytes != v.nbytes: + raise AssertionError("inplace tensor storage must be the same size") + else: + t = torch.empty(v.nbytes, dtype=torch.uint8, device=self._device) + + work = self._pg.recv([t], src_rank, tag=3 + i) + i += 1 + + if inplace is None: + # if not inplace we need to copy it to CPU to avoid OOMing + work.wait() + t = t.cpu() + else: + works.append(work) + + return torch.as_strided( + t.view(v.dtype), + size=v.shape, + stride=v.stride, + storage_offset=v.storage_offset, + ) + + values: list[object] = [] + for path, v in zip(meta.paths, meta.non_tensor_leaves): + if isinstance(v, _TensorMeta): + values.append(recv(path, v)) + elif isinstance(v, _DTensorMeta): + tensor = recv(path, v.local) + # pyrefly: ignore [bad-argument-type, bad-argument-count, unexpected-keyword] + values.append(DTensor(tensor, v.spec, requires_grad=False)) + elif isinstance(v, _ShardedTensorMeta): + # Receive all local shards that were sent to us + local_shards = [] + current_rank = self._pg.rank() + + # Receive tensors for each local shard that was sent + for j, shard_meta in enumerate(v.local_shards_meta): + tensor = recv(path, shard_meta) + + # Use the original shard metadata that was stored during preparation + # but update the placement to reflect the current rank/device + original_shard_metadata = v.local_shards_shard_metadata[j] + updated_shard_metadata = ShardMetadata( + shard_offsets=original_shard_metadata.shard_offsets, + shard_sizes=original_shard_metadata.shard_sizes, + placement=f"rank:{current_rank}/{tensor.device.type}", + ) + + local_shard = ShardedTensorShard( + tensor=tensor, metadata=updated_shard_metadata + ) + local_shards.append(local_shard) + + # Use complete metadata to reconstruct ShardedTensor + sharded_tensor = ( + ShardedTensor._init_from_local_shards_and_global_metadata( + local_shards=local_shards, + sharded_tensor_metadata=v.sharded_tensor_metadata, + ) + ) + values.append(sharded_tensor) + else: + values.append(v) + + for work in works: + work.wait() + + return tree_unflatten(values, meta.treespec) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/_sharded_tensor_utils.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/_sharded_tensor_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..a68bcddeb7f9d9ffe6f89056dfe1ccc30cc12eb5 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/_sharded_tensor_utils.py @@ -0,0 +1,107 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates + +import copy +from typing import TYPE_CHECKING + +import torch.distributed as dist +from torch.distributed._shard.sharded_tensor import Shard, ShardedTensor, ShardMetadata +from torch.distributed.checkpoint.metadata import STATE_DICT_TYPE +from torch.distributed.remote_device import _remote_device + +from ._traverse import OBJ_PATH, set_element, STATE_DICT_ITEM, traverse_state_dict +from .utils import _element_wise_add, _normalize_device_info + + +if TYPE_CHECKING: + from torch.distributed._shard.sharded_tensor.metadata import ShardedTensorMetadata + + +# TODO: We need to refactor this code. +def _flatten_sharded_tensors(state_dict: STATE_DICT_TYPE) -> STATE_DICT_TYPE: + r""" + Transform ``state_dict`` by flattening all nested ShardedTensor instances found. + + The resulting ShardedTensor instances are only correct regarding the local shard and + MUST not be used for any other purpose but checkpointing, as no operator will work with them. + + This function should be used in conjunction with a state_dict produced by FSDP's + StateDictType.SHARDED_STATE_DICT methods. + """ + new_state_dict: STATE_DICT_TYPE = {} + + def rewrite_dict(path: OBJ_PATH, value: STATE_DICT_ITEM) -> None: + if not isinstance(value, ShardedTensor): + set_element(new_state_dict, path, value) + return + shards = value.local_shards() + + if len(shards) == 0: + return + if len(shards) != 1: + set_element(new_state_dict, path, value) + return + + outer_shard = shards[0] + + inner_st = outer_shard.tensor + if not isinstance(inner_st, ShardedTensor): + set_element(new_state_dict, path, value) + return + + if len(inner_st.local_shards()) != 1: + raise ValueError("Cannot handle inner tensor with more than 1 shard") + inner_shard = inner_st.local_shards()[0] + + local_shards = [ + Shard( + tensor=inner_shard.tensor, + metadata=ShardMetadata( + shard_offsets=_element_wise_add( + outer_shard.metadata.shard_offsets, + inner_shard.metadata.shard_offsets, + ), + shard_sizes=inner_shard.metadata.shard_sizes, + placement=f"rank:{dist.get_rank()}/{inner_shard.tensor.device}", + ), + ) + ] + + st_meta: ShardedTensorMetadata = copy.deepcopy(value.metadata()) + other_rank = 0 if dist.get_rank() > 0 else 1 + device_info = _normalize_device_info(inner_shard.tensor.device.type, 0) + + # Remove the outer ST shard the inner ST covers + for i, shard_md in enumerate(st_meta.shards_metadata): + if shard_md.shard_offsets == outer_shard.metadata.shard_offsets: + st_meta.shards_metadata.pop(i) + break + + # Attribute other rank for the other shards + for shard_md in st_meta.shards_metadata: + shard_md.placement = _remote_device(f"rank:{other_rank}/{device_info}") + + # Add other inner shards from the inner tensor + for inner_md in inner_st.metadata().shards_metadata: + if inner_md.shard_offsets != inner_shard.metadata.shard_offsets: + st_meta.shards_metadata.append( + ShardMetadata( + shard_offsets=_element_wise_add( + outer_shard.metadata.shard_offsets, + inner_md.shard_offsets, + ), + shard_sizes=inner_md.shard_sizes, + placement=f"rank:{other_rank}/{device_info}", + ) + ) + + # Finally add this shard + st_meta.shards_metadata.append(local_shards[0].metadata) + + st = ShardedTensor._init_from_local_shards_and_global_metadata( + local_shards=local_shards, + sharded_tensor_metadata=st_meta, + ) + set_element(new_state_dict, path, st) + + traverse_state_dict(state_dict, rewrite_dict) + return new_state_dict diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/_state_dict_stager.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/_state_dict_stager.py new file mode 100644 index 0000000000000000000000000000000000000000..155a87b9dec5bcd1f532d17ee2b8ef56454e37ab --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/_state_dict_stager.py @@ -0,0 +1,467 @@ +# mypy: allow-untyped-defs +import types +import warnings +import weakref +from copyreg import dispatch_table +from typing import Any + +import torch +import torch.cuda._pin_memory_utils as pin_memory_utils +from torch.storage import UntypedStorage +from torch.utils.weak import WeakIdKeyDictionary + + +class StateDictStager: + """ + A class for optimizing storage objects during staging for async checkpointing. + + StateDictStager stages the state_dict to CPU DRAM while applying optimizations + like memory sharing and pinning to improve performance. It caches storage objects + to avoid redundant copies and can be configured to automatically share memory + (for multi-process usage) and pin memory (for faster CPU-GPU transfers). + + Attributes: + pin_memory (bool): Whether to pin CPU memory for faster CPU-GPU transfers + share_memory (bool): Whether to share memory across processes + pin_memory_min_bytes (int): Minimum tensor size in bytes to pin memory (default: 5) + _cached_storage_mapping (WeakIdKeyDictionary): Maps storage objects to optimized CPU storages using weak references + """ + + def __init__( + self, + pin_memory: bool = False, + share_memory: bool = False, + pin_memory_min_bytes: int = 5, + ): + if pin_memory and not torch.cuda.is_available(): + warnings.warn( + "Ignoring pin_memory flag for checkpoint staging as pinning memory" + "requires CUDA, but CUDA is not available. ", + stacklevel=2, + ) + self.pin_memory = False + else: + self.pin_memory = pin_memory + self.share_memory = share_memory + # Mapping from original storage objects to CPU storages using weak references + self._cached_storage_mapping = WeakIdKeyDictionary() + self.pin_memory_min_bytes = pin_memory_min_bytes + + def _deepcopy_atomic(x, _): + return x + + def _deepcopy_list(x, memo, non_blocking=False): + y: list = [] + memo[id(x)] = y + append = y.append + for a in x: + append( + self.deepcopy_with_tensor_offload( + a, memo, non_blocking=non_blocking + ) + ) + return y + + def _deepcopy_tuple(x, memo, non_blocking=False): + y = [ + self.deepcopy_with_tensor_offload(a, memo, non_blocking=non_blocking) + for a in x + ] + # We're not going to put the tuple in the memo, but it's still important we + # check for it, in case the tuple contains recursive mutable structures. + try: + return memo[id(x)] + except KeyError: + pass + + # Check if any elements changed during deepcopy + for k, j in zip(x, y): + if k is not j: + # At least one element changed, create new tuple + return tuple(y) + + # No elements changed, return original tuple + return x + + def _deepcopy_dict(x, memo, non_blocking=False): + y: dict = {} + memo[id(x)] = y + for key, value in x.items(): + y[ + self.deepcopy_with_tensor_offload( + key, memo, non_blocking=non_blocking + ) + ] = self.deepcopy_with_tensor_offload( + value, memo, non_blocking=non_blocking + ) + return y + + def _deepcopy_method(x, memo, non_blocking=False): # Copy instance methods + return type(x)( + x.__func__, + self.deepcopy_with_tensor_offload( + x.__self__, memo, non_blocking=non_blocking + ), + ) + + d: dict[Any, Any] = {} + self._deepcopy_dispatch = d + d[type(None)] = _deepcopy_atomic + d[int] = _deepcopy_atomic + d[float] = _deepcopy_atomic + d[bool] = _deepcopy_atomic + d[complex] = _deepcopy_atomic + d[bytes] = _deepcopy_atomic + d[str] = _deepcopy_atomic + d[types.CodeType] = _deepcopy_atomic + d[type] = _deepcopy_atomic + d[range] = _deepcopy_atomic + d[types.BuiltinFunctionType] = _deepcopy_atomic + d[types.FunctionType] = _deepcopy_atomic + d[weakref.ref] = _deepcopy_atomic + d[property] = _deepcopy_atomic + d[types.MethodType] = _deepcopy_method + d[dict] = _deepcopy_dict + d[tuple] = _deepcopy_tuple + d[list] = _deepcopy_list + + def _stage_untyped_storage( + self, + storage: UntypedStorage, + non_blocking: bool = False, + ): + """ + Called from the hooked storage_deepcopy function in torch.Tensor.__deepcopy__. + + This method handles the storage optimization logic for the StagingStateDict class. + It checks if the storage has already been cached, and if so, reuses it. + Otherwise, it creates a new CPU storage and applies memory optimizations. + + Args: + storage: The storage to optimize + + Returns: + The optimized storage + """ + # Check if we've already cached this storage + if storage in self._cached_storage_mapping: + cached_storage = self._cached_storage_mapping[storage] + assert cached_storage.size() == storage.size(), ( + "For async checkpointing, We cache storages in DRAM and reuse them." + "Cached storage size does not match original storage size." + "This should never happen as we track the original storage weakref " + "and clean up the cache storage. Please report this to PyTorch Distributed Checkpointing." + ) + # Reuse cached storage but update with new data + cached_storage.copy_(storage, non_blocking=non_blocking) + return cached_storage + + # Create new CPU storage + if self.share_memory: + new_storage = type(storage)._new_shared(storage.size(), device="cpu") + else: + new_storage = type(storage)(storage.size(), device="cpu") + + # Skip pinning for tensors below the minimum size threshold + # Small tensors (e.g., optimizer step counters, scalars) have negligible + # transfer time improvement from pinning, but pinning overhead is significant + if self.pin_memory and new_storage.nbytes() >= self.pin_memory_min_bytes: + pin_memory_utils.pin_memory(new_storage.data_ptr(), new_storage.nbytes()) + # Set up a weak reference to unpin when cpu storage is garbage collected + f = weakref.finalize( + new_storage, pin_memory_utils.unpin_memory, new_storage.data_ptr() + ) + # This makes sure that the finalizer is not called after + # cuda context is destroyed. + f.atexit = False + + new_storage.copy_(storage, non_blocking=non_blocking) + + # Cache the storage - WeakIdKeyDictionary will automatically clean up when storage is garbage collected + self._cached_storage_mapping[storage] = new_storage + return new_storage + + @torch.no_grad() + def stage( + self, + state_dict: Any, + non_blocking: bool = False, + ) -> Any: + return self.deepcopy_with_tensor_offload(state_dict, None, [], non_blocking) + + def _offload_tensor(self, x, memo, non_blocking=False): + """ + Deep copy a PyTorch tensor with optimized storage handling. + + This method creates a CPU copy of a tensor while applying memory optimizations + like sharing and pinning based on the StateDictStager configuration. + + Args: + x: The tensor to copy + memo: Memo dictionary for tracking already copied objects + non_blocking: Whether to perform non-blocking copies where possible + + Returns: + A CPU copy of the tensor with optimized storage + """ + # if data_ptr is not 0, we allocate a new storage below. so we can skip + # memory allocation by using [] for size. + y = x.new_empty([] if x.data_ptr() != 0 else x.size(), device="cpu") + + # Store in memo dict early to handle recursive references + d = id(x) + memo[d] = y + + if type(x) is torch.Tensor or x.data_ptr() != 0: + # Get the untyped storage + untyped_storage = x.untyped_storage() + storage_id = id(untyped_storage) + + # Check if this storage has already been staged in this deepcopy operation + # This handles the case where different tensors share the same storage + # (e.g., FSDP state_dict where norm.weight and norm_weight reference same storage) + # PyTorch caches untyped_storage() calls, so same storage -> same id + if storage_id in memo: + copied_storage = memo[storage_id] + else: + # Storage not seen before in this operation, stage it + copied_storage = self._stage_untyped_storage( + untyped_storage, non_blocking=non_blocking + ) + # Add to memo to avoid re-staging if we see this storage again + memo[storage_id] = copied_storage + + # Set the tensor data using the staged storage + y.set_(copied_storage, x.storage_offset(), x.size(), x.stride()) + + # Copy any attributes the tensor might have + if hasattr(x, "__dict__"): + for attr_name, attr_value in x.__dict__.items(): + setattr( + y, + attr_name, + self.deepcopy_with_tensor_offload( + attr_value, memo, non_blocking=non_blocking + ), + ) + + if hasattr(x, "__slots__"): + for slot in x.__slots__: + if hasattr(x, slot): + setattr( + y, + slot, + self.deepcopy_with_tensor_offload( + getattr(x, slot), memo, non_blocking=non_blocking + ), + ) + + return y + + def close(self): + """ + Clean up all cached storages and release associated resources. + + This method clears the internal storage cache, allowing garbage collection + of cached CPU storages. Any pinned memory associated with cached storages + will be automatically unpinned through weak reference finalizers. + """ + self._cached_storage_mapping.clear() + + @torch.no_grad() + def deepcopy_with_tensor_offload(self, x, memo=None, _nil=[], non_blocking=False): # noqa: B006 + """Deep copy operation on arbitrary Python objects with special handling for PyTorch tensors. + + This implementation extends the standard deepcopy functionality to handle PyTorch tensors + and their storages in a way that optimizes memory usage and performance, similar to the + stage method. It applies memory sharing and pinning optimizations based on the StateDictStager + configuration. + + Args: + x: The object to deep copy + memo: Memo dictionary for tracking already copied objects + _nil: Sentinel value for memo dictionary + non_blocking: Whether to perform non-blocking copies where possible + + Returns: + A deep copy of the input object with optimized tensor storage handling + """ + if memo is None: + memo = {} + + d = id(x) + y = memo.get(d, _nil) + if y is not _nil: + return y + + cls = type(x) + + # tensors and subclasses of tensors are handled separately + if isinstance(x, torch.Tensor): + y = self._offload_tensor(x, memo, non_blocking=non_blocking) + else: + # Use the dispatch table for standard types + copier = self._deepcopy_dispatch.get(cls) + if copier is not None: + # Check if this is an atomic copier (only accepts x and memo) + if copier.__name__ == "_deepcopy_atomic": + y = copier(x, memo) + else: + y = copier(x, memo, non_blocking=non_blocking) + else: + if issubclass(cls, type): + # type copier is also atomic + y = self._deepcopy_dispatch[type](x, memo) + else: + copier = getattr(x, "__deepcopy__", None) + if copier is not None: + y = copier(memo) + else: + reductor = dispatch_table.get(cls) + if reductor: + rv = reductor(x) + else: + reductor = getattr(x, "__reduce_ex__", None) + if reductor is not None: + rv = reductor(4) + else: + reductor = getattr(x, "__reduce__", None) + if reductor: + rv = reductor() + else: + raise RuntimeError( + f"un(deep)copyable object of type {cls}" + ) + if isinstance(rv, str): + y = x + else: + # Unpack rv tuple elements (up to 5 from pickle protocol) + # and explicitly pass non_blocking as keyword arg + if len(rv) == 2: + func, args = rv + y = self._reconstruct( + x, memo, func, args, non_blocking=non_blocking + ) + elif len(rv) == 3: + func, args, state = rv + y = self._reconstruct( + x, + memo, + func, + args, + state, + non_blocking=non_blocking, + ) + elif len(rv) == 4: + func, args, state, listiter = rv + y = self._reconstruct( + x, + memo, + func, + args, + state, + listiter, + non_blocking=non_blocking, + ) + elif len(rv) == 5: + func, args, state, listiter, dictiter = rv + y = self._reconstruct( + x, + memo, + func, + args, + state, + listiter, + dictiter, + non_blocking=non_blocking, + ) + else: + raise RuntimeError( + f"Unexpected pickle protocol return value length: {len(rv)}" + ) + + # If is its own copy, don't memoize. + if y is not x: + memo[d] = y + self._keep_alive(x, memo) # Make sure x lives at least as long as d + return y + + def _keep_alive(self, x, memo): + """Keeps a reference to the object x in the memo. + + Because we remember objects by their id, we have + to assure that possibly temporary objects are kept + alive by referencing them. + We store a reference at the id of the memo, which should + normally not be used unless someone tries to deepcopy + the memo itself... + """ + try: + memo[id(memo)].append(x) + except KeyError: + # aha, this is the first one :-) + memo[id(memo)] = [x] + + def _reconstruct( + self, + x, + memo, + func, + args, + state=None, + listiter=None, + dictiter=None, + non_blocking=False, + ): + deep = memo is not None + if deep and args: + args = tuple( + self.deepcopy_with_tensor_offload(arg, memo, non_blocking=non_blocking) + for arg in args + ) + y = func(*args) + if deep: + memo[id(x)] = y + + if state is not None: + if deep: + state = self.deepcopy_with_tensor_offload( + state, memo, non_blocking=non_blocking + ) + if hasattr(y, "__setstate__"): + y.__setstate__(state) + else: + if isinstance(state, tuple) and len(state) == 2: + state, slotstate = state + else: + slotstate = None + if state is not None: + y.__dict__.update(state) + if slotstate is not None: + for key, value in slotstate.items(): + setattr(y, key, value) + + if listiter is not None: + if deep: + for item in listiter: + item = self.deepcopy_with_tensor_offload( + item, memo, non_blocking=non_blocking + ) + y.append(item) + else: + for item in listiter: + y.append(item) + if dictiter is not None: + if deep: + for key, value in dictiter: + key = self.deepcopy_with_tensor_offload( + key, memo, non_blocking=non_blocking + ) + value = self.deepcopy_with_tensor_offload( + value, memo, non_blocking=non_blocking + ) + y[key] = value + else: + for key, value in dictiter: + y[key] = value + return y diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/_storage_utils.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/_storage_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..73acc628342a058f659042b2d41c8245c86c2c42 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/_storage_utils.py @@ -0,0 +1,49 @@ +import os +from typing import Union + +from .filesystem import FileSystemReader, FileSystemWriter +from .storage import StorageReader, StorageWriter + + +def _storage_setup( + storage: Union[StorageReader, StorageWriter, None], + checkpoint_id: Union[str, os.PathLike, None], + reader: bool = False, +) -> Union[None, StorageReader, StorageWriter]: + if storage: + if checkpoint_id is not None: + storage.reset(checkpoint_id) + return storage + + if not checkpoint_id: + raise RuntimeError( + "`checkpoint_id` must be specified if " + "storage_reader/storage_writer is None." + ) + + targets: list[type[Union[StorageReader, StorageWriter]]] = [] + if reader: + targets = [ + FileSystemReader, + ] + else: + targets = [ + FileSystemWriter, + ] + try: + from ._fsspec_filesystem import FsspecReader, FsspecWriter + + targets.append(FsspecReader if reader else FsspecWriter) + except Exception: + pass + + for target in targets: + if target.validate_checkpoint_id(checkpoint_id): + storage = target(checkpoint_id) # type: ignore[call-arg] + storage.reset(checkpoint_id) + return storage + + raise RuntimeError( + "Cannot detect which StorageReader or StorageWriter to use. " + "Please specify the storage_reader/storage_writer." + ) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/_traverse.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/_traverse.py new file mode 100644 index 0000000000000000000000000000000000000000..48eb67b4f7621b1aa3a4d6b2d7c56c5503337eb7 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/_traverse.py @@ -0,0 +1,200 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates +from collections.abc import Callable, Collection, Mapping, MutableMapping +from typing import cast, Optional, TypeVar, Union + +import torch +from torch.distributed._shard.sharded_tensor.api import ShardedTensor +from torch.distributed.checkpoint.metadata import STATE_DICT_TYPE +from torch.distributed.tensor import DTensor + + +PATH_ITEM = Union[str, int] +OBJ_PATH = tuple[PATH_ITEM, ...] +T = TypeVar("T") + +STATE_DICT_ITEM = object +CONTAINER_TYPE = MutableMapping[PATH_ITEM, STATE_DICT_ITEM] + +__all__ = ["traverse_state_dict", "set_element", "get_element", "print_tensor"] + + +def _keep_visiting_tensors(value: STATE_DICT_ITEM) -> bool: + return isinstance(value, torch.Tensor) + + +# TODO: update docstring for traverse.py +def traverse_state_dict( + state_dict: STATE_DICT_TYPE, + visitor: Callable[[OBJ_PATH, STATE_DICT_ITEM], None], + keep_traversing: Callable[[STATE_DICT_ITEM], bool] = _keep_visiting_tensors, +) -> None: + """ + Invoke ``visitor`` for each value recursively in ``state_dict``. + Mapping will be traversed and ``visitor`` will be applied to the leaf elements. + ``visitor`` will only be applied to elements in a list or a tuple, if the + container contains tensors or mappings. + """ + + def _is_terminal(value: STATE_DICT_ITEM) -> bool: + values: Collection[STATE_DICT_ITEM] + if isinstance(value, Mapping): + return False + elif isinstance(value, list): + values = value + else: + return True + + for entry in values: + if isinstance(entry, (Mapping, list)) and not _is_terminal(entry): + return False + if keep_traversing is not None and keep_traversing(entry): + return False + return True + + def _traverse_obj(path: OBJ_PATH, value: STATE_DICT_ITEM) -> None: + if isinstance(value, Mapping): + for k, v in value.items(): + _traverse_obj(path + (str(k),), v) + elif _is_terminal(value): + visitor(path, value) + elif isinstance(value, (list, tuple)): + for i, v in enumerate(value): + _traverse_obj(path + (i,), v) + + for key, value in state_dict.items(): + _traverse_obj((str(key),), value) + + +def traverse_state_dict_v_2_3( + state_dict: STATE_DICT_TYPE, + visitor: Callable[[OBJ_PATH, STATE_DICT_ITEM], None], + keep_traversing: Callable[[STATE_DICT_ITEM], bool] = _keep_visiting_tensors, +) -> None: + """ + Traversal is short-circuited when if finds a collection for which ``keep_visiting_tensors`` evaluates + to false for all elements. + By default, all collections with at least one ``torch.Tensor`` element are traversed. + Visitor takes a path argument that is a tuple of the keys used to reach it. + """ + + # a value is terminal if it has no other containers values inside it + def _is_terminal(value: STATE_DICT_ITEM) -> bool: + values: Collection[STATE_DICT_ITEM] + if isinstance(value, Mapping): + values = value.values() + elif isinstance(value, list): + values = value + else: + return True + + for entry in values: + if isinstance(entry, (Mapping, list)) and not _is_terminal(entry): + return False + if keep_traversing is not None and keep_traversing(entry): + return False + return True + + def _traverse_obj(path: OBJ_PATH, value: STATE_DICT_ITEM) -> None: + if _is_terminal(value): + visitor(path, value) + elif isinstance(value, Mapping): + for k, v in value.items(): + _traverse_obj(path + (str(k),), v) + elif isinstance(value, list): + for i, v in enumerate(value): + _traverse_obj(path + (i,), v) + + for key, value in state_dict.items(): + _traverse_obj((str(key),), value) + + +def set_element( + root_dict: STATE_DICT_TYPE, path: OBJ_PATH, value: STATE_DICT_ITEM +) -> None: + """Set ``value`` in ``root_dict`` along the ``path`` object path.""" + cur_container = cast(CONTAINER_TYPE, root_dict) + + def extend_list(lst: list[STATE_DICT_ITEM], idx: int) -> None: + while len(lst) <= idx: + lst.append(None) + + for i in range(1, len(path)): + prev_key = path[i - 1] + key = path[i] + def_val = cast(STATE_DICT_ITEM, {} if type(key) is str else []) + + if isinstance(cur_container, Mapping): + cur_container = cast( + CONTAINER_TYPE, cur_container.setdefault(prev_key, def_val) + ) + else: + # pyrefly: ignore [bad-argument-type] + extend_list(cur_container, prev_key) + if cur_container[prev_key] is None: + cur_container[prev_key] = def_val + cur_container = cur_container[prev_key] + + key = path[-1] + if type(key) is int: + extend_list(cast(list[STATE_DICT_ITEM], cur_container), key) + + cur_container[key] = value + + +def get_element( + root_dict: STATE_DICT_TYPE, + path: OBJ_PATH, + default_value: Optional[T] = None, +) -> Optional[T]: + """Retrieve the value at ``path``from ``root_dict``, returning ``default_value`` if not found.""" + cur_value = cast(CONTAINER_TYPE, root_dict) + for part in path: + if type(part) is int: + if not isinstance(cur_value, list) or len(cur_value) < part: + return default_value + elif not isinstance(cur_value, Mapping) or part not in cur_value: + return default_value + + # pyrefly: ignore [index-error] + cur_value = cast(CONTAINER_TYPE, cur_value[part]) + return cast(Optional[T], cur_value) + + +def _print_nested( + value: STATE_DICT_ITEM, + prefix: str = "", + print_fun: Callable[[str], None] = print, +) -> None: + if type(value) is ShardedTensor: + print_fun(f"{prefix} ShardedTensor size: {value.size()}") + for shard in value.local_shards(): + _print_nested( + shard.tensor, + f"{shard.metadata.shard_offsets} ", + print_fun=print_fun, + ) + elif type(value) is (DTensor): + print_fun(f"{prefix} DistributedTensor size: {value.size()}") + # TODO: add local offset for _local_tensor in print_nested. + _print_nested( + value._local_tensor, + print_fun=print_fun, + ) + elif isinstance(value, torch.Tensor): + print_fun(f"{prefix} Tensor size: {value.size()}") + else: + print_fun(f"{prefix} Type: {type(value)}") + + +def print_tensor( + path: OBJ_PATH, + value: STATE_DICT_ITEM, + print_fun: Callable[[str], None] = print, +) -> None: + """ + Use this callback with traverse_state_dict to print its content. + + By default the content is printed using the builtin ``print`` but this can + be change by passing a different ``print_fun` callable. + """ + _print_nested(value, prefix=str(path), print_fun=print_fun) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/_version.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/_version.py new file mode 100644 index 0000000000000000000000000000000000000000..b3065bdfd6a2c141a959ef0ffe30aeafdc2dc54f --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/_version.py @@ -0,0 +1,6 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates + +from typing import Optional + + +_derived_version: Optional[str] = None diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/api.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/api.py new file mode 100644 index 0000000000000000000000000000000000000000..4aa4854db2358ae4361403d37d59563ab8963fbd --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/api.py @@ -0,0 +1,42 @@ +import traceback as tb +from typing import Any + + +WRAPPED_EXCEPTION = tuple[BaseException, tb.StackSummary] + +__all__ = ["CheckpointException"] + + +def _wrap_exception(exc: BaseException) -> WRAPPED_EXCEPTION: + return (exc, tb.extract_tb(exc.__traceback__)) + + +def _is_wrapped_exception(obj: Any) -> bool: + if not isinstance(obj, tuple): + return False + if len(obj) != 2: + return False + return isinstance(obj[0], BaseException) and isinstance(obj[1], tb.StackSummary) + + +class CheckpointException(BaseException): + """Exception raised if failure was detected as part of a checkpoint load or save.""" + + def __init__(self, msg: str, failures: dict[int, WRAPPED_EXCEPTION]): + super().__init__(msg, failures) + self._failures = failures + + @property + def failures(self) -> dict[int, WRAPPED_EXCEPTION]: + """Return a dictionary mapping node ranks to their associated exceptions in case of failure.""" + return self._failures + + def __str__(self) -> str: + str = f"CheckpointException ranks:{self._failures.keys()}\n" + for rank, exc_pair in self._failures.items(): + exc, trace = exc_pair + str += f"Traceback (most recent call last): (RANK {rank})\n" + if trace is not None: + str += "".join(tb.format_list(trace)) + str += "".join(tb.format_exception_only(type(exc), value=exc)) + return str diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/default_planner.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/default_planner.py new file mode 100644 index 0000000000000000000000000000000000000000..716cb90a996534e4388a42545935ebee894eeb1a --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/default_planner.py @@ -0,0 +1,702 @@ +# mypy: allow-untyped-defs +# Copyright (c) Meta Platforms, Inc. and affiliates + +import dataclasses +import io +import logging +import math +import sys +from bisect import bisect_right, insort +from collections import ChainMap +from typing import Any, cast, Optional, Union + +import torch +from torch.distributed._shard._utils import narrow_tensor_by_index +from torch.distributed.checkpoint._dedup_save_plans import dedup_save_plans +from torch.distributed.checkpoint._nested_dict import ( + FLATTEN_MAPPING, + flatten_state_dict, +) +from torch.distributed.checkpoint._sharded_tensor_utils import _flatten_sharded_tensors +from torch.distributed.checkpoint._traverse import set_element +from torch.distributed.checkpoint.metadata import ( + BytesStorageMetadata, + ChunkStorageMetadata, + Metadata, + MetadataIndex, + STATE_DICT_TYPE, + STORAGE_TYPES, + StorageMeta, + TensorStorageMetadata, +) +from torch.distributed.checkpoint.planner import ( + LoadPlan, + LoadPlanner, + ReadItem, + SavePlan, + SavePlanner, + WriteItem, + WriteItemType, +) +from torch.distributed.checkpoint.planner_helpers import ( + _compare_save_plans, + _contains_usable_plan, + _create_default_metadata_only_plan, + _create_read_items, + _create_write_items, + _init_state_dict, + _merge_delta_local_plans, +) +from torch.distributed.checkpoint.utils import find_state_dict_object +from torch.distributed.tensor import DTensor + +from . import _version + + +logger: logging.Logger = logging.getLogger(__name__) + + +__all__ = [ + "DefaultSavePlanner", + "DefaultLoadPlanner", + "create_default_local_load_plan", + "create_default_global_load_plan", + "create_default_local_save_plan", + "create_default_global_save_plan", +] + + +# TODO: Update docstrings for default_planner.py +class DefaultSavePlanner(SavePlanner): + mappings: FLATTEN_MAPPING + + def __init__( + self, + flatten_state_dict: bool = True, + flatten_sharded_tensors: bool = True, + dedup_replicated_tensors: Optional[bool] = None, + dedup_save_to_lowest_rank: bool = False, + enable_plan_caching: bool = False, + ) -> None: + self.flatten_state_dict = flatten_state_dict + self.flatten_sharded_tensors = flatten_sharded_tensors + self.mappings = {} + self.dedup_save_to_lowest_rank = dedup_save_to_lowest_rank + if dedup_replicated_tensors is not None: + logger.warning( + "DefaultSavePlanner's `dedup_replicated_tensors` argument is being " + "deprecated, and no longer has any effect. Please remove this argument " + "from your call." + ) + self._cached_plans_key: str = self.__class__.__name__ + self._enable_plan_caching = enable_plan_caching + + def set_up_planner( + self, + state_dict: STATE_DICT_TYPE, + storage_meta: Optional[StorageMeta] = None, + is_coordinator: bool = False, + ) -> None: + if self.flatten_state_dict: + state_dict, self.mappings = flatten_state_dict(state_dict) + if self.flatten_sharded_tensors: + state_dict = _flatten_sharded_tensors(state_dict) + self.state_dict = state_dict + self.is_coordinator = is_coordinator + + def create_local_plan(self) -> SavePlan: + plan = create_default_local_save_plan(self.state_dict, self.is_coordinator) + if self.flatten_state_dict: + plan = dataclasses.replace(plan, planner_data=self.mappings) + self.plan = plan + + if self._enable_plan_caching: + # If plans are equal, we can skip sending the plan to the coordinator. + if ( + self._cached_plans_key in SavePlanner._cached_save_plan + and _compare_save_plans( + plan, SavePlanner._cached_save_plan[self._cached_plans_key] + ) + ): + logger.info( + "No change in the local plan. Skipping sending the plan to the coordinator" + ) + return SavePlan([], usable=False) + else: + SavePlanner._cached_save_plan[self._cached_plans_key] = plan + + return self.plan + + def _dedup_save_plans(self, all_plans: list[SavePlan]) -> list[SavePlan]: + return dedup_save_plans(all_plans, self.dedup_save_to_lowest_rank) + + def _create_global_plan( + self, all_plans: list[SavePlan] + ) -> tuple[list[SavePlan], Metadata]: + deduped_plans = self._dedup_save_plans(all_plans) + + global_plan, metadata = create_default_global_save_plan(deduped_plans) + + if self.flatten_state_dict: + # | does not work for Python 3.8 or older version. + # merged_mappings = reduce( + # lambda x, y: x | y, (p.planner_data for p in global_plan) + # ) + planner_data_dict = [p.planner_data for p in global_plan] + merged_mappings = dict(ChainMap(*planner_data_dict)) + metadata = dataclasses.replace(metadata, planner_data=merged_mappings) + + if not _validate_global_plan(global_plan, metadata): + raise ValueError("Failed to validate global plan") + + return global_plan, metadata + + def _create_global_plan_with_caching( + self, all_plans: list[SavePlan] + ) -> tuple[list[SavePlan], list[SavePlan], Metadata]: + """ + Create global plan with caching. + Returns a tuple of global_plan_delta, global_plan, metadata. + """ + global_plan_delta: list[SavePlan] = [] + + if self._cached_plans_key not in SavePlanner._cached_all_plans: + # Case 1: If the plans are not cached, the cache will be hydrated with the + # all_plans, global_plans (Deduped), and metadata. + + # Cache the original all_plans + SavePlanner._cached_all_plans[self._cached_plans_key] = all_plans + global_plan, metadata = self._create_global_plan(all_plans) + # Cache the deduped and validated global_plan + SavePlanner._cached_global_plan[self._cached_plans_key] = global_plan + # Cache the metadata + SavePlanner._cached_metadata[self._cached_plans_key] = metadata + # If plans are not cached, global_plan delta will be the same as global plan. + return global_plan, global_plan, metadata + + # Case 2: Plans are cached + if not _contains_usable_plan(all_plans): + # Case 2.1: Plans are cached and the local plans have NOT changed (No usable plans). + # Global plan delta will be empty plans to avoid the collective overhead. + # We can reuse the deduped global plan and metadata from the cache directly. + global_plan_delta = [SavePlan([], usable=False)] * len(all_plans) + global_plan = SavePlanner._cached_global_plan[self._cached_plans_key] + metadata = SavePlanner._cached_metadata[self._cached_plans_key] + else: + # Case 2.2: Plans are cached but the local plans have changed. + # We will merge the changed local plans with the cached local plans. + # Updated plans will overwrite the cached plans. New global plan and metadata will be created and cached. + # Global plan delta will be created by comparing the new global plan with the cached global plan. + # Only the global plan delta (updated ones) will be sent to the coordinator to avoid the collective overhead. + merged_plans = _merge_delta_local_plans( + SavePlanner._cached_all_plans[self._cached_plans_key], all_plans + ) + # Cache the updated local plans + SavePlanner._cached_all_plans[self._cached_plans_key] = merged_plans + global_plan, metadata = self._create_global_plan(merged_plans) + + if self._cached_plans_key in self._cached_global_plan: + for cached_plan, new_plan in zip( + SavePlanner._cached_global_plan[self._cached_plans_key], global_plan + ): + if _compare_save_plans(cached_plan, new_plan): + global_plan_delta.append(SavePlan([], usable=False)) + else: + global_plan_delta.append(new_plan) + + # Cache the new global plan and the metadata + SavePlanner._cached_global_plan[self._cached_plans_key] = global_plan + SavePlanner._cached_metadata[self._cached_plans_key] = metadata + + return global_plan_delta, global_plan, metadata + + def create_global_plan( + self, all_plans: list[SavePlan] + ) -> tuple[list[SavePlan], Metadata]: + global_plan_delta: list[SavePlan] = [] + if self._enable_plan_caching: + # If the plans are cached, we only need to send the global plan delta to be scattered + # across ranks. Ranks will use the cached final plans instead. + ( + global_plan_delta, + global_plan, + metadata, + ) = self._create_global_plan_with_caching(all_plans) + else: + global_plan, metadata = self._create_global_plan(all_plans) + # If the caching is not enabled, global delta plan will always be same as the new global plan. + global_plan_delta = global_plan + + self.global_plan = global_plan + self.metadata = metadata + + return global_plan_delta, self.metadata + + def _finish_plan_with_caching(self, new_plan: SavePlan) -> SavePlan: + finished_plan: SavePlan = new_plan + + if not new_plan.usable: + finished_plan = SavePlanner._cached_final_save_plan[self._cached_plans_key] + else: + finished_plan = new_plan + SavePlanner._cached_final_save_plan[self._cached_plans_key] = new_plan + return finished_plan + + def finish_plan(self, new_plan: SavePlan) -> SavePlan: + finished_plan: SavePlan = new_plan + + if self._enable_plan_caching: + finished_plan = self._finish_plan_with_caching(new_plan) + + self.plan = finished_plan + return self.plan + + def resolve_data(self, write_item: WriteItem) -> Union[torch.Tensor, io.BytesIO]: + object = self.lookup_object(write_item.index) + return self.transform_object(write_item, object) + + def lookup_object(self, index: MetadataIndex) -> Any: + """Extension from the planner interface to make it easy to extend the default planner.""" + return find_state_dict_object(self.state_dict, index) + + def transform_object(self, write_item: WriteItem, object: Any): + """Extension from the planner interface to make it easy to extend the default planner.""" + if write_item.type == WriteItemType.BYTE_IO: + bytes = io.BytesIO() + torch.save(object, bytes) + object = bytes + return object + + +class DefaultLoadPlanner(LoadPlanner): + """ + DefaultLoadPlanner that adds multiple features on top of LoadPlanner. + + In particular it adds the following: + + flatten_state_dict: Handle state_dict with nested dicts + flatten_sharded_tensors: For FSDP in 2D parallel mode + allow_partial_load: If False, will raise a runtime error if a key is present in state_dict, but not in the checkpoint. + """ + + original_state_dict: STATE_DICT_TYPE + mappings: FLATTEN_MAPPING + + def __init__( + self, + flatten_state_dict: bool = True, + flatten_sharded_tensors: bool = True, + allow_partial_load: bool = False, + ) -> None: + self.flatten_state_dict = flatten_state_dict + self.flatten_sharded_tensors = flatten_sharded_tensors + self.original_state_dict = {} + self.mappings = {} + self.allow_partial_load = allow_partial_load + + def set_up_planner( + self, + state_dict: STATE_DICT_TYPE, + metadata: Optional[Metadata] = None, + is_coordinator: bool = False, + ) -> None: + _init_state_dict(state_dict) + self.original_state_dict = state_dict + + if self.flatten_sharded_tensors: + state_dict = _flatten_sharded_tensors(state_dict) + + if self.flatten_state_dict: + state_dict, self.mappings = flatten_state_dict(state_dict) + + self.state_dict = state_dict + self.metadata = metadata + self.is_coordinator = is_coordinator + + def create_local_plan(self) -> LoadPlan: + if self.metadata is None: + raise AssertionError("self.metadata is not None") + if self.flatten_state_dict: + # To support checkpoints that are saved before v2.4, we have to + # differentiate if the missing keys are due to old checkpoints. + # The contracts are: + # 1. There are 3 cases when we found a missing key. + # 1.1 Actual missing key, but allow_partial_load is False + # 1.2 Actual missing key, but allow_partial load is True + # 1.3 Old checkpoint, but allow_partial_load is False + # 1.4 Old checkpoint, but allow_partial_load is True + # 2. If we found a missing key, we first convert the keys back to + # the key format of v2.3 + # 3. If the previous missing keys are in the v2.3 keys, we assume + # this is a old checkpoint. + # 4. Pass the state_dict to `create_default_local_load_plan()`, + # which has the logic to check missing for allow_partial_load. + # So for 1.2 and 1.4 cases, we delegate allow_partial_load check to + # `create_default_local_load_plan()`. The logic here is to determine + # whether the checkpoint belong to 2.3 (or before) or 2.4 (or after). + current_keys = set(self.state_dict.keys()) + load_keys = set(self.metadata.state_dict_metadata.keys()) + missing_keys = load_keys - current_keys + if missing_keys: + _version._derived_version = "2_3" + old_state_dict, old_mappings = flatten_state_dict( + self.original_state_dict + ) + old_keys = set(old_state_dict.keys()) + if old_keys & missing_keys: + self.state_dict, self.mappings = old_state_dict, old_mappings + # _derived_version is only used by flatten_state_dict now. + # Set it back to None so that later we can save to a new version. + _version._derived_version = None + + return create_default_local_load_plan( + self.state_dict, self.metadata, not self.allow_partial_load + ) + + def create_global_plan(self, global_plan: list[LoadPlan]) -> list[LoadPlan]: + return create_default_global_load_plan(global_plan) + + def finish_plan(self, new_plan: LoadPlan) -> LoadPlan: + return new_plan + + def load_bytes(self, read_item: ReadItem, value: io.BytesIO) -> None: + if self.flatten_state_dict: + set_element( + self.original_state_dict, + self.mappings[read_item.dest_index.fqn], + torch.load(value, weights_only=False), + ) + else: + self.state_dict[read_item.dest_index.fqn] = torch.load( + value, weights_only=False + ) + + def resolve_tensor(self, read_item: ReadItem): + tensor = self.lookup_tensor(read_item.dest_index) + return self.transform_tensor(read_item, tensor) + + def commit_tensor(self, read_item: ReadItem, tensor: torch.Tensor) -> None: + pass + + def lookup_tensor(self, index: MetadataIndex) -> torch.Tensor: + """Extension from the planner interface to make it easy to extend the default planner.""" + return find_state_dict_object(self.state_dict, index) + + def transform_tensor(self, read_item: ReadItem, tensor: torch.Tensor): + """Extension from the planner interface to make it easy to extend the default planner.""" + return narrow_tensor_by_index(tensor, read_item.dest_offsets, read_item.lengths) + + +class _EmptyStateDictLoadPlanner(DefaultLoadPlanner): + """ + Extension of DefaultLoadPlanner, which rebuilds state_dict from the saved metadata. + Useful for loading in state_dict without first initializing a model, such as + when converting a DCP checkpoint into a Torch save file. + + . N.B. `state_dict` must be an empty dictionary when used with this LoadPlanner + + .. warning:: + Because the entire state dict is initialized, It's recommended to only utilize + this LoadPlanner on a single rank or process to avoid OOM. + + """ + + def __init__(self, keys=None, *args, **kwargs): + self.keys = keys + super().__init__(*args, **kwargs) + + def _should_include_key(self, key: str, metadata: Metadata) -> bool: + if self.keys is None: + return True + + if key in self.keys: + return True + + unflattened_keys: list[str] = [] + planner_data = metadata.planner_data.get(key) + for unflattened_key in planner_data: + if unflattened_keys: + unflattened_keys.append( + ".".join([unflattened_keys[-1], str(unflattened_key)]) + ) + + else: + unflattened_keys.append(unflattened_key) + + if any(unflattened_key in self.keys for unflattened_key in unflattened_keys): + return True + + return False + + def set_up_planner( + self, + state_dict: STATE_DICT_TYPE, + metadata: Optional[Metadata] = None, + is_coordinator: bool = False, + ) -> None: + if state_dict: + raise AssertionError("not state_dict") + if metadata is None: + raise AssertionError("metadata is not None") + + # rebuild the state dict from the metadata + for k, v in metadata.state_dict_metadata.items(): + if not self._should_include_key(k, metadata): + continue + + if isinstance(v, TensorStorageMetadata): + v = torch.empty(v.size, dtype=v.properties.dtype) # type: ignore[assignment] + if metadata.planner_data is not None and k in metadata.planner_data: + set_element(state_dict, metadata.planner_data[k], v) + else: + state_dict[k] = v + + super().set_up_planner(state_dict, metadata, is_coordinator) + + +def create_default_local_load_plan( + state_dict: dict[str, Any], metadata: Metadata, strict: bool = True +) -> LoadPlan: + requests = [] + """ + Create the ``LoadPlan`` used by DefaultLoadPlanner. + + It produces one read item per value in ``state_dict`` using the metadata in ``metadata``. + + The default behavior is to match key exactly between state_dict and metadata. + It handles resharding by issuing multiple read requests against storage in order to match + load requirements. + """ + + for fqn, obj in state_dict.items(): + # ignore state_dict keys which do not exist in `state_dict` if strict=False + if fqn not in metadata.state_dict_metadata: + if strict: + raise RuntimeError(f"Missing key in checkpoint state_dict: {fqn}.") + else: + continue + + md = metadata.state_dict_metadata[fqn] + if ( + isinstance(md, TensorStorageMetadata) + and getattr(obj, "size", None) is not None + and md.size != obj.size() + ): + raise ValueError( + f"Size mismatch between saved {md.size} and current: {obj.size()} for {fqn}", + ) + # Since DTensor supports submesh, adding extra check to ensure _create_read_items() + # gets called only when the current rank is part of the mesh for the corresponding DTensor. + if isinstance(obj, DTensor): + if obj.device_mesh.get_coordinate() is not None: + requests += _create_read_items(fqn, md, obj) + else: + requests += _create_read_items(fqn, md, obj) + + return LoadPlan(requests) + + +def create_default_global_load_plan( + all_plans: list[LoadPlan], +) -> list[LoadPlan]: + """ + Create global load plan used by DefaultLoadPlanner. + + The default load behavior involved no global coordination and this function + currently doesn't change the local plans. + """ + return all_plans + + +def create_default_local_save_plan( + state_dict: dict[str, Any], is_coordinator: bool +) -> SavePlan: + """ + Create the ``SavePlan`` used by DefaultSavePlanner. + + On non-coordinator ranks, this function ignores tensors and non-tensor objects, + only producing writes for ShardedTensor objects. + + On the coordinator rank, produce writes for all values. + """ + requests = [] + for fqn, obj in state_dict.items(): + # Since DTensor supports submesh, adding extra check to ensure _create_write_items() + # gets called only when the current rank is part of the mesh for the corresponding DTensor. + if isinstance(obj, DTensor): + if obj.device_mesh.get_coordinate() is not None: + requests += _create_write_items(fqn, obj) + else: + # For the plain tensor and non-tensor values, add the request for all + # the ranks. Coordinator will decides whether to deduplicate the + # values based on the keys. + requests += _create_write_items(fqn, obj) + + return SavePlan(requests) + + +def create_default_global_save_plan( + all_plans: list[SavePlan], + rewrite_index_hints: bool = True, +) -> tuple[list[SavePlan], Metadata]: + """ + Create the global plan and metadata used by DefaultSavePlanner. + + Metadata is produced by concatenating the metadata of all ``WriteItem`` from the supplied plans. + + The only global planning change is to update index hints in all ``MetadataIndex`` objects if + ``rewrite_index_hints`` is True. + """ + md: dict[str, STORAGE_TYPES] = {} + new_plans = [] + for plan in all_plans: + new_items = [] + for item in plan.items: + if item.type != WriteItemType.SHARD: + if item.index.fqn in md: + raise AssertionError("item.index.fqn not in md") + + if item.type == WriteItemType.BYTE_IO: + md[item.index.fqn] = BytesStorageMetadata() + new_items.append(item) + else: + if item.tensor_data is None: + raise AssertionError("item.tensor_data is not None") + tensor_md = cast( + TensorStorageMetadata, + md.setdefault( + item.index.fqn, + TensorStorageMetadata( + properties=item.tensor_data.properties, + size=item.tensor_data.size, + chunks=[], + ), + ), + ) + new_item = item + if rewrite_index_hints: + new_index = dataclasses.replace( + item.index, index=len(tensor_md.chunks) + ) + new_item = dataclasses.replace(item, index=new_index) + new_items.append(new_item) + + if item.tensor_data.chunk is None: + raise AssertionError(f""" + Cannot create MD for tensor without bounds. + FQN: {item.index.fqn} + """) + tensor_md.chunks.append(item.tensor_data.chunk) + new_plans.append(dataclasses.replace(plan, items=new_items)) + return (new_plans, Metadata(md)) + + +def _create_default_local_metadata(state_dict: STATE_DICT_TYPE) -> Metadata: + """Return the ``Metadata`` if DefaultSavePlanner was used to checkpoint ``state_dict``.""" + plan = _create_default_metadata_only_plan(state_dict) + _, md = create_default_global_save_plan([plan]) + return md + + +def _check_box_overlap(box0: ChunkStorageMetadata, box1: ChunkStorageMetadata) -> bool: + """Check if two boxes overlap. Tuples are (offset, lengths).""" + # For each dim of each shard, check if one shard resides on the other + # end of second shard with respect to that dim. As an example for a 2D + # shard, we would check if one shard is above or on the left of the + # other shard. + ndims = len(box0.offsets) + for i in range(ndims): + if box0.offsets[i] >= box1.offsets[i] + box1.sizes[i]: + return False + if box1.offsets[i] >= box0.offsets[i] + box0.sizes[i]: + return False + + return True + + +def _check_box_bounds( + outer_box_size: torch.Size, inner_box: ChunkStorageMetadata +) -> bool: + for i in range(len(outer_box_size)): + if inner_box.offsets[i] < 0: + return False + if inner_box.sizes[i] < 0: + return False + if inner_box.offsets[i] + inner_box.sizes[i] > outer_box_size[i]: + return False + + return True + + +def _validate_global_plan(global_plan: list[SavePlan], metadata: Metadata) -> bool: + all_good = True + for key, value in metadata.state_dict_metadata.items(): + if isinstance(value, BytesStorageMetadata): + continue + if len(value.size) == 0: + continue + chunks = value.chunks + chunks_volume = 0 + for chunk in chunks: + # Compute the volume + if not _check_box_bounds(value.size, chunk): + logger.warning( + """ + key:%s has out of bounds chunk: + tensor-size:%s chunk: %s + """, + key, + value.size, + chunk, + ) + all_good = False + chunks_volume += math.prod(chunk.sizes) + + if len(chunks) > 1: + dims = len(value.size) + sweep_dim = max(range(dims), default=0, key=lambda d: value.size[d]) + sorted_indices = sorted( + range(len(chunks)), + key=lambda idx: ( + chunks[idx].offsets[sweep_dim], + *(chunks[idx].offsets[d] for d in range(dims)), + ), + ) + active: list[tuple[int, int]] = [] + for idx in sorted_indices: + current = chunks[idx] + start = current.offsets[sweep_dim] + end = start + current.sizes[sweep_dim] + + cutoff = bisect_right(active, (start, sys.maxsize)) + if cutoff: + del active[:cutoff] + + for _, other_idx in active: + other = chunks[other_idx] + if _check_box_overlap(current, other): + logger.warning( + "key:%s has overlapping chunks: %s %s", + key, + current, + other, + ) + all_good = False + + insort(active, (end, idx)) + + # Check whether combined chunk cover the whole tensor + tensor_volume = math.prod(value.size) + if len(global_plan) > 1 and chunks_volume != tensor_volume: + logger.warning( + """ + key:%s invalid fill tensor-volume: + %s chunks-volume: %s + """, + key, + tensor_volume, + chunks_volume, + ) + all_good = False + + return all_good diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/filesystem.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/filesystem.py new file mode 100644 index 0000000000000000000000000000000000000000..b21cac12ff90522f075b7b32029eae01e7a92169 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/filesystem.py @@ -0,0 +1,1035 @@ +# mypy: allow-untyped-defs +import collections +import dataclasses +import io +import json +import operator +import os +import pickle +import queue +import threading +import uuid +import warnings +from abc import ABC, abstractmethod +from collections.abc import Callable, Generator, Iterable, Iterator, Sequence +from contextlib import contextmanager +from dataclasses import dataclass +from enum import Enum +from io import UnsupportedOperation +from pathlib import Path +from typing import Any, cast, Final, IO, Optional, Union + +# introduced as collections.abc.Buffer in Python 3.12 +from typing_extensions import Buffer + +import torch +from torch import Tensor +from torch._utils import _get_available_device_type, _get_device_module +from torch.distributed._shard._utils import narrow_tensor_by_index +from torch.distributed.checkpoint._extension import ( + ExtensionRegistry, + StreamTransformExtension, +) +from torch.distributed.checkpoint._hf_utils import ( + CUSTOM_METADATA_KEY, + DCP_VERSION_KEY, + FORMAT_KEY, + FORMAT_VALUE, + HF_DCP_VERSION, +) +from torch.distributed.checkpoint.metadata import Metadata, STATE_DICT_TYPE, StorageMeta +from torch.distributed.checkpoint.planner import ( + LoadItemType, + LoadPlan, + LoadPlanner, + ReadItem, + SavePlan, + SavePlanner, + WriteItem, + WriteItemType, +) +from torch.distributed.checkpoint.staging import BlockingAsyncStager +from torch.distributed.checkpoint.storage import ( + StorageReader, + StorageWriter, + WriteResult, +) +from torch.distributed.checkpoint.utils import _create_file_view +from torch.futures import Future + + +__all__ = [ + "FileSystemWriter", + "FileSystemReader", + "FileSystem", + "FileSystemBase", + "SerializationFormat", +] + +_metadata_fn: str = ".metadata" + +CURRENT_DCP_VERSION: Final[str] = "1.0.0" + + +@dataclass +class _StorageInfo: + """This is the per entry storage info.""" + + relative_path: str + offset: int + length: int + transform_descriptors: Optional[Sequence[str]] = None + + def __getstate__(self): + return {k: v for k, v in self.__dict__.items() if v is not None} + + +@dataclass +class _StoragePrefix: + prefix: str + + +class SerializationFormat(Enum): + TORCH_SAVE = "torch_save" + SAFETENSORS = "safetensors" + + +DEFAULT_SUFFIX = ".distcp" + + +def _generate_uuid() -> str: + return str(uuid.uuid4()) + + +class _TensorLoader(ABC): + @abstractmethod + def add(self, size: int, obj: object) -> None: + pass + + @abstractmethod + def start_loading(self) -> None: + pass + + @abstractmethod + def values(self) -> Iterator[tuple[torch.Tensor, object]]: + pass + + +class _SerialCpuLoader(_TensorLoader): + def __init__(self, resolve_fun: Callable) -> None: + self.resolve_fun = resolve_fun + self.items: list[tuple[int, object]] = [] + + def add(self, size: int, obj: object) -> None: + self.items.append((size, obj)) + + def start_loading(self) -> None: + pass + + def values(self) -> Iterator[tuple[torch.Tensor, object]]: + for _, obj in self.items: + tensor = self.resolve_fun(obj).detach() + tensor = tensor.cpu() + if tensor.storage().size() != tensor.numel(): + tensor = tensor.clone() + yield ( + tensor, + obj, + ) + + +class _OverlappingCpuLoader(_TensorLoader): + def __init__( + self, + resolve_fun: Callable, + stream: Optional[torch.Stream] = None, + inflight_threshhold: int = 1_000_000, + ) -> None: + self.resolve_fun = resolve_fun + self.items: list[tuple[int, object]] = [] + self.inflight_threshhold = inflight_threshhold + self.in_flight_data = 0 + self.current_items: collections.deque = collections.deque() + self.idx = 0 + self.started = False + self.device_type = ( + stream.device_type if stream else _get_available_device_type() + ) + self.device_module = _get_device_module(self.device_type) + self.stream = cast( + torch.cuda.Stream, stream or self.device_module.current_stream() + ) + if self.stream != self.device_module.current_stream(): + self.stream.wait_stream(self.device_module.current_stream()) + + @property + def _done(self) -> bool: + return self.idx >= len(self.items) + + def _drain(self) -> list[tuple[torch.Tensor, object]]: + drained = [] + if self.in_flight_data >= self.inflight_threshhold: + self.stream.synchronize() + while self.in_flight_data >= self.inflight_threshhold: + val = self.current_items.popleft() + self.in_flight_data -= val[0].numel() * val[0].element_size() + drained.append(val) + return drained + + def _refill(self) -> None: + with self.device_module.stream(self.stream): + while not self._done and self.in_flight_data < self.inflight_threshhold: + _, obj = self.items[self.idx] + self.idx += 1 + tensor = self.resolve_fun(obj).detach() + if tensor.device.type == self.device_type: + tensor = tensor.to(device="cpu", non_blocking=True) + elif tensor.device == torch.device("cpu"): + if ( + tensor.untyped_storage().size() + != tensor.numel() * tensor.itemsize + ): + # this forces the tensor to be both contiguous and with minimal storage + tensor = tensor.clone() + + self.current_items.append( + ( + tensor, + obj, + ) + ) + self.in_flight_data += tensor.numel() * tensor.element_size() + + def _finish(self) -> Iterable[tuple[torch.Tensor, object]]: + if not self._done: + raise AssertionError("_finish called before all items were processed") + if len(self.current_items) > 0: + self.stream.synchronize() + return self.current_items + + def add(self, size: int, obj: object) -> None: + if self.started: + raise RuntimeError("cannot add items after loading started") + self.items.append((size, obj)) + + def start_loading(self) -> None: + if self.started: + return + self.started = True + self.items.sort(key=operator.itemgetter(0)) + self._refill() + + def values(self) -> Iterator[tuple[torch.Tensor, object]]: + self.start_loading() + while not self._done: + drained = self._drain() + self._refill() + yield from drained + + yield from self._finish() + + +class _StorageWriterTransforms: + """ + This is experimental, and will likely move elsewhere in the + future. It lives here to minimize changes while we are still + learning and gathering feedback. + """ + + def __init__( + self, extensions: Optional[Sequence[StreamTransformExtension]] = None + ) -> None: + """ + If the extensions arg is None, this means the implementation + should provide whatever defaults it chooses. An empty + sequence indicates no extensions should be used. At this + time, the default extensions sequence is empty. + """ + self.extensions = () if extensions is None else extensions + + def transform_save_stream( + self, write_item: WriteItem, raw_stream: io.IOBase + ) -> tuple[IO[bytes], list[str]]: + # In order to avoid leaking fds, transformers' close must + # cascade to wrapped streams, but since this function can + # append to the raw stream, we can't close the actual stream. + # So, we use this to put a wrapper around the raw stream's + # close() to make it a noop, and it gets closed once all files + # are appended. + + class NoCloseWriter(io.IOBase): + def __init__(self, raw: io.IOBase): + self.raw = raw + + def writeable(self) -> bool: + return True + + def write(self, b: Buffer) -> int: + return self.raw.write(b) + + def close(self): + self.flush() + self.raw.flush() + # but not close. + + transform_to = cast(IO[bytes], NoCloseWriter(raw_stream)) + + for ex in self.extensions: + transform_to = ex.transform_to(transform_to) + + return (transform_to, [ex.get_descriptor() for ex in reversed(self.extensions)]) + + +def _item_size(item: WriteItem) -> int: + size = 1 + if item.tensor_data is None: + raise AssertionError("WriteItem tensor_data must not be None") + # can't use math.prod as PT needs to support older python + for s in item.tensor_data.size: + size *= s + + dtype = item.tensor_data.properties.dtype + return size * torch._utils._element_size(dtype) + + +def _split_by_size_and_type(bins: int, items: list[WriteItem]) -> list[list[WriteItem]]: + if bins == 1: + return [items] + + bytes_w = [wi for wi in items if wi.type == WriteItemType.BYTE_IO] + tensor_w = [wi for wi in items if wi.type != WriteItemType.BYTE_IO] + + buckets: list[list[WriteItem]] = [[] for _ in range(bins)] + bucket_sizes = [0 for _ in range(bins)] + + tensor_w.sort(key=_item_size, reverse=True) + + for i, wi in enumerate(bytes_w): + buckets[i % bins].append(wi) + + for wi in tensor_w: + # TODO replace with headq + idx = min(enumerate(bucket_sizes), key=operator.itemgetter(1))[0] + buckets[idx].append(wi) + bucket_sizes[idx] += _item_size(wi) + + return buckets + + +def _write_item( + transforms: _StorageWriterTransforms, + stream: io.IOBase, + data: Union[io.BytesIO, torch.Tensor], + write_item: WriteItem, + storage_key: str, + serialization_format: SerializationFormat, +) -> WriteResult: + offset = stream.tell() + + (transform_to, transform_descriptors) = transforms.transform_save_stream( + write_item, stream + ) + + if write_item.type == WriteItemType.BYTE_IO: + if not isinstance(data, io.BytesIO): + raise AssertionError("Data must be io.BytesIO for BYTE_IO write items") + transform_to.write(data.getbuffer()) + else: + if not isinstance(data, torch.Tensor): + raise AssertionError( + "Data must be torch.Tensor for non-BYTE_IO write items" + ) + if data.device != torch.device("cpu"): + raise AssertionError("Tensor must be on CPU device") + if serialization_format == SerializationFormat.TORCH_SAVE: + torch.save(data, transform_to) + + transform_to.close() + + if serialization_format == SerializationFormat.TORCH_SAVE or isinstance( + data, io.BytesIO + ): + length = stream.tell() - offset + else: + length = data.numel() * data.element_size() + + # For consistency with earlier versions, leave this field out of the + # metadata if there are no extensions. + info_transform_descriptors = ( + None if len(transform_descriptors) == 0 else transform_descriptors + ) + + return WriteResult( + index=write_item.index, + size_in_bytes=length, + storage_data=_StorageInfo( + storage_key, + offset, + length, + transform_descriptors=info_transform_descriptors, + ), + ) + + +def _write_files_from_queue( + create_stream: Callable, + file_queue: queue.Queue, + result_queue: queue.Queue, + planner: SavePlanner, + transforms: _StorageWriterTransforms, + inflight_threshhold: int, + use_fsync: bool, + thread_count: int, + serialization_format: SerializationFormat, +) -> None: + try: + while True: + file_name, storage_key, write_items = file_queue.get_nowait() + loader: _TensorLoader + + custom_backend_name = torch._C._get_privateuse1_backend_name() + custom_device_mod = getattr(torch, custom_backend_name, None) + + # TODO: Using the OverlappingCpuLoader with multiple threads creates significant + # performance degradation, observed as being related to cuda stream syncs. We + # should try to fix this and use _OverlappingCpuLoader for all threaded cases + if ( + thread_count == 1 + and ( + torch.cuda.is_available() + or (custom_device_mod and custom_device_mod.is_available()) + ) + and inflight_threshhold > 0 + ): + loader = _OverlappingCpuLoader( + planner.resolve_data, + inflight_threshhold=inflight_threshhold, + ) + else: + loader = _SerialCpuLoader( + planner.resolve_data, + ) + + tensor_w = [wi for wi in write_items if wi.type != WriteItemType.BYTE_IO] + for write_item in tensor_w: + loader.add(_item_size(write_item), write_item) + loader.start_loading() + + bytes_w = [wi for wi in write_items if wi.type == WriteItemType.BYTE_IO] + write_results = [] + + with create_stream(file_name, "wb") as stream: + for write_item in bytes_w: + data = planner.resolve_data(write_item) + write_results.append( + _write_item( + transforms, + stream, + data, + write_item, + storage_key, + serialization_format, + ) + ) + + tensor_dict = {} + metadata_dict = {} + for tensor, write_item in loader.values(): + if not tensor.is_cpu: + raise AssertionError("Tensor must be on CPU") + write_results.append( + _write_item( + transforms, + stream, + tensor, + write_item, # type: ignore[arg-type] + storage_key, + serialization_format, + ) + ) + tensor_dict[write_item.index.fqn] = tensor # type: ignore[attr-defined] + metadata_dict[write_item.index.fqn] = { # type: ignore[attr-defined] + "saved_offsets": write_item.tensor_data.chunk.offsets # type: ignore[attr-defined] + } + + if serialization_format == SerializationFormat.SAFETENSORS: + from safetensors.torch import save # type: ignore[import-not-found] + + stream.write( + save( + tensor_dict, + metadata={ + CUSTOM_METADATA_KEY: json.dumps(metadata_dict), + DCP_VERSION_KEY: str(HF_DCP_VERSION), + FORMAT_KEY: FORMAT_VALUE, + }, + ) + ) + + if use_fsync: + try: + os.fsync(stream.fileno()) + except (AttributeError, UnsupportedOperation): + os.sync() + stream.close() + result_queue.put(write_results) + except queue.Empty: + pass + + +class FileSystemBase(ABC): + @contextmanager + @abstractmethod + def create_stream( + self, path: Union[str, os.PathLike], mode: str + ) -> Generator[io.IOBase, None, None]: ... + + @abstractmethod + def concat_path( + self, path: Union[str, os.PathLike], suffix: str + ) -> Union[str, os.PathLike]: ... + + @abstractmethod + def rename( + self, path: Union[str, os.PathLike], new_path: Union[str, os.PathLike] + ) -> None: ... + + @abstractmethod + def init_path(self, path: Union[str, os.PathLike]) -> Union[str, os.PathLike]: ... + + @abstractmethod + def mkdir(self, path: Union[str, os.PathLike]) -> None: ... + + @classmethod + @abstractmethod + def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool: ... + + @abstractmethod + def exists(self, path: Union[str, os.PathLike]) -> bool: ... + + @abstractmethod + def rm_file(self, path: Union[str, os.PathLike]) -> None: ... + + +class FileSystem(FileSystemBase): + @contextmanager + def create_stream( + self, path: Union[str, os.PathLike], mode: str + ) -> Generator[io.IOBase, None, None]: + if not isinstance(path, Path): + path = Path(path) + with path.open(mode) as stream: + yield cast(io.IOBase, stream) + + def concat_path( + self, path: Union[str, os.PathLike], suffix: str + ) -> Union[str, os.PathLike]: + if not isinstance(path, Path): + path = Path(path) + return path / suffix + + def init_path(self, path: Union[str, os.PathLike]) -> Union[str, os.PathLike]: + if not isinstance(path, Path): + path = Path(path) + return path + + def rename( + self, path: Union[str, os.PathLike], new_path: Union[str, os.PathLike] + ) -> None: + if not isinstance(path, Path): + path = Path(path) + + path.rename(cast(Path, new_path)) + + def mkdir(self, path: Union[str, os.PathLike]) -> None: + if not isinstance(path, Path): + path = Path(path) + path.mkdir(parents=True, exist_ok=True) + + @classmethod + def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool: + if isinstance(checkpoint_id, Path): + return True + + if "://" in str(checkpoint_id): + return False + + for p in Path(checkpoint_id).parents: + if p.exists() and os.access(str(p), os.W_OK): + return True + + return False + + def exists(self, path: Union[str, os.PathLike]) -> bool: + if not isinstance(path, Path): + path = Path(path) + return path.exists() + + def rm_file(self, path: Union[str, os.PathLike]) -> None: + if not isinstance(path, Path): + path = Path(path) + path.unlink() + + def ls(self, path: Union[str, os.PathLike]) -> list[str]: + if not isinstance(path, Path): + path = Path(path) + return [str(p) for p in path.iterdir()] + + +class _FileSystemWriter(StorageWriter): + """ + Basic implementation of StorageWriter using file IO. + + This implementation makes the following assumptions and simplifications: + + * The checkpoint path is an empty or non-existing directory. + * File creation is atomic + + The checkpoint consist of one file per write request plus + a `.metadata` file with the serialized metadata. + + """ + + def __init__( + self, + path: Union[str, os.PathLike], + single_file_per_rank: bool = True, + sync_files: bool = True, + thread_count: int = 1, + per_thread_copy_ahead: int = 10_000_000, + overwrite: bool = True, + _extensions: Optional[Sequence[StreamTransformExtension]] = None, + serialization_format: SerializationFormat = SerializationFormat.TORCH_SAVE, + *args: Any, + **kwargs: Any, + ) -> None: + """ + Initialize the writer pointing to `path`. + + Args: + path: directory where the checkpoint will be written to. + single_file_per_rank: Produce one file per rank instead of one file per tensor/blob. Default to True. + sync_files : force files to be synced to permanent storage. Default to True. + thread_count: Number of IO threads to use to write. Default to 1. + per_thread_copy_ahead: How many bytes to copy from the GPU ahead of saving then. Default 10Mb. + overwrite: Whether to allow overwriting existing checkpoints. Defaults to True. + _extensions: Extensions to apply to output streams (EXPERIMENTAL) + + N. B. If sync_files is disabled, there's no guarantee that the checkpoint will be consistent in the case of a failure. + """ + super().__init__() + self.fs = FileSystem() + self.path = self.fs.init_path(path) + self.single_file_per_rank = single_file_per_rank + self.sync_files = sync_files + self.thread_count = thread_count + self.per_thread_copy_ahead = per_thread_copy_ahead + self.save_id = _generate_uuid() + self.overwrite = overwrite + self.transforms = _StorageWriterTransforms(_extensions) + self.serialization_format = serialization_format + self.rank: Optional[int] = None + self.use_collectives: bool = True + + def reset(self, checkpoint_id: Union[str, os.PathLike, None] = None) -> None: + if checkpoint_id: + self.path = self.fs.init_path(checkpoint_id) + self.save_id = _generate_uuid() + + def set_up_storage_writer( + self, is_coordinator: bool, *args: Any, **kwargs: Any + ) -> None: + self.rank = kwargs.get("rank") + self.use_collectives = kwargs.get("use_collectives", True) + + def _metadata_exists(self) -> bool: + if self.use_collectives: + # A global checkpoint metadata file + metadata_path = self._get_metadata_path(rank=None) + else: + # A rank 0 specific metadata file if every rank has written its own metadata + # Just looking for lowest rank metadata file is sufficient + metadata_path = self._get_metadata_path(rank=0) + + return self.fs.exists(metadata_path) + + def prepare_local_plan(self, plan: SavePlan) -> SavePlan: + self.fs.mkdir(self.path) + if self._metadata_exists(): + if self.overwrite: + warnings.warn( + f"Detected an existing checkpoint in {self.path}, overwriting since {self.overwrite=}." + " Past version 2.5 of PyTorch, `overwrite` will default to False. Set this variable to True to" + " maintain this functionality or False to raise when an existing checkpoint is found.", + stacklevel=2, + ) + else: + raise RuntimeError(f"Checkpoint already exists and {self.overwrite=}.") + + if self.rank is not None and not self.use_collectives: + plan = dataclasses.replace( + plan, storage_data=_StoragePrefix(f"__{self.rank}_") + ) + + return plan + + def prepare_global_plan(self, plans: list[SavePlan]) -> list[SavePlan]: + new_plans = [ + dataclasses.replace(plan, storage_data=_StoragePrefix(f"__{i}_")) + if plan.storage_data is None + else plan + for i, plan in enumerate(plans) + ] + return new_plans + + def write_data( + self, + plan: SavePlan, + planner: SavePlanner, + ) -> Future[list[WriteResult]]: + storage_plan: _StoragePrefix = plan.storage_data + file_count = 0 + + def gen_file(): + nonlocal file_count + file_name = f"{storage_plan.prefix}{file_count}{DEFAULT_SUFFIX}" + file_count += 1 + return file_name + + file_queue: queue.Queue = queue.Queue() + if self.single_file_per_rank: + for bucket in _split_by_size_and_type(self.thread_count, plan.items): + file_name = gen_file() + path = self.fs.concat_path(self.path, file_name) + file_queue.put((path, file_name, bucket)) + else: + for item in plan.items: + file_name = gen_file() + path = self.fs.concat_path(self.path, file_name) + file_queue.put((path, file_name, [item])) + + return self._write_data(planner, file_queue) + + def _write_data( + self, + planner: SavePlanner, + file_queue: queue.Queue, + ) -> Future[list[WriteResult]]: + result_queue: queue.Queue = queue.Queue() + + threads = [] + for _ in range(1, self.thread_count): + t = threading.Thread( + target=_write_files_from_queue, + args=( + self.fs.create_stream, + file_queue, + result_queue, + planner, + self.transforms, + self.per_thread_copy_ahead, + self.sync_files, + self.thread_count, + self.serialization_format, + ), + ) + t.start() + threads.append(t) + + _write_files_from_queue( + create_stream=self.fs.create_stream, + file_queue=file_queue, + result_queue=result_queue, + planner=planner, + transforms=self.transforms, + inflight_threshhold=self.per_thread_copy_ahead, + use_fsync=self.sync_files, + thread_count=self.thread_count, + serialization_format=self.serialization_format, + ) + + for t in threads: + t.join() + + res = [] + try: + while True: + res += result_queue.get_nowait() + except queue.Empty: + fut: Future[list[WriteResult]] = Future() + fut.set_result(res) + return fut + + def finish(self, metadata: Metadata, results: list[list[WriteResult]]) -> None: + metadata = dataclasses.replace(metadata, version=CURRENT_DCP_VERSION) + + storage_md = {} + for wr_list in results: + storage_md.update({wr.index: wr.storage_data for wr in wr_list}) + metadata.storage_data = storage_md + + metadata.storage_meta = self.storage_meta() + tmp_filename = ( + f"__{self.rank}{_metadata_fn}.tmp" + if not self.use_collectives and self.rank is not None + else f"{_metadata_fn}.tmp" + ) + tmp_path = cast(Path, self.fs.concat_path(self.path, tmp_filename)) + with self.fs.create_stream(tmp_path, "wb") as metadata_file: + pickle.dump(metadata, metadata_file) + if self.sync_files: + try: + os.fsync(metadata_file.fileno()) + except (AttributeError, UnsupportedOperation): + os.sync() + + # delete in-case other checkpoints were present. + if not self.use_collectives and self.rank is not None: + metadata_path = self._get_metadata_path(self.rank) + else: + metadata_path = self._get_metadata_path() + + if self.fs.exists(metadata_path): + self.fs.rm_file(metadata_path) + + self.fs.rename(tmp_path, metadata_path) + + def storage_meta(self) -> Optional[StorageMeta]: + return StorageMeta(checkpoint_id=self.checkpoint_id, save_id=self.save_id) + + def _get_metadata_path(self, rank: Optional[int] = None) -> os.PathLike: + filename = f"{_metadata_fn}" if rank is None else f"__{rank}{_metadata_fn}" + return cast(Path, self.fs.concat_path(self.path, filename)) + + @property + def checkpoint_id(self) -> Union[str, os.PathLike]: + """ + return the checkpoint_id that will be used to save the checkpoint. + """ + return self.path + + @classmethod + def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool: + return FileSystem.validate_checkpoint_id(checkpoint_id) + + +class _StorageReaderTransforms: + """ + This is experimental, and will likely move elsewhere in the + future. It lives here to minimize changes while we are still + learning and gathering feedback. + """ + + def __init__(self, extension_registry: Optional[ExtensionRegistry] = None) -> None: + self.extension_registry = ( + ExtensionRegistry() if extension_registry is None else extension_registry + ) + + def transform_load_stream( + self, + read_item: ReadItem, + transform_descriptors: Sequence[str], + raw_stream: IO[bytes], + ) -> IO[bytes]: + extensions = self.extension_registry.from_descriptor_list(transform_descriptors) + transform_from = raw_stream + for ex in extensions: + if isinstance(ex, StreamTransformExtension): + transform_from = ex.transform_from(transform_from) + return transform_from + + +class FileSystemReader(StorageReader): + def __init__( + self, + path: Union[str, os.PathLike], + _extension_registry: Optional[ExtensionRegistry] = None, # EXPERIMENTAL + ) -> None: + super().__init__() + self.fs = FileSystem() + self.path = self.fs.init_path(path) + self.storage_data: dict[Any, Any] = {} + self.load_id = _generate_uuid() + self.transforms = _StorageReaderTransforms(_extension_registry) + self.rank = None + self.use_collectives = True + + def _slice_file(self, file, sinfo: _StorageInfo) -> IO[bytes]: + return cast(IO[bytes], _create_file_view(file, sinfo.offset, sinfo.length)) + + def reset(self, checkpoint_id: Union[str, os.PathLike, None] = None) -> None: + self.storage_data = {} + if checkpoint_id: + self.path = self.fs.init_path(checkpoint_id) + self.load_id = _generate_uuid() + + def read_data(self, plan: LoadPlan, planner: LoadPlanner) -> Future[None]: + # group requests by file + per_file: dict[str, list[ReadItem]] = {} + for read_item in plan.items: + item_md: _StorageInfo = self.storage_data[read_item.storage_index] + path = item_md.relative_path + per_file.setdefault(path, []).append(read_item) + + for relative_path, reqs in per_file.items(): + new_path = self.fs.concat_path(self.path, relative_path) + with self.fs.create_stream(new_path, "rb") as stream: + # TODO sort by offset and cache the reading + for req in reqs: + item_md = self.storage_data[req.storage_index] + file_slice = self._slice_file(stream, item_md) + transform_from = self.transforms.transform_load_stream( + req, + # This field wasn't present in older + # implementations so provide a fallback. + item_md.transform_descriptors or (), + file_slice, + ) + + if req.type == LoadItemType.BYTE_IO: + read_bytes = io.BytesIO(transform_from.read(-1)) + read_bytes.seek(0) + planner.load_bytes(req, read_bytes) + else: + if transform_from.seekable(): + seekable = transform_from + else: + # torch.load requires a seekable input, so read the transform + # stream now and store the output if needed + seekable = io.BytesIO(transform_from.read(-1)) + seekable.seek(0) + + tensor = cast( + Tensor, + torch.load( + seekable, + map_location="cpu", + weights_only=True, + ), + ) + tensor = narrow_tensor_by_index( + tensor, req.storage_offsets, req.lengths + ) + target_tensor = planner.resolve_tensor(req).detach() + + if target_tensor.size() != tensor.size(): + raise AssertionError( + f"req {req.storage_index} mismatch sizes {target_tensor.size()} vs {tensor.size()}" + ) + target_tensor.copy_(tensor) + planner.commit_tensor(req, target_tensor) + + fut: Future = Future() + fut.set_result(None) + return fut + + def _get_metadata_path(self, rank: Optional[int] = None) -> os.PathLike: + filename = f"{_metadata_fn}" if rank is None else f"__{rank}{_metadata_fn}" + return cast(Path, self.fs.concat_path(self.path, filename)) + + # Implementing the abstract function in StorageReader + def read_metadata(self, *args: Any, **kwargs: Any) -> Metadata: + rank = kwargs.get("rank") + path = self._get_metadata_path(rank) + with self.fs.create_stream(path, "rb") as metadata_file: + metadata = pickle.load(metadata_file) + + if getattr(metadata, "storage_meta", None) is None: + metadata.storage_meta = StorageMeta() + metadata.storage_meta.load_id = self.load_id + + return metadata + + def set_up_storage_reader( + self, metadata: Metadata, is_coordinator: bool, *args: Any, **kwargs: Any + ) -> None: + self.storage_data = metadata.storage_data + self.rank = kwargs.get("rank") + self.use_collectives = kwargs.get("use_collectives", True) + if self.storage_data is None: + raise AssertionError("storage_data must not be None in metadata") + + def prepare_local_plan(self, plan: LoadPlan) -> LoadPlan: + return plan + + def prepare_global_plan(self, plans: list[LoadPlan]) -> list[LoadPlan]: + return plans + + @property + def checkpoint_id(self) -> Union[str, os.PathLike]: + """ + return the checkpoint_id that will be used to load the checkpoint. + """ + return self.path + + @classmethod + def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool: + return FileSystem.validate_checkpoint_id(checkpoint_id) + + +class FileSystemWriter(_FileSystemWriter, BlockingAsyncStager): + """ + Basic implementation of StorageWriter using file IO. + + This implementation makes the following assumptions and simplifications: + + * The checkpoint path is an empty or non-existing directory. + * File creation is atomic + + The checkpoint consist of one file per write request plus + a global `.metadata` file with the serialized metadata if rank coordination is enabled. + a rank local `__{rank}.metadata` file with the serialized metadata if rank coordination is NOT enabled. + + """ + + def __init__( + self, + path: Union[str, os.PathLike], + single_file_per_rank: bool = True, + sync_files: bool = True, + thread_count: int = 1, + per_thread_copy_ahead: int = 10_000_000, + cache_staged_state_dict: bool = False, + overwrite: bool = True, + _extensions: Optional[Sequence[StreamTransformExtension]] = None, + serialization_format: SerializationFormat = SerializationFormat.TORCH_SAVE, + ) -> None: + """ + Initialize the writer pointing to `path`. + + Args: + path: directory where the checkpoint will be written to. + single_file_per_rank: Produce one file per rank instead of one file per tensor/blob. Default to True. + sync_files : force files to be synced to permanent storage. Default to True. + thread_count: Number of IO threads to use to write. Default to 1. + per_thread_copy_ahead: How many bytes to copy from the GPU ahead of saving then. Default 10Mb. + cache_staged_state_dict: Whether to cache the staged state_dict. This option decreases staging latency + at the cost of increases memory usage. Additionally, if this parameter is set to True, it's the expectation + that the stager is maintained and reused for multiple dcp.async_save calls. Default to False. + overwrite: Whether to allow overwriting existing checkpoints. Defaults to True. + _extensions: Extensions to apply to output streams (EXPERIMENTAL) + + N. B. If sync_files is disabled, there's no guarantee that the checkpoint will be consistent in the case of a failure. + """ + _FileSystemWriter.__init__( + self, + path=path, + single_file_per_rank=single_file_per_rank, + sync_files=sync_files, + thread_count=thread_count, + per_thread_copy_ahead=per_thread_copy_ahead, + overwrite=overwrite, + _extensions=_extensions, + serialization_format=serialization_format, + ) + BlockingAsyncStager.__init__( + self, + cache_staged_state_dict=cache_staged_state_dict, + ) + + def stage(self, state_dict: STATE_DICT_TYPE) -> STATE_DICT_TYPE: + """Override of AsyncStager.stage""" + # in the async case, the state dict is already on CPU, so maintaining this + # buffer makes no sense + self.per_thread_copy_ahead = 0 + return super().stage(state_dict) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/format_utils.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/format_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..912f983fe2a7ce9267ce74940d42f9bd2b3969ca --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/format_utils.py @@ -0,0 +1,292 @@ +# mypy: allow-untyped-defs +import argparse +import os +from enum import Enum +from typing import cast, Optional, Union + +import torch +import torch.distributed as dist +from torch.distributed._shard._utils import narrow_tensor_by_index +from torch.distributed.checkpoint import FileSystemReader, FileSystemWriter +from torch.distributed.checkpoint._nested_dict import flatten_state_dict +from torch.distributed.checkpoint.default_planner import ( + _EmptyStateDictLoadPlanner, + DefaultLoadPlanner, +) +from torch.distributed.checkpoint.metadata import ( + Metadata, + STATE_DICT_TYPE, + STORAGE_TYPES, + TensorProperties, + TensorStorageMetadata, +) +from torch.distributed.checkpoint.planner import LoadItemType, LoadPlan, LoadPlanner +from torch.distributed.checkpoint.planner_helpers import _create_chunk_list +from torch.distributed.checkpoint.state_dict_loader import _load_state_dict +from torch.distributed.checkpoint.state_dict_saver import _save_state_dict +from torch.distributed.checkpoint.storage import StorageReader +from torch.futures import Future + + +__all__ = [ + "dcp_to_torch_save", + "torch_save_to_dcp", + "BroadcastingTorchSaveReader", + "DynamicMetaLoadPlanner", +] + + +class BroadcastingTorchSaveReader(StorageReader): + """ + StorageReader for reading a Torch Save file. This reader will read the entire checkpoint + on the coordinator rank, and then broadcast and shard each tensor to all ranks. + + . N.B. Intended to be used with DynamicMetaLoadPlanner + + .. warning:: + Current implementation only supports loading Tensors. + + >>> # xdoctest: +SKIP("undefined vars") + >>> sd = {"mode": model} + >>> dcp.load( + >>> sd, + >>> storage_reader=BroadcastingTorchSaveReader(), + >>> planner=DynamicMetaLoadPlanner(), + >>> checkpoint_id="path_to_model.pt" + >>> ) + """ + + def __init__( + self, + checkpoint_id: Optional[Union[str, os.PathLike]] = None, + coordinator_rank: int = 0, + ) -> None: + self.checkpoint_id = checkpoint_id + self.coordinator_rank = coordinator_rank + + # pyrefly: ignore [bad-override] + def read_metadata(self) -> Metadata: + """Extends the default StorageReader to support building the metadata file""" + # Metadata is built in planner.set_up_planner, since we are not actually reading metadata from + # the disk + return Metadata(state_dict_metadata={}) + + def read_data(self, plan: LoadPlan, planner: LoadPlanner) -> Future[None]: + """ + Reads torch save data on the coordinator rank, and broadcast afterwards + this incurrs a communication cost, but avoids having to load + the entire checkpoint on each rank, hopefully preventing OOM issues + """ + planner = cast(DefaultLoadPlanner, planner) + + # data is read in on the coordinator rank, and broadcast afterwards + # this incurs a communication cost, but it avoids having to load + # the entire checkpoint on each rank, hopefully preventing OOM issues + # TODO: read on each host, instead of only the coordinator + if self.is_coordinator: + if self.checkpoint_id is None: + raise AssertionError("checkpoint_id must be set before reading data") + torch_state_dict = torch.load( + self.checkpoint_id, map_location="cpu", weights_only=False + ) + if planner.flatten_state_dict: + torch_state_dict, _ = flatten_state_dict(torch_state_dict) + else: + torch_state_dict = None + + for req in plan.items: + if req.type == LoadItemType.BYTE_IO: + raise RuntimeError( + f"Non-tensor value identified at {req.storage_index.fqn}. " + f"At this time {type(self).__name__} only supports loading Tensors." + ) + + # Broadcast the tensor from the coordinator rank + if self.is_coordinator: + pg_device = dist.distributed_c10d._get_pg_default_device() + # pyrefly: ignore [unsupported-operation] + tensor = torch_state_dict[req.storage_index.fqn].to(pg_device) + else: + tensor = torch.empty_like(planner.state_dict[req.storage_index.fqn]) + + dist.broadcast(tensor, src=self.coordinator_rank, async_op=False) + + tensor = narrow_tensor_by_index(tensor, req.storage_offsets, req.lengths) + target_tensor = planner.resolve_tensor(req).detach() + if not target_tensor.size() == tensor.size(): + raise AssertionError( + f"req {req.storage_index} mismatch sizes, " + f"{target_tensor.size()} vs {tensor.size()}" + ) + target_tensor.copy_(tensor) + planner.commit_tensor(req, target_tensor) + + fut: Future = Future() + fut.set_result(None) + return fut + + # pyrefly: ignore [bad-override] + def set_up_storage_reader(self, metadata: Metadata, is_coordinator: bool) -> None: + """Implementation of the StorageReader method""" + self.is_coordinator = is_coordinator + if self.is_coordinator: + if not dist.get_rank() == self.coordinator_rank: + raise AssertionError( + f"Coordinator rank mismatch: expected {self.coordinator_rank}, " + f"got {dist.get_rank()}" + ) + + if self.checkpoint_id is None: + raise AssertionError( + "checkpoint_id must be set before setting up storage reader" + ) + + def prepare_local_plan(self, plan: LoadPlan) -> LoadPlan: + """Implementation of the StorageReader method""" + return plan + + def prepare_global_plan(self, global_plan: list[LoadPlan]) -> list[LoadPlan]: + """Implementation of the StorageReader method""" + return global_plan + + def reset(self, checkpoint_id: Union[str, os.PathLike, None] = None) -> None: + """Implementation of the StorageReader method""" + self.checkpoint_id = checkpoint_id + + @classmethod + def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool: + """Implementation of the StorageReader method""" + return os.path.isfile(checkpoint_id) + + +class DynamicMetaLoadPlanner(DefaultLoadPlanner): + """ + Extension of DefaultLoadPlanner, which creates a new Metadata object based on the passed in state dict, + avoiding the need to read metadata from disk. This is useful when reading formats which don't have a + metadata file, like Torch Save files. + + . N.B. Intended to be used with BroadcastingTorchSaveReader + + .. warning:: + Current implementation only supports loading Tensors. + + >>> # xdoctest: +SKIP("undefined vars") + >>> sd = {"mode": model} + >>> dcp.load( + >>> sd, + >>> storage_reader=BroadcastingTorchSaveReader(), + >>> planner=DynamicMetaLoadPlanner(), + >>> checkpoint_id="path_to_model.pt" + >>> ) + """ + + def set_up_planner( + self, + state_dict: STATE_DICT_TYPE, + metadata: Optional[Metadata] = None, + is_coordinator: bool = False, + ) -> None: + """Setups of the planner, extnding default behavior by creating the Metadata object from the state dict""" + super().set_up_planner(state_dict, metadata, is_coordinator) + + state_dict_metadata: dict[str, STORAGE_TYPES] = {} + for key, tensor in self.state_dict.items(): + if not torch.is_tensor(tensor): + raise RuntimeError( + f"Non-tensor value identified at {key}. " + f"At this time {type(self).__name__} only supports loading Tensors." + ) + + state_dict_metadata[key] = TensorStorageMetadata( + TensorProperties(dtype=tensor.dtype), + tensor.size(), + _create_chunk_list(tensor), + ) + self.metadata = Metadata(state_dict_metadata=state_dict_metadata) + + +def dcp_to_torch_save( + dcp_checkpoint_dir: Union[str, os.PathLike], + torch_save_path: Union[str, os.PathLike], +): + """ + Given a directory containing a DCP checkpoint, this function will convert it into a + Torch save file. + + Args: + dcp_checkpoint_dir: Directory containing the DCP checkpoint. + torch_save_path: Filename to store the converted Torch save file. + + .. warning:: + To avoid OOM, it's recommended to only run this function on a single rank. + """ + sd: STATE_DICT_TYPE = {} + _load_state_dict( + sd, + storage_reader=FileSystemReader(dcp_checkpoint_dir), + planner=_EmptyStateDictLoadPlanner(), + no_dist=True, + ) + torch.save(sd, torch_save_path) + + +def torch_save_to_dcp( + torch_save_path: Union[str, os.PathLike], + dcp_checkpoint_dir: Union[str, os.PathLike], +): + """ + Given the location of a torch save file, converts it into a DCP checkpoint. + + Args: + torch_save_path: Filename of the Torch save file. + dcp_checkpoint_dir: Directory to store the DCP checkpoint. + + .. warning:: + To avoid OOM, it's recommended to only run this function on a single rank. + """ + + state_dict = torch.load(torch_save_path, weights_only=False) + # we don't need stateful behavior here because the expectation is anything loaded by + # torch.load would not contain stateful objects. + _save_state_dict( + state_dict, storage_writer=FileSystemWriter(dcp_checkpoint_dir), no_dist=True + ) + + +if __name__ == "__main__": + + class FormatMode(Enum): + TORCH_TO_DCP = "torch_to_dcp" + DCP_TO_TORCH = "dcp_to_torch" + + # Parse command-line arguments + parser = argparse.ArgumentParser() + parser.add_argument( + "mode", + type=str, + help="Conversion mode", + choices=[m.value for m in FormatMode], + default=FormatMode.TORCH_TO_DCP, + ) + parser.add_argument("src", type=str, help="Path to the source model") + parser.add_argument("dst", type=str, help="Path to the destination model") + args = parser.parse_args() + + print( + f"Converting checkpoint from {args.src} to {args.dst} using method: '{args.mode}'" + ) + checkpoint_missing_warning = ( + f"No checkpoint found at {args.src}. Skipping conversion." + ) + if args.mode == FormatMode.TORCH_TO_DCP.value: + if os.path.isfile(args.src): + torch_save_to_dcp(args.src, args.dst) + else: + print(checkpoint_missing_warning) + elif args.mode == FormatMode.DCP_TO_TORCH.value: + if os.path.isdir(args.src): + dcp_to_torch_save(args.src, args.dst) + else: + print(checkpoint_missing_warning) + else: + raise ValueError(f"Unknown conversion mode: {args.mode}") diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/hf_storage.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/hf_storage.py new file mode 100644 index 0000000000000000000000000000000000000000..52f9209da0ec58826cfa3c445e2b2070c5dee60f --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/hf_storage.py @@ -0,0 +1,391 @@ +# mypy: allow-untyped-defs +import dataclasses +import json +import logging +import queue +import threading +from typing import Any, Optional + +import torch +from torch.distributed.checkpoint import FileSystemReader, FileSystemWriter +from torch.distributed.checkpoint._consolidate_hf_safetensors import ( + consolidate_safetensors_files, +) +from torch.distributed.checkpoint._hf_utils import ( + _gen_file_name, + _HFStorageInfo, + _metadata_fn, + CUSTOM_METADATA_KEY, + SAVED_OFFSETS_KEY, + SHARDED_DIR_NAME, + SUFFIX, +) +from torch.distributed.checkpoint.filesystem import SerializationFormat +from torch.distributed.checkpoint.metadata import ( + ChunkStorageMetadata, + Metadata, + MetadataIndex, + StorageMeta, + TensorProperties, + TensorStorageMetadata, +) +from torch.distributed.checkpoint.planner import ( + LoadPlan, + LoadPlanner, + ReadItem, + SavePlan, + SavePlanner, + WriteItem, +) +from torch.distributed.checkpoint.storage import WriteResult +from torch.futures import Future + + +logger: logging.Logger = logging.getLogger(__name__) + +__all__ = ["HuggingFaceStorageWriter", "HuggingFaceStorageReader"] + + +class HuggingFaceStorageWriter(FileSystemWriter): + """ + A writer that writes to storage in the huggingface safetensors format. + """ + + def __init__( + self, + path: str, + fqn_to_index_mapping: Optional[dict[str, int]] = None, + thread_count: int = 1, + save_distributed: bool = False, + enable_consolidation: bool = False, + thread_count_consolidation: int = 1, + ) -> None: + """ + Initialize the huggingface writer pointing to path. + + Args: + path: directory where the checkpoint will be read from. + fqn_to_index_mapping: A mapping from tensor FQN to the index of the file that the tensor should be written to. + Indices are from 1 to N, where N is the number of files. If not provided, + the tensors will be written to a single file. If none, then all the tensors on the + same rank will be written to the same file. + thread_count: Number of threads to use to write distributed checkpoint. Default to 1. + save_distributed: If True, save the checkpoint using distributed APIs where every rank saves its own shard. + Default is False which assumes rank-0 checkpointing of the full state_dict. + enable_consolidation: If True, consolidate the sharded checkpoint after saving. The sharded tensors will be + saved to path/sharded and the full tensors will be saved to path. Default to False. + thread_count_consolidation: Number of threads to use for parallel processing of saving data + to consolidated output files. Default to 1. + """ + + super().__init__( + path=path, + serialization_format=SerializationFormat.SAFETENSORS, + thread_count=thread_count, + ) + self.fqn_to_index_mapping: Optional[dict[str, int]] = fqn_to_index_mapping + self.save_distributed: bool = save_distributed + self.enable_consolidation: bool = enable_consolidation + self.consolidated_output_path: Optional[str] = None + if self.enable_consolidation: + self.consolidated_output_path = str(self.path) + self.path = self.fs.concat_path(self.path, SHARDED_DIR_NAME) + self.thread_count_consolidation = thread_count_consolidation + + def prepare_global_plan(self, plans: list[SavePlan]) -> list[SavePlan]: + new_plans = [] + for i, plan in enumerate(plans, start=1): + storage_data: dict[str, Any] = {} + if self.fqn_to_index_mapping is not None: + storage_data["fqn_to_index_mapping"] = self.fqn_to_index_mapping + if self.save_distributed: + storage_data["shard_index"] = i + + new_plans.append(dataclasses.replace(plan, storage_data=storage_data)) + + return new_plans + + def write_data( + self, + plan: SavePlan, + planner: SavePlanner, + ) -> Future[list[WriteResult]]: + if len(plan.items) == 0: + fut: Future = Future() + fut.set_result([]) + return fut + + # storage_plan is a map from key to file index + storage_data: dict[str, Any] = plan.storage_data + storage_plan: Optional[dict[str, int]] = None + shard_index: Optional[int] = None + if "fqn_to_index_mapping" in storage_data: + storage_plan = storage_data["fqn_to_index_mapping"] + if "shard_index" in storage_data: + shard_index = storage_data["shard_index"] + + buckets = self._split_by_storage_plan(storage_plan, plan.items) + highest_index = max(storage_plan.values()) if storage_plan is not None else 1 + + file_queue: queue.Queue = queue.Queue() + for file_index, write_items in buckets.items(): + file_name = _gen_file_name(file_index, highest_index, shard_index) + file_queue.put( + (self.fs.concat_path(self.path, file_name), file_name, write_items) + ) + + return super()._write_data(planner, file_queue) + + def finish(self, metadata: Metadata, results: list[list[WriteResult]]) -> None: + if self.save_distributed and not self.enable_consolidation: + # if we are saving distributed, without consolidating, + # then we have no metadata to write because a metadata + # file with fqn to file mapping doesn't make sense + # in this case, because fqns will be in multiple files + logger.info("Not consolidating sharded checkpoint in finish step.") + return + if self.save_distributed: + fqn_to_index_mapping: dict[str, int] = ( + self.fqn_to_index_mapping + if self.fqn_to_index_mapping is not None + else dict.fromkeys(metadata.state_dict_metadata.keys(), 1) + ) + + return consolidate_safetensors_files( + input_dir=str(self.path), + output_dir=self.consolidated_output_path, # type: ignore[arg-type] + num_threads=self.thread_count_consolidation, + fqn_to_index_mapping=fqn_to_index_mapping, + ) + + # writing a model.index.safetensors.json file with fqn to file mapping + # for the rank-0 checkpointing case + metadata_to_write = {} + storage_md = {} + total_size = 0 + for wr_list in results: + storage_md.update( + {wr.index.fqn: wr.storage_data.relative_path for wr in wr_list} + ) + total_size += sum([wr.storage_data.length for wr in wr_list]) + metadata_to_write["metadata"] = {"total_size": total_size} + metadata_to_write["weight_map"] = storage_md + + metadata_path = self.fs.concat_path(self.path, f"{_metadata_fn}") + with self.fs.create_stream(metadata_path, "w") as metadata_file: + json.dump(metadata_to_write, metadata_file, indent=2) + + def _split_by_storage_plan( + self, storage_plan: Optional[dict[str, int]], items: list[WriteItem] + ) -> dict[int, list[WriteItem]]: + # storage_plan is a map from key to index + if storage_plan is None: + return {1: items} + + buckets = {} + for item in items: + key = item.index.fqn + + idx = storage_plan[key] + if idx not in buckets: + buckets[idx] = [item] + else: + buckets[idx].append(item) + + return buckets + + @property + def metadata_path(self) -> str: + return _metadata_fn + + +class HuggingFaceStorageReader(FileSystemReader): + """ + A reader that reads a checkpoint in the huggingface safetensors format. + """ + + def __init__(self, path: str, thread_count: int = 1) -> None: + """ + Initialize the huggingface reader pointing to path. + + Args: + path: directory where the checkpoint will be read from. + thread_count: Number of threads to use to read distributed checkpoint. Default to 1. + """ + + super().__init__(path=path) + self.thread_count = thread_count + + def _process_read_request(self, f, req: ReadItem, planner: LoadPlanner) -> None: + """Helper function to process a single read request.""" + # Create slices for each dimension based on offsets and lengths + slices = tuple( + slice(offset, offset + length) + for offset, length in zip(req.storage_offsets, req.lengths) + ) + tensor = f.get_slice(req.storage_index.fqn)[slices] + target_tensor = planner.resolve_tensor(req).detach() + + if target_tensor.size() != tensor.size(): + raise AssertionError( + f"req {req.storage_index} mismatch sizes {target_tensor.size()} vs {tensor.size()}" + ) + + target_tensor.copy_(tensor) + planner.commit_tensor(req, target_tensor) + + def _read_files_from_queue( + self, + file_queue: queue.Queue, + result_queue: queue.Queue, + planner: LoadPlanner, + ) -> None: + from safetensors import safe_open # type: ignore[import] + + try: + while True: + file_name, reqs = file_queue.get_nowait() + with safe_open(filename=file_name, framework="pt") as f: + for req in reqs: + self._process_read_request(f, req, planner) + result_queue.put(True) # Signal that this file has been processed + except queue.Empty: + pass + + def read_data(self, plan: LoadPlan, planner: LoadPlanner) -> Future[None]: + from safetensors import safe_open # type: ignore[import] + + per_file: dict[str, list[ReadItem]] = {} + + for read_item in plan.items: + item_md: _HFStorageInfo = self.storage_data[read_item.storage_index] + file_name = item_md.relative_path + per_file.setdefault(file_name, []).append(read_item) + + if self.thread_count <= 1 or len(per_file) <= 1: + for file_name, reqs in per_file.items(): + with safe_open(filename=file_name, framework="pt") as f: + for req in reqs: + self._process_read_request(f, req, planner) + else: + # Use parallel implementation with thread pool + file_queue: queue.Queue = queue.Queue() + result_queue: queue.Queue = queue.Queue() + + # Fill the queue with files to process + for file_name, reqs in per_file.items(): + file_queue.put((file_name, reqs)) + + # Create and start worker threads + threads = [] + num_threads = min(self.thread_count, len(per_file)) + for _ in range(num_threads): + t = threading.Thread( + target=self._read_files_from_queue, + args=(file_queue, result_queue, planner), + ) + t.start() + threads.append(t) + + # Wait for all threads to complete + for t in threads: + t.join() + + # Check if all files were processed + processed_count = 0 + try: + while True: + result_queue.get_nowait() + processed_count += 1 + except queue.Empty: + pass + + if processed_count != len(per_file): + raise AssertionError( + f"Not all files were processed: {processed_count} out of {len(per_file)}" + ) + + fut: Future = Future() + fut.set_result(None) + return fut + + # pyrefly: ignore [bad-override] + def read_metadata(self) -> Metadata: + from safetensors import safe_open # type: ignore[import] + from safetensors.torch import _getdtype # type: ignore[import] + + state_dict_metadata: dict[str, TensorStorageMetadata] = {} + storage_data: dict[MetadataIndex, _HFStorageInfo] = {} + + safetensors_files = [] + for file in self.fs.ls(self.path): + if file.endswith(SUFFIX): + safetensors_files.append(file) + + for safetensor_file in safetensors_files: + with safe_open(safetensor_file, framework="pt") as f: + keys = f.keys() + extra_metadata = f.metadata() + + dcp_sharding_info = None + if extra_metadata and extra_metadata.get(CUSTOM_METADATA_KEY): + dcp_sharding_info = json.loads( + extra_metadata.get(CUSTOM_METADATA_KEY) + ) + + for key in keys: + shape = f.get_slice(key).get_shape() + dtype = f.get_slice(key).get_dtype() + # construct state_dict_metadata + if dcp_sharding_info is not None: + offset = dcp_sharding_info[key][SAVED_OFFSETS_KEY] + else: + offset = [0] * len(shape) + + if key not in state_dict_metadata: + state_dict_metadata[key] = TensorStorageMetadata( + properties=TensorProperties(dtype=_getdtype(dtype)), + size=torch.Size( + [saved + offset for saved, offset in zip(shape, offset)] + ), + chunks=[ + ChunkStorageMetadata( + offsets=torch.Size(offset), + sizes=torch.Size(shape), + ) + ], + ) + else: + state_dict_metadata[key].chunks.append( + ChunkStorageMetadata( + torch.Size(offset), sizes=torch.Size(shape) + ) + ) + size = list(state_dict_metadata[key].size) + for i in range(len(size)): + size[i] = max(size[i], shape[i] + offset[i]) + state_dict_metadata[key].size = torch.Size(size) + + # construct storage data + if dcp_sharding_info is not None: + metadata_index = MetadataIndex( + fqn=key, offset=dcp_sharding_info[key][SAVED_OFFSETS_KEY] + ) + else: + metadata_index = MetadataIndex(fqn=key, offset=[0] * len(shape)) + storage_data[metadata_index] = _HFStorageInfo( + relative_path=safetensor_file, + shape=torch.Size(shape), + dtype=_getdtype(dtype), + ) + + metadata = Metadata( + state_dict_metadata=state_dict_metadata, # type: ignore[arg-type] + storage_data=storage_data, + ) + + if getattr(metadata, "storage_meta", None) is None: + metadata.storage_meta = StorageMeta() + metadata.storage_meta.load_id = self.load_id # type: ignore[union-attr] + + return metadata diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/logger.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/logger.py new file mode 100644 index 0000000000000000000000000000000000000000..677cac0339cb9fab60c77f75da04bc7ef06504f3 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/logger.py @@ -0,0 +1,121 @@ +# mypy: allow-untyped-defs +import functools +import logging +import time +from collections.abc import Callable +from typing import Any, TypeVar +from typing_extensions import ParamSpec +from uuid import uuid4 + +import torch.distributed.c10d_logger as c10d_logger +from torch.distributed.checkpoint.logging_handlers import DCP_LOGGER_NAME + + +logger = logging.getLogger() + + +__all__: list[str] = [] + +# pyrefly: ignore [unknown-name] +global _dcp_logger +_dcp_logger = c10d_logger._get_or_create_logger(DCP_LOGGER_NAME) + +_T = TypeVar("_T") +_P = ParamSpec("_P") + + +def _msg_dict_from_dcp_method_args(*args, **kwargs) -> dict[str, Any]: + """ + Extracts log data from dcp method args + """ + msg_dict = {} + + # checkpoint ID can be passed in through the serializer or through the checkpoint id directly + storage_writer = kwargs.get("storage_writer") + storage_reader = kwargs.get("storage_reader") + planner = kwargs.get("planner") + + checkpoint_id = kwargs.get("checkpoint_id") + if not checkpoint_id and (serializer := storage_writer or storage_reader): + checkpoint_id = getattr(serializer, "checkpoint_id", None) + + msg_dict["checkpoint_id"] = ( + # pyrefly: ignore [unsupported-operation] + str(checkpoint_id) if checkpoint_id is not None else checkpoint_id + ) + + # Uniquely identify a _dcp_method_logger wrapped function call. + msg_dict["uuid"] = str(uuid4().int) + + if storage_writer: + msg_dict["storage_writer"] = storage_writer.__class__.__name__ + + if storage_reader: + msg_dict["storage_reader"] = storage_reader.__class__.__name__ + + if planner: + msg_dict["planner"] = planner.__class__.__name__ + + return msg_dict + + +def _get_msg_dict(func_name, *args, **kwargs) -> dict[str, Any]: + msg_dict = _msg_dict_from_dcp_method_args(*args, **kwargs) + msg_dict.update(c10d_logger._get_msg_dict(func_name, *args, **kwargs)) + + return msg_dict + + +def _dcp_method_logger( + log_exceptions: bool = False, **wrapper_kwargs: Any +) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]: # pyre-ignore + """This method decorator logs the start, end, and exception of wrapped events.""" + + def decorator(func: Callable[_P, _T]): + @functools.wraps(func) + def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _T: + msg_dict = _get_msg_dict( + func.__name__, *args, **{**wrapper_kwargs, **kwargs} + ) + + # log start event + msg_dict["event"] = "start" + t0 = time.time_ns() + msg_dict["time"] = t0 + msg_dict["log_exceptions"] = log_exceptions + _dcp_logger.debug(msg_dict) + + # exceptions + try: + result = func(*args, **kwargs) + except BaseException as error: + if log_exceptions: + msg_dict["event"] = "exception" + msg_dict["error"] = f"{error}" + msg_dict["time"] = time.time_ns() + _dcp_logger.error(msg_dict) + raise + + # end event + msg_dict["event"] = "end" + t1 = time.time_ns() + msg_dict["time"] = time.time_ns() + msg_dict["times_spent"] = t1 - t0 + _dcp_logger.debug(msg_dict) + + return result + + return wrapper + + return decorator + + +def _init_logger(rank: int): + logger.setLevel(logging.INFO) + ch = logging.StreamHandler() + ch.setLevel(logging.INFO) + formatter = logging.Formatter( + f"[{rank}] %(asctime)s - %(name)s - %(levelname)s - %(message)s" + ) + ch.setFormatter(formatter) + logger.addHandler(ch) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/logging_handlers.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/logging_handlers.py new file mode 100644 index 0000000000000000000000000000000000000000..99c3ee4156ce340e37a2723106df5ea64b19170d --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/logging_handlers.py @@ -0,0 +1,14 @@ +import logging + +from torch.distributed.logging_handlers import _log_handlers + + +__all__: list[str] = [] + +DCP_LOGGER_NAME = "dcp_logger" + +_log_handlers.update( + { + DCP_LOGGER_NAME: logging.NullHandler(), + } +) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/metadata.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/metadata.py new file mode 100644 index 0000000000000000000000000000000000000000..36864b6bf3ad60778ad008fcbb4c10002933c4c6 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/metadata.py @@ -0,0 +1,185 @@ +# mypy: allow-untyped-defs +import os +from collections.abc import Sequence +from dataclasses import dataclass, field +from enum import Enum +from typing import Any, Optional, Union + +import torch +from torch.distributed.checkpoint.stateful import StatefulT + + +__all__ = [ + "ChunkStorageMetadata", + "TensorStorageMetadata", + "BytesStorageMetadata", + "Metadata", + "MetadataIndex", + "TensorProperties", + "StorageMeta", +] + + +@dataclass +class ChunkStorageMetadata: + """ + Each chunk is expected to have the same properties of the TensorStorageMetadata + that includes it. + """ + + offsets: torch.Size + sizes: torch.Size + + +class _MEM_FORMAT_ENCODING(Enum): + """Describe the memory format of a tensor.""" + + TORCH_CONTIGUOUS_FORMAT = 0 + TORCH_CHANNELS_LAST = 1 + TORCH_PRESERVE_FORMAT = 2 + + +@dataclass +class TensorProperties: + """Properties used to create :class:`Tensor`""" + + # Regular tensor fields + dtype: torch.dtype = field(default_factory=torch.get_default_dtype) + # This field is deprecated. + layout: torch.layout = field(default=torch.strided) + # This field is deprecated. + requires_grad: bool = False + # This field is deprecated. + memory_format: torch.memory_format = field(default=torch.contiguous_format) + # This field is deprecated. + pin_memory: bool = False + + def __getstate__(self): + # Since torch.memory_format cannot be pickled! + memory_format = self.memory_format + if memory_format == torch.contiguous_format: + mem_format_encoding = _MEM_FORMAT_ENCODING.TORCH_CONTIGUOUS_FORMAT + elif memory_format == torch.channels_last: + mem_format_encoding = _MEM_FORMAT_ENCODING.TORCH_CHANNELS_LAST + elif memory_format == torch.preserve_format: + mem_format_encoding = _MEM_FORMAT_ENCODING.TORCH_PRESERVE_FORMAT + else: + raise RuntimeError(f"Invalid torch.memory_format: {memory_format}") + + return ( + self.dtype, + self.layout, + self.requires_grad, + mem_format_encoding, + self.pin_memory, + ) + + def __setstate__( + self, + state, + ): + ( + self.dtype, + self.layout, + self.requires_grad, + mem_format_encoding, + self.pin_memory, + ) = state + + if mem_format_encoding == _MEM_FORMAT_ENCODING.TORCH_CONTIGUOUS_FORMAT: + memory_format = torch.contiguous_format + elif mem_format_encoding == _MEM_FORMAT_ENCODING.TORCH_CHANNELS_LAST: + memory_format = torch.channels_last + elif mem_format_encoding == _MEM_FORMAT_ENCODING.TORCH_PRESERVE_FORMAT: + memory_format = torch.preserve_format + else: + raise RuntimeError( + f"Invalid torch.memory_format encoding: {mem_format_encoding}" + ) + + self.memory_format = memory_format + + @staticmethod + def create_from_tensor(tensor: torch.Tensor) -> "TensorProperties": + return TensorProperties( + dtype=tensor.dtype, + layout=tensor.layout, + requires_grad=tensor.requires_grad, + memory_format=torch.contiguous_format, + pin_memory=tensor.is_pinned(), + ) + + +@dataclass +class TensorStorageMetadata: + properties: TensorProperties + size: torch.Size + chunks: list[ChunkStorageMetadata] + + +@dataclass +class BytesStorageMetadata: + pass + + +STORAGE_TYPES = Union[TensorStorageMetadata, BytesStorageMetadata] +STATE_DICT_TYPE = dict[str, Union[StatefulT, Any]] + + +@dataclass +class StorageMeta: + checkpoint_id: Union[str, os.PathLike, None] = None + save_id: Optional[str] = None + load_id: Optional[str] = None + modules: list[str] = field(default_factory=list) + + +@dataclass +class Metadata: + """This class represents the metadata of the checkpoint.""" + + # Keys are the same from the `state_dict` used. + state_dict_metadata: dict[str, STORAGE_TYPES] + # It is the responsibility of the planner and storage plugins to ensure + # backward compatibility of the planner_data and storage_data. DCP will + # also ensure the backward compatibility of the metadata in this file and + # the metadata of the built-in planner and storage plugins. + planner_data: Any = None + storage_data: Any = None + storage_meta: Optional[StorageMeta] = None + version: Optional[str] = None + + +@dataclass(frozen=True) +class MetadataIndex: + """This class represents a lookup key for items in a state dict or Metadata.""" + + fqn: str + """Fully Qualified Name of the object""" + + offset: Optional[torch.Size] = None + """If the object is a tensor, offset into the tensor we're looking for""" + + index: Optional[int] = field(hash=False, compare=False, default=None) + """ + Index hint when searching for tensor chunk to speedup lookups (optional) + + A common representation of a sharded tensor is as a list of chunks so to + find the index in such a list you need to linear search it. + + When constructing an instance of MetadataIndex that points to that list, + one can provide the index as a hint and it will be probed first before + the linear search and thus making it significantly faster. + """ + + def __init__( + self, + fqn: str, + offset: Optional[Sequence[int]] = None, + index: Optional[int] = None, + ): + # We must use object.__setattr__ due to frozen=True + object.__setattr__(self, "fqn", fqn) + object.__setattr__(self, "index", index) + if offset is not None: + object.__setattr__(self, "offset", torch.Size(offset)) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/optimizer.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/optimizer.py new file mode 100644 index 0000000000000000000000000000000000000000..343497da0aa21f35a081a7ca9063d4dcbbf41ccc --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/optimizer.py @@ -0,0 +1,360 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates + +import dataclasses +from collections.abc import Sequence +from typing import cast, Optional, Union + +import torch +import torch.distributed as dist +from torch._utils import _get_device_module +from torch.distributed._shard.sharded_tensor.api import ShardedTensor +from torch.distributed._shard.sharded_tensor.metadata import ( + TensorProperties as ShardTensorProperties, +) +from torch.distributed._shard.sharded_tensor.shard import Shard +from torch.distributed._shard.sharding_spec.chunk_sharding_spec import ChunkShardingSpec +from torch.distributed.checkpoint._nested_dict import unflatten_state_dict +from torch.distributed.checkpoint.default_planner import DefaultLoadPlanner +from torch.distributed.checkpoint.metadata import ( + BytesStorageMetadata, + ChunkStorageMetadata, + Metadata, + MetadataIndex, + STATE_DICT_TYPE, + TensorProperties, + TensorStorageMetadata, +) +from torch.distributed.checkpoint.planner import LoadPlan, LoadPlanner +from torch.distributed.checkpoint.planner_helpers import ( + _create_read_items, + create_read_items_for_chunk_list, +) + +# pyrefly: ignore [deprecated] +from torch.distributed.checkpoint.state_dict_loader import load_state_dict +from torch.distributed.checkpoint.storage import StorageReader +from torch.distributed.checkpoint.utils import ( + _element_wise_add, + _element_wise_sub, + _normalize_device_info, +) +from torch.distributed.distributed_c10d import _get_default_group +from torch.distributed.fsdp._shard_utils import _create_chunk_sharded_tensor +from torch.distributed.remote_device import _remote_device +from torch.distributed.tensor import DTensor + + +STATE_DICT_2D_LAYOUT = dict[str, tuple[Optional[Sequence[int]], Sequence[int]]] + + +# TODO: Update docstrings for optimizer.py +__all__ = [ + "load_sharded_optimizer_state_dict", +] + + +def _gen_rank_device(global_rank: int, device_type: str = "cuda") -> str: + if device_type == "cpu": + return "cpu" + device_module = _get_device_module(device_type) + if device_module.is_available(): + return _normalize_device_info( + device_type, global_rank % device_module.device_count() + ) + return "cpu" + + +def _create_colwise_spec( + pg: Optional[dist.ProcessGroup] = None, +) -> ChunkShardingSpec: + pg_device_type = dist.distributed_c10d._get_pg_default_device(pg).type + if pg is None: + placements = [ + f"rank:{idx}/{_gen_rank_device(idx, pg_device_type)}" + for idx in range(dist.get_world_size()) + ] + else: + placements = [ + f"rank:{idx}/{_gen_rank_device(dist.get_global_rank(pg, idx), pg_device_type)}" + for idx in range(pg.size()) + ] + return ChunkShardingSpec( + dim=0, + placements=cast(list[Union[_remote_device, str]], placements), + ) + + +def _is_nested_tensor(val: torch.Tensor) -> bool: + if type(val) is ShardedTensor: + if len(val.local_shards()) == 0: + return False + if type(val.local_shards()[0].tensor) is ShardedTensor: + return True + if type(val.local_shards()[0].tensor) is DTensor: + raise ValueError("Cannot handle DTensor nested inside ShardedTensor") + elif type(val) is DTensor and ( + type(val._local_tensor) is DTensor or type(val._local_tensor) is ShardedTensor + ): + raise ValueError("Cannot handle nested DTensor") + return False + + +def _alloc_tensor( + props: TensorProperties, size: Sequence[int], device_type: str = "cuda" +) -> torch.Tensor: + if device_type == "cpu": + device = cast(torch.device, _get_device_module(device_type).current_device()) + else: + device = torch.device( + device_type, _get_device_module(device_type).current_device() + ) + + return torch.empty( + size=size, + dtype=props.dtype, + layout=props.layout, + requires_grad=props.requires_grad, + pin_memory=props.pin_memory, + device=device, + ) + + +def _get_state_dict_2d_layout( + state_dict: STATE_DICT_TYPE, +) -> tuple[STATE_DICT_2D_LAYOUT, Optional[dist.ProcessGroup]]: + """ + Load the right TP slice of the optimizer state. + + This is not easy since the per-tensor slicing can't be inferred from checkpoint metadata. + We take advantage of the model state_dict producing a sliced ST to figure out what we need to load. + This is pretty fragile and it might be easier for FSDP to compute this info for us. + Returns a dictionary where keys are the same of the state_dict and the value is a tuple of + (offset, size) for the current rank TP slice. + N.B. The state_dict *MUST* come from FSDP.sharded_state_dict. + """ + specs: STATE_DICT_2D_LAYOUT = {} + dp_pg: Optional[dist.ProcessGroup] = None + for key, value in state_dict.items(): + specs[key] = (None, value.size()) + if _is_nested_tensor(value): + if not len(value.local_shards()) == 1: + raise AssertionError("Cannot handle ST with multiple shards") + if not isinstance(value, ShardedTensor): + raise AssertionError("Can only handle nested ShardedTensor") + shard = value.local_shards()[0] + specs[key] = ( + shard.metadata.shard_offsets, + shard.metadata.shard_sizes, + ) + dp_pg = shard.tensor._process_group # type: ignore[attr-defined] + + return ( + specs, + dp_pg, + ) + + +class _ReaderWithOffset(DefaultLoadPlanner): + translation: dict[MetadataIndex, MetadataIndex] + state_dict: STATE_DICT_TYPE + # pyrefly: ignore [bad-override] + metadata: Metadata + + def __init__(self, fqn_to_offset: dict[str, Sequence[int]]) -> None: + super().__init__() + self.fqn_to_offset = fqn_to_offset + self.metadata = Metadata({}) + self.state_dict = {} + self.translation = {} + + def create_local_plan(self) -> LoadPlan: + requests = [] + self.translation = {} + for fqn, obj in self.state_dict.items(): + md = self.metadata.state_dict_metadata[fqn] + if not isinstance(obj, ShardedTensor): + requests += _create_read_items(fqn, md, obj) + continue + + if fqn not in self.fqn_to_offset: + requests += _create_read_items(fqn, md, obj) + continue + + offset = self.fqn_to_offset[fqn] + + if not len(obj.local_shards()) == 1: + raise AssertionError("Expected exactly one local shard") + original_shard = obj.local_shards()[0] + local_chunks = [ + ChunkStorageMetadata( + offsets=torch.Size( + _element_wise_add(original_shard.metadata.shard_offsets, offset) + ), + sizes=torch.Size(original_shard.metadata.shard_sizes), + ) + ] + + reqs = create_read_items_for_chunk_list( + fqn, cast(TensorStorageMetadata, md), local_chunks + ) + # TODO: The ReadItems will have a displaced MetadataIndex, fix it. + # TODO: we should change _create_sharded_read_items to have more ergonomic API + for ri in reqs: + if ri.dest_index.offset is None: + raise AssertionError("dest_index.offset must not be None") + original_offset = _element_wise_sub(ri.dest_index.offset, offset) + original_index = dataclasses.replace( + ri.dest_index, offset=torch.Size(original_offset) + ) + self.translation[ri.dest_index] = original_index + + requests += reqs + return LoadPlan(requests) + + def lookup_tensor(self, index: MetadataIndex) -> torch.Tensor: + return super().lookup_tensor(self.translation.get(index, index)) + + +def load_sharded_optimizer_state_dict( + model_state_dict: STATE_DICT_TYPE, + optimizer_key: str, + storage_reader: StorageReader, + planner: Optional[LoadPlanner] = None, +) -> STATE_DICT_TYPE: + """ + Load a state_dict in conjunction with FSDP sharded optimizer state. + + This is the current recommended way to checkpoint FSDP. + >>> # xdoctest: +SKIP + >>> import torch.distributed.checkpoint as dist_cp + >>> # Save + >>> model: torch.nn.Model + >>> optim_params = model.parameters() + >>> optim = torch.optim.SGD(optim_params, lr=0.01) + >>> # Save + >>> with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT): + >>> state_dict = { + >>> "optimizer": FSDP.optim_state_dict(model, optim), + >>> "model": model.state_dict() + >>> } + >>> dist_cp.save_state_dict( + >>> state_dict=optim_state, + >>> storage_writer=dist_cp.FileSystemWriter("checkpoint"), + >>> planner=dist_cp.DefaultSavePlanner(), + >>> ) + >>> + >>> # Load + >>> with FSDP.state_dict_type(model_tp, StateDictType.SHARDED_STATE_DICT): + >>> model_state_dict = model_tp.state_dict() + >>> checkpoint = { + >>> "model": model_state_dict + >>> } + >>> dist_cp.load_state_dict( + >>> state_dict=checkpoint, + >>> storage_reader=dist_cp.FileSystemReader(checkpoint_file), + >>> planner=dist_cp.DefaultLoadPlanner(), + >>> ) + >>> model.load_state_dict(checkpoint["model_state"]) + >>> + >>> optim_state = dist_cp.load_sharded_optimizer_state_dict( + >>> model_state_dict, + >>> optimizer_key="optimizer", + >>> storage_reader=dist_cp.FileSystemReader("checkpoint"), + >>> ) + >>> + >>> flattened_osd = FSDP.optim_state_dict_to_load( + >>> model, optim, optim_state["optimizer"] + >>> ) + >>> + >>> optim.load_state_dict(flattened_osd) + """ + metadata = storage_reader.read_metadata() + + layout_specs, dp_pg = _get_state_dict_2d_layout(model_state_dict) + dp_pg_device_type = dist.distributed_c10d._get_pg_default_device(dp_pg).type + device_module = _get_device_module(dp_pg_device_type) + + if dp_pg is None: + placements = [] + for i in range(dist.get_world_size()): + device_info = _normalize_device_info( + dp_pg_device_type, i % device_module.device_count() + ) + placements.append(f"rank:{i}/{device_info}") + sharding_spec = ChunkShardingSpec(dim=0, placements=placements) # type: ignore[arg-type] + else: + sharding_spec = _create_colwise_spec(dp_pg) + + # Create a state_dict for optimizer state + state_dict: STATE_DICT_TYPE = {} + + fqn_to_offset: dict[str, Sequence[int]] = {} + for key, value in metadata.state_dict_metadata.items(): + key_path = metadata.planner_data[key] + if key_path[0] != optimizer_key: + continue + + if isinstance(value, BytesStorageMetadata): + state_dict[key] = "" + continue + + # value: TensorStorageMetadata + if value.size.numel() == 1: + state_dict[key] = _alloc_tensor( + value.properties, value.size, dp_pg_device_type + ) + elif dp_pg is None: + state_dict[key] = _create_chunk_sharded_tensor( + _alloc_tensor(value.properties, value.size, dp_pg_device_type), + rank=dist.get_rank(), + world_size=dist.get_world_size(), + num_devices_per_node=device_module.device_count(), + pg=_get_default_group(), + ) + else: + spec_key = key_path[2] + alloc_size = layout_specs.get(spec_key, (None, value.size))[1] + + properties = ShardTensorProperties( + dtype=value.properties.dtype, + layout=value.properties.layout, + requires_grad=value.properties.requires_grad, + memory_format=value.properties.memory_format, + pin_memory=value.properties.pin_memory, + ) + + st_md = sharding_spec.build_metadata(torch.Size(alloc_size), properties) + local_shards = [] + current_rank = dist.get_rank(dp_pg) + for shard_md in st_md.shards_metadata: + if cast(_remote_device, shard_md.placement).rank() != current_rank: + continue + local_shards.append( + Shard( + tensor=_alloc_tensor( + value.properties, shard_md.shard_sizes, dp_pg_device_type + ), + metadata=shard_md, + ) + ) + + st = ShardedTensor._init_from_local_shards_and_global_metadata( + local_shards, st_md, process_group=dp_pg + ) + + if spec_key in layout_specs and layout_specs[spec_key][0] is not None: + fqn_to_offset[key] = cast(Sequence[int], layout_specs[spec_key][0]) + + state_dict[key] = st + + # Whether we unflatten before or after doesn't matter + load_state_dict( + state_dict=state_dict, + storage_reader=storage_reader, + # FIXME the type of planner is wrong in load_state_dict + planner=_ReaderWithOffset(fqn_to_offset) if dp_pg is not None else planner, + ) + + state_dict = unflatten_state_dict(state_dict, metadata.planner_data) + + return state_dict diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/planner.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/planner.py new file mode 100644 index 0000000000000000000000000000000000000000..8c97dc0379b109dd3a9706176390720a88128851 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/planner.py @@ -0,0 +1,450 @@ +import abc +import io +import operator +from dataclasses import dataclass +from enum import auto, Enum +from functools import reduce +from typing import Any, Optional, Union + +import torch +from torch.distributed.checkpoint.metadata import ( + ChunkStorageMetadata, + Metadata, + MetadataIndex, + STATE_DICT_TYPE, + StorageMeta, + TensorProperties, +) + + +__all__ = [ + "WriteItemType", + "LoadItemType", + "BytesIOWriteData", + "TensorWriteData", + "WriteItem", + "ReadItem", + "SavePlan", + "LoadPlan", + "SavePlanner", + "LoadPlanner", +] + + +class WriteItemType(Enum): + TENSOR = auto() + SHARD = auto() + BYTE_IO = auto() + + +class LoadItemType(Enum): + TENSOR = auto() + BYTE_IO = auto() + + +@dataclass(frozen=True) +class BytesIOWriteData: + nbytes: int + + +@dataclass(frozen=True) +class TensorWriteData: + chunk: ChunkStorageMetadata + properties: TensorProperties + size: torch.Size + + +@dataclass(frozen=True) +class WriteItem: + """Dataclass which holds information about what needs to be written to storage.""" + + index: MetadataIndex + type: WriteItemType + + # Size of bytesIO data to be written. + bytes_io_data: Optional[BytesIOWriteData] = None + + # Value present if it's a tensor write + tensor_data: Optional[TensorWriteData] = None + + def tensor_storage_size(self) -> Optional[int]: + """ + Calculates the storage size of the underlying tensor, or None if this is not a tensor write. + + Returns: + Optional[int] storage size, in bytes of underlying tensor if any. + """ + if self.tensor_data is None: + return None + + numels = reduce(operator.mul, self.tensor_data.size, 1) + dtype_size = torch._utils._element_size(self.tensor_data.properties.dtype) + return numels * dtype_size + + +@dataclass(frozen=True) +class ReadItem: + # Read Item + type: LoadItemType + + # Index into the state_dict + dest_index: MetadataIndex + # Offsets into destination tensor + dest_offsets: torch.Size + + # Index into the checkpoint + storage_index: MetadataIndex + # Offset into the checkpoint data + storage_offsets: torch.Size + + # Size of the hypercube to copy + lengths: torch.Size + + +@dataclass(frozen=True) +class SavePlan: + items: list[WriteItem] + storage_data: Any = None + planner_data: Any = None + # This is used to indicate that the ranks should + # use the cached plans to write data instead. + usable: bool = True + + +@dataclass +class LoadPlan: + items: list[ReadItem] + storage_data: Any = None + planner_data: Any = None + + +class SavePlanner(abc.ABC): + """ + Abstract class defining the protocol used by save_state_dict to plan the save process. + + SavePlanners are stateful objects that can be used to customize the whole save process. + + SavePlanner acts as an access proxy to the state_dict, so any transformation done to it + will be visible to the whole process. + + A planner subclass can expect the following sequence of calls during save_state_dict: + + 1) set_up_planner - called on all ranks. + Signals the start of a checkpoint save. + + 2) create_local_plan - called on all ranks. + Process the state_dict and produces a `SavePlan` that will be sent for global planning. + + 3) create_global_plan - called on the coordinator rank only. + Takes the SavePlan from all ranks and make any global decision. + + 4) finish_plan - called on all ranks. + This gives each rank a chance to adjust to global planning decisions. + + 5) resolve_data - called multiple times on each rank + Lookups a value on the `state_dict` for the storage layer to write. + + Users are recommended to extend DefaultSavePlanner instead of this interface directly as + most changes can be expressed by changes in a single method. + + There are 3 usual patterns of extension: + + Rewriting state_dict. This is the simplest way to extend the save process as it + doesn't requite understanding the intrincacies of how SavePlan works: + + >>> # xdoctest: +SKIP("undefined vars") + >>> class RenamePlanner(DefaultSavePlanner): + >>> def set_up_planner( + >>> self, + >>> state_dict: STATE_DICT_TYPE, + >>> storage_meta: Optional[StorageMeta], + >>> is_coordinator: bool, + >>> ) -> None: + >>> # prefix all keys with `foo_`` + >>> super().set_up_planner({"foo_" + k: v for k, v in state_dict.items()}, storage_meta, is_coordinator) + + Modifying local plan and lookup in tandem. This is useful when fine control of how data is persisted + + >>> # xdoctest: +SKIP("undefined vars") + >>> class FP16Planner(DefaultSavePlanner): + >>> def create_local_plan(self): + >>> plan = super().create_local_plan() + >>> for p in plan: + >>> if p.tensor_data is not None: + >>> p.tensor_data.properties.dtype = torch.float16 + >>> return plan + >>> + >>> def resolve_data(self, write_item): + >>> item = super().resolve_data(write_item) + >>> return item if write_item.type == WriteItemType.BYTE_IO else item.to(torch.float16) + + Using the global planning step to make central decisions that can't be made individually by each rank + + >>> # xdoctest: +SKIP("undefined vars") + >>> from itertools import zip_longest + >>> from dataclasses import replace + >>> class DDPLoadBalancingPlanner(DefaultSavePlanner): + >>> # This uses the default local plan behavior of having all non-sharded writes in rank 0 + >>> # This sample doesn't handle ShardedTensors + >>> def create_global_plan(self, all_plans): + >>> iters = [iter(all_plans[0].items)] * len(all_plans) + >>> items_per_rank = [ + >>> [item for item in items if item is not None] + >>> for items in zip(*zip_longest(*iters), strict=True) + >>> ] + >>> all_plans = [ + >>> replace(plan, items=items) + >>> for plan, items in zip(all_plans, items_per_rank, strict=True) + >>> ] + >>> return super().create_global_plan(all_plans) + + Finally, some planners need to save additional metadata in the checkpoint, this is + accomplished by having each rank contribute their data items in the local plan and + the global planner aggregate them: + + >>> # xdoctest: +SKIP("undefined vars") + >>> class SaveExtraDataPlanner(DefaultSavePlanner): + >>> def create_local_plan(self) -> SavePlan: + >>> plan = super().create_local_plan() + >>> return replace(plan, planner_data="per-rank-data") + >>> + >>> def create_global_plan(self, all_plans: List[SavePlan]) -> Tuple[List[SavePlan], Metadata]: + >>> global_plan, metadata = super().create_global_plan(all_plans) + >>> merged_data = [p.planner_data for p in global_plan] + >>> metadata = replace(metadata, planner_data=merged_data) + >>> return global_plan, metadata + """ + + # Save plan for the current rank as computed by `create_local_plan` API + # Cached on the local rank. + _cached_save_plan: dict[str, SavePlan] = {} + # Final save plan for the current rank. + # This is created by merging the plan created by `create_local_plan` API + # and the result of `create_global_plan` for the given rank. + # This is the final plan computed by the `finish_plan` API that gets + # sent to the `write_data`. + # Cached on the local rank. + _cached_final_save_plan: dict[str, SavePlan] = {} + # Collection of all the local plans from all the ranks. + # This is the input to the `create_global_plan` API. + # Cached on the coordinator rank. + _cached_all_plans: dict[str, list[SavePlan]] = {} + # Global checkpoint plan as computed by `create_global_plan` API. + # Cached on the coordinator rank. + _cached_global_plan: dict[str, list[SavePlan]] = {} + # Metadata for the global checkpoint plan as computed by `create_global_plan` API. + # Cached on the coordinator rank. + _cached_metadata: dict[str, Metadata] = {} + + @abc.abstractmethod + def set_up_planner( + self, + state_dict: STATE_DICT_TYPE, + storage_meta: Optional[StorageMeta] = None, + is_coordinator: bool = False, + ) -> None: + """ + Initialize this planner to save ``state_dict``. + + Implementations should save those values as they won't be provided lated in the save process. + + This is called on all ranks. + """ + + @abc.abstractmethod + def create_local_plan(self) -> SavePlan: + """ + Compute the save plan for the current rank. + + This will be aggregated and passed to create_global_plan. + Planner specific data can be passed through SavePlan::planner_data. + + This is called on all ranks. + """ + + @abc.abstractmethod + def create_global_plan( + self, all_plans: list[SavePlan] + ) -> tuple[list[SavePlan], Metadata]: + """ + Compute the global checkpoint plan and return the local plan of each rank. + + This is called on the coordinator rank only. + """ + + @abc.abstractmethod + def finish_plan(self, new_plan: SavePlan) -> SavePlan: + """ + Merge the plan created by `create_local_plan` and the result of `create_global_plan`. + + This is called on all ranks. + """ + + @abc.abstractmethod + def resolve_data(self, write_item: WriteItem) -> Union[torch.Tensor, io.BytesIO]: + """ + Transform and prepare ``write_item`` from ``state_dict`` for storage, ensuring idempotency and thread-safety. + + Lookup the object associated with ``write_item`` in ``state_dict`` and apply any + transformation (such as serialization) prior to the storage layer consuming it. + + Called on each rank multiple times, at least once per WriteItem in the final SavePlan. + + This method should be idempotent and thread-save. StorageWriter implementations + are free to call it as frequently as they need. + + Any transformation that allocates memory should be lazily done when his method + is called in order to reduce peak memory required by checkpointing. + + When returning tensors, they can be on any device or format, they can be views too. + It's the storage layer responsibility to figure out how to save them. + """ + + +class LoadPlanner: + """ + Abstract class defining the protocol used by load_state_dict to plan the load process. + + LoadPlanner are stateful objects that can be used to customize the whole load process. + + LoadPlanner acts as an access proxy to the state_dict, so any transformation done to it + will be visible to the whole process. + + A planner subclass can expect the following sequence of calls during load_state_dict: + + 1) set_up_planner - called on all ranks. + Signals the start of loading a checkpoint. + + 2) create_local_plan - called on all ranks. + Process the state_dict and produces a `LoadPlan` that will be sent for global planning. + + 3) create_global_plan - called on the coordinator rank only. + Takes the LoadPlan from all ranks and make any global decision. + + 4) load_bytes - called multiple times on each rank + This is called once per non-tensor value in state_dict. + + 5) resolve_tensor and commit_tensor - called multiple times on each rank + They are called in pair for each Tensor value in state_dict. + + Users are recommended to extend DefaultLoadPlanner instead of this interface directly as + most changes can be expressed by changes in a single method. + + There are two usual patterns of extension: + + Rewriting state_dict. This is the simplest way to extend the load process as it + doesn't requite understanding the intrincacies of how LoadPlan works. We need + to keep a reference to the original state_dict as load happens in place so + we need to be able to perform it in place + + >>> # xdoctest: +SKIP("undefined vars") + >>> class RenamePlanner(DefaultLoadPlanner): + >>> def set_up_planner( + >>> self, + >>> state_dict: STATE_DICT_TYPE, + >>> metadata: Metadata, + >>> is_coordinator: bool, + >>> ) -> None: + >>> self.original_state_dict = state_dict + >>> state_dict = {"foo_" + k: v for k, v in state_dict.items()} + >>> + >>> if self.flatten_sharded_tensors: + >>> state_dict = _flatten_sharded_tensors(state_dict) + >>> + >>> if self.flatten_state_dict: + >>> state_dict, self.mappings = flatten_state_dict(state_dict) + >>> + >>> self.state_dict = state_dict + >>> self.metadata = metadata + >>> self.is_coordinator = is_coordinator + >>> + >>> def load_bytes(self, read_item, value): + >>> # Remove the "foo_" prefix + >>> self.original_state_dict[read_item.dest_index.fqn[4:]] = torch.load(value, weights_only=False) + + + Modifying resolve_tensor and commit_tensor to handle load time transformation. + + >>> # xdoctest: +SKIP("undefined vars") + >>> class MetaModelMaterialize(DefaultSavePlanner): + >>> def resolve_tensor(self, read_item): + >>> tensor = super().resolve_tensor(read_item) + >>> return torch.empty_like(tensor, device="cpu") + >>> + >>> def commit_tensor(self, read_item, tensor): + >>> self.state_dict[read_item.dest_index.fqn] = tensor + """ + + @abc.abstractmethod + def set_up_planner( + self, + state_dict: STATE_DICT_TYPE, + metadata: Optional[Metadata] = None, + is_coordinator: bool = False, + ) -> None: + """ + Initialize this instance to load data into ``state_dict``. + + . N.B. This is called on every rank. + """ + + @abc.abstractmethod + def create_local_plan(self) -> LoadPlan: + """ + Create a LoadPlan based on state_dict and metadata provided by set_up_planner. + + . N.B. This is called on every rank. + """ + + @abc.abstractmethod + def create_global_plan(self, global_plan: list[LoadPlan]) -> list[LoadPlan]: + """ + Compute the global load plan and return plans for each rank. + + . N.B. This is called on the coordinator rank only + """ + + @abc.abstractmethod + def finish_plan(self, central_plan: LoadPlan) -> LoadPlan: + """Accept the plan from coordinator and return final LoadPlan.""" + + @abc.abstractmethod + def load_bytes(self, read_item: ReadItem, value: io.BytesIO) -> None: + """ + Load the item described by ``read_item``and ``value``. + + This method is expected to modify in-place the underlying state_dict. + + The contents of ``value`` are defined by the SavePlanner used to produce + the checkpoint being loaded. + """ + + def resolve_bytes(self, read_item: ReadItem) -> io.BytesIO: + """ + Return the BytesIO to be used by the StorageReader to load `read_item`. + + The BytesIO should alias with one on the underlying state_dict as StorageReader will replace its contents. + """ + raise NotImplementedError("LoadPlanner.resolve_bytes is not implemented") + + @abc.abstractmethod + def resolve_tensor(self, read_item: ReadItem) -> torch.Tensor: + """ + Return the tensor described by ``read_item`` to be used by the StorageReader to load `read_item`. + + The tensor should alias with one on the underlying state_dict as StorageReader will replace its contents. + If, for any reason, that's not possible, the planner can use the ``commit_tensor`` method to copy the data + back to the one in state_dict. + """ + + @abc.abstractmethod + def commit_tensor(self, read_item: ReadItem, tensor: torch.Tensor) -> None: + """ + Call once the StorageReader finished loading data into ``tensor``. + + The provided tensor is the same one returned by the call to ``resolve_tensor``. + This method is only needed if this LoadPlanner needs to post process ``tensor`` prior to + copying it back to the one in the state_dict. + + The contents of tensor will follow its device synchronization model. + """ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/planner_helpers.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/planner_helpers.py new file mode 100644 index 0000000000000000000000000000000000000000..9d7af7d7a821b541cf66044a28d828d863624da2 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/planner_helpers.py @@ -0,0 +1,491 @@ +# mypy: allow-untyped-defs +import io +from collections.abc import Callable +from typing import Any, cast + +import torch +import torch.distributed as dist +from torch._utils import _get_device_module +from torch.distributed._shard.metadata import ShardMetadata +from torch.distributed._shard.sharded_tensor import ShardedTensor +from torch.distributed.tensor import DTensor +from torch.distributed.tensor._utils import compute_local_shape_and_global_offset + +from .metadata import ( + BytesStorageMetadata, + ChunkStorageMetadata, + MetadataIndex, + STATE_DICT_TYPE, + STORAGE_TYPES, + TensorProperties, + TensorStorageMetadata, +) +from .planner import ( + LoadItemType, + ReadItem, + SavePlan, + TensorWriteData, + WriteItem, + WriteItemType, +) +from .resharding import ( + _check_shard_metadata_pair_overlap, + _shards_get_overlap_region_wrt_saved_tensor, +) + + +__all__: list[str] = ["create_read_items_for_chunk_list"] + + +def _compare_save_plans(plan: SavePlan, other_plan: SavePlan) -> bool: + """ + Compare the two Save plans and return True if they are equal. + + Args: + plan (SavePlan): First SavePlan to compare. + other_plan (SavePlan): Second SavePlan to compare. + + Returns: + True if the two plans are equal, False otherwise. + """ + if plan.usable != other_plan.usable: + return False + + # Both the plans should have the same number of items + if len(plan.items) != len(other_plan.items): + return False + + # Both the plans should have the same write items. + for plan_item, other_plan_item in zip(plan.items, other_plan.items): + # Write item type should be same + if plan_item.type != other_plan_item.type: + return False + + plan_metadata_index = plan_item.index + other_plan_metadata_index = other_plan_item.index + + # Write item metadata_index should be same + if ( + plan_metadata_index.fqn != other_plan_metadata_index.fqn + or plan_metadata_index.offset != other_plan_metadata_index.offset + or plan_metadata_index.index != other_plan_metadata_index.index + ): + return False + + # Write item tensor_data should be present in both the write items plans, if it exists in either of them. + tensor_data = plan_item.tensor_data + other_tensor_data = other_plan_item.tensor_data + if (tensor_data and not other_tensor_data) or ( + not tensor_data and other_tensor_data + ): + return False + + if tensor_data and other_tensor_data: + # Write item tensor_data size should be same + if tensor_data.size != other_tensor_data.size: + return False + + # Write item tensor_data chunk should be present in both the write items, if it exists in either of them. + chunk = tensor_data.chunk + other_chunk = other_tensor_data.chunk + if (chunk and not other_chunk) or (not chunk and other_chunk): + return False + + # Write item tensor_data chunk offsets and sizes should be same + if chunk and other_chunk: + if ( + chunk.offsets != other_chunk.offsets + or chunk.sizes != other_chunk.sizes + ): + return False + + return True + + +def _contains_usable_plan(delta_plans: list[SavePlan]) -> bool: + """ + Check if any delta plan is usable, indicating the plan has changed. + + Args: + delta_plans (List[SavePlan]): A list of delta plans to check. + Returns: + True if any delta plan is usable, False otherwise. + """ + return any(delta_plan and delta_plan.usable for delta_plan in delta_plans) + + +def _merge_delta_local_plans( + cached_plans: list[SavePlan], + delta_plans: list[SavePlan], +) -> list[SavePlan]: + """ + Merge a list of delta plans into a single plan. + + Args: + cached_plans (List[SavePlan]): A list of cached plans. + delta_plans (List[SavePlan]): A list of delta plans to merge. It can contain empty plans + + Returns: + A single merged plan. If a delta plan is not usable, use the cached plan. Otherwise, use the delta plan. + """ + merged_plans = [] + + for cached_plan, delta_plan in zip(cached_plans, delta_plans): + if delta_plan and not delta_plan.usable: + merged_plans.append(cached_plan) + else: + merged_plans.append(delta_plan) + + return merged_plans + + +def _create_chunk_from_tensor(tensor: torch.Tensor) -> ChunkStorageMetadata: + return ChunkStorageMetadata( + offsets=torch.Size([0] * len(tensor.size())), sizes=tensor.size() + ) + + +def _chunk_for_shard(shard_md: ShardMetadata) -> ChunkStorageMetadata: + return ChunkStorageMetadata( + offsets=torch.Size(shard_md.shard_offsets), + sizes=torch.Size(shard_md.shard_sizes), + ) + + +def _sharded_tensor_metadata( + sharded_tensor: ShardedTensor, shard_md: ShardMetadata +) -> TensorWriteData: + shard_properties = sharded_tensor.metadata().tensor_properties + + properties = TensorProperties( + dtype=shard_properties.dtype, + layout=shard_properties.layout, + requires_grad=shard_properties.requires_grad, + memory_format=shard_properties.memory_format, + pin_memory=shard_properties.pin_memory, + ) + + return TensorWriteData( + chunk=_chunk_for_shard(shard_md), + properties=properties, + size=sharded_tensor.metadata().size, + ) + + +def _create_write_items_for_dtensor(fqn: str, tensor: DTensor) -> WriteItem: + sizes, offsets = compute_local_shape_and_global_offset( + tensor.shape, tensor.device_mesh, tensor.placements + ) + sizes, offsets = torch.Size(sizes), torch.Size(offsets) + + return WriteItem( + index=MetadataIndex(fqn, offsets), + type=WriteItemType.SHARD, + tensor_data=TensorWriteData( + chunk=ChunkStorageMetadata( + offsets=offsets, + sizes=sizes, + ), + properties=TensorProperties.create_from_tensor(tensor.to_local()), + size=tensor.size(), + ), + ) + + +def _create_write_item_for_shard( + fqn: str, sharded_tensor: ShardedTensor, shard_md: ShardMetadata +) -> WriteItem: + offsets = torch.Size(shard_md.shard_offsets) + return WriteItem( + index=MetadataIndex(fqn, offsets), + type=WriteItemType.SHARD, + tensor_data=_sharded_tensor_metadata(sharded_tensor, shard_md), + ) + + +def _create_write_item_for_tensor(fqn: str, tensor: torch.Tensor) -> WriteItem: + offsets = torch.Size([0] * len(tensor.size())) + return WriteItem( + index=MetadataIndex(fqn, offsets), + type=WriteItemType.TENSOR, + tensor_data=TensorWriteData( + chunk=ChunkStorageMetadata(offsets=offsets, sizes=tensor.size()), + properties=TensorProperties.create_from_tensor(tensor), + size=tensor.size(), + ), + ) + + +def _create_write_item_for_bytesio(fqn: str, bytes: Any): + return WriteItem( + index=MetadataIndex(fqn), + type=WriteItemType.BYTE_IO, + ) + + +def _create_read_item_for_byteio( + dest_index, dest_offset, storage_index, storage_offset, length +): + return ReadItem( + type=LoadItemType.BYTE_IO, + dest_index=dest_index, + dest_offsets=torch.Size((dest_offset,)), + storage_index=storage_index, + storage_offsets=torch.Size((storage_offset,)), + lengths=torch.Size((length,)), + ) + + +def _create_read_item_for_tensor( + dest_index, dest_offsets, storage_index, storage_offsets, lengths +): + return ReadItem( + type=LoadItemType.TENSOR, + dest_index=dest_index, + dest_offsets=torch.Size(dest_offsets), + storage_index=storage_index, + storage_offsets=torch.Size(storage_offsets), + lengths=torch.Size(lengths), + ) + + +def create_read_items_for_chunk_list( + fqn: str, + checkpoint_md: TensorStorageMetadata, + local_chunks: list[ChunkStorageMetadata], +) -> list[ReadItem]: + """ + Create a list of ``ReadItem`` based on the checkpoint and local chunks. + + This applies the resharding algorithm and computes the reads needed + to satisfy ``local_chunks`` with a checkpoint described by ``checkpoint_md``. + + Args: + fqn (str) : The state_dict FQN to pass to ``ReadItem``. + checkpoint_md (TensorStorageMetadata): metadata for a given tensor + from a checkpoint. + local_chunks (List[ChunkStorageMetadata]): Local chunks that needs to be + loaded. + + Returns: + A list of ``ReadItem`` that will satisfy all input chunks. + """ + read_items = [] + # this is a naive quadratic algo that can be optimized later + for idx, shard in enumerate(local_chunks): + for storage_idx, storage_md in enumerate(checkpoint_md.chunks): + if not _check_shard_metadata_pair_overlap(shard, storage_md): + continue + + storage_offsets = [] + dest_offsets = [] + lengths = [] + for ( + _dim, + offset_for_saved_tensor, + offset_for_current_tensor, + length, + ) in _shards_get_overlap_region_wrt_saved_tensor( + saved_shard=storage_md, current_shard=shard + ): + storage_offsets.append(offset_for_saved_tensor) + dest_offsets.append(offset_for_current_tensor) + lengths.append(length) + + read_items.append( + _create_read_item_for_tensor( + dest_index=MetadataIndex(fqn, shard.offsets, idx), + dest_offsets=dest_offsets, + storage_index=MetadataIndex(fqn, storage_md.offsets, storage_idx), + storage_offsets=storage_offsets, + lengths=lengths, + ) + ) + return read_items + + +def _create_default_metadata_only_plan(state_dict: STATE_DICT_TYPE) -> SavePlan: + requests = [] + for fqn, obj in state_dict.items(): + if isinstance(obj, DTensor): + requests.append(_create_write_items_for_dtensor(fqn, obj)) + elif isinstance(obj, ShardedTensor): + requests.extend( + _create_write_item_for_shard(fqn, obj, shard_md) + for shard_md in obj.metadata().shards_metadata + ) + elif isinstance(obj, torch.Tensor): + requests.append(_create_write_item_for_tensor(fqn, obj)) + else: + requests.append(_create_write_item_for_bytesio(fqn, obj)) + return SavePlan(requests) + + +def _create_write_items(fqn: str, object: Any) -> list[WriteItem]: + if hasattr(object, "__create_write_items__"): + # DTensor implements _Checkpointable + return object.__create_write_items__(fqn, object) + elif isinstance(object, ShardedTensor): + return [ + _create_write_item_for_shard(fqn, object, shard.metadata) + for shard in object.local_shards() + ] + elif isinstance(object, torch.Tensor): + return [_create_write_item_for_tensor(fqn, object)] + else: + return [_create_write_item_for_bytesio(fqn, object)] + + +def _create_chunk_from_dtensor(tensor: DTensor) -> ChunkStorageMetadata: + sizes, offsets = compute_local_shape_and_global_offset( + tensor.shape, tensor.device_mesh, tensor.placements + ) + sizes, offsets = torch.Size(sizes), torch.Size(offsets) + return ChunkStorageMetadata( + offsets=offsets, + sizes=sizes, + ) + + +def _create_chunk_list(tensor: torch.Tensor) -> list[ChunkStorageMetadata]: + if hasattr(tensor, "__create_chunk_list__"): + # DTensor implements _Checkpointable + local_chunks = tensor.__create_chunk_list__() # type: ignore[attr-defined] + elif isinstance(tensor, ShardedTensor): + local_chunks = [ + _chunk_for_shard(shard.metadata) for shard in tensor.local_shards() + ] + elif isinstance(tensor, torch.Tensor): + local_chunks = [_create_chunk_from_tensor(tensor)] + else: + raise ValueError( + "Unsupported Type, expecting one of [Tensor, DTensor, ShardedTensor] " + f",but got {type(tensor)}" + ) + + return local_chunks + + +def _create_read_items(fqn: str, md: STORAGE_TYPES, obj: Any) -> list[ReadItem]: + if not isinstance(md, BytesStorageMetadata): + try: + local_chunks = _create_chunk_list(obj) + except ValueError as ex: + raise ValueError( + f"Invalid checkpoint metadata for {fqn}, " + + f"expected BytesStorageMetadata but found {type(md)}", + ) from ex + + return create_read_items_for_chunk_list(fqn, md, local_chunks) + else: + return [ + _create_read_item_for_byteio( + dest_index=MetadataIndex(fqn), + dest_offset=0, + storage_index=MetadataIndex(fqn), + storage_offset=0, + length=0, + ) + ] + + +def _init_state_dict(state_dict: dict[str, Any]) -> Any: + """ + Initializes meta tensor if the meta tensor is DTensor or torch.Tensor. + """ + + def dtensor_func(value: DTensor): + device = getattr(value, "device", None) + if device == torch.device("meta"): + device_type = dist.distributed_c10d._get_pg_default_device().type + device = cast( + torch.device, _get_device_module(device_type).current_device() + ) + new_local_tensor = torch.empty_like(value.to_local(), device=device) + # We need to pass shape and stride explicitly, since DTensor might be + # sharded unevenly. + dtensor = DTensor.from_local( + new_local_tensor, + device_mesh=value.device_mesh, + placements=value.placements, + shape=value.size(), + stride=value.stride(), + ) + return dtensor + else: + return value + + def sharded_tensor_func(value: Any): + device = getattr(value, "device", None) + if device == torch.device("meta"): + raise RuntimeError( + f"Found unsupported type {type(value)} for meta device loading." + ) + else: + return value + + def tensor_func(value: torch.Tensor): + device = getattr(value, "device", None) + if device == torch.device("meta"): + device_type = dist.distributed_c10d._get_pg_default_device().type + device = cast( + torch.device, _get_device_module(device_type).current_device() + ) + tensor = torch.empty_like(value, device=device) + return tensor + else: + return value + + _iterate_state_dict( + state_dict, + dtensor_func, + sharded_tensor_func, + tensor_func, + ) + + +def _iterate_state_dict( + iter_object: Any, + dtensor_func: Callable, + sharded_tensor_func: Callable, + tensor_func: Callable, +): + """ + Iterate through the state dict, applying the given functions to each tensor type + and update the state dict in place. + + Args: + iter_object (Any): the target state_dict. + sharded_tensor_func (Callable): the function to apply to ShardedTensor + dtensor_func (Callable): the function to apply to DTensor + tensor_func (Callable): the function to apply to Tensor + + # TODO: let state_dict_util._iterate_state_dict() to support in place option + so we don't need to have two versions of _iterate_state_dict. + """ + + if isinstance(iter_object, DTensor): + return dtensor_func(iter_object) + elif isinstance(iter_object, ShardedTensor): + return sharded_tensor_func(iter_object) + elif isinstance(iter_object, torch.Tensor): + return tensor_func(iter_object) + elif ( + isinstance(iter_object, (int, float, str, bytes, io.BytesIO)) + or iter_object is None + ): + return iter_object + elif isinstance(iter_object, dict): + for key, value in iter_object.items(): + iter_object[key] = _iterate_state_dict( + value, dtensor_func, sharded_tensor_func, tensor_func + ) + return iter_object + elif isinstance(iter_object, (list, tuple)): + ret = [ + _iterate_state_dict(v, dtensor_func, sharded_tensor_func, tensor_func) + for v in iter_object + ] + if isinstance(iter_object, tuple): + ret = tuple(ret) # type: ignore[assignment] + return ret diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/quantized_hf_storage.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/quantized_hf_storage.py new file mode 100644 index 0000000000000000000000000000000000000000..464052d99062a9b7b4e4b156cbe7a25d0fedc017 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/quantized_hf_storage.py @@ -0,0 +1,506 @@ +# mypy: allow-untyped-defs +import json +import logging +import math +from pathlib import Path +from typing import Any + +import torch +from torch.distributed.checkpoint._hf_utils import _metadata_fn +from torch.distributed.checkpoint.metadata import TensorStorageMetadata +from torch.distributed.checkpoint.planner import LoadPlanner, ReadItem + +from .hf_storage import HuggingFaceStorageReader + + +logger: logging.Logger = logging.getLogger(__name__) + +__all__ = ["QuantizedHuggingFaceStorageReader"] + + +class QuantizedHuggingFaceStorageReader(HuggingFaceStorageReader): + """ + Extension of HuggingFaceStorageReader that handles quantized tensors. + Checkpoint should have the full tensor in a SafeTensor file. The quantized + tensor should not be sharded across multiple files. + + This reader handles the dequantization of tensors during the read process, + converting them from quantized blocks to full dequantized tensors before + copying to the target tensor. + """ + + def __init__( + self, + path: str, + thread_count: int = 1, + target_dtype: torch.dtype = torch.float32, + block_size: int = 128, + ): + """ + Initialize the HuggingFace storage reader to load quantized checkpoints + + Args: + path: directory where the checkpoint will be read from. + thread_count: Number of threads to use to read distributed checkpoint. Defaults to 1. + target_dtype: Target dtype for dequantized tensor. Defaults to torch.float32. + block_size: Fixed block size for dequantization. Defaults to 128. + """ + super().__init__(path=path, thread_count=thread_count) + + self.target_dtype: torch.dtype = target_dtype + self.block_size: int = block_size + self._weight_scale_mapping: dict[str, str] = {} + # Track which file contains each tensor + self._weight_map: dict[str, str] = {} + # Cache for full tensor shapes (fqn -> shape) + self._tensor_full_shapes: dict[str, torch.Size] = {} + + def read_metadata(self) -> Any: + metadata = super().read_metadata() + + # Load quantization metadata first. + self._load_quantization_metadata() + + # Build a cache of FQN -> full tensor shape, correcting for quantized tensors. + for fqn, tensor_metadata in metadata.state_dict_metadata.items(): + # Only process TensorStorageMetadata which has size attribute. + if isinstance(tensor_metadata, TensorStorageMetadata): + # Check if this is a MXFP4 quantized tensor that needs shape correction. + if fqn.endswith("_blocks"): + # Save the quantized tensor shapes for lookup when dequantization. + self._tensor_full_shapes[fqn + "_quantized"] = tensor_metadata.size + *prefix_shape, G, B = tensor_metadata.size + dequantized_size = torch.Size([*prefix_shape, G * B * 2]) + + # Update the metadata with the size after dequantization. + # Metadata used by planner to slice state dict. + tensor_metadata.size = dequantized_size + self._tensor_full_shapes[fqn] = dequantized_size + else: + self._tensor_full_shapes[fqn] = tensor_metadata.size + + return metadata + + def _load_quantization_metadata(self): + """Load quantization metadata from the checkpoint.""" + checkpoint_path = Path(self.path) + # Load weight mapping from index file + index_file = checkpoint_path / _metadata_fn + + with open(index_file) as f: + index_data = json.load(f) + weight_map = index_data.get("weight_map", {}) + self._build_weight_scale_mapping(weight_map) + + def _build_weight_scale_mapping(self, weight_map: dict[str, str]): + """Analyze and build weight-scale tensor pairs from weight mapping.""" + # Store the complete weight map for file location lookups. + self._weight_map = weight_map + + for tensor_name in weight_map: + if tensor_name.endswith(".weight_scale_inv"): + weight_name = tensor_name.replace(".weight_scale_inv", ".weight") + if weight_name in weight_map: + self._weight_scale_mapping[weight_name] = tensor_name + # Handle MXFP4 format: _blocks and _scales. + elif tensor_name.endswith("_scales"): + blocks_name = tensor_name.replace("_scales", "_blocks") + if blocks_name in weight_map: + self._weight_scale_mapping[blocks_name] = tensor_name + + def _process_read_request( + self, f: Any, req: ReadItem, planner: LoadPlanner + ) -> None: + """Override the Helper function that processes a single read request.""" + tensor_fqn = req.storage_index.fqn + + # Check if this is a quantized tensor that needs dequantization + if self._is_tensor_quantized(tensor_fqn): + tensor = self._read_quantized_tensor_with_block_alignment(req, f) + else: + # Standard tensor reading + slices = tuple( + slice(offset, offset + length) + for offset, length in zip(req.storage_offsets, req.lengths) + ) + tensor = f.get_slice(tensor_fqn)[slices] + + target_tensor = planner.resolve_tensor(req).detach() + + if target_tensor.size() != tensor.size(): + raise AssertionError( + f"req {req.storage_index} mismatch sizes {target_tensor.size()} vs {tensor.size()}" + ) + + target_tensor.copy_(tensor) + planner.commit_tensor(req, target_tensor) + + def _get_slice_to_block_mapping( + self, req: ReadItem + ) -> tuple[tuple[int, int], tuple[int, int], slice, slice]: + """ + Calculate which blocks correspond to the requested slice. + + Args: + req: Read request containing tensor info and required slices + + Returns: + Tuple of (row_block_range, col_block_range, row_slice, col_slice) + """ + # Get the slice information + row_slice = slice( + req.storage_offsets[0], req.storage_offsets[0] + req.lengths[0] + ) + col_slice = slice( + req.storage_offsets[1], req.storage_offsets[1] + req.lengths[1] + ) + + # Calculate which blocks this slice spans + row_start_block = row_slice.start // self.block_size + row_end_block = (row_slice.stop - 1) // self.block_size + 1 # Inclusive end + + col_start_block = col_slice.start // self.block_size + col_end_block = (col_slice.stop - 1) // self.block_size + 1 # Inclusive end + + return ( + (row_start_block, row_end_block), + (col_start_block, col_end_block), + row_slice, + col_slice, + ) + + def _dequantize_tensor_mxfp4( + self, + blocks: torch.Tensor, + scales: torch.Tensor, + req: ReadItem, + group_start: int, + offset_in_first_group: int, + ) -> torch.Tensor: + """ + Dequantize a 4D tensor using MXFP4 format. + Adapted from openai's implementation: + https://github.com/openai/gpt-oss/blob/8890e95919f975a490fc0ba09ffb10890ec7319d/gpt_oss/torch/weights.py#L68 + + Args: + blocks: Sliced quantized weight tensor of shape [a_slice, b_slice, groups_slice, B] in uint8 + scales: FULL scale tensor of shape [a, b, c] in uint8 (will be converted to exponents) + req: Read request containing slice information + group_start: The starting group index in the checkpoint + offset_in_first_group: Offset in values within the first group + + Returns: + Dequantized tensor matching the requested shape + """ + # FP4 lookup table + FP4_VALUES = [ + +0.0, + +0.5, + +1.0, + +1.5, + +2.0, + +3.0, + +4.0, + +6.0, + -0.0, + -0.5, + -1.0, + -1.5, + -2.0, + -3.0, + -4.0, + -6.0, + ] + + # blocks: [a_slice, b_slice, groups_slice, B] uint8. + # Read slightly more groups than needed, and slice at the end. + + # Slice the scales to match the blocks dimensions. + # [a_full, b_full, c_full] -> [a_slice, b_slice, groups_slice] + dim0_start = req.storage_offsets[0] + dim0_end = dim0_start + req.lengths[0] + dim1_start = req.storage_offsets[1] + dim1_end = dim1_start + req.lengths[1] + num_groups = blocks.shape[2] + scales = scales[ + dim0_start:dim0_end, + dim1_start:dim1_end, + group_start : group_start + num_groups, + ] + + scales = scales.to(torch.int32) - 127 + + assert blocks.shape[:-1] == scales.shape, ( + f"{blocks.shape=} does not match {scales.shape=}" + ) + + lut = torch.tensor(FP4_VALUES, dtype=self.target_dtype, device=blocks.device) + + *prefix_shape, G, B = blocks.shape + rows_total = math.prod(prefix_shape) * G + + blocks = blocks.reshape(rows_total, B) + scales = scales.reshape(rows_total, 1) + + out = torch.empty( + rows_total, B * 2, dtype=self.target_dtype, device=blocks.device + ) + + rows_per_chunk = 16384 * 512 + + for r0 in range(0, rows_total, rows_per_chunk): + r1 = min(r0 + rows_per_chunk, rows_total) + + blk = blocks[r0:r1] + exp = scales[r0:r1] + + # nibble indices -> int64 + idx_lo = (blk & 0x0F).to(torch.long) + idx_hi = (blk >> 4).to(torch.long) + + sub = out[r0:r1] + sub[:, 0::2] = lut[idx_lo] + sub[:, 1::2] = lut[idx_hi] + + torch.ldexp(sub, exp, out=sub) + + del idx_lo, idx_hi, blk, exp + + result = out.reshape(*prefix_shape, G, B * 2).view(*prefix_shape, G * B * 2) + + # Slice the last dimension to match the requested range. + if offset_in_first_group > 0 or result.shape[-1] > req.lengths[2]: + end_offset = offset_in_first_group + req.lengths[2] + result = result[..., offset_in_first_group:end_offset] + + return result + + def _dequantize_tensor( + self, + weight: torch.Tensor, + scale_inv: torch.Tensor, + full_tensor_shape: torch.Size, + slice_info: tuple[tuple[int, int], tuple[int, int], slice, slice], + ) -> torch.Tensor: + """ + Dequantize a sliced tensor using the appropriate portion of the scale tensor. + + Args: + weight: Sliced quantized weight tensor + scale_inv: Full scale inverse tensor for dequantization + full_tensor_shape: Shape of the original full tensor + slice_info: Block mapping information from _get_slice_to_block_mapping + + Returns: + Dequantized tensor + """ + (row_block_range, col_block_range, row_slice, col_slice) = slice_info + + # Convert to float32 for computation + # Certain quantized dtypes like Float8_e4m3fn + # don't support multiplication on CPU yet in PyTorch. + upcasted_weight = weight.to(torch.float32) + + # Create output tensor in target dtype + dequantized = weight.detach().to(dtype=self.target_dtype, copy=True) + + # Get the actual slice boundaries + row_start_global = row_slice.start + row_end_global = row_slice.stop + col_start_global = col_slice.start + col_end_global = col_slice.stop + + # Apply scaling factors to each block that intersects with our slice + for block_i in range(row_block_range[0], row_block_range[1]): + for block_j in range(col_block_range[0], col_block_range[1]): + # Calculate the block boundaries in global coordinates + block_row_start_global = block_i * self.block_size + block_row_end_global = min( + block_row_start_global + self.block_size, full_tensor_shape[0] + ) + block_col_start_global = block_j * self.block_size + block_col_end_global = min( + block_col_start_global + self.block_size, full_tensor_shape[1] + ) + + # Find the intersection of the block with our slice + intersect_row_start = max(block_row_start_global, row_start_global) + intersect_row_end = min(block_row_end_global, row_end_global) + intersect_col_start = max(block_col_start_global, col_start_global) + intersect_col_end = min(block_col_end_global, col_end_global) + + # Skip if no intersection + if ( + intersect_row_start >= intersect_row_end + or intersect_col_start >= intersect_col_end + ): + continue + + # Convert global coordinates to local coordinates in the sliced tensor + local_row_start = intersect_row_start - row_start_global + local_row_end = intersect_row_end - row_start_global + local_col_start = intersect_col_start - col_start_global + local_col_end = intersect_col_end - col_start_global + + # Get the block from the sliced tensor + block = upcasted_weight[ + local_row_start:local_row_end, local_col_start:local_col_end + ] + + # Apply the scale factor + scale = scale_inv[block_i, block_j] + block = block * scale + + # Convert block to target dtype and store + block_converted = block.to(dtype=self.target_dtype) + dequantized[ + local_row_start:local_row_end, local_col_start:local_col_end + ] = block_converted + + return dequantized + + def _is_tensor_quantized(self, tensor_fqn: str) -> bool: + """ + Check if a tensor is a quantized. + + Args: + tensor_fqn: Fully qualified name of the tensor + + Returns: + True if tensor is quantized and has a corresponding scale tensor, + False otherwise + """ + # Skip scale tensors themselves + if tensor_fqn.endswith((".weight_scale_inv", "_scales")): + return False + + # Check if this weight tensor has a corresponding scale tensor + if tensor_fqn not in self._weight_scale_mapping: + return False + + return True + + def _read_quantized_tensor_with_block_alignment( + self, req: ReadItem, safetensor_file: Any + ) -> torch.Tensor: + """ + Read a quantized tensor with block alignment. + + Args: + req: Read request containing tensor info and required slices + safetensor_file: Open safetensors file handle + + Returns: + Dequantized tensor ready for use + """ + tensor_fqn = req.storage_index.fqn + scale_fqn = self._weight_scale_mapping[tensor_fqn] + + try: + group_start = 0 + offset_in_first_group = 0 + if tensor_fqn.endswith("_blocks"): + # Full tensor is a 4D MXFP4 quantized tensor: [..., G, B]. + # Each group G produces B * 2 dequantized values. + # Checkpoint [..., G, B] -> dequantized [..., G*B*2]. + + # The planner gives 3D requests based on the dequantized shape. + # Need to figure out which groups (dimension 2 in checkpoint) to read. + + # Use the quantized checkpoint shape to get the correct B. + *prefix_shape, B = self._tensor_full_shapes[tensor_fqn + "_quantized"] + values_per_group = B * 2 # Each byte has 2 nibbles (4-bit values). + + # Calculate which groups we need based on the requested range in dim 2. + # Ensure the reequest is in 3D. + assert len(req.storage_offsets) == 3 + + # Positions in dequantized space. + dim2_start_deq = req.storage_offsets[2] + dim2_length_deq = req.lengths[2] + dim2_end_deq = dim2_start_deq + dim2_length_deq + + # Convert to group indices. + group_start = dim2_start_deq // values_per_group + group_end = (dim2_end_deq + values_per_group - 1) // values_per_group + + # Read only the necessary groups from checkpoint. + weight_slices_4d = ( + slice( + req.storage_offsets[0], req.storage_offsets[0] + req.lengths[0] + ), + slice( + req.storage_offsets[1], req.storage_offsets[1] + req.lengths[1] + ), + slice(group_start, group_end), + slice(None), # Read all B values for each group. + ) + quantized_tensor = safetensor_file.get_slice(tensor_fqn)[ + weight_slices_4d + ] + + # Also track the offset within the first group + offset_in_first_group = dim2_start_deq - ( + group_start * values_per_group + ) + else: + # 2D quantized tensor, use 2d block partition. + weight_slices = tuple( + slice(offset, offset + length) + for offset, length in zip(req.storage_offsets, req.lengths) + ) + quantized_tensor = safetensor_file.get_slice(tensor_fqn)[weight_slices] + + # Load the corresponding scale inverse tensor (full tensor) + scale_file_name = self._weight_map.get(scale_fqn) + if scale_file_name is None: + raise ValueError(f"Scale tensor {scale_fqn} not found in weight_map") + + # Check if scale tensor is in the same file as the weight tensor + weight_file_name = self._weight_map.get(tensor_fqn) + + if scale_file_name == weight_file_name: + # Scale tensor is in the same file, use current handle + scale_inv = safetensor_file.get_tensor(scale_fqn) + else: + # Scale tensor is in a different file, need to open it + from safetensors import safe_open # type: ignore[import] + + scale_file_path = Path(self.path) / scale_file_name + with safe_open( + scale_file_path, framework="pt", device="cpu" + ) as scale_file: + scale_inv = scale_file.get_tensor(scale_fqn) + + # Get the full tensor shape from our O(1) lookup cache + full_tensor_shape = self._tensor_full_shapes.get(tensor_fqn) + if full_tensor_shape is None: + raise ValueError(f"Could not find full tensor shape for {tensor_fqn}") + + # Determine which dequantization function to use. + if len(full_tensor_shape) == 2: + # 2D block-wise quantization, e.g., used in deepseek v3.1 + slice_info = self._get_slice_to_block_mapping(req) + dequantized_tensor = self._dequantize_tensor( + weight=quantized_tensor, + scale_inv=scale_inv, + full_tensor_shape=full_tensor_shape, + slice_info=slice_info, + ) + elif tensor_fqn.endswith("_blocks"): + # 4D with blocks along dimension 2, used in MXFP4, e.g. gpt-oss + dequantized_tensor = self._dequantize_tensor_mxfp4( + blocks=quantized_tensor, + scales=scale_inv, + req=req, + group_start=group_start, + offset_in_first_group=offset_in_first_group, + ) + else: + raise ValueError("Unsupported quantization types") + + return dequantized_tensor + + except Exception as e: + logger.error("Failed to read the quantized tensor!!") + raise e diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/resharding.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/resharding.py new file mode 100644 index 0000000000000000000000000000000000000000..e6f24b891aa895d3a445908fe6d084e13f9b05da --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/resharding.py @@ -0,0 +1,69 @@ +from torch.distributed.checkpoint.metadata import ChunkStorageMetadata + + +__all__: list[str] = [] + + +def _check_shard_metadata_pair_overlap( + shard1: ChunkStorageMetadata, shard2: ChunkStorageMetadata +) -> bool: + """Check if two shards overlap.""" + # For each dim of each shard, check if one shard resides on the other + # end of second shard with respect to that dim. As an example for a 2D + # shard, we would check if one shard is above or on the left of the + # other shard. + ndims = len(shard1.offsets) + for i in range(ndims): + if shard1.offsets[i] >= shard2.offsets[i] + shard2.sizes[i]: + return False + if shard2.offsets[i] >= shard1.offsets[i] + shard1.sizes[i]: + return False + + return True + + +def _shards_get_overlap_region_wrt_saved_tensor( + saved_shard: ChunkStorageMetadata, current_shard: ChunkStorageMetadata +) -> list[tuple[int, int, int, int]]: + """ + Return the overlapping region between saved_shard and current_shard. + + There returned list has the same number of elements as the tensor's dimension. + For each element, we produce a tuple with the following contents: + (dimension, `saved_shard` offset, `current_shard` offset, length) + + Offsets are relative to each shard. + """ + narrows = [] + for dim, ( + saved_shard_offset, + current_shard_offset, + saved_shard_size, + current_shard_size, + ) in enumerate( + zip( + saved_shard.offsets, + current_shard.offsets, + saved_shard.sizes, + current_shard.sizes, + ) + ): + min_range_end = min( + saved_shard_offset + saved_shard_size, + current_shard_offset + current_shard_size, + ) + + length = min_range_end - max(current_shard_offset, saved_shard_offset) + + if saved_shard_offset > current_shard_offset: + offset_for_saved_tensor = 0 + offset_for_current_tensor = saved_shard_offset - current_shard_offset + else: + offset_for_saved_tensor = current_shard_offset - saved_shard_offset + offset_for_current_tensor = 0 + + narrows.append( + (dim, offset_for_saved_tensor, offset_for_current_tensor, length) + ) + + return narrows diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/staging.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/staging.py new file mode 100644 index 0000000000000000000000000000000000000000..4bbacc66aaaffe038bced44b9ca7b466a2246f90 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/staging.py @@ -0,0 +1,474 @@ +import os +import tempfile +from concurrent.futures import Future, ThreadPoolExecutor +from contextlib import nullcontext +from dataclasses import dataclass +from datetime import timedelta +from typing import Any, cast, Optional, Union +from typing_extensions import deprecated, Protocol, runtime_checkable + +import torch +import torch.distributed as dist +from torch.distributed import ProcessGroup +from torch.distributed._state_dict_utils import _copy_state_dict, _create_cpu_state_dict +from torch.distributed.checkpoint._pg_transport import PGTransport +from torch.distributed.checkpoint._state_dict_stager import StateDictStager +from torch.distributed.checkpoint.metadata import STATE_DICT_TYPE + + +__all__ = ["AsyncStager", "BlockingAsyncStager", "DefaultStager", "StagingOptions"] + +""" +Experimental staging module for PyTorch Distributed Checkpointing. +This module provides advanced staging capabilities for checkpoints including: +- Asynchronous staging using ThreadPoolExecutor +- Pinned memory allocation for faster CPU-GPU transfers +- Shared memory support for multi-process scenarios +- Non-blocking CUDA operations with stream synchronization +- Caching of frequently used storages for efficient memory management +- Automatic resource cleanup and memory management +Classes: + AsyncStager: Protocol defining the staging interface + StagingOptions: Configuration dataclass for staging behavior + DefaultStager: Default implementation with comprehensive staging features + BlockingAsyncStager: Implementation of AsyncStager which stages the state_dict + on CPU RAM and blocks until the copy is complete. Please use DefaultStager instead. +""" + + +@runtime_checkable +class AsyncStager(Protocol): + """ + This protocol is meant to provide customization and extensibility for dcp.async_save, allowing users + to customize how data is staged previous to executing the usual dcp.save path in parallel. + The expected order of operations (concretely defined in `torch.distributed.state_dict_saver.async_save`) + is the following: + + 1. AsyncStager.stage_data(state_dict): + This call gives the AsyncStager the opportunity to 'stage' + the state_dict. The expectation and purpose of staging in this context is to create a "training-safe" + representation of the state dict, meaning that any updates to module data after staging is complete + should not be reflected in the state dict returned from this method. For example, in the default + case a copy of the entire state dict is created on CPU RAM and returned here, allowing users + to continue training without risking changes to data which is being serialized. + + 2. dcp.save is called on the state_dict returned from stage in parallel. This call is responsible + for serializing the state_dict and writing it to storage. + + 3. If AsyncStager.should_synchronize_after_execute is True, this method will be called immediately after + the serialization thread starts and before returning from dcp.async_save. If this is set to False, + the assumption is the user has defined a custom synchronization point for the purpose of further + optimizing save latency in the training loop (for example, by overlapping staging with the + forward/backward pass), and it is the respondsibility of the user to call `AsyncStager.synchronize_staging` + at the appropriate time. + + """ + + # default to True since the common case is to stage synchronously + _synchronize_after_execute: bool = True + + @property + def should_synchronize_after_execute(self) -> bool: + """ + Whether to synchronize after executing the stage. + """ + return self._synchronize_after_execute + + def stage( + self, state_dict: STATE_DICT_TYPE + ) -> Union[Future[STATE_DICT_TYPE], STATE_DICT_TYPE]: + """ + Returns a "staged" copy of `state_dict`. The expectation of the staged copy is that it is + inoculated from any updates incurred after the stage call is complete. + """ + raise NotImplementedError( + f"{self.__class__.__name__} must implement stage method" + ) + + @deprecated( + "`synchronize_staging` is deprecated and will be removed in future versions." + "Please use staging_future from AsyncSaveResponse instead.", + category=FutureWarning, + ) + def synchronize_staging(self) -> None: + """ + In the case `stage` is async in some way, this method should be called to ensure staging + is complete and it is safe to begin modifying the original `state_dict` + """ + + def close(self) -> None: + """ + Clean up all resources used by the stager. + """ + + +@dataclass +class StagingOptions: + """ + Configuration options for checkpoint staging behavior. + + Attributes: + use_pinned_memory (bool): Enable pinned memory allocation for faster + CPU-GPU transfers. Requires CUDA to be available. Default: True + use_shared_memory (bool): Enable shared memory for multi-process + scenarios. Useful when multiple processes need access to the + same staged data. Default: True + use_async_staging (bool): Enable asynchronous staging using a + background thread pool. Allows overlapping computation with + staging operations. Requires CUDA. Default: True + use_non_blocking_copy (bool): Use non-blocking device memory + copies with stream synchronization. Improves performance by + allowing CPU work to continue during GPU transfers. Default: True + + Note: + CUDA-dependent features will raise exception if CUDA is not available. + """ + + use_pinned_memory: bool = True + use_shared_memory: bool = True + use_async_staging: bool = True + use_non_blocking_copy: bool = True + + +class DefaultStager(AsyncStager): + """ + DefaultStager provides a full-featured staging implementation that combines + multiple optimization techniques for efficient checkpoint preparation. + + The staging process works as follows: + 1. State dictionary is submitted for staging (sync or async) + 2. Tensors are copied from GPU to optimized CPU storage + 3. CUDA operations are synchronized if non-blocking copies are used + 4. Staged state dictionary is returned or made available via Future + + Usage Patterns: + # Synchronous staging + stager = DefaultStager(StagingOptions(use_async_staging=False)) + staged_dict = stager.stage(state_dict) + stager.close() + + # Asynchronous staging + stager = DefaultStager(StagingOptions(use_async_staging=True)) + future = stager.stage(state_dict) + # ... do other work ... + staged_dict = future.result() + stager.close() + + # Context manager pattern (recommended) + stager = DefaultStager(config) + with stager: + result = stager.stage(state_dict) + + Performance Considerations: + - Async staging provides best performance when model computation + can overlap with staging operations + - Pinned memory improves CPU-GPU transfer speeds but uses more memory + - Shared memory allows efficient IPC to checkpoint process + - Non-blocking copies reduce GPU idle time during memory transfers + + Thread Safety: + DefaultStager is not thread-safe. Each thread should use its own + instance, or external synchronization should be provided. + """ + + def __init__( + self, + config: StagingOptions = StagingOptions(), + ): + self._config = config + self._state_dict_stager = StateDictStager( + pin_memory=config.use_pinned_memory, share_memory=config.use_shared_memory + ) + self._staging_executor = None + self._staging_stream = None + if self._config.use_async_staging: + # pyrefly: ignore [bad-assignment] + self._staging_executor = ThreadPoolExecutor(max_workers=1) + if torch.accelerator.is_available(): + # Note: stream needs to be initialized on the main thread after default cuda + # stream is setup/used to avoid the risk of accidentally reusing the main + # compute stream or in other cases kernels actually launching from the + # main thread. + # pyrefly: ignore [bad-assignment] + self._staging_stream = torch.Stream() + + if self._config.use_non_blocking_copy: + if not torch.accelerator.is_available(): + raise AssertionError( + "Non-blocking copy requires that the current accelerator is available." + ) + + self._staging_future: Optional[Future[STATE_DICT_TYPE]] = None + + def stage( + self, + state_dict: STATE_DICT_TYPE, + **kwargs: Any, + ) -> Union[STATE_DICT_TYPE, Future[STATE_DICT_TYPE]]: + """ + This function is responsible for staging staging the state_dict. + See class docstring for more details on staging. + If use_async_staging is True, it will return a Future object that will be + fulfilled when staging is complete. + If use_async_staging is False, it will return the fully staged state_dict. + + Args: + state_dict (STATE_DICT_TYPE): The state_dict to be staged. + """ + if self._config.use_async_staging: + if self._staging_executor is None: + raise AssertionError( + "staging_executor should not be None for async staging" + ) + self._staging_future = self._staging_executor.submit( + self._stage, + state_dict, + **kwargs, + ) + return self._staging_future + else: + return self._stage(state_dict, **kwargs) + + def _stage(self, state_dict: STATE_DICT_TYPE, **kwargs: Any) -> STATE_DICT_TYPE: + if self._config.use_non_blocking_copy: + if not (self._staging_stream or not self._config.use_async_staging): + raise AssertionError( + "Non-blocking copy in a background thread for async staging needs staging_stream to be initialized." + ) + with ( + self._staging_stream + if self._staging_stream is not None + else nullcontext() + ): + state_dict = self._state_dict_stager.stage( + state_dict, non_blocking=self._config.use_non_blocking_copy + ) + # waits for the enqued copy operations to finish. + self._staging_stream.synchronize() if self._staging_stream else torch.accelerator.synchronize() + else: + state_dict = self._state_dict_stager.stage(state_dict, non_blocking=False) + return state_dict + + def close(self) -> None: + """ + Clean up all resources used by the DefaultStager. Shuts down the ThreadPoolExecutor + used for async staging operations and cleans up the underlying StateDictStager's + cached storages. Should be called when the stager is no longer needed to prevent + resource leaks, especially in long-running applications. After calling close(), + the stager should not be used for further staging operations. + + Example Usage: + stager = DefaultStager(StagingOptions(use_async_staging=True)) + future = stager.stage(state_dict) + result = future.result() + stager.close() # Clean up all resources + """ + if self._staging_executor: + self._staging_executor.shutdown(wait=True) + + def synchronize_staging(self) -> None: + """ + When use_async_staging is True, this method will wait until staging is complete. + If use_async_staging is False, this method is a no-op. + """ + if self._staging_future is not None: + self._staging_future.result() + + +class BlockingAsyncStager(AsyncStager): + """ + An implementation of AsyncStager which stages the state_dict on CPU RAM and blocks until the copy is complete. + This implementation also provides an option to optimize stage latency using pinned memory. + + N.B. synchronize_staging is a no-op in this case. + + + """ + + # default to True since the common case is to stage synchronously + _synchronize_after_execute: bool = False + + def __init__( + self, + cache_staged_state_dict: bool = False, + type_check: bool = False, + ): + """ + Initializes the BlockingAsyncStager. + + Args: + cache_staged_state_dict: Whether to cache the staged state_dict. This option decreases staging latency + at the cost of increases memory usage. Additionally, if this parameter is set to True, it's the expectation + that the stager is maintained and reused for multiple dcp.async_save calls. Default to False. + type_check: Whether to perform a type check during cpu_offload. Defaults to False. + + """ + self.cache_staged_state_dict = cache_staged_state_dict + self.type_check = type_check + self.state_dict_cache: Optional[STATE_DICT_TYPE] = None + + def stage(self, state_dict: STATE_DICT_TYPE) -> STATE_DICT_TYPE: + """ + Returns a copy of `state_dict` on the CPU. + """ + + if not self.cache_staged_state_dict: + staged_state_dict = _create_cpu_state_dict(state_dict) + _copy_state_dict(state_dict, staged_state_dict, type_check=self.type_check) + return staged_state_dict + + if self.state_dict_cache is None: + self.state_dict_cache = _create_cpu_state_dict(state_dict, pin_memory=True) + return _copy_state_dict(state_dict, self.state_dict_cache) + + def synchronize_staging(self) -> None: + """ + No-op function, since staging is blocking. + """ + + def close(self) -> None: + pass + + +class _ReplicationStager(AsyncStager): + """ + An AsyncStager implementation that replicates state_dict across training ranks + using PGTransport. + + Args: + pg: ProcessGroup for distributed communication + timeout: Timeout for communication operations + device: Device to use for tensor operations + storage_dir: Directory to store persisted state_dicts + + Warning: This is experimental and subject to change. + """ + + _synchronize_after_execute: bool = False + + def __init__( + self, + pg: ProcessGroup, + timeout: timedelta = timedelta(minutes=30), + device: torch.device = torch.device("cpu"), + storage_dir: Optional[str] = None, + ): + self._pg = pg + self._timeout = timeout + # pyrefly: ignore [read-only] + self._device = device + self._transport = PGTransport(pg, timeout, device, None) + + # Set up storage directory for persisting exchanged state_dicts + if storage_dir is None: + self._storage_dir = tempfile.mkdtemp(prefix="replication_stager_") + else: + self._storage_dir = storage_dir + os.makedirs(self._storage_dir, exist_ok=True) + + def stage( + self, state_dict: STATE_DICT_TYPE + ) -> Union[Future[STATE_DICT_TYPE], STATE_DICT_TYPE]: + """ + Stage the state_dict by replicating it across ranks. Returns a state_dict representing + the received replica. + + Perform the actual replication logic. Creates bidirectional pairs where each rank exchanges + state_dict with its partner at (rank + world_size//2) % world_size. + Uses simple rank-based ordering to prevent deadlocks. + + Assumes world_size is always even. + """ + if not dist.is_initialized(): + return state_dict + + world_size = dist.get_world_size() + + current_rank = dist.get_rank() + + # Calculate partner rank using half-world offset + # creates bidirectional pairs for replication. + offset = world_size // 2 + partner_rank = (current_rank + offset) % world_size + + # Use simple rank-based ordering to prevent deadlocks. + # Lower-numbered rank sends first, higher-numbered rank receives first. + if current_rank < partner_rank: + # Send first, then receive + self._transport.send_checkpoint([partner_rank], state_dict) + received_state_dict = self._transport.recv_checkpoint(partner_rank) + else: + # Receive first, then send + received_state_dict = self._transport.recv_checkpoint(partner_rank) + self._transport.send_checkpoint([partner_rank], state_dict) + + # Persist the received state_dict for future discoverability + received_state_dict = cast(STATE_DICT_TYPE, received_state_dict) + self._persist_state_dict(received_state_dict, current_rank, partner_rank) + + return received_state_dict + + def _persist_state_dict( + self, state_dict: STATE_DICT_TYPE, current_rank: int, partner_rank: int + ) -> None: + """ + Persist the received state_dict to disk for future discoverability. + Only keeps one replica per rank, overwriting any previous replica. + Uses atomic write pattern (temp file + rename). + + Args: + state_dict: The state_dict received from partner rank + current_rank: Current rank that received the state_dict + partner_rank: Rank that sent the state_dict + """ + final_path = self._get_persisted_path(current_rank, partner_rank) + temp_path = final_path + ".tmp" + + try: + # Ensure parent directory exists and is writable + os.makedirs(os.path.dirname(final_path), exist_ok=True) + + # Write to temporary file with explicit flushing + with open(temp_path, "wb") as f: + torch.save(state_dict, f) + # Flush application buffers to OS buffers + f.flush() + # Force OS buffers to disk for durability + os.fsync(f.fileno()) + + # Atomic rename to final location + os.rename(temp_path, final_path) + except Exception as e: + # Clean up temp file if it exists + try: + if os.path.exists(temp_path): + os.remove(temp_path) + except Exception: + pass # Ignore cleanup errors + # Re-raise the original exception with more context + raise RuntimeError( + f"Failed to persist state_dict from rank {partner_rank} to rank {current_rank}: {e}" + ) from e + + def _get_persisted_path(self, current_rank: int, partner_rank: int) -> str: + """ + Get the file path where a state_dict would be persisted. + + Args: + current_rank: Current rank + + Returns: + File path for the persisted state_dict + """ + filename = f"rank_{current_rank}_replica_partner_{partner_rank}.pt" + return os.path.join(self._storage_dir, filename) + + def synchronize_staging(self) -> None: + """ + No-op function, since staging is blocking. + """ + + def close(self) -> None: + """ + Clean up resources. Persisted files are intentionally left for future discovery. + """ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/state_dict.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/state_dict.py new file mode 100644 index 0000000000000000000000000000000000000000..6a31144348acb669e0a2f8e17805d5650d5d61d1 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/state_dict.py @@ -0,0 +1,1634 @@ +# mypy: allow-untyped-defs +import contextlib +import functools +import gc +import warnings +from collections.abc import Callable, Generator, Iterable +from dataclasses import asdict, dataclass, field +from itertools import chain +from typing import Any, cast, no_type_check, Optional, Union + +import torch +import torch.distributed as dist +import torch.nn as nn +from torch.distributed._shard.sharded_tensor import ShardedTensor +from torch.distributed._state_dict_utils import ( + _broadcast_state_dict, + _distribute_state_dict, + _flatten_state_dict, + _gather_state_dict, + _offload_state_dict_to_cpu, + _unflatten_state_dict, +) +from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( + _CHECKPOINT_PREFIX, +) +from torch.distributed.fsdp import ( + FullOptimStateDictConfig, + FullStateDictConfig, + FullyShardedDataParallel as FSDP, + OptimStateDictConfig, + ShardedOptimStateDictConfig, + ShardedStateDictConfig, + StateDictConfig, + StateDictType, +) +from torch.distributed.fsdp._common_utils import ( + _get_module_fsdp_state_if_fully_sharded_module, + FSDP_WRAPPED_MODULE, +) +from torch.distributed.tensor import DTensor +from torch.nn.modules.module import _IncompatibleKeys +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils._pytree import tree_map_only + + +__all__ = [ + "FQNS_T", + "PrimitiveType", + "ValueType", + "DictValueType", + "ListDictValueType", + "OptimizerStateType", + "StateDictOptions", + "get_model_state_dict", + "get_optimizer_state_dict", + "get_state_dict", + "set_model_state_dict", + "set_optimizer_state_dict", + "set_state_dict", +] + + +_FLAT_PARAM = "_flat_param" +_PG = "param_groups" +_PARAMS = "params" +_STATE = "state" + +FQNS_T = set[str] +PrimitiveType = Union[DTensor, ShardedTensor, torch.Tensor, int, float, str] +ValueType = Union[ + PrimitiveType, list[PrimitiveType], tuple[PrimitiveType], dict[str, "ValueType"] +] +DictValueType = dict[str, ValueType] +ListDictValueType = list[DictValueType] +OptimizerStateType = dict[str, Union[DictValueType, ListDictValueType]] + + +_patched_state_dict: set[Callable] = set() + + +@contextlib.contextmanager +def _gc_context(): + is_enabled = gc.isenabled() + gc.disable() + try: + yield + finally: + if is_enabled: + gc.enable() + + +@dataclass +class StateDictOptions: + """ + This dataclass specifies how get_state_dict/set_state_dict will work. + + - ``full_state_dict``: if this is set to True, all the tensors in the + returned state_dict will be gathered. No ShardedTensor and DTensor + will be in the returned state_dict. + + - ``cpu_offload``: offload all the tensors to cpu. To prevent CPU OOM, if + ``full_state_dict`` is also true, then only the rank0 will get the + state_dict and all other ranks will get empty state_dict. + + - ``ignore_frozen_params``: if the value is True, the returned state_dict + won't contain any frozen parameters -- the ``requires_grad`` is False. + The default value is False. + + - ``keep_submodule_prefixes`` (deprecated): when ``submodules`` is not None, this option + indicates whether to keep the submodule prefixes from the state_dict keys. + or example, if the submodule is ``module.pretrain`` and the full FQN of + the parameter is ``pretrain.layer1.weight`` of the param. When this option + is True, the parameter's key in the returned state_dict will be + ``pretrain.layer1.weight``. If the options is False, the key will be + ``layer1.weight``. + Note that if ``keep_submodule_prefixes`` is False, there may be conflicted + FQNs, hence there should be only one submodule in ``submodules``. + + - ``strict``: the ``strict`` option when ``set_state_dict`` calls + model.load_state_dict(). + + - ``broadcast_from_rank0``: when the option is True, rank0 should receive a + full state_dict and will broadcast the tensors in the state_dict/ + optim_state_dict one by one to other ranks. Other ranks will receive + the tensors and shard according to the local shards in the model and + optimizer. ``full_state_dict`` must be set to True when using this option. + This option currently only supports DTensor, not the legacy ShardedTensor. + """ + + full_state_dict: bool = False + cpu_offload: bool = False + ignore_frozen_params: bool = False + keep_submodule_prefixes: bool = True + strict: bool = True + broadcast_from_rank0: bool = False + flatten_optimizer_state_dict: bool = False + dsd_fqn_modifiers: str = "_fqn_modifiers" + + +@dataclass +class _StateDictInfo(StateDictOptions): + fqn_param_mapping: dict[ + Union[str, torch.Tensor], + Union[FQNS_T, torch.Tensor], + ] = field(default_factory=dict) + shared_params_mapping: dict[ + Union[str, torch.Tensor], + Union[FQNS_T, torch.Tensor], + ] = field(default_factory=dict) + submodule_prefixes: set[str] = field(default_factory=set) + handle_model: bool = True + handle_optim: bool = True + fsdp_context: Callable = contextlib.nullcontext + fsdp_modules: list[nn.Module] = field(default_factory=list) + + +def _get_fqns( + model: nn.Module, + name: str, + dsd_fqn_modifiers: str = "_fqn_modifiers", + skip_ddp_prefix: bool = True, + skip_compiler_prefix: bool = True, +) -> FQNS_T: + """ + This API is used to convert the name of a parameter to the FQNs. For FSDP + without `use_orig_params`, the name of FlatParameter can be mapped to + multiple original parameters. As a result, the return type of this function + is `set[str]`. + + Args: + module (nn.Module): the root model. + name (str): the name + skip_ddp_prefix (bool): whether to skip DDP's `module` prefix + + Returns: + The canonical FQNs based on the model traversal. + """ + + # Remove the checkpoint prefix, if it exists. + name = name.replace(_CHECKPOINT_PREFIX, "") + if "." not in name: + return {name} + + obj_names = name.split(".") + fqn_obj_names = [] + curr_obj = model + for i, curr_obj_name in enumerate(obj_names): + if isinstance(curr_obj, DDP): + if curr_obj_name != "module": + raise AssertionError(f"Expected 'module', got '{curr_obj_name}'") + curr_obj = curr_obj.module + if not skip_ddp_prefix: + fqn_obj_names.append(curr_obj_name) + elif isinstance(curr_obj, FSDP): + if i < len(obj_names) - 1 and obj_names[i + 1] == _FLAT_PARAM: + prefix = ".".join(fqn_obj_names) + flat_param = getattr(curr_obj, _FLAT_PARAM) + if prefix: + prefix = f"{prefix}." + return {f"{prefix}{fqn}" for fqn in flat_param._fqns} + curr_obj = getattr(curr_obj, FSDP_WRAPPED_MODULE) + if curr_obj_name != FSDP_WRAPPED_MODULE: + # pyrefly: ignore [bad-argument-type] + fqn_obj_names.append(curr_obj_name) + curr_obj = getattr(curr_obj, curr_obj_name) + elif isinstance(curr_obj, torch._dynamo.eval_frame.OptimizedModule): + if curr_obj_name != "_orig_mod": + raise AssertionError(f"Expected '_orig_mod', got '{curr_obj_name}'") + curr_obj = curr_obj._orig_mod + if not skip_compiler_prefix: + fqn_obj_names.append(curr_obj_name) + else: + # In some modules, _fqn_modifiers would not shown in the state_dict keys, + # skip them in the fqn to ensure load stat dict successfully for them. + if hasattr(curr_obj, dsd_fqn_modifiers): + if removed_fqn := getattr(curr_obj, dsd_fqn_modifiers)().get( + curr_obj_name + ): + if hasattr(curr_obj, removed_fqn): + curr_obj = getattr(curr_obj, removed_fqn) + # pyrefly: ignore [bad-argument-type] + fqn_obj_names.append(curr_obj_name) + if curr_obj_name == nn.modules.module._EXTRA_STATE_KEY_SUFFIX: + if i != len(obj_names) - 1: + raise RuntimeError("Expect `_extra_state` to be the last obj name") + else: + curr_obj = getattr(curr_obj, curr_obj_name) + + return {".".join(fqn_obj_names).replace(_CHECKPOINT_PREFIX, "")} + + +class _EXTRA_STATE: + pass + + +def _iterate_valid_model_state(model, dsd_fqn_modifiers="_fqn_modifiers"): + visited_modules: set[nn.Module] = set() + + def recurse(module: nn.Module, curr_fqn: str) -> Generator: + visited_modules.add(module) + + curr_fqn = f"{curr_fqn}." if curr_fqn else "" + for name, submodule in module.named_children(): + if submodule in visited_modules: + continue + # if user have state_dict_hooks in their model, they can add the state_dict key changes + # at dsd_fqn_modifiers in input to align with the function of state_dict_hook + if ( + hasattr(module, dsd_fqn_modifiers) + and name in getattr(module, dsd_fqn_modifiers)().values() + ): + # skip _fqn_modifiers here thus remove the last `.` added + new_fqn = curr_fqn[:-1] + else: + new_fqn = f"{curr_fqn}{name}" + yield from recurse(submodule, new_fqn) + + for name, obj in chain( + module.named_buffers(recurse=False), module.named_parameters(recurse=False) + ): + if name in module._non_persistent_buffers_set: + continue + new_fqn = f"{curr_fqn}{name}" + yield new_fqn, obj + + if ( + getattr(module.__class__, "get_extra_state", nn.Module.get_extra_state) + != nn.Module.get_extra_state + ): + new_fqn = f"{curr_fqn}{nn.modules.module._EXTRA_STATE_KEY_SUFFIX}" + yield new_fqn, _EXTRA_STATE() + + yield from recurse(model, "") + + +def _verify_options( + model: nn.Module, + optims: tuple[torch.optim.Optimizer, ...], + optim_only: bool, + *, + submodules: Optional[set[nn.Module]] = None, + options: Optional[StateDictOptions] = None, +) -> _StateDictInfo: + """ + Verify the model and options passed by the user and generates _StateDictInfo. + """ + if submodules: + warnings.warn( + "Getting submodules only model/optim state_dict is deprecated and " + "will be removed in 2.5. This feature can be achieved by manually " + "filtering out the state_dict returned from get_state_dict.", + FutureWarning, + stacklevel=2, + ) + if optim_only and not optims: + raise RuntimeError( + "Optimizers are not passed in but optim_only is set to True." + ) + + options = options or StateDictOptions() + + fqn_param_mapping: dict[ + Union[str, torch.Tensor], Union[set[str], torch.Tensor] + ] = {} + shared_params_mapping: dict[ + Union[str, torch.Tensor], Union[set[str], torch.Tensor] + ] = {} + for name, param in _iterate_valid_model_state(model): + if isinstance(param, _EXTRA_STATE): + continue + + fqns = _get_fqns(model, name) + fqn = fqn_param_mapping.get(param) + if fqn is not None: + cast(set[str], fqn_param_mapping[param]).update(fqns) + shared_params_mapping[param] = fqn_param_mapping[param] + else: + # We need to do copy as _get_fqns is lru_cached + fqn_param_mapping[param] = fqns.copy() + for fqn in fqns: + if not isinstance(param, _EXTRA_STATE): + fqn_param_mapping[fqn] = param + + for param_, fqns_ in list(shared_params_mapping.items()): + for fqn in fqns_: + shared_params_mapping[fqn] = cast(torch.Tensor, param_) + + submodule_prefixes: set[str] = set() + if submodules: + submodules = set(submodules) + for name, module in model.named_modules(): + if module not in submodules: + continue + fqns = _get_fqns(model, name) + if len(fqns) != 1: + raise AssertionError("Submodule FQN should only have 1 instance") + submodule_prefixes.update(f"{fqn}." for fqn in fqns) + + if options.broadcast_from_rank0 and not options.full_state_dict: + raise ValueError( + "full_state_dict must be True when broadcast_from_rank0 is True." + ) + fsdp_modules = FSDP.fsdp_modules(model) + state_dict_config: StateDictConfig + optim_state_dict_config: OptimStateDictConfig + fsdp_context: Callable + if fsdp_modules: + # FSDP API only work if at least one FSDP instance exists. + if options.full_state_dict: + state_dict_config = FullStateDictConfig( + offload_to_cpu=options.cpu_offload, rank0_only=options.cpu_offload + ) + optim_state_dict_config = FullOptimStateDictConfig( + offload_to_cpu=options.cpu_offload, + rank0_only=(options.cpu_offload or options.broadcast_from_rank0), + ) + state_dict_type = StateDictType.FULL_STATE_DICT + else: + state_dict_config = ShardedStateDictConfig( + offload_to_cpu=options.cpu_offload, + ) + optim_state_dict_config = ShardedOptimStateDictConfig( + offload_to_cpu=options.cpu_offload, + ) + state_dict_type = StateDictType.SHARDED_STATE_DICT + + @contextlib.contextmanager + def fsdp_state_dict_type_without_warning( + module, + state_dict_type, + state_dict_config, + optim_state_dict_config, + ): + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", message="FSDP.state_dict_type", category=FutureWarning + ) + with FSDP.state_dict_type( + module=module, + state_dict_type=state_dict_type, + state_dict_config=state_dict_config, + optim_state_dict_config=optim_state_dict_config, + ): + yield + + fsdp_context = functools.partial( + fsdp_state_dict_type_without_warning, + module=model, + state_dict_type=state_dict_type, + state_dict_config=state_dict_config, + optim_state_dict_config=optim_state_dict_config, + ) + else: + fsdp_context = contextlib.nullcontext + + return _StateDictInfo( + **asdict(options), + fqn_param_mapping=fqn_param_mapping, + shared_params_mapping=shared_params_mapping, + submodule_prefixes=submodule_prefixes, + fsdp_context=fsdp_context, + fsdp_modules=cast(list[nn.Module], fsdp_modules), + handle_model=not optim_only, + handle_optim=(len(optims) > 0), + ) + + +def _verify_state_dict( + model_state_dict: dict[str, ValueType], + optim_state_dict: OptimizerStateType, + info: _StateDictInfo, +) -> None: + for module in info.fsdp_modules: + fsdp_state = _get_module_fsdp_state_if_fully_sharded_module(module) + if fsdp_state is None: + raise AssertionError("Expected a fsdp_state with a fsdp module.") + + # Verify if the model_state_dict and optim_state_dict are valid. This API + # should give the users an explicit error message to debug or report. + if ( + info.handle_model + and not model_state_dict + and not info.submodule_prefixes + and not info.ignore_frozen_params + and not (info.cpu_offload and info.full_state_dict) + and info.strict + and not info.broadcast_from_rank0 + ): + raise RuntimeError( + "The option indicates that model state_dict is required to save " + "or load, but model state_dict is empty." + f"rank = {dist.get_rank()=}." + ) + + if info.handle_optim: + if ( + not optim_state_dict + and not (info.cpu_offload and info.full_state_dict) + and (not info.broadcast_from_rank0) + ): + raise RuntimeError( + "The option indicates that model state_dict is required to save, " + f"or load but optim state_dict is empty. {optim_state_dict}" + ) + + for key in model_state_dict: + if _FLAT_PARAM in key: + raise RuntimeError( + f"{key} contains {_FLAT_PARAM}. This can happen if the model " + "is not the root module." + ) + + +def _state_dict_fn(obj: Union[nn.Module, torch.optim.Optimizer], api: str) -> Callable: + call = getattr(obj, api) + if call in _patched_state_dict: + call = functools.partial(getattr(obj.__class__, api), self=obj) + return call + + +def _maybe_full_or_cpu_state_dict( + state_dict: dict[str, Any], info: _StateDictInfo +) -> dict[str, Any]: + if info.full_state_dict: + ranks_only = ( + () + if (not info.cpu_offload or not torch.distributed.is_initialized()) + else (0,) + ) + return _gather_state_dict( + state_dict, cpu_offload=info.cpu_offload, ranks_only=ranks_only + ) + elif info.cpu_offload: + return _offload_state_dict_to_cpu(state_dict) + else: + return state_dict + + +@torch.no_grad() +def _get_model_state_dict( + model: nn.Module, info: _StateDictInfo +) -> dict[str, ValueType]: + if not info.handle_model: + return {} + + with info.fsdp_context(): + state_dict = _state_dict_fn(model, "state_dict")() + + for key in list(state_dict.keys()): + fqns = _get_fqns(model, key) + if len(fqns) != 1: + raise AssertionError( + f"Expected 1 FQN for key '{key}', got {len(fqns)}: {fqns}" + ) + fqn = next(iter(fqns)) + if fqn != key: + # As we only support FSDP, DDP, and TP, the only cases are + # wrapper-based DDP and compiler. Verify if the assumption + # is correct. + def verify(key, fqn) -> bool: + if len(fqn) >= len(key): + return False + fqn_split = fqn.split(".") + key_split = key.split(".") + fqn_idx = 0 + for key_idx, key_name in enumerate(key_split): + if key_name == fqn_split[fqn_idx]: + fqn_idx += 1 + if fqn_idx == len(fqn_split): + return key_idx == len(key_split) - 1 + elif key_name in ("module", "_orig_mod"): + continue + else: + return False + return True + + if not verify(key, fqn): + raise RuntimeError(f"An unexpected key, {key}, exists. FQN is {fqn}") + state_dict[fqn] = state_dict.pop(key) + + if info.submodule_prefixes: + new_state_dict: dict[str, ValueType] = {} + # TODO: make this faster. + for fqn in state_dict: + for prefix in info.submodule_prefixes: + if not fqn.startswith(prefix): + continue + if info.keep_submodule_prefixes: + new_state_dict[fqn] = state_dict[fqn] + else: + new_fqn = fqn[len(prefix) :] + new_state_dict[new_fqn] = state_dict[fqn] + state_dict = new_state_dict + + if info.ignore_frozen_params: + for key, param in model.named_parameters(): + if param.requires_grad: + continue + fqns = _get_fqns(model, key) + for fqn in fqns: + state_dict.pop(fqn) + + return _maybe_full_or_cpu_state_dict(state_dict, info) + + +@torch.no_grad() +def _load_model_state_dict( + model: nn.Module, + state_dict: dict[str, ValueType], + info: _StateDictInfo, +) -> _IncompatibleKeys: + if not info.handle_model or (not state_dict and not info.broadcast_from_rank0): + return _IncompatibleKeys({}, {}) + + local_state_dict = {} + for key, value in _iterate_valid_model_state(model, info.dsd_fqn_modifiers): + fqns = _get_fqns(model, key, info.dsd_fqn_modifiers) + fqns_with_prefix = _get_fqns( + model, + key, + info.dsd_fqn_modifiers, + skip_ddp_prefix=False, + skip_compiler_prefix=False, + ) + + for fqn, fqn_with_prefix in zip(fqns, fqns_with_prefix): + if ( + not info.broadcast_from_rank0 or dist.get_rank() == 0 + ) and fqn != fqn_with_prefix: + load_value = state_dict.pop(fqn, None) + if load_value is None: + if info.strict: + raise RuntimeError(f"Missing key: {fqn}.") + else: + state_dict[fqn_with_prefix] = load_value + local_state_dict[fqn_with_prefix] = value + + assign = False + if info.broadcast_from_rank0 or info.full_state_dict: + devices = set() + for value in local_state_dict.values(): + if torch.is_tensor(value) and value.dim() > 0: + devices.add(value.device) + # In lora state_dict, there could be multiple devices, with meta device inside. + # Take the other device in the broadcast/distribtue, and set assign to True + if torch.device("meta") in devices: + devices.remove(torch.device("meta")) + assign = True + if len(devices) == 0: + devices.add(dist.distributed_c10d._get_pg_default_device()) + elif len(devices) > 1: + raise ValueError("Multiple devices found") + + if info.broadcast_from_rank0: + _broadcast_state_dict( + state_dict, + local_state_dict, + device=devices.pop(), + strict=info.strict, + cpu_offload=info.cpu_offload, + ) + elif info.full_state_dict: + _distribute_state_dict(state_dict, local_state_dict, device=devices.pop()) + state_dict.update(local_state_dict) + + with info.fsdp_context(): + return cast( + _IncompatibleKeys, + _state_dict_fn(model, "load_state_dict")( + state_dict=state_dict, strict=info.strict, assign=assign + ), + ) + + +def _init_optim_state(optim: torch.optim.Optimizer) -> None: + """ + Initialize optim states by calling the step() with zero grads. + """ + if optim.state: + # The optimizer state is initialized. + return + + # There are some stateless optimizers like SGD. These optimizer will + # not return in the above condition. So if gradients exist, we should also + # return. If gradients do not exist, the following initialization should + # not disturb SGD because the gradients and lr are both zero. + for param_group in optim.param_groups: + for param in param_group[_PARAMS]: + if param.grad is not None: + return + + for param_group in optim.param_groups: + for param in param_group[_PARAMS]: + if param.requires_grad: + param.grad = torch.zeros_like(param) + + # Some optimizers will update parameters regardless of grads due to lr, so + # make lr to zero when calling `step()`. + lrs = [] + for param_group in optim.param_groups: + if "lr" in param_group: + lrs.append(param_group["lr"]) + param_group["lr"] = ( + torch.tensor(0.0) + if isinstance(param_group["lr"], torch.Tensor) + else 0.0 + ) + optim.step(closure=None) + # Whether to recover the "lr" should not matter too much as we will + # restore checkpointing later. + for param_group in optim.param_groups: + if "lr" in param_group: + param_group["lr"] = lrs.pop(0) + optim.zero_grad(set_to_none=True) + + +def _flatten_optim_state_dict(state_dict: OptimizerStateType) -> dict[str, ValueType]: + """ + This API flattens the optimizer state_dict to support optimizer resharding for + MPMD, e.g., pipeline parallelism. + + Without the API, the original optimizer state_dict looks like: + { + "state": { + "layer1.weight": { + "step": 10, "exp_avg": SomeTensor, "exp_avg_sq": SomeTensor + }, + "layer2.weight": { + "step": 10, "exp_avg": SomeTensor, "exp_avg_sq": SomeTensor + }, + }, + "param_groups": [ + { + "lr": 0.0, + "betas": (0.9, 0.95), ..., + "params": ["layer1.weight", "layer2.weight"] + } + ] + } + + With this API, the optimizer state_dict looks like: + { + "state.layer1.weight.step": 10, + "state.layer2.weight.step": 10, + "state.layer1.weight.exp_avg": SomeTensor, + "state.layer2.weight.exp_avg": SomeTensor, + "state.layer1.weight.exp_avg_sq": SomeTensor, + "state.layer2.weight.exp_avg_sq": SomeTensor, + "param_groups.layer1.weight.lr": 0.1, + "param_groups.layer2.weight.lr": 0.1, + "param_groups.layer1.weight.betas": (0.9, 0.95), + "param_groups.layer2.weight.betas": (0.9, 0.95), + } + + The "state" section supports arbitrary levels of nesting for optimizers like Shampoo. + """ + + def _flatten_state_nested_dict( + nested_dict: dict[str, Any], prefix: str + ) -> dict[str, ValueType]: + """ + Recursively flatten a nested dictionary with dot-separated keys. + + Args: + nested_dict: The dictionary to flatten + prefix: The prefix to prepend to all keys + + Returns: + Flattened dictionary with dot-separated keys + """ + flattened: dict[str, ValueType] = {} + + for key, value in nested_dict.items(): + # Convert all keys to strings for flattening + str_key = str(key) + full_key = f"{prefix}.{str_key}" if prefix else str_key + + if isinstance(value, dict): + # Recursively flatten nested dictionaries + flattened.update(_flatten_state_nested_dict(value, full_key)) + else: + # Base case: store the value with the flattened key + _raise_if_type_not_supported(value) + flattened[full_key] = value + + return flattened + + def _raise_if_type_not_supported(v): + if not isinstance(v, (torch.Tensor, int, float, dict)): + raise NotImplementedError( + "Flattening optimizer state_dict only supports " + "tensor, int, float, dict states now. " + f"Type is {type(v)}." + ) + + ret: dict[str, ValueType] = {} + + # Handle the "state" section with recursive flattening + for fqn, state in cast(DictValueType, state_dict[_STATE]).items(): + state_prefix = f"{_STATE}.{fqn}" + ret.update( + _flatten_state_nested_dict(cast(dict[str, Any], state), state_prefix) + ) + + # Handle the "param_groups" section with two-level flattening + for param_group in cast(ListDictValueType, state_dict[_PG]): + fqns = param_group.pop(_PARAMS) + for fqn in cast(list[str], fqns): + for k, v in param_group.items(): + ret[f"{_PG}.{fqn}.{k}"] = v + + return ret + + +def _unflatten_optim_state_dict( + optim: torch.optim.Optimizer, + state_dict: dict[str, ValueType], + info: _StateDictInfo, +) -> OptimizerStateType: + """ + This API unflattens the state_dict generated by _flatten_optim_state_dict(). + Supports arbitrary levels of nesting in the state section through recursive reconstruction. + + See the docstring of _flatten_optim_state_dict() for more detail. + """ + + def _reconstruct_nested_dict( + flattened_key: str, flattened_dict: dict[str, ValueType] + ) -> dict[str, ValueType]: + """ + Reconstructs a potentially nested value from flattened keys. + For non-nested values, returns the value directly. + For nested values, reconstructs the nested structure with string keys. + """ + + # Create the prefix to search for nested keys + # e.g., if flattened_key is "state.layer1.weight", prefix becomes "state.layer1.weight." + prefix = f"{flattened_key}." + # Initialize an empty dictionary to build our nested structure + nested_dict: dict[str, Any] = {} + + # Iterate through all keys in the flattened dictionary + for key, value in flattened_dict.items(): + # Check if this key is nested under our target key + # e.g., "state.layer1.weight.exp_avg" starts with "state.layer1.weight." + if not key.startswith(prefix): + # Skip keys that don't belong to this nested structure + continue + + # Remove the prefix to get just the nested part + # e.g., "state.layer1.weight.exp_avg" -> "exp_avg" + remaining_key = key[len(prefix) :] + # Split the remaining key into parts to build the nested structure + # e.g., "step" -> ["step"] or "momentum_buffer" -> ["momentum_buffer"] + parts = remaining_key.split(".") + # Start at the root of our new nested dictionary + current = nested_dict + + # Navigate through or create the nested dictionary structure + # For each part except the last one (which will hold the value) + for part in parts[:-1]: + # Create the nested dictionary if it doesn't exist yet + if part not in current: + current[part] = {} + # Move deeper into the nested structure + assert isinstance(current[part], dict) + current = current[part] + + # Set the value at the final level using the last part as the key + # e.g., current["exp_avg"] = tensor(...) + current[parts[-1]] = value + + # Return the reconstructed nested dictionary (empty dict if no keys matched at all) + return nested_dict + + state: DictValueType = {} + pg_state: ListDictValueType = [] + return_osd: OptimizerStateType = {_STATE: state, _PG: pg_state} + + for param_group in optim.param_groups: + pg_state.append({_PARAMS: []}) + for param in param_group[_PARAMS]: + for fqn in info.fqn_param_mapping[param]: + # If a parameter is shared, only one of the FQN will be used. + # So we need to verify which if this fqn is actually used in + # the state_dict. + if fqn in info.shared_params_mapping: + in_params = False + for k in param_group: + if k == _PARAMS: + continue + flatten_key = f"{_PG}.{fqn}.{k}" + if flatten_key in state_dict: + in_params = True + break + else: + in_params = True + + if not in_params: + continue + + params = pg_state[-1][_PARAMS] + if not isinstance(params, list): + raise AssertionError(f"Expected list, got {type(params)}") + params.append(fqn) + + # Only add state if param requires grad + if not param.requires_grad: + continue + + # Reconstruct state for this parameter + state[fqn] = {} + for state_name in optim.state[param]: + flattened_state_key = f"{_STATE}.{fqn}.{state_name}" + + if flattened_state_key not in state_dict: + # Try to reconstruct the value + reconstructed_value = _reconstruct_nested_dict( + flattened_state_key, state_dict + ) + cast(DictValueType, state[fqn])[state_name] = ( + reconstructed_value + ) + else: + # Existing keys mean no nesting, directly use the value. + cast(DictValueType, state[fqn])[state_name] = state_dict[ + flattened_state_key + ] + + first_param_fqn = cast(list[str], pg_state[-1][_PARAMS])[0] + for k in param_group: + if k == _PARAMS: + continue + value = state_dict[f"{_PG}.{first_param_fqn}.{k}"] + if k not in pg_state[-1]: + pg_state[-1][k] = value + elif pg_state[-1][k] != value: + raise RuntimeError( + "All the parameters in the same parameter group should have " + f"the same saved param_group value. But {first_param_fqn}.{k} " + f"is {value} while other(s) is {pg_state[-1][k]}." + ) + + return return_osd + + +@torch.no_grad() +def _get_optim_state_dict( + model: nn.Module, + optimizers: tuple[torch.optim.Optimizer, ...], + info: _StateDictInfo, +) -> OptimizerStateType: + if not info.handle_optim: + return {} + + optim_state_dict: OptimizerStateType = {_STATE: {}, _PG: []} + for optim in optimizers: + _init_optim_state(optim) + osd = _state_dict_fn(optim, "state_dict")() + if info.fsdp_modules: + with info.fsdp_context(): + osd = FSDP.optim_state_dict(model, optim, osd) + + # We need to specially handle FlatParameter FSDP as + # FlatParameter FSDP converts the FQNs. + # There are no easy ways to do this conversion systematically. + # We can only use a string replacement without correctness check. + if not osd: + continue + for k in list(osd[_STATE].keys()): + if "_orig_mod" in k: + osd[_STATE][k.replace("_orig_mod.", "")] = osd[_STATE].pop(k) + for g in osd[_PG]: + params = [k.replace("_orig_mod.", "") for k in g[_PARAMS]] + g[_PARAMS] = params + else: + params = list(chain.from_iterable(g[_PARAMS] for g in optim.param_groups)) + param_pid_mapping = dict(zip(params, range(len(params)))) + fqn_pid_mapping = {} + for key, param in model.named_parameters(): + fqns = _get_fqns(model, key) + if len(fqns) != 1: + raise AssertionError( + f"Expected 1 FQN for key '{key}', got {len(fqns)}" + ) + fqn = next(iter(fqns)) + if param not in param_pid_mapping: + continue + pid = param_pid_mapping[param] + fqn_pid_mapping[fqn] = pid + fqn_pid_mapping[pid] = fqn + + # Only convert top-level parameter IDs to FQNs, preserve nested key types + for key in list(osd[_STATE].keys()): + fqn = fqn_pid_mapping[key] + # Move the entire state dict value (which may contain nested integer keys) + # without modifying its internal structure + osd[_STATE][fqn] = osd[_STATE].pop(key) + + for group in osd[_PG]: + group[_PARAMS] = [fqn_pid_mapping[pid] for pid in group[_PARAMS]] + + if not osd: + continue + + cast(DictValueType, optim_state_dict[_STATE]).update(osd[_STATE]) + cast(ListDictValueType, optim_state_dict[_PG]).extend(osd[_PG]) + + if info.flatten_optimizer_state_dict: + optim_state_dict = cast( + OptimizerStateType, _flatten_optim_state_dict(optim_state_dict) + ) + + return _maybe_full_or_cpu_state_dict(optim_state_dict, info) + + +def _split_optim_state_dict( + model: nn.Module, + optim: torch.optim.Optimizer, + optim_state_dict: OptimizerStateType, + info: _StateDictInfo, +) -> OptimizerStateType: + """ + Extract the corresponding optim state_dict from ``optim_state_dict`` for + ``optim`` and return the result optim state_dict. + + Args: + model (nn.Module): the root model. + optim (torch.optim.Optimizer): the optimizer. + optim_state_dict (Dict[str, ValueType]): the superset optim state_dict that + contains the optim state_dict of ``optim``. + info (_StateDictInfo): state dict information. + + Returns: + The optim state_dict of ``optim``. + """ + + state: DictValueType = {} + pg_state: ListDictValueType = [] + return_osd: OptimizerStateType = {_STATE: state, _PG: pg_state} + pg_mapping: dict[int, int] = {} + + if all(isinstance(k, int) for k in cast(DictValueType, optim_state_dict[_STATE])): + return optim_state_dict + + for param_group in optim.param_groups: + pg_state.append({_PARAMS: []}) + for param in param_group[_PARAMS]: + for fqn in info.fqn_param_mapping[param]: + if fqn in info.shared_params_mapping: + in_params = False + for loaded_param_group in cast( + ListDictValueType, optim_state_dict[_PG] + ): + if fqn in cast(list[str], loaded_param_group[_PARAMS]): + in_params = True + break + else: + in_params = True + if not in_params: + continue + + params = pg_state[-1][_PARAMS] + if not isinstance(params, list): + raise AssertionError(f"Expected list, got {type(params)}") + params.append(fqn) + if param.requires_grad: + if fqn in cast(DictValueType, optim_state_dict[_STATE]): + state[fqn] = cast(DictValueType, optim_state_dict[_STATE])[fqn] + elif info.strict: + raise RuntimeError( + f"Missing optimizer state for parameter '{fqn}' in checkpoint. " + "The parameter requires gradients but has no saved optimizer state. " + "To load anyway, use StateDictOptions(strict=False)." + ) + for loaded_param_group in cast( + ListDictValueType, optim_state_dict[_PG] + ): + if fqn in cast(list[str], loaded_param_group[_PARAMS]): + pg_mapping[id(loaded_param_group)] = len(return_osd[_PG]) - 1 + + if len(param_group[_PARAMS]) == 0: + # Param_group with empty params. + ret = [] + for loaded_param_group in cast(ListDictValueType, optim_state_dict[_PG]): + if len(cast(list[str], loaded_param_group[_PARAMS])) == 0: + ret.append(loaded_param_group) + if len(ret) != 1: + raise ValueError( + "There are param groups that have zero parameters. " + "In such a case, DSD only support exactly one param group " + "with zero parameters." + "But the loaded state_dict has zero or more than one param groups " + "that have zero parameters." + ) + if len(optim_state_dict[_PG]) != len(optim.param_groups): + raise ValueError( + "When there is a parameter group that has zero parameters, " + "multiple optimizers are not supported." + ) + pg_mapping[id(loaded_param_group)] = len(return_osd[_PG]) - 1 + + for param_group in cast(ListDictValueType, optim_state_dict[_PG]): + pg_idx = pg_mapping.get(id(param_group), -1) + if pg_idx == -1: + continue + + for key, value in param_group.items(): + if key == _PARAMS: + continue + # TODO: check if value is the same if exists. + pg_state[pg_idx][key] = value + + return return_osd + + +@torch.no_grad() +def _load_optim_state_dict( + model: nn.Module, + optimizers: tuple[torch.optim.Optimizer, ...], + state_dict: OptimizerStateType, + info: _StateDictInfo, +) -> None: + if not info.handle_optim: + return + + for optim in optimizers: + _init_optim_state(optim) + if state_dict: + if _STATE in state_dict: + optim_state_dict = _split_optim_state_dict( + model, optim, state_dict, info + ) + else: + optim_state_dict = _unflatten_optim_state_dict( + optim, cast(dict[str, ValueType], state_dict), info + ) + else: + optim_state_dict = {} + if info.fsdp_modules: + # We need to specially handle FlatParameter FSDP as + # FlatParameter FSDP converts the FQNs. + for original_fqn, _ in model.named_parameters(): + fqns = _get_fqns(model, original_fqn) + fqns_with_compiler = _get_fqns( + model, original_fqn, skip_compiler_prefix=False + ) + if fqns == fqns_with_compiler: + continue + + if len(fqns) != 1: + raise AssertionError( + f"Expected 1 FQN for '{original_fqn}', got {len(fqns)}" + ) + fqn = fqns.pop() + fqn_with_compiler = fqns_with_compiler.pop() + for g in optim_state_dict[_PG]: + val = cast(dict[str, Any], g) + params = [ + key.replace(fqn, fqn_with_compiler) for key in val[_PARAMS] + ] + val[_PARAMS] = params + osd_state = cast(DictValueType, optim_state_dict[_STATE]) + for k in list(osd_state.keys()): + if fqn in k: + osd_state[k.replace(fqn, fqn_with_compiler)] = osd_state.pop(k) + + with info.fsdp_context(): + optim_state_dict = FSDP.optim_state_dict_to_load( + model, optim, optim_state_dict + ) + elif info.full_state_dict: + info.full_state_dict = False + local_state_dict = _get_optim_state_dict(model, (optim,), info) + info.full_state_dict = True + device = None + + def _device(t): + if t.dim() > 0: + nonlocal device + if device is None: + device = t.device + elif device != t.device: + raise ValueError("Device mismatch") + return t + + _ = tree_map_only(torch.Tensor, _device, local_state_dict) + if device is None: + raise AssertionError("Expected device to be set") + flatten_osd, osd_mapping = _flatten_state_dict(optim_state_dict) + flatten_local_osd, local_osd_mapping = _flatten_state_dict(local_state_dict) + if info.broadcast_from_rank0: + _broadcast_state_dict(flatten_osd, flatten_local_osd, device=device) + else: + _distribute_state_dict(flatten_osd, flatten_local_osd, device=device) + # The modifications listed seek to address the problem where optim might possess + # dissimilar parameters in comparison to optim_state_dict. This is achieved by + # incorporating differential parameters within local, which may result in optim + # having additional parameters ultimately. + for optim_key in flatten_osd: + if optim_key not in flatten_local_osd: + if optim_key not in osd_mapping: + raise AssertionError( + f"Expected key '{optim_key}' in osd_mapping" + ) + flatten_local_osd[optim_key] = flatten_osd[optim_key] + local_osd_mapping[optim_key] = osd_mapping[optim_key] + optim_state_dict = _unflatten_state_dict( + flatten_local_osd, local_osd_mapping + ) + for pg in optim_state_dict[_PG]: + if _PARAMS not in pg: + cast(dict[str, ValueType], pg)[_PARAMS] = [] + + # Note that we do not have to convert the FQN back to param id here if + # order in optim.param_groups[idx][_PARAMS] is the same as the one in + # optim_state_dict[_PG][idx][_PARAMS]. + _state_dict_fn(optim, "load_state_dict")(state_dict=optim_state_dict) + + +def get_model_state_dict( + model: nn.Module, + *, + submodules: Optional[set[nn.Module]] = None, + options: Optional[StateDictOptions] = None, +) -> dict[str, ValueType]: + """ + Return the model state_dict of ``model``. + + See ``get_state_dict`` for the detail usage. + + Args: + model (nn.Module): the nn.Module to the model. + submodules (deprecated): Optional[set[nn.Module]]: only return the model parameters + that belong to the submodules. + options (StateDictOptions): the options to control how + model state_dict and optimizer state_dict should be returned. See + `StateDictOptions` for the details. + + Returns: + The state_dict for ``model``. + + :rtype: typing.Dict[str, ValueType] + """ + with _gc_context(): + info = _verify_options( + model, + (), + optim_only=False, + submodules=submodules, + options=options, + ) + model_state_dict = _get_model_state_dict(model, info) + _verify_state_dict(model_state_dict, {}, info) + return model_state_dict + + +def get_optimizer_state_dict( + model: nn.Module, + optimizers: Union[torch.optim.Optimizer, Iterable[torch.optim.Optimizer]], + *, + submodules: Optional[set[nn.Module]] = None, + options: Optional[StateDictOptions] = None, +) -> OptimizerStateType: + """ + Return the combined state_dict for optimizers. + + See ``get_state_dict`` for the detail usage. + + Args: + model (nn.Module): the nn.Module to the model. + optimizers (Union[None, Optimizer, Iterable[Optimizer]]): + The optimizers that are used to optimize ``model``. + submodules (deprecated): Optional[set[nn.Module]]: only return the model parameters + that belong to the submodules. + options (StateDictOptions): the options to control how + model state_dict and optimizer state_dict should be returned. See + `StateDictOptions` for the details. + + Returns: + The state_dict for ``optimizers``. + + :rtype: OptimizerStateType + """ + with _gc_context(): + optimizers = ( + (optimizers,) + if isinstance(optimizers, torch.optim.Optimizer) + else tuple(optimizers) + ) + info = _verify_options( + model, + optimizers, + optim_only=True, + submodules=submodules, + options=options, + ) + optim_state_dict = _get_optim_state_dict(model, optimizers, info) + _verify_state_dict({}, optim_state_dict, info) + return optim_state_dict + + +def get_state_dict( + model: nn.Module, + optimizers: Union[torch.optim.Optimizer, Iterable[torch.optim.Optimizer]], + *, + submodules: Optional[set[nn.Module]] = None, + options: Optional[StateDictOptions] = None, +) -> tuple[dict[str, ValueType], OptimizerStateType]: + """ + Return the model state_dict and optimizers state_dict. + + ``get_state_dict`` can process any module that is parallelized by PyTorch + FSDP/fully_shard, DDP/replicate, tensor_parallel/parallelize_module, and any + combination of these parallelisms. The main functions of ``get_state_dict`` + are: 1.) returning a model and optimizer state_dict that can be resharded + with a different number of trainers and/or different parallelisms. + 2.) hiding the parallelism-specific state_dict APIs. Users don't have to call + these APIs. + 3.) sanity checking the result state_dict. + + The keys of the result state dictionary are the canonical FQNs (Fully + Qualified Names). A canonical FQN refers to the FQN based on a parameter's + position in an nn.Module hierarchy. More specifically, a canonical FQN to a + parameter is the FQN returned by ``module.named_parameters()`` or + ``module.named_buffers()`` when the module is not distributed by any + parallelisms. Since the optimizer internally uses parameter IDs to represent + a parameter, there will be a conversion from the parameter IDs to the + canonical FQNs when calling this API. + + ``get_state_dict`` can also process a module that is not parallelized. In + such a case, ``get_state_dict`` only performs one function -- converting the + optimizer parameter IDs to the canonical FQNs. + + Example: + >>> # xdoctest: +SKIP + >>> import torch + >>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP + >>> from torch.nn.parallel import DistributedDataParallel as DDP + >>> from torch.distributed.checkpoint.state_dict import get_state_dict + + >>> fsdp_model = FSDP(copy.deepcopy(model)) + >>> fsdp_optim = torch.optim.Adam(model.parameters(), lr=1e-3) + >>> ddp_model = DDP(copy.deepcopy(model)) + >>> ddp_optim = torch.optim.Adam(model.parameters(), lr=1e-3) + + + >>> ddp_state_dict, ddp_optim_state_dict = get_state_dict(ddp_model, ddp_optim) + >>> fsdp_state_dict, fsdp_optim_state_dict = get_state_dict( + ... fsdp_model, fsdp_optim + ... ) + + >>> # if we simply call ddp_model.state_dict() and fsdp_model.state_dict(), + >>> # the asserts will fail. + >>> assert ddp_state_dict == fsdp_state_dict + >>> assert ddp_optim_state == fsdp_optim_state_dict + + + Args: + model (nn.Module): the nn.Module to the model. + optimizers (Union[None, Optimizer, Iterable[Optimizer]]): + The optimizers that are used to optimize ``model``. + submodules (deprecated): Optional[set[nn.Module]]: only return the model parameters + that belong to the submodules. + options (StateDictOptions): the options to control how + model state_dict and optimizer state_dict should be returned. See + `StateDictOptions` for the details. + + Returns: + ``Tuple`` that contain model state_dict and optimizer state_dict. + + :rtype: typing.Tuple[typing.Dict[str, ValueType], OptimizerStateType] + """ + + with _gc_context(): + optimizers = ( + (optimizers,) + if isinstance(optimizers, torch.optim.Optimizer) + else tuple(optimizers) + ) + info = _verify_options( + model, + optimizers, + optim_only=False, + submodules=submodules, + options=options, + ) + model_state_dict = _get_model_state_dict(model, info) + optim_state_dict = _get_optim_state_dict(model, optimizers, info) + _verify_state_dict(model_state_dict, optim_state_dict, info) + return model_state_dict, optim_state_dict + + +def _unflatten_model_state_dict( + model: nn.Module, + state_dict: Union[dict[nn.Module, dict[str, ValueType]], dict[str, ValueType]], +) -> dict[str, ValueType]: + if not state_dict: + return {} + + if isinstance(next(iter(state_dict.keys())), nn.Module): + warnings.warn( + "Passing model_state_dict as a ``Dict[nn.Module, Dict[str, Any]]``" + "is deprecated and will be removed in 2.5. If you need this " + "feature, please preprocessing the model_state_dict to achieve the " + "same functionality.", + FutureWarning, + stacklevel=2, + ) + cast_state_dict = cast(dict[nn.Module, dict[str, ValueType]], state_dict) + new_state_dict: dict[str, ValueType] = {} + for submodule, sub_state_dict in cast_state_dict.items(): + for name, m in model.named_modules(): + if m != submodule: + continue + + fqns = _get_fqns(model, name) + if len(fqns) != 1: + raise AssertionError( + "FQNs for a submodule should only have 1 element" + ) + prefix = f"{next(iter(fqns))}." + new_state_dict.update( + {prefix + subfqn: value for subfqn, value in sub_state_dict.items()} + ) + return new_state_dict + else: + return cast(dict[str, ValueType], state_dict) + + +def set_model_state_dict( + model: nn.Module, + model_state_dict: dict[str, ValueType], + *, + options: Optional[StateDictOptions] = None, +) -> _IncompatibleKeys: + """Load the model state_dict. + + The counterpart of ``get_model_state_dict`` to set the state_dict to the + model. See ``set_state_dict`` for the detail usage. + + Args: + model (nn.Module): the nn.Module to the model. + model_state_dict: (Dict[str, ValueType]): + the model state_dict to load. If the key of the ``model_state_dict`` + is nn.Module, the key is a submodule of ``model`` and the value should + be the state_dict of the submodule. When loading the state_dict, + the prefix of the submodule will be append to the state_dict. + options (StateDictOptions): the options to control how + model state_dict and optimizer state_dict should be loaded. See + `StateDictOptions` for the details. + + Returns: + ``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields: + * **missing_keys** is a list of str containing the missing keys + * **unexpected_keys** is a list of str containing the unexpected keys + + :type model_state_dict: typing.Dict[str, ValueType] + """ + model_state_dict: dict[str, ValueType] = _unflatten_model_state_dict( + model, model_state_dict + ) + with _gc_context(): + info = _verify_options(model, (), optim_only=False, options=options) + + _verify_state_dict(model_state_dict, {}, info) + return _load_model_state_dict(model, model_state_dict, info) + + +def set_optimizer_state_dict( + model: nn.Module, + optimizers: Union[torch.optim.Optimizer, Iterable[torch.optim.Optimizer]], + optim_state_dict: OptimizerStateType, + *, + options: Optional[StateDictOptions] = None, +) -> None: + """Load the optimizers state_dict. + + The counterpart of ``get_optimizer_state_dict`` to set the state_dict to the + optimizers. See ``set_state_dict`` for the detail usage. + + WARN: ``set_optimizer_state_dict`` can only be called before ``backward()`` or after + ``step()`` is called on the optimizers. Otherwise, the optimizer states won't be + initialized correctly. + + Args: + model (nn.Module): the nn.Module to the model. + optimizers (Union[Optimizer, Iterable[Optimizer]]): + The optimizers that are used to optimize ``model``. + optim_state_dict: OptimizerStateType: + the optimizer state_dict to load. + options (StateDictOptions): the options to control how + model state_dict and optimizer state_dict should be loaded. See + `StateDictOptions` for the details. + + Returns: + None + + :type optim_state_dict: typing.OptimizerStateType + """ + with _gc_context(): + optimizers = ( + (optimizers,) + if isinstance(optimizers, torch.optim.Optimizer) + else tuple(optimizers) + ) + info = _verify_options(model, optimizers, optim_only=True, options=options) + + _verify_state_dict({}, optim_state_dict, info) + _load_optim_state_dict(model, optimizers, optim_state_dict, info) + + +def set_state_dict( + model: nn.Module, + optimizers: Union[torch.optim.Optimizer, Iterable[torch.optim.Optimizer]], + *, + model_state_dict: dict[str, ValueType], + optim_state_dict: OptimizerStateType, + options: Optional[StateDictOptions] = None, +) -> _IncompatibleKeys: + """Load the model state_dict and optimizers state_dict. + + The counterpart of ``get_state_dict`` to set the state_dict to the model and + optimizers. The given ``model_state_dict`` and ``optim_state_dict`` do not + have to be returned by ``get_state_dict`` but must meet the following + requirements: 1) all FQNs are canonical FQNs as defined in ``get_state_dict``, + 2) if a tensor is sharded, it must be either a ShardedTensor or DTensor, + 3) optimizer state_dict cannot contain the parameter IDs; the keys should be + the canonical FQNs. + + WARN: ``set_state_dict`` can only be called before ``backward()`` or after ``step()`` + is called on the optimizers. Otherwise, the optimizer states won't be initialized + correctly. + + Args: + model (nn.Module): the nn.Module to the model. + optimizers (Union[Optimizer, Iterable[Optimizer]]): + The optimizers that are used to optimize ``model``. + model_state_dict: (Union[Dict[nn.Module, Dict[str, ValueType]], Dict[str, ValueType]]): + the model state_dict to load. If the key of the ``model_state_dict`` + is nn.Module, the key is a submodule of ``model`` and the value should + be the state_dict of the submodule. When loading the state_dict, + the prefix of the submodule will be append to the state_dict. + optim_state_dict: OptimizerStateType: + the optimizer state_dict to load. + options (StateDictOptions): the options to control how + model state_dict and optimizer state_dict should be loaded. See + `StateDictOptions` for the details. + + Returns: + ``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields: + * **missing_keys** is a list of str containing the missing keys of the model state_dict. + * **unexpected_keys** is a list of str containing the unexpected keys of the model state_dict. + + :type model_state_dict: typing.Dict[str, ValueType] + :type optim_state_dict: typing.OptimizerStateType + """ + + model_state_dict: dict[str, ValueType] = _unflatten_model_state_dict( + model, model_state_dict + ) + with _gc_context(): + optimizers = ( + (optimizers,) + if isinstance(optimizers, torch.optim.Optimizer) + else tuple(optimizers) + ) + info = _verify_options( + model, optimizers, optim_only=not model_state_dict, options=options + ) + + _verify_state_dict(model_state_dict, optim_state_dict, info) + _load_optim_state_dict(model, optimizers, optim_state_dict, info) + return _load_model_state_dict(model, model_state_dict, info) + + +# TODO: correct the state_dict function signature. +# TODO: this API is not yet fully tested. Make it private +@no_type_check +def _patch_model_state_dict( + model: nn.Module, + *, + options: Optional[StateDictOptions] = None, +) -> None: + """Patch the ``state_dict`` and ``load_state_dict`` attributes of ``model``. + + Patch the ``state_dict`` and ``load_state_dict`` attributes of ``model`` to + be a partial function to call ``get_state_dict`` and ``set_state_dict``. + + Example: + from torch.distributed.fsdp import FullyShardedDataParallel as FSDP + from torch.distributed.checkpoint.state_dict import patch_model_state_dict + + model = fsdp(model) + patch_model_state_dict(model) + + Args: + model (nn.Module): the nn.Module to the model. + options (StateDictOptions): the options to control how + model state_dict and optimizer state_dict should be loaded. See + `StateDictOptions` for the details. + Returns: + None + """ + + _state_dict_call = functools.partial( + get_model_state_dict, + model=model, + options=options, + ) + + def state_dict_call(): + return _state_dict_call() + + model.state_dict = state_dict_call + + _load_state_dict_call = functools.partial( + set_model_state_dict, + model=model, + options=options, + ) + + def load_state_dict_call(state_dict: dict[str, Any]): + _load_state_dict_call(model_state_dict=state_dict) + + model.load_state_dict = load_state_dict_call + + _patched_state_dict.add(state_dict_call) + _patched_state_dict.add(load_state_dict_call) + + +# TODO: correct the load_state_dict function signature. +# TODO: this API is not yet fully tested. Make it private +@no_type_check +def _patch_optimizer_state_dict( + model: nn.Module, + *, + optimizers: tuple[torch.optim.Optimizer, ...], + options: Optional[StateDictOptions] = None, +) -> None: + """Patch the ``state_dict`` and ``load_state_dict`` attributes of ``optimizers``. + + Patch the ``state_dict`` and ``load_state_dict`` attributes of ``optimizers`` to + be a partial function to call ``get_state_dict`` and ``set_state_dict``. + + Note that if there are multiple optimizers, all of the optimizers will be patched. + So users only need to call one of the state_dict() to get the full result. + + Example: + from torch.distributed.fsdp import FullyShardedDataParallel as FSDP + from torch.distributed.checkpoint.state_dict import patch_model_state_dict + + model = fsdp(model) + patch_model_state_dict(model) + + Args: + model (nn.Module): the nn.Module to the model. + options (StateDictOptions): the options to control how + model state_dict and optimizer state_dict should be loaded. See + `StateDictOptions` for the details. + Returns: + None + """ + + _state_dict_call = functools.partial( + get_optimizer_state_dict, + model=model, + optimizers=optimizers, + options=options, + ) + + def state_dict_call(): + return _state_dict_call() + + _load_state_dict_call = functools.partial( + set_optimizer_state_dict, + model=model, + optimizers=optimizers, + options=options, + ) + + def load_state_dict_call(state_dict: dict[str, Any]): + _load_state_dict_call(optim_state_dict=state_dict) + + _patched_state_dict.add(state_dict_call) + _patched_state_dict.add(load_state_dict_call) + optimizers = ( + (optimizers,) + if isinstance(optimizers, torch.optim.Optimizer) + else tuple(optimizers) + ) + for optim in optimizers: + optim.state_dict = state_dict_call + optim.load_state_dict = load_state_dict_call diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/state_dict_loader.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/state_dict_loader.py new file mode 100644 index 0000000000000000000000000000000000000000..178e190e937fb5fab1aa582464e20f1cff8d7abf --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/state_dict_loader.py @@ -0,0 +1,389 @@ +# mypy: allow-untyped-decorators +# mypy: allow-untyped-defs +import inspect +import logging +import os +import warnings +from typing import Any, cast, Optional, TYPE_CHECKING, Union +from typing_extensions import deprecated + +import torch +import torch.distributed as dist +from torch.distributed.checkpoint.default_planner import _EmptyStateDictLoadPlanner +from torch.distributed.checkpoint.logger import _dcp_method_logger +from torch.distributed.checkpoint.stateful import Stateful + +from ._storage_utils import _storage_setup +from .default_planner import DefaultLoadPlanner +from .planner import LoadPlan, LoadPlanner +from .storage import StorageReader +from .utils import _api_bc_check, _DistWrapper, _profile + + +if TYPE_CHECKING: + from torch.distributed.checkpoint.metadata import Metadata + +__all__ = ["load_state_dict", "load"] + +logger = logging.getLogger() + + +@deprecated( + "`load_state_dict` is deprecated and will be removed in future versions. " + "Please use `load` instead.", + category=FutureWarning, +) +def load_state_dict( + state_dict: dict[str, Any], + storage_reader: StorageReader, + process_group: Optional[dist.ProcessGroup] = None, + coordinator_rank: int = 0, + no_dist: bool = False, + planner: Optional[LoadPlanner] = None, +) -> None: + """This method is deprecated. Please switch to 'load'.""" + storage_reader.reset() + with _profile(): + # TODO: test returning `load` here instead. + return _load_state_dict( + state_dict, + storage_reader, + process_group, + coordinator_rank, + no_dist, + planner, + ) + + +@_dcp_method_logger(log_exceptions=True) +@_api_bc_check +def load( + state_dict: dict[str, Any], + *, + checkpoint_id: Union[str, os.PathLike, None] = None, + storage_reader: Optional[StorageReader] = None, + planner: Optional[LoadPlanner] = None, + process_group: Optional[dist.ProcessGroup] = None, + no_dist: bool = False, +) -> None: + """ + Load a checkpoint into a distributed state dict in SPMD style. + + Each rank must have the same keys in their ``state_dict`` provided to this + API. Mismatched keys may result in hangs or errors. If unsure, you can use + the ``utils._assert_same_keys`` API to check (but may incur communication + costs). + + Each rank will try to read the least amount of data necessary + to fulfill the requested `state_dict`. When loading :class:`ShardedTensor` + or :class:`DTensor` instances, each rank only reads data for their local shards. + + For each ``Stateful`` object (having both a ``state_dict`` and a ``load_state_dict``), + load will first call ``state_dict`` before attempting deserialization, followed by + ``load_state_dict`` once the deserialization is complete. + For each non-``Stateful`` object, load will deserialize the object, and then replace + it in the ``state_dict`` with the deserialized object. + + .. warning:: + All tensors in ``state_dict`` must be allocated on their + destination device *prior to* calling this function. + + All non-tensor data is loaded using `torch.load()` and modified in place + on state_dict. + + .. warning:: + Users must call `load_state_dict` on the root module to ensure load + pos-processing and non-tensor data properly propagates. + + .. note: + If no process group is initialized, this function will assume the intent + is to load a checkpoint into the local process. This can be useful in the + case of local inference, and when using regular Tensors (as opposed to DTensor + or ShardedTensor) + + .. note: + Rank 0 is assumed to be the coordinator rank. + + Args: + state_dict (Dict[str, Any]): The state_dict to load the checkpoint into. + checkpoint_id (Union[str, os.PathLike, None]): + The ID of this checkpoint instance. The meaning of the checkpoint_id + depends on the storage. It can be a path to a folder or to a file. + It can also be a key if the storage is a key-value store. + (Default: ``None``) + storage_reader (Optional[StorageReader]): + Instance of StorageWriter used to perform reads. If this is not + specified, DCP will automatically infer the reader based on the + checkpoint_id. If checkpoint_id is also None, an exception will + be raised. (Default: ``None``) + planner (Optional[LoadPlanner]): + Instance of LoadPlanner. If this is not specified, the default + planner will be used. (Default: ``None``) + process_group (Optional[ProcessGroup]): + ProcessGroup to be used for cross-rank synchronization. + (Default: ``None``) + no_dist (bool): If ``True``, this function will assume the intent is to load + a checkpoint without using cross-rank synchronization. (Default: ``False``) + Returns: + None. + + Examples + >>> # xdoctest: +SKIP + >>> my_model = MyModule() + >>> optimizer = Adagrad(my_model.parameters()) + >>> model_state_dict = my_model.state_dict() + >>> fs_storage_reader = torch.distributed.checkpoint.FileSystemReader( + ... "/checkpoint/1" + ... ) + + >>> torch.distributed.checkpoint.load_state_dict( + >>> state_dict=model_state_dict, + >>> storage_reader=fs_storage_reader, + >>> ) + + >>> # module.load_state_dict() function might have customized steps + >>> # to flush the state_dict, must call it to + >>> # ensure correct behavior. + >>> my_model.load_state_dict(model_state_dict) + + .. note:: + load_state_dict uses collectives to coordinate reads across ranks. + For NCCL-based process groups, internal tensor representations of + objects must be moved to the GPU device before communication takes place. + In this case, the device used is given by ``torch.cuda.current_device()`` + and it is the user's responsibility to ensure that this is set so that each + rank has an individual GPU, via ``torch.cuda.set_device()``. + """ + + no_dist = no_dist or (not dist.is_available()) or (not dist.is_initialized()) + if no_dist: + warnings.warn( + "torch.distributed is disabled, unavailable or uninitialized, assuming the intent is to load in a single process.", + stacklevel=2, + ) + + with _profile(): + storage_reader = cast( + StorageReader, _storage_setup(storage_reader, checkpoint_id, reader=True) + ) + + # All ranks must have the same keys in their `state_dict` provided to + # this API. See documentation for more details. + # Here we simply sort the keys to ensure that all ranks load values in + # the same order. + keys = sorted(state_dict.keys()) + + statetful_sd = {} + for key in keys: + if key not in state_dict: + continue + elem = state_dict[key] + statetful_sd[key] = ( + elem.state_dict() if isinstance(elem, Stateful) else elem + ) + + _load_state_dict( + state_dict=statetful_sd, + storage_reader=storage_reader, + process_group=process_group, + no_dist=no_dist, + planner=planner, + ) + for key in keys: + if key not in state_dict: + continue + elem = state_dict[key] + if isinstance(elem, Stateful): + # If the state_dict is a Stateful object, + # DCP does an in-place load in the original state dict. + elem.load_state_dict(statetful_sd[key]) + else: + # Otherwise, replace the state_dict with the loaded state_dict. + state_dict[key] = statetful_sd[key] + + +def _load_state_dict( + state_dict: dict[str, Any], + storage_reader: StorageReader, + process_group: Optional[dist.ProcessGroup] = None, + coordinator_rank: int = 0, + no_dist: bool = False, + planner: Optional[LoadPlanner] = None, +) -> None: + torch._C._log_api_usage_once("torch.distributed.checkpoint.load_state_dict") + + distW = _DistWrapper(process_group, not no_dist, coordinator_rank) + if planner is None: + planner = DefaultLoadPlanner() + + ckpt_kwargs = {} + if (ckpt_id := getattr(storage_reader, "checkpoint_id", None)) is not None: + ckpt_kwargs["checkpoint_id"] = ckpt_id + ckpt_kwargs["process_group"] = distW.group + + use_collectives = True + metadata: Optional[Metadata] = None + + @_dcp_method_logger(**ckpt_kwargs) + def local_step(): + nonlocal use_collectives + nonlocal metadata + + # Use global metadata if available, otherwise fallback to rank local metadata + try: + metadata = storage_reader.read_metadata() + except Exception: + logger.info( + "Global metadata is not found. Falling back to rank local metadata." + ) + + if ( + not metadata + and "kwargs" in inspect.signature(storage_reader.read_metadata).parameters + ): + try: + metadata = storage_reader.read_metadata(rank=distW.rank) # noqa: F841 + use_collectives = False + except Exception: + logger.info("Rank local metadata is not found.") + + if planner is None: + raise AssertionError("planner is None") + if metadata is None: + raise AssertionError("metadata is None") + planner.set_up_planner(state_dict, metadata, distW.is_coordinator) + + if ( + "kwargs" + in inspect.signature(storage_reader.set_up_storage_reader).parameters + ): + storage_reader.set_up_storage_reader( + metadata, + distW.is_coordinator, + rank=distW.rank, + use_collectives=use_collectives, + ) + else: + storage_reader.set_up_storage_reader(metadata, distW.is_coordinator) + + local_plan = planner.create_local_plan() + local_plan = storage_reader.prepare_local_plan(local_plan) + return local_plan + + @_dcp_method_logger(**ckpt_kwargs) + def global_step(all_local_plans): + if planner is None: + raise AssertionError("planner is None") + all_local_plans = planner.create_global_plan(all_local_plans) + all_local_plans = storage_reader.prepare_global_plan(all_local_plans) + return all_local_plans + + central_plan: Optional[LoadPlan] = None + if use_collectives: + central_plan = distW.reduce_scatter("plan", local_step, global_step) + else: + local_plan: LoadPlan = local_step() + global_plan: list[LoadPlan] = global_step([local_plan]) + central_plan = global_plan[0] + + @_dcp_method_logger(**ckpt_kwargs) + def read_data(): + if planner is None: + raise AssertionError("planner is None") + if central_plan is None: + raise AssertionError("central_plan is None") + final_local_plan = planner.finish_plan(central_plan) + all_reads = storage_reader.read_data(final_local_plan, planner) + + all_reads.wait() + return None + + if use_collectives: + _ = distW.all_gather("read", read_data) + else: + read_data() + distW.barrier() + + +def _load_state_dict_from_keys( + keys: Optional[Union[set[str], str]] = None, + *, + checkpoint_id: Union[str, os.PathLike, None] = None, + storage_reader: Optional[StorageReader] = None, + process_group: Optional[dist.ProcessGroup] = None, +) -> dict[str, Any]: + """ + Load only the specified keys from the checkpoint, if no keys are specified, the entire + checkpoint will be loaded. Note, this method completely loads the checkpoint into the + current process and is not distributed. + + .. warning:: + + + .. warning:: + + All non-tensor data is loaded using `torch.load()` + + .. note: + As opposed to the usual pattern, this function does not take a state dict as input + and does not load inplace. Instead, a new state dict is directly initialized and read + from file. + + .. note: + If no process group is initialized, this function will assume the intent + is to load a checkpoint into the local process. This can be useful in the + case of local inference, and when using regular Tensors (as opposed to DTensor + or ShardedTensor) + + .. note: + Rank 0 is assumed to be the coordinator rank. + + Args: + keys (Optional[Union[set[str], str]]): + Loads any key specified in this set. If no keys are specified, the entire checkpoint + is loaded. + checkpoint_id (Union[str, os.PathLike, None]): + The ID of this checkpoint instance. The meaning of the checkpoint_id + depends on the storage. It can be a path to a folder or to a file. + It can also be a key if the storage is a key-value store. + (Default: ``None``) + storage_reader (Optional[StorageReader]): + Instance of StorageWriter used to perform reads. If this is not + specified, DCP will automatically infer the reader based on the + checkpoint_id. If checkpoint_id is also None, an exception will + be raised. (Default: ``None``) + process_group (Optional[ProcessGroup]): + ProcessGroup to be used for cross-rank synchronization. + (Default: ``None``) + + Returns: + State dict from specified keys + """ + torch._C._log_api_usage_once( + "torch.distributed.checkpoint._load_state_dict_from_keys" + ) + + no_dist = not (dist.is_available() and dist.is_initialized()) + if no_dist: + warnings.warn( + "torch.distributed is unavailable or uninitialized, assuming the intent is to load in a single process.", + stacklevel=2, + ) + + storage_reader = cast( + StorageReader, _storage_setup(storage_reader, checkpoint_id, reader=True) + ) + + if isinstance(keys, str): + keys = {keys} + + sd: dict[str, Any] = {} + _load_state_dict( + state_dict=sd, + storage_reader=storage_reader, + process_group=process_group, + no_dist=no_dist, + planner=_EmptyStateDictLoadPlanner(keys=keys), + ) + + return sd diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/state_dict_saver.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/state_dict_saver.py new file mode 100644 index 0000000000000000000000000000000000000000..370f97cd1cd013246563f021749b6537a327b235 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/state_dict_saver.py @@ -0,0 +1,496 @@ +# mypy: allow-untyped-decorators +# mypy: allow-untyped-defs +import inspect +import os +import warnings +from concurrent.futures import Future +from dataclasses import dataclass +from enum import Enum +from typing import cast, Optional, TYPE_CHECKING, Union +from typing_extensions import deprecated + +import torch +import torch.distributed as dist +from torch.distributed._state_dict_utils import STATE_DICT_TYPE +from torch.distributed.checkpoint._async_process_executor import ( + _ProcessBasedAsyncCheckpointExecutor, +) +from torch.distributed.checkpoint._async_thread_executor import ( + _ThreadBasedAsyncCheckpointExecutor, +) +from torch.distributed.checkpoint._storage_utils import _storage_setup +from torch.distributed.checkpoint.default_planner import DefaultSavePlanner +from torch.distributed.checkpoint.logger import _dcp_method_logger +from torch.distributed.checkpoint.metadata import Metadata +from torch.distributed.checkpoint.planner import SavePlan, SavePlanner +from torch.distributed.checkpoint.staging import ( + AsyncStager, + DefaultStager, + StagingOptions, +) +from torch.distributed.checkpoint.stateful import Stateful +from torch.distributed.checkpoint.storage import StorageWriter, WriteResult +from torch.distributed.distributed_c10d import _get_default_group + +from .utils import _api_bc_check, _DistWrapper, _profile + + +if TYPE_CHECKING: + from torch.distributed.checkpoint._async_executor import _AsyncCheckpointExecutor + + +__all__ = [ + "save_state_dict", + "save", + "async_save", + "AsyncCheckpointerType", + "AsyncSaveResponse", +] + + +class AsyncCheckpointerType(Enum): + """Enum for async checkpointer type.""" + + THREAD = "thread" + PROCESS = "process" + + +@deprecated( + "`save_state_dict` is deprecated and will be removed in future versions." + "Please use `save` instead.", + category=FutureWarning, +) +def save_state_dict( + state_dict: STATE_DICT_TYPE, + storage_writer: StorageWriter, + process_group: Optional[dist.ProcessGroup] = None, + coordinator_rank: int = 0, + no_dist: bool = False, + planner: Optional[SavePlanner] = None, +) -> Metadata: + """This method is deprecated. Please switch to 'save'.""" + storage_writer.reset() + + # TODO: test returning `save` here instead. + with _profile(): + return _save_state_dict( + state_dict, + storage_writer, + process_group, + coordinator_rank, + no_dist, + planner, + ) + + +@_dcp_method_logger(log_exceptions=True) # type: ignore[arg-type] +@_api_bc_check +def save( + state_dict: STATE_DICT_TYPE, + *, + checkpoint_id: Union[str, os.PathLike, None] = None, + storage_writer: Optional[StorageWriter] = None, + planner: Optional[SavePlanner] = None, + process_group: Optional[dist.ProcessGroup] = None, + no_dist: bool = False, + use_collectives: bool = True, +) -> Metadata: + """ + Save a distributed model in SPMD style. + + This function is different from ``torch.save()`` as it handles + ``ShardedTensor`` , and ``DTensor`` by having each rank only save their local shards. + + For each ``Stateful`` object (having both a ``state_dict`` and a ``load_state_dict``), + save will call ``state_dict`` before serialization. + + .. warning:: + There is no guarantees of Backwards Compatibility across PyTorch versions + for saved state_dicts. + + .. warning:: + If using the `process_group` argument, make sure that only its ranks + call `save_state_dict` and that all data in state_dict belong to it. + + .. note:: + When saving checkpoint for FSDP's `ShardingStrategy.HYBRID_SHARD`, only one of + the shard_group should be calling `save_state_dict` and the corresponding process + group needs to be passed in. + + .. note:: + If no process group is available, this function assumes the intention is to save the + state_dict in the local process. + + .. note: + Rank 0 is assumed to be the coordinator rank. + + + Args: + state_dict (Dict[str, Any]): The state_dict to save. + checkpoint_id (Union[str, os.PathLike, None]): + The ID of this checkpoint instance. The meaning of the checkpoint_id + depends on the storage. It can be a path to a folder or to a file. + It can also be a key if the storage is a key-value store. + (Default: ``None``) + storage_writer (Optional[StorageWriter]): + Instance of StorageWriter used to perform writes. If this is not + specified, DCP will automatically infer the writer based on the + checkpoint_id. If checkpoint_id is also None, an exception will + be raised. (Default: ``None``) + planner (Optional[SavePlanner]): + Instance of SavePlanner. If this is not specified, the default + planner will be used. (Default: ``None``) + process_group (Optional[ProcessGroup]): + ProcessGroup to be used for cross-rank synchronization. + (Default: ``None``) + no_dist (bool): + If ``True``, this function will assume the intent is to load + a checkpoint on a single rank/process. + (Default: ``False``) + use_collectives (bool): If ``False``, this function will assume the intent is to save + a checkpoint without using cross-rank synchronization. + (Default: ``True``) + This configuration is experimental and should be used with caution. + It will change the format of the saved checkpoint and may not be backward compatible. + + Returns: + Metadata: Metadata object for the saved checkpoint. + + Example: + >>> # xdoctest: +SKIP + >>> my_model = MyModule() + + >>> state_dict = {"model": my_model} + + >>> fs_storage_writer = torch.distributed.checkpoint.FileSystemWriter( + ... "/checkpoint/1" + ... ) + >>> torch.distributed.checkpoint.save( + >>> state_dict=state_dict, + >>> storage_writer=fs_storage_writer, + >>> ) + + .. note:: + save_state_dict uses collectives to coordinate writes across ranks. + For NCCL-based process groups, internal tensor representations of + objects must be moved to the GPU device before communication takes place. + In this case, the device used is given by ``torch.cuda.current_device()`` + and it is the user's responsibility to ensure that this is set so that + each rank has an individual GPU, via ``torch.cuda.set_device()``. + """ + torch._C._log_api_usage_once("torch.distributed.checkpoint.save") + + no_dist = no_dist or (not dist.is_available()) or (not dist.is_initialized()) + if no_dist: + warnings.warn( + "torch.distributed is disabled, unavailable or uninitialized, assuming the intent is to save in a single process.", + stacklevel=2, + ) + + with _profile(): + storage_writer = cast( + StorageWriter, _storage_setup(storage_writer, checkpoint_id, reader=False) + ) + + return _save_state_dict( + state_dict=_stateful_to_state_dict(state_dict), + storage_writer=storage_writer, + process_group=process_group, + no_dist=no_dist, + planner=planner, + use_collectives=use_collectives, + ) + + +@dataclass +class AsyncSaveResponse: + """This class contains futures for staging and upload completion. + It is returned by async_save(). + staging_completion is a future that indicates when local copy + of state_dict is complete. + upload_completion is a future that indicates when a checkpoint + completed saving. + """ + + staging_completion: Future[None] + upload_completion: Future[None] + + +@_dcp_method_logger(log_exceptions=True) +def async_save( + state_dict: STATE_DICT_TYPE, + *, + checkpoint_id: Union[str, os.PathLike, None] = None, + storage_writer: Optional[StorageWriter] = None, + planner: Optional[SavePlanner] = None, + process_group: Optional[dist.ProcessGroup] = None, + async_checkpointer_type: AsyncCheckpointerType = AsyncCheckpointerType.THREAD, + async_stager: Optional[AsyncStager] = None, + no_dist: bool = False, + use_collectives: bool = True, +) -> Union[Future, AsyncSaveResponse]: + """Asynchronous version of ``save``. This code first de-stages the state_dict on to the + staging storage (defaults to CPU memory), and then calls the `save` in a separate thread. + + .. warning:: + This feature is experimental and subject to change. + MUST CALL CLOSE AFTER LAST CHECKPOINT IS SAVED + + Args: + state_dict (Dict[str, Any]): The state_dict to save. + checkpoint_id (Union[str, os.PathLike, None]): + The ID of this checkpoint instance. The meaning of the checkpoint_id + depends on the storage. It can be a path to a folder or to a file. + It can also be a key if the storage is a key-value store. + (Default: ``None``) + storage_writer (Optional[StorageWriter]): + Instance of StorageWriter used to perform 'stage' and 'save'. If + this is not specified, DCP will automatically infer the writer based on the + checkpoint_id. If checkpoint_id is also None, an exception will + be raised. (Default: ``None``) + planner (Optional[SavePlanner]): + Instance of SavePlanner. If this is not specified, the default + planner will be used. (Default: ``None``) + process_group (Optional[ProcessGroup]): + ProcessGroup to be used for cross-rank synchronization. + (Default: ``None``) + async_checkpointer_type (AsyncCheckpointerType): + whether to do checkpoint in separate thread or process + (Default: ``AsyncCheckpointerType.THREAD``) + async_stager (AsyncStager): + provides staging implementation. If storage_writer implements AsyncStager + and async_stager is provided, async_stager will be used for staging + no_dist (bool): + If ``True``, this function will assume the intent is to save + a checkpoint on a single rank/process. + (Default: ``False``) + use_collectives: If False, Save the checkpoint without rank coordination. (Default: ``True``) + This configuration is experimental and should be used with caution. + It will change the format of the saved checkpoint and may not be backward compatible. + + Returns: + Future: A future holding the resultant Metadata object from `save`. + + Example: + >>> # xdoctest: +SKIP + >>> my_model = MyModule() + + >>> state_dict = {"model": my_model} + + >>> fs_storage_writer = torch.distributed.checkpoint.FileSystemWriter( + ... "/checkpoint/1" + ... ) + >>> checkpoint_future = torch.distributed.checkpoint.async_save( + >>> state_dict=state_dict, + >>> storage_writer=fs_storage_writer, + >>> ) + >>> + >>> # ... do some work ... + >>> + >>> checkpoint_future.result() + + """ + torch._C._log_api_usage_once("torch.distributed.checkpoint.async_save") + + if dist.is_available() and dist.is_initialized(): + pg = process_group or _get_default_group() + if torch.device("cpu") not in pg._device_types: + raise AssertionError( + "A CPU backend must be enabled for async save; try initializing process group with 'cpu:gloo,cuda:nccl'" + ) + + if async_stager is None: + if storage_writer is not None and isinstance(storage_writer, AsyncStager): + # bwc with old storage_writers + async_stager = storage_writer + else: + async_stager = DefaultStager( + StagingOptions( + False, + False, + False, + False, + ) + ) + + state_dict = _stateful_to_state_dict(state_dict) + + @_dcp_method_logger(log_exceptions=True) + def stage_state_dict() -> Union[Future[STATE_DICT_TYPE], STATE_DICT_TYPE]: + return async_stager.stage(state_dict) + + staging_future_or_state_dict = stage_state_dict() + + upload_executor: _AsyncCheckpointExecutor = ( + _ProcessBasedAsyncCheckpointExecutor() + if async_checkpointer_type == AsyncCheckpointerType.PROCESS + else _ThreadBasedAsyncCheckpointExecutor() + ) + + upload_future: Future = upload_executor.execute_save( + staging_future_or_state_dict, + checkpoint_id=checkpoint_id, + # pyrefly: ignore [bad-argument-type] + storage_writer=storage_writer, + planner=planner, + process_group=process_group, + no_dist=no_dist, + use_collectives=use_collectives, + ) + + if isinstance(staging_future_or_state_dict, Future): + staging_future = staging_future_or_state_dict + return_staging_future: Future[None] = Future() + + def callback( + original_staging_future: Future[STATE_DICT_TYPE], + return_staging_future: Future[None] = return_staging_future, + ): + try: + original_staging_future.result() + return_staging_future.set_result(None) + except Exception as e: + return_staging_future.set_exception(e) + + if not staging_future.done(): + staging_future.add_done_callback(callback) + else: + return_staging_future.set_result(None) + + # return new AsyncSaveResponse for users using new ZOC implementation + return AsyncSaveResponse( + staging_completion=return_staging_future, upload_completion=upload_future + ) + else: + + @_dcp_method_logger(log_exceptions=True) + def maybe_synchronize_staging(): + if async_stager.should_synchronize_after_execute: + async_stager.synchronize_staging() + + maybe_synchronize_staging() + return upload_future + + +@_dcp_method_logger(log_exceptions=True) +def _stateful_to_state_dict(state_dict: STATE_DICT_TYPE) -> STATE_DICT_TYPE: + """Creates a shallow copy of `state_dict` where `state_dict` is called for each Stateful object.""" + stateful_state_dict = {} + for key, elem in state_dict.items(): + # Apply _dcp_method_logger to each state_dict() call + def _elem_to_state_dict(elem): + return elem.state_dict() if isinstance(elem, Stateful) else elem + + _elem_to_state_dict.__name__ = f"_stateful_to_state_dict.{key}" + + stateful_state_dict[key] = _dcp_method_logger(log_exceptions=True)( + _elem_to_state_dict + )(elem) + return stateful_state_dict + + +def _save_state_dict( + state_dict: STATE_DICT_TYPE, + storage_writer: StorageWriter, + process_group: Optional[dist.ProcessGroup] = None, + coordinator_rank: int = 0, + no_dist: bool = False, + planner: Optional[SavePlanner] = None, + use_collectives: bool = True, +) -> Metadata: + torch._C._log_api_usage_once("torch.distributed.checkpoint.save_state_dict") + + distW = _DistWrapper(process_group, not no_dist, coordinator_rank) + if planner is None: + planner = DefaultSavePlanner() + if planner is None: + raise AssertionError("planner is None") + + global_metadata = None + + ckpt_kwargs = {} + if (ckpt_id := getattr(storage_writer, "checkpoint_id", None)) is not None: + ckpt_kwargs["checkpoint_id"] = ckpt_id + ckpt_kwargs["process_group"] = distW.group + + @_dcp_method_logger(**ckpt_kwargs) + def local_step(): + if planner is None: + raise AssertionError("planner is None") + storage_meta = storage_writer.storage_meta() + if "storage_meta" not in inspect.signature(planner.set_up_planner).parameters: + warnings.warn( + "The function definition for SavePlanner.set_up_planner has been updated" + " to include the storage_meta argument. Please update your implementation" + " to include this parameter.", + stacklevel=2, + ) + planner.set_up_planner(state_dict, distW.is_coordinator) # type: ignore[call-arg, arg-type] + else: + planner.set_up_planner( + state_dict=state_dict, + storage_meta=storage_meta, + is_coordinator=distW.is_coordinator, + ) + + if ( + "kwargs" + in inspect.signature(storage_writer.set_up_storage_writer).parameters + ): + storage_writer.set_up_storage_writer( + distW.is_coordinator, + rank=distW.rank, + use_collectives=use_collectives, + ) + else: + storage_writer.set_up_storage_writer(distW.is_coordinator) + + local_plan = planner.create_local_plan() + local_plan = storage_writer.prepare_local_plan(local_plan) + return local_plan + + @_dcp_method_logger(**ckpt_kwargs) + def global_step(all_local_plans): + nonlocal global_metadata + + if planner is None: + raise AssertionError("planner is None") + all_local_plans, global_metadata = planner.create_global_plan(all_local_plans) + all_local_plans = storage_writer.prepare_global_plan(all_local_plans) + return all_local_plans + + central_plan: Optional[SavePlan] = None + if use_collectives: + central_plan = distW.reduce_scatter("plan", local_step, global_step) + else: + local_plan: SavePlan = local_step() + global_plan: list[SavePlan] = global_step([local_plan]) + central_plan = global_plan[0] + + @_dcp_method_logger(**ckpt_kwargs) + def write_data(): + if planner is None: + raise AssertionError("planner is None") + if central_plan is None: + raise AssertionError("central_plan is None") + final_local_plan = planner.finish_plan(central_plan) + all_writes = storage_writer.write_data(final_local_plan, planner) + + all_writes.wait() + return all_writes.value() + + @_dcp_method_logger(**ckpt_kwargs) + def finish_checkpoint(all_results): + if global_metadata is None: + raise AssertionError("global_metadata is None") + storage_writer.finish(metadata=global_metadata, results=all_results) + return global_metadata + + if use_collectives: + metadata = distW.all_reduce("write", write_data, finish_checkpoint) + else: + write_results: list[WriteResult] = write_data() + metadata = finish_checkpoint([write_results]) + distW.barrier() + + return metadata diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/stateful.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/stateful.py new file mode 100644 index 0000000000000000000000000000000000000000..15e227d92fb5d29631b0316b3971c435120ad15b --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/stateful.py @@ -0,0 +1,42 @@ +from typing import Any, TypeVar +from typing_extensions import Protocol, runtime_checkable + + +__all__ = ["Stateful", "StatefulT"] + + +@runtime_checkable +class Stateful(Protocol): + """ + Stateful protocol for objects that can be checkpointed and restored. + """ + + def state_dict(self) -> dict[str, Any]: + """ + Objects should return their state_dict representation as a dictionary. + The output of this function will be checkpointed, and later restored in + `load_state_dict()`. + + .. warning:: + Because of the inplace nature of restoring a checkpoint, this function + is also called during `torch.distributed.checkpoint.load`. + + + Returns: + Dict: The objects state dict + """ + + ... + + def load_state_dict(self, state_dict: dict[str, Any]) -> None: + """ + Restore the object's state from the provided state_dict. + + Args: + state_dict: The state dict to restore from + """ + + ... + + +StatefulT = TypeVar("StatefulT", bound=Stateful) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/storage.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/storage.py new file mode 100644 index 0000000000000000000000000000000000000000..b184d7b1700528ad22bc10726cb6619975e8d9e8 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/storage.py @@ -0,0 +1,288 @@ +import abc +import os +from dataclasses import dataclass +from typing import Any, Optional, Union + +from torch.distributed.checkpoint.metadata import Metadata, MetadataIndex, StorageMeta +from torch.distributed.checkpoint.planner import ( + LoadPlan, + LoadPlanner, + SavePlan, + SavePlanner, +) +from torch.futures import Future + + +__all__ = ["WriteResult", "StorageWriter", "StorageReader"] + + +@dataclass(frozen=True) +class WriteResult: + index: MetadataIndex + + size_in_bytes: int + storage_data: Any + + +class StorageWriter(abc.ABC): + """ + Interface used by ``save_state_dict`` to write to storage. + + One StorageWriter instance acts as both the coordinator and the follower + in a distributed checkpoint. As part of initialization, each instance + is told its role. + + A subclass should expect the following sequence of calls. + + 0) (all ranks) set checkpoint_id if users pass a valid checkpoint_id. + 1) (all ranks) set_up_storage_writer() + 2) (all ranks) prepare_local_plan() + 3) (coordinator) prepare_global_plan() + 4) (all ranks) write_data() + 5) (coordinator) finish() + """ + + @abc.abstractmethod + def reset(self, checkpoint_id: Union[str, os.PathLike, None] = None) -> None: + """ + Calls to indicates a brand new checkpoint write is going to happen. + A checkpoint_id may be present if users set the checkpoint_id for + this checkpoint write. The meaning of the checkpiont_id is + storage-dependent. It can be a path to a folder/file or a key for + a key-value storage. + + Args: + checkpoint_id (Union[str, os.PathLike, None]): + The ID of this checkpoint instance. The meaning of the checkpoint_id + depends on the storage. It can be a path to a folder or to a file. + It can also be a key if the storage is a key-value store. + (Default: ``None``) + """ + ... + + @abc.abstractmethod + def set_up_storage_writer( + self, is_coordinator: bool, *args: Any, **kwargs: Any + ) -> None: + """ + Initialize this instance. + + Args: + is_coordinator (bool): Whether this instance is responsible for coordinating + the checkpoint. + """ + + @abc.abstractmethod + def prepare_local_plan(self, plan: SavePlan) -> SavePlan: + """ + Perform storage-specific local planning. + + While this method can produce a completely different plan, the recommended + way is to store storage specific data in SavePlan::storage_data. + + Args: + plan (SavePlan): The local plan from the ``SavePlanner`` in use. + + Returns: + A transformed ``SavePlan`` after storage local planning + """ + + @abc.abstractmethod + def prepare_global_plan(self, plans: list[SavePlan]) -> list[SavePlan]: + """ + Perform centralized planning of storage. + + This method is only called on the coordinator instance. + + While this method can produce a completely different plan, the preferred + way is to store storage specific data in SavePlan::storage_data. + + Args: + plans: A list of ``SavePlan`` instances, one for each rank. + + Returns: + A list of transformed ``SavePlan`` after storage global planning + """ + + @abc.abstractmethod + def write_data( + self, plan: SavePlan, planner: SavePlanner + ) -> Future[list[WriteResult]]: + """ + Write all items from ``plan`` using ``planner`` to resolve the data. + + A subclass should call ``SavePlanner::resolve_data`` on each item + from the plan to get access to the underlying object to write. + + Subclasses should lazily call `resolve_data` as it can allocate memory. + In case of tensors, make following assumptions: + + - They might be on any device, including not matching the one on ``WriteItem::tensor_data`` + - They might be views or not contiguous. Only the projection needs to be saved. + + Args: + plan (SavePlan): The save plan to execute. + planner (SavePlanner): Planner object to be used to resolve items to data. + + Returns: + A future that completes to a list of WriteResult + """ + + @abc.abstractmethod + def finish(self, metadata: Metadata, results: list[list[WriteResult]]) -> None: + """ + Write the metadata and marks the current checkpoint as successful. + + The actual format/schema used for serializing `metadata` is an + implementation detail. The only requirement is that it's recoverable + in to the same object graph. + + Args: + metadata (Metadata): metadata for the new checkpoint + results: A list of WriteResults from all ranks. + + Returns: + None + """ + + @classmethod + @abc.abstractmethod + def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool: + """ + Check if the given checkpoint_id is supported by the storage. This allow + us to enable automatic storage selection. + """ + ... + + def storage_meta(self) -> Optional[StorageMeta]: + """ + Return the storage-specific metadata. This is used to store additional information + in a checkpoint that can be useful for providing request-level observability. StorageMeta + is passed to the ``SavePlanner`` during save calls. Returns None by default. + + TODO: provide an example + """ + return None + + +class StorageReader(abc.ABC): + """ + Interface used by ``load_state_dict`` to read from storage. + + One StorageReader instance acts as both the coordinator and the follower + in a distributed checkpoint. As part of initialization, each instance + is told its role. + + A subclass should expected the following sequence of calls by ``load_state_dict``: + + 0) (all ranks) set checkpoint_id if users pass a valid checkpoint_id. + 1) (all ranks) read_metadata() + 2) (all ranks) set_up_storage_reader() + 3) (all ranks) prepare_local_plan() + 4) (coordinator) prepare_global_plan() + 5) (all ranks) read_data() + """ + + @abc.abstractmethod + def reset(self, checkpoint_id: Union[str, os.PathLike, None] = None) -> None: + """ + Calls to indicates a brand new checkpoint read is going to happen. + A checkpoint_id may be present if users set the checkpoint_id for + this checkpoint read. The meaning of the checkpiont_id is + storage-dependent. It can be a path to a folder/file or a key for + a key-value storage. + + Args: + checkpoint_id (Union[str, os.PathLike, None]): + The ID of this checkpoint instance. The meaning of the checkpoint_id + depends on the storage. It can be a path to a folder or to a file. + It can also be a key if the storage is more like a key-value store. + (Default: ``None``) + """ + ... + + @abc.abstractmethod + def read_metadata(self, *args: Any, **kwargs: Any) -> Metadata: + """ + Read the checkpoint metadata. + + Returns: + The metadata object associated with the checkpoint being loaded. + + """ + + @abc.abstractmethod + def set_up_storage_reader( + self, metadata: Metadata, is_coordinator: bool, *args: Any, **kwargs: Any + ) -> None: + """ + Initialize this instance. + + Args: + metadata (Metadata): The metadata schema to use. + is_coordinator (bool): Whether this instance is responsible for coordinating + the checkpoint. + """ + + @abc.abstractmethod + def prepare_local_plan(self, plan: LoadPlan) -> LoadPlan: + """ + Perform storage-specific local planning. + + While this method can produce a completely different plan, the recommended + way is to store storage specific data in LoadPlan::storage_data. + + Args: + plan (LoadPlan): The local plan from the ``LoadPlan`` in use. + + Returns: + A transformed ``LoadPlan`` after storage local planning + """ + + @abc.abstractmethod + def prepare_global_plan(self, plans: list[LoadPlan]) -> list[LoadPlan]: + """ + Perform centralized planning of storage loading. + + This method is only called on the coordinator instance. + + While this method can produce a completely different plan, the preferred + way is to store storage specific data in LoadPlan::storage_data. + + Args: + plans: A list of ``LoadPlan`` instances, one for each rank. + + Returns: + A list of transformed ``LoadPlan`` after storage global planning + """ + + @abc.abstractmethod + def read_data(self, plan: LoadPlan, planner: LoadPlanner) -> Future[None]: + """ + Read all items from ``plan`` using ``planner`` to resolve the data. + + A subclass should call ``LoadPlanner::load_bytes`` to deserialize a BytesIO + object into the right place. + + A subclass should call ``LoadPlanner::resolve_tensor`` to get access to the + tensors that in should load data into. + + It's the StorageLayer responsibility to properly schedule any cross device copies + required. + + Args: + plan (LoadPlan): The local plan to execute on + planner (LoadPlanner): The planner object to use to resolve items. + + Returns: + A future that completes once all reads are finished. + """ + + @classmethod + @abc.abstractmethod + def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool: + """ + Check if the given checkpoint_id is supported by the storage. This allow + us to enable automatic storage selection. + """ + ... diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/utils.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..073649c5f124d1817af12d161d8a80b76ae3ceda --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/utils.py @@ -0,0 +1,485 @@ +# mypy: allow-untyped-defs +import cProfile +import inspect +import io +import itertools +import os +import warnings +from collections.abc import Callable, Sequence +from contextlib import contextmanager +from functools import wraps +from pstats import Stats +from typing import Any, cast, Optional, TypeVar, Union + +import torch +import torch.distributed as dist +from torch.distributed._shard.sharded_tensor import ShardedTensor +from torch.distributed._shard.sharded_tensor.shard import Shard + +from .api import ( + _is_wrapped_exception, + _wrap_exception, + CheckpointException, + WRAPPED_EXCEPTION, +) +from .metadata import MetadataIndex, STATE_DICT_TYPE + + +__all__ = ["find_tensor_shard", "find_state_dict_object"] + +T = TypeVar("T") +R = TypeVar("R") + + +def _get_failure_dict( + results: list[Union[T, WRAPPED_EXCEPTION]], +) -> dict[int, WRAPPED_EXCEPTION]: + return cast( + dict[int, WRAPPED_EXCEPTION], + {i: err for i, err in enumerate(results) if _is_wrapped_exception(err)}, + ) + + +def _all_gather_keys( + local_dict: dict[str, Any], group: Optional[dist.ProcessGroup] = None +) -> set[str]: + """Gathers all keys, and returns them sorted.""" + keys = list(local_dict.keys()) + gathered_keys: list[list[str]] = [None] * dist.get_world_size(group) # type: ignore[list-item] + + dist.all_gather_object(gathered_keys, keys, group=group) + return set(itertools.chain.from_iterable(gathered_keys)) + + +def _assert_same_keys( + state_dict: dict[str, Any], process_group: Optional[dist.ProcessGroup] = None +) -> None: + """ + Asserts that all ranks have the same keys in their state dict. + This is a collective call which requires all ranks in ``process_group`` to + join. It will also induce cross-rank communication and block CPU. + """ + + if dist.get_world_size(process_group) == 1: + return + + all_keys = _all_gather_keys(state_dict, process_group) + my_keys = set(state_dict.keys()) + diff = all_keys - my_keys + if len(diff) > 0: + raise AssertionError( + f"Key(s) present in other ranks but not this one, difference: {diff}" + ) + + +class _DistWrapper: + """ + This is a wrapper around PG that provides a series of features around object collectives. + + It works without distributed initialized, where most collectives turns into nops. + + All variants that take functions are exception robust, meaning that if one or more + ranks raise errors, all ranks will observe those. + """ + + def __init__( + self, + group: Optional[dist.ProcessGroup], + use_dist: bool, + coordinator_rank: int, + ): + self.group = group + self.use_dist = use_dist + self.coordinator_rank = coordinator_rank + if self.use_dist: + self.global_coordinator_rank = ( + dist.get_global_rank(group, coordinator_rank) + if group is not None + else coordinator_rank + ) + self.rank = dist.get_rank(group) + self.is_coordinator = self.rank == coordinator_rank + else: + self.global_coordinator_rank = 0 + self.rank = 0 + self.is_coordinator = True + + def get_rank(self) -> int: + return self.rank + + def get_world_size(self) -> int: + if self.use_dist: + return dist.get_world_size(self.group) + return 1 + + def broadcast_object(self, object: Optional[T]) -> T: + """Implement functionality similar to c10d::broadcast_object_list but without distributed enabled.""" + object_list = [object] + if self.use_dist: + dist.broadcast_object_list( + object_list=object_list, + group=self.group, + src=self.global_coordinator_rank, + ) + return cast(T, object_list[0]) + + def gather_object(self, object: T) -> Optional[list[T]]: + """Implement functionality similar to c10d::gather_object but without distributed enabled.""" + if self.use_dist: + gather_objs = ( + cast(list[T], [None] * dist.get_world_size(self.group)) + if self.is_coordinator + else None + ) + + dist.gather_object( + obj=object, + object_gather_list=gather_objs if self.is_coordinator else None, + dst=self.global_coordinator_rank, + group=self.group, + ) + result = gather_objs + else: + result = [object] + return result + + def all_gather_object(self, object: T) -> list[T]: + """Implement functionality similar to c10d::all_gather_object but without distributed enabled.""" + if self.use_dist: + gather_objs = cast(list[T], [None] * dist.get_world_size(self.group)) + + dist.all_gather_object( + object_list=gather_objs, obj=object, group=self.group + ) + else: + gather_objs = [object] + return gather_objs + + def scatter_object(self, object_list: Optional[list[T]]) -> T: + """Implement functionality similar to c10d::scatter_object but without distributed enabled.""" + if self.use_dist: + gather_result = cast(list[T], [None]) + dist.scatter_object_list( + scatter_object_output_list=gather_result, + scatter_object_input_list=object_list if self.is_coordinator else None, + src=self.global_coordinator_rank, + group=self.group, + ) + + local_reply = gather_result[0] + else: + if object_list is None: + raise AssertionError("object_list is None") + local_reply = object_list[0] + return local_reply + + def reduce_scatter( + self, + step: str, + map_fun: Callable[[], T], + reduce_fun: Callable[[list[T]], list[R]], + ) -> R: + """ + Compute a value on each rank, then do centralized reduce on a single rank, followed by a scatter. + + This method operates in the following way: + Run ``map_fun`` on all ranks + Gather results on rank 0 + Call ``reduce_fun`` on all those values + Scatter to each rank part of the result. + """ + local_data: Union[WRAPPED_EXCEPTION, T] + try: + local_data = map_fun() + except BaseException as e: # noqa: B036 + local_data = _wrap_exception(e) + + all_data = self.gather_object(local_data) + all_results: Optional[list[Union[R, CheckpointException]]] = None + if self.is_coordinator: + if all_data is None: + raise AssertionError("all_data is None") + node_failures = _get_failure_dict(all_data) + + if len(node_failures) == 0: + try: + # N.B. why can't mypy cast List[R] to List[Union[R, WRAPPED_EXCEPTION]]? + all_results = cast( + list[Union[R, CheckpointException]], + reduce_fun(cast(list[T], all_data)), + ) + except BaseException as e: # noqa: B036 + node_failures[self.rank] = _wrap_exception(e) + + if len(node_failures) > 0: + all_results = [ + CheckpointException(step, node_failures) + ] * self.get_world_size() + + result = self.scatter_object(all_results) + if isinstance(result, CheckpointException): + raise result + return result + + def all_reduce( + self, + step: str, + map_fun: Callable[[], T], + reduce_fun: Callable[[list[T]], R], + ) -> R: + """ + Compute a value on each rank, then do centralized reduce on a single rank, followed by a broadcast. + + This method operates in the following way: + Run ``map_fun`` on all ranks + Gather results on rank 0 + Call ``reduce_fun`` on all those values + Broadcast the reduced value to all ranks. + """ + local_data: Union[T, WRAPPED_EXCEPTION] + try: + local_data = map_fun() + except BaseException as e: # noqa: B036 + local_data = _wrap_exception(e) + + all_data = self.gather_object(local_data) + result: Optional[Union[R, CheckpointException]] = None + if self.is_coordinator: + if all_data is None: + raise AssertionError("all_data is None") + node_failures = _get_failure_dict(all_data) + if len(node_failures) == 0: + try: + result = reduce_fun(cast(list[T], all_data)) + except BaseException as e: # noqa: B036 + node_failures[self.rank] = _wrap_exception(e) + + if len(node_failures) > 0: + result = CheckpointException(step, node_failures) + + # pyrefly: ignore [bad-argument-type] + final_result = self.broadcast_object(result) + if isinstance(final_result, CheckpointException): + raise final_result + return cast(R, final_result) + + def all_gather( + self, + step: str, + map_fun: Callable[[], T], + ) -> list[T]: + """ + Compute a value on each rank, then all_gather them. + + This method operates in the following way: + Run ``map_cp`` on all ranks + all_gather the values to all ranks + """ + result: Union[T, WRAPPED_EXCEPTION] + try: + result = map_fun() + except BaseException as e: # noqa: B036 + result = _wrap_exception(e) + + all_results = self.all_gather_object(result) + + node_failures = _get_failure_dict(all_results) + if len(node_failures) > 0: + raise CheckpointException(step, node_failures) + return cast(list[T], all_results) + + def broadcast( + self, + step: str, + map_fun: Callable[[], T], + ) -> T: + """ + Compute a value on rank 0 and broadcast it. + + This method operates in the following way: + Run ``map_cp`` on rank 0 + broadcast the value + """ + result: Optional[Union[T, CheckpointException]] = None + if self.is_coordinator: + try: + result = map_fun() + except BaseException as e: # noqa: B036 + result = CheckpointException(step, {self.rank: _wrap_exception(e)}) + # pyrefly: ignore [bad-argument-type] + final_result = self.broadcast_object(result) + if isinstance(final_result, CheckpointException): + raise final_result + return cast(T, final_result) + + def barrier(self) -> None: + """ + Add a synchronization point across all processes when using distributed. + If torch.distributed is initialized, this function will invoke a barrier across the global process group. + If torch.distributed is not initialized, this function is a no-op. + """ + if not self.use_dist: + return + dist.barrier(group=self.group) + + +def _find_shard(tensor: ShardedTensor, index: MetadataIndex) -> Shard: + if index.offset is None: + raise ValueError( + f"Cannot lookup {index.fqn} since its a ShardedTensor and no offset was provided" + ) + + shards = tensor.local_shards() + # index fast path + if index.index is not None: + if ( + len(shards) > index.index + and torch.Size(shards[index.index].metadata.shard_offsets) == index.offset + ): + return shards[index.index] + + for shard in shards: + if torch.Size(shard.metadata.shard_offsets) == index.offset: + return shard + raise ValueError(f"Could not find shard at '{index.offset}' for FQN: '{index.fqn}'") + + +def find_tensor_shard(tensor: torch.Tensor, index: MetadataIndex) -> torch.Tensor: + if hasattr(tensor, "__get_tensor_shard__"): + # DTensor implements _Checkpointable + return tensor.__get_tensor_shard__(index) # type: ignore[attr-defined] + if isinstance(tensor, ShardedTensor): + return _find_shard(tensor, index).tensor + if index.offset is not None: + # special case looking up a tensor by origin + if index.offset == torch.Size([0] * len(tensor.size())): + return tensor + raise ValueError( + f"FQN: '{index.fqn}' is not a ShardedTensor, can't find by offset: '{index.offset}'" + ) + return tensor + + +def find_state_dict_object(state_dict: STATE_DICT_TYPE, index: MetadataIndex) -> Any: + if index.fqn not in state_dict: + raise ValueError(f"Could not find FQN: '{index.fqn}'") + obj = state_dict[index.fqn] + + if isinstance(obj, torch.Tensor): + return find_tensor_shard(obj, index) + elif index.offset is not None: + raise ValueError( + f"FQN: '{index.fqn}' is not a ShardedTensor, can't find by offset: '{index.offset}'" + ) + return obj + + +def _element_wise_add(a: Sequence[int], b: Sequence[int]) -> list[int]: + return [i_a + i_b for i_a, i_b in zip(a, b)] + + +def _element_wise_sub(a: Sequence[int], b: Sequence[int]) -> list[int]: + return [i_a - i_b for i_a, i_b in zip(a, b)] + + +class _ReaderView(io.IOBase): + def __init__(self, base_stream: io.IOBase, offset: int, len: int): + super().__init__() + self.offset = offset + self.len = len + self.base_stream = base_stream + self.seek(0) + + def seek(self, offset: int, whence: int = os.SEEK_SET, /) -> int: + if whence == os.SEEK_SET: + offset = self.offset + offset + elif whence == os.SEEK_END: + whence = os.SEEK_SET + offset = (self.offset + self.len) - offset + return self.base_stream.seek(offset, whence) + + def tell(self) -> int: + return self.base_stream.tell() - self.offset + + def readable(self) -> bool: + return self.base_stream.readable() + + def seekable(self) -> bool: + return self.base_stream.seekable() + + def readinto(self, b): + max_size = self.len - self.tell() + if max_size == 0: + return 0 + if len(b) > max_size: + b = memoryview(b)[:max_size] + return self.base_stream.readinto(b) # type: ignore[attr-defined] + + def read(self, size=-1): + max_size = self.len - self.tell() + if size == -1 or size > max_size: + size = max_size + return self.base_stream.read(size) + + +def _create_file_view(file: io.IOBase, offset: int, length: int) -> io.IOBase: + # FIXME (kumpera) torch.load fails if we wrap with io.BufferedReader + return _ReaderView(file, offset, length) + + +def _normalize_device_info(device_type: str, device_id: int) -> str: + """Device info normalization.""" + if device_type == "cpu": + return "cpu" + return f"{device_type}:{device_id}" + + +# TODO: integrate with distributed logging flag +ENABLE_PROFILE = False + + +@contextmanager +def _profile(): + # Only log the profiling when it is enable and is on rank0 or dist is not + # available. + if ENABLE_PROFILE and (not dist.is_available() or dist.get_rank() == 0): + profiler = cProfile.Profile() + profiler.enable() + try: + yield + finally: + profiler.disable() + stats = Stats(profiler) + stats.sort_stats("time").print_stats(10) + else: + yield + + +def _api_bc_check(func): + @wraps(func) + def inner_func(*args, **kwargs) -> Any: + if len(args) == 2: + warnings.warn( + f"The argument order of {func.__name__} has been changed. " + "Please check the document to avoid future breakages.", + stacklevel=2, + ) + sig = inspect.signature(func) + kwonlyargs = [ + p.name for p in sig.parameters.values() if p.kind == p.KEYWORD_ONLY + ] + if "storage_writer" in kwonlyargs: + if "storage_writer" in kwargs: + raise AssertionError(f"storage_writer in kwargs: {(args, kwargs)}") + kwargs["storage_writer"] = args[1] + elif "storage_reader" in kwonlyargs: + if "storage_reader" in kwargs: + raise AssertionError(f"storage_reader in kwargs: {(args, kwargs)}") + kwargs["storage_reader"] = args[1] + else: + raise RuntimeError(f"Unexpected kwonlyargs = {kwonlyargs}") + return func(args[0], **kwargs) + else: + return func(*args, **kwargs) + + return inner_func diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/debug/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/debug/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..93295802ae847cb939954e8c8918dfd2ce49cf4f --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/debug/__init__.py @@ -0,0 +1,88 @@ +import logging +import multiprocessing +import socket + +# import for registration side effect +import torch.distributed.debug._handlers # noqa: F401 +from torch._C._distributed_c10d import _WorkerServer +from torch.distributed.debug._store import get_rank, tcpstore_client + + +__all__ = [ + "start_debug_server", + "stop_debug_server", +] + +logger: logging.Logger = logging.getLogger(__name__) + +_WORKER_SERVER: _WorkerServer | None = None +_DEBUG_SERVER_PROC: multiprocessing.Process | None = None + + +def start_debug_server(port: int = 25999, worker_port: int = 0) -> None: + """ + Start the debug server stack on all workers. The frontend debug server is + only started on rank0 while the per rank worker servers are started on all + ranks. + + This server provides an HTTP frontend that allows for debugging slow and + deadlocked distributed jobs across all ranks simultaneously. This collects + data such as stack traces, FlightRecorder events, and performance profiles. + + This depends on dependencies which are not installed by default. + + Dependencies: + - Jinja2 + - aiohttp + + WARNING: This is intended to only be used in trusted network environments. + The debug server is not designed to be secure and should not be exposed to + the public internet. See SECURITY.md for more details. + + WARNING: This is an experimental feature and may change at any time. + + Args: + port (int): The port to start the frontend debug server on. + worker_port (int): The port to start the worker server on. Defaults to 0, which + will cause the worker server to bind to an ephemeral port. + """ + global _WORKER_SERVER, _DEBUG_SERVER_PROC + + assert _WORKER_SERVER is None, "debug server already started" + assert _DEBUG_SERVER_PROC is None, "debug server already started" + + logger.info("Starting debug server on port %d", port) + + store = tcpstore_client() + + _WORKER_SERVER = _WorkerServer("::", worker_port) + + RANK = get_rank() + store.set(f"rank{RANK}", f"http://{socket.gethostname()}:{_WORKER_SERVER.port}") + + from torch.distributed.debug._frontend import main + + if RANK == 0: + _DEBUG_SERVER_PROC = multiprocessing.Process( + target=main, args=(port,), daemon=True + ) + _DEBUG_SERVER_PROC.start() + + +def stop_debug_server() -> None: + """ + Shutdown the debug server and stop the frontend debug server process. + """ + global _WORKER_SERVER, _DEBUG_SERVER_PROC + + assert _DEBUG_SERVER_PROC is not None + assert _WORKER_SERVER is not None + + logger.info("Stopping debug server") + + _DEBUG_SERVER_PROC.terminate() + _WORKER_SERVER.shutdown() + _DEBUG_SERVER_PROC.join() + + _WORKER_SERVER = None + _DEBUG_SERVER_PROC = None diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/debug/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/debug/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1ebdd99d9ed7ad9f8651c4f921ce34d6d3231f6e Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/debug/__pycache__/__init__.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/debug/__pycache__/_frontend.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/debug/__pycache__/_frontend.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..95dc73c09c917cb62919768c786fddea77771e2f Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/debug/__pycache__/_frontend.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/debug/__pycache__/_handlers.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/debug/__pycache__/_handlers.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8be441fe6e649ca94c73cebbc26c891987221243 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/debug/__pycache__/_handlers.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/debug/__pycache__/_store.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/debug/__pycache__/_store.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9acde11c9d4fa39361eb91707b6c76af1ed763fa Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/debug/__pycache__/_store.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/debug/_frontend.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/debug/_frontend.py new file mode 100644 index 0000000000000000000000000000000000000000..16cccb88632f0372bac132a01cc8b97f60223852 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/debug/_frontend.py @@ -0,0 +1,553 @@ +import asyncio +import json +import logging +import socket +import threading +from collections.abc import Iterable +from concurrent.futures import ThreadPoolExecutor +from dataclasses import dataclass +from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer +from urllib.parse import parse_qs, urlparse + +from jinja2 import DictLoader, Environment +from tabulate import tabulate + +from torch.distributed.debug._store import get_world_size, tcpstore_client +from torch.distributed.flight_recorder.components.builder import build_db +from torch.distributed.flight_recorder.components.config_manager import JobConfig +from torch.distributed.flight_recorder.components.types import ( + Collective, + Group, + Membership, + NCCLCall, +) + + +logger: logging.Logger = logging.getLogger(__name__) + + +@dataclass(slots=True) +class Response: + status_code: int + text: str + + def raise_for_status(self): + if self.status_code != 200: + raise RuntimeError(f"HTTP {self.status_code}: {self.text}") + + def json(self): + return json.loads(self.text) + + +def fetch_thread_pool(urls: list[str]) -> Iterable[Response]: + # late import for optional dependency + import requests + + max_workers = 20 + + def get(url: str) -> Response: + resp = requests.post(url) + return Response(resp.status_code, resp.text) + + with ThreadPoolExecutor(max_workers=max_workers) as executor: + resps = executor.map(get, urls) + + return resps + + +def fetch_aiohttp(urls: list[str]) -> Iterable[Response]: + # late import for optional dependency + import aiohttp + + async def fetch(session: aiohttp.ClientSession, url: str) -> Response: + async with session.post(url) as resp: + text = await resp.text() + return Response(resp.status, text) + + async def gather(urls: list[str]) -> Iterable[Response]: + async with aiohttp.ClientSession() as session: + return await asyncio.gather(*[fetch(session, url) for url in urls]) + + return asyncio.run(gather(urls)) + + +def fetch_all(endpoint: str, args: str = "") -> tuple[list[str], Iterable[Response]]: + store = tcpstore_client() + keys = [f"rank{r}" for r in range(get_world_size())] + addrs = store.multi_get(keys) + addrs = [f"{addr.decode()}/handler/{endpoint}?{args}" for addr in addrs] + + try: + resps = fetch_aiohttp(addrs) + except ImportError: + resps = fetch_thread_pool(addrs) + + return addrs, resps + + +def format_json(blob: str): + parsed = json.loads(blob) + return json.dumps(parsed, indent=2) + + +templates = { + "base.html": """ + + + {% block title %}{% endblock %} - PyTorch Distributed + + + + + + + +
+ {% block header %}{% endblock %} + {% block content %}{% endblock %} +
+ """, + "index.html": """ +{% extends "base.html" %} +{% block header %} +

{% block title %}Index{% endblock %}

+{% endblock %} +{% block content %} +Hi +{% endblock %} + """, + "raw_resp.html": """ +{% extends "base.html" %} +{% block header %} +

{% block title %}{{title}}{% endblock %}

+{% endblock %} +{% block content %} + {% for i, (addr, resp) in enumerate(zip(addrs, resps)) %} +

Rank {{ i }}: {{ addr }}

+ {% if resp.status_code != 200 %} +

Failed to fetch: status={{ resp.status_code }}

+
{{ resp.text }}
+ {% else %} +
{{ resp.text }}
+ {% endif %} + {% endfor %} +{% endblock %} + """, + "json_resp.html": """ +{% extends "base.html" %} +{% block header %} +

{% block title %}{{ title }}{% endblock %}

+{% endblock %} +{% block content %} + {% for i, (addr, resp) in enumerate(zip(addrs, resps)) %} +

Rank {{ i }}: {{ addr }}

+ {% if resp.status_code != 200 %} +

Failed to fetch: status={{ resp.status_code }}

+
{{ resp.text }}
+ {% else %} +
{{ format_json(resp.text) }}
+ {% endif %} + {% endfor %} +{% endblock %} + """, + "profile.html": """ +{% extends "base.html" %} +{% block header %} +

{% block title %}torch.profiler{% endblock %}

+{% endblock %} + +{% block content %} +
+ + + +
+ + + + {% for i, (addr, resp) in enumerate(zip(addrs, resps)) %} +

Rank {{ i }}: {{ addr }}

+ {% if resp.status_code != 200 %} +

Failed to fetch: status={{ resp.status_code }}

+
{{ resp.text }}
+ {% else %} + + + + {% endif %} + {% endfor %} +{% endblock %} + """, + "tcpstore.html": """ +{% extends "base.html" %} +{% block header %} +

{% block title %}TCPStore Keys{% endblock %}

+{% endblock %} +{% block content %} +
+    {% for k, v in zip(keys, values) -%}
+{{ k }}: {{ v | truncate(100) }}
+    {% endfor %}
+    
+{% endblock %} + """, + "fr_trace.html": """ +{% extends "base.html" %} +{% block header %} +

{% block title %}{{ title }}{% endblock %}

+{% endblock %} +{% block content %} +

Groups

+ {{ groups | safe }} +

Memberships

+ {{ memberships | safe }} +

Collectives

+ {{ collectives | safe }} +

NCCL Calls

+ {{ ncclcalls | safe }} +{% endblock %} + """, + "pyspy_dump.html": """ +{% extends "base.html" %} +{% block header %} +

{% block title %}py-spy Stack Traces{% endblock %}

+{% endblock %} +{% block content %} +
+ + + + + +
+ + {% for i, (addr, resp) in enumerate(zip(addrs, resps)) %} +

Rank {{ i }}: {{ addr }}

+ {% if resp.status_code != 200 %} +

Failed to fetch: status={{ resp.status_code }}

+
{{ resp.text }}
+ {% else %} +
{{ resp.text }}
+ {% endif %} + {% endfor %} +{% endblock %} + """, +} + + +class _IPv6HTTPServer(ThreadingHTTPServer): + address_family: socket.AddressFamily = socket.AF_INET6 # pyre-ignore + request_queue_size: int = 1024 + + +class HTTPRequestHandler(BaseHTTPRequestHandler): + frontend: "FrontendServer" + + def log_message(self, format, *args): + logger.info( + "%s %s", + self.client_address[0], + format % args, + ) + + def do_GET(self): + self.frontend._handle_request(self) + + def get_path(self) -> str: + return urlparse(self.path).path + + def get_query(self) -> dict[str, list[str]]: + return parse_qs(self.get_raw_query()) + + def get_raw_query(self) -> str: + return urlparse(self.path).query + + def get_query_arg( + self, name: str, default: object = None, type: type = str + ) -> object: + query = self.get_query() + if name not in query: + return default + return type(query[name][0]) + + +class FrontendServer: + def __init__(self, port: int): + # Setup templates + loader = DictLoader(templates) + self._jinja_env = Environment(loader=loader, enable_async=True) + self._jinja_env.globals.update( + zip=zip, + format_json=format_json, + enumerate=enumerate, + ) + + # Create routes + self._routes = { + "/": self._handle_index, + "/stacks": self._handle_stacks, + "/pyspy_dump": self._handle_pyspy_dump, + "/fr_trace": self._handle_fr_trace, + "/fr_trace_json": self._handle_fr_trace_json, + "/fr_trace_nccl": self._handle_fr_trace_nccl, + "/fr_trace_nccl_json": self._handle_fr_trace_nccl_json, + "/profile": self._handle_profiler, + "/wait_counters": self._handle_wait_counters, + "/tcpstore": self._handle_tcpstore, + } + + # Create HTTP server + RequestHandlerClass = type( + "HTTPRequestHandler", + (HTTPRequestHandler,), + {"frontend": self}, + ) + + server_address = ("", port) + self._server = _IPv6HTTPServer(server_address, RequestHandlerClass) + + self._thread = threading.Thread( + target=self._serve, + args=(), + daemon=True, + name="distributed.debug.FrontendServer", + ) + self._thread.start() + + def _serve(self) -> None: + try: + self._server.serve_forever() + except Exception: + logger.exception("got exception in frontend server") + + def join(self) -> None: + self._thread.join() + + def _handle_request(self, req: HTTPRequestHandler) -> None: + path = req.get_path() + if path not in self._routes: + req.send_error(404, f"Handler not found: {path}") + return + + handler = self._routes[path] + try: + resp = handler(req) + # Catch SystemExit to not crash when FlightRecorder errors. + except (Exception, SystemExit) as e: + logger.exception( + "Exception in frontend server when handling %s", + path, + ) + req.send_error(500, f"Exception: {repr(e)}") + return + + req.send_response(200) + req.send_header("Content-type", "text/html") + req.end_headers() + req.wfile.write(resp) + + def _render_template(self, template: str, **kwargs: object) -> bytes: + return self._jinja_env.get_template(template).render(**kwargs).encode() + + def _handle_index(self, req: HTTPRequestHandler) -> bytes: + return self._render_template("index.html") + + def _handle_stacks(self, req: HTTPRequestHandler) -> bytes: + addrs, resps = fetch_all("dump_traceback") + return self._render_template( + "raw_resp.html", title="Stacks", addrs=addrs, resps=resps + ) + + def _handle_pyspy_dump(self, req: HTTPRequestHandler) -> bytes: + addrs, resps = fetch_all("pyspy_dump", req.get_raw_query()) + return self._render_template( + "pyspy_dump.html", + addrs=addrs, + resps=resps, + ) + + def _render_fr_trace(self, addrs: list[str], resps: list[Response]) -> bytes: + config = JobConfig() + # pyrefly: ignore [bad-assignment] + args = config.parse_args(args=[]) + args.allow_incomplete_ranks = True + args.verbose = True + + details = {} + for rank, resp in enumerate(resps): + resp.raise_for_status() + dump = { + "rank": rank, + "host_name": addrs[rank], + **resp.json(), + } + if "entries" not in dump: + dump["entries"] = [] + details[f"rank{rank}.json"] = dump + + version = next(iter(details.values()))["version"] + + db = build_db(details, args, version) + + return self._render_template( + "fr_trace.html", + title="FlightRecorder", + groups=tabulate(db.groups, headers=Group._fields, tablefmt="html"), + memberships=tabulate( + db.memberships, headers=Membership._fields, tablefmt="html" + ), + collectives=tabulate( + db.collectives, headers=Collective._fields, tablefmt="html" + ), + ncclcalls=tabulate(db.ncclcalls, headers=NCCLCall._fields, tablefmt="html"), + ) + + def _handle_fr_trace(self, req: HTTPRequestHandler) -> bytes: + addrs, resps = fetch_all("fr_trace_json") + + return self._render_fr_trace(addrs, list(resps)) + + def _handle_fr_trace_json(self, req: HTTPRequestHandler) -> bytes: + addrs, resps = fetch_all("fr_trace_json") + + return self._render_template( + "json_resp.html", + title="FlightRecorder", + addrs=addrs, + resps=resps, + ) + + def _handle_fr_trace_nccl(self, req: HTTPRequestHandler) -> bytes: + addrs, resps = fetch_all("dump_nccl_trace_json", "onlyactive=true") + + return self._render_fr_trace(addrs, list(resps)) + + def _handle_fr_trace_nccl_json(self, req: HTTPRequestHandler) -> bytes: + addrs, resps = fetch_all("dump_nccl_trace_json", "onlyactive=true") + + return self._render_template( + "json_resp.html", + title="FlightRecorder NCCL", + addrs=addrs, + resps=resps, + ) + + def _handle_profiler(self, req: HTTPRequestHandler) -> bytes: + duration = req.get_query_arg("duration", default=1.0, type=float) + + addrs, resps = fetch_all("torch_profile", f"duration={duration}") + + return self._render_template("profile.html", addrs=addrs, resps=resps) + + def _handle_wait_counters(self, req: HTTPRequestHandler) -> bytes: + addrs, resps = fetch_all("wait_counter_values") + return self._render_template( + "json_resp.html", title="Wait Counters", addrs=addrs, resps=resps + ) + + def _handle_tcpstore(self, req: HTTPRequestHandler) -> bytes: + store = tcpstore_client(prefix="") + keys = store.list_keys() + keys.sort() + values = [repr(v) for v in store.multi_get(keys)] + return self._render_template("tcpstore.html", keys=keys, values=values) + + +def main(port: int) -> None: + logger.setLevel(logging.INFO) + + server = FrontendServer(port=port) + logger.info("Frontend server started on port %d", server._server.server_port) + server.join() diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/debug/_handlers.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/debug/_handlers.py new file mode 100644 index 0000000000000000000000000000000000000000..b8095c5b34bea5d2408ed87b21b541bf8966f4ad --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/debug/_handlers.py @@ -0,0 +1,23 @@ +import pathlib +import tempfile +import time + +from torch._C._distributed_c10d import _register_handler, _Request, _Response +from torch.profiler import _ExperimentalConfig, profile + + +def _torch_profile(req: _Request, resp: _Response) -> None: + experimental_config = _ExperimentalConfig( + profile_all_threads=True, + ) + duration = float(req.get_param("duration")) + with profile(record_shapes=True, experimental_config=experimental_config) as prof: + time.sleep(duration) + + with tempfile.NamedTemporaryFile(prefix="torch_debug", suffix=".json") as f: + prof.export_chrome_trace(f.name) + resp.set_content(pathlib.Path(f.name).read_bytes(), "application/json") + resp.set_status(200) + + +_register_handler("torch_profile", _torch_profile) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/debug/_store.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/debug/_store.py new file mode 100644 index 0000000000000000000000000000000000000000..487dd30abd6aff96d676ee3cf10d98490613e2a1 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/debug/_store.py @@ -0,0 +1,25 @@ +import os + +import torch.distributed as dist + + +def get_rank() -> int: + return int(os.environ["RANK"]) + + +def get_world_size() -> int: + return int(os.environ["WORLD_SIZE"]) + + +def tcpstore_client(prefix: str = "debug_server") -> dist.Store: + MASTER_ADDR = os.environ["MASTER_ADDR"] + MASTER_PORT = int(os.environ["MASTER_PORT"]) + + store = dist.TCPStore( + host_name=MASTER_ADDR, + port=MASTER_PORT, + is_master=False, + ) + if prefix: + store = dist.PrefixStore(prefix, store) + return store diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a7c9b29a750593a812907ce2cf4c800d7d1435bb --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/__init__.py @@ -0,0 +1,77 @@ +#!/usr/bin/env/python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" + +Torchelastic agent and user worker failover contract: + +**TL;DR;**: + +* TE(torchelastic) expects user workers to finish with the 5 minutes drift +* It is better to design DDP app to fail for all workers, rather than a single one. +* TE does not synchronize number of restarts between agents +* TE re-rendezvous does not trigger restart decrease +* When a single agent finishes its job(successfully or not), it will close rendezvous. + If other agents still have workers in progress, they will be terminated. +* Based on above, scale down does not work if at least single agent finishes the job. +* When Scale up is detected by agents, it will not decrease ``max_restarts`` + + +In general TE(torchelastic) can launch arbitrary user code, but there is some +clarifications need to be done around what failover mechanism torchelastic +provides and what failover mechanism it expects from user workers. + +Torchelastic currently supports DDP style applications. That means that +TE expects *ALL* workers finish approximately at the same time. In practice, +it is nearly to impossible to guarantee that all workers in arbitrary +DDP application finish at the time, so TE provides a finalization barrier +that waits for TIMEOUT(5 minutes) for worker finalization. + +**Worker Failure** + +When worker fails, TE will check the number of restarts +available, if there is more than 0 restarts, TE will start a new rendezvous +round and restart the worker process. New rendezvous round will other +TE agents to terminate their workers. + +.. note:: The TE agent does not synchronize restarts between themselves. + When a single agent performs restart, it will trigger a local ``max_restarts`` + decrease, other agent will not decrease their ``max_restarts``. + the user to run the distributed application locally on a dev host. + +A single worker failure can cause the whole cluster to fail: +If a single worker is constantly failing, it will cause the TE agent +``max_restarts`` to go to zero. This will cause an agent to finish its +work and close rendezvous. If there are any other workers on different +agents, they will be terminated. + + +**Re-Rendezvous** + +Re-rendezvous occurs when TE agents detect a new node +trying to joint a cluster. TE will not decrease ``max_restarts``. TE agents +will terminate its workers and start a new rendezvous round. + +Note about DynamicRendezvous(etcd-v2, c10d-experimental): If the rendezvous +has already max_nodes, the new node won't be added to the wait list right +away since there is no need to tear down a rendezvous that is already fully +utilized. The new node will wait until its timeout (600 secs by default) +and periodically check the number of participants. If the number becomes +less than max_nodes, it will be added to the wait list; otherwise, it will time out after 600 secs. + +*Scale up event*. When scale up event happens, torchelastic rendezvous +will detect that there are new nodes trying to join. Torchelastic agent +will stop all workers and perform re-rendezvous. Note: when scale up event +happens, *``max_restarts``* will *not* decrease. + +*Scale down event*. When scale down event happens, rendezvous will not +notify the torchelastic agent about it. If TE agent launched with ``max_restarts=0`` , +it relies on the underlying scheduler to handle job restart. If the ``max_restarts>0`` , +TE agent will terminate workers and start a new rdzv round, which is a *Scale up event*. + +""" diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b044819b76ace30649f2a01dfb19080ce79b3764 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/__pycache__/__init__.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/__pycache__/control_plane.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/__pycache__/control_plane.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b4934f32634eea7039879cb6d4eb9cd127098b0a Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/__pycache__/control_plane.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/agent/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/agent/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/agent/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/agent/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4737dcf3a77ca2bba8985248eb5dfc0c41d070e6 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/agent/__pycache__/__init__.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/agent/server/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/agent/server/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7c0d76131fe40d70945ffa8ff97431954151d50e --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/agent/server/__init__.py @@ -0,0 +1,41 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +The elastic agent is the control plane of torchelastic. + +It is a process that launches and manages underlying worker processes. +The agent is responsible for: + +1. Working with distributed torch: the workers are started with all the + necessary information to successfully and trivially call + ``torch.distributed.init_process_group()``. + +2. Fault tolerance: monitors workers and upon detecting worker failures + or unhealthiness, tears down all workers and restarts everyone. + +3. Elasticity: Reacts to membership changes and restarts workers with the new + members. + +The simplest agents are deployed per node and works with local processes. +A more advanced agent can launch and manage workers remotely. Agents can +be completely decentralized, making decisions based on the workers it manages. +Or can be coordinated, communicating to other agents (that manage workers +in the same job) to make a collective decision. +""" + +from .api import ( # noqa: F401 + ElasticAgent, + RunResult, + SimpleElasticAgent, + Worker, + WorkerGroup, + WorkerSpec, + WorkerState, +) +from .local_elastic_agent import TORCHELASTIC_ENABLE_FILE_TIMER, TORCHELASTIC_TIMER_FILE diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/agent/server/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/agent/server/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6a649be9e5d1b6191dd17668a345f70d36fdcd65 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/agent/server/__pycache__/__init__.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/agent/server/__pycache__/api.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/agent/server/__pycache__/api.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0b5b0cdfe4a9d1dfc9b28581aaebf9b925c1f9d5 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/agent/server/__pycache__/api.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/agent/server/__pycache__/health_check_server.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/agent/server/__pycache__/health_check_server.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..99acebd8aee5ffc8e9986c09aaaa23cd1385706b Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/agent/server/__pycache__/health_check_server.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/agent/server/__pycache__/local_elastic_agent.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/agent/server/__pycache__/local_elastic_agent.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9639631a61b1d49d224c9b83bc57c8a0723632a9 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/agent/server/__pycache__/local_elastic_agent.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/agent/server/api.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/agent/server/api.py new file mode 100644 index 0000000000000000000000000000000000000000..2575aa137a58128213173dde681c313bb24fc5a2 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/agent/server/api.py @@ -0,0 +1,995 @@ +# mypy: ignore-errors + +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import abc +import json +import os +import signal +import socket +import time +import traceback +import warnings +from collections import defaultdict +from collections.abc import Callable +from contextlib import contextmanager +from dataclasses import dataclass, field +from enum import Enum +from typing import Any + +import torch.distributed.elastic.rendezvous as rdzv +import torch.distributed.elastic.utils.store as store_util +from torch.distributed.elastic.events import Event, EventSource, record +from torch.distributed.elastic.metrics import prof, put_metric +from torch.distributed.elastic.multiprocessing import ProcessFailure, SignalException +from torch.distributed.elastic.rendezvous import RendezvousGracefulExitError +from torch.distributed.elastic.utils.logging import get_logger +from torch.numa.binding import NumaOptions + + +__all__ = [ + "WorkerSpec", + "Worker", + "WorkerState", + "WorkerGroup", + "RunResult", + "ElasticAgent", + "SimpleElasticAgent", +] +_TERMINAL_STATE_SYNC_ID = "torchelastic/agent/terminal_state" + +DEFAULT_ROLE = "default" +logger = get_logger(__name__) + + +@dataclass +class WorkerSpec: + """ + Blueprint information about a particular type of worker. + + For a given role, there must only exist a single worker spec. + Worker spec is expected to be homogeneous across all nodes (machine), + that is each node runs the same number of workers for a particular spec. + + Args: + role: user-defined role for the workers with this spec + local_world_size: number local workers to run + fn: (deprecated use entrypoint instead) + entrypoint: worker function or command + args: arguments to pass to ``entrypoint`` + rdzv_handler: handles rdzv for this set of workers + max_restarts: number of max retries for the workers + monitor_interval: monitor status of workers every ``n`` seconds + master_port: fixed port to run the c10d store on rank 0 + if not specified then will chose a random free port + master_addr: fixed master_addr to run the c10d store on rank 0 + if not specified then will chose hostname on agent rank 0 + redirects: redirect std streams to a file, + selectively redirect for a particular + local rank by passing a map + tee: tees the specified std stream(s) to console + file, + selectively tee for a particular local rank by passing a map, + takes precedence over ``redirects`` settings. + event_log_handler: name of the event logging handler as registered in + `elastic/events/handlers.py `_. + duplicate_stdout_filters: If non-empty, duplicates stdout to a file containing only lines + that match _any_ of the filter strings. + duplicate_stderr_filters: If non-empty, duplicates stderr to a file containing only lines + that match _any_ of the filter strings. + virtual_local_rank: Enable virtual local rank mode for workers (defaults to False). + When enabled, LOCAL_RANK is set to 0 for all workers and + CUDA_VISIBLE_DEVICES is adjusted so each worker accesses its + assigned GPU at device index 0. + """ + + role: str + local_world_size: int + rdzv_handler: rdzv.RendezvousHandler + fn: Callable | None = None + # TODO @kiuk - make entrypoint a required field + entrypoint: Callable | str | None = None + args: tuple = () + max_restarts: int = 3 + monitor_interval: float = 0.1 + master_port: int | None = None + master_addr: str | None = None + local_addr: str | None = None + event_log_handler: str = "null" + numa_options: NumaOptions | None = None + duplicate_stdout_filters: list[str] | None = None + duplicate_stderr_filters: list[str] | None = None + virtual_local_rank: bool = False + + def __post_init__(self): + assert self.local_world_size > 0 + assert self.monitor_interval > 0 + + if self.fn: + warnings.warn( + "WorkerSpec.fn will be deprecated," + " please use WorkerSpec.entrypoint instead", + stacklevel=2, + category=DeprecationWarning, + ) + self.entrypoint = self.fn + assert self.entrypoint + + def get_entrypoint_name(self): + """Get the entry point name. + + If the entrypoint is a function (e.g. ``Callable``) returns its ``__qualname__`` + else if the entrypoint is a binary (e.g. ``str``), returns the binary name. + """ + if isinstance(self.entrypoint, str): + return os.path.basename(self.entrypoint) + else: + assert self.entrypoint is not None + return self.entrypoint.__qualname__ + + +class Worker: + """A worker instance. + + Contrast this with ``WorkerSpec`` that represents the specifications of a + worker. A ``Worker`` is created from a ``WorkerSpec``. A ``Worker`` is to + a ``WorkerSpec`` as an object is to a class. + + The ``id`` of the worker is interpreted + by the specific implementation of ``ElasticAgent``. For a local + agent, it could be the ``pid (int)`` of the worker, for a remote + agent it could be encoded as ``host:port (string)``. + + Args: + id (Any): uniquely identifies a worker (interpreted by the agent) + local_rank (int): local rank of the worker + global_rank (int): global rank of the worker + role_rank (int): rank of the worker across all workers that have the same role + world_size (int): number of workers (globally) + role_world_size (int): number of workers that have the same role + """ + + __slots__ = [ + "id", + "local_rank", + "global_rank", + "role_rank", + "world_size", + "role_world_size", + ] + + def __init__( + self, + local_rank: int, + global_rank: int = -1, + role_rank: int = -1, + world_size: int = -1, + role_world_size: int = -1, + ): + # unique identifier for this worker + self.id: Any = None + + # rank of the worker among workers with the same role being monitored + # by the same ``agent`` instance. + self.local_rank: int = local_rank + + # rank of the worker among all the workers across all roles + # across all ``agent`` instances. + # Global rank is not stable between re-rendezvous. + self.global_rank: int = global_rank + + # rank of the worker among all the workers with the same role + # across all ``agent`` instances. + # Role rank is not stable between re-rendezvous. + self.role_rank: int = role_rank + + # total number of workers (globally). Due to elasticity + # the world size may change between re-rendezvous. + self.world_size: int = world_size + + # total number of workers that share the same role. Due to elasticity + # the role world size may change between re-rendezvous. + self.role_world_size: int = role_world_size + + def __str__(self): + return ( + f"local_rank={self.local_rank},global_rank={self.global_rank}" + f",role_rank={self.role_rank},world_size={self.world_size}" + f",role_world_size={self.role_world_size}" + ) + + def __repr__(self): + return str(self) + + +class WorkerState(str, Enum): + """A state of the ``WorkerGroup``. + + Workers in a worker group change state as a unit. If a single worker + in a worker group fails the entire set is considered failed:: + + UNKNOWN - agent lost track of worker group state, unrecoverable + INIT - worker group object created not yet started + HEALTHY - workers running and healthy + UNHEALTHY - workers running and unhealthy + STOPPED - workers stopped (interrupted) by the agent + SUCCEEDED - workers finished running (exit 0) + FAILED - workers failed to successfully finish (exit !0) + + + A worker group starts from an initial ``INIT`` state, + then progresses to ``HEALTHY`` or ``UNHEALTHY`` states, + and finally reaches a terminal ``SUCCEEDED`` or ``FAILED`` state. + + Worker groups can be interrupted and temporarily put into ``STOPPED`` state + by the agent. Workers in ``STOPPED`` state are scheduled to be restarted + in the near future by the agent. Some examples of workers being put into + ``STOPPED`` state are: + + 1. Worker group failure|unhealthy observed + 2. Membership change detected + + When actions (start, stop, rdzv, retry, etc) on worker group fails + and results in the action being partially applied to the worker group + the state will be ``UNKNOWN``. Typically this happens on uncaught/unhandled + exceptions during state change events on the agent. The agent is not + expected to recover worker groups in ``UNKNOWN`` state and is better off + self terminating and allowing the job manager to retry the node. + """ + + UNKNOWN = "UNKNOWN" + INIT = "INIT" + HEALTHY = "HEALTHY" + UNHEALTHY = "UNHEALTHY" + STOPPED = "STOPPED" + SUCCEEDED = "SUCCEEDED" + FAILED = "FAILED" + + @staticmethod + def is_running(state: "WorkerState") -> bool: + """Return the state of the Worker. + + Returns: + True if the worker state represents workers still running + (e.g. that the process exists but not necessarily healthy). + """ + return state in {WorkerState.HEALTHY, WorkerState.UNHEALTHY} + + +class WorkerGroup: + """A set of ``Worker`` instances. + + The class defines a set of ``Worker`` instances for the given ``WorkerSpec`` managed by ``ElasticAgent``. Whether the worker + group contains cross instance workers or not depends on the implementation of the agent. + """ + + __slots__ = [ + "spec", + "workers", + "store", + "group_rank", + "group_world_size", + "state", + "master_addr", + "master_port", + ] + + def __init__(self, spec: WorkerSpec): + self.spec = spec + self.workers = [Worker(local_rank=i) for i in range(self.spec.local_world_size)] + + # assigned after rdzv + self.store = None + self.group_rank = None + self.group_world_size = None + self.master_addr = None + self.master_port = None + + self.state = WorkerState.INIT + + +class _RoleInstanceInfo: + """The class is used by the agent to exchange the information with other agents. + + The information is used to determine the rank of the workers that agent + manages in heterogeneous environments, where different agents can have + different number of workers. + """ + + __slots__ = ["role", "rank", "local_world_size"] + + def __init__(self, role: str, rank: int, local_world_size: int): + r"""Initialize the agent class instance. + + Args: + role (str): user-defined role for the workers with this spec + rank (int): the rank of the agent + local_world_size (int): number of local workers to run + """ + self.role = role + self.rank = rank + self.local_world_size = local_world_size + + def serialize(self) -> bytes: + dict_data = { + "role": self.role, + "rank": self.rank, + "local_world_size": self.local_world_size, + } + return json.dumps(dict_data).encode(encoding="UTF-8") + + @staticmethod + def deserialize(data: bytes): + dict_data = json.loads(data.decode(encoding="UTF-8")) + return _RoleInstanceInfo( + dict_data["role"], dict_data["rank"], dict_data["local_world_size"] + ) + + @staticmethod + def compare(obj1, obj2) -> int: + if obj1.role == obj2.role: + return obj1.rank - obj2.rank + elif obj1.role > obj2.role: + return 1 + else: + return -1 + + @staticmethod + def find_role_boundaries(roles_infos: list, role: str) -> tuple[int, int]: + start_idx, end_idx = -1, -1 + for idx, role_info in enumerate(roles_infos): + if role_info.role == role: + if start_idx == -1: + start_idx = idx + end_idx = idx + return (start_idx, end_idx) + + +@dataclass +class RunResult: + """Return results of the worker executions. + + Run results follow an "all-or-nothing" policy where the run is successful if and + only if ALL local workers managed by this agent complete successfully. + + If the result is successful (e.g. ``is_failed() = False``) then the ``return_values`` + field contains the outputs (return values) of the workers managed by THIS agent mapped + by their GLOBAL ranks. That is ``result.return_values[0]`` is the return value of + global rank 0. + + .. note:: ``return_values`` are only meaningful for when the worker entrypoint + is a function. Workers specified as a binary entrypoint do not canonically + have a return value and the ``return_values`` field is meaningless and + may be empty. + + If ``is_failed()`` returns ``True`` then the ``failures`` field contains the + failure information, again, mapped by the GLOBAL rank of the worker that failed. + + The keys in ``return_values`` and ``failures`` are mutually exclusive, that is, + a worker's final state can only be one of: succeeded, failed. Workers intentionally + terminated by the agent according to the agent's restart policy, are not represented + in either ``return_values`` nor ``failures``. + """ + + state: WorkerState + return_values: dict[int, Any] = field(default_factory=dict) + failures: dict[int, ProcessFailure] = field(default_factory=dict) + + def is_failed(self) -> bool: + return self.state == WorkerState.FAILED + + +def _get_fq_hostname() -> str: + return socket.getfqdn(socket.gethostname()) + + +class ElasticAgent(abc.ABC): + """An agent process responsible for managing one or more worker processes. + + The worker processes are assumed to be regular distributed PyTorch scripts. + When the worker process is created by the agent, the agent provides the + necessary information for the worker processes to properly initialize + a torch process group. + + The exact deployment topology and ratio of agent-to-worker is dependent + on the specific implementation of the agent and the user's job placement + preferences. For instance, to run a distributed training job on GPU with + 8 trainers (one per GPU) one can: + + 1. Use 8 x single GPU instances, place an agent per instance, managing + 1 worker per agent. + 2. Use 4 x double GPU instances, place an agent per instance, managing + 2 workers per agent. + 3. Use 2 x quad GPU instances, place an agent per instance, managing + 4 workers per agent. + 4. Use 1 x 8 GPU instance, place an agent per instance, managing + 8 workers per agent. + + Usage + :: + + group_result = agent.run() + if group_result.is_failed(): + # workers failed + failure = group_result.failures[0] + logger.exception("worker 0 failed with exit code : %s", failure.exit_code) + else: + return group_result.return_values[0] # return rank 0's results + + """ + + @abc.abstractmethod + def run(self, role: str = DEFAULT_ROLE) -> RunResult: + """Run the agent. + + Supports retrying the worker group on failures up to ``max_restarts``. + + Returns: + The result of the execution, containing the return values or + failure details for each worker mapped by the worker's global rank. + + Raises: + Exception - any other failures NOT related to worker process + """ + raise NotImplementedError + + @abc.abstractmethod + def get_worker_group(self, role: str = DEFAULT_ROLE) -> WorkerGroup: + """Return the ``WorkerGroup`` for the given ``role``. + + Note that the worker group is a mutable object and hence in a + multi-threaded/process environment it may change state. + Implementers are encouraged (but not required) to return + a defensive read-only copy. + """ + raise NotImplementedError + + +class SimpleElasticAgent(ElasticAgent): + """An ``ElasticAgent`` that manages one particular type of worker role. + + An ``ElasticAgent`` that manages workers (``WorkerGroup``) for a single ``WorkerSpec`` + such as one particular type of worker role. + """ + + def __init__(self, spec: WorkerSpec, exit_barrier_timeout: float = 300): + self._worker_group = WorkerGroup(spec) + self._remaining_restarts = self._worker_group.spec.max_restarts + self._store = None + self._exit_barrier_timeout = exit_barrier_timeout + self._total_execution_time = 0 + + def get_worker_group(self, role: str = DEFAULT_ROLE) -> WorkerGroup: + return self._worker_group + + @abc.abstractmethod + def _start_workers(self, worker_group: WorkerGroup) -> dict[int, Any]: + r"""Start ``worker_group.spec.local_world_size`` number of workers. + + This is according to worker spec for the worker group . + Returns a map of ``local_rank`` to worker ``id``. + """ + raise NotImplementedError + + @abc.abstractmethod + def _stop_workers(self, worker_group: WorkerGroup) -> None: + r"""Stop all workers in the given worker group. + + Implementers must deal with workers in all states defined by + ``WorkerState``. That is, it must gracefully handle stopping + non-existent workers, unhealthy (stuck) workers, etc. + """ + raise NotImplementedError + + @abc.abstractmethod + def _monitor_workers(self, worker_group: WorkerGroup) -> RunResult: + r"""Check on the workers for the ``worker_group``. + + This function also returns the new state of the worker group. + """ + raise NotImplementedError + + @abc.abstractmethod + def _shutdown(self, death_sig: signal.Signals = signal.SIGTERM) -> None: + """Clean up any resources that were allocated during the agent's work. + + Args: + death_sig: Signal to send to the child process, SIGTERM is default + """ + raise NotImplementedError + + @prof + def _rendezvous(self, worker_group: WorkerGroup) -> None: + r"""Run rendezvous for the workers specified by the worker spec. + + Assigns workers a new global rank and world size. + Updates the rendezvous store for the worker group. + """ + spec = worker_group.spec + + with self.record_duration("RENDEZVOUS"): + rdzv_info = spec.rdzv_handler.next_rendezvous() + store = rdzv_info.store + group_rank = rdzv_info.rank + group_world_size = rdzv_info.world_size + + # master_addr/master_port could be explicitly overridden + # TODO: BC - specific to static rdzv and can be simplified further + master_addr = spec.master_addr or rdzv_info.bootstrap_store_info.master_addr + master_port = spec.master_port or rdzv_info.bootstrap_store_info.master_port + + self._store = store + + with self.record_duration("ASSIGN_WORKER_RANKS"): + workers = self._assign_worker_ranks( + store, group_rank, group_world_size, spec + ) + worker_group.workers = workers + worker_group.store = store + worker_group.group_rank = group_rank + worker_group.group_world_size = group_world_size + worker_group.master_addr = master_addr + worker_group.master_port = master_port + + restart_count = spec.max_restarts - self._remaining_restarts + + logger.info( + "[%(role)s] Rendezvous complete for workers. Result:\n" + " restart_count=%(restart_count)s\n" + " master_addr=%(master_addr)s\n" + " master_port=%(master_port)s\n" + " group_rank=%(group_rank)s\n" + " group_world_size=%(group_world_size)s\n" + " local_ranks=%(local_ranks)s\n" + " role_ranks=%(role_ranks)s\n" + " global_ranks=%(global_ranks)s\n" + " role_world_sizes=%(role_world_sizes)s\n" + " global_world_sizes=%(global_world_sizes)s\n" + " event_log_handler=%(event_log_handler)s\n", + { + "role": spec.role, + "restart_count": restart_count, + "master_addr": master_addr, + "master_port": master_port, + "group_rank": group_rank, + "group_world_size": group_world_size, + "local_ranks": [worker.local_rank for worker in workers], + "role_ranks": [worker.role_rank for worker in workers], + "global_ranks": [worker.global_rank for worker in workers], + "role_world_sizes": [worker.role_world_size for worker in workers], + "global_world_sizes": [worker.world_size for worker in workers], + "event_log_handler": spec.event_log_handler, + }, + ) + + # pyre-fixme[56]: Pyre was not able to infer the type of the decorator + # `torch.distributed.elastic.metrics.prof`. + @prof + def _assign_worker_ranks( + self, store, group_rank: int, group_world_size: int, spec: WorkerSpec + ) -> list[Worker]: + """Determine proper ranks for worker processes. + + Fast Path: when all workers have the same role and world size. We calculate + the global rank to be group_rank * group_world_size + local_rank. And the + `role_world_size` is the same as `global_world_size`. No TCP store is used in + this case. This is only enabled when users set the environment variable + `TORCH_ELASTIC_WORKER_IDENTICAL` to 1. + + Time complexity: each worker O(1), overall O(1) + + Slow Path: when workers have different roles and world sizes. We use the + the following algorithm: + + 1. Each agent writes its configuration(group_rank, group_world_size + , num_workers) to the common store. + 2. The rank 0 agent reads all the role_info from the store and + determines each agents worker ranks. + 3. Determine the global rank: the global rank of the workers is computed + by cumulative sum of the local_world_size for all workers in front of it. + For efficiency reasons each worker is assigned a base global rank + such that it's workers are in the range [base_global_rank, + base_global_rank + local_world_size). + 4. Determine the role rank: The role rank is determined using the algorithms + in the point 3 with the exception that the ranks are calculated with + respect to the role name. + 5. The rank 0 agent writes the assigned ranks to the store. + 6. Each agent reads the assigned ranks from the store. + + Time complexity: each worker O(1), rank0 O(n), overall O(n) + """ + + if os.environ.get("TORCH_ELASTIC_WORKER_IDENTICAL", "0") == "1": + global_world_size = group_world_size * spec.local_world_size + base_global_rank = group_rank * spec.local_world_size + base_role_rank = base_global_rank + role_world_size = global_world_size + else: + ROLE_INFO_PREFIX = "torchelastic/role_info/" + ASSIGNED_RANKS_PREFIX = "torchelastic/assigned_ranks/" + + agent_role_info = _RoleInstanceInfo( + spec.role, group_rank, spec.local_world_size + ) + store.set(f"{ROLE_INFO_PREFIX}{group_rank}", agent_role_info.serialize()) + + # tcp store is collocated with rank 0 so we can use it to do extra compute to reduce overall # of operations. + if group_rank == 0: + role_infos_bytes = store.multi_get( + [f"torchelastic/role_info/{i}" for i in range(group_world_size)] + ) + role_infos = [ + _RoleInstanceInfo.deserialize(info_bytes) + for info_bytes in role_infos_bytes + ] + + role_sizes = defaultdict(lambda: 0) + global_size = 0 + for role_info in role_infos: + role_sizes[role_info.role] += role_info.local_world_size + global_size += role_info.local_world_size + + base_global_rank = 0 + role_ranks = defaultdict(lambda: 0) + + keys = [] + values = [] + for i, role_info in enumerate(role_infos): + keys.append(f"{ASSIGNED_RANKS_PREFIX}{i}") + values.append( + json.dumps( + [ + base_global_rank, + global_size, + role_ranks[role_info.role], + role_sizes[role_info.role], + ] + ) + ) + + base_global_rank += role_info.local_world_size + role_ranks[role_info.role] += role_info.local_world_size + + store.multi_set(keys, values) + + # get will block until the data is available in the store. + ( + base_global_rank, + global_world_size, + base_role_rank, + role_world_size, + ) = json.loads(store.get(f"{ASSIGNED_RANKS_PREFIX}{group_rank}")) + + workers = [] + for local_rank in range(spec.local_world_size): + worker = Worker( + local_rank=local_rank, + global_rank=base_global_rank + local_rank, + role_rank=base_role_rank + local_rank, + world_size=global_world_size, + role_world_size=role_world_size, + ) + workers.append(worker) + return workers + + # pyre-fixme[56]: Pyre was not able to infer the type of the decorator + # `torch.distributed.elastic.metrics.prof`. + @prof + def _initialize_workers(self, worker_group: WorkerGroup) -> None: + r"""Start a fresh set of workers for the worker_group. + + Essentially, a rendezvous followed by a ``start_workers``. + The caller should first call ``_stop_workers()`` to stop running workers + prior to calling this method. + + Optimistically sets the state of the worker group that + just started as ``HEALTHY`` and delegates the actual monitoring + of state to ``_monitor_workers()`` method + """ + role = worker_group.spec.role + logger.info("[%s] Rendezvous'ing worker group", role) + + # TODO after stopping workers, wait at least monitor_interval*2 for + # workers on different nodes to fail on a collective op before waiting + # on the rdzv barrier, this way we ensure that nodes enter rdzv + # at around the same time and reduce false positive rdzv timeout errors + self._rendezvous(worker_group) + + logger.info("[%s] Starting worker group", role) + worker_ids = self._start_workers(worker_group) + for local_rank, w_id in worker_ids.items(): + worker = worker_group.workers[local_rank] + worker.id = w_id + record( + self._construct_event("START", EventSource.WORKER, worker), + worker_group.spec.event_log_handler, + ) + + worker_group.state = WorkerState.HEALTHY + + # pyre-fixme[56]: Pyre was not able to infer the type of the decorator + # `torch.distributed.elastic.metrics.prof`. + @prof + def _restart_workers(self, worker_group: WorkerGroup) -> None: + """Restart (stops, rendezvous, starts) all local workers in the group.""" + role = worker_group.spec.role + logger.info("[%s] Stopping worker group", role) + self._stop_workers(worker_group) + worker_group.state = WorkerState.STOPPED + self._initialize_workers(worker_group) + + # pyre-fixme[56]: Pyre was not able to infer the type of the decorator + # `torch.distributed.elastic.metrics.prof`. + @prof + def run(self, role: str = DEFAULT_ROLE) -> RunResult: + start_time = time.monotonic() + shutdown_called: bool = False + try: + result = self._invoke_run(role) + self._total_execution_time = int(time.monotonic() - start_time) + self._record_metrics(result) + self._record_worker_events(result) + return result + except RendezvousGracefulExitError as e: + logger.info("Rendezvous gracefully exited: %s", e) # noqa: G200 + except SignalException as e: + logger.warning("Received %s death signal, shutting down workers", e.sigval) + self._shutdown(e.sigval) + shutdown_called = True + raise + finally: + if not shutdown_called: + self._shutdown() + # record the execution time in case there were any exceptions during run. + self._total_execution_time = int(time.monotonic() - start_time) + + def get_event_failed(self) -> Event: + return self._construct_event( + state="FAILED", + source=EventSource.AGENT, + raw_error=traceback.format_exc(), + ) + + def get_event_succeeded(self) -> Event: + return self._construct_event( + state="SUCCEEDED", + source=EventSource.AGENT, + ) + + def _record_worker_events(self, result: RunResult) -> None: + for worker in self._worker_group.workers: + failure = result.failures.get(worker.global_rank) + state: str = self._get_worker_state(worker, result) + raw_error = json.dumps(failure.error_file_data) if failure else None + exit_code = failure.exitcode if failure else None + worker_pid = failure.pid if failure else None + record( + self._construct_event( + state=state, + source=EventSource.WORKER, + worker=worker, + raw_error=raw_error, + exit_code=exit_code, + worker_pid=worker_pid, + ), + self._worker_group.spec.event_log_handler, + ) + + def _get_worker_state(self, worker: Worker, result: RunResult) -> str: + failure = result.failures.get(worker.global_rank) + if result.state in {WorkerState.UNHEALTHY, WorkerState.FAILED} and not failure: + # The worker got terminated by the torchelastic agent via SIGTERM signal + return "TERMINATED" + elif failure or worker.global_rank in result.return_values: + return result.state.value + else: + raise ValueError(f"Unknown worker: {worker.global_rank}") + + @contextmanager + def record_duration(self, state: str): + start_time = time.perf_counter() + try: + yield + finally: + end_time = time.perf_counter() + duration_ms = (end_time - start_time) * 1000 + record( + self._construct_event( + state=state, source=EventSource.AGENT, duration_ms=duration_ms + ), + self._worker_group.spec.event_log_handler, + ) + + def _construct_event( + self, + state: str, + source: EventSource, + worker: Worker | None = None, + raw_error: str | None = None, + duration_ms: float | None = None, + exit_code: int | None = None, + worker_pid: int | None = None, + ) -> Event: + wg = self._worker_group + spec = wg.spec + md = { + "group_world_size": wg.group_world_size, + "entry_point": spec.get_entrypoint_name(), + } + if worker: + md["local_rank"] = (worker.local_rank,) + md["role_rank"] = (worker.role_rank,) + md["role_world_size"] = (worker.role_world_size,) + md["exit_code"] = (exit_code,) + md["worker_pid"] = (worker_pid,) + global_rank = worker.global_rank + worker_id = str(worker.id) + else: + global_rank = None + worker_id = None + md_str = json.dumps(md) + metadata = { + "run_id": spec.rdzv_handler.get_run_id(), + "global_rank": global_rank, + "group_rank": wg.group_rank, + "worker_id": worker_id, + "role": spec.role, + "hostname": _get_fq_hostname(), + "state": state, + "total_run_time": self._total_execution_time, + "rdzv_backend": spec.rdzv_handler.get_backend(), + "raw_error": raw_error, + "metadata": md_str, + "agent_restarts": spec.max_restarts - self._remaining_restarts, + "duration_ms": duration_ms, + } + + return Event( + f"torchelastic.worker.status.{state}", source=source, metadata=metadata + ) + + def _record_metrics(self, group_results: RunResult): + is_failed = group_results.is_failed() + self._record_flakiness_metric(is_failed) + spec = self._worker_group.spec + restarts_happened = self._remaining_restarts != spec.max_restarts + put_metric(f"workers.{spec.role}.run_total", 1) + self._record_metric_with_condition( + "run_success_with_retries", not is_failed and restarts_happened + ) + self._record_metric_with_condition( + "run_success_no_retries", not is_failed and not restarts_happened + ) + self._record_metric_with_condition( + "run_failed_with_retries", is_failed and restarts_happened + ) + self._record_metric_with_condition( + "run_failed_no_retries", is_failed and not restarts_happened + ) + + def _record_metric_with_condition(self, metric_name, condition): + spec = self._worker_group.spec + if condition: + put_metric(f"workers.{spec.role}.{metric_name}", 1) + else: + put_metric(f"workers.{spec.role}.{metric_name}", 0) + + def _record_flakiness_metric(self, is_failed: bool = False): + if is_failed: + flakiness = 100.0 + else: + spec = self._worker_group.spec + flakiness = 100.0 - 100.0 * (self._remaining_restarts + 1) / ( + spec.max_restarts + 1 + ) + spec = self._worker_group.spec + + put_metric(f"workers.{spec.role}.flakiness", int(flakiness)) + + def _invoke_run(self, role: str = DEFAULT_ROLE) -> RunResult: + # NOTE: currently only works for a single role + + spec = self._worker_group.spec + role = spec.role + + logger.info( + "[%s] starting workers for entrypoint: %s", role, spec.get_entrypoint_name() + ) + + self._initialize_workers(self._worker_group) + monitor_interval = spec.monitor_interval + rdzv_handler = spec.rdzv_handler + + while True: + assert self._worker_group.state != WorkerState.INIT + time.sleep(monitor_interval) + run_result = self._monitor_workers(self._worker_group) + state = run_result.state + self._worker_group.state = state + + put_metric(f"workers.{role}.remaining_restarts", self._remaining_restarts) + put_metric(f"workers.{role}.{state.name.lower()}", 1) + + if state == WorkerState.SUCCEEDED: + logger.info( + "[%s] worker group successfully finished." + " Waiting %s seconds for other agents to finish.", + role, + self._exit_barrier_timeout, + ) + self._exit_barrier() + return run_result + elif state in {WorkerState.UNHEALTHY, WorkerState.FAILED}: + if self._remaining_restarts > 0: + logger.info( + "[%s] Worker group %s. " + "%s/%s attempts left;" + " will restart worker group", + role, + state.name, + self._remaining_restarts, + spec.max_restarts, + ) + self._remaining_restarts -= 1 + self._restart_workers(self._worker_group) + else: + self._stop_workers(self._worker_group) + self._worker_group.state = WorkerState.FAILED + return run_result + elif state == WorkerState.HEALTHY: + # membership changes do not count as retries + num_nodes_waiting = rdzv_handler.num_nodes_waiting() + group_rank = self._worker_group.group_rank + if num_nodes_waiting > 0: + logger.info( + "[%s] Detected %s " + "new nodes from group_rank=%s; " + "will restart worker group", + role, + num_nodes_waiting, + group_rank, + ) + self._restart_workers(self._worker_group) + else: + raise Exception( # noqa: TRY002 + f"[{role}] Worker group in {state.name} state" + ) + + def _exit_barrier(self): + """ + Define a barrier that keeps the agent process alive until all workers finish. + + Wait for ``exit_barrier_timeout`` seconds for all agents to finish + executing their local workers (either successfully or not). This + acts as a safety guard against user scripts that terminate at different + times. + """ + logger.info( + "Local worker group finished (%s). " + "Waiting %s seconds for other agents to finish", + self._worker_group.state, + self._exit_barrier_timeout, + ) + start = time.time() + try: + store_util.barrier( + store=self._store, + world_size=self._worker_group.group_world_size, + key_prefix=_TERMINAL_STATE_SYNC_ID, + barrier_timeout=self._exit_barrier_timeout, + ) + logger.info( + "Done waiting for other agents. Elapsed: %s seconds", + time.time() - start, + ) + except SignalException as e: + logger.warning("Got termination signal: %s", e.sigval) + raise + except Exception: + logger.exception( + "Error waiting on exit barrier. Elapsed: %s seconds", + time.time() - start, + ) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/agent/server/health_check_server.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/agent/server/health_check_server.py new file mode 100644 index 0000000000000000000000000000000000000000..4815d86aa289c531a01bfcc8277b7ae9ffb2930e --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/agent/server/health_check_server.py @@ -0,0 +1,65 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from collections.abc import Callable + +from torch.distributed.elastic.utils.logging import get_logger + + +log = get_logger(__name__) + +__all__ = ["HealthCheckServer", "create_healthcheck_server"] + + +class HealthCheckServer: + """ + Interface for health check monitoring server, which can be extended + by starting tcp/http server on the specified port. + + Args: + + alive_callback: Callable[[], int], callback to last progress time of agent + + port: int, port number to start tcp/http server + + timeout: int, timeout seconds to decide agent is alive/dead + """ + + _alive_callback: Callable[[], int] + _port: int + _timeout: int + + def __init__( + self, alive_callback: Callable[[], int], port: int, timeout: int + ) -> None: + self._alive_callback = alive_callback + self._port = port + self._timeout = timeout + + def start(self) -> None: + """ + Unsupported functionality for Pytorch, doesn't start any health check server + """ + log.warning("No health check server started") + + def stop(self) -> None: + """ + Function to stop health check server + """ + log.info("Stopping noop health check server.") + + +def create_healthcheck_server( + alive_callback: Callable[[], int], + port: int, + timeout: int, +) -> HealthCheckServer: + """ + creates health check server object + """ + return HealthCheckServer(alive_callback, port, timeout) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/agent/server/local_elastic_agent.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/agent/server/local_elastic_agent.py new file mode 100644 index 0000000000000000000000000000000000000000..ef281b6c58c318a06e2c97832ab43171313e56df --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/agent/server/local_elastic_agent.py @@ -0,0 +1,456 @@ +#!/usr/bin/env python3 +# mypy: allow-untyped-defs + +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +import json +import os +import signal +import socket +import time +import uuid +from string import Template +from typing import Any, TYPE_CHECKING + +import torch.distributed.elastic.timer as timer +from torch.distributed.elastic import events +from torch.distributed.elastic.agent.server.api import ( + RunResult, + SimpleElasticAgent, + WorkerGroup, + WorkerSpec, + WorkerState, +) +from torch.distributed.elastic.agent.server.health_check_server import ( + create_healthcheck_server, + HealthCheckServer, +) +from torch.distributed.elastic.metrics.api import prof +from torch.distributed.elastic.multiprocessing import ( + LogsSpecs, + PContext, + start_processes, +) +from torch.distributed.elastic.utils import macros +from torch.distributed.elastic.utils.logging import get_logger + + +if TYPE_CHECKING: + from torch.distributed.elastic.events.api import EventMetadataValue + +logger = get_logger(__name__) + +__all__ = [ + "LocalElasticAgent", + "TORCHELASTIC_ENABLE_FILE_TIMER", + "TORCHELASTIC_TIMER_FILE", + "TORCHELASTIC_HEALTH_CHECK_PORT", +] + +TORCHELASTIC_ENABLE_FILE_TIMER = "TORCHELASTIC_ENABLE_FILE_TIMER" +TORCHELASTIC_HEALTH_CHECK_PORT = "TORCHELASTIC_HEALTH_CHECK_PORT" +TORCHELASTIC_TIMER_FILE = "TORCHELASTIC_TIMER_FILE" + + +class LocalElasticAgent(SimpleElasticAgent): + """An implementation of :py:class:`torchelastic.agent.server.ElasticAgent` that handles host-local workers. + + This agent is deployed per host and is configured to spawn ``n`` workers. + When using GPUs, ``n`` maps to the number of GPUs available on the host. + + The local agent does not communicate to other local agents deployed on + other hosts, even if the workers may communicate inter-host. The worker id + is interpreted to be a local process. The agent starts and stops all worker + processes as a single unit. + + + The worker function and argument passed to the worker function must be + python multiprocessing compatible. To pass multiprocessing data structures + to the workers you may create the data structure in the same multiprocessing + context as the specified ``start_method`` and pass it as a function argument. + + The ``exit_barrier_timeout`` specifies the amount of time (in seconds) to wait + for other agents to finish. This acts as a safety net to handle cases where + workers finish at different times, to prevent agents from viewing workers + that finished early as a scale-down event. It is strongly advised that the + user code deal with ensuring that workers are terminated in a synchronous + manner rather than relying on the exit_barrier_timeout. + + A named pipe based watchdog can be enabled in ```LocalElasticAgent``` if an + environment variable ``TORCHELASTIC_ENABLE_FILE_TIMER`` with value 1 has + been defined in the ```LocalElasticAgent``` process. + Optionally, another environment variable ```TORCHELASTIC_TIMER_FILE``` + can be set with a unique file name for the named pipe. If the environment + variable ```TORCHELASTIC_TIMER_FILE``` is not set, ```LocalElasticAgent``` + will internally create a unique file name and set it to the environment + variable ```TORCHELASTIC_TIMER_FILE```, and this environment variable will + be propagated to the worker processes to allow them to connect to the same + named pipe that ```LocalElasticAgent``` uses. + + Logs are written to the specified log directory. Each log line will be by default + prefixed by ``[${role_name}${local_rank}]:`` (e.g. ``[trainer0]: foobar``). + Log prefixes can be customized by passing a `template string + `_ as the + ``log_line_prefix_template`` argument. + The following macros (identifiers) are substituted at runtime: + ``${role_name}, ${local_rank}, ${rank}``. For example, to prefix each log line with + global rank instead of the local rank, set ``log_line_prefix_template = "[${rank}]:``. + + + Example launching function + + :: + + def trainer(args) -> str: + return "do train" + + def main(): + start_method="spawn" + shared_queue= multiprocessing.get_context(start_method).Queue() + spec = WorkerSpec( + role="trainer", + local_world_size=nproc_per_process, + entrypoint=trainer, + args=("foobar",), + ...) + agent = LocalElasticAgent(spec, start_method) + results = agent.run() + + if results.is_failed(): + print("trainer failed") + else: + print(f"rank 0 return value: {results.return_values[0]}") + # prints -> rank 0 return value: do train + + Example launching binary + + :: + + def main(): + spec = WorkerSpec( + role="trainer", + local_world_size=nproc_per_process, + entrypoint="/usr/local/bin/trainer", + args=("--trainer-args", "foobar"), + ...) + agent = LocalElasticAgent(spec) + results = agent.run() + + if not results.is_failed(): + print("binary launches do not have return values") + + """ + + def __init__( + self, + spec: WorkerSpec, + logs_specs: LogsSpecs, + start_method="spawn", + exit_barrier_timeout: float = 300, + log_line_prefix_template: str | None = None, + ): + super().__init__(spec, exit_barrier_timeout) + self._start_method = start_method + self._pcontext: PContext | None = None + self._rdzv_handler = spec.rdzv_handler + self._log_line_prefix_template = log_line_prefix_template + self._worker_watchdog: timer.FileTimerServer | None = None + self._logs_specs = logs_specs + self._health_check_server: HealthCheckServer | None = None + + def _setup_local_watchdog(self, envs: dict[int, dict[str, str]]) -> None: + enable_watchdog_env_name = TORCHELASTIC_ENABLE_FILE_TIMER + watchdog_enabled = os.getenv(enable_watchdog_env_name) + watchdog_file_env_name = TORCHELASTIC_TIMER_FILE + watchdog_file_path = os.getenv(watchdog_file_env_name) + if watchdog_enabled is not None and str(watchdog_enabled) == "1": + if watchdog_file_path is None: + watchdog_file_path = "/tmp/watchdog_timer_" + str(uuid.uuid4()) + logger.info("Starting a FileTimerServer with %s ...", watchdog_file_path) + if not envs: + logger.warning( + "Empty envs variables, using empty run_id for FileTimerServer" + ) + run_id = "" + else: + run_id = envs[0]["TORCHELASTIC_RUN_ID"] + self._worker_watchdog = timer.FileTimerServer( + file_path=watchdog_file_path, + run_id=run_id, + max_interval=0.1, + daemon=True, + log_event=self._log_watchdog_event, + ) + self._worker_watchdog.start() + logger.info("FileTimerServer started") + else: + logger.info( + "Environment variable '%s' not found. Do not start FileTimerServer.", + enable_watchdog_env_name, + ) + # Propagate the watchdog file env to worker processes + if watchdog_file_path is not None: + for worker_env in envs.values(): + worker_env[watchdog_file_env_name] = watchdog_file_path + + @staticmethod + def _get_current_time_secs() -> int: + return int(time.time()) + + def _setup_healthcheck(self) -> None: + healthcheck_port_env_name = TORCHELASTIC_HEALTH_CHECK_PORT + healthcheck_port = os.getenv(healthcheck_port_env_name) + if healthcheck_port is not None: + logger.info( + "Found healthcheck port %s: %s", + healthcheck_port_env_name, + healthcheck_port, + ) + if self._worker_watchdog is None: + logger.info( + "FileTimerServer doesn't exist, using current time as dummy callback" + ) + alive_callback = LocalElasticAgent._get_current_time_secs + else: + alive_callback = self._worker_watchdog.get_last_progress_time + + try: + healthcheck_port_as_int = int(healthcheck_port) + self._health_check_server = create_healthcheck_server( + alive_callback=alive_callback, + port=healthcheck_port_as_int, + timeout=60, + ) + self._health_check_server.start() + except ValueError: + logger.info( + "Invalid healthcheck port value: '%s', expecting integer. Not starting healthcheck server.", + healthcheck_port, + ) + else: + logger.info( + "Environment variable '%s' not found. Do not start health check.", + healthcheck_port_env_name, + ) + + def _get_fq_hostname(self) -> str: + return socket.getfqdn(socket.gethostname()) + + def _log_watchdog_event( + self, + name: str, + request: timer.FileTimerRequest | None, + ) -> None: + wg = self._worker_group + spec = wg.spec + md = {"watchdog_event": name} + if request is not None: + md["worker_pid"] = str(request.worker_pid) + md["scope_id"] = request.scope_id + md["expiration_time"] = str(request.expiration_time) + md["signal"] = str(request.signal) + md_str = json.dumps(md) + state = "RUNNING" + metadata: dict[str, EventMetadataValue] = { + "run_id": spec.rdzv_handler.get_run_id(), + "global_rank": None, + "group_rank": wg.group_rank, + "worker_id": None, + "role": spec.role, + "hostname": self._get_fq_hostname(), + "state": state, + "total_run_time": self._total_execution_time, + "rdzv_backend": spec.rdzv_handler.get_backend(), + "raw_error": None, + "metadata": md_str, + "agent_restarts": spec.max_restarts - self._remaining_restarts, + } + # Note: The 'metadata' field of the Event is converted to a TorchelasticStatusLogEntry later. + # The 'name' field of the Event is NOT used in the TorchelasticStatusLogEntry. + event = events.Event( + name=name, source=events.EventSource.AGENT, metadata=metadata + ) + events.record(event, self._worker_group.spec.event_log_handler) + + # pyre-fixme[56]: Pyre was not able to infer the type of the decorator + # `torch.distributed.elastic.metrics.prof`. + @prof + def _stop_workers(self, worker_group: WorkerGroup) -> None: + self._shutdown() + + # pyre-fixme[56]: Pyre was not able to infer the type of the decorator + # `torch.distributed.elastic.metrics.prof`. + @prof + def _start_workers(self, worker_group: WorkerGroup) -> dict[int, Any]: + spec = worker_group.spec + store = worker_group.store + assert store is not None + restart_count = spec.max_restarts - self._remaining_restarts + + use_agent_store: bool = spec.rdzv_handler.use_agent_store + logger.info("use_agent_store: %s", use_agent_store) + + args: dict[int, tuple] = {} + envs: dict[int, dict[str, str]] = {} + log_line_prefixes: dict[int, str] | None = ( + {} if self._log_line_prefix_template else None + ) + for worker in worker_group.workers: + local_rank = worker.local_rank + worker_env = { + "RANK": str(worker.global_rank), + "GROUP_RANK": str(worker_group.group_rank), + "ROLE_RANK": str(worker.role_rank), + "ROLE_NAME": spec.role, + "LOCAL_WORLD_SIZE": str(spec.local_world_size), + "WORLD_SIZE": str(worker.world_size), + "GROUP_WORLD_SIZE": str(worker_group.group_world_size), + "ROLE_WORLD_SIZE": str(worker.role_world_size), + "MASTER_ADDR": worker_group.master_addr, + "MASTER_PORT": str(worker_group.master_port), + "TORCHELASTIC_RESTART_COUNT": str(restart_count), + "TORCHELASTIC_MAX_RESTARTS": str(spec.max_restarts), + "TORCHELASTIC_RUN_ID": spec.rdzv_handler.get_run_id(), + "TORCHELASTIC_USE_AGENT_STORE": str(use_agent_store), + "TORCH_NCCL_ASYNC_ERROR_HANDLING": os.getenv( + "TORCH_NCCL_ASYNC_ERROR_HANDLING", str(1) + ), + } + self._set_local_rank_env(worker_env, local_rank, spec) + if "OMP_NUM_THREADS" in os.environ: + worker_env["OMP_NUM_THREADS"] = os.environ["OMP_NUM_THREADS"] + + if self._log_line_prefix_template: + log_line_prefix = Template( + self._log_line_prefix_template + ).safe_substitute( + role_name=spec.role, + rank=worker.global_rank, + local_rank=local_rank, + ) + # pyrefly: ignore [unsupported-operation] + log_line_prefixes[local_rank] = log_line_prefix + + # pyrefly: ignore [unsupported-operation] + envs[local_rank] = worker_env + worker_args = list(spec.args) + worker_args = macros.substitute(worker_args, str(local_rank)) + args[local_rank] = tuple(worker_args) + + self._setup_local_watchdog(envs=envs) + self._setup_healthcheck() + + assert spec.entrypoint is not None + assert self._logs_specs is not None + self._pcontext = start_processes( + name=spec.role, + entrypoint=spec.entrypoint, + args=args, + envs=envs, + logs_specs=self._logs_specs, + log_line_prefixes=log_line_prefixes, + start_method=self._start_method, + numa_options=spec.numa_options, + duplicate_stdout_filters=spec.duplicate_stdout_filters, + duplicate_stderr_filters=spec.duplicate_stderr_filters, + ) + + return self._pcontext.pids() + + def _set_local_rank_env( + self, worker_env: dict[str, str | None], local_rank: int, spec: WorkerSpec + ) -> None: + # Set CUDA_VISIBLE_DEVICES and LOCAL_RANK based on virtual_local_rank mode. + # Virtual mode: Each worker sees only its assigned GPU as device 0, LOCAL_RANK=0 + # Traditional mode: Workers see all GPUs, LOCAL_RANK matches actual local rank + + if spec.virtual_local_rank: + # Set LOCAL_RANK=0 and use CUDA_VISIBLE_DEVICES to control the actual GPU access. + + worker_env["LOCAL_RANK"] = "0" + + # Map local_rank through existing CUDA_VISIBLE_DEVICES + # HIP uses CUDA_VISIBLE_DEVICES as a compatibility hack: + # https://rocm.docs.amd.com/en/latest/conceptual/gpu-isolation.html#cuda-visible-devices + parent_visible_devices = os.getenv("CUDA_VISIBLE_DEVICES") + if parent_visible_devices is not None: + # Parse comma-separated list of GPU IDs + available_gpus = parent_visible_devices.split(",") + if local_rank >= len(available_gpus): + raise ValueError( + f"local_rank {local_rank} exceeds available GPUs in " + f"CUDA_VISIBLE_DEVICES={parent_visible_devices}" + ) + + visible_gpu = available_gpus[local_rank].strip() + else: + # No restriction, use local_rank directly + visible_gpu = str(local_rank) + + worker_env["CUDA_VISIBLE_DEVICES"] = visible_gpu + return + + # In traditional mode, don't override CUDA_VISIBLE_DEVICES + # (inherit from parent environment) + worker_env["LOCAL_RANK"] = str(local_rank) + + if "CUDA_VISIBLE_DEVICES" in os.environ: + worker_env["CUDA_VISIBLE_DEVICES"] = os.environ["CUDA_VISIBLE_DEVICES"] + + def _shutdown(self, death_sig: signal.Signals = signal.SIGTERM) -> None: + if self._worker_watchdog is not None: + self._worker_watchdog.stop() + self._worker_watchdog = None + if self._health_check_server is not None: + self._health_check_server.stop() + self._health_check_server = None + if self._pcontext: + self._pcontext.close(death_sig) + + # pyre-fixme[56]: Pyre was not able to infer the type of the decorator + # `torch.distributed.elastic.metrics.prof`. + @prof + def _monitor_workers(self, worker_group: WorkerGroup) -> RunResult: + role = worker_group.spec.role + worker_pids = {w.id for w in worker_group.workers} + assert self._pcontext is not None + pc_pids = set(self._pcontext.pids().values()) + if worker_pids != pc_pids: + logger.error( + "[%s] worker pids do not match process_context pids." + " Expected: %s, actual: %s", + role, + worker_pids, + pc_pids, + ) + return RunResult(state=WorkerState.UNKNOWN) + + result = self._pcontext.wait(0) + if result: + if result.is_failed(): + # map local rank failure to global rank + worker_failures = {} + for local_rank, failure in result.failures.items(): + worker = worker_group.workers[local_rank] + worker_failures[worker.global_rank] = failure + return RunResult( + state=WorkerState.FAILED, + failures=worker_failures, + ) + else: + # copy ret_val_queue into a map with a global ranks + workers_ret_vals = {} + for local_rank, ret_val in result.return_values.items(): + worker = worker_group.workers[local_rank] + workers_ret_vals[worker.global_rank] = ret_val + return RunResult( + state=WorkerState.SUCCEEDED, + return_values=workers_ret_vals, + ) + else: + return RunResult(state=WorkerState.HEALTHY) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/control_plane.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/control_plane.py new file mode 100644 index 0000000000000000000000000000000000000000..817255edd23dcee2deea8554ada3637d30f9885f --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/control_plane.py @@ -0,0 +1,53 @@ +import os +from collections.abc import Generator +from contextlib import contextmanager, ExitStack + +from torch.distributed.elastic.multiprocessing.errors import record + + +__all__ = [ + "worker_main", +] + +TORCH_WORKER_SERVER_SOCKET = "TORCH_WORKER_SERVER_SOCKET" + + +@contextmanager +def _worker_server(socket_path: str) -> Generator[None, None, None]: + from torch._C._distributed_c10d import _WorkerServer + + server = _WorkerServer(socket_path) + try: + yield + finally: + server.shutdown() + + +@record +@contextmanager +def worker_main() -> Generator[None, None, None]: + """ + This is a context manager that wraps your main entry function. This combines + the existing ``errors.record`` logic as well as a new ``_WorkerServer`` that + exposes handlers via a unix socket specified by + ``Torch_WORKER_SERVER_SOCKET``. + + Example + + :: + + @worker_main() + def main(): + pass + + + if __name__ == "__main__": + main() + + """ + with ExitStack() as stack: + socket_path = os.environ.get(TORCH_WORKER_SERVER_SOCKET) + if socket_path is not None: + stack.enter_context(_worker_server(socket_path)) + + yield diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/events/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/events/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..deea40f3899aee490a899cfa1dd6d3019512cb9e --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/events/__init__.py @@ -0,0 +1,173 @@ +#!/usr/bin/env/python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Module contains events processing mechanisms that are integrated with the standard python logging. + +Example of usage: + +:: + + from torch.distributed.elastic import events + + event = events.Event( + name="test_event", source=events.EventSource.WORKER, metadata={...} + ) + events.get_logging_handler(destination="console").info(event) + +""" + +import inspect +import logging +import os +import socket +import traceback +from typing import Optional + +from torch.distributed.elastic.events.handlers import get_logging_handler + +from .api import ( # noqa: F401 + Event, + EventMetadataValue, + EventSource, + NodeState, + RdzvEvent, +) + + +_events_loggers: dict[str, logging.Logger] = {} + + +def _get_or_create_logger(destination: str = "null") -> logging.Logger: + """ + Construct python logger based on the destination type or extends if provided. + + Available destination could be found in ``handlers.py`` file. + The constructed logger does not propagate messages to the upper level loggers, + e.g. root logger. This makes sure that a single event can be processed once. + + Args: + destination: The string representation of the event handler. + Available handlers found in ``handlers`` module + """ + global _events_loggers + + if destination not in _events_loggers: + _events_logger = logging.getLogger(f"torchelastic-events-{destination}") + _events_logger.setLevel(os.environ.get("LOGLEVEL", "INFO")) + # Do not propagate message to the root logger + _events_logger.propagate = False + + logging_handler = get_logging_handler(destination) + _events_logger.addHandler(logging_handler) + + # Add the logger to the global dictionary + _events_loggers[destination] = _events_logger + + return _events_loggers[destination] + + +def record(event: Event, destination: str = "null") -> None: + _get_or_create_logger(destination).info(event.serialize()) + + +def record_rdzv_event(event: RdzvEvent) -> None: + _get_or_create_logger("dynamic_rendezvous").info(event.serialize()) + + +def construct_and_record_rdzv_event( + run_id: str, + message: str, + node_state: NodeState, + name: str = "", + hostname: str = "", + pid: int | None = None, + master_endpoint: str = "", + local_id: int | None = None, + rank: int | None = None, +) -> None: + """ + Initialize rendezvous event object and record its operations. + + Args: + run_id (str): The run id of the rendezvous. + message (str): The message describing the event. + node_state (NodeState): The state of the node (INIT, RUNNING, SUCCEEDED, FAILED). + name (str): Event name. (E.g. Current action being performed). + hostname (str): Hostname of the node. + pid (Optional[int]): The process id of the node. + master_endpoint (str): The master endpoint for the rendezvous store, if known. + local_id (Optional[int]): The local_id of the node, if defined in dynamic_rendezvous.py + rank (Optional[int]): The rank of the node, if known. + Returns: + None + Example: + >>> # See DynamicRendezvousHandler class + >>> def _record( + ... self, + ... message: str, + ... node_state: NodeState = NodeState.RUNNING, + ... rank: Optional[int] = None, + ... ) -> None: + ... construct_and_record_rdzv_event( + ... name=f"{self.__class__.__name__}.{get_method_name()}", + ... run_id=self._settings.run_id, + ... message=message, + ... node_state=node_state, + ... hostname=self._this_node.addr, + ... pid=self._this_node.pid, + ... local_id=self._this_node.local_id, + ... rank=rank, + ... ) + """ + # We don't want to perform an extra computation if not needed. + if isinstance(get_logging_handler("dynamic_rendezvous"), logging.NullHandler): + return + + # Set up parameters. + if not hostname: + hostname = socket.getfqdn() + if not pid: + pid = os.getpid() + + # Determines which file called this function. + callstack = inspect.stack() + filename = "no_file" + if len(callstack) > 1: + stack_depth_1 = callstack[1] + filename = os.path.basename(stack_depth_1.filename) + if not name: + name = stack_depth_1.function + + # Delete the callstack variable. If kept, this can mess with python's + # garbage collector as we are holding on to stack frame information in + # the inspect module. + del callstack + + # Set up error trace if this is an exception + if node_state == NodeState.FAILED: + error_trace = traceback.format_exc() + else: + error_trace = "" + + # Initialize event object + event = RdzvEvent( + name=f"{filename}:{name}", + run_id=run_id, + message=message, + hostname=hostname, + pid=pid, + node_state=node_state, + master_endpoint=master_endpoint, + rank=rank, + local_id=local_id, + error_trace=error_trace, + ) + + # Finally, record the event. + record_rdzv_event(event) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/events/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/events/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6d6b063a0360bfe57e91cd61cada2a05b6e11bb5 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/events/__pycache__/__init__.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/events/__pycache__/api.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/events/__pycache__/api.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6bdb38d22b11c1ad7836ccb7b3ced0838bcc9ab4 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/events/__pycache__/api.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/events/__pycache__/handlers.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/events/__pycache__/handlers.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..886e08d158e0ed7a956fd8f38768da4bb67873ce Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/events/__pycache__/handlers.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/events/api.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/events/api.py new file mode 100644 index 0000000000000000000000000000000000000000..31afe29ff5f597b27b453e9993e1257e3f1f8d2a --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/events/api.py @@ -0,0 +1,116 @@ +#!/usr/bin/env python3 +# mypy: allow-untyped-defs + +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import json +from dataclasses import asdict, dataclass, field +from enum import Enum +from typing import Union + + +__all__ = ["EventSource", "Event", "NodeState", "RdzvEvent"] + +EventMetadataValue = Union[str, int, float, bool, None] + + +class EventSource(str, Enum): + """Known identifiers of the event producers.""" + + AGENT = "AGENT" + WORKER = "WORKER" + + +@dataclass +class Event: + """ + The class represents the generic event that occurs during the torchelastic job execution. + + The event can be any kind of meaningful action. + + Args: + name: event name. + source: the event producer, e.g. agent or worker + timestamp: timestamp in milliseconds when event occurred. + metadata: additional data that is associated with the event. + """ + + name: str + source: EventSource + timestamp: int = 0 + metadata: dict[str, EventMetadataValue] = field(default_factory=dict) + + def __str__(self): + return self.serialize() + + @staticmethod + def deserialize(data: Union[str, "Event"]) -> "Event": + if isinstance(data, Event): + return data + if isinstance(data, str): + data_dict = json.loads(data) + data_dict["source"] = EventSource[data_dict["source"]] # type: ignore[possibly-undefined] + # pyrefly: ignore [unbound-name] + return Event(**data_dict) + + def serialize(self) -> str: + return json.dumps(asdict(self)) + + +class NodeState(str, Enum): + """The states that a node can be in rendezvous.""" + + INIT = "INIT" + RUNNING = "RUNNING" + SUCCEEDED = "SUCCEEDED" + FAILED = "FAILED" + + +@dataclass +class RdzvEvent: + """ + Dataclass to represent any rendezvous event. + + Args: + name: Event name. (E.g. Current action being performed) + run_id: The run id of the rendezvous + message: The message describing the event + hostname: Hostname of the node + pid: The process id of the node + node_state: The state of the node (INIT, RUNNING, SUCCEEDED, FAILED) + master_endpoint: The master endpoint for the rendezvous store, if known + rank: The rank of the node, if known + local_id: The local_id of the node, if defined in dynamic_rendezvous.py + error_trace: Error stack trace, if this is an error event. + """ + + name: str + run_id: str + message: str + hostname: str + pid: int + node_state: NodeState + master_endpoint: str = "" + rank: int | None = None + local_id: int | None = None + error_trace: str = "" + + def __str__(self): + return self.serialize() + + @staticmethod + def deserialize(data: Union[str, "RdzvEvent"]) -> "RdzvEvent": + if isinstance(data, RdzvEvent): + return data + if isinstance(data, str): + data_dict = json.loads(data) + data_dict["node_state"] = NodeState[data_dict["node_state"]] # type: ignore[possibly-undefined] + # pyrefly: ignore [unbound-name] + return RdzvEvent(**data_dict) + + def serialize(self) -> str: + return json.dumps(asdict(self)) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/events/handlers.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/events/handlers.py new file mode 100644 index 0000000000000000000000000000000000000000..30d925353253d5bab4c4780f298e7fa68a4409e5 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/events/handlers.py @@ -0,0 +1,21 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import logging + + +_log_handlers: dict[str, logging.Handler] = { + "console": logging.StreamHandler(), + "dynamic_rendezvous": logging.NullHandler(), + "null": logging.NullHandler(), +} + + +def get_logging_handler(destination: str = "null") -> logging.Handler: + global _log_handlers + return _log_handlers[destination] diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/metrics/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/metrics/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b2c2330924879ddbe35629a82d94a9b0c4c9c339 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/metrics/__init__.py @@ -0,0 +1,168 @@ +#!/usr/bin/env/python3 +# mypy: allow-untyped-defs + +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Metrics API. + +**Overview**: + +The metrics API in torchelastic is used to publish telemetry metrics. +It is designed to be used by torchelastic's internal modules to +publish metrics for the end user with the goal of increasing visibility +and helping with debugging. However you may use the same API in your +jobs to publish metrics to the same metrics ``sink``. + +A ``metric`` can be thought of as timeseries data +and is uniquely identified by the string-valued tuple +``(metric_group, metric_name)``. + +torchelastic makes no assumptions about what a ``metric_group`` is +and what relationship it has with ``metric_name``. It is totally up +to the user to use these two fields to uniquely identify a metric. + +.. note:: The metric group ``torchelastic`` is reserved by torchelastic for + platform level metrics that it produces. + For instance torchelastic may output the latency (in milliseconds) + of a re-rendezvous operation from the agent as + ``(torchelastic, agent.rendezvous.duration.ms)`` + +A sensible way to use metric groups is to map them to a stage or module +in your job. You may also encode certain high level properties +the job such as the region or stage (dev vs prod). + +**Publish Metrics**: + +Using torchelastic's metrics API is similar to using python's logging +framework. You first have to configure a metrics handler before +trying to add metric data. + +The example below measures the latency for the ``calculate()`` function. + +:: + + import time + import torch.distributed.elastic.metrics as metrics + + # makes all metrics other than the one from "my_module" to go /dev/null + metrics.configure(metrics.NullMetricsHandler()) + metrics.configure(metrics.ConsoleMetricsHandler(), "my_module") + + + def my_method(): + start = time.time() + calculate() + end = time.time() + metrics.put_metric("calculate_latency", int(end - start), "my_module") + +You may also use the torch.distributed.elastic.metrics.prof` decorator +to conveniently and succinctly profile functions + +:: + + # -- in module examples.foobar -- + + import torch.distributed.elastic.metrics as metrics + + metrics.configure(metrics.ConsoleMetricsHandler(), "foobar") + metrics.configure(metrics.ConsoleMetricsHandler(), "Bar") + + + @metrics.prof + def foo(): + pass + + + class Bar: + @metrics.prof + def baz(): + pass + +``@metrics.prof`` will publish the following metrics +:: + + .success - 1 if the function finished successfully + .failure - 1 if the function threw an exception + .duration.ms - function duration in milliseconds + +**Configuring Metrics Handler**: + +`torch.distributed.elastic.metrics.MetricHandler` is responsible for emitting +the added metric values to a particular destination. Metric groups can be +configured with different metric handlers. + +By default torchelastic emits all metrics to ``/dev/null``. +By adding the following configuration metrics, +``torchelastic`` and ``my_app`` metric groups will be printed out to +console. + +:: + + import torch.distributed.elastic.metrics as metrics + + metrics.configure(metrics.ConsoleMetricHandler(), group="torchelastic") + metrics.configure(metrics.ConsoleMetricHandler(), group="my_app") + +**Writing a Custom Metric Handler**: + +If you want your metrics to be emitted to a custom location, implement +the `torch.distributed.elastic.metrics.MetricHandler` interface +and configure your job to use your custom metric handler. + +Below is a toy example that prints the metrics to ``stdout`` + +:: + + import torch.distributed.elastic.metrics as metrics + + + class StdoutMetricHandler(metrics.MetricHandler): + def emit(self, metric_data): + ts = metric_data.timestamp + group = metric_data.group_name + name = metric_data.name + value = metric_data.value + print(f"[{ts}][{group}]: {name}={value}") + + + metrics.configure(StdoutMetricHandler(), group="my_app") + +Now all metrics in the group ``my_app`` will be printed to stdout as: + +:: + + [1574213883.4182858][my_app]: my_metric= + [1574213940.5237644][my_app]: my_metric= + +""" + +from typing import Optional + +from .api import ( # noqa: F401 + configure, + ConsoleMetricHandler, + get_elapsed_time_ms, + getStream, + MetricData, + MetricHandler, + MetricsConfig, + NullMetricHandler, + prof, + profile, + publish_metric, + put_metric, +) + + +def initialize_metrics(cfg: MetricsConfig | None = None): + pass + + +try: + from torch.distributed.elastic.metrics.static_init import * # type: ignore[import] # noqa: F401 F403 +except ModuleNotFoundError: + pass diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/metrics/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/metrics/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2ad71c7c9137545ee12103f2c13f4fa43725595b Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/metrics/__pycache__/__init__.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/metrics/__pycache__/api.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/metrics/__pycache__/api.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..afdf113e704df6765c032b10f0e1e40916e5b5e6 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/metrics/__pycache__/api.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/metrics/api.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/metrics/api.py new file mode 100644 index 0000000000000000000000000000000000000000..102049481538d15a7fe995a8602ba45d6842303e --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/metrics/api.py @@ -0,0 +1,216 @@ +#!/usr/bin/env python3 +# mypy: allow-untyped-defs + +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import abc +import time +from collections import namedtuple +from functools import wraps +from typing_extensions import deprecated + + +__all__ = [ + "MetricsConfig", + "MetricHandler", + "ConsoleMetricHandler", + "NullMetricHandler", + "MetricStream", + "configure", + "getStream", + "prof", + "profile", + "put_metric", + "publish_metric", + "get_elapsed_time_ms", + "MetricData", +] + +MetricData = namedtuple("MetricData", ["timestamp", "group_name", "name", "value"]) + + +class MetricsConfig: + __slots__ = ["params"] + + def __init__(self, params: dict[str, str] | None = None): + self.params = params + if self.params is None: + self.params = {} + + +class MetricHandler(abc.ABC): + @abc.abstractmethod + def emit(self, metric_data: MetricData): + pass + + +class ConsoleMetricHandler(MetricHandler): + def emit(self, metric_data: MetricData): + print( + f"[{metric_data.timestamp}][{metric_data.group_name}]: {metric_data.name}={metric_data.value}" + ) + + +class NullMetricHandler(MetricHandler): + def emit(self, metric_data: MetricData): + pass + + +class MetricStream: + def __init__(self, group_name: str, handler: MetricHandler): + self.group_name = group_name + self.handler = handler + + def add_value(self, metric_name: str, metric_value: int): + self.handler.emit( + MetricData(time.time(), self.group_name, metric_name, metric_value) + ) + + +_metrics_map: dict[str, MetricHandler] = {} +_default_metrics_handler: MetricHandler = NullMetricHandler() + + +# pyre-fixme[9]: group has type `str`; used as `None`. +def configure(handler: MetricHandler, group: str | None = None): + if group is None: + global _default_metrics_handler + # pyre-fixme[9]: _default_metrics_handler has type `NullMetricHandler`; used + # as `MetricHandler`. + _default_metrics_handler = handler + else: + _metrics_map[group] = handler + + +def getStream(group: str): + handler = _metrics_map.get(group, _default_metrics_handler) + return MetricStream(group, handler) + + +def _get_metric_name(fn): + qualname = fn.__qualname__ + split = qualname.split(".") + if len(split) == 1: + module = fn.__module__ + if module: + return module.split(".")[-1] + "." + split[0] + else: + return split[0] + else: + return qualname + + +def prof(fn=None, group: str = "torchelastic"): + r""" + @profile decorator publishes duration.ms, count, success, failure metrics for the function that it decorates. + + The metric name defaults to the qualified name (``class_name.def_name``) of the function. + If the function does not belong to a class, it uses the leaf module name instead. + + Usage + + :: + + @metrics.prof + def x(): + pass + + + @metrics.prof(group="agent") + def y(): + pass + """ + + def wrap(f): + @wraps(f) + def wrapper(*args, **kwargs): + key = _get_metric_name(f) + try: + start = time.time() + result = f(*args, **kwargs) + put_metric(f"{key}.success", 1, group) + except Exception: + put_metric(f"{key}.failure", 1, group) + raise + finally: + put_metric(f"{key}.duration.ms", get_elapsed_time_ms(start), group) # type: ignore[possibly-undefined] + return result + + return wrapper + + if fn: + return wrap(fn) + else: + return wrap + + +@deprecated("Deprecated, use `@prof` instead", category=FutureWarning) +def profile(group=None): + """ + @profile decorator adds latency and success/failure metrics to any given function. + + Usage + + :: + + @metrics.profile("my_metric_group") + def some_function(): + """ + + def wrap(func): + @wraps(func) + def wrapper(*args, **kwargs): + try: + start_time = time.time() + result = func(*args, **kwargs) + # pyrefly: ignore [bad-argument-type] + publish_metric(group, f"{func.__name__}.success", 1) + except Exception: + # pyrefly: ignore [bad-argument-type] + publish_metric(group, f"{func.__name__}.failure", 1) + raise + finally: + publish_metric( + # pyrefly: ignore [bad-argument-type] + group, + f"{func.__name__}.duration.ms", + get_elapsed_time_ms(start_time), # type: ignore[possibly-undefined] + ) + return result + + return wrapper + + return wrap + + +def put_metric(metric_name: str, metric_value: int, metric_group: str = "torchelastic"): + """ + Publish a metric data point. + + Usage + + :: + + put_metric("metric_name", 1) + put_metric("metric_name", 1, "metric_group_name") + """ + getStream(metric_group).add_value(metric_name, metric_value) + + +@deprecated( + "Deprecated, use `put_metric(metric_group)(metric_name, metric_value)` instead", + category=FutureWarning, +) +def publish_metric(metric_group: str, metric_name: str, metric_value: int): + metric_stream = getStream(metric_group) + metric_stream.add_value(metric_name, metric_value) + + +def get_elapsed_time_ms(start_time_in_seconds: float): + """Return the elapsed time in millis from the given start time.""" + end_time = time.time() + return int((end_time - start_time_in_seconds) * 1000) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/multiprocessing/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/multiprocessing/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..60b7cd32fd2531a3e3b04416b75a29767ba835fa --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/multiprocessing/__init__.py @@ -0,0 +1,256 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Library that launches and manages ``n`` copies of worker subprocesses either specified by a function or a binary. + +For functions, it uses ``torch.multiprocessing`` (and therefore python +``multiprocessing``) to spawn/fork worker processes. For binaries it uses python +``subprocessing.Popen`` to create worker processes. + + +Usage 1: Launching two trainers as a function + +:: + + from torch.distributed.elastic.multiprocessing import Std, start_processes + + + def trainer(a, b, c): + pass # train + + + # runs two trainers + # LOCAL_RANK=0 trainer(1,2,3) + # LOCAL_RANK=1 trainer(4,5,6) + ctx = start_processes( + name="trainer", + entrypoint=trainer, + args={0: (1, 2, 3), 1: (4, 5, 6)}, + envs={0: {"LOCAL_RANK": 0}, 1: {"LOCAL_RANK": 1}}, + log_dir="/tmp/foobar", + redirects=Std.ALL, # write all worker stdout/stderr to a log file + tee={0: Std.ERR}, # tee only local rank 0's stderr to console + ) + + # waits for all copies of trainer to finish + ctx.wait() + +Usage 2: Launching 2 echo workers as a binary + +:: + + # same as invoking + # echo hello + # echo world > stdout.log + ctx = start_processes( + name="echo" + entrypoint="echo", + log_dir="/tmp/foobar", + args={0: "hello", 1: "world"}, + redirects={1: Std.OUT}, + ) + +Just like ``torch.multiprocessing``, the return value of the function +:func:`start_processes` is a process context (:class:`api.PContext`). If a function +was launched, a :class:`api.MultiprocessContext` is returned and if a binary +was launched a :class:`api.SubprocessContext` is returned. Both are specific +implementations of the parent :class:`api.PContext` class. +""" + +from collections.abc import Callable +from typing import Optional, Union + +from torch.distributed.elastic.multiprocessing.api import ( # noqa: F401 + _validate_full_rank, + DefaultLogsSpecs, + LogsDest, + LogsSpecs, + MultiprocessContext, + PContext, + ProcessFailure, + RunProcsResult, + SignalException, + Std, + SubprocessContext, + to_map, +) +from torch.distributed.elastic.utils.logging import get_logger +from torch.numa.binding import NumaOptions + + +__all__ = [ + "start_processes", + "MultiprocessContext", + "PContext", + "ProcessFailure", + "RunProcsResult", + "SignalException", + "Std", + "LogsDest", + "LogsSpecs", + "DefaultLogsSpecs", + "SubprocessContext", + "to_map", +] + + +def start_processes( + name: str, + entrypoint: Callable | str, + args: dict[int, tuple], + envs: dict[int, dict[str, str]], + logs_specs: LogsSpecs, + log_line_prefixes: dict[int, str] | None = None, + start_method: str = "spawn", + numa_options: NumaOptions | None = None, + duplicate_stdout_filters: list[str] | None = None, + duplicate_stderr_filters: list[str] | None = None, +) -> PContext: + """ + Start ``n`` copies of ``entrypoint`` processes with the provided options. + + ``entrypoint`` is either a ``Callable`` (function) or a ``str`` (binary). + The number of copies is determined by the number of entries for ``args`` and + ``envs`` arguments, which need to have the same key set. + + ``args`` and ``env`` parameters are the arguments and environment variables + to pass down to the entrypoint mapped by the replica index (local rank). + All local ranks must be accounted for. + That is, the keyset should be ``{0,1,...,(nprocs-1)}``. + + .. note:: When the ``entrypoint`` is a binary (``str``), ``args`` can only be strings. + If any other type is given, then it is casted to a string representation + (e.g. ``str(arg1)``). Furthermore, a binary failure will only write + an ``error.json`` error file if the main function is annotated with + ``torch.distributed.elastic.multiprocessing.errors.record``. For function launches, + this is done by default and there is no need to manually annotate + with the ``@record`` annotation. + + Inside ``logs_specs``, ``redirects`` and ``tee`` are bitmasks specifying which std + stream(s) to redirect to a log file in the ``log_dir``. Valid mask values are defined + in ``Std``. To redirect/tee only certain local ranks, pass ``redirects`` as a map + with the key as the local rank to specify the redirect behavior for. + Any missing local ranks will default to ``Std.NONE``. + + ``duplicate_stdout_filters`` and ``duplicate_stderr_filters``, if non-empty, + duplicate stdouts and stderrs respectively specified in ``logs_specs``'s ``tee`` + to a file containing only lines that match _any_ of the filter strings. The log + file is aggregated across all ranks selected by ``tee``. + + ``tee`` acts like the unix "tee" command in that it redirects + prints to console. + To avoid worker stdout/stderr from printing to console, use the ``redirects`` parameter. + + For each process, the ``log_dir`` will contain: + + #. ``{local_rank}/error.json``: if the process failed, a file with the error info + #. ``{local_rank}/stdout.log``: if ``redirect & STDOUT == STDOUT`` + #. ``{local_rank}/stderr.log``: if ``redirect & STDERR == STDERR`` + #. ``filtered_stdout.log``: if ``duplicate_stdout_filters`` is non-empty + #. ``filtered_stderr.log``: if ``duplicate_stderr_filters`` is non-empty + + .. note:: It is expected that the ``log_dir`` exists, is empty, and is a directory. + + Example: + :: + + log_dir = "/tmp/test" + + # ok; two copies of foo: foo("bar0"), foo("bar1") + start_processes( + name="trainer", + entrypoint=foo, + args:{0:("bar0",), 1:("bar1",), + envs:{0:{}, 1:{}}, + log_dir=log_dir + ) + + # invalid; envs missing for local rank 1 + start_processes( + name="trainer", + entrypoint=foo, + args:{0:("bar0",), 1:("bar1",), + envs:{0:{}}, + log_dir=log_dir + ) + + # ok; two copies of /usr/bin/touch: touch file1, touch file2 + start_processes( + name="trainer", + entrypoint="/usr/bin/touch", + args:{0:("file1",), 1:("file2",), + envs:{0:{}, 1:{}}, + log_dir=log_dir + ) + + # caution; arguments casted to string, runs: + # echo "1" "2" "3" and echo "[1, 2, 3]" + start_processes( + name="trainer", + entrypoint="/usr/bin/echo", + args:{0:(1,2,3), 1:([1,2,3],), + envs:{0:{}, 1:{}}, + log_dir=log_dir + ) + + Args: + name: a human readable short name that describes what the processes are + (used as header when tee'ing stdout/stderr outputs) + entrypoint: either a ``Callable`` (function) or ``cmd`` (binary) + args: arguments to each replica + envs: env vars to each replica + log_dir: directory used to write log files + start_method: multiprocessing start method (spawn, fork, forkserver) + ignored for binaries + logs_specs: defines ``log_dir``, ``redirects``, and ``tee``. + inside ``logs_specs``: + - redirects: which std streams to redirect to a log file + - tee: which std streams to redirect + print to console + local_ranks_filter: which ranks' logs to print to console + duplicate_stdout_filters: filters for the duplicated stdout logs + duplicate_stderr_filters: filters for the duplicated stderr logs + + """ + + nprocs = len(args) + _validate_full_rank(args, nprocs, "args") + _validate_full_rank(envs, nprocs, "envs") + + context: PContext + if isinstance(entrypoint, str): + context = SubprocessContext( + name=name, + entrypoint=entrypoint, + args=args, + envs=envs, + duplicate_stdout_filters=duplicate_stdout_filters, + duplicate_stderr_filters=duplicate_stderr_filters, + logs_specs=logs_specs, + log_line_prefixes=log_line_prefixes, + numa_options=numa_options, + ) + else: + context = MultiprocessContext( + name=name, + entrypoint=entrypoint, + args=args, + envs=envs, + duplicate_stdout_filters=duplicate_stdout_filters, + duplicate_stderr_filters=duplicate_stderr_filters, + log_line_prefixes=log_line_prefixes, + start_method=start_method, + logs_specs=logs_specs, + numa_options=numa_options, + ) + + try: + context.start() + return context + except Exception: + context.close() + raise diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/multiprocessing/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/multiprocessing/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5022ed8c07179ba3bd39b7235239986169d0dd5b Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/multiprocessing/__pycache__/__init__.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/multiprocessing/__pycache__/api.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/multiprocessing/__pycache__/api.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e8f066a54e89370908d93a91003280c624c36832 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/multiprocessing/__pycache__/api.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/multiprocessing/__pycache__/redirects.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/multiprocessing/__pycache__/redirects.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cce656006ea1ddff348a34aaa218ce12004166ee Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/multiprocessing/__pycache__/redirects.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/multiprocessing/__pycache__/tail_log.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/multiprocessing/__pycache__/tail_log.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3df70dd46b7dd8321d85d9802710969d13377f4c Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/multiprocessing/__pycache__/tail_log.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/multiprocessing/api.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/multiprocessing/api.py new file mode 100644 index 0000000000000000000000000000000000000000..45351c380ca0db821149edd174cb588192619be6 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/multiprocessing/api.py @@ -0,0 +1,1036 @@ +#!/usr/bin/env python3 +# mypy: allow-untyped-defs + +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import abc +import logging +import os +import re +import shutil +import signal +import subprocess +import sys +import tempfile +import threading +import time +from abc import ABC, abstractmethod +from collections.abc import Callable +from contextlib import nullcontext +from dataclasses import dataclass, field +from enum import IntFlag +from multiprocessing import synchronize +from types import FrameType +from typing import Any, TextIO, Union + +import torch.multiprocessing as mp +from torch.distributed.elastic.multiprocessing.errors import ProcessFailure, record +from torch.distributed.elastic.multiprocessing.redirects import ( + redirect_stderr, + redirect_stdout, +) +from torch.distributed.elastic.multiprocessing.subprocess_handler import ( + get_subprocess_handler, + SubprocessHandler, +) +from torch.distributed.elastic.multiprocessing.tail_log import TailLog +from torch.numa.binding import maybe_wrap_with_numa_binding, NumaOptions + + +IS_WINDOWS = sys.platform == "win32" +IS_MACOS = sys.platform == "darwin" + + +logger = logging.getLogger(__name__) + +__all__ = [ + "DefaultLogsSpecs", + "SignalException", + "Std", + "to_map", + "RunProcsResult", + "PContext", + "get_std_cm", + "MultiprocessContext", + "SubprocessContext", + "LogsDest", + "LogsSpecs", +] + + +class SignalException(Exception): + """ + Exception is raised inside the torchelastic agent process by the termination handler + if the death signal got received by the process. + """ + + def __init__(self, msg: str, sigval: signal.Signals) -> None: + super().__init__(msg) + self.sigval = sigval + + +def _terminate_process_handler(signum: int, frame: FrameType | None) -> None: + """Termination handler that raises exceptions on the main process. + + When the process receives death signal(SIGTERM, SIGINT), this termination handler will + be invoked. It raises the ``SignalException`` exception that should be processed by the + user code. Python does not terminate process after the termination handler is finished, + so the exception should not be silently ignored, otherwise the process will never + be terminated. + """ + sigval = signal.Signals(signum) + raise SignalException(f"Process {os.getpid()} got signal: {sigval}", sigval=sigval) + + +def _get_kill_signal() -> signal.Signals: + """Get the kill signal. SIGKILL for unix, CTRL_C_EVENT for windows.""" + if IS_WINDOWS: + return signal.CTRL_C_EVENT # type: ignore[attr-defined] # noqa: F821 + else: + return signal.SIGKILL + + +def _get_default_signal() -> signal.Signals: + """Get the default termination signal. SIGTERM for unix, CTRL_C_EVENT for windows.""" + if IS_WINDOWS: + return signal.CTRL_C_EVENT # type: ignore[attr-defined] # noqa: F821 + else: + return signal.SIGTERM + + +def _validate_full_rank(d: dict[int, Any], nprocs: int, what: str): + actual_keys = set(d.keys()) + expected_keys = set(range(nprocs)) + + if actual_keys != expected_keys: + raise RuntimeError( + f"{what}, local rank mapping mismatch," + f" expected: {expected_keys}, actual: {actual_keys}" + ) + + +_MAPPING_REGEX = r"^(\d:[0123],)*(\d:[0123])$" +_VALUE_REGEX = r"^[0123]$" + + +class Std(IntFlag): + NONE = 0 + OUT = 1 + ERR = 2 + ALL = OUT | ERR + + @classmethod + def from_str(cls, vm: str) -> Union["Std", dict[int, "Std"]]: + """ + Example: + :: + + from_str("0") -> Std.NONE + from_str("1") -> Std.OUT + from_str("0:3,1:0,2:1,3:2") -> {0: Std.ALL, 1: Std.NONE, 2: Std.OUT, 3: Std.ERR} + + Any other input raises an exception + """ + + def to_std(v: str) -> Std: # type: ignore[return] + s = Std(int(v)) + if s in Std: + return s + # return None -> should NEVER reach here since we regex check input + + if re.match(_VALUE_REGEX, vm): # vm is a number (e.g. 0) + return to_std(vm) + elif re.match(_MAPPING_REGEX, vm): # vm is a mapping (e.g. 0:1,1:2) + d: dict[int, Std] = {} + for m in vm.split(","): + i, v = m.split(":") + d[int(i)] = to_std(v) + return d + else: + raise ValueError( + f"{vm} does not match: <{_VALUE_REGEX}> or <{_MAPPING_REGEX}>" + ) + + +def to_map(val_or_map: Std | dict[int, Std], local_world_size: int) -> dict[int, Std]: + """ + Certain APIs take redirect settings either as a single value (e.g. apply to all + local ranks) or as an explicit user-provided mapping. This method is a convenience + method that converts a value or mapping into a mapping. + + Example: + :: + + to_map(Std.OUT, local_world_size=2) # returns: {0: Std.OUT, 1: Std.OUT} + to_map({1: Std.OUT}, local_world_size=2) # returns: {0: Std.NONE, 1: Std.OUT} + to_map( + {0: Std.OUT, 1: Std.OUT}, local_world_size=2 + ) # returns: {0: Std.OUT, 1: Std.OUT} + """ + if isinstance(val_or_map, Std): + return dict.fromkeys(range(local_world_size), val_or_map) + else: + map = {} + for i in range(local_world_size): + map[i] = val_or_map.get(i, Std.NONE) + return map + + +@dataclass +class LogsDest: + """ + For each log type, holds mapping of local rank ids to file paths. + """ + + stdouts: dict[int, str] = field(default_factory=dict) + stderrs: dict[int, str] = field(default_factory=dict) + tee_stdouts: dict[int, str] = field(default_factory=dict) + tee_stderrs: dict[int, str] = field(default_factory=dict) + error_files: dict[int, str] = field(default_factory=dict) + filtered_stdout: str = field(default_factory=str) + filtered_stderr: str = field(default_factory=str) + + +class LogsSpecs(ABC): + """ + Defines logs processing and redirection for each worker process. + + Args: + log_dir: + Base directory where logs will be written. + redirects: + Streams to redirect to files. Pass a single ``Std`` + enum to redirect for all workers, or a mapping keyed + by local_rank to selectively redirect. + tee: + Streams to duplicate to stdout/stderr. + Pass a single ``Std`` enum to duplicate streams for all workers, + or a mapping keyed by local_rank to selectively duplicate. + """ + + def __init__( + self, + log_dir: str | None = None, + redirects: Std | dict[int, Std] = Std.NONE, + tee: Std | dict[int, Std] = Std.NONE, + local_ranks_filter: set[int] | None = None, + ) -> None: + self._root_log_dir = log_dir + self._redirects = redirects + self._tee = tee + self._local_ranks_filter = local_ranks_filter + + @abstractmethod + def reify( + self, + envs: dict[int, dict[str, str]], + ) -> LogsDest: + """ + Given the environment variables, builds destination of log files for each of the local ranks. + + Envs parameter contains env variables dict for each of the local ranks, where entries are defined in: + :func:`~torchelastic.distributed.elastic.agent.server.local_elastic_agent.LocalElasticAgent._start_workers`. + """ + + @property + @abstractmethod + def root_log_dir(self) -> str: + pass + + +class DefaultLogsSpecs(LogsSpecs): + """ + Default LogsSpecs implementation: + + - `log_dir` will be created if it doesn't exist + - Generates nested folders for each attempt and rank. + """ + + def __init__( + self, + log_dir: str | None = None, + redirects: Std | dict[int, Std] = Std.NONE, + tee: Std | dict[int, Std] = Std.NONE, + local_ranks_filter: set[int] | None = None, + ) -> None: + if log_dir != os.devnull: + if not log_dir: + log_dir = tempfile.mkdtemp(prefix="torchelastic_") + elif not os.path.exists(log_dir): + os.makedirs(log_dir, exist_ok=True) + else: + if os.path.isfile(log_dir): + raise NotADirectoryError(f"log_dir: {log_dir} is a file") + super().__init__(log_dir, redirects, tee, local_ranks_filter) + # initialized only once + self._run_log_dir = None + + @property + def root_log_dir(self) -> str: + return str(self._root_log_dir) + + def _make_log_dir(self, log_dir: str | None, rdzv_run_id: str): + base_log_dir = log_dir or tempfile.mkdtemp(prefix="torchelastic_") + os.makedirs(base_log_dir, exist_ok=True) + dir = tempfile.mkdtemp(prefix=f"{rdzv_run_id}_", dir=base_log_dir) + logger.info("log directory set to: %s", dir) + return dir + + def reify( + self, + envs: dict[int, dict[str, str]], + ) -> LogsDest: + """ + Uses following scheme to build log destination paths: + + - `//attempt_//stdout.log` + - `//attempt_//stderr.log` + - `//attempt_//error.json` + - `//attempt_/filtered_stdout.log` + - `//attempt_/filtered_stderr.log` + """ + nprocs = len(envs) + global_env = {} # use only to query properties that are not dependent on a rank + if nprocs > 0: + global_env = envs[0] + else: + logger.warning( + "Empty envs map provided when defining logging destinations." + ) + # Keys are always defined, but values can be missing in unit tests + run_id = global_env.get("TORCHELASTIC_RUN_ID", "test_run_id") + restart_count = global_env.get("TORCHELASTIC_RESTART_COUNT", "0") + + attempt_log_dir: str = "" + if self._root_log_dir != os.devnull: + if not self._run_log_dir: + self._run_log_dir = self._make_log_dir(self._root_log_dir, run_id) + + attempt_log_dir = os.path.join( + self._run_log_dir, f"attempt_{restart_count}" + ) # type: ignore[call-overload] + shutil.rmtree(attempt_log_dir, ignore_errors=True) + os.makedirs(attempt_log_dir) + + if self._root_log_dir == os.devnull: + attempt_log_dir = os.devnull + + # create subdirs for each local rank in the logs_dir + # logs_dir + # |- 0 + # |- error.json + # |- stdout.log + # |- stderr.log + # |- ... + # |- (nprocs-1) + redirs = to_map(self._redirects, nprocs) + ts = to_map(self._tee, nprocs) + + # to tee stdout/stderr we first redirect into a file + # then tail -f stdout.log/stderr.log so add tee settings to redirects + for local_rank, tee_std in ts.items(): + redirect_std = redirs[local_rank] + redirs[local_rank] = redirect_std | tee_std + + SYS_STREAM = "" # special case to indicate to output to console + stdouts = dict.fromkeys(range(nprocs), SYS_STREAM) + stderrs = dict.fromkeys(range(nprocs), SYS_STREAM) + tee_stdouts: dict[int, str] = {} + tee_stderrs: dict[int, str] = {} + error_files = {} + + for local_rank in range(nprocs): + if attempt_log_dir == os.devnull: + tee_stdouts[local_rank] = os.devnull + tee_stderrs[local_rank] = os.devnull + error_files[local_rank] = os.devnull + envs[local_rank]["TORCHELASTIC_ERROR_FILE"] = "" + else: + clogdir = os.path.join(attempt_log_dir, str(local_rank)) + os.mkdir(clogdir) + + rd = redirs[local_rank] + if (rd & Std.OUT) == Std.OUT: + stdouts[local_rank] = os.path.join(clogdir, "stdout.log") + if (rd & Std.ERR) == Std.ERR: + stderrs[local_rank] = os.path.join(clogdir, "stderr.log") + + t = ts[local_rank] + if t & Std.OUT == Std.OUT: + tee_stdouts[local_rank] = stdouts[local_rank] + if t & Std.ERR == Std.ERR: + tee_stderrs[local_rank] = stderrs[local_rank] + + if ( + self._local_ranks_filter + and local_rank not in self._local_ranks_filter + ): + # If stream is tee'd, only write to file, but don't tail + if local_rank in tee_stdouts: + tee_stdouts.pop(local_rank, None) + if local_rank in tee_stderrs: + tee_stderrs.pop(local_rank, None) + + # If stream is not redirected, don't print + if stdouts[local_rank] == SYS_STREAM: + stdouts[local_rank] = os.devnull + if stderrs[local_rank] == SYS_STREAM: + stderrs[local_rank] = os.devnull + + error_file = os.path.join(clogdir, "error.json") + error_files[local_rank] = error_file + logger.info( + "Setting worker%s reply file to: %s", local_rank, error_file + ) + envs[local_rank]["TORCHELASTIC_ERROR_FILE"] = error_file + + return LogsDest( + stdouts, + stderrs, + tee_stdouts, + tee_stderrs, + error_files, + os.path.join(attempt_log_dir, "filtered_stdout.log"), + os.path.join(attempt_log_dir, "filtered_stderr.log"), + ) + + def __repr__(self) -> str: + return ( + f"DefaultLogsSpecs(root_log_dir={self._root_log_dir}, redirects={self._redirects}, " + f"tee={self._tee}, local_ranks_filter={self._local_ranks_filter})" + ) + + def __eq__(self, other: object) -> bool: + if not isinstance(other, DefaultLogsSpecs): + return False + + return ( + self._root_log_dir == other._root_log_dir + and self._redirects == other._redirects + and self._tee == other._tee + and self._local_ranks_filter == other._local_ranks_filter + ) + + +@dataclass +class RunProcsResult: + """ + Results of a completed run of processes started with ``start_processes()``. Returned by ``PContext``. + + Note the following: + + 1. All fields are mapped by local rank + 2. ``return_values`` - only populated for functions (not the binaries). + 3. ``stdouts`` - path to stdout.log (empty string if no redirect) + 4. ``stderrs`` - path to stderr.log (empty string if no redirect) + + """ + + return_values: dict[int, Any] = field(default_factory=dict) + failures: dict[int, ProcessFailure] = field(default_factory=dict) + stdouts: dict[int, str] = field(default_factory=dict) + stderrs: dict[int, str] = field(default_factory=dict) + + def is_failed(self) -> bool: + return len(self.failures) > 0 + + +class PContext(abc.ABC): + """ + The base class that standardizes operations over a set of processes that are launched via different mechanisms. + + The name ``PContext`` is intentional to disambiguate with ``torch.multiprocessing.ProcessContext``. + + .. warning:: stdouts and stderrs should ALWAYS be a superset of + tee_stdouts and tee_stderrs (respectively) this is b/c + tee is implemented as a redirect + tail -f + + Args: + duplicate_stdout_filters: + If non-empty, duplicates stdouts specified in ``logs_specs``'s ``tee`` + to a file containing only lines that match _any_ of the filter strings. + The log file is aggregated across all ranks selected by ``tee``. + duplicate_stderr_filters: + If non-empty, duplicates stderrs specified in ``logs_specs``'s ``tee`` + to a file containing only lines that match _any_ of the filter strings. + The log file is aggregated across all ranks selected by ``tee``. + """ + + def __init__( + self, + name: str, + entrypoint: Callable | str, + args: dict[int, tuple], + envs: dict[int, dict[str, str]], + logs_specs: LogsSpecs, + log_line_prefixes: dict[int, str] | None = None, + duplicate_stdout_filters: list[str] | None = None, + duplicate_stderr_filters: list[str] | None = None, + ): + self.name = name + # validate that all mappings have the same number of keys and + # all local ranks are accounted for + nprocs = len(args) + + # TODO log_line_prefixes can be expanded too + logs_dest = logs_specs.reify(envs) + + _validate_full_rank(logs_dest.stdouts, nprocs, "stdouts") + _validate_full_rank(logs_dest.stderrs, nprocs, "stderrs") + + self.entrypoint = entrypoint + self.args = args + self.envs = envs + self.stdouts = logs_dest.stdouts + self.stderrs = logs_dest.stderrs + self.error_files = logs_dest.error_files + self.nprocs = nprocs + self.filtered_stdout: TextIO | None = None + self.filtered_stderr: TextIO | None = None + + self._tail_logs = [ + TailLog(name, logs_dest.tee_stdouts, sys.stdout, log_line_prefixes), + TailLog(name, logs_dest.tee_stderrs, sys.stderr, log_line_prefixes), + ] + + if duplicate_stdout_filters: + self.filtered_stdout = open( # noqa: SIM115 + logs_dest.filtered_stdout, mode="w", errors="replace", buffering=1 + ) + self._tail_logs.append( + TailLog( + name, + logs_dest.tee_stdouts, + self.filtered_stdout, + log_line_prefixes, + log_line_filter=lambda line: any( + needle in line for needle in duplicate_stdout_filters + ), + ) + ) + + if duplicate_stderr_filters: + self.filtered_stderr = open( # noqa: SIM115 + logs_dest.filtered_stderr, mode="w", errors="replace", buffering=1 + ) + self._tail_logs.append( + TailLog( + name, + logs_dest.tee_stderrs, + self.filtered_stderr, + log_line_prefixes, + log_line_filter=lambda line: any( + needle in line for needle in duplicate_stderr_filters + ), + ) + ) + + def start(self) -> None: + """Start processes using parameters defined in the constructor.""" + if threading.current_thread() is threading.main_thread(): + # Register signal handlers for the signals specified in the environment variable + signals_to_handle = os.environ.get( + "TORCHELASTIC_SIGNALS_TO_HANDLE", "SIGTERM,SIGINT,SIGHUP,SIGQUIT" + ) + signal_list = signals_to_handle.split(",") + + for sig_name in signal_list: + try: + sig = getattr(signal, sig_name.strip()) + signal.signal(sig, _terminate_process_handler) + logger.info("Registered signal handler for %s", sig_name) + except (AttributeError, ValueError): + logger.warning( + "Failed to register signal handler for %s", + sig_name, + exc_info=True, + ) + except RuntimeError: + if IS_WINDOWS and sig_name.strip() in [ + "SIGHUP", + "SIGQUIT", + "SIGUSR1", + "SIGUSR2", + ]: + logger.info( + "Signal %s is not supported on Windows, skipping", sig_name + ) + else: + logger.warning( + "Failed to register signal handler for %s", + sig_name, + exc_info=True, + ) + else: + logger.warning( + "Failed to register signal handlers since torchelastic is running on a child thread. " + "This could lead to orphaned worker processes if the torchrun is terminated." + ) + self._start() + for tail_log in self._tail_logs: + tail_log.start() + + @abc.abstractmethod + def _start(self) -> None: + """Start processes using strategy defined in a particular context.""" + raise NotImplementedError + + @abc.abstractmethod + def _poll(self) -> RunProcsResult | None: + """ + Poll the run status of the processes running under this context. + This method follows an "all-or-nothing" policy and returns + a ``RunProcessResults`` object if either all processes complete + successfully or any process fails. Returns ``None`` if + all processes are still running. + """ + raise NotImplementedError + + def wait(self, timeout: float = -1, period: float = 1) -> RunProcsResult | None: + """ + Wait for the specified ``timeout`` seconds, polling every ``period`` seconds + for the processes to be done. Returns ``None`` if the processes are still running + on timeout expiry. Negative timeout values are interpreted as "wait-forever". + A timeout value of zero simply queries the status of the processes (e.g. equivalent + to a poll). + + .. note:: + Multiprocessing library registers SIGTERM and SIGINT signal handlers that raise + ``SignalException`` when the signals received. It is up to the consumer of the code + to properly handle the exception. It is important not to swallow the exception otherwise + the process would not terminate. Example of the typical workflow can be: + + .. code-block:: python + pc = start_processes(...) + try: + pc.wait(1) + .. do some other work + except SignalException as e: + pc.shutdown(e.sigval, timeout=30) + + If SIGTERM or SIGINT occurs, the code above will try to shutdown child processes by propagating + received signal. If child processes will not terminate in the timeout time, the process will send + the SIGKILL. + """ + if timeout == 0: + return self._poll() + + if timeout < 0: + timeout = sys.maxsize + + expiry = time.time() + timeout + while time.time() < expiry: + pr = self._poll() + if pr: + return pr + time.sleep(period) + + return None + + @abc.abstractmethod + def pids(self) -> dict[int, int]: + """Return pids of processes mapped by their respective local_ranks.""" + raise NotImplementedError + + @abc.abstractmethod + def _close(self, death_sig: signal.Signals, timeout: int = 30) -> None: + r""" + Terminates all processes managed by this context and cleans up any + meta resources (e.g. redirect, error_file files). + """ + raise NotImplementedError + + def close(self, death_sig: signal.Signals | None = None, timeout: int = 30) -> None: + r""" + Terminates all processes managed by this context and cleans up any + meta resources (e.g. redirect, error_file files). + + Args: + death_sig: Death signal to terminate processes. + timeout: Time to wait for processes to finish, if process is + still alive after this time, it will be terminated via SIGKILL. + """ + if not death_sig: + death_sig = _get_default_signal() + self._close(death_sig=death_sig, timeout=timeout) + for tail_log in self._tail_logs: + tail_log.stop() + if self.filtered_stdout: + self.filtered_stdout.close() + if self.filtered_stderr: + self.filtered_stderr.close() + + +def get_std_cm(std_rd: str, redirect_fn): + if IS_WINDOWS or IS_MACOS or not std_rd: + return nullcontext() + else: + return redirect_fn(std_rd) + + +def _wrap( + local_rank: int, + fn: Callable, + args: dict[int, tuple], + envs: dict[int, dict[str, str]], + stdout_redirects: dict[int, str], # redirect file for stdout (to console if None) + stderr_redirects: dict[int, str], # redirect file for stderr (to console if None) + ret_vals: dict[int, mp.SimpleQueue], + queue_finished_reading_event: synchronize.Event, + numa_options: NumaOptions | None, +) -> None: + # get the per-rank params up front so we fail fast if no mapping is found + args_ = args[local_rank] + env_ = envs[local_rank] + ret_val_ = ret_vals[local_rank] + + stdout_rd = stdout_redirects[local_rank] + stderr_rd = stderr_redirects[local_rank] + + stdout_cm = get_std_cm(stdout_rd, redirect_stdout) + stderr_cm = get_std_cm(stderr_rd, redirect_stderr) + + for k, v in env_.items(): + os.environ[k] = v + + with stdout_cm, stderr_cm: + fn = maybe_wrap_with_numa_binding( + fn, gpu_index=local_rank, numa_options=numa_options + ) + ret = record(fn)(*args_) + ret_val_.put(ret) + queue_finished_reading_event.wait() + + +class MultiprocessContext(PContext): + """``PContext`` holding worker processes invoked as a function.""" + + def __init__( + self, + name: str, + entrypoint: Callable, + args: dict[int, tuple], + envs: dict[int, dict[str, str]], + start_method: str, + logs_specs: LogsSpecs, + log_line_prefixes: dict[int, str] | None = None, + numa_options: NumaOptions | None = None, + duplicate_stdout_filters: list[str] | None = None, + duplicate_stderr_filters: list[str] | None = None, + ): + super().__init__( + name, + entrypoint, + args, + envs, + logs_specs, + log_line_prefixes, + duplicate_stdout_filters, + duplicate_stderr_filters, + ) + + self.start_method = start_method + # each ret_val queue will always contain a single element. + self._ret_vals = { + local_rank: mp.get_context(self.start_method).SimpleQueue() + for local_rank in range(self.nprocs) + } + + # see comments in ``join()`` for what this is + self._return_values: dict[int, Any] = {} + self._pc: mp.ProcessContext | None = None + # Note: set method should ONLY be invoked for the use case when all processes finished + # successfully. If any process died on event.wait() calling set() method will deadlock. + self._worker_finished_event = mp.get_context(self.start_method).Event() + + self._numa_options: NumaOptions | None = numa_options + + def _start(self): + if self._pc: + raise ValueError( + "The process context already initialized." + " Most likely the start method got called twice." + ) + self._pc = mp.start_processes( + fn=_wrap, + args=( + self.entrypoint, + self.args, + self.envs, + self.stdouts, + self.stderrs, + self._ret_vals, + self._worker_finished_event, + self._numa_options, + ), + nprocs=self.nprocs, + join=False, + daemon=False, + start_method=self.start_method, + ) + + def _is_done(self) -> bool: + return len(self._return_values) == self.nprocs + + def _poll(self) -> RunProcsResult | None: + assert self._pc is not None # assertion for mypy type checker + + try: + # torch.mp.ProcessContext Throws an Exception if some/all of + # worker processes failed + # timeout < 0 checks worker status and return immediately + # Join will never return success since we use synchronize.Event to wait + # for all processes to finish. + self._pc.join(-1) + + # IMPORTANT: we use multiprocessing.Queue to carry worker return values + # back to the parent, the worker process will wait before terminating + # until all the buffered items are fed by the feeder thread to the underlying + # pipe. Hence to prevent deadlocks on large return values, + # we opportunistically try queue.get on each join call + # See: https://docs.python.org/2/library/multiprocessing.html#all-platforms + for local_rank in range(self.nprocs): + return_queue = self._ret_vals[local_rank] + if not return_queue.empty(): + # save the return values temporarily into a member var + self._return_values[local_rank] = return_queue.get() + + if self._is_done(): + # we should ALWAYS have ALL the return values when all the processes are done + self._worker_finished_event.set() + + # At this point workers finished running the user function + # But the child process might still have not exited. Wait for them. + # pc.join() blocks [forever] until "a" proc exits. Loop until all of them exits. + while not self._pc.join(): + logger.debug( + "entrypoint fn finished, waiting for all child procs to exit..." + ) + + _validate_full_rank( + self._return_values, self.nprocs, "return_value queue" + ) + self.close() + return RunProcsResult( + return_values=self._return_values, + stdouts=self.stdouts, + stderrs=self.stderrs, + ) + else: + return None + except (mp.ProcessRaisedException, mp.ProcessExitedException) as e: + failed_local_rank = e.error_index + + # entrypoint for MultiprocessContext will always be a Callable + fn_name = self.entrypoint.__qualname__ # type: ignore[union-attr] + failed_proc = self._pc.processes[failed_local_rank] + error_filepath = self.error_files[failed_local_rank] + + logger.exception( + "failed (exitcode: %s)" + " local_rank: %s (pid: %s)" + " of fn: %s (start_method: %s)", + failed_proc.exitcode, + failed_local_rank, + e.error_pid, + fn_name, + self.start_method, + ) + + self.close() + return RunProcsResult( + failures={ + failed_local_rank: ProcessFailure( + local_rank=failed_local_rank, + pid=e.error_pid, + exitcode=failed_proc.exitcode, + error_file=error_filepath, + ) + }, + stdouts=self.stdouts, + stderrs=self.stderrs, + ) + + def pids(self) -> dict[int, int]: + assert self._pc is not None # assertion for mypy type checking + return dict(enumerate(self._pc.pids())) + + def _close(self, death_sig: signal.Signals, timeout: int = 30) -> None: + if not self._pc: + return + for proc in self._pc.processes: + if proc.is_alive(): + logger.warning( + "Closing process %s via signal %s", proc.pid, death_sig.name + ) + try: + os.kill(proc.pid, death_sig) + except ProcessLookupError: + # If the process exited because of some reason, + # `ProcessLookupError` will be raised, it is safe to ignore it. + pass + end = time.monotonic() + timeout + for proc in self._pc.processes: + time_to_wait = end - time.monotonic() + if time_to_wait <= 0: + break + proc.join(time_to_wait) + for proc in self._pc.processes: + if proc.is_alive(): + logger.warning( + "Unable to shutdown process %s via %s, forcefully exiting via %s", + proc.pid, + death_sig, + _get_kill_signal(), + ) + try: + os.kill(proc.pid, _get_kill_signal()) + except ProcessLookupError: + # If the process exited because of some reason, + # `ProcessLookupError` will be raised, it is safe to ignore it. + pass + proc.join() + + +class SubprocessContext(PContext): + """``PContext`` holding worker processes invoked as a binary.""" + + def __init__( + self, + name: str, + entrypoint: str, + args: dict[int, tuple], + envs: dict[int, dict[str, str]], + logs_specs: LogsSpecs, + log_line_prefixes: dict[int, str] | None = None, + numa_options: NumaOptions | None = None, + duplicate_stdout_filters: list[str] | None = None, + duplicate_stderr_filters: list[str] | None = None, + ): + super().__init__( + name, + entrypoint, + args, + envs, + logs_specs, + log_line_prefixes, + duplicate_stdout_filters, + duplicate_stderr_filters, + ) + + # state vector; _vdone[local_rank] -> is local_rank finished or not + self._running_local_ranks: set[int] = set(range(self.nprocs)) + self._failures: dict[int, ProcessFailure] = {} + self.subprocess_handlers: dict[int, SubprocessHandler] = {} + self._numa_options: NumaOptions | None = numa_options + + def _start(self): + if self.subprocess_handlers: + raise ValueError( + "The subprocess handlers already initialized. Most likely the start method got called twice." + ) + self.subprocess_handlers = { + local_rank: get_subprocess_handler( + entrypoint=self.entrypoint, # type: ignore[arg-type] # entrypoint is always a str + args=self.args[local_rank], + env=self.envs[local_rank], + stdout=self.stdouts[local_rank], + stderr=self.stderrs[local_rank], + local_rank_id=local_rank, + numa_options=self._numa_options, + ) + for local_rank in range(self.nprocs) + } + + def _capture_process_failures(self, done_local_ranks: set[int]): + for local_rank in self._running_local_ranks: + handler = self.subprocess_handlers[local_rank] + exitcode = handler.proc.poll() + if exitcode is not None: + done_local_ranks.add(local_rank) + if exitcode != 0: # failed or signaled + self._failures[local_rank] = ProcessFailure( + local_rank=local_rank, + pid=handler.proc.pid, + exitcode=exitcode, + error_file=self.error_files[local_rank], + ) + # else: --> succeeded; nothing to do + + def _poll(self) -> RunProcsResult | None: + done_local_ranks: set[int] = set() + self._capture_process_failures(done_local_ranks) + + self._running_local_ranks.difference_update(done_local_ranks) + + # if ALL procs are finished or ANY have failed + if not self._running_local_ranks or self._failures: + self.close() # terminate all running procs + self._capture_process_failures( + done_local_ranks + ) # log sigterms and sigkill exit codes in the self._failures for bookkeeping purposes + + result = RunProcsResult( + failures=self._failures, + stdouts=self.stdouts, + stderrs=self.stderrs, + ) + if result.is_failed(): + first_failure = min(result.failures.values(), key=lambda f: f.timestamp) + logger.error( + "failed (exitcode: %s) local_rank: %s (pid: %s) of binary: %s", + first_failure.exitcode, + first_failure.local_rank, + first_failure.pid, + self.entrypoint, + ) + else: + # Populate return with dummy values. This provides consistency with MultiprocessingHandler + result.return_values = dict.fromkeys(range(self.nprocs)) + + return result + else: # there are no failures and procs still running + return None + + def pids(self) -> dict[int, int]: + return { + local_rank: sh.proc.pid + for local_rank, sh in self.subprocess_handlers.items() + } + + def _close(self, death_sig: signal.Signals, timeout: int = 30) -> None: + if not self.subprocess_handlers: + return + for handler in self.subprocess_handlers.values(): + if handler.proc.poll() is None: + logger.warning( + "Sending process %s closing signal %s", + handler.proc.pid, + death_sig.name, + ) + handler.close(death_sig=death_sig) + end = time.monotonic() + timeout + for handler in self.subprocess_handlers.values(): + time_to_wait = end - time.monotonic() + if time_to_wait <= 0: + break + try: + handler.proc.wait(time_to_wait) + except subprocess.TimeoutExpired: + # Ignore the timeout expired exception, since + # the child process will be forcefully terminated via SIGKILL + pass + for handler in self.subprocess_handlers.values(): + if handler.proc.poll() is None: + logger.warning( + "Unable to shutdown process %s via %s, forcefully exiting via %s", + handler.proc.pid, + death_sig, + _get_kill_signal(), + ) + handler.close(death_sig=_get_kill_signal()) + handler.proc.wait() diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f61c99dc5c7779a5d839ca5b0364616b55079286 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py @@ -0,0 +1,390 @@ +#!/usr/bin/env python3 +# mypy: allow-untyped-defs + +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Each host in a distributed PyTorch job runs with a single TorchElastic agent, +and multiple workers (as children processes of the TorchElastic agent). +Since the workers are user-provided (your PyTorch script/job), TorchElastic +has a way to propagate errors on the trainers through the agent and up to the +scheduler, which ultimately informs the end-user about the state of the job +and applies any retry policies. + +TorchElastic categorizes errors into 3 categories: + ++----------------+----------------+--------------------------------------------------------------+ +| Category | Sub-Category | Description | ++================+================+==============================================================+ +| User Error | Input Error | invalid inputs to TorchElastic APIs (e.g. min > max nodes) | +| +----------------+--------------------------------------------------------------+ +| | Worker Failure | any failures on the worker child process | ++----------------+----------------+--------------------------------------------------------------+ +| Platform Error | n/a | failures caused by the agent | ++----------------+----------------+--------------------------------------------------------------+ +| Infra Error | n/a | failures outside the domain of the agent and workers | +| | | (e.g. host failures) | ++----------------+----------------+--------------------------------------------------------------+ + +All errors other than "Worker Failure" are either raised canonically from the +agent process or implicitly or explicitly crash the agent process. So the +standard language (python) provided exception handling strategies apply. + +Worker Failures are special because the exception/failure originates on a different +process from the agent so the error needs to be propagated inter-process +(e.g. the agent cannot simply ``try-catch`` an exception raised on the worker process). + +TorchElastic agents use :func:`torch.distributed.elastic.multiprocessing.start_processes` +to launch the workers which has a simple file based inter-process error propagation +built-in. + +Any function or binary entrypoint decorated with :func:`record` +will write uncaught exceptions (with the trace information) to a file specified by the +environment variable ``TORCHELASTIC_ERROR_FILE``. The parent process (e.g. agent) +sets this env var on each child it launches, then aggregates the error files for all +children, and propagates the one with the **smallest** timestamp (e.g. the **first** error). +""" + +import json +import os +import signal +import socket +import time +from collections.abc import Callable +from dataclasses import dataclass, field +from datetime import datetime +from functools import wraps +from string import Template +from typing import Any, Optional, TypeVar, Union +from typing_extensions import ParamSpec + +from torch.distributed.elastic.utils.logging import get_logger + +from .error_handler import ErrorHandler # noqa: F401 +from .handlers import get_error_handler # noqa: F401 + + +__all__ = [ + "ProcessFailure", + "ChildFailedError", + "record", + "ErrorHandler", + "get_error_handler", +] + +logger = get_logger(__name__) + + +JSON = dict[str, Any] + +_EMPTY_ERROR_DATA: dict[str, Any] = {"message": ""} +_NOT_AVAILABLE = "" + +_R = TypeVar("_R") +_P = ParamSpec("_P") + + +@dataclass +class ProcessFailure: + """ + Represent the failed process result. When the worker process fails, it may record failure root cause into the file. + + Tries to read the failure timestamp from the provided ``error_file``, + if the ``error_file`` does not exist, the timestamp is the current + timestamp (seconds since epoch). + + The ``message`` field is a concise explanation of the failure. If + the error file exists then the message is obtained from the error file. + Otherwise one is generated based on the failure signature. + + .. note:: It is assumed that the ``error_file`` is written by + ``torch.distributed.elastic.multiprocessing.errors.error_handler.ErrorHandler``. + Otherwise the behavior is undefined. + + """ + + local_rank: int + pid: int + exitcode: int + error_file: str + error_file_data: JSON = field(init=False) + message: str = field(init=False) + timestamp: int = field(init=False) + + def __post_init__(self): + self.error_file_data = _EMPTY_ERROR_DATA + if os.path.isfile(self.error_file): + try: + with open(self.error_file) as fp: + self.error_file_data = json.load(fp) + logger.debug( + "User process failed with error data: %s", + json.dumps(self.error_file_data, indent=2), + ) + self.message, self.timestamp = self._get_error_data( + self.error_file_data + ) + except Exception: + logger.exception("Failed to parse reply file: %s", self.error_file) + raise + else: + self._set_no_reply_file() + + # make up an informative message if not already present + if not self.message: + # signals typically do not generate an error file message + if self.exitcode < 0: + self.message = ( + f"Signal {-self.exitcode} ({self.signal_name()})" + f" received by PID {self.pid}" + ) + else: + self.error_file_data["errorTraits"] = { + "category": "system_terminated_error", + "retryability": "False", + } + self.message = "To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html" + + def _get_error_data(self, error_file_data: dict[str, Any]) -> tuple[str, int]: + message = error_file_data["message"] + if isinstance(message, str): + timestamp = int(error_file_data.get("timestamp", 0)) + else: + timestamp = int(message["extraInfo"]["timestamp"]) + return (message, timestamp) + + def _set_no_reply_file(self): + self.error_file = _NOT_AVAILABLE + self.error_file_data = _EMPTY_ERROR_DATA + self.message = "" + self.timestamp = int(time.time()) + + def signal_name(self) -> str: + if self.exitcode < 0: + # We don't want to kill the parent process trying to find the signal name. + # if the signal doesn't map to a known name, use not available. + try: + return signal.Signals(-self.exitcode).name + except Exception: + return _NOT_AVAILABLE + else: + return _NOT_AVAILABLE + + def timestamp_isoformat(self): + """Return timestamp in ISO format (YYYY-MM-DD_HH:MM:SS).""" + return datetime.fromtimestamp(self.timestamp).isoformat(sep="_") + + +GlobalRank = int + +_FAILURE_FORMAT_TEMPLATE = """[${idx}]: + time : ${time} + host : ${hostname} + rank : ${rank} (local_rank: ${local_rank}) + exitcode : ${exitcode} (pid: ${pid}) + error_file: ${error_file} + traceback : ${message}""" + +# extra new lines before and after are intentional +_MSG_FORMAT_TEMPLATE = """ +${boarder} +${title} +${section} +Failures: +${other_failures} +${section} +Root Cause (first observed failure): +${root_failure} +${boarder}""" + + +class ChildFailedError(Exception): + """ + Special exception type that can be raised from a function annotated with the + ``@record`` decorator to have the child process' (root exception) propagate + up the stack as-is (e.g. without being wrapped in the parent's traceback). + + Useful in cases where the parent is a simple nanny process + and the child (worker) processes are actually doing meaningful compute. + In this case, errors typically occur on the child process as the parent + is not doing anything non-trivial, and child errors should be propagated + to the scheduler for accurate root cause diagnostics. + + .. note:: The propagation relies on error files rather than exception handling to + support both function and binary launches. + + Example: + :: + + # process tree on a host (container) + 0: scheduler-init-process: + |- 1: torchelastic_agent: + |- 2: trainer_0 (ok) + |- 3: trainer_1 (fail) -> error.json + |- ... + |- n+2: trainer_n (ok) + |- n+3: other processes + |- ... + + In the example above, trainer 1's failure (written into error.json) is + the root cause and should be reported to the scheduler's init process. + The torchelastic agent raises a ``ChildFailedError("trainer", {1: "trainer_1/error.json"})`` + upon detecting trainer 1's failure which would propagate the contents + of trainer 1's error file to the scheduler's init process. + """ + + def __init__(self, name: str, failures: dict[GlobalRank, ProcessFailure]): + self.name = name + self.failures = failures + assert ( + self.failures + ) # does not make sense to create a ChildFaileError with no failures + super().__init__(self.format_msg()) + + def get_first_failure(self) -> tuple[GlobalRank, ProcessFailure]: + rank = min(self.failures.keys(), key=lambda r: self.failures[r].timestamp) + return rank, self.failures[rank] + + def format_msg(self, boarder_delim="=", section_delim="-"): + title = f"{self.name} FAILED" + root_rank, _root_failure = self.get_first_failure() + + root_failure_fmt: str = "" + other_failures_fmt: list[str] = [] + width = len(title) + for idx, (rank, failure) in enumerate(self.failures.items()): + fmt, w = self._format_failure(idx, rank, failure) + width = max(width, w) + if rank == root_rank: + root_failure_fmt = fmt + else: + other_failures_fmt.append(fmt) + + # upper boundary on width + width = min(width, 60) + + return Template(_MSG_FORMAT_TEMPLATE).substitute( + boarder=boarder_delim * width, + title=title, + section=section_delim * width, + root_failure=root_failure_fmt, + other_failures="\n".join(other_failures_fmt or [" "]), + ) + + def _format_failure( + self, idx: int, rank: int, failure: ProcessFailure + ) -> tuple[str, int]: + # failure.message is either a str (when the failure does not generate a traceback - e.g. signals) + # or a dict (json) of the form + # {"message": $ERROR_MSG, "extraInfo": {"py_callstack": $TRACEBACK, timestamp: $TS}} + # so the display logic is: + # 1. if failure.message is not a dict (it is a str) just show it as is + # 2. else try to get the traceback (py_callstack) + # 3. if the traceback is not there, use the message + # 4. if the message is not there show + msg = failure.message + if isinstance(failure.message, dict): + msg = ( + failure.message.get("extraInfo", {}) + .get("py_callstack", failure.message.get("message", "")) + .replace("\n", "\n ") # to properly indent the traceback + ) + + fmt = Template(_FAILURE_FORMAT_TEMPLATE).substitute( + idx=idx, + time=failure.timestamp_isoformat(), + hostname=socket.getfqdn(), + rank=rank, + local_rank=failure.local_rank, + exitcode=failure.exitcode, + pid=failure.pid, + error_file=failure.error_file, + message=msg, + ) + width = 0 + for line in fmt.split("\n"): + width = max(width, len(line)) + return fmt, width + + +def record( + fn: Callable[_P, _R], error_handler: ErrorHandler | None = None +) -> Callable[_P, _R | None]: + """ + Syntactic sugar to record errors/exceptions that happened in the decorated + function using the provided ``error_handler``. + + Using this decorator is equivalent to: + + :: + + error_handler = get_error_handler() + error_handler.initialize() + try: + foobar() + except ChildFailedError as e: + _, failure = e.get_first_failure() + error_handler.dump_error_file(failure.error_file, failure.exitcode) + raise + except Exception as e: + error_handler.record_exception(e) + raise + + .. important:: use this decorator once per process at the top level method, + typically this is the main method. + + Example + + :: + + @record + def main(): + pass + + + if __name__ == "__main__": + main() + + """ + if not error_handler: + error_handler = get_error_handler() + + def wrap(f: Callable[_P, _R]) -> Callable[_P, _R | None]: + @wraps(f) + def wrapper(*args: _P.args, **kwargs: _P.kwargs): + assert error_handler is not None # assertion for mypy type checker + error_handler.initialize() + try: + return f(*args, **kwargs) + except SystemExit as se: + # For run_path based entrypoints, SystemExit with code = 0 will never exit. + # Handling it here by returning a value: + if se.code == 0: + return None + else: + raise + except ChildFailedError as e: + rank, failure = e.get_first_failure() + if failure.error_file != _NOT_AVAILABLE: + error_handler.dump_error_file(failure.error_file, failure.exitcode) + else: + logger.info( + ( + "local_rank %s FAILED with no error file." + " Decorate your entrypoint fn with @record for traceback info." + " See: https://pytorch.org/docs/stable/elastic/errors.html", + rank, + ) + ) + raise + except Exception as e: + error_handler.record_exception(e) + raise + + return wrapper + + return wrap(fn) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/multiprocessing/errors/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/multiprocessing/errors/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..891e7e9ba115b34ed43b708121633a43bdd9980a Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/multiprocessing/errors/__pycache__/__init__.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/multiprocessing/errors/__pycache__/error_handler.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/multiprocessing/errors/__pycache__/error_handler.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..daf29e0c95259aa1e1292f371979c2dd74c18a58 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/multiprocessing/errors/__pycache__/error_handler.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/multiprocessing/errors/__pycache__/handlers.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/multiprocessing/errors/__pycache__/handlers.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..497726974806c4b252c10e54445633c73b4a362f Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/multiprocessing/errors/__pycache__/handlers.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/multiprocessing/errors/error_handler.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/multiprocessing/errors/error_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..ab6613e54dee10edbec54abcc5bc689b01676358 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/multiprocessing/errors/error_handler.py @@ -0,0 +1,170 @@ +#!/usr/bin/env python3 +# mypy: allow-untyped-defs + +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +import faulthandler +import json +import logging +import os +import time +import traceback +import warnings +from typing import Any + + +__all__ = ["ErrorHandler"] + +logger = logging.getLogger(__name__) + + +class ErrorHandler: + """ + Write the provided exception object along with some other metadata about + the error in a structured way in JSON format to an error file specified by the + environment variable: ``TORCHELASTIC_ERROR_FILE``. If this environment + variable is not set, then simply logs the contents of what would have been + written to the error file. + + This handler may be subclassed to customize the handling of the error. + Subclasses should override ``initialize()`` and ``record_exception()``. + """ + + def _get_error_file_path(self) -> str | None: + """ + Return the error file path. + + May return ``None`` to have the structured error be logged only. + """ + return os.environ.get("TORCHELASTIC_ERROR_FILE", None) + + def initialize(self) -> None: + """ + Call prior to running code that we wish to capture errors/exceptions. + + Typically registers signal/fault handlers. Users can override this + function to add custom initialization/registrations that aid in + propagation/information of errors/signals/exceptions/faults. + """ + try: + faulthandler.enable(all_threads=True) + except Exception as e: + warnings.warn( + f"Unable to enable fault handler. {type(e).__name__}: {e}", stacklevel=2 + ) + + def _write_error_file(self, file_path: str, error_msg: str) -> None: + """Write error message to the file.""" + try: + with open(file_path, "w") as fp: + fp.write(error_msg) + except Exception as e: + warnings.warn( + f"Unable to write error to file. {type(e).__name__}: {e}", stacklevel=2 + ) + + def record_exception(self, e: BaseException) -> None: + """ + Write a structured information about the exception into an error file in JSON format. + + If the error file cannot be determined, then logs the content + that would have been written to the error file. + """ + file = self._get_error_file_path() + if file: + data = { + "message": { + "message": f"{type(e).__name__}: {e}", + "extraInfo": { + "py_callstack": traceback.format_exc(), + "timestamp": str(int(time.time())), + }, + } + } + with open(file, "w") as fp: + json.dump(data, fp) + + def override_error_code_in_rootcause_data( + self, + rootcause_error_file: str, + rootcause_error: dict[str, Any], + error_code: int = 0, + ): + """Modify the rootcause_error read from the file, to correctly set the exit code.""" + if "message" not in rootcause_error: + logger.warning( + "child error file (%s) does not have field `message`. \n" + "cannot override error code: %s", + rootcause_error_file, + error_code, + ) + elif isinstance(rootcause_error["message"], str): + logger.warning( + "child error file (%s) has a new message format. \n" + "skipping error code override", + rootcause_error_file, + ) + else: + rootcause_error["message"]["errorCode"] = error_code + + def dump_error_file(self, rootcause_error_file: str, error_code: int = 0): + """Dump parent error file from child process's root cause error and error code.""" + with open(rootcause_error_file) as fp: + rootcause_error = json.load(fp) + # Override error code since the child process cannot capture the error code if it + # is terminated by signals like SIGSEGV. + if error_code: + self.override_error_code_in_rootcause_data( + rootcause_error_file, rootcause_error, error_code + ) + logger.debug( + "child error file (%s) contents:\n%s", + rootcause_error_file, + json.dumps(rootcause_error, indent=2), + ) + + my_error_file = self._get_error_file_path() + if my_error_file: + # Guard against existing error files + # This can happen when the child is created using multiprocessing + # and the same env var (TORCHELASTIC_ERROR_FILE) is used on the + # parent and child to specify the error files (respectively) + # because the env vars on the child is set in the wrapper function + # and by default the child inherits the parent's env vars, if the child + # process receives a signal before the wrapper function kicks in + # and the signal handler writes to the error file, then the child + # will write to the parent's error file. In this case just log the + # original error file contents and overwrite the error file. + self._rm(my_error_file) + self._write_error_file(my_error_file, json.dumps(rootcause_error)) + logger.info("dumped error file to parent's %s", my_error_file) + else: + logger.error( + "no error file defined for parent, to copy child error file (%s)", + rootcause_error_file, + ) + + def _rm(self, my_error_file): + if os.path.isfile(my_error_file): + # Log the contents of the original file. + with open(my_error_file) as fp: + try: + original = json.dumps(json.load(fp), indent=2) + logger.warning( + "%s already exists" + " and will be overwritten." + " Original contents:\n%s", + my_error_file, + original, + ) + except json.decoder.JSONDecodeError: + logger.warning( + "%s already exists" + " and will be overwritten." + " Unable to load original contents:\n", + my_error_file, + ) + os.remove(my_error_file) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/multiprocessing/errors/handlers.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/multiprocessing/errors/handlers.py new file mode 100644 index 0000000000000000000000000000000000000000..6721217a41190c2bdd6bf2293540a33c893c145d --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/multiprocessing/errors/handlers.py @@ -0,0 +1,18 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# Multiprocessing error-reporting module + + +from torch.distributed.elastic.multiprocessing.errors.error_handler import ErrorHandler + + +__all__ = ["get_error_handler"] + + +def get_error_handler() -> ErrorHandler: + return ErrorHandler() diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/multiprocessing/redirects.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/multiprocessing/redirects.py new file mode 100644 index 0000000000000000000000000000000000000000..057013fbb9e5b8a2aeca69b41d7679cbe75c0e28 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/multiprocessing/redirects.py @@ -0,0 +1,104 @@ +# mypy: allow-untyped-defs +# !/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# Taken and modified from original source: +# https://eli.thegreenplace.net/2015/redirecting-all-kinds-of-stdout-in-python/ +import ctypes +import logging +import os +import sys +from contextlib import contextmanager +from functools import partial + + +IS_WINDOWS = sys.platform == "win32" +IS_MACOS = sys.platform == "darwin" + + +logger = logging.getLogger(__name__) + + +def get_libc(): + if IS_WINDOWS or IS_MACOS: + logger.warning( + "NOTE: Redirects are currently not supported in Windows or MacOs." + ) + return None + else: + return ctypes.CDLL("libc.so.6") + + +libc = get_libc() + + +def _c_std(stream: str): + return ctypes.c_void_p.in_dll(libc, stream) + + +def _python_std(stream: str): + return {"stdout": sys.stdout, "stderr": sys.stderr}[stream] + + +_VALID_STD = {"stdout", "stderr"} + + +@contextmanager +def redirect(std: str, to_file: str): + """ + Redirect ``std`` (one of ``"stdout"`` or ``"stderr"``) to a file in the path specified by ``to_file``. + + This method redirects the underlying std file descriptor (not just python's ``sys.stdout|stderr``). + See usage for details. + + Directory of ``dst_filename`` is assumed to exist and the destination file + is overwritten if it already exists. + + .. note:: Due to buffering cross source writes are not guaranteed to + appear in wall-clock order. For instance in the example below + it is possible for the C-outputs to appear before the python + outputs in the log file. + + Usage: + + :: + + # syntactic-sugar for redirect("stdout", "tmp/stdout.log") + with redirect_stdout("/tmp/stdout.log"): + print("python stdouts are redirected") + libc = ctypes.CDLL("libc.so.6") + libc.printf(b"c stdouts are also redirected" + os.system("echo system stdouts are also redirected") + + print("stdout restored") + + """ + if std not in _VALID_STD: + raise ValueError( + f"unknown standard stream <{std}>, must be one of {_VALID_STD}" + ) + + c_std = _c_std(std) + python_std = _python_std(std) + std_fd = python_std.fileno() + + def _redirect(dst): + libc.fflush(c_std) + python_std.flush() + os.dup2(dst.fileno(), std_fd) + + with os.fdopen(os.dup(std_fd)) as orig_std, open(to_file, mode="w+b") as dst: + _redirect(dst) + try: + yield + finally: + _redirect(orig_std) + + +redirect_stdout = partial(redirect, "stdout") +redirect_stderr = partial(redirect, "stderr") diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/multiprocessing/subprocess_handler/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/multiprocessing/subprocess_handler/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f56d423ce080fd7c331dc9b43eda58e5370678fc --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/multiprocessing/subprocess_handler/__init__.py @@ -0,0 +1,16 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +from torch.distributed.elastic.multiprocessing.subprocess_handler.handlers import ( + get_subprocess_handler, +) +from torch.distributed.elastic.multiprocessing.subprocess_handler.subprocess_handler import ( + SubprocessHandler, +) + + +__all__ = ["SubprocessHandler", "get_subprocess_handler"] diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/multiprocessing/subprocess_handler/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/multiprocessing/subprocess_handler/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c8ee7f7ef7c7fa9cd7192717707fa74920075e24 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/multiprocessing/subprocess_handler/__pycache__/__init__.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/multiprocessing/subprocess_handler/__pycache__/handlers.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/multiprocessing/subprocess_handler/__pycache__/handlers.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d1c2ea93c96a259e53a67018e2e71959e0d1c57a Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/multiprocessing/subprocess_handler/__pycache__/handlers.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/multiprocessing/subprocess_handler/__pycache__/subprocess_handler.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/multiprocessing/subprocess_handler/__pycache__/subprocess_handler.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c60524ae7c79440fb74038c064a62ef2d469bd54 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/multiprocessing/subprocess_handler/__pycache__/subprocess_handler.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/multiprocessing/subprocess_handler/handlers.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/multiprocessing/subprocess_handler/handlers.py new file mode 100644 index 0000000000000000000000000000000000000000..ea1742626e285838485c19911704792510d13fb4 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/multiprocessing/subprocess_handler/handlers.py @@ -0,0 +1,33 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from torch.distributed.elastic.multiprocessing.subprocess_handler.subprocess_handler import ( + SubprocessHandler, +) +from torch.numa.binding import NumaOptions + + +__all__ = ["get_subprocess_handler"] + + +def get_subprocess_handler( + entrypoint: str, + args: tuple, + env: dict[str, str], + stdout: str, + stderr: str, + local_rank_id: int, + numa_options: NumaOptions | None = None, +) -> SubprocessHandler: + return SubprocessHandler( + entrypoint=entrypoint, + args=args, + env=env, + stdout=stdout, + stderr=stderr, + local_rank_id=local_rank_id, + numa_options=numa_options, + ) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/multiprocessing/subprocess_handler/subprocess_handler.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/multiprocessing/subprocess_handler/subprocess_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..268817108d8cd20f6ba0130818286d297da78c4e --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/multiprocessing/subprocess_handler/subprocess_handler.py @@ -0,0 +1,89 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +import os +import signal +import sys +from subprocess import Popen +from typing import Any + +from torch.numa.binding import maybe_wrap_command_args_with_numa_binding, NumaOptions + + +__all__ = ["SubprocessHandler"] + +IS_WINDOWS = sys.platform == "win32" + + +def _get_default_signal() -> signal.Signals: + """Get the default termination signal. SIGTERM for unix, CTRL_C_EVENT for windows.""" + if IS_WINDOWS: + return signal.CTRL_C_EVENT # type: ignore[attr-defined] # noqa: F821 + else: + return signal.SIGTERM + + +class SubprocessHandler: + """ + Convenience wrapper around python's ``subprocess.Popen``. Keeps track of + meta-objects associated to the process (e.g. stdout and stderr redirect fds). + """ + + def __init__( + self, + entrypoint: str, + args: tuple, + env: dict[str, str], + stdout: str | None, + stderr: str | None, + local_rank_id: int, + numa_options: NumaOptions | None, + ): + self._stdout = open(stdout, "w") if stdout else None # noqa: SIM115 + self._stderr = open(stderr, "w") if stderr else None # noqa: SIM115 + # inherit parent environment vars + env_vars = os.environ.copy() + env_vars.update(env) + + args_str = (entrypoint, *[str(e) for e in args]) + args_str = maybe_wrap_command_args_with_numa_binding( + args_str, + gpu_index=local_rank_id, + numa_options=numa_options, + ) + + self.local_rank_id = local_rank_id + + self.proc: Popen = self._popen(args_str, env_vars) + + def _popen(self, args: tuple, env: dict[str, str]) -> Popen: + kwargs: dict[str, Any] = {} + if not IS_WINDOWS: + kwargs["start_new_session"] = True + + return Popen( + # pyre-fixme[6]: Expected `Union[typing.Sequence[Union[_PathLike[bytes], + # _PathLike[str], bytes, str]], bytes, str]` for 1st param but got + # `Tuple[str, *Tuple[Any, ...]]`. + args=args, + env=env, + stdout=self._stdout, + stderr=self._stderr, + **kwargs, + ) + + def close(self, death_sig: signal.Signals | None = None) -> None: + if not death_sig: + death_sig = _get_default_signal() + if IS_WINDOWS: + self.proc.send_signal(death_sig) + else: + os.killpg(self.proc.pid, death_sig) + if self._stdout: + self._stdout.close() + if self._stderr: + self._stderr.close() diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/multiprocessing/tail_log.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/multiprocessing/tail_log.py new file mode 100644 index 0000000000000000000000000000000000000000..77d410cce55c09b0acd79ebf4583028f5a7bb759 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/multiprocessing/tail_log.py @@ -0,0 +1,168 @@ +#!/usr/bin/env python3 +# mypy: allow-untyped-defs + +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import os +import time +from collections.abc import Callable +from concurrent.futures.thread import ThreadPoolExecutor +from threading import Event +from typing import TextIO, TYPE_CHECKING + + +if TYPE_CHECKING: + from concurrent.futures._base import Future + +__all__ = ["tail_logfile", "TailLog"] + +logger = logging.getLogger(__name__) + + +def tail_logfile( + header: str, + file: str, + dst: TextIO, + finished: Event, + interval_sec: float, + log_line_filter: Callable[[str], bool] | None = None, +): + while not os.path.exists(file): + if finished.is_set(): + return + time.sleep(interval_sec) + + with open(file, errors="replace") as fp: + while True: + line = fp.readline() + + if line: + if log_line_filter and log_line_filter(line): + dst.write(f"{header}{line}") + else: # reached EOF + if finished.is_set(): + # log line producer is finished + break + else: + # log line producer is still going + # wait for a bit before looping again + time.sleep(interval_sec) + + +class TailLog: + """ + Tail the given log files. + + The log files do not have to exist when the ``start()`` method is called. The tail-er will gracefully wait until + the log files are created by the producer and will tail the contents of the + log files until the ``stop()`` method is called. + + .. warning:: ``TailLog`` will wait indefinitely for the log file to be created! + + Each log file's line will be suffixed with a header of the form: ``[{name}{idx}]:``, + where the ``name`` is user-provided and ``idx`` is the index of the log file + in the ``log_files`` mapping. ``log_line_prefixes`` can be used to override the + header for each log file. + + Usage: + + :: + + log_files = {0: "/tmp/0_stdout.log", 1: "/tmp/1_stdout.log"} + tailer = TailLog("trainer", log_files, sys.stdout).start() + # actually run the trainers to produce 0_stdout.log and 1_stdout.log + run_trainers() + tailer.stop() + + # once run_trainers() start writing the ##_stdout.log files + # the tailer will print to sys.stdout: + # >>> [trainer0]:log_line1 + # >>> [trainer1]:log_line1 + # >>> [trainer0]:log_line2 + # >>> [trainer0]:log_line3 + # >>> [trainer1]:log_line2 + + .. note:: Due to buffering log lines between files may not necessarily + be printed out in order. You should configure your application's + logger to suffix each log line with a proper timestamp. + + """ + + def __init__( + self, + name: str, + log_files: dict[int, str], + dst: TextIO, + log_line_prefixes: dict[int, str] | None = None, + interval_sec: float = 0.1, + log_line_filter: Callable[[str], bool] = (lambda _: True), + ): + n = len(log_files) + self._threadpool = None + if n > 0: + # pyrefly: ignore [bad-assignment] + self._threadpool = ThreadPoolExecutor( + max_workers=n, + thread_name_prefix=f"{self.__class__.__qualname__}_{name}", + ) + + self._name = name + self._dst = dst + self._log_files = log_files + self._log_line_prefixes = log_line_prefixes + self._log_line_filter = log_line_filter + self._finished_events: dict[int, Event] = { + local_rank: Event() for local_rank in log_files + } + self._futs: list[Future] = [] + self._interval_sec = interval_sec + self._stopped = False + + def start(self) -> "TailLog": + if not self._threadpool or not self._dst: + return self + + for local_rank, file in self._log_files.items(): + header = f"[{self._name}{local_rank}]:" + if self._log_line_prefixes and local_rank in self._log_line_prefixes: + header = self._log_line_prefixes[local_rank] + self._futs.append( + self._threadpool.submit( + tail_logfile, + header=header, + file=file, + dst=self._dst, + finished=self._finished_events[local_rank], + interval_sec=self._interval_sec, + log_line_filter=self._log_line_filter, + ) + ) + return self + + def stop(self) -> None: + for finished in self._finished_events.values(): + finished.set() + + for local_rank, f in enumerate(self._futs): + try: + f.result() + except Exception as e: + logger.exception( + "error in log tailor for %s%s. %s", + self._name, + local_rank, + e.__class__.__qualname__, + ) + + if self._threadpool: + self._threadpool.shutdown(wait=True) + + self._stopped = True + + def stopped(self) -> bool: + return self._stopped diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/rendezvous/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/rendezvous/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c387a3ec2833ac643c571afa7a194a1dc0d3fbea --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/rendezvous/__init__.py @@ -0,0 +1,163 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +In the context of Torch Distributed Elastic we use the term *rendezvous* to +refer to a particular functionality that combines a **distributed +synchronization** primitive with **peer discovery**. + +It is used by Torch Distributed Elastic to gather participants of a training +job (i.e. nodes) such that they all agree on the same list of participants and +everyone's roles, as well as make a consistent collective decision on when +training can begin/resume. + +Torch Distributed Elastic rendezvous provides the following critical +functionalities: + +**Barrier**: + +Nodes performing rendezvous will all block until the rendezvous is considered +complete - this happens when at least ``min`` total number of nodes have joined +the rendezvous barrier (for the same job). This also implies the barrier is not +necessarily of fixed size. + +There's an additional small waiting time after reaching ``min`` number of +nodes - this is used to ensure the rendezvous is not completed "too quickly" +(which could potentially exclude additional nodes attempting to join at +approximately the same time). + +If ``max`` number of nodes is gathered at the barrier, the rendezvous is +completed immediately. + +There's also an overall timeout which causes the rendezvous to fail if ``min`` +number of nodes is never reached - this is meant to be a simple fail-safe to +help release partially allocated job resources, in case there's a problem with +the resource manager, and is meant to be interpreted as non-retryable. + +**Exclusivity**: + +A simple distributed barrier would not be sufficient, as we also need to ensure +that only one group of nodes exists at any given time (for a given job). In +other words, new nodes (i.e. joining late) should not be able to form a parallel +independent group of workers for the same job. + +Torch Distributed Elastic rendezvous ensures that if a group of nodes has +already completed a rendezvous (and hence might already be training), then +additional "late" nodes attempting to rendezvous will only announce themselves +as waiting, and will have to wait until the (previously completed) existing +rendezvous is destroyed first. + +**Consistency**: + +When a rendezvous is completed, all its members will agree on the job membership +and everyone's role in it. This role is represented using an integer, called +rank, that is between between 0 and world size. + +Note that ranks are *not stable*, in the sense that the same node can be +assigned a different rank in the next (re-)rendezvous. + +**Fault-tolerance**: + +Torch Distributed Elastic rendezvous is designed to tolerate node failures +during the rendezvous process. Should a process crash (or lose network +connectivity, etc), between joining the rendezvous and it being completed, then +a re-rendezvous with remaining healthy nodes will happen automatically. + +A node can also fail *after* it has completed (or *has been observed* by other +nodes to have completed) the rendezvous - this scenario will be handled by the +Torch Distributed Elastic ``train_loop`` instead (where it will also trigger a +re-rendezvous). + +**Shared key-value store**: + +When the rendezvous is completed, a shared key-value store is created and +returned. This store implements a ``torch.distributed.Store`` API (see +`distributed communication docs +`__). + +This store is only shared by the members of the completed rendezvous. It +is intended to be used by Torch Distributed Elastic to exchange information +necessary to initialize job control and data-planes. + +**Waiting workers and rendezvous closing**: + +Torch Distributed Elastic rendezvous handler object provides additional +functionalities, which are technically not part of the rendezvous process: + +1. Querying how many workers arrived late at the barrier, who can participate in + *next* rendezvous. + +2. Setting the rendezvous *closed* to signal all nodes not to participate in + next rendezvous. + +**DynamicRendezvousHandler**: + +Torch Distributed Elastic comes with the :py:class:`.DynamicRendezvousHandler` +class that implements the rendezvous mechanism described above. It is a backend- +agnostic type that expects a particular :py:class:`.RendezvousBackend` instance +to be specified during construction. + +Torch distributed users can either implement their own backend type or use one +of the following implementations that come with PyTorch: + +- :py:class:`.C10dRendezvousBackend`: Uses a C10d store (by default + ``TCPStore``) as the rendezvous backend. The main advantage of using a C10d + store is that it requires no 3rd-party dependency (such as etcd) to establish + a rendezvous. +- :py:class:`.EtcdRendezvousBackend`: Supersedes the legacy + :py:class:`.EtcdRendezvousHandler` class. Passing an + :py:class:`.EtcdRendezvousBackend` instance to + :py:class:`.DynamicRendezvousHandler` is functionally equivalent to + instantiating an :py:class:`.EtcdRendezvousHandler`. + + :: + + store = TCPStore("localhost") + + backend = C10dRendezvousBackend(store, "my_run_id") + + rdzv_handler = DynamicRendezvousHandler.from_backend( + run_id="my_run_id", store=store, backend=backend, min_nodes=2, max_nodes=4 + ) +""" + +from .api import ( + rendezvous_handler_registry, + RendezvousClosedError, + RendezvousConnectionError, + RendezvousError, + RendezvousGracefulExitError, + RendezvousHandler, + RendezvousHandlerCreator, + RendezvousHandlerRegistry, + RendezvousInfo, + RendezvousParameters, + RendezvousStateError, + RendezvousStoreInfo, + RendezvousTimeoutError, +) +from .registry import _register_default_handlers, _register_out_of_tree_handlers + + +_register_default_handlers() +_register_out_of_tree_handlers() + + +__all__ = [ + "RendezvousClosedError", + "RendezvousConnectionError", + "RendezvousError", + "RendezvousGracefulExitError", + "RendezvousHandler", + "RendezvousHandlerCreator", + "RendezvousHandlerRegistry", + "RendezvousInfo", + "RendezvousParameters", + "RendezvousStateError", + "RendezvousStoreInfo", + "RendezvousTimeoutError", + "rendezvous_handler_registry", +] diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/rendezvous/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/rendezvous/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..83b7cac4611572c82ef1d91960e99351cd1cd368 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/rendezvous/__pycache__/__init__.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/rendezvous/__pycache__/_etcd_stub.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/rendezvous/__pycache__/_etcd_stub.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3d864bc2248e7418ab33023d9640db04d367bdd3 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/rendezvous/__pycache__/_etcd_stub.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/rendezvous/__pycache__/api.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/rendezvous/__pycache__/api.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5f856b44ca83f93122ef5aa62d93814249d24085 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/rendezvous/__pycache__/api.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/rendezvous/__pycache__/c10d_rendezvous_backend.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/rendezvous/__pycache__/c10d_rendezvous_backend.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1578575555431c78b9e6d35153dfb435e16b0bef Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/rendezvous/__pycache__/c10d_rendezvous_backend.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/rendezvous/__pycache__/dynamic_rendezvous.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/rendezvous/__pycache__/dynamic_rendezvous.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..591cd09199c1aefcf2b8c76552c0d739f618e012 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/rendezvous/__pycache__/dynamic_rendezvous.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/rendezvous/__pycache__/etcd_rendezvous.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/rendezvous/__pycache__/etcd_rendezvous.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d63ce2a418cff27bd530109b49e9ae237ebc7364 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/rendezvous/__pycache__/etcd_rendezvous.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/rendezvous/__pycache__/etcd_rendezvous_backend.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/rendezvous/__pycache__/etcd_rendezvous_backend.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..48453291c91b2f296dd1f42e1bd6dc9cfd162ac4 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/rendezvous/__pycache__/etcd_rendezvous_backend.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/rendezvous/__pycache__/etcd_server.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/rendezvous/__pycache__/etcd_server.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c3f158262216316b5707b7ba72614a09f93aa1bb Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/rendezvous/__pycache__/etcd_server.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/rendezvous/__pycache__/etcd_store.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/rendezvous/__pycache__/etcd_store.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6df9387b0fab7924c65c36e459066644cc2c8c6d Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/rendezvous/__pycache__/etcd_store.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/rendezvous/__pycache__/registry.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/rendezvous/__pycache__/registry.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8a115e667f582cf2c4933961c28e6158dc93ebeb Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/rendezvous/__pycache__/registry.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/rendezvous/__pycache__/static_tcp_rendezvous.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/rendezvous/__pycache__/static_tcp_rendezvous.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..02731883f783a1dbdb780e065f875e11aaeaa1a2 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/rendezvous/__pycache__/static_tcp_rendezvous.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/rendezvous/__pycache__/utils.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/rendezvous/__pycache__/utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8a5277ba8d8c0c1c3c6f3d46b1383d65ecd98ba2 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/rendezvous/__pycache__/utils.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/rendezvous/_etcd_stub.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/rendezvous/_etcd_stub.py new file mode 100644 index 0000000000000000000000000000000000000000..5890a97c672a61b5678e66b006ba173fe7668286 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/rendezvous/_etcd_stub.py @@ -0,0 +1,75 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Any + + +""" +This file is not meant to be used directly. It serves as a stub to allow +other files to be safely imported without requiring the installation of +the 'etcd' library. The classes and methods here raise exceptions to +indicate that the real 'etcd' module is needed. +""" + + +class EtcdStubError(ImportError): + """Custom exception to indicate that the real etcd module is required.""" + + def __init__(self) -> None: + super().__init__("The 'etcd' module is required but not installed.") + + +class EtcdAlreadyExist(Exception): + def __init__(self, *args: Any, **kwargs: Any) -> None: + raise EtcdStubError + + +class EtcdCompareFailed(Exception): + def __init__(self, *args: Any, **kwargs: Any) -> None: + raise EtcdStubError + + +class EtcdKeyNotFound(Exception): + def __init__(self, *args: Any, **kwargs: Any) -> None: + raise EtcdStubError + + +class EtcdWatchTimedOut(Exception): + def __init__(self, *args: Any, **kwargs: Any) -> None: + raise EtcdStubError + + +class EtcdEventIndexCleared(Exception): + def __init__(self, *args: Any, **kwargs: Any) -> None: + raise EtcdStubError + + +class EtcdException(Exception): + def __init__(self, *args: Any, **kwargs: Any) -> None: + raise EtcdStubError + + +class EtcdResult: + def __init__(self) -> None: + raise EtcdStubError + + +class Client: + def __init__(self, *args: Any, **kwargs: Any) -> None: + raise EtcdStubError + + def read(self, key: str) -> None: + raise EtcdStubError + + def write( + self, key: str, value: Any, ttl: int | None = None, **kwargs: Any + ) -> None: + raise EtcdStubError + + def test_and_set( + self, key: str, value: Any, prev_value: Any, ttl: int | None = None + ) -> None: + raise EtcdStubError diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/rendezvous/api.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/rendezvous/api.py new file mode 100644 index 0000000000000000000000000000000000000000..2b3fa8183dfb81da2f0b675a5e1a5d1f6fee935f --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/rendezvous/api.py @@ -0,0 +1,391 @@ +# mypy: allow-untyped-defs +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import socket +from abc import ABC, abstractmethod +from collections.abc import Callable +from dataclasses import dataclass +from typing import Any, ClassVar + +from torch.distributed import Store +from torch.distributed.elastic.utils.distributed import get_free_port + + +__all__ = [ + "RendezvousClosedError", + "RendezvousConnectionError", + "RendezvousError", + "RendezvousGracefulExitError", + "RendezvousHandler", + "RendezvousHandlerCreator", + "RendezvousHandlerRegistry", + "RendezvousInfo", + "RendezvousParameters", + "RendezvousStateError", + "RendezvousStoreInfo", + "RendezvousTimeoutError", + "rendezvous_handler_registry", +] + + +class RendezvousError(Exception): + """Represents the base type for rendezvous errors.""" + + +class RendezvousClosedError(RendezvousError): + """Raised when a rendezvous is closed.""" + + +class RendezvousTimeoutError(RendezvousError): + """Raised when a rendezvous did not complete on time.""" + + +class RendezvousConnectionError(RendezvousError): + """Raised when the connection to a rendezvous backend has failed.""" + + +class RendezvousStateError(RendezvousError): + """Raised when the state of a rendezvous is corrupt.""" + + +class RendezvousGracefulExitError(RendezvousError): + """Raised when node wasn't not included in rendezvous and gracefully exits. + + Exception is a mechanism to exit the stack, however does not mean a failure. + """ + + +@dataclass +class RendezvousStoreInfo: + """Store address and port that can be used to bootstrap trainer distributed comms""" + + MASTER_ADDR_KEY: ClassVar[str] = "MASTER_ADDR" + MASTER_PORT_KEY: ClassVar[str] = "MASTER_PORT" + master_addr: str + master_port: int + + @staticmethod + def build( + rank: int, + store: Store, + local_addr: str | None, + server_port: int | None = None, + ) -> "RendezvousStoreInfo": + """Factory method, finds unused new port on rank0 host and addr/port info with all ranks. + + If master_addr/master_port is knowns (useful when sharing existing tcp store server) use the constructor. + + Args: + rank: rank of the current node + store: store to use for rendezvous + local_addr: address of the current node, if not provided will be resolved from hostname + server_port: port of the TCPStore server, when the TCPStore is shared. + """ + # TODO swap to collectives comms API + if rank == 0: + addr = local_addr or socket.getfqdn() + # When TCPStore is not shared, we fallback to get_free_port. + port = server_port or get_free_port() + store.set( + RendezvousStoreInfo.MASTER_ADDR_KEY, + addr.encode(encoding="UTF-8"), # type: ignore[arg-type] + ) + store.set( + RendezvousStoreInfo.MASTER_PORT_KEY, + str(port).encode(encoding="UTF-8"), # type: ignore[arg-type] + ) + + addr = store.get(RendezvousStoreInfo.MASTER_ADDR_KEY).decode(encoding="UTF-8") + port = int( + store.get(RendezvousStoreInfo.MASTER_PORT_KEY).decode(encoding="UTF-8") + ) + return RendezvousStoreInfo(master_addr=addr, master_port=port) + + +class RendezvousInfo: + """Holds the information about the rendezvous.""" + + def __init__( + self, + store: Store, + rank: int, + world_size: int, + bootstrap_store_info: RendezvousStoreInfo, + ): + self._store = store + self._rank = rank + self._world_size = world_size + self._bootstrap_store_info = bootstrap_store_info + + @property + def store(self) -> Store: + """Store used by torchelastic control plane""" + return self._store + + @property + def rank(self) -> int: + """Rank within a group""" + return self._rank + + @property + def world_size(self) -> int: + """Global group size""" + return self._world_size + + @property + def bootstrap_store_info(self) -> RendezvousStoreInfo | None: + """Store information that can used by trainer code to bootstrap distributed comms.""" + return self._bootstrap_store_info + + +class RendezvousHandler(ABC): + """Main rendezvous interface. + + Note: + Distributed Torch users normally **do not** need to implement their own + ``RendezvousHandler``. An implementation based on C10d Store is already + provided, and is recommended for most users. + """ + + @abstractmethod + def get_backend(self) -> str: + """Return the name of the rendezvous backend.""" + + @property + def use_agent_store(self) -> bool: + """Indicates that store reference returned by :py:meth:`next_rendezvous` can be shared with user + applications and will be available during application lifecycle. + + Rendezvous handler impl will share store details as instance of :py:class:`RendezvousStoreInfo`. + Applications as a convention use `MASTER_ADDR`/`MASTER_PORT` env variables to lookup the store. + """ + return False + + @abstractmethod + def next_rendezvous(self) -> RendezvousInfo: + """Main entry-point into the rendezvous barrier. + + Blocks until the rendezvous is complete and the current process is + included in the formed worker group, or a timeout occurs, or the + rendezvous was marked closed. + + Returns: + Instance of :py:class:`RendezvousInfo`. + + Raises: + RendezvousClosedError: + The rendezvous is closed. + RendezvousConnectionError: + The connection to the rendezvous backend has failed. + RendezvousStateError: + The rendezvous state is corrupt. + RendezvousTimeoutError: + The rendezvous did not complete on time. + """ + + @abstractmethod + def is_closed(self) -> bool: + """Check whether the rendezvous has been closed. + + A closed rendezvous means all future attempts to re-rendezvous within + same job will fail. + + ``is_closed()`` and :py:meth:`set_closed` have semantics of eventual + propagation and should not be used for synchronization. The intention is + that if at least one node decides the job is finished, it will close the + rendezvous, and other nodes will soon observe this and stop running as + well. + """ + + @abstractmethod + def set_closed(self): + """Mark the rendezvous as closed.""" + + @abstractmethod + def num_nodes_waiting(self) -> int: + """Return the number of nodes who arrived late at the rendezvous + barrier, hence were not included in the current worker group. + + Callers should periodically call this method to check whether new + nodes are waiting to join the job and if so admit them by calling + :py:meth:`next_rendezvous()` (re-rendezvous). + """ + + @abstractmethod + def get_run_id(self) -> str: + """Return the run id of the rendezvous. + + The run id is a user-defined id that uniquely identifies an instance of + a distributed application. It typically maps to a job id and is used to + allow nodes to join the correct distributed application. + """ + + @abstractmethod + def shutdown(self) -> bool: + """Close all resources that were open for the rendezvous. + + Example:: + + rdzv_handler = ... + try: + store, rank, world_size = rdzv_handler.next_rendezvous() + finally: + rdzv_handler.shutdown() + """ + + +class RendezvousParameters: + """Hold the parameters to construct a :py:class:`RendezvousHandler`. + + Args: + backend: + The name of the backend to use to handle the rendezvous. + endpoint: + The endpoint of the rendezvous, usually in form [:]. + run_id: + The id of the rendezvous. + min_nodes: + The minimum number of nodes to admit to the rendezvous. + max_nodes: + The maximum number of nodes to admit to the rendezvous. + local_addr: + The address of the local node. + **kwargs: + Additional parameters for the specified backend. + """ + + def __init__( + self, + backend: str, + endpoint: str, + run_id: str, + min_nodes: int, + max_nodes: int, + local_addr: str | None = None, + **kwargs, + ): + if not backend: + raise ValueError("The rendezvous backend name must be a non-empty string.") + + if min_nodes < 1: + raise ValueError( + f"The minimum number of rendezvous nodes ({min_nodes}) must be greater than zero." + ) + if max_nodes < min_nodes: + raise ValueError( + f"The maximum number of rendezvous nodes ({max_nodes}) must be greater than or " + f"equal to the minimum number of rendezvous nodes ({min_nodes})." + ) + + self.backend = backend + self.endpoint = endpoint + self.run_id = run_id + self.min_nodes = min_nodes + self.max_nodes = max_nodes + self.config = kwargs + self.local_addr = local_addr + + def get(self, key: str, default: Any = None) -> Any: + """Return the value for ``key`` if ``key`` exists, else ``default``.""" + return self.config.get(key, default) + + def get_as_bool(self, key: str, default: bool | None = None) -> bool | None: + """Return the value for ``key`` as a ``bool``.""" + value = self.get(key, default) + if value is None or isinstance(value, bool): + return value + if isinstance(value, int): + if value == 1: + return True + if value == 0: + return False + elif isinstance(value, str): + if value.lower() in ["1", "true", "t", "yes", "y"]: + return True + if value.lower() in ["0", "false", "f", "no", "n"]: + return False + raise ValueError( + f"The rendezvous configuration option '{key}' does not represent a valid boolean value." + ) + + def get_as_int(self, key: str, default: int | None = None) -> int | None: + """Return the value for ``key`` as an ``int``.""" + value = self.get(key, default) + if value is None: + return value + try: + return int(value) + except ValueError as e: + raise ValueError( + f"The rendezvous configuration option '{key}' does not represent a valid integer " + "value." + ) from e + + +RendezvousHandlerCreator = Callable[[RendezvousParameters], RendezvousHandler] + + +class RendezvousHandlerRegistry: + """Represent a registry of :py:class:`RendezvousHandler` backends.""" + + _registry: dict[str, RendezvousHandlerCreator] + + def __init__(self) -> None: + self._registry = {} + + def register(self, backend: str, creator: RendezvousHandlerCreator) -> None: + """Register a new rendezvous backend. + + Args: + backend: + The name of the backend. + creator: + The callback to invoke to construct the + :py:class:`RendezvousHandler`. + """ + if not backend: + raise ValueError("The rendezvous backend name must be a non-empty string.") + + current_creator: RendezvousHandlerCreator | None + try: + current_creator = self._registry[backend] + except KeyError: + current_creator = None + + if current_creator is not None and current_creator != creator: + raise ValueError( + f"The rendezvous backend '{backend}' cannot be registered with '{creator}' as it " + f"is already registered with '{current_creator}'." + ) + + self._registry[backend] = creator + + def create_handler(self, params: RendezvousParameters) -> RendezvousHandler: + """Create a new :py:class:`RendezvousHandler`.""" + try: + creator = self._registry[params.backend] + except KeyError as e: + raise ValueError( + f"The rendezvous backend '{params.backend}' is not registered. Did you forget " + f"to call `{self.register.__name__}`?" + ) from e + + handler = creator(params) + + # Do some sanity check. + if handler.get_backend() != params.backend: + raise RuntimeError( + f"The rendezvous backend '{handler.get_backend()}' does not match the requested " + f"backend '{params.backend}'." + ) + + return handler + + +# The default global registry instance used by launcher scripts to instantiate +# rendezvous handlers. +rendezvous_handler_registry = RendezvousHandlerRegistry() diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/rendezvous/c10d_rendezvous_backend.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/rendezvous/c10d_rendezvous_backend.py new file mode 100644 index 0000000000000000000000000000000000000000..0296c4d45ddc13dadc9ee1d91f07a3950c277892 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/rendezvous/c10d_rendezvous_backend.py @@ -0,0 +1,270 @@ +# mypy: allow-untyped-defs +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import binascii +import logging +import os +import tempfile +from base64 import b64decode, b64encode +from datetime import timedelta +from typing import Any, cast + +from torch.distributed import FileStore, Store, TCPStore +from torch.distributed.elastic.events import construct_and_record_rdzv_event, NodeState + +from .api import ( + RendezvousConnectionError, + RendezvousError, + RendezvousParameters, + RendezvousStateError, +) +from .dynamic_rendezvous import RendezvousBackend, Token +from .utils import _matches_machine_hostname, parse_rendezvous_endpoint + + +logger = logging.getLogger(__name__) + +# default port for the TCP store +DEFAULT_PORT = 29400 + + +class C10dRendezvousBackend(RendezvousBackend): + """Represents a C10d-backed rendezvous backend. + + Args: + store: + The :py:class:`torch.distributed.Store` instance to use to + communicate with the C10d store. + run_id: + The run id of the rendezvous. + """ + + # See the explanation in the __init__ method. + _NULL_SENTINEL = "Y2FuaW1hZGFt" + + _store: Store + _key: str + + def __init__(self, store: Store, run_id: str) -> None: + if not run_id: + raise ValueError("The run id must be a non-empty string.") + + self._store = store + + self._key = "torch.rendezvous." + run_id + + # The read operation of a store blocks the caller until the specified + # key becomes available. This behavior makes it tricky to use a store + # as a regular key-value dictionary. + # + # As a workaround we initially set a sentinel value as the rendezvous + # state. Whenever this value gets returned we treat it as a None. + self._call_store("compare_set", self._key, "", self._NULL_SENTINEL) + + @property + def name(self) -> str: + """See base class.""" + return "c10d" + + def get_state(self) -> tuple[bytes, Token] | None: + """See base class.""" + base64_state: bytes = self._call_store("get", self._key) + + return self._decode_state(base64_state) + + def set_state( + self, state: bytes, token: Token | None = None + ) -> tuple[bytes, Token, bool] | None: + """See base class.""" + base64_state_str: str = b64encode(state).decode() + + if token: + # Shortcut if we know for sure that the token is not valid. + if not isinstance(token, bytes): + result = self.get_state() + if result is not None: + return *result, False + return None + + token = token.decode() + else: + token = self._NULL_SENTINEL + + base64_state: bytes = self._call_store( + "compare_set", self._key, token, base64_state_str + ) + + state_token_pair = self._decode_state(base64_state) + if state_token_pair is None: + return None + + new_state, new_token = state_token_pair + + # C10d Store's compare_set method does not offer an easy way to find out + # whether our write attempt was successful. As a brute-force solution we + # perform a bitwise comparison of our local state and the remote state. + return new_state, new_token, new_state == state + + def _call_store(self, store_op: str, *args, **kwargs) -> Any: + try: + return getattr(self._store, store_op)(*args, **kwargs) + except (ValueError, RuntimeError, TimeoutError) as exc: + raise RendezvousConnectionError( + "The connection to the C10d store has failed. See inner exception for details." + ) from exc + + def _decode_state(self, base64_state: bytes) -> tuple[bytes, Token] | None: + if base64_state == self._NULL_SENTINEL.encode(): + return None + + try: + state = b64decode(base64_state) + except binascii.Error as exc: + raise RendezvousStateError( + "The state object is corrupt. See inner exception for details." + ) from exc + + return state, base64_state + + +def _create_tcp_store(params: RendezvousParameters) -> TCPStore: + host, port = parse_rendezvous_endpoint(params.endpoint, default_port=DEFAULT_PORT) + + cfg_is_host = params.get_as_bool("is_host") + # If the user has explicitly specified whether our process should host the + # the store, respect it. + if cfg_is_host is not None: + is_host = cfg_is_host + # Otherwise try to determine whether we are the host based on our hostname + # and IP address. + else: + is_host = _matches_machine_hostname(host) + + # The timeout + read_timeout = cast(int, params.get_as_int("read_timeout", 60)) + if read_timeout <= 0: + raise ValueError("The read timeout must be a positive integer.") + + # In specific cases we attempt to instantiate the store twice. For details + # see the explanation in the except clause below. + for is_server in [is_host, False]: + try: + store = TCPStore( + host, + port, + is_master=is_server, + multi_tenant=True, + timeout=timedelta(seconds=read_timeout), + ) + + if is_server: + msg = f"Process {os.getpid()} hosts the TCP store for the C10d rendezvous backend." + construct_and_record_rdzv_event( + run_id=params.run_id, message=msg, node_state=NodeState.INIT + ) + logger.info(msg) + + break + except (ValueError, RuntimeError, TimeoutError) as exc: + # If we heuristically inferred the value of is_host as True and our + # first attempt to instantiate the TCP store has failed, try it one + # more time with is_host set to False. As an edge case there can be + # more than one process that is part of the same rendezvous on this + # machine and only one of them will eventually host the store. + + if not is_server or cfg_is_host is not None: + raise RendezvousConnectionError( + "The connection to the C10d store has failed. See inner exception for details." + ) from exc + + return store # type: ignore[possibly-undefined] + + +def _create_file_store(params: RendezvousParameters) -> FileStore: + # If a user specifies an endpoint, we treat it as a path to a file. + if params.endpoint: + path = params.endpoint + else: + try: + # The temporary file is readable and writable only by the user of + # this process. + _, path = tempfile.mkstemp() + except OSError as exc: + raise RendezvousError( + "The file creation for C10d store has failed. See inner exception for details." + ) from exc + + try: + store = FileStore(path) + except (ValueError, RuntimeError) as exc: + raise RendezvousConnectionError( + "The connection to the C10d store has failed. See inner exception for details." + ) from exc + + return store + + +def create_backend(params: RendezvousParameters) -> tuple[C10dRendezvousBackend, Store]: + """Create a new :py:class:`C10dRendezvousBackend` from the specified parameters. + + +--------------+-----------------------------------------------------------+ + | Parameter | Description | + +==============+===========================================================+ + | store_type | The type of the C10d store. The currently supported types | + | | are "tcp" and "file" which correspond to | + | | :py:class:`torch.distributed.TCPStore` and | + | | :py:class:`torch.distributed.FileStore`, respectively. | + | | Defaults to "tcp". | + +--------------+-----------------------------------------------------------+ + | read_timeout | The read timeout, in seconds, for store operations. | + | | Defaults to 60 seconds. | + | | | + | | Note this only applies to | + | | :py:class:`torch.distributed.TCPStore`. It is not relevant| + | | to :py:class:`torch.distributed.FileStore` which does not | + | | take in timeout as a parameter. | + +--------------+-----------------------------------------------------------+ + | is_host | A boolean value indicating whether this backend instance | + | | will host the C10d store. If not specified it will be | + | | inferred heuristically by matching the hostname or the IP | + | | address of this machine against the specified rendezvous | + | | endpoint. Defaults to ``None``. | + | | | + | | Note that this configuration option only applies to | + | | :py:class:`torch.distributed.TCPStore`. In normal | + | | circumstances you can safely skip it; the only time when | + | | it is needed is if its value cannot be correctly | + | | determined (e.g. the rendezvous endpoint has a CNAME as | + | | the hostname or does not match the FQDN of the machine). | + +--------------+-----------------------------------------------------------+ + """ + # As of today we only support TCPStore and FileStore. Other store types do + # not have the required functionality (e.g. compare_set) yet. + store_type = params.get("store_type", "tcp").strip().lower() + store: Store + + try: + if store_type == "file": + store = _create_file_store(params) + elif store_type == "tcp": + store = _create_tcp_store(params) + else: + raise ValueError( + "Invalid store type given. Currently only supports file and tcp." + ) + + backend = C10dRendezvousBackend(store, params.run_id) + + except Exception as e: + construct_and_record_rdzv_event( + message=f"{type(e).__name__}: {str(e)}", + run_id=params.run_id, + node_state=NodeState.FAILED, + ) + raise + + return backend, store diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/rendezvous/dynamic_rendezvous.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/rendezvous/dynamic_rendezvous.py new file mode 100644 index 0000000000000000000000000000000000000000..84adeea95573121e69f11c6faa52fe6601f271c7 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/rendezvous/dynamic_rendezvous.py @@ -0,0 +1,1453 @@ +# mypy: allow-untyped-defs +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import inspect +import logging +import os +import pickle +import socket +import threading +import time +import weakref +from abc import ABC, abstractmethod +from collections.abc import Callable +from dataclasses import dataclass +from datetime import datetime, timedelta, timezone +from enum import Enum +from typing import Any + +import torch.distributed as dist +from torch.distributed import Store +from torch.distributed.elastic.events import construct_and_record_rdzv_event, NodeState + +from .api import ( + RendezvousClosedError, + RendezvousError, + RendezvousGracefulExitError, + RendezvousHandler, + RendezvousInfo, + RendezvousParameters, + RendezvousStateError, + RendezvousStoreInfo, + RendezvousTimeoutError, +) +from .utils import _delay, _PeriodicTimer + + +__all__ = [ + "RendezvousBackend", + "RendezvousTimeout", + "RendezvousSettings", + "DynamicRendezvousHandler", + "create_handler", +] + +logger = logging.getLogger(__name__) + + +def get_method_name(depth=2): + if len(inspect.stack()) > depth: + return inspect.stack()[depth].function + return "no_method_name" + + +Token = Any +"""Represent an opaque fencing token used by the rendezvous backend.""" + + +class RendezvousBackend(ABC): + """Represent a backend that holds the rendezvous state.""" + + @property + @abstractmethod + def name(self) -> str: + """Get the name of the backend.""" + + @abstractmethod + def get_state(self) -> tuple[bytes, Token] | None: + """Get the rendezvous state. + + Returns: + A tuple of the encoded rendezvous state and its fencing token or + ``None`` if no state is found in the backend. + + Raises: + RendezvousConnectionError: + The connection to the backend has failed. + RendezvousStateError: + The rendezvous state is corrupt. + """ + + @abstractmethod + def set_state( + self, state: bytes, token: Token | None = None + ) -> tuple[bytes, Token, bool] | None: + """Set the rendezvous state. + + The new rendezvous state is set conditionally: + + - If the specified ``token`` matches the fencing token stored in the + backend, the state will be updated. The new state will be returned + to the caller along with its fencing token. + - If the specified ``token`` does not match the fencing token stored + in the backend, the state won't be updated; instead the existing + state along with its fencing token will be returned to the caller. + - If the specified ``token`` is ``None``, the new state will be set + only if there is no existing state in the backend. Either the new + state or the existing state along with its fencing token will be + returned to the caller. + + Args: + state: + The encoded rendezvous state. + token: + An optional fencing token that was retrieved by a previous call + to :py:meth:`get_state` or ``set_state()``. + + Returns: + A tuple of the serialized rendezvous state, its fencing token, and + a boolean value indicating whether our set attempt succeeded. + + Raises: + RendezvousConnectionError: + The connection to the backend has failed. + RendezvousStateError: + The rendezvous state is corrupt. + """ + + +class RendezvousTimeout: + """Hold the timeout configuration of a rendezvous. + + Args: + join: + The time within which the rendezvous is expected to complete. + last_call: + An additional wait amount before completing the rendezvous once the + rendezvous has the minimum number of required participants. + close: + The time within which the rendezvous is expected to close after a + call to :py:meth:`RendezvousHandler.set_closed` or + :py:meth:`RendezvousHandler.shutdown`. + heartbeat: + The time within which a keep-alive heartbeat is expected to + complete. + """ + + _ZERO = timedelta(0) + + _DEFAULT_TIMEOUTS = { + "join": timedelta(seconds=600), + "last_call": timedelta(seconds=30), + "close": timedelta(seconds=30), + "heartbeat": timedelta(seconds=5), + } + + _join: timedelta + _last_call: timedelta + _close: timedelta + _heartbeat: timedelta + + def __init__( + self, + join: timedelta | None = None, + last_call: timedelta | None = None, + close: timedelta | None = None, + heartbeat: timedelta | None = None, + ) -> None: + self._set_timeouts( + join=join, last_call=last_call, close=close, heartbeat=heartbeat + ) + + @property + def join(self) -> timedelta: + """Get the join timeout.""" + return self._join + + @property + def last_call(self) -> timedelta: + """Get the last call timeout.""" + return self._last_call + + @property + def close(self) -> timedelta: + """Get the close timeout.""" + return self._close + + @property + def heartbeat(self) -> timedelta: + """Get the keep-alive heartbeat timeout.""" + return self._heartbeat + + def _set_timeouts(self, **timeouts: timedelta | None): + for name, timeout in timeouts.items(): + if timeout is None: + timeout = self._DEFAULT_TIMEOUTS[name] + if timeout <= self._ZERO: + raise ValueError(f"The {name} timeout ({timeout}) must be positive.") + setattr(self, "_" + name, timeout) + + +@dataclass(repr=False, eq=False, frozen=True) +class RendezvousSettings: + """Hold the settings of the rendezvous. + + Attributes: + run_id: + The run id of the rendezvous. + min_nodes: + The minimum number of nodes to admit to the rendezvous. + max_nodes: + The maximum number of nodes to admit to the rendezvous. + timeout: + The timeout configuration of the rendezvous. + keep_alive_interval: + The amount of time a node waits before sending a heartbeat to keep + it alive in the rendezvous. + keep_alive_max_attempt: + The maximum number of failed heartbeat attempts after which a node + is considered dead. + """ + + run_id: str + min_nodes: int + max_nodes: int + timeout: RendezvousTimeout + keep_alive_interval: timedelta + keep_alive_max_attempt: int + + +@dataclass(eq=True, order=True, frozen=True) +class _NodeDesc: + """Describe a node in the rendezvous. + + Attributes: + addr: + The FQDN of the node or user specified local node address. + pid: + The id of the process in which the rendezvous handler runs. + local_id: + A process-wide unique id. + """ + + addr: str + pid: int + local_id: int + + def __repr__(self) -> str: + return f"{self.addr}_{self.pid}_{self.local_id}" + + +class _NodeDescGenerator: + """Generate node descriptors. + + A node descriptor is a combination of an FQDN, a process id, and an auto- + incremented integer that uniquely identifies a node in the rendezvous. + """ + + _lock: threading.Lock + _local_id: int + + def __init__(self) -> None: + self._lock = threading.Lock() + + # An integer that is incremented with each call to generate(). + self._local_id = 0 + + def generate(self, local_addr: str | None = None) -> _NodeDesc: + # This method can be called by multiple threads concurrently; therefore, + # we must increment the integer atomically. + with self._lock: + local_id = self._local_id + + self._local_id += 1 + + return _NodeDesc(local_addr or socket.getfqdn(), os.getpid(), local_id) + + +class _RendezvousState: + """Hold the state of a rendezvous. + + Attributes: + round: + The current round of the rendezvous. + complete: + A boolean value indicating whether the current round of the + rendezvous is complete. + deadline: + The time at which the current round of the rendezvous will be + considered complete if it is still waiting for nodes to join. + closed: + A boolean value indicating whether the rendezvous is closed. + participants: + A dictionary of the participants and their corresponding ranks. + wait_list: + A set of nodes that are waiting to participate in the next round of + the rendezvous. + redundancy_list: + A set of nodes that are redundant in the current round and can join + the next rendezvous without triggering re-rendezvous. + last_heartbeats: + A dictionary containing each node's last heartbeat time. + """ + + round: int + complete: bool + deadline: datetime | None + closed: bool + participants: dict[_NodeDesc, int] + wait_list: set[_NodeDesc] + redundancy_list: set[_NodeDesc] + last_heartbeats: dict[_NodeDesc, datetime] + + def __init__(self) -> None: + self.round = 0 + self.complete = False + self.deadline = None + self.closed = False + self.participants = {} + self.wait_list = set() + self.redundancy_list = set() + self.last_heartbeats = {} + + +def _remove_participant_epilogue( + state: _RendezvousState, settings: RendezvousSettings +) -> None: + if state.complete: + # If we do not have any participants left, move to the next round. + if not state.participants: + msg = "No participants left in the rendezvous, marking rendezvous as incomplete" + logger.debug(msg) + state.complete = False + + state.round += 1 + else: + if len(state.participants) < settings.min_nodes: + msg = ( + f"Number of participants {len(state.participants)}) less than" + f"min_nodes {settings.min_nodes}, clearning deadline in state" + ) + logger.debug(msg) + state.deadline = None + + +class _RendezvousStateHolder(ABC): + """Hold the shared rendezvous state synced with other nodes.""" + + @property + @abstractmethod + def state(self) -> _RendezvousState: + """Get the local state.""" + + @abstractmethod + def sync(self) -> bool | None: + """Read or writes the latest state. + + Returns: + A boolean value indicating whether the local state, in case marked + as dirty, was successfully synced with other nodes. + """ + + @abstractmethod + def mark_dirty(self) -> None: + """Mark the local state as dirty.""" + + +class _BackendRendezvousStateHolder(_RendezvousStateHolder): + """Hold the rendezvous state synced with other nodes via a backend. + + Args: + backend: + The rendezvous backend to use. + settings: + The rendezvous settings. + cache_duration: + The amount of time, in seconds, to cache the last rendezvous state + before requesting it from the backend again. + """ + + _backend: RendezvousBackend + _state: _RendezvousState + _settings: RendezvousSettings + _cache_duration: int + _token: Token + _dirty: bool + _last_sync_time: float + _dead_nodes: list[_NodeDesc] + + def __init__( + self, + backend: RendezvousBackend, + settings: RendezvousSettings, + cache_duration: int = 1, + ) -> None: + self._backend = backend + self._state = _RendezvousState() + self._settings = settings + self._cache_duration = cache_duration + self._token = None + self._dirty = False + self._last_sync_time = -1 + self._dead_nodes = [] + + def _record(self, message: str, node_state: NodeState = NodeState.RUNNING): + construct_and_record_rdzv_event( + name=f"{self.__class__.__name__}.{get_method_name()}", + run_id=self._settings.run_id, + message=message, + node_state=node_state, + ) + + @property + def state(self) -> _RendezvousState: + """See base class.""" + return self._state + + def sync(self) -> bool | None: + """See base class.""" + state_bits: bytes | None = None + + token = None + + has_set: bool | None + + if self._dirty: + has_set = False + + state_bits = pickle.dumps(self._state) + + set_response = self._backend.set_state(state_bits, self._token) + if set_response is not None: + state_bits, token, has_set = set_response + else: + has_set = None + + if self._cache_duration > 0: + # Avoid overloading the backend if we are asked to retrieve the + # state repeatedly. Try to serve the cached state. + if self._last_sync_time >= max( + time.monotonic() - self._cache_duration, 0 + ): + return None + + get_response = self._backend.get_state() + if get_response is not None: + state_bits, token = get_response + + if state_bits is not None: + try: + self._state = pickle.loads(state_bits) + except pickle.PickleError as exc: + raise RendezvousStateError( + "The rendezvous state is corrupt. See inner exception for details." + ) from exc + else: + self._state = _RendezvousState() + + if has_set and self._dead_nodes and logger.isEnabledFor(logging.DEBUG): + node_list = ", ".join(f"'{dead_node}'" for dead_node in self._dead_nodes) + + msg = ( + f"As part of the sync operation the node(s) {node_list} have been removed from the " + f"rendezvous '{self._settings.run_id}' since they had no heartbeat." + ) + self._record(message=msg) + logger.debug(msg) + + self._token = token + + self._dirty = False + + self._last_sync_time = time.monotonic() + + self._sanitize() + + return has_set + + def _sanitize(self) -> None: + state = self._state + + expire_time = datetime.now(timezone.utc) - ( + self._settings.keep_alive_interval * self._settings.keep_alive_max_attempt + ) + + # Filter out the dead nodes. + self._dead_nodes = [ + node + for node, last_heartbeat in state.last_heartbeats.items() + if last_heartbeat < expire_time + ] + + participant_removed = False + + for dead_node in self._dead_nodes: + msg = f"Detected dead node '{dead_node}', removing it from the rendezvous" + logger.debug(msg) + del state.last_heartbeats[dead_node] + + try: + del state.participants[dead_node] + + participant_removed = True + except KeyError: + pass + + try: + state.wait_list.remove(dead_node) + except KeyError: + pass + + try: + state.redundancy_list.remove(dead_node) + except KeyError: + pass + + if participant_removed: + # Common epilogue shared with the _remove_from_participants() + # function of _DistributedRendezvousOpExecutor. + _remove_participant_epilogue(state, self._settings) + + def mark_dirty(self) -> None: + """See base class. + + If the local rendezvous state is dirty, the next sync call will try to + write the changes back to the backend. However this attempt might fail + if another node, which had the same state, also made changes and wrote + them before us. + """ + self._dirty = True + + +class _Action(Enum): + """Specifies the possible actions based on the state of the rendezvous.""" + + KEEP_ALIVE = 1 + ADD_TO_PARTICIPANTS = 2 + ADD_TO_WAIT_LIST = 3 + ADD_TO_REDUNDANCY_LIST = 4 + REMOVE_FROM_PARTICIPANTS = 5 + REMOVE_FROM_WAIT_LIST = 6 + REMOVE_FROM_REDUNDANCY_LIST = 7 + MARK_RENDEZVOUS_COMPLETE = 8 + MARK_RENDEZVOUS_CLOSED = 9 + SYNC = 10 + ERROR_CLOSED = 11 + ERROR_TIMEOUT = 12 + FINISH = 13 + + +class _RendezvousContext: + """Holds the context of the rendezvous. + + Attributes: + node: + The node descriptor associated with the current rendezvous handler + instance. + state: + The current state of the rendezvous. + settings: + The rendezvous settings. + """ + + node: _NodeDesc + state: _RendezvousState + settings: RendezvousSettings + + def __init__( + self, node: _NodeDesc, state: _RendezvousState, settings: RendezvousSettings + ) -> None: + self.node = node + self.state = state + self.settings = settings + + +class _RendezvousOpExecutor(ABC): + """Execute rendezvous operations.""" + + @abstractmethod + def run( + self, + state_handler: Callable[[_RendezvousContext, float], _Action], + deadline: float, + update_deadline: Callable[[timedelta], float] | None = None, + ) -> None: + """Execute a rendezvous operation. + + An operation is run inside a state machine and is expected to transition + the rendezvous from one state to another. + + Args: + state_handler: + A callable that is expected to return the next state transition + action based on the current state of the rendezvous. + deadline: + The time, in seconds, at which the operation will be considered + timed-out. + update_deadline: + Function to generate a new operation deadline if the current + node may participate in the next rendezvous. + """ + + +class _DistributedRendezvousOpExecutor(_RendezvousOpExecutor): + """Execute rendezvous operations using a shared state. + + Args: + node: + The node descriptor associated with the current rendezvous handler + instance. + state_holder: + The ``RendezvousStateHolder`` to use to sync the rendezvous state + with other nodes. + settings: + The rendezvous settings. + """ + + _node: _NodeDesc + _state: _RendezvousState + _state_holder: _RendezvousStateHolder + _settings: RendezvousSettings + + def __init__( + self, + node: _NodeDesc, + state_holder: _RendezvousStateHolder, + settings: RendezvousSettings, + ) -> None: + self._node = node + self._state_holder = state_holder + self._settings = settings + + def _record(self, message: str, node_state: NodeState = NodeState.RUNNING) -> None: + construct_and_record_rdzv_event( + name=f"{self.__class__.__name__}.{get_method_name()}", + run_id=self._settings.run_id, + message=message, + node_state=node_state, + hostname=self._node.addr, + pid=self._node.pid, + local_id=self._node.local_id, + ) + + def run( + self, + state_handler: Callable[[_RendezvousContext, float], _Action], + deadline: float, + update_deadline: Callable[[timedelta], float] | None = None, + ) -> None: + """See base class.""" + action = None + while action != _Action.FINISH: + # Reads or writes the latest rendezvous state shared by all nodes in + # the rendezvous. Note that our local changes might get overridden + # by another node if that node synced its changes before us. + has_set = self._state_holder.sync() + if has_set is not None: + if has_set: + msg = ( + f"The node '{self._node}' has successfully synced its local changes with " + f"other nodes in the rendezvous '{self._settings.run_id}'." + ) + else: + msg = ( + f"The node '{self._node}' has a stale state and failed to sync its local " + f"changes with other nodes in the rendezvous '{self._settings.run_id}'." + ) + + self._record(message=msg) + logger.debug(msg) + + self._state = self._state_holder.state + + ctx = _RendezvousContext(self._node, self._state, self._settings) + + # Determine the next action to take based on the current state of + # the rendezvous. + action = state_handler(ctx, deadline) + + if action == _Action.FINISH: + continue + + if action == _Action.ERROR_CLOSED: + raise RendezvousClosedError + + if action == _Action.ERROR_TIMEOUT: + raise RendezvousTimeoutError + + if action == _Action.SYNC: + # Delay the execution by one second to avoid overloading the + # backend if we are asked to poll for state changes. + _delay(seconds=1) + else: + if action == _Action.KEEP_ALIVE: + self._keep_alive() + elif action == _Action.ADD_TO_PARTICIPANTS: + self._add_to_participants() + elif action == _Action.ADD_TO_WAIT_LIST: + self._add_to_wait_list() + elif action == _Action.ADD_TO_REDUNDANCY_LIST: + self._add_to_redundancy_list() + elif action == _Action.REMOVE_FROM_PARTICIPANTS: + self._remove_from_participants() + elif action == _Action.REMOVE_FROM_WAIT_LIST: + self._remove_from_wait_list() + elif action == _Action.REMOVE_FROM_REDUNDANCY_LIST: + self._remove_from_redundancy_list() + # update deadline since the node may participate in rendezvous process + if update_deadline: + deadline = update_deadline(self._settings.timeout.join) + elif action == _Action.MARK_RENDEZVOUS_COMPLETE: + self._mark_rendezvous_complete() + elif action == _Action.MARK_RENDEZVOUS_CLOSED: + self._mark_rendezvous_closed() + + # Attempt to sync our changes back to other nodes. + self._state_holder.mark_dirty() + + def _keep_alive(self) -> None: + msg = ( + f"The node '{self._node}' updated its keep-alive heartbeat time for the rendezvous " + f"'{self._settings.run_id}'. Pending sync." + ) + self._record(message=msg) + logger.debug(msg) + + self._state.last_heartbeats[self._node] = datetime.now(timezone.utc) + + def _add_to_participants(self) -> None: + msg = ( + f"The node '{self._node}' added itself to the participants of round " + f"{self._state.round} of the rendezvous '{self._settings.run_id}'. Pending sync." + ) + self._record(message=msg) + logger.debug(msg) + + state = self._state + + try: + state.wait_list.remove(self._node) + except KeyError: + pass + + # The ranks of the participants will be set once the rendezvous is + # complete. + state.participants[self._node] = 0 + + self._keep_alive() + + if len(state.participants) == self._settings.min_nodes: + state.deadline = ( + datetime.now(timezone.utc) + self._settings.timeout.last_call + ) + + if len(state.participants) == self._settings.max_nodes: + self._mark_rendezvous_complete() + + def _add_to_wait_list(self) -> None: + msg = ( + f"The node '{self._node}' added itself to the wait list of round " + f"{self._state.round + 1} of the rendezvous '{self._settings.run_id}'. Pending sync." + ) + self._record(message=msg) + logger.debug(msg) + + if self._node in self._state.redundancy_list: + self._state.redundancy_list.remove(self._node) + self._state.wait_list.add(self._node) + + self._keep_alive() + + def _add_to_redundancy_list(self) -> None: + msg = ( + f"The node '{self._node}' added itself to the redundancy list of round " + f"{self._state.round + 1} of the rendezvous '{self._settings.run_id}'. Pending sync." + ) + self._record(message=msg) + logger.debug(msg) + + self._state.redundancy_list.add(self._node) + + self._keep_alive() + + def _remove_from_participants(self) -> None: + msg = ( + f"The node '{self._node}' removed itself from the participants of round " + f"{self._state.round} of the rendezvous '{self._settings.run_id}'. Pending sync." + ) + self._record(message=msg) + logger.debug(msg) + + state = self._state + + del state.participants[self._node] + + del state.last_heartbeats[self._node] + + # Common epilogue shared with the sanitizer() function of + # _BackendRendezvousStateHolder. + _remove_participant_epilogue(state, self._settings) + + def _remove_from_wait_list(self) -> None: + msg = ( + f"The node '{self._node}' removed itself from the wait list of round " + f"{self._state.round + 1} of the rendezvous '{self._settings.run_id}'. Pending sync." + ) + self._record(message=msg) + logger.debug(msg) + + self._state.wait_list.remove(self._node) + + del self._state.last_heartbeats[self._node] + + def _remove_from_redundancy_list(self) -> None: + msg = ( + f"The node '{self._node}' removed itself from the redundant list of round " + f"{self._state.round + 1} of the rendezvous '{self._settings.run_id}'. Pending sync." + ) + self._record(message=msg) + logger.debug(msg) + + self._state.redundancy_list.remove(self._node) + + del self._state.last_heartbeats[self._node] + + def _mark_rendezvous_complete(self) -> None: + msg = ( + f"The node '{self._node}' marked round {self._state.round} of the rendezvous " + f"'{self._settings.run_id}' as complete. Pending sync." + ) + self._record(message=msg, node_state=NodeState.SUCCEEDED) + logger.debug(msg) + + state = self._state + + state.complete = True + state.deadline = None + + # Assign the ranks. + for rank, node in enumerate(sorted(state.participants)): + state.participants[node] = rank + + def _mark_rendezvous_closed(self) -> None: + msg = ( + f"The node '{self._node}' marked the rendezvous '{self._settings.run_id}' as closed. " + "Pending sync." + ) + self._record(message=msg, node_state=NodeState.SUCCEEDED) + logger.debug(msg) + + self._state.closed = True + + +def _should_keep_alive(ctx: _RendezvousContext) -> bool: + """Determine whether a keep-alive heartbeat should be sent.""" + try: + last_heartbeat = ctx.state.last_heartbeats[ctx.node] + except KeyError: + return False + + return ( + last_heartbeat <= datetime.now(timezone.utc) - ctx.settings.keep_alive_interval + ) + + +class _RendezvousExitOp: + """Represent a rendezvous exit operation.""" + + def __call__(self, ctx: _RendezvousContext, deadline: float) -> _Action: + if ctx.node in ctx.state.participants: + if time.monotonic() > deadline: + return _Action.ERROR_TIMEOUT + return _Action.REMOVE_FROM_PARTICIPANTS + return _Action.FINISH + + +class _RendezvousJoinOp: + """Represent a rendezvous join operation.""" + + def __call__(self, ctx: _RendezvousContext, deadline: float) -> _Action: + state = ctx.state + + # A closed rendezvous means that it no longer accepts new nodes. + if state.closed: + if ctx.node in state.redundancy_list: + msg = f"The rendezvous '{ctx.settings.run_id}' is closed, terminating pending rendezvous." + raise RendezvousGracefulExitError(msg) + return _Action.ERROR_CLOSED + + if ctx.node in state.redundancy_list: + msg = f"The node {ctx.node} is in redundancy list" + logger.debug(msg) + # don't apply the timeout logic here, since we want to allow the node to rejoin + if len(state.participants) == ctx.settings.max_nodes: + if _should_keep_alive(ctx): + return _Action.KEEP_ALIVE + else: + return _Action.SYNC + else: + # transition to waiting state that will respect timeouts. + msg = f"The node {ctx.node} is removed from redundancy list" + logger.debug(msg) + return _Action.REMOVE_FROM_REDUNDANCY_LIST + + is_participant = ctx.node in state.participants + + # If we are part of the rendezvous and it is already complete there is + # no further action to take. + if state.complete and is_participant: + return _Action.FINISH + + now = time.monotonic() + if now > deadline: + rollback_period = 5 # 5 seconds + + # If we still have time to rollback (a short period on top of the + # operation deadline), try to remove ourself from the rendezvous. + # It is okay if we can't though as our keep-alive will eventually + # expire. + if now <= deadline + rollback_period: + # If we are part of the rendezvous, it means we couldn't find + # enough participants to complete it on time. + if is_participant: + return _Action.REMOVE_FROM_PARTICIPANTS + # If we are in the wait list, it means we couldn't wait till the + # next round of the rendezvous. + if ctx.node in state.wait_list: + return _Action.REMOVE_FROM_WAIT_LIST + return _Action.ERROR_TIMEOUT + + if state.complete: + # If we are here, it means we are not part of the rendezvous. In + # case the rendezvous has capacity for additional participants add + # ourself to the wait list for the next round. + if len(state.participants) < ctx.settings.max_nodes: + if ctx.node not in state.wait_list: + return _Action.ADD_TO_WAIT_LIST + elif len(state.participants) >= ctx.settings.max_nodes: + if ( + ctx.node not in state.redundancy_list + and ctx.node not in state.wait_list + ): + return _Action.ADD_TO_REDUNDANCY_LIST + elif is_participant: + # If the rendezvous has enough number of participants including us, + # check whether we have passed the rendezvous deadline. If yes, + # complete it. + if ( + len(state.participants) >= ctx.settings.min_nodes + and len(state.participants) <= ctx.settings.max_nodes + and state.deadline is not None + ): + if state.deadline < datetime.now(timezone.utc): + msg = ( + f"The node '{ctx.node}' marking the rendezvous complete, " + f"quorum established within deadline" + ) + logger.debug(msg) + return _Action.MARK_RENDEZVOUS_COMPLETE + else: + msg = f"The node '{ctx.node}' can't complete rendezvous: deadline reached" + logger.debug(msg) + else: + msg = f"The node '{ctx.node}' can't complete rendezvous: not enough participants" + logger.debug(msg) + else: + # The rendezvous is not complete yet and we are not part of it. Try + # to join. + return _Action.ADD_TO_PARTICIPANTS + + if _should_keep_alive(ctx): + return _Action.KEEP_ALIVE + + # At this point either the rendezvous is not complete, but we are part + # of it, which means we have to wait for other participants to join; or + # the rendezvous is complete, but we are not part of it, which means we + # have to wait for the next round. + return _Action.SYNC + + +class _RendezvousCloseOp: + """Represent a rendezvous close operation.""" + + def __call__(self, ctx: _RendezvousContext, deadline: float) -> _Action: + if ctx.state.closed: + return _Action.FINISH + if time.monotonic() > deadline: + return _Action.ERROR_TIMEOUT + return _Action.MARK_RENDEZVOUS_CLOSED + + +class _RendezvousKeepAliveOp: + """Represent a rendezvous keep-alive update operation.""" + + def __call__(self, ctx: _RendezvousContext, deadline: float) -> _Action: + if _should_keep_alive(ctx): + if time.monotonic() > deadline: + return _Action.ERROR_TIMEOUT + return _Action.KEEP_ALIVE + return _Action.FINISH + + +class DynamicRendezvousHandler(RendezvousHandler): + """Represent a handler that sets up a rendezvous among a set of nodes.""" + + # Static + _node_desc_generator = _NodeDescGenerator() + + _this_node: _NodeDesc + _settings: RendezvousSettings + _backend_name: str + _store: Store + _state_holder: _RendezvousStateHolder + _op_executor: _RendezvousOpExecutor + _heartbeat_lock: threading.Lock + _keep_alive_timer: _PeriodicTimer | None + + @classmethod + def from_backend( + cls, + run_id: str, + store: Store, + backend: RendezvousBackend, + min_nodes: int, + max_nodes: int, + local_addr: str | None = None, + timeout: RendezvousTimeout | None = None, + keep_alive_interval: int = 5, + keep_alive_max_attempt: int = 3, + ): + """Create a new :py:class:`DynamicRendezvousHandler`. + + Args: + run_id: + The run id of the rendezvous. + store: + The C10d store to return as part of the rendezvous. + backend: + The backend to use to hold the rendezvous state. + min_nodes: + The minimum number of nodes to admit to the rendezvous. + max_nodes: + The maximum number of nodes to admit to the rendezvous. + local_addr: + The local node address. + timeout: + The timeout configuration of the rendezvous. + keep_alive_interval: + The amount of time a node waits before sending a heartbeat to keep + it alive in the rendezvous. + keep_alive_max_attempt: + The maximum number of failed heartbeat attempts after which a node + is considered dead. + """ + # We associate each handler instance with a unique node descriptor. + node = cls._node_desc_generator.generate(local_addr) + + settings = RendezvousSettings( + run_id, + min_nodes, + max_nodes, + timeout or RendezvousTimeout(), + keep_alive_interval=timedelta(seconds=keep_alive_interval), + keep_alive_max_attempt=keep_alive_max_attempt, + ) + + state_holder = _BackendRendezvousStateHolder(backend, settings) + + return cls(node, settings, backend.name, store, state_holder) + + def __init__( + self, + node: _NodeDesc, + settings: RendezvousSettings, + backend_name: str, + store: Store, + state_holder: _RendezvousStateHolder, + ) -> None: + if not settings.run_id: + raise ValueError("The run id must be a non-empty string.") + + if settings.min_nodes < 1: + raise ValueError( + f"The minimum number of nodes ({settings.min_nodes}) must be greater than zero." + ) + + if settings.max_nodes < settings.min_nodes: + raise ValueError( + f"The maximum number of nodes ({settings.max_nodes}) must be greater than or equal " + f"to the minimum number of nodes ({settings.min_nodes})." + ) + + self._this_node = node + + self._settings = settings + + self._backend_name = backend_name + + self._store = store + + self._state_holder = state_holder + + self._op_executor = _DistributedRendezvousOpExecutor( + self._this_node, self._state_holder, self._settings + ) + + self._heartbeat_lock = threading.Lock() + + self._keep_alive_timer = None + + # Cached shared store server reference + self._shared_tcp_store_server: dist.Store | None = None + + self._bootstrap_store_info: RendezvousStoreInfo | None = None + + def _record( + self, + message: str, + node_state: NodeState = NodeState.RUNNING, + rank: int | None = None, + ) -> None: + construct_and_record_rdzv_event( + name=f"{self.__class__.__name__}.{get_method_name()}", + run_id=self._settings.run_id, + message=message, + node_state=node_state, + hostname=self._this_node.addr, + pid=self._this_node.pid, + local_id=self._this_node.local_id, + rank=rank, + ) + + def _create_tcp_store_server(self, master_addr, master_port) -> dist.TCPStore: + return dist.TCPStore( + host_name=master_addr, + port=master_port, + is_master=True, + multi_tenant=True, + ) + + @property + def settings(self) -> RendezvousSettings: + """Get the settings of the rendezvous.""" + return self._settings + + def get_backend(self) -> str: + """See base class.""" + return self._backend_name + + @property + def use_agent_store(self) -> bool: + """See base class.""" + return os.getenv("TORCH_DISABLE_SHARE_RDZV_TCP_STORE", "0") != "1" + + def next_rendezvous(self) -> RendezvousInfo: + """See base class.""" + msg = ( + f"The node '{self._this_node}' attempts to join the next round of the rendezvous " + f"'{self._settings.run_id}'." + ) + self._record(message=msg) + logger.info(msg) + + try: + self._stop_heartbeats() + + # Delay the execution for a small random amount of time if this is our + # first run. This will slightly skew the rendezvous attempts across the + # nodes and reduce the load on the backend. + if self._state_holder.state.round == 0: + _delay(seconds=(0, 0.3)) + + exit_op = _RendezvousExitOp() + join_op = _RendezvousJoinOp() + + deadline = self._get_deadline(self._settings.timeout.join) + self._op_executor.run(exit_op, deadline) + self._op_executor.run(join_op, deadline, self._get_deadline) + + self._start_heartbeats() + + rank, world_size = self._get_world() + store = self._get_store() + + except Exception as e: + self._record( + message=f"{type(e).__name__}: {str(e)}", + node_state=NodeState.FAILED, + ) + raise + + msg = ( + f"The node '{self._this_node}' has joined round {self._state_holder.state.round} of " + f"the rendezvous '{self._settings.run_id}' as rank {rank} in a world of size " + f"{world_size}." + ) + self._record(message=msg, rank=rank) + logger.info(msg) + + # opt-out option of TCPStore sharing + if os.getenv("TORCH_DISABLE_SHARE_RDZV_TCP_STORE", "0") == "1": + bootstrap_store_info = RendezvousStoreInfo.build( + rank, store, local_addr=self._this_node.addr + ) + return RendezvousInfo( + store, + rank, + world_size, + bootstrap_store_info, + ) + + # This will only be hit when TCPStore sharing is enabled. + if self._bootstrap_store_info is None: + # To avoid race in get_free_port because we release the port after the call, + # we want to create a TCPStore server soon afterwards. + server_port = 0 + if rank == 0: + self._shared_tcp_store_server = self._create_tcp_store_server( + self._this_node.addr, server_port + ) + server_port = self._shared_tcp_store_server.port + self._bootstrap_store_info = RendezvousStoreInfo.build( + rank, + store, + local_addr=self._this_node.addr, + server_port=server_port, # For non-0 rank, this is a no-op + ) + + assert self._bootstrap_store_info is not None + if rank == 0: + assert self._shared_tcp_store_server is not None + + return RendezvousInfo( + store, + rank, + world_size, + self._bootstrap_store_info, # type: ignore[assignment] + ) + + def is_closed(self) -> bool: + """See base class.""" + try: + with self._heartbeat_lock: + self._state_holder.sync() + + return self._state_holder.state.closed + + except Exception as e: + self._record( + message=f"{type(e).__name__}: {str(e)}", + node_state=NodeState.FAILED, + ) + raise + + def set_closed(self) -> None: + """See base class.""" + try: + with self._heartbeat_lock: + self._close() + except Exception as e: + self._record( + message=f"{type(e).__name__}: {str(e)}", + node_state=NodeState.FAILED, + ) + raise + + def num_nodes_waiting(self) -> int: + """See base class.""" + try: + with self._heartbeat_lock: + self._state_holder.sync() + + return len(self._state_holder.state.wait_list) + + except Exception as e: + self._record( + message=f"{type(e).__name__}: {str(e)}", + node_state=NodeState.FAILED, + ) + raise + + def get_run_id(self) -> str: + """See base class.""" + return self._settings.run_id + + def shutdown(self) -> bool: + """See base class.""" + self._stop_heartbeats() + + try: + self._close() + + return True + except RendezvousError as ex: + msg = ( + f"The node '{self._this_node}' has failed to shutdown the rendezvous " + f"'{self._settings.run_id}' due to an error of type {type(ex).__name__}." + ) + self._record(message=msg, node_state=NodeState.FAILED) + logger.warning(msg) + + return False + except Exception as e: + self._record( + message=f"{type(e).__name__}: {str(e)}", + node_state=NodeState.FAILED, + ) + raise + + def _close(self) -> None: + op = _RendezvousCloseOp() + + deadline = self._get_deadline(self._settings.timeout.close) + + self._op_executor.run(op, deadline) + + msg = f"The node '{self._this_node}' has closed the rendezvous '{self._settings.run_id}'." + self._record(message=msg, node_state=NodeState.SUCCEEDED) + logger.info(msg) + + @staticmethod + def _keep_alive_weak(weak_self) -> None: + self = weak_self() + if self is not None: + self._keep_alive() + + def _keep_alive(self) -> None: + with self._heartbeat_lock: + op = _RendezvousKeepAliveOp() + + deadline = self._get_deadline(self._settings.timeout.heartbeat) + + try: + self._op_executor.run(op, deadline) + + msg = ( + f"The node '{self._this_node}' has sent a keep-alive heartbeat to the rendezvous " + f"'{self._settings.run_id}'." + ) + self._record(message=msg) + logger.debug(msg) + except RendezvousError as ex: + msg = ( + f"The node '{self._this_node}' has failed to send a keep-alive heartbeat to the " + f"rendezvous '{self._settings.run_id}' due to an error of type {type(ex).__name__}." + ) + self._record(message=msg, node_state=NodeState.FAILED) + logger.warning(msg) + + def _start_heartbeats(self) -> None: + self._keep_alive_timer = _PeriodicTimer( + self._settings.keep_alive_interval, self._keep_alive_weak, weakref.ref(self) + ) + + self._keep_alive_timer.set_name( + f"RendezvousKeepAliveTimer_{self._this_node.local_id}" + ) + + self._keep_alive_timer.start() + + def _stop_heartbeats(self) -> None: + if self._keep_alive_timer is None: + return + + self._keep_alive_timer.cancel() + + def _get_world(self) -> tuple[int, int]: + state = self._state_holder.state + + return state.participants[self._this_node], len(state.participants) + + def _wrap_store(self, store: Store) -> Store: + key_prefix = ( + f"torch.rendezvous.{self._settings.run_id}.{self._state_holder.state.round}" + ) + + return dist.PrefixStore(key_prefix, store) + + def _get_store(self) -> Store: + return self._wrap_store(self._store) + + def _get_deadline(self, timeout: timedelta) -> float: + return time.monotonic() + timeout.total_seconds() + + +def _get_timeout(params: RendezvousParameters, key: str) -> timedelta | None: + timeout = params.get_as_int(key + "_timeout") + if timeout is None: + return None + return timedelta(seconds=timeout) + + +def create_handler( + store: Store, backend: RendezvousBackend, params: RendezvousParameters +) -> DynamicRendezvousHandler: + """Create a new :py:class:`DynamicRendezvousHandler` from the specified parameters. + + Args: + store: + The C10d store to return as part of the rendezvous. + backend: + The backend to use to hold the rendezvous state. + + +-------------------+------------------------------------------------------+ + | Parameter | Description | + +===================+======================================================+ + | join_timeout | The total time, in seconds, within which the | + | | rendezvous is expected to complete. Defaults to 600 | + | | seconds. | + +-------------------+------------------------------------------------------+ + | last_call_timeout | An additional wait amount, in seconds, before | + | | completing the rendezvous once the minimum number of | + | | nodes has been reached. Defaults to 30 seconds. | + +-------------------+------------------------------------------------------+ + | close_timeout | The time, in seconds, within which the rendezvous is | + | | expected to close after a call to | + | | :py:meth:`RendezvousHandler.set_closed` or | + | | :py:meth:`RendezvousHandler.shutdown`. Defaults to | + | | 30 seconds. | + +-------------------+------------------------------------------------------+ + | heartbeat | The time, in seconds, within which a keep-alive | + | | heartbeat is expected to complete | + +-------------------+------------------------------------------------------+ + """ + try: + timeout = RendezvousTimeout( + _get_timeout(params, "join"), + _get_timeout(params, "last_call"), + _get_timeout(params, "close"), + _get_timeout(params, "heartbeat"), + ) + keep_alive_interval = params.get_as_int("keep_alive_interval", 5) + if keep_alive_interval is None: + raise TypeError( + "You passed 'keep_alive_interval=None' as a rendezvous configuration option" + ) + keep_alive_max_attempt = params.get_as_int("keep_alive_max_attempt", 3) + if keep_alive_max_attempt is None: + raise TypeError( + "You passed 'keep_alive_max_attempt=None' as a rendezvous configuration option" + ) + + return DynamicRendezvousHandler.from_backend( + params.run_id, + store, + backend, + params.min_nodes, + params.max_nodes, + params.local_addr, + timeout, + keep_alive_interval=keep_alive_interval, + keep_alive_max_attempt=keep_alive_max_attempt, + ) + except Exception as e: + construct_and_record_rdzv_event( + message=f"{type(e).__name__}: {str(e)}", + run_id=params.run_id, + node_state=NodeState.FAILED, + ) + raise diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/rendezvous/etcd_rendezvous.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/rendezvous/etcd_rendezvous.py new file mode 100644 index 0000000000000000000000000000000000000000..93a7073bed87a33a7f2ba0dfb64c7daa57b9d55f --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/rendezvous/etcd_rendezvous.py @@ -0,0 +1,1080 @@ +#!/usr/bin/env python3 +# mypy: allow-untyped-defs + +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import json +import logging +import sys +import threading +import time + + +try: + import etcd # type: ignore[import] +except ModuleNotFoundError: + from . import _etcd_stub as etcd + +from torch.distributed.elastic.rendezvous import ( + RendezvousClosedError, + RendezvousError, + RendezvousHandler, + RendezvousInfo, + RendezvousParameters, + RendezvousStoreInfo, + RendezvousTimeoutError, +) + +from .etcd_store import cas_delay, EtcdStore +from .utils import parse_rendezvous_endpoint + + +__all__ = [ + "EtcdRendezvousRetryableFailure", + "EtcdRendezvousRetryImmediately", + "EtcdRendezvousHandler", + "EtcdRendezvous", + "create_rdzv_handler", +] + +_log_fmt = logging.Formatter("%(levelname)s %(asctime)s %(message)s") +_log_handler = logging.StreamHandler(sys.stderr) +_log_handler.setFormatter(_log_fmt) + +logger = logging.getLogger(__name__) +logger.propagate = False +logger.setLevel(logging.INFO) +logger.addHandler(_log_handler) + + +# Retryable failure exception means the we were too late to make +# a desired state transition (e.g. because of a race condition), +# and should now restart from the beginning. +# A small delay is recommended to avoid spamming Etcd. +class EtcdRendezvousRetryableFailure(Exception): + pass + + +# Similar to retryable failure, but the new state we observed suggests we +# can re-try immediately, i.e. without a need for "safety delay". +class EtcdRendezvousRetryImmediately(Exception): + pass + + +# Default timeout for the rendezvous. +_DEFAULT_TIMEOUT: int = 600 # 10 minutes + +# Additional waiting time after reaching the minimum number of nodes +# in case the rendezvous is elastic (min != max). +_DEFAULT_LAST_CALL_TIMEOUT: int = 30 # 30 seconds + +# Various constants used internally in EtcdRendezvous +CONST_ETCD_SETUP_TTL = 5 +CONST_ETCD_FROZEN_TTL = 10 +CONST_ETCD_JOINABLE_EPHEMERAL_TTL = 10 + +# Ephemeral node TTL for worker's keep-alive key: +CONST_WORKER_KEEPALIVE_TTL = 10 + +# TTL for the ephemeral run_id-specific directory. All rendezvous state data +# for a specific run_id (job instance) is contained within directory. +# Its only role is to clean-up rendezvous data from old runs (for the case when +# etcd server is persistent), and has no affect on correctness, but should be +# larger than any timeouts that a worker process is expected to survive: +CONST_RUNID_SUBROOT_TTL = 7200 # 2 hours + + +class EtcdRendezvousHandler(RendezvousHandler): + """ + Implements a + :py:class:`torch.distributed.elastic.rendezvous.RendezvousHandler` interface + backed by + :py:class:`torch.distributed.elastic.rendezvous.etcd_rendezvous.EtcdRendezvous`. + ``EtcdRendezvousHandler`` uses a URL to configure the type of rendezvous to + use and to pass implementation specific configurations to the rendezvous + module. The basic etcd rendezvous configuration URL looks like the following + :: + + etcd://:/?min_workers=&max_workers= # noqa: W605 + + -- example -- + + etcd://localhost:2379/1234?min_workers=1&max_workers=3 + + The URL above is interpreted as follows: + + 1. Use the rendezvous handler that is registered with the ``etcd`` + scheme + 2. The ``etcd`` endpoint to use is ``localhost:2379`` + 3. ``job_id == 1234`` is used as the prefix in etcd (this allows one to + share a common etcd server for multiple jobs so long as the + ``job_ids`` are guaranteed to be unique). Note that the job id can be + any string (e.g. does not need to be a number) as long as it is + unique. + 4. ``min_workers=1`` and ``max_workers=3`` specifies a range for + membership size - Torch Distributed Elastic starts running the job as + long as the cluster size is greater than or equal to ``min_workers`` + and admits up to ``max_workers`` into the cluster. + + Below are a full list of the parameters that can be passed to etcd + rendezvous: + + +--------------------------------------------+--------------------------+ + | Parameter | Description | + +============================================+==========================+ + | min_workers | minimum number of | + | | workers for the | + | | rendezvous to be valid | + +--------------------------------------------+--------------------------+ + | max_workers | maximum number of | + | | workers to admit | + +--------------------------------------------+--------------------------+ + | timeout | total timeout within | + | | which next_rendezvous is | + | | expected to succeed | + | | (default 600s) | + +--------------------------------------------+--------------------------+ + | last_call_timeout | additional wait amount | + | | ("last call") after min | + | | number of workers has | + | | been reached (defaults | + | | to 30s) | + +--------------------------------------------+--------------------------+ + | etcd_prefix | path prefix (from etcd | + | | root), inside which all | + | | etcd nodes will be | + | | created (defaults to | + | | ``/torchelastic/p2p``) | + +--------------------------------------------+--------------------------+ + """ + + def __init__(self, rdzv_impl: "EtcdRendezvous", local_addr: str | None): + """ + Args: + rdzv_impl: the implementation of the rendezvous + local_addr: the local address of the current node + """ + + self._rdzv_impl = rdzv_impl + self._local_addr = local_addr + + def __del__(self): + # TODO: look into using weakref here instead. + del self._rdzv_impl + + def get_backend(self) -> str: + return "etcd" + + def next_rendezvous(self): + rdzv_version, rank, world_size = self._rdzv_impl.rendezvous_barrier() + + logger.info("Creating EtcdStore as the c10d::Store implementation") + store = self._rdzv_impl.setup_kv_store(rdzv_version) + + bootstrap_store_info = RendezvousStoreInfo.build( + rank, store, local_addr=self._local_addr + ) + return RendezvousInfo(store, rank, world_size, bootstrap_store_info) + + def is_closed(self): + try: + _, state = self._rdzv_impl.get_rdzv_state() + return state["status"] == "closed" + except etcd.EtcdKeyNotFound: + # No rendezvous state, so it cannot be closed. + return False + + def set_closed(self): + self._rdzv_impl.set_closed() + + def num_nodes_waiting(self): + try: + _, state = self._rdzv_impl.get_rdzv_state() + if state["status"] == "final": + return state["num_workers_waiting"] + except etcd.EtcdKeyNotFound: + pass + return 0 + + def get_run_id(self) -> str: + return self._rdzv_impl._run_id + + def shutdown(self) -> bool: + try: + self.set_closed() + return True + except BaseException: # noqa: B036 + logger.warning("Shutdown failed", exc_info=True) + return False + + +# TODO: we should probably handle a few additional errors, +# like EtcdLeaderElectionInProgress and EtcdWatcherCleared. These are +# only relevant for multi-node Etcd ensemble. A simple retry would work, +# but is verbose to add everywhere. Consider wrapping the client calls +# into auto-retry for these errors? +# +class EtcdRendezvous: + """A rendezvous implementation that uses `etcd `__ as the backend store.""" + + def __init__( + self, + client, + prefix, + run_id, + num_min_workers, + num_max_workers, + timeout, + last_call_timeout, + ): + self.client = client + logger.info("Etcd machines: %s", self.client.machines) + + self._prefix = prefix + self._run_id = run_id + self._num_min_workers = num_min_workers + self._num_max_workers = num_max_workers + self._timeout = timeout + self._last_call_timeout = last_call_timeout + + # For cleaning up TTL refresher threads (for ephemeral keys) + self._lease_run_id_stop = None + self._lease_this_rank_stop = None + + if not self._prefix.endswith("/"): + self._prefix += "/" + + # Setup a permanent prefix dir, if didn't exist + if self._prefix != "/": + self.create_path_if_not_exists(self._prefix) + + # Lease a "sub-root" node specific to this job instance (run_id) + self.create_path_if_not_exists(self.get_path(""), ttl=CONST_RUNID_SUBROOT_TTL) + self._lease_run_id_stop = self.setup_lease_renewal( + self.get_path(""), ttl=CONST_RUNID_SUBROOT_TTL + ) + + # Subdir for all rendezvous work + self.create_path_if_not_exists(self.get_path("/rdzv")) + + # Create a rendezvous version counter, if doesn't exist + try: + self.client.write( + key=self.get_path("/rdzv/version_counter"), value="0", prevExist=False + ) + except etcd.EtcdAlreadyExist: + pass + + def __del__(self): + # TODO: look into using weakref here instead. + if self._lease_run_id_stop is not None: + self._lease_run_id_stop.set() + + if self._lease_this_rank_stop is not None: + self._lease_this_rank_stop.set() + + def rendezvous_barrier(self): + """ + Main entry point for next rendezvous. + + This method is blocking until rendezvous succeeds or a timeout occurs. + + Returns: + ``(rdzv_version, rank, world_size)`` + + Raises: + RendezvousTimeoutError - timeout waiting for rendezvous + RendezvousClosedError - rendezvous is or was closed while waiting + RendezvousError - other persistent errors that + render the rendezvous non-retryable + """ + self._rendezvous_deadline = time.time() + self._timeout + while True: + if time.time() > self._rendezvous_deadline: + raise RendezvousTimeoutError + + logger.info("Attempting to join next rendezvous") + try: + # Dis-own our lease in the previous rendezvous, if exists + if self._lease_this_rank_stop is not None: + self._lease_this_rank_stop.set() + + return self.init_phase() + + except EtcdRendezvousRetryImmediately: + # The type of failure suggests we can retry without delay + pass + + except EtcdRendezvousRetryableFailure: + # In case of retryable failure, wait a small delay + # to avoid spamming etcd + time.sleep(1) + + except RendezvousTimeoutError: + logger.info("Rendezvous timeout occurred in EtcdRendezvousHandler") + raise + + except RendezvousClosedError: + logger.info( + "Rendezvous for run_id=%s was observed to be closed", self._run_id + ) + raise + + except RendezvousError: + raise + + except Exception as e: + # In case of a general exception, wait a small delay + # to avoid spamming etcd + # FIXME: there are a few things that fall under this like + # etcd.EtcdKeyNotFound, etc, which could be handled more explicitly. + logger.info("Rendezvous attempt failed, will retry. Reason: %s", e) # noqa: G200 + time.sleep(1) + + def init_phase(self): + """ + Initially, the rendezvous state is expected to be one of: + + 1. empty (non-existent) - in this case we try to create a new one. + 2. joinable - we try to join it. + 3. final - we announce ourselves as waiting, and go into monitoring mode + + Any other state is considered transitional, and will be retried after + a short delay. + + Returns: + ``(rdzv_version, rank, world_size)`` + + Raises: + RendezvousClosedError - current rendezvous was/is closed + EtcdRendezvousRetryableFailure - observed some intermediate + state, which is best handled by retrying later + """ + try: + active_version = self.try_create_rendezvous() + state = json.loads(active_version.value) + logger.info("New rendezvous state created: %s", state) + except etcd.EtcdAlreadyExist: + active_version, state = self.get_rdzv_state() + # Note: it is possible for above query to fail (etcd.EtcdKeyNotFound), + # but this is ok for us - just means we'll restart from beginning. + logger.info("Observed existing rendezvous state: %s", state) + + if state["status"] == "closed": + raise RendezvousClosedError + + if state["status"] == "joinable": + return self.join_phase(state["version"]) + + if state["status"] == "final": + self.handle_existing_rendezvous(state["version"]) + raise EtcdRendezvousRetryImmediately + + self.try_wait_for_state_change(etcd_index=active_version.etcd_index + 1) + raise EtcdRendezvousRetryableFailure + + def join_phase(self, expected_version): + """ + We observed a rendezvous state in 'joinable' state, and attempt to join this + particular version, and then wait for all other peers to join. + """ + # Failure to join will propagate an exception, causing a re-entry. + active_version, this_rank = self.join_rendezvous(expected_version) + state = json.loads(active_version.value) + logger.info( + "Joined rendezvous version %s as rank %s. Full state: %s", + state["version"], + this_rank, + state, + ) + + # If this worker was first to reach num_min_workers requirement, + # and rendezvous is still joinable (therefore it is elastic), + # then this worker will be responsible for waiting out the "last call" + # timeout and closing (i.e. transitioning to 'frozen') the rendezvous + # afterwards. + # As a safety against a potential failure of this worker (during the + # last call timeout), the rendezvous state is made ephemeral + # when min_num_workers is reached. + + if this_rank == self._num_min_workers - 1 and state["status"] == "joinable": + logger.info("Rank %s is responsible for join last call.", this_rank) + last_call_deadline = time.time() + self._last_call_timeout + self.handle_join_last_call(expected_version, last_call_deadline) + logger.info("Rank %s finished join last call.", this_rank) + + # Wait for rendezvous state to be frozen, which means a fixed set of peers + logger.info("Waiting for remaining peers.") + active_version = self.wait_for_peers(expected_version) + state = json.loads(active_version.value) + + assert state["version"] == expected_version, ( + "Logic error: failed to observe version mismatch" + ) + + return self.confirm_phase(expected_version, this_rank) + + def confirm_phase(self, expected_version, this_rank): + """ + Once the rendezvous state transitions from 'joinable' to 'frozen', + we have every participant confirm their membership and setup per-member + keep-alive TTL keys, and then wait for all other participants to confirm, + which would then successfully conclude this rendezvous. + """ + logger.info("All peers arrived. Confirming membership.") + self.confirm_membership(expected_version, this_rank) + + logger.info("Waiting for confirmations from all peers.") + active_version = self.wait_for_final(expected_version) + state = json.loads(active_version.value) + + logger.info( + "Rendezvous version %s is complete. Final state: %s", + state["version"], + state, + ) + + # Rendezvous version number; our rank in it; world size + return state["version"], this_rank, len(state["participants"]) + + def handle_existing_rendezvous(self, expected_version): + """ + Handle the case when there's an existing (state 'final) rendezvous already + in place, and we have to announce ourselves waiting, and wait until + the next rendezvous opportunity. + """ + # If state is 'final' -> increment num_workers_waiting + # Then, observe state changes: + # 1. if it's no longer final -> bail out and re-try + # 2. if keep alives are missing, destroy it and bail out. + active_state = self.announce_self_waiting(expected_version) + logger.info( + "Added self to waiting list. Rendezvous full state: %s", active_state.value + ) + + self.wait_for_rendezvous_to_free(expected_version) + logger.info( + "Previously existing rendezvous state changed. Will re-try joining." + ) + + def try_create_rendezvous(self): + """ + Create new rendezvous state or raise an exception that indicates an unexpected state (e.g. already exists). + + Raises: + RendezvousError - on unexpected state + """ + # Initially active_version is ephemeral - this is to handle the + # possibility that might fail to complete the setup transaction, + # i.e. the transition "setup" -> "joinable". + active_version = self.client.write( + key=self.get_path("/rdzv/active_version"), + value=json.dumps({"status": "setup"}), + prevExist=False, + ttl=CONST_ETCD_SETUP_TTL, + ) + + try: + version_counter = self.client.get(self.get_path("/rdzv/version_counter")) + version_counter.value = str(int(version_counter.value) + 1) + self.client.update(version_counter) + except (etcd.EtcdKeyNotFound, etcd.EtcdCompareFailed) as e: + raise RendezvousError( + "Unexpected state of EtcdRendezvousHandler, worker needs to die." + ) from e + + # Any failure below results in declaring a retryable rendezvous failure. + # The ephemeral /rdzv/active_version will expire and someone can then + # re-try the setup process. + + # Create directory node for participant data + self.client.write( + key=self.get_path(f"/rdzv/v_{version_counter.value}"), + value=None, + dir=True, + prevExist=False, + ) + + # Publish rendezvous version and signal it is ready-to-be-joined. + # If rendezvous was set closed just before this, a retry will happen, + # where the closed condition will be handled. + return self.client.test_and_set( + key=self.get_path("/rdzv/active_version"), + value=json.dumps( + { + "status": "joinable", + "version": version_counter.value, + "participants": [], + } + ), + prev_value=active_version.value, + ) + + def join_rendezvous(self, expected_version): + """Helper method for the join phase.""" + # Use compare-and-swap to add self to rendezvous state: + while True: + cas_delay() + active_version, state = self.get_rdzv_state() + + if state["status"] != "joinable": + raise EtcdRendezvousRetryableFailure( + "Rendezvous state became non-joinable before we could join. " + "Must join next one." + ) + + if state["version"] != expected_version: + raise EtcdRendezvousRetryImmediately( + "Rendezvous version changed. Must try join the new one." + ) + + assert len(state["participants"]) < self._num_max_workers, ( + "Logic error: joinable rendezvous should always have space left" + ) + + this_rank = len(state["participants"]) + state["participants"].append(this_rank) + + # When reaching min workers, or changing state to frozen, we'll set + # the active_version node to be ephemeral. + set_ttl: int | None = None + if len(state["participants"]) == self._num_max_workers: + state["status"] = "frozen" + state["keep_alives"] = [] + set_ttl = CONST_ETCD_FROZEN_TTL + elif len(state["participants"]) >= self._num_min_workers: + set_ttl = CONST_ETCD_JOINABLE_EPHEMERAL_TTL + + try: + # Compare-and-swap. + active_version = self.client.test_and_set( + key=self.get_path("/rdzv/active_version"), + value=json.dumps(state), + prev_value=active_version.value, + ttl=set_ttl, + ) + # We succeeded joining. + return active_version, this_rank + + except etcd.EtcdCompareFailed: + logger.info("Join rendezvous CAS unsuccessful, retrying") + + def wait_for_peers(self, expected_version): + """Helper method for the join phase.""" + active_version, state = self.get_rdzv_state() + while True: + if state["status"] == "frozen" and state["version"] == expected_version: + # Success, all peers arrived. + return active_version + + elif state["status"] == "joinable" and state["version"] == expected_version: + # Continue waiting for any interesting events. + active_version, state = self.try_wait_for_state_change( + etcd_index=active_version.etcd_index + 1 + ) + + else: + # No valid transition possible at this point + raise EtcdRendezvousRetryableFailure( + "Rendezvous state transition no longer possible. Must re-enter." + ) + + def confirm_membership(self, expected_version, this_rank): + """Helper method for the confirm phase.""" + # Compare-and-swap loop + while True: + cas_delay() + active_version, state = self.get_rdzv_state() + + if state["status"] != "frozen": + raise EtcdRendezvousRetryImmediately( + "Rendezvous no longer frozen, before we confirmed. " + "Must join next one" + ) + if state["version"] != expected_version: + raise EtcdRendezvousRetryImmediately( + "Rendezvous version changed. Must try join the new one." + ) + + this_lease_key = self.get_path( + f"/rdzv/v_{expected_version}/rank_{this_rank}" + ) + self.client.set(this_lease_key, value=None, ttl=CONST_WORKER_KEEPALIVE_TTL) + + state["keep_alives"].append(this_lease_key) + if len(state["keep_alives"]) == len(state["participants"]): + # Everyone confirmed (this rank is last to do so) + state["status"] = "final" + state["num_workers_waiting"] = 0 + finalize = True + else: + finalize = False + + try: + # Compare-and-swap. If new state is still frozen, keep it ephemeral. + active_version = self.client.test_and_set( + key=self.get_path("/rdzv/active_version"), + value=json.dumps(state), + prev_value=active_version.value, + ttl=None if finalize else CONST_ETCD_FROZEN_TTL, + ) + + self._lease_this_rank_stop = self.setup_lease_renewal( + this_lease_key, ttl=CONST_WORKER_KEEPALIVE_TTL + ) + return active_version + + except etcd.EtcdCompareFailed: + logger.info("Confirm membership CAS unsuccessful, retrying") + + def wait_for_final(self, expected_version): + """Helper method for the confirm phase.""" + active_version, state = self.get_rdzv_state() + while True: + if state["status"] == "final" and state["version"] == expected_version: + # Success. This rendezvous is final, and we accept it. + return active_version + + elif state["status"] == "frozen" and state["version"] == expected_version: + # Continue waiting for any interesting events. + active_version, state = self.try_wait_for_state_change( + etcd_index=active_version.etcd_index + 1 + ) + + else: + # No valid transition possible at this point + raise EtcdRendezvousRetryableFailure( + "Rendezvous state transition no longer possible. Must re-enter." + ) + + def announce_self_waiting(self, expected_version): + """ + Announce this worker is waiting (via num_workers_waiting counter) to join next + rendezvous, but only if state and version match. + """ + while True: + cas_delay() + active_version, state = self.get_rdzv_state() + + if state["status"] != "final" or state["version"] != expected_version: + raise EtcdRendezvousRetryImmediately + + # Increment counter to signal an additional waiting worker. + state["num_workers_waiting"] += 1 + + try: + active_version = self.client.test_and_set( + key=self.get_path("/rdzv/active_version"), + value=json.dumps(state), + prev_value=active_version.value, + ) + return active_version + + except etcd.EtcdCompareFailed: + logger.info("Announce self as waiting CAS unsuccessful, retrying") + + def wait_for_rendezvous_to_free(self, expected_version): + """ + When there's an existing valid rendezvous in state 'final', we have to wait until the next opportunity to join. + + Such opportunity may come from: + + 1. rendezvous state changed by someone else, in which case we unblock and retry. + 2. rendezvous becomes invalid because at least one member failed to renew their + leased keep_alive node. We detect this, and destroy the rendezvous. + """ + active_version, state = self.get_rdzv_state() + while True: + if state["status"] != "final" or state["version"] != expected_version: + return + + # Check if current rendezvous state is valid, in the sense that all + # its members are alive (renewing their lease). + # If not, try destroy this rendezvous, so a new one can be created. + alive_members = self.client.get( + self.get_path(f"/rdzv/v_{expected_version}") + ) + keep_alive_keys = [ch.key for ch in alive_members.children] + + for key in state["keep_alives"]: + if key not in keep_alive_keys: + # This participant didn't renew their lease. We'll declare this + # rendezvous version as dead (but only if it hadn't changed) + logger.info("Keep-alive key %s is not renewed.", key) + logger.info( + "Rendezvous version %s is incomplete. ", expected_version + ) + logger.info("Attempting to destroy it.") + + # Compare-and-delete operation. Throws if compare failed, + # which means rendezvous was already destroyed/re-created/closed, + # and we can try to re-enter the barrier. + self.client.delete( + key=self.get_path("/rdzv/active_version"), + prevValue=active_version.value, + ) + + logger.info( + "Destroyed rendezvous version %s successfully.", + expected_version, + ) + + # We can return (and retry) immediately + return + + # Existing rendezvous seems valid, no reason to destroy it. + # We just have to wait until something changes and re-check. + try: + overall_timeout = ( + max(self._rendezvous_deadline - time.time(), 0.0) + 1.0 + ) + self.client.watch( + key=self.get_path("/rdzv"), + index=active_version.etcd_index + 1, + recursive=True, + timeout=overall_timeout, + ) + except (etcd.EtcdEventIndexCleared, etcd.EtcdWatchTimedOut): + pass + + if time.time() > self._rendezvous_deadline: + raise RendezvousTimeoutError + active_version, state = self.get_rdzv_state() + + def handle_join_last_call(self, expected_version, deadline): + """ + After we reach min number of workers, one particular worker takes on the + responsibility of waiting an additional timeout before closing the join window. + If the worker responsible for this fails, the rendezvous will be destroyed due + to expiring TTL, and the other participants will re-rendezvous. + + Here we expect to see state + Exit gracefully if either: + + 1. state becomes + 2. timeout happens (reaching deadline), in which case + we try the transition to + + Exit with exception otherwise. + """ + active_version, state = self.get_rdzv_state() + while True: + if state["status"] == "frozen" and state["version"] == expected_version: + # Worker set became frozen before last-call timeout. This is possible + # when num_max_workers is reached before the timeout. + return + + if state["status"] != "joinable" or state["version"] != expected_version: + raise EtcdRendezvousRetryableFailure( + "Rendezvous state transition no longer possible. Must re-enter." + ) + + # If timeout occurred, attempt a state transition (joinable -> frozen) + if time.time() >= deadline: + state["status"] = "frozen" + state["keep_alives"] = [] + try: + active_version = self.client.test_and_set( + key=self.get_path("/rdzv/active_version"), + value=json.dumps(state), + prev_value=active_version.value, + ttl=CONST_ETCD_FROZEN_TTL, + ) + # We successfully made this rendezvous frozen. + return + except etcd.EtcdCompareFailed: + logger.info( + "Join last-call transition CAS unsuccessful. Will retry" + ) + cas_delay() + active_version, state = self.get_rdzv_state() + continue + + # Timeout did not occur, so we must refresh TTL, and wait for + # further changes. Note: we only want TTL to be refreshed if + # state is still joinable, hence we use CAS for that here, + # even though we don't change any of the data. + try: + active_version = self.client.test_and_set( + key=self.get_path("/rdzv/active_version"), + value=active_version.value, + prev_value=active_version.value, + ttl=CONST_ETCD_JOINABLE_EPHEMERAL_TTL, + ) + + # Minimize "oversleeping": + timeout = min( + CONST_ETCD_JOINABLE_EPHEMERAL_TTL / 2, + deadline - time.time() + 1.0, # Oversleeping by 1s is ok. + ) + active_version, state = self.try_wait_for_state_change( + etcd_index=active_version.etcd_index + 1, timeout=timeout + ) + except etcd.EtcdCompareFailed: + logger.info("Join last-call TTL refresh CAS unsuccessful, will retry") + cas_delay() + active_version, state = self.get_rdzv_state() + + def set_closed(self): + """ + Mark rendezvous 'closed' for current run_id, which is used to signal other + participants to not attempt to perform (re-)rendezvous. This is useful + when one of the workers decides the job is complete. + """ + while True: + active_version, state = self.get_rdzv_state() + + if state["status"] == "closed": + # Already closed by someone else. + return + + state["status"] = "closed" + try: + self.client.test_and_set( + key=self.get_path("/rdzv/active_version"), + value=json.dumps(state), + prev_value=active_version.value, + ) + return + + except etcd.EtcdCompareFailed: + logger.info("Set closed CAS unsuccessful, retrying") + cas_delay() + + def get_rdzv_state(self): + active_version = self.client.get(key=self.get_path("/rdzv/active_version")) + return active_version, json.loads(active_version.value) + + def try_wait_for_state_change(self, etcd_index, timeout=None): + # Don't sleep past the overall deadline (at least more than by 1s) + overall_timeout = max(self._rendezvous_deadline - time.time(), 0.0) + 1.0 + timeout = overall_timeout if timeout is None else min(timeout, overall_timeout) + + try: + self.client.watch( + self.get_path("/rdzv/active_version"), index=etcd_index, timeout=timeout + ) + except (etcd.EtcdEventIndexCleared, etcd.EtcdWatchTimedOut): + pass + + if time.time() > self._rendezvous_deadline: + raise RendezvousTimeoutError + + # Unfortunately, we have to do another fetch in order to get last etcd_index. + return self.get_rdzv_state() + + def get_path(self, path): + if not path.startswith("/"): + path = "/" + path + + return f"{self._prefix}run_{self._run_id}{path}" + + def create_path_if_not_exists(self, full_path, ttl=None): + try: + self.client.write( + key=full_path, value=None, dir=True, prevExist=False, ttl=ttl + ) + except etcd.EtcdAlreadyExist: + pass + + def setup_lease_renewal(self, full_path, ttl): + # NOTE: For ephemeral key TTL renewal (~lease) to work correctly, + # make sure you don't call any long-blocking methods that do not + # release the Python's GIL! An example of this is calling a pybind11 + # extension function that is blocking / long-running, but is not + # doing a scoped release of the GIL. + def lease_worker(client, path, ttl, stop_event): + while True: + try: + client.refresh(path, ttl=ttl) + except etcd.EtcdKeyNotFound: + break + except ConnectionRefusedError: + # This error usually occurs during test when the server already got terminated but the + # python garbage collector have not yet invoked the __del__ method. + break + + if stop_event.wait(timeout=ttl / 2): + break + + lease_stop_event = threading.Event() + lease_thread = threading.Thread( + target=lease_worker, args=(self.client, full_path, ttl, lease_stop_event) + ) + + lease_thread.daemon = True + lease_thread.start() + + return lease_stop_event + + def store_extra_data(self, rdzv_version, key, value): + node = self.get_path(f"/rdzv/v_{rdzv_version}/extra_data") + try: + # If first time we are storing anything: + extra_data = self.client.write( + key=node, value=json.dumps({key: value}), prevExist=False + ) + return + except etcd.EtcdAlreadyExist: + pass + + # CAS loop, to make sure we don't lose concurrent stores. + while True: + # We never delete extra_data. Failure here should be fatal, no special handling. + extra_data = self.client.get(node) + + new_extra_data_value = json.loads(extra_data.value) + new_extra_data_value[key] = value + + try: + extra_data = self.client.test_and_set( + key=node, + value=json.dumps(new_extra_data_value), + prev_value=extra_data.value, + ) + return + except etcd.EtcdCompareFailed: + logger.info("Store extra_data CAS unsuccessful, retrying") + time.sleep(0.1) + + def load_extra_data(self, rdzv_version, key, timeout=None): + # 'extra_data' node itself, and the directory it is located in: + node = self.get_path(f"/rdzv/v_{rdzv_version}/extra_data") + node_dir = self.get_path(f"/rdzv/v_{rdzv_version}") + + # TODO: implement timeout + # https://github.com/pytorch/elastic/issues/12 + while True: + # Combined wait for the node itself, and the key inside it. + root = self.client.get(node_dir) + + # Find the extra_data node, if it exists + extra_data = [n for n in root.children if n.key == node] + assert len(extra_data) <= 1 + + # Node for extra_data exists, check the desired key inside it. + if len(extra_data) == 1: + extra_data_dict = json.loads(extra_data[0].value) + if key in extra_data_dict: + return extra_data_dict[key] + + # The 'extra_data' node doesn't exist, or they key isn't published yet. + # Wait for interesting events on the extra_data node and retry. + try: + self.client.watch(node, index=root.etcd_index + 1) + except (etcd.EtcdEventIndexCleared, etcd.EtcdWatchTimedOut): + pass + + def setup_kv_store(self, rdzv_version): + store_path = self.get_path(f"/rdzv/v_{rdzv_version}/kv") + self.create_path_if_not_exists(store_path) + return EtcdStore(etcd_client=self.client, etcd_store_prefix=store_path) + + +def _create_etcd_client(params: RendezvousParameters) -> etcd.Client: + """Create a new ``etcd.Client`` from the specified ``RendezvousParameters``.""" + hostname, port = parse_rendezvous_endpoint(params.endpoint, 2379) + + # The communication protocol + protocol = params.config.get("protocol") + if protocol is None: + protocol = "http" + else: + if protocol != "http" and protocol != "https": + raise ValueError("The etcd protocol must be HTTP or HTTPS.") + + # The SSL client certificate + ssl_cert = params.config.get("cert") + if ssl_cert is not None: + cert_key = params.config.get("key") + if cert_key is not None: + # The etcd client expects the certificate key as the second element + # of the `cert` tuple. + ssl_cert = (ssl_cert, cert_key) + + # The root certificate + ca_cert = params.config.get("cacert") + + return etcd.Client( + hostname, + port, + protocol=protocol, + cert=ssl_cert, + ca_cert=ca_cert, + allow_reconnect=True, + ) + + +# Handler for torch.distributed "static" registration +def create_rdzv_handler(params: RendezvousParameters) -> RendezvousHandler: + """ + Usage: + + :: + + rdzv_params = RendezvousParameters( + backend="etcd", + endpoint="192.168.0.42:2379", + run_id="123", + min_nodes=4, + max_nodes=8, + timeout=300, + last_call_timeout=30, + etcd_prefix="custom_prefix", + protocol="https", + cacert="/etc/kubernetes/certs/ca.crt", + cert="/etc/kubernetes/certs/client.crt", + key="/etc/kubernetes/certs/client.key") + # -- or -- + rdzv_params = RendezvousParameters( + backend="etcd", + endpoint="192.168.0.42:2379", + run_id="123", + min_nodes=4, + max_nodes=8) + + etcd_rdzv_handler = create_etcd_rendezvous_handler(rdzv_params) + + + Where: + run_id - unique id for this training job instance, + min_nodes - min number of workers expected to join the rendezvous, + max_nodes - max number of workers allowed to join the rendezvous, + defaults to min_workers is not specified. + timeout - total timeout within which next_rendezvous is expected to + succeed; a RendezvousTimeoutError is raised otherwise; + Defaults is 600 (10 minutes). + last_call_timeout - additional wait amount ("last call") after + min number of workers has been reached. + Defaults to 30 seconds. + etcd_prefix - path prefix (from etcd root), inside which all + etcd nodes will be created. + Default is "/torchelastic/p2p". + protocol - http (default) or https to access etcd. + cacert - CA cert to access etcd, only makes sense with https. + cert - client cert to access etcd, only makes sense with https. + key - client key to access etcd, only makes sense with https. + """ + client = _create_etcd_client(params) + + etcd_prefix = params.get("etcd_prefix", "/torchelastic/p2p") + + rdzv = EtcdRendezvous( + client=client, + prefix=etcd_prefix, + run_id=params.run_id, + num_min_workers=params.min_nodes, + num_max_workers=params.max_nodes, + timeout=params.get_as_int("timeout", _DEFAULT_TIMEOUT), + last_call_timeout=params.get_as_int( + "last_call_timeout", _DEFAULT_LAST_CALL_TIMEOUT + ), + ) + return EtcdRendezvousHandler( + rdzv_impl=rdzv, + local_addr=params.local_addr, + ) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/rendezvous/etcd_rendezvous_backend.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/rendezvous/etcd_rendezvous_backend.py new file mode 100644 index 0000000000000000000000000000000000000000..4cda28221ff4ec79fbd468a5067c91942b9b7be4 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/rendezvous/etcd_rendezvous_backend.py @@ -0,0 +1,214 @@ +# mypy: allow-untyped-defs +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import binascii +from base64 import b64decode, b64encode +from typing import cast + +import urllib3.exceptions # type: ignore[import] + + +try: + import etcd # type: ignore[import] +except ModuleNotFoundError: + from . import _etcd_stub as etcd + +from torch.distributed import Store + +from .api import RendezvousConnectionError, RendezvousParameters, RendezvousStateError +from .dynamic_rendezvous import RendezvousBackend, Token +from .etcd_store import EtcdStore +from .utils import parse_rendezvous_endpoint + + +class EtcdRendezvousBackend(RendezvousBackend): + """Represents an etcd-based rendezvous backend. + + Args: + client: + The ``etcd.Client`` instance to use to communicate with etcd. + run_id: + The run id of the rendezvous. + key_prefix: + The path under which to store the rendezvous state in etcd. + ttl: + The TTL of the rendezvous state. If not specified, defaults to two hours. + """ + + _DEFAULT_TTL = 7200 # 2 hours + + _client: etcd.Client + _key: str + _ttl: int + + def __init__( + self, + client: etcd.Client, + run_id: str, + key_prefix: str | None = None, + ttl: int | None = None, + ) -> None: + if not run_id: + raise ValueError("The run id must be a non-empty string.") + + self._client = client + + if key_prefix: + self._key = key_prefix + "/" + run_id + else: + self._key = run_id + + if ttl and ttl > 0: + self._ttl = ttl + else: + self._ttl = self._DEFAULT_TTL + + @property + def name(self) -> str: + """See base class.""" + return "etcd-v2" + + def get_state(self) -> tuple[bytes, Token] | None: + """See base class.""" + try: + result = self._client.read(self._key) + except etcd.EtcdKeyNotFound: + return None + except (etcd.EtcdException, urllib3.exceptions.TimeoutError) as exc: + raise RendezvousConnectionError( + "The connection to etcd has failed. See inner exception for details." + ) from exc + + return self._decode_state(result) + + def set_state( + self, state: bytes, token: Token | None = None + ) -> tuple[bytes, Token, bool] | None: + """See base class.""" + base64_state = b64encode(state).decode() + + kwargs = {} + + def get_state(): + result = self.get_state() + if result is not None: + return *result, False + return None + + if token: + try: + token = int(token) + except ValueError: + return get_state() + + if token: + kwargs["prevIndex"] = token + else: + kwargs["prevExist"] = False + + try: + result = self._client.write(self._key, base64_state, self._ttl, **kwargs) + except (etcd.EtcdAlreadyExist, etcd.EtcdCompareFailed): + result = None + except (etcd.EtcdException, urllib3.exceptions.TimeoutError) as exc: + raise RendezvousConnectionError( + "The connection to etcd has failed. See inner exception for details." + ) from exc + + if result is None: + return get_state() + + tmp = *self._decode_state(result), True + return tmp + + def _decode_state(self, result: etcd.EtcdResult) -> tuple[bytes, Token]: + # pyrefly: ignore [missing-attribute] + base64_state = result.value.encode() + + try: + state = b64decode(base64_state) + except binascii.Error as exc: + raise RendezvousStateError( + "The state object is corrupt. See inner exception for details." + ) from exc + + # pyrefly: ignore [missing-attribute] + return state, result.modifiedIndex + + +def _create_etcd_client(params: RendezvousParameters) -> etcd.Client: + host, port = parse_rendezvous_endpoint(params.endpoint, default_port=2379) + + # The timeout + read_timeout = cast(int, params.get_as_int("read_timeout", 60)) + if read_timeout <= 0: + raise ValueError("The read timeout must be a positive integer.") + + # The communication protocol + protocol = params.get("protocol", "http").strip().lower() + if protocol != "http" and protocol != "https": + raise ValueError("The protocol must be HTTP or HTTPS.") + + # The SSL client certificate + ssl_cert = params.get("ssl_cert") + if ssl_cert: + ssl_cert_key = params.get("ssl_cert_key") + if ssl_cert_key: + # The etcd client expects the certificate key as the second element + # of the `cert` tuple. + ssl_cert = (ssl_cert, ssl_cert_key) + + # The root certificate + ca_cert = params.get("ca_cert") + + try: + return etcd.Client( + host, + port, + read_timeout=read_timeout, + protocol=protocol, + cert=ssl_cert, + ca_cert=ca_cert, + allow_reconnect=True, + ) + except (etcd.EtcdException, urllib3.exceptions.TimeoutError) as exc: + raise RendezvousConnectionError( + "The connection to etcd has failed. See inner exception for details." + ) from exc + + +def create_backend(params: RendezvousParameters) -> tuple[EtcdRendezvousBackend, Store]: + """Create a new :py:class:`EtcdRendezvousBackend` from the specified parameters. + + +--------------+-----------------------------------------------------------+ + | Parameter | Description | + +==============+===========================================================+ + | read_timeout | The read timeout, in seconds, for etcd operations. | + | | Defaults to 60 seconds. | + +--------------+-----------------------------------------------------------+ + | protocol | The protocol to use to communicate with etcd. Valid | + | | values are "http" and "https". Defaults to "http". | + +--------------+-----------------------------------------------------------+ + | ssl_cert | The path to the SSL client certificate to use along with | + | | HTTPS. Defaults to ``None``. | + +--------------+-----------------------------------------------------------+ + | ssl_cert_key | The path to the private key of the SSL client certificate | + | | to use along with HTTPS. Defaults to ``None``. | + +--------------+-----------------------------------------------------------+ + | ca_cert | The path to the rool SSL authority certificate. Defaults | + | | to ``None``. | + +--------------+-----------------------------------------------------------+ + """ + client = _create_etcd_client(params) + + backend = EtcdRendezvousBackend( + client, params.run_id, key_prefix="/torch/elastic/rendezvous" + ) + + store = EtcdStore(client, "/torch/elastic/store") + + return backend, store diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/rendezvous/etcd_server.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/rendezvous/etcd_server.py new file mode 100644 index 0000000000000000000000000000000000000000..347e7339d9a46a78c9edf20917eef6146672ffc8 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/rendezvous/etcd_server.py @@ -0,0 +1,248 @@ +#!/usr/bin/env python3 +# mypy: allow-untyped-defs + +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +import atexit +import logging +import os +import shlex +import shutil +import socket +import subprocess +import tempfile +import time +from typing import TextIO + + +try: + import etcd # type: ignore[import] +except ModuleNotFoundError: + pass + + +logger = logging.getLogger(__name__) + + +def find_free_port(): + """ + Find a free port and binds a temporary socket to it so that the port can be "reserved" until used. + + .. note:: the returned socket must be closed before using the port, + otherwise a ``address already in use`` error will happen. + The socket should be held and closed as close to the + consumer of the port as possible since otherwise, there + is a greater chance of race-condition where a different + process may see the port as being free and take it. + + Returns: a socket binded to the reserved free port + + Usage:: + + sock = find_free_port() + port = sock.getsockname()[1] + sock.close() + use_port(port) + """ + addrs = socket.getaddrinfo( + host="localhost", port=None, family=socket.AF_UNSPEC, type=socket.SOCK_STREAM + ) + + for addr in addrs: + family, type, proto, _, _ = addr + try: + s = socket.socket(family, type, proto) + s.bind(("localhost", 0)) + s.listen(0) + return s + except OSError as e: + s.close() # type: ignore[possibly-undefined] + print(f"Socket creation attempt failed: {e}") + raise RuntimeError("Failed to create a socket") + + +def stop_etcd(subprocess, data_dir: str | None = None): + if subprocess and subprocess.poll() is None: + logger.info("stopping etcd server") + subprocess.terminate() + subprocess.wait() + + if data_dir: + logger.info("deleting etcd data dir: %s", data_dir) + shutil.rmtree(data_dir, ignore_errors=True) + + +class EtcdServer: + """ + .. note:: tested on etcd server v3.4.3. + + Starts and stops a local standalone etcd server on a random free + port. Useful for single node, multi-worker launches or testing, + where a sidecar etcd server is more convenient than having to + separately setup an etcd server. + + This class registers a termination handler to shutdown the etcd + subprocess on exit. This termination handler is NOT a substitute for + calling the ``stop()`` method. + + The following fallback mechanism is used to find the etcd binary: + + 1. Uses env var TORCHELASTIC_ETCD_BINARY_PATH + 2. Uses ``/bin/etcd`` if one exists + 3. Uses ``etcd`` from ``PATH`` + + Usage + :: + + server = EtcdServer("/usr/bin/etcd", 2379, "/tmp/default.etcd") + server.start() + client = server.get_client() + # use client + server.stop() + + Args: + etcd_binary_path: path of etcd server binary (see above for fallback path) + """ + + def __init__(self, data_dir: str | None = None): + self._port = -1 + self._host = "localhost" + + root = os.path.dirname(__file__) + default_etcd_bin = os.path.join(root, "bin/etcd") + self._etcd_binary_path = os.environ.get( + "TORCHELASTIC_ETCD_BINARY_PATH", default_etcd_bin + ) + if not os.path.isfile(self._etcd_binary_path): + self._etcd_binary_path = "etcd" + + self._base_data_dir = ( + data_dir if data_dir else tempfile.mkdtemp(prefix="torchelastic_etcd_data") + ) + self._etcd_cmd = None + self._etcd_proc: subprocess.Popen | None = None + + def _get_etcd_server_process(self) -> subprocess.Popen: + if not self._etcd_proc: + raise RuntimeError( + "No etcd server process started. Call etcd_server.start() first" + ) + else: + return self._etcd_proc + + def get_port(self) -> int: + """Return the port the server is running on.""" + return self._port + + def get_host(self) -> str: + """Return the host the server is running on.""" + return self._host + + def get_endpoint(self) -> str: + """Return the etcd server endpoint (host:port).""" + return f"{self._host}:{self._port}" + + def start( + self, + timeout: int = 60, + num_retries: int = 3, + stderr: int | TextIO | None = None, + ) -> None: + """ + Start the server, and waits for it to be ready. When this function returns the sever is ready to take requests. + + Args: + timeout: time (in seconds) to wait for the server to be ready + before giving up. + num_retries: number of retries to start the server. Each retry + will wait for max ``timeout`` before considering it as failed. + stderr: the standard error file handle. Valid values are + `subprocess.PIPE`, `subprocess.DEVNULL`, an existing file + descriptor (a positive integer), an existing file object, and + `None`. + + Raises: + TimeoutError: if the server is not ready within the specified timeout + """ + curr_retries = 0 + while True: + try: + data_dir = os.path.join(self._base_data_dir, str(curr_retries)) + os.makedirs(data_dir, exist_ok=True) + return self._start(data_dir, timeout, stderr) + except Exception as e: + curr_retries += 1 + stop_etcd(self._etcd_proc) + logger.warning( # noqa: G200 + "Failed to start etcd server, got error: %s, retrying", str(e) + ) + if curr_retries >= num_retries: + shutil.rmtree(self._base_data_dir, ignore_errors=True) + raise + atexit.register(stop_etcd, self._etcd_proc, self._base_data_dir) + + def _start( + self, data_dir: str, timeout: int = 60, stderr: int | TextIO | None = None + ) -> None: + sock = find_free_port() + sock_peer = find_free_port() + self._port = sock.getsockname()[1] + peer_port = sock_peer.getsockname()[1] + + etcd_cmd = shlex.split( + " ".join( + [ + self._etcd_binary_path, + "--enable-v2", + "--data-dir", + data_dir, + "--listen-client-urls", + f"http://{self._host}:{self._port}", + "--advertise-client-urls", + f"http://{self._host}:{self._port}", + "--listen-peer-urls", + f"http://{self._host}:{peer_port}", + ] + ) + ) + + logger.info("Starting etcd server: [%s]", etcd_cmd) + + sock.close() + sock_peer.close() + self._etcd_proc = subprocess.Popen(etcd_cmd, close_fds=True, stderr=stderr) + self._wait_for_ready(timeout) + + def get_client(self): + """Return an etcd client object that can be used to make requests to this server.""" + return etcd.Client( + host=self._host, port=self._port, version_prefix="/v2", read_timeout=10 + ) + + def _wait_for_ready(self, timeout: int = 60) -> None: + client = etcd.Client( + host=f"{self._host}", port=self._port, version_prefix="/v2", read_timeout=5 + ) + max_time = time.time() + timeout + + while time.time() < max_time: + if self._get_etcd_server_process().poll() is not None: + # etcd server process finished + exitcode = self._get_etcd_server_process().returncode + raise RuntimeError( + f"Etcd server process exited with the code: {exitcode}" + ) + try: + logger.info("etcd server ready. version: %s", client.version) + return + except Exception: + time.sleep(1) + raise TimeoutError("Timed out waiting for etcd server to be ready!") + + def stop(self) -> None: + """Stop the server and cleans up auto generated resources (e.g. data dir).""" + logger.info("EtcdServer stop method called") + stop_etcd(self._etcd_proc, self._base_data_dir) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/rendezvous/etcd_store.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/rendezvous/etcd_store.py new file mode 100644 index 0000000000000000000000000000000000000000..faaf77587bc9d66e42110f8b36c8c17e5aedec87 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/rendezvous/etcd_store.py @@ -0,0 +1,215 @@ +# mypy: allow-untyped-defs +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import datetime +import random +import time +from base64 import b64decode, b64encode + +# pyre-ignore[21]: Could not find name `Store` in `torch.distributed`. +from torch.distributed import Store + + +try: + import etcd # type: ignore[import] +except ModuleNotFoundError: + from . import _etcd_stub as etcd + + +# Delay (sleep) for a small random amount to reduce CAS failures. +# This does not affect correctness, but will reduce requests to etcd server. +def cas_delay(): + time.sleep(random.uniform(0, 0.1)) + + +# pyre-fixme[11]: Annotation `Store` is not defined as a type. +class EtcdStore(Store): + """ + Implement a c10 Store interface by piggybacking on the rendezvous etcd instance. + + This is the store object returned by ``EtcdRendezvous``. + """ + + def __init__( + self, + etcd_client, + etcd_store_prefix, + # Default timeout same as in c10d/Store.hpp + timeout: datetime.timedelta | None = None, + ): + super().__init__() # required for pybind trampoline. + + self.client = etcd_client + self.prefix = etcd_store_prefix + + if timeout is not None: + self.set_timeout(timeout) + + if not self.prefix.endswith("/"): + self.prefix += "/" + + def set(self, key, value): + """ + Write a key/value pair into ``EtcdStore``. + + Both key and value may be either Python ``str`` or ``bytes``. + """ + self.client.set(key=self.prefix + self._encode(key), value=self._encode(value)) + + def get(self, key) -> bytes: + """ + Get a value by key, possibly doing a blocking wait. + + If key is not immediately present, will do a blocking wait + for at most ``timeout`` duration or until the key is published. + + + Returns: + value ``(bytes)`` + + Raises: + LookupError - If key still not published after timeout + """ + b64_key = self.prefix + self._encode(key) + kvs = self._try_wait_get([b64_key]) + + if kvs is None: + raise LookupError(f"Key {key} not found in EtcdStore") + + return self._decode(kvs[b64_key]) + + def add(self, key, num: int) -> int: + """ + Atomically increment a value by an integer amount. + + The integer is represented as a string using base 10. If key is not present, + a default value of ``0`` will be assumed. + + Returns: + the new (incremented) value + + + """ + b64_key = self._encode(key) + # c10d Store assumes value is an integer represented as a decimal string + try: + # Assume default value "0", if this key didn't yet: + node = self.client.write( + key=self.prefix + b64_key, + value=self._encode(str(num)), # i.e. 0 + num + prevExist=False, + ) + return int(self._decode(node.value)) + except etcd.EtcdAlreadyExist: + pass + + while True: + # Note: c10d Store does not have a method to delete keys, so we + # can be sure it's still there. + node = self.client.get(key=self.prefix + b64_key) + new_value = self._encode(str(int(self._decode(node.value)) + num)) + try: + node = self.client.test_and_set( + key=node.key, value=new_value, prev_value=node.value + ) + return int(self._decode(node.value)) + except etcd.EtcdCompareFailed: + cas_delay() + + def wait(self, keys, override_timeout: datetime.timedelta | None = None): + """ + Wait until all of the keys are published, or until timeout. + + Raises: + LookupError - if timeout occurs + """ + b64_keys = [self.prefix + self._encode(key) for key in keys] + kvs = self._try_wait_get(b64_keys, override_timeout) + if kvs is None: + raise LookupError("Timeout while waiting for keys in EtcdStore") + # No return value on success + + def check(self, keys) -> bool: + """Check if all of the keys are immediately present (without waiting).""" + b64_keys = [self.prefix + self._encode(key) for key in keys] + kvs = self._try_wait_get( + b64_keys, + override_timeout=datetime.timedelta(microseconds=1), # as if no wait + ) + return kvs is not None + + # + # Encode key/value data in base64, so we can store arbitrary binary data + # in EtcdStore. Input can be `str` or `bytes`. + # In case of `str`, utf-8 encoding is assumed. + # + def _encode(self, value) -> str: + if type(value) is bytes: + return b64encode(value).decode() + elif type(value) is str: + return b64encode(value.encode()).decode() + raise ValueError("Value must be of type str or bytes") + + # + # Decode a base64 string (of type `str` or `bytes`). + # Return type is `bytes`, which is more convenient with the Store interface. + # + def _decode(self, value) -> bytes: + if type(value) is bytes: + return b64decode(value) + elif type(value) is str: + return b64decode(value.encode()) + raise ValueError("Value must be of type str or bytes") + + # + # Get all of the (base64-encoded) etcd keys at once, or wait until all the keys + # are published or timeout occurs. + # This is a helper method for the public interface methods. + # + # On success, a dictionary of {etcd key -> etcd value} is returned. + # On timeout, None is returned. + # + def _try_wait_get(self, b64_keys, override_timeout=None): + timeout = self.timeout if override_timeout is None else override_timeout # type: ignore[attr-defined] + deadline = time.time() + timeout.total_seconds() + + while True: + # Read whole directory (of keys), filter only the ones waited for + all_nodes = None + try: + all_nodes = self.client.get(key=self.prefix) + req_nodes = { + node.key: node.value + for node in all_nodes.children + if node.key in b64_keys + } + + if len(req_nodes) == len(b64_keys): + # All keys are available + return req_nodes + except etcd.EtcdKeyNotFound: + pass + + watch_timeout = deadline - time.time() + if watch_timeout <= 0: + return None + + try: + index = all_nodes.etcd_index + 1 if all_nodes else 0 + self.client.watch( + key=self.prefix, + recursive=True, + timeout=watch_timeout, + index=index, + ) + except etcd.EtcdWatchTimedOut: + if time.time() >= deadline: + return None + else: + continue + except etcd.EtcdEventIndexCleared: + continue diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/rendezvous/registry.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/rendezvous/registry.py new file mode 100644 index 0000000000000000000000000000000000000000..ebada4623a814c6b8a2b802d544e5926426e13fc --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/rendezvous/registry.py @@ -0,0 +1,95 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import logging +from importlib.metadata import entry_points + +from .api import ( + rendezvous_handler_registry as handler_registry, + RendezvousHandler, + RendezvousParameters, +) +from .dynamic_rendezvous import create_handler + + +log = logging.getLogger(__name__) + +__all__ = ["get_rendezvous_handler"] + + +def _create_static_handler(params: RendezvousParameters) -> RendezvousHandler: + from . import static_tcp_rendezvous + + return static_tcp_rendezvous.create_rdzv_handler(params) + + +def _create_etcd_handler(params: RendezvousParameters) -> RendezvousHandler: + from . import etcd_rendezvous + + return etcd_rendezvous.create_rdzv_handler(params) + + +def _create_etcd_v2_handler(params: RendezvousParameters) -> RendezvousHandler: + from .etcd_rendezvous_backend import create_backend + + backend, store = create_backend(params) + + return create_handler(store, backend, params) + + +def _create_c10d_handler(params: RendezvousParameters) -> RendezvousHandler: + from .c10d_rendezvous_backend import create_backend + + backend, store = create_backend(params) + + return create_handler(store, backend, params) + + +def _register_default_handlers() -> None: + handler_registry.register("etcd", _create_etcd_handler) + handler_registry.register("etcd-v2", _create_etcd_v2_handler) + handler_registry.register("c10d", _create_c10d_handler) + handler_registry.register("static", _create_static_handler) + + +def _register_out_of_tree_handlers() -> None: + discovered_handler_generators = entry_points(group="torchrun.handlers") + + for handler_generator in discovered_handler_generators: + try: + get_handler = discovered_handler_generators[handler_generator.name].load() + handler_registry.register(handler_generator.name, get_handler()) + except Exception: + log.warning( + "Exception while registering out of tree plugin %s: ", + handler_generator.name, + exc_info=True, + ) + + +def get_rendezvous_handler(params: RendezvousParameters) -> RendezvousHandler: + """ + Obtain a reference to a :py:class`RendezvousHandler`. + + Custom rendezvous handlers can be registered by + + :: + + from torch.distributed.elastic.rendezvous import rendezvous_handler_registry + from torch.distributed.elastic.rendezvous.registry import get_rendezvous_handler + + + def create_my_rdzv(params: RendezvousParameters): + return MyCustomRdzv(params) + + + rendezvous_handler_registry.register("my_rdzv_backend_name", create_my_rdzv) + + my_rdzv_handler = get_rendezvous_handler( + "my_rdzv_backend_name", RendezvousParameters + ) + """ + return handler_registry.create_handler(params) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/rendezvous/static_tcp_rendezvous.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/rendezvous/static_tcp_rendezvous.py new file mode 100644 index 0000000000000000000000000000000000000000..52b68000530889b6be1a8ec78ea762f6e5817975 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/rendezvous/static_tcp_rendezvous.py @@ -0,0 +1,128 @@ +#!/usr/bin/env python3 +# mypy: allow-untyped-defs + +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import datetime +import logging +from typing import cast + +from torch.distributed import PrefixStore, Store, TCPStore +from torch.distributed.elastic.rendezvous import ( + RendezvousHandler, + RendezvousInfo, + RendezvousParameters, + RendezvousStoreInfo, +) +from torch.distributed.elastic.rendezvous.utils import parse_rendezvous_endpoint + + +__all__ = ["StaticTCPRendezvous", "create_rdzv_handler"] + +logger = logging.getLogger(__name__) + +_default_timeout_seconds = 600 + + +class StaticTCPRendezvous(RendezvousHandler): + """ + Static rendezvous that is a wrapper around the TCPStore. + + Creates TCPStore based on the input parameters with the + listener on the agent with group_rank=0 + """ + + def __init__( + self, + master_addr: str, + master_port: int, + rank: int, + world_size: int, + run_id: str, + timeout: int, + ): + self.master_addr = master_addr + self.master_port = master_port + self.rank = rank + self.world_size = world_size + self.run_id = run_id + self.timeout = datetime.timedelta(seconds=timeout) + self._store: Store | None = None + + def get_backend(self) -> str: + return "static" + + @property + def use_agent_store(self) -> bool: + return True + + def next_rendezvous(self) -> RendezvousInfo: + logger.info("Creating TCPStore as the c10d::Store implementation") + is_master = self.rank == 0 + if not self._store: + self._store = TCPStore( # type: ignore[call-arg] + self.master_addr, + self.master_port, + self.world_size, + is_master, + self.timeout, + multi_tenant=True, + ) + store = PrefixStore(self.run_id, self._store) + # TCPStore server instance is used by trainer code + bootstrap_store_info = RendezvousStoreInfo(self.master_addr, self.master_port) + return RendezvousInfo( + store, + self.rank, + self.world_size, + bootstrap_store_info, + ) + + def is_closed(self): + return False + + def set_closed(self): + pass + + def num_nodes_waiting(self): + return 0 + + def get_run_id(self) -> str: + return self.run_id + + def shutdown(self) -> bool: + return True + + +def create_rdzv_handler(params: RendezvousParameters) -> RendezvousHandler: + if "rank" not in params.config: + raise ValueError( + "rank is absent in RendezvousParameters." + "Try add --node-rank to the cmd request" + ) + endpoint = params.endpoint.strip() + if not endpoint: + raise ValueError( + "endpoint is absent in RendezvousParameters" + "Try add --master-port and --master-addr to the cmd request" + ) + master_addr, master_port = parse_rendezvous_endpoint(endpoint, -1) + if master_port == -1: + raise ValueError( + f"Port is absent in endpoint: {endpoint}. Try launching with --master-port" + ) + world_size = params.max_nodes + rank = cast(int, params.config.get("rank")) + run_id = params.run_id + if "timeout" in params.config: + timeout = int(params.config["timeout"]) + else: + timeout = _default_timeout_seconds + + return StaticTCPRendezvous( + master_addr, master_port, rank, world_size, run_id, timeout + ) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/rendezvous/utils.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/rendezvous/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..05ebbba55913fc4f7d9843420a68b4ae233f3e14 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/rendezvous/utils.py @@ -0,0 +1,285 @@ +# mypy: allow-untyped-defs +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import ipaddress +import random +import re +import socket +import time +import weakref +from collections.abc import Callable +from datetime import timedelta +from threading import Event, Thread +from typing import Any + + +__all__ = ["parse_rendezvous_endpoint"] + + +def _parse_rendezvous_config(config_str: str) -> dict[str, str]: + """Extract key-value pairs from a rendezvous configuration string. + + Args: + config_str: + A string in format =,...,=. + """ + config: dict[str, str] = {} + + config_str = config_str.strip() + if not config_str: + return config + + key_values = config_str.split(",") + for kv in key_values: + key, *values = kv.split("=", 1) + + key = key.strip() + if not key: + raise ValueError( + "The rendezvous configuration string must be in format " + "=,...,=." + ) + + value: str | None + if values: + value = values[0].strip() + else: + value = None + if not value: + raise ValueError( + f"The rendezvous configuration option '{key}' must have a value specified." + ) + + config[key] = value + return config + + +def _try_parse_port(port_str: str) -> int | None: + """Try to extract the port number from ``port_str``.""" + if port_str and re.match(r"^[0-9]{1,5}$", port_str): + return int(port_str) + return None + + +def parse_rendezvous_endpoint( + endpoint: str | None, default_port: int +) -> tuple[str, int]: + """Extract the hostname and the port number from a rendezvous endpoint. + + Args: + endpoint: + A string in format [:]. + default_port: + The port number to use if the endpoint does not include one. + + Returns: + A tuple of hostname and port number. + """ + if endpoint is not None: + endpoint = endpoint.strip() + + if not endpoint: + return ("localhost", default_port) + + # An endpoint that starts and ends with brackets represents an IPv6 address. + if endpoint[0] == "[" and endpoint[-1] == "]": + host, *rest = endpoint, *[] + else: + host, *rest = endpoint.rsplit(":", 1) + + # Sanitize the IPv6 address. + if len(host) > 1 and host[0] == "[" and host[-1] == "]": + host = host[1:-1] + + if len(rest) == 1: + port = _try_parse_port(rest[0]) + if port is None or port >= 2**16: + raise ValueError( + f"The port number of the rendezvous endpoint '{endpoint}' must be an integer " + "between 0 and 65536." + ) + else: + port = default_port + + if not re.match(r"^[\w\.:-]+$", host): + raise ValueError( + f"The hostname of the rendezvous endpoint '{endpoint}' must be a dot-separated list of " + "labels, an IPv4 address, or an IPv6 address." + ) + + return host, port + + +def _matches_machine_hostname(host: str) -> bool: + """Indicate whether ``host`` matches the hostname of this machine. + + This function compares ``host`` to the hostname as well as to the IP + addresses of this machine. Note that it may return a false negative if this + machine has CNAME records beyond its FQDN or IP addresses assigned to + secondary NICs. + """ + if host == "localhost": + return True + + try: + addr = ipaddress.ip_address(host) + except ValueError: + addr = None + + if addr and addr.is_loopback: + return True + + try: + host_addr_list = socket.getaddrinfo( + host, None, proto=socket.IPPROTO_TCP, flags=socket.AI_CANONNAME + ) + except (ValueError, socket.gaierror) as _: + host_addr_list = [] + + host_ip_list = [host_addr_info[4][0] for host_addr_info in host_addr_list] + + this_host = socket.gethostname() + if host == this_host: + return True + + addr_list = socket.getaddrinfo( + this_host, None, proto=socket.IPPROTO_TCP, flags=socket.AI_CANONNAME + ) + for addr_info in addr_list: + # If we have an FQDN in the addr_info, compare it to `host`. + if addr_info[3] and addr_info[3] == host: + return True + + # Otherwise if `host` represents an IP address, compare it to our IP + # address. + if addr and addr_info[4][0] == str(addr): + return True + + # If the IP address matches one of the provided host's IP addresses + if addr_info[4][0] in host_ip_list: + return True + + return False + + +def _delay(seconds: float | tuple[float, float]) -> None: + """Suspend the current thread for ``seconds``. + + Args: + seconds: + Either the delay, in seconds, or a tuple of a lower and an upper + bound within which a random delay will be picked. + """ + if isinstance(seconds, tuple): + seconds = random.uniform(*seconds) + # Ignore delay requests that are less than 10 milliseconds. + if seconds >= 0.01: + time.sleep(seconds) + + +class _PeriodicTimer: + """Represent a timer that periodically runs a specified function. + + Args: + interval: + The interval, in seconds, between each run. + function: + The function to run. + """ + + # The state of the timer is hold in a separate context object to avoid a + # reference cycle between the timer and the background thread. + class _Context: + interval: float + function: Callable[..., None] + args: tuple[Any, ...] + kwargs: dict[str, Any] + stop_event: Event + + _name: str | None + _thread: Thread | None + _finalizer: weakref.finalize | None + + # The context that is shared between the timer and the background thread. + _ctx: _Context + + def __init__( + self, + interval: timedelta, + function: Callable[..., None], + *args: Any, + **kwargs: Any, + ) -> None: + self._name = None + + self._ctx = self._Context() + self._ctx.interval = interval.total_seconds() + self._ctx.function = function # type: ignore[assignment] + self._ctx.args = args or () + self._ctx.kwargs = kwargs or {} + self._ctx.stop_event = Event() + + self._thread = None + self._finalizer = None + + @property + def name(self) -> str | None: + """Get the name of the timer.""" + return self._name + + def set_name(self, name: str) -> None: + """Set the name of the timer. + + The specified name will be assigned to the background thread and serves + for debugging and troubleshooting purposes. + """ + if self._thread: + raise RuntimeError("The timer has already started.") + + self._name = name + + def start(self) -> None: + """Start the timer.""" + if self._thread: + raise RuntimeError("The timer has already started.") + + self._thread = Thread( + target=self._run, + name=self._name or "PeriodicTimer", + args=(self._ctx,), + daemon=True, + ) + + # We avoid using a regular finalizer (a.k.a. __del__) for stopping the + # timer as joining a daemon thread during the interpreter shutdown can + # cause deadlocks. The weakref.finalize is a superior alternative that + # provides a consistent behavior regardless of the GC implementation. + self._finalizer = weakref.finalize( + self, self._stop_thread, self._thread, self._ctx.stop_event + ) + + # We do not attempt to stop our background thread during the interpreter + # shutdown. At that point we do not even know whether it still exists. + self._finalizer.atexit = False + + self._thread.start() + + def cancel(self) -> None: + """Stop the timer at the next opportunity.""" + if self._finalizer: + self._finalizer() + + @staticmethod + def _run(ctx) -> None: + while not ctx.stop_event.wait(ctx.interval): + ctx.function(*ctx.args, **ctx.kwargs) + + @staticmethod + def _stop_thread(thread, stop_event): + stop_event.set() + + thread.join() diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/timer/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/timer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b9c2ea349cc67ff7175d5ef17ec63aecddbf52a7 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/timer/__init__.py @@ -0,0 +1,54 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Expiration timers are set up on the same process as the agent and +used from your script to deal with stuck workers. When you go into +a code-block that has the potential to get stuck you can acquire +an expiration timer, which instructs the timer server to kill the +process if it does not release the timer by the self-imposed expiration +deadline. + +Usage:: + + import torchelastic.timer as timer + import torchelastic.agent.server as agent + + def main(): + start_method = "spawn" + message_queue = mp.get_context(start_method).Queue() + server = timer.LocalTimerServer(message, max_interval=0.01) + server.start() # non-blocking + + spec = WorkerSpec( + fn=trainer_func, + args=(message_queue,), + ...) + agent = agent.LocalElasticAgent(spec, start_method) + agent.run() + + def trainer_func(message_queue): + timer.configure(timer.LocalTimerClient(message_queue)) + with timer.expires(after=60): # 60 second expiry + # do some work + +In the example above if ``trainer_func`` takes more than 60 seconds to +complete, then the worker process is killed and the agent retries the worker group. +""" + +from .api import ( # noqa: F401 + configure, + expires, + TimerClient, + TimerRequest, + TimerServer, +) +from .file_based_local_timer import ( # noqa: F401 + FileTimerClient, + FileTimerRequest, + FileTimerServer, +) +from .local_timer import LocalTimerClient, LocalTimerServer # noqa: F401 diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/timer/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/timer/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..762535aa9eb8a62b0118447a00671a9e533af5d0 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/timer/__pycache__/__init__.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/timer/__pycache__/api.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/timer/__pycache__/api.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..25d157e688c51ed727eec8f2e81b269866910826 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/timer/__pycache__/api.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/timer/__pycache__/debug_info_logging.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/timer/__pycache__/debug_info_logging.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..060cb60092d8d670f70c719fe7ae79b1f607fb67 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/timer/__pycache__/debug_info_logging.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/timer/__pycache__/file_based_local_timer.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/timer/__pycache__/file_based_local_timer.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f1ef347cd3045c02d0b592180fc7a0065b6ef19f Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/timer/__pycache__/file_based_local_timer.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/timer/__pycache__/local_timer.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/timer/__pycache__/local_timer.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f7c9e7950e4e565906dc1be15c3c2c6049a3b973 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/timer/__pycache__/local_timer.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/timer/api.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/timer/api.py new file mode 100644 index 0000000000000000000000000000000000000000..efe942022246e90c3b6b68fae59be012d9c8d56b --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/timer/api.py @@ -0,0 +1,281 @@ +# mypy: allow-untyped-defs +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +import abc +import logging +import threading +import time +from contextlib import contextmanager +from inspect import getframeinfo, stack +from typing import Any + + +__all__ = [ + "TimerRequest", + "TimerClient", + "RequestQueue", + "TimerServer", + "configure", + "expires", +] + +logger = logging.getLogger(__name__) + + +class TimerRequest: + """ + Data object representing a countdown timer acquisition and release + that is used between the ``TimerClient`` and ``TimerServer``. + A negative ``expiration_time`` should be interpreted as a "release" + request. + + .. note:: the type of ``worker_id`` is implementation specific. + It is whatever the TimerServer and TimerClient implementations + have on to uniquely identify a worker. + """ + + __slots__ = ["worker_id", "scope_id", "expiration_time"] + + def __init__(self, worker_id: Any, scope_id: str, expiration_time: float): + self.worker_id = worker_id + self.scope_id = scope_id + self.expiration_time = expiration_time + + def __eq__(self, other): + if isinstance(other, TimerRequest): + return ( + self.worker_id == other.worker_id + and self.scope_id == other.scope_id + and self.expiration_time == other.expiration_time + ) + return False + + +class TimerClient(abc.ABC): + """ + Client library to acquire and release countdown timers by communicating + with the TimerServer. + """ + + @abc.abstractmethod + def acquire(self, scope_id: str, expiration_time: float) -> None: + """ + Acquires a timer for the worker that holds this client object + given the scope_id and expiration_time. Typically registers + the timer with the TimerServer. + """ + + @abc.abstractmethod + def release(self, scope_id: str): + """ + Releases the timer for the ``scope_id`` on the worker this + client represents. After this method is + called, the countdown timer on the scope is no longer in effect. + """ + + +class RequestQueue(abc.ABC): + """ + Consumer queue holding timer acquisition/release requests + """ + + @abc.abstractmethod + def size(self) -> int: + """ + Returns the size of the queue at the time this method is called. + Note that by the time ``get`` is called the size of the queue + may have increased. The size of the queue should not decrease + until the ``get`` method is called. That is, the following assertion + should hold: + + size = q.size() + res = q.get(size, timeout=0) + assert size == len(res) + + -- or -- + + size = q.size() + res = q.get(size * 2, timeout=1) + assert size <= len(res) <= size * 2 + """ + + @abc.abstractmethod + def get(self, size: int, timeout: float) -> list[TimerRequest]: + """ + Gets up to ``size`` number of timer requests in a blocking fashion + (no more than ``timeout`` seconds). + """ + + +class TimerServer(abc.ABC): + """ + Entity that monitors active timers and expires them + in a timely fashion. This server is responsible for + reaping workers that have expired timers. + """ + + def __init__( + self, request_queue: RequestQueue, max_interval: float, daemon: bool = True + ): + """ + :param request_queue: Consumer ``RequestQueue`` + :param max_interval: max time (in seconds) to wait + for an item in the request_queue + :param daemon: whether to run the watchdog thread as a daemon + """ + super().__init__() + self._request_queue = request_queue + self._max_interval = max_interval + self._daemon = daemon + self._watchdog_thread: threading.Thread | None = None + self._stop_signaled = False + + @abc.abstractmethod + def register_timers(self, timer_requests: list[TimerRequest]) -> None: + """ + Processes the incoming timer requests and registers them with the server. + The timer request can either be a acquire-timer or release-timer request. + Timer requests with a negative expiration_time should be interpreted + as a release-timer request. + """ + + @abc.abstractmethod + def clear_timers(self, worker_ids: set[Any]) -> None: + """ + Clears all timers for the given ``worker_ids``. + """ + + @abc.abstractmethod + def get_expired_timers(self, deadline: float) -> dict[str, list[TimerRequest]]: + """ + Returns all expired timers for each worker_id. An expired timer + is a timer for which the expiration_time is less than or equal to + the provided deadline. + """ + + @abc.abstractmethod + def _reap_worker(self, worker_id: Any) -> bool: + """ + Reaps the given worker. Returns True if the worker has been + successfully reaped, False otherwise. If any uncaught exception + is thrown from this method, the worker is considered reaped + and all associated timers will be removed. + """ + + def _reap_worker_no_throw(self, worker_id: Any) -> bool: + """ + Wraps ``_reap_worker(worker_id)``, if an uncaught exception is + thrown, then it considers the worker as reaped. + """ + try: + return self._reap_worker(worker_id) + except Exception: + logger.exception( + "Uncaught exception thrown from _reap_worker(), " + "check that the implementation correctly catches exceptions", + ) + return True + + def _watchdog_loop(self): + while not self._stop_signaled: + try: + self._run_watchdog() + except Exception: + logger.exception("Error running watchdog") + + def _run_watchdog(self): + batch_size = max(1, self._request_queue.size()) + timer_requests = self._request_queue.get(batch_size, self._max_interval) + self.register_timers(timer_requests) + now = time.time() + reaped_worker_ids = set() + for worker_id, expired_timers in self.get_expired_timers(now).items(): + logger.info( + "Reaping worker_id=[%s]. Expired timers: %s", + worker_id, + self._get_scopes(expired_timers), + ) + if self._reap_worker_no_throw(worker_id): + logger.info("Successfully reaped worker=[%s]", worker_id) + reaped_worker_ids.add(worker_id) + else: + logger.error( + "Error reaping worker=[%s]. Will retry on next watchdog.", worker_id + ) + self.clear_timers(reaped_worker_ids) + + def _get_scopes(self, timer_requests): + return [r.scope_id for r in timer_requests] + + def start(self) -> None: + logger.info( + "Starting %s... max_interval=%s, daemon=%s", + type(self).__name__, + self._max_interval, + self._daemon, + ) + self._watchdog_thread = threading.Thread( + target=self._watchdog_loop, daemon=self._daemon + ) + logger.info("Starting watchdog thread...") + self._watchdog_thread.start() + + def stop(self) -> None: + logger.info("Stopping %s", type(self).__name__) + self._stop_signaled = True + if self._watchdog_thread: + logger.info("Stopping watchdog thread...") + self._watchdog_thread.join(self._max_interval) + self._watchdog_thread = None + else: + logger.info("No watchdog thread running, doing nothing") + + +_timer_client: TimerClient | None = None + + +def configure(timer_client: TimerClient): + """ + Configures a timer client. Must be called before using ``expires``. + """ + global _timer_client + _timer_client = timer_client + logger.info("Timer client configured to: %s", type(_timer_client).__name__) + + +@contextmanager +def expires(after: float, scope: str | None = None, client: TimerClient | None = None): + """ + Acquires a countdown timer that expires in ``after`` seconds from now, + unless the code-block that it wraps is finished within the timeframe. + When the timer expires, this worker is eligible to be reaped. The + exact meaning of "reaped" depends on the client implementation. In + most cases, reaping means to terminate the worker process. + Note that the worker is NOT guaranteed to be reaped at exactly + ``time.now() + after``, but rather the worker is "eligible" for being + reaped and the ``TimerServer`` that the client talks to will ultimately + make the decision when and how to reap the workers with expired timers. + + Usage:: + + torch.distributed.elastic.timer.configure(LocalTimerClient()) + with expires(after=10): + torch.distributed.all_reduce(...) + """ + if client is None: + if _timer_client is None: + raise RuntimeError("Configure timer client before using countdown timers.") + client = _timer_client + if scope is None: + # grab the caller file + lineno + caller = getframeinfo(stack()[1][0]) + scope = f"{caller.filename}#{caller.lineno}" + expiration = time.time() + after + client.acquire(scope, expiration) + try: + yield + finally: + client.release(scope) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/timer/debug_info_logging.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/timer/debug_info_logging.py new file mode 100644 index 0000000000000000000000000000000000000000..e385d91283a7b610f00397bfa4bc4800a89761ca --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/timer/debug_info_logging.py @@ -0,0 +1,24 @@ +#!/usr/bin/env python3 +# mypy: allow-untyped-defs + +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +from torch.distributed.elastic.utils.logging import get_logger + + +logger = get_logger(__name__) + +__all__ = ["log_debug_info_for_expired_timers"] + + +def log_debug_info_for_expired_timers( + run_id: str, + expired_timers: dict[int, list[str]], +): + if expired_timers: + logger.info("Timers expired for run:[%s] [%s].", run_id, expired_timers) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/timer/file_based_local_timer.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/timer/file_based_local_timer.py new file mode 100644 index 0000000000000000000000000000000000000000..5855efefcc85342378c273657fed27b37160a6ba --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/timer/file_based_local_timer.py @@ -0,0 +1,444 @@ +# mypy: allow-untyped-defs +# Copyright (c) Meta Platforms, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import io +import json +import os +import select +import signal +import sys +import threading +import time +from collections.abc import Callable +from typing import TypeVar +from typing_extensions import ParamSpec + +from torch.distributed.elastic.timer.api import TimerClient, TimerRequest +from torch.distributed.elastic.timer.debug_info_logging import ( + log_debug_info_for_expired_timers, +) +from torch.distributed.elastic.utils.logging import get_logger + + +_P = ParamSpec("_P") +_R = TypeVar("_R") + +__all__ = ["FileTimerClient", "FileTimerRequest", "FileTimerServer"] + +logger = get_logger(__name__) + + +def _retry(max_retries: int, sleep_time: float) -> Callable: + """ + A simple retry wrapper. + + Args: + max_retries: int, the maximum number of retries. + sleep_time: float, the time to sleep between retries. + """ + + def wrapper(func: Callable[_P, _R]) -> Callable[_P, _R]: + def wrapper(*args: _P.args, **kwargs: _P.kwargs): + for i in range(max_retries): + try: + return func(*args, **kwargs) + except Exception: + logger.exception("Error running %s. Retrying...", func.__name__) + if i < max_retries - 1: + time.sleep(sleep_time) + else: + raise + + return wrapper + + return wrapper + + +class FileTimerRequest(TimerRequest): + """ + Data object representing a countdown timer acquisition and release + that is used between the ``FileTimerClient`` and ``FileTimerServer``. + A negative ``expiration_time`` should be interpreted as a "release" + request. + ``signal`` is the signal to reap the worker process from the server + process. + """ + + __slots__ = ["version", "signal"] + + def __init__( + self, worker_pid: int, scope_id: str, expiration_time: float, signal: int = 0 + ) -> None: + super().__init__( + worker_id=worker_pid, scope_id=scope_id, expiration_time=expiration_time + ) + self.version = 1 + self.signal = signal + + @property + def worker_pid(self) -> int: + return self.worker_id + + def __eq__(self, other) -> bool: + if isinstance(other, FileTimerRequest): + return ( + super().__eq__(other) + and self.version == other.version + and self.signal == other.signal + ) + return False + + def to_json(self) -> str: + return json.dumps( + { + "version": self.version, + "pid": self.worker_pid, + "scope_id": self.scope_id, + "expiration_time": self.expiration_time, + "signal": self.signal, + }, + ) + + +class FileTimerClient(TimerClient): + """ + Client side of ``FileTimerServer``. This client is meant to be used + on the same host that the ``FileTimerServer`` is running on and uses + pid to uniquely identify a worker. + This client uses a named_pipe to send timer requests to the + ``FileTimerServer``. This client is a producer while the + ``FileTimerServer`` is a consumer. Multiple clients can work with + the same ``FileTimerServer``. + + Args: + + file_path: str, the path of a FIFO special file. ``FileTimerServer`` + must have created it by calling os.mkfifo(). + + signal: signal, the signal to use to kill the process. Using a + negative or zero signal will not kill the process. + """ + + def __init__( + self, + file_path: str, + signal=(signal.SIGKILL if sys.platform != "win32" else signal.CTRL_C_EVENT), # type: ignore[attr-defined] + ) -> None: + super().__init__() + self._file_path = file_path + self.signal = signal + + @_retry(max_retries=10, sleep_time=0.1) + def _open_non_blocking(self) -> io.TextIOWrapper | None: + # The server may have crashed or may haven't started yet. + # In such case, calling open() in blocking model blocks the client. + # To avoid such issue, open it in non-blocking mode, and an OSError will + # be raised if the server is not there. + fd = os.open(self._file_path, os.O_WRONLY | os.O_NONBLOCK) + return os.fdopen(fd, "wt") + + def _send_request(self, request: FileTimerRequest) -> None: + try: + file = self._open_non_blocking() + except Exception as e: + raise BrokenPipeError( + "Could not send the FileTimerRequest because FileTimerServer is not available." + ) from e + with file: + json_request = request.to_json() + # Write request with no greater than select.PIPE_BUF is guarantee to be atomic. + if len(json_request) > select.PIPE_BUF: + raise RuntimeError( + f"FileTimerRequest larger than {select.PIPE_BUF} bytes " + f"is not supported: {json_request}" + ) + file.write(json_request + "\n") + + def acquire(self, scope_id: str, expiration_time: float) -> None: + self._send_request( + request=FileTimerRequest( + worker_pid=os.getpid(), + scope_id=scope_id, + expiration_time=expiration_time, + signal=self.signal, + ), + ) + + def release(self, scope_id: str) -> None: + self._send_request( + request=FileTimerRequest( + worker_pid=os.getpid(), scope_id=scope_id, expiration_time=-1, signal=0 + ), + ) + + +class FileTimerServer: + """ + Server that works with ``FileTimerClient``. Clients are expected to be + running on the same host as the process that is running this server. + Each host in the job is expected to start its own timer server locally + and each server instance manages timers for local workers (running on + processes on the same host). + + Args: + + file_path: str, the path of a FIFO special file to be created. + + max_interval: float, max interval in seconds for each watchdog loop. + + daemon: bool, running the watchdog thread in daemon mode or not. + A daemon thread will not block a process to stop. + log_event: Callable[[Dict[str, str]], None], an optional callback for + logging the events in JSON format. + """ + + def __init__( + self, + file_path: str, + run_id: str, + max_interval: float = 10, + daemon: bool = True, + log_event: Callable[[str, FileTimerRequest | None], None] | None = None, + ) -> None: + self._file_path = file_path + self._run_id = run_id + self._max_interval = max_interval + self._daemon = daemon + self._timers: dict[tuple[int, str], FileTimerRequest] = {} + self._stop_signaled = False + self._watchdog_thread: threading.Thread | None = None + + self._is_client_started = False + if os.path.exists(self._file_path): + os.remove(self._file_path) + os.mkfifo(self._file_path) + # For test only. Count the number of requests received. + self._request_count = 0 + # For test only. Process all requests and stop the server. + self._run_once = False + self._log_event = ( + log_event if log_event is not None else lambda name, request: None + ) + self._last_progress_time = int(time.time()) + + def start(self) -> None: + logger.info( + "Starting %s... max_interval=%s, daemon=%s, file_path=%s", + type(self).__name__, + self._max_interval, + self._daemon, + self._file_path, + ) + self._watchdog_thread = threading.Thread( + target=self._watchdog_loop, daemon=self._daemon + ) + logger.info("Starting watchdog thread...") + self._watchdog_thread.start() + self._log_event("watchdog started", None) + + def stop(self) -> None: + logger.info("Stopping %s", type(self).__name__) + self._stop_signaled = True + if self._watchdog_thread: + logger.info("Stopping watchdog thread...") + self._watchdog_thread.join(self._max_interval) + self._watchdog_thread = None + else: + logger.info("No watchdog thread running, doing nothing") + if os.path.exists(self._file_path): + os.remove(self._file_path) + self._log_event("watchdog stopped", None) + + def run_once(self) -> None: + self._run_once = True + if self._watchdog_thread: + logger.info("Stopping watchdog thread...") + self._watchdog_thread.join() + self._watchdog_thread = None + else: + logger.info("No watchdog thread running, doing nothing") + if os.path.exists(self._file_path): + os.remove(self._file_path) + + @staticmethod + def is_process_running(pid: int): + """ + function to check process is running or not + """ + try: + # Check if the process exists and we can send signals to it + os.kill(pid, 0) + return True + except OSError: + return False + + def _watchdog_loop(self) -> None: + # Open the pipe in blocking mode blocks the server thread. + # This is fine for the following reasons: + # 1. No client case usually does not happen. + # 2. We are running the watchdog loop in a separate daemon + # thread, which will not block the process to stop. + try: + with open(self._file_path) as fd: + self._is_client_started = True + while not self._stop_signaled: + try: + run_once = self._run_once + self._run_watchdog(fd) + if run_once: + break + self._last_progress_time = int(time.time()) + except Exception: + logger.exception("Error running watchdog") + + except Exception: + logger.exception("Could not open the FileTimerServer pipe") + raise + + def _run_watchdog(self, fd: io.TextIOWrapper) -> None: + timer_requests = self._get_requests(fd, self._max_interval) + self.register_timers(timer_requests) + now = time.time() + reaped_worker_pids = set() + kill_process = False + reap_signal = 0 + + all_expired_timers = self.get_expired_timers(now) + log_debug_info_for_expired_timers( + self._run_id, + { + pid: [expired_timer.to_json() for expired_timer in expired_timers] + for pid, expired_timers in all_expired_timers.items() + }, + ) + + for worker_pid, expired_timers in all_expired_timers.items(): + logger.info( + "Reaping worker_pid=[%s]. Expired timers: %s", + worker_pid, + self._get_scopes(expired_timers), + ) + reaped_worker_pids.add(worker_pid) + # In case we have multiple expired timers, we find the first timer + # with a valid signal (>0) in the expiration time order. + expired_timers.sort(key=lambda timer: timer.expiration_time) + signal = 0 + expired_timer = None + for timer in expired_timers: + self._log_event("timer expired", timer) + if timer.signal > 0: + signal = timer.signal + expired_timer = timer + break + if signal <= 0: + logger.info( + "No signal specified with worker=[%s]. Do not reap it.", worker_pid + ) + continue + if self._reap_worker(worker_pid, signal): + logger.info( + "Successfully reaped worker=[%s] with signal=%s", worker_pid, signal + ) + self._log_event("kill worker process", expired_timer) + kill_process = True + reap_signal = signal + else: + logger.error( + "Error reaping worker=[%s]. Will retry on next watchdog.", + worker_pid, + ) + if kill_process and reap_signal > 0: + logger.info( + "Terminating the server process=[%s] because of expired timers", + os.getpid(), + ) + self._reap_worker(os.getpid(), reap_signal) + + self.clear_timers(reaped_worker_pids) + + def _get_scopes(self, timer_requests: list[FileTimerRequest]) -> list[str]: + return [r.scope_id for r in timer_requests] + + def _get_requests( + self, fd: io.TextIOWrapper, max_interval: float + ) -> list[FileTimerRequest]: + start = time.time() + requests = [] + while not self._stop_signaled or self._run_once: + # For named pipe, readline() is blocking when at least one writer opens. + # It returns only when flush() is called at the writer side. + # Note that flush() is automatically called inside close(). + # After the last writer closes, readline() is not blocking. + # It will return an empty string when it's at end-of-file. + # Since the client side always opens the pipe, writes a message and closes + # the pipe immediately, the readline() call below is not blocking for long. + json_request = fd.readline() + if len(json_request) == 0: + if self._run_once: + break + time.sleep(min(max_interval, 1)) + else: + request = json.loads(json_request) + pid = request["pid"] + scope_id = request["scope_id"] + expiration_time = request["expiration_time"] + signal = request["signal"] + requests.append( + FileTimerRequest( + worker_pid=pid, + scope_id=scope_id, + expiration_time=expiration_time, + signal=signal, + ) + ) + now = time.time() + if now - start > max_interval: + break + return requests + + def register_timers(self, timer_requests: list[FileTimerRequest]) -> None: + for request in timer_requests: + pid = request.worker_pid + scope_id = request.scope_id + expiration_time = request.expiration_time + self._request_count += 1 + + key = (pid, scope_id) + # negative expiration is a proxy for a release call + if expiration_time < 0: + if key in self._timers: + del self._timers[key] + else: + self._timers[key] = request + + def clear_timers(self, worker_pids: set[int]) -> None: + for pid, scope_id in list(self._timers.keys()): + if pid in worker_pids or not FileTimerServer.is_process_running(pid): + del self._timers[(pid, scope_id)] + + def get_expired_timers(self, deadline: float) -> dict[int, list[FileTimerRequest]]: + # pid -> [timer_requests...] + expired_timers: dict[int, list[FileTimerRequest]] = {} + for request in self._timers.values(): + if request.expiration_time <= deadline: + expired_scopes = expired_timers.setdefault(request.worker_pid, []) + expired_scopes.append(request) + return expired_timers + + def _reap_worker(self, worker_pid: int, signal: int) -> bool: + try: + os.kill(worker_pid, signal) + return True + except ProcessLookupError: + logger.info("Process with pid=%s does not exist. Skipping", worker_pid) + return True + except Exception: + logger.exception("Error terminating pid=%s", worker_pid) + return False + + def get_last_progress_time(self) -> int: + return self._last_progress_time if self._is_client_started else int(time.time()) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/timer/local_timer.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/timer/local_timer.py new file mode 100644 index 0000000000000000000000000000000000000000..5e66ef3fae34958422c1160bfdc1994b13bf1553 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/timer/local_timer.py @@ -0,0 +1,128 @@ +# mypy: allow-untyped-defs +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +import logging +import multiprocessing as mp +import os +import signal +import time +from queue import Empty +from typing import Any + +from .api import RequestQueue, TimerClient, TimerRequest, TimerServer + + +__all__ = ["LocalTimerClient", "MultiprocessingRequestQueue", "LocalTimerServer"] + +logger = logging.getLogger(__name__) + + +class LocalTimerClient(TimerClient): + """ + Client side of ``LocalTimerServer``. This client is meant to be used + on the same host that the ``LocalTimerServer`` is running on and uses + pid to uniquely identify a worker. This is particularly useful in situations + where one spawns a subprocess (trainer) per GPU on a host with multiple + GPU devices. + """ + + def __init__(self, mp_queue): + super().__init__() + self._mp_queue = mp_queue + + def acquire(self, scope_id, expiration_time): + pid = os.getpid() + acquire_request = TimerRequest(pid, scope_id, expiration_time) + self._mp_queue.put(acquire_request) + + def release(self, scope_id): + pid = os.getpid() + release_request = TimerRequest(pid, scope_id, -1) + self._mp_queue.put(release_request) + + +class MultiprocessingRequestQueue(RequestQueue): + """ + A ``RequestQueue`` backed by python ``multiprocessing.Queue`` + """ + + def __init__(self, mp_queue: mp.Queue): + super().__init__() + self._mp_queue = mp_queue + + def size(self) -> int: + return self._mp_queue.qsize() + + def get(self, size, timeout: float) -> list[TimerRequest]: + requests = [] + wait = timeout + for _ in range(size): + start = time.time() + + try: + r = self._mp_queue.get(block=True, timeout=wait) + except Empty: + break + + requests.append(r) + wait = wait - (time.time() - start) + if wait <= 0: + break + + return requests + + +class LocalTimerServer(TimerServer): + """ + Server that works with ``LocalTimerClient``. Clients are expected to be + subprocesses to the parent process that is running this server. Each host + in the job is expected to start its own timer server locally and each + server instance manages timers for local workers (running on processes + on the same host). + """ + + def __init__( + self, mp_queue: mp.Queue, max_interval: float = 60, daemon: bool = True + ): + super().__init__(MultiprocessingRequestQueue(mp_queue), max_interval, daemon) + self._timers: dict[tuple[Any, str], TimerRequest] = {} + + def register_timers(self, timer_requests: list[TimerRequest]) -> None: + for request in timer_requests: + pid = request.worker_id + scope_id = request.scope_id + expiration_time = request.expiration_time + + # negative expiration is a proxy for a release call + if expiration_time < 0: + self._timers.pop((pid, scope_id), None) + else: + self._timers[(pid, scope_id)] = request + + def clear_timers(self, worker_ids: set[int]) -> None: + for pid, scope_id in list(self._timers.keys()): + if pid in worker_ids: + self._timers.pop((pid, scope_id)) + + def get_expired_timers(self, deadline: float) -> dict[Any, list[TimerRequest]]: + # pid -> [timer_requests...] + expired_timers: dict[Any, list[TimerRequest]] = {} + for request in self._timers.values(): + if request.expiration_time <= deadline: + expired_scopes = expired_timers.setdefault(request.worker_id, []) + expired_scopes.append(request) + return expired_timers + + def _reap_worker(self, worker_id: int) -> bool: + try: + os.kill(worker_id, signal.SIGKILL) + return True + except ProcessLookupError: + logger.info("Process with pid=%s does not exist. Skipping", worker_id) + return True + except Exception: + logger.exception("Error terminating pid=%s", worker_id) + return False diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/utils/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ce2bbf5bbe2348bb0eaa411a034710dd14f7648e --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/utils/__init__.py @@ -0,0 +1,9 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from .api import get_env_variable_or_raise, get_socket_with_port, macros # noqa: F401 diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/utils/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/utils/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..43b8b1c1b66c72cf11bdea93cc87c48fa0c38bac Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/utils/__pycache__/__init__.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/utils/__pycache__/api.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/utils/__pycache__/api.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..83fbdd5baea0c31796e6cec6c9f3e520020e3ca3 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/utils/__pycache__/api.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/utils/__pycache__/distributed.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/utils/__pycache__/distributed.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5a20e91e49df5aebbd5478611de3a397fea3608f Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/utils/__pycache__/distributed.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/utils/__pycache__/log_level.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/utils/__pycache__/log_level.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..101f033af849a3478632af3519292720e4564e35 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/utils/__pycache__/log_level.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/utils/__pycache__/logging.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/utils/__pycache__/logging.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7dd938c8e20afafe771df96c6fa7ccce491bbbde Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/utils/__pycache__/logging.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/utils/__pycache__/store.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/utils/__pycache__/store.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9754c84c85bed265223b83ec7f3e26b555464a51 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/utils/__pycache__/store.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/utils/api.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/utils/api.py new file mode 100644 index 0000000000000000000000000000000000000000..2b881137047c23789a061a719437a43b1743959f --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/utils/api.py @@ -0,0 +1,62 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import os +import socket +from string import Template +from typing import Any + + +def get_env_variable_or_raise(env_name: str) -> str: + r""" + Tries to retrieve environment variable. Raises ``ValueError`` + if no environment variable found. + + Args: + env_name (str): Name of the env variable + """ + value = os.environ.get(env_name, None) + if value is None: + msg = f"Environment variable {env_name} expected, but not set" + raise ValueError(msg) + return value + + +def get_socket_with_port() -> socket.socket: + addrs = socket.getaddrinfo( + host="localhost", port=None, family=socket.AF_UNSPEC, type=socket.SOCK_STREAM + ) + for addr in addrs: + family, type, proto, _, _ = addr + s = socket.socket(family, type, proto) + try: + s.bind(("localhost", 0)) + s.listen(0) + return s + except OSError: + s.close() + raise RuntimeError("Failed to create a socket") + + +class macros: + """ + Defines simple macros for caffe2.distributed.launch cmd args substitution + """ + + local_rank = "${local_rank}" + + @staticmethod + def substitute(args: list[Any], local_rank: str) -> list[str]: + args_sub = [] + for arg in args: + if isinstance(arg, str): + sub = Template(arg).safe_substitute(local_rank=local_rank) + args_sub.append(sub) + else: + args_sub.append(arg) + return args_sub diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/utils/data/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/utils/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6c39bca6f3c8a31f5f2d7115ad12c1fc4925fe1d --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/utils/data/__init__.py @@ -0,0 +1,10 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from .cycling_iterator import CyclingIterator # noqa: F401 +from .elastic_distributed_sampler import ElasticDistributedSampler # noqa: F401 diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/utils/data/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/utils/data/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..db6ff00919260b0692af4ee439d6d5556f2aecc7 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/utils/data/__pycache__/__init__.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/utils/data/__pycache__/cycling_iterator.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/utils/data/__pycache__/cycling_iterator.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f152e4f639f93d987f156cfedca634254f696117 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/utils/data/__pycache__/cycling_iterator.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/utils/data/__pycache__/elastic_distributed_sampler.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/utils/data/__pycache__/elastic_distributed_sampler.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b20a845592707a9446cdfa6092c5536b12e5627f Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/utils/data/__pycache__/elastic_distributed_sampler.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/utils/data/cycling_iterator.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/utils/data/cycling_iterator.py new file mode 100644 index 0000000000000000000000000000000000000000..291a04226db79c77b3bde4cec239e45b31be81b5 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/utils/data/cycling_iterator.py @@ -0,0 +1,57 @@ +#!/usr/bin/env python3 + +from collections.abc import Callable, Iterator +from typing import TypeVar +from typing_extensions import Self + + +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +_T = TypeVar("_T") + +__all__ = ["CyclingIterator"] + + +class CyclingIterator(Iterator[_T]): + """ + An iterator decorator that cycles through the + underlying iterator "n" times. Useful to "unroll" + the dataset across multiple training epochs. + + The generator function is called as ``generator_fn(epoch)`` + to obtain the underlying iterator, where ``epoch`` is a + number less than or equal to ``n`` representing the ``k``th cycle + + For example if ``generator_fn`` always returns ``[1,2,3]`` + then ``CyclingIterator(n=2, generator_fn)`` will iterate through + ``[1,2,3,1,2,3]`` + """ + + def __init__( + self, + n: int, + generator_fn: Callable[[int], Iterator[_T]], + start_epoch: int = 0, + ): + self._n = n + self._epoch = start_epoch + self._generator_fn = generator_fn + self._iter = generator_fn(self._epoch) + + def __iter__(self) -> Self: + return self + + def __next__(self) -> _T: + try: + return next(self._iter) + except StopIteration as eod: # eod == end of data + if self._epoch < self._n - 1: + self._epoch += 1 + self._iter = self._generator_fn(self._epoch) + return self.__next__() + else: + raise eod diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/utils/data/elastic_distributed_sampler.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/utils/data/elastic_distributed_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..c824cc2fd018c005a59d0927a53ca449bf99102d --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/utils/data/elastic_distributed_sampler.py @@ -0,0 +1,93 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import math +from collections.abc import Iterator, Sized +from typing import cast, TypeVar + +import torch +from torch.utils.data import Dataset +from torch.utils.data.distributed import DistributedSampler + + +T = TypeVar("T") + +__all__ = ["ElasticDistributedSampler"] + + +class ElasticDistributedSampler(DistributedSampler[T]): + """ + Sampler that restricts data loading to a subset of + the dataset for elastic training. + + It is especially useful in conjunction with + :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each + process can pass a DistributedSampler instance as a DataLoader sampler, + and load a subset of the original dataset that is exclusive to it. + + .. note:: + Dataset is assumed to be of constant size. + + Args: + dataset: Dataset used for sampling. + num_replicas (optional): Number of processes participating in + distributed training. + rank (optional): Rank of the current process within num_replicas. + start_index (optional): Which index of the dataset to start sampling from + """ + + def __init__( + self, + dataset: Dataset[T], + num_replicas: int | None = None, + rank: int | None = None, + start_index: int = 0, + ): + super().__init__(dataset=dataset, num_replicas=num_replicas, rank=rank) + if not isinstance(dataset, Sized): + raise TypeError("Dataset must be an instance of collections.abc.Sized") + + # Cast to Sized for mypy + # pyrefly: ignore [redundant-cast] + sized_dataset = cast(Sized, dataset) + + if start_index >= len(sized_dataset): + raise ValueError( + f"Start index {start_index} should be less than dataset size {len(sized_dataset)}" + ) + + self.start_index = start_index + sized_dataset = cast(Sized, self.dataset) + self.num_samples = math.ceil( + float(len(sized_dataset) - self.start_index) / self.num_replicas + ) + self.total_size = self.num_samples * self.num_replicas + + def __iter__(self) -> Iterator[T]: + # deterministically shuffle based on epoch + g = torch.Generator() + g.manual_seed(self.epoch) + sized_dataset = cast(Sized, self.dataset) + indices = ( + torch.randperm(len(sized_dataset) - self.start_index, generator=g) + .add(self.start_index) + .tolist() + ) + + # add extra samples to make it evenly divisible + indices += indices[: (self.total_size - len(indices))] + assert len(indices) == self.total_size + + # subsample + indices = indices[self.rank : self.total_size : self.num_replicas] + assert len(indices) == self.num_samples + + return iter(indices) + + def __len__(self) -> int: + return self.num_samples diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/utils/distributed.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/utils/distributed.py new file mode 100644 index 0000000000000000000000000000000000000000..7b294d222ea7de5f0b7e91ac27ef876768d47eb6 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/utils/distributed.py @@ -0,0 +1,183 @@ +#!/usr/bin/env python3 +# mypy: allow-untyped-defs + +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +import datetime +import os +import socket +from contextlib import closing + +import torch.distributed as dist +from torch.distributed.elastic.utils.logging import get_logger +from torch.distributed.elastic.utils.store import barrier + + +__all__ = ["create_c10d_store", "get_free_port", "get_socket_with_port"] + +logger = get_logger(__name__) + +_ADDRESS_IN_USE = "Address already in use" +_SOCKET_TIMEOUT = "Socket Timeout" + +_TCP_STORE_INIT = "_tcp_store/num_members" + + +def create_c10d_store( + is_server: bool, + server_addr: str, + server_port: int = -1, + world_size: int = 1, + timeout: float = (60 * 10), # 10 min + wait_for_workers: bool = True, + retries=3, + use_libuv: bool | None = None, +): + if use_libuv is not None: + logger.warning( + "argument use_libuv is deprecated and ignored. Set USE_LIBUV environment " + 'variable to "0" to disable libuv, or "1" to enable it. If the env var ' + "is not set, libuv will be used by default." + ) + + # check os.environ for use_libuv + use_libuv = os.environ.get("USE_LIBUV", "1") == "1" # libuv is the default option + + if server_port == -1 and world_size > 1: + raise ValueError( + f"server_port must be specified when world_size > 1, got server_port={server_port}, world_size={world_size}" + ) + + if server_port != -1: + logger.info("sever_port: %s, specified, ignoring retries", server_port) + + # only retry when server_port is NOT static + attempt = retries if server_port == -1 else 1 + while True: + if server_port != -1: + port = server_port + else: + port = get_free_port() + + logger.info( + "Creating c10d store on %s:%s\n" + " world_size : %s\n" + " is_server : %s\n" + " timeout(sec): %s\n" + " use_libuv : %s\n", + server_addr, + port, + world_size, + is_server, + timeout, + use_libuv, + ) + + try: + store = dist.TCPStore( + host_name=server_addr, + port=port, + world_size=world_size, + is_master=is_server, + timeout=datetime.timedelta(seconds=timeout), + wait_for_workers=wait_for_workers, + use_libuv=use_libuv, + ) + # skips full rank check when we don't have to wait for all workers + if wait_for_workers: + _check_full_rank(store, world_size, timeout=timeout) + logger.info("Successfully created c10d store") + return store + except RuntimeError as e: + # this is brittle, but the underlying exception type is not properly pybinded + # so we parse the error msg for now, interestingly this is how torch itself + # detects timeouts and port conflicts in their own unittests + # see - caffe2/torch/testing/_internal/common_utils.py + # TODO properly map the exceptions in pybind (c10d/init.cpp) + if str(e) == _ADDRESS_IN_USE: # this will only happen on the server + if attempt < retries: + logger.warning( + "port: %s already in use, attempt: [%s/%s]", + port, + attempt, + retries, + ) + attempt += 1 + else: + raise RuntimeError( + f"on {server_addr}, port: {port} already in use" + ) from e + else: + raise + + +def _check_full_rank(store, world_size, timeout): + try: + barrier(store, world_size, key_prefix=_TCP_STORE_INIT, barrier_timeout=timeout) + except RuntimeError as e: + if str(e) == _SOCKET_TIMEOUT: + raise TimeoutError( + f"timed out waiting for all {world_size} members to join" + ) from e + else: + raise + + +def get_free_port(): + """ + Returns an unused port on localhost. + + This function finds an unused port on localhost by opening to socket to bind + to a port and then closing it. + + Returns: + int: an unused port on localhost + + Example: + >>> # xdoctest: +SKIP("Nondeterministic") + >>> get_free_port() + 63976 + + .. note:: + The port returned by :func:`get_free_port` is not reserved and may be + taken by another process after this function returns. + """ + sock = get_socket_with_port() + with closing(sock): + return sock.getsockname()[1] + + +def get_socket_with_port() -> socket.socket: + """ + Returns a free port on localhost that is "reserved" by binding a temporary + socket on it. Close the socket before passing the port to the entity + that requires it. Usage example + + :: + + sock = _get_socket_with_port() + with closing(sock): + port = sock.getsockname()[1] + sock.close() + # there is still a race-condition that some other process + # may grab this port before func() runs + func(port) + """ + + addrs = socket.getaddrinfo( + host="localhost", port=None, family=socket.AF_UNSPEC, type=socket.SOCK_STREAM + ) + for addr in addrs: + family, type, proto, _, _ = addr + s = socket.socket(family, type, proto) + try: + s.bind(("localhost", 0)) + s.listen(0) + return s + except OSError as e: + s.close() + logger.warning("Socket creation attempt failed.", exc_info=e) + raise RuntimeError("Failed to create a socket") diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/utils/log_level.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/utils/log_level.py new file mode 100644 index 0000000000000000000000000000000000000000..87ea0f7d64182488b40fd7fed6965ce57ec475a0 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/utils/log_level.py @@ -0,0 +1,14 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +def get_log_level() -> str: + """ + Return default log level for pytorch. + """ + return "WARNING" diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/utils/logging.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/utils/logging.py new file mode 100644 index 0000000000000000000000000000000000000000..aadf37eb16b8084486a537b18f399098cbcc4fb5 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/utils/logging.py @@ -0,0 +1,69 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import inspect +import logging +import os +import warnings + +from torch.distributed.elastic.utils.log_level import get_log_level + + +def get_logger(name: str | None = None) -> logging.Logger: + """ + Util function to set up a simple logger that writes + into stderr. The loglevel is fetched from the LOGLEVEL + env. variable or WARNING as default. The function will use the + module name of the caller if no name is provided. + + Args: + name: Name of the logger. If no name provided, the name will + be derived from the call stack. + """ + + # Derive the name of the caller, if none provided + # Use depth=2 since this function takes up one level in the call stack + return _setup_logger(name or _derive_module_name(depth=2)) + + +def _setup_logger(name: str | None = None) -> logging.Logger: + logger = logging.getLogger(name) + logger.setLevel(os.environ.get("LOGLEVEL", get_log_level())) + return logger + + +def _derive_module_name(depth: int = 1) -> str | None: + """ + Derives the name of the caller module from the stack frames. + + Args: + depth: The position of the frame in the stack. + """ + try: + stack = inspect.stack() + assert depth < len(stack) + # FrameInfo is just a named tuple: (frame, filename, lineno, function, code_context, index) + frame_info = stack[depth] + + module = inspect.getmodule(frame_info[0]) + if module: + module_name = module.__name__ + else: + # inspect.getmodule(frame_info[0]) does NOT work (returns None) in + # binaries built with @mode/opt + # return the filename (minus the .py extension) as modulename + filename = frame_info[1] + module_name = os.path.splitext(os.path.basename(filename))[0] + return module_name + except Exception as e: + warnings.warn( + f"Error deriving logger module name, using . Exception: {e}", + RuntimeWarning, + stacklevel=2, + ) + return None diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/utils/store.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/utils/store.py new file mode 100644 index 0000000000000000000000000000000000000000..598899e936aa0c9a1c43dda38ef2479eec03f842 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/elastic/utils/store.py @@ -0,0 +1,225 @@ +#!/usr/bin/env python3 +# mypy: allow-untyped-defs + +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from collections.abc import Callable, Iterable +from contextlib import contextmanager +from datetime import timedelta + +import torch + + +DistStoreError = torch._C._DistStoreError + +_NUM_MEMBERS = "/num_members" +_LAST_MEMBER_CHECKIN = "/last_member" +_TRACE = "/TRACE" +_TRACING_GATE = "/TRACING_GATE" +_MAX_TRACE_MISSING_RANKS = 16 + + +__all__ = ["store_timeout", "get_all", "synchronize", "barrier"] + + +@contextmanager +def store_timeout(store, timeout: float): + """ + This sets the timeout and then restores the old timeout when the context + manager exits. + + Args: + store: the store to set the timeout on + timeout: the timeout to set + """ + + old_timeout = store.timeout + store.set_timeout(timedelta(seconds=timeout)) + yield + store.set_timeout(old_timeout) + + +def get_all(store, rank: int, prefix: str, world_size: int): + r""" + Given a store and a prefix, the method goes through the array of keys + of the following format: ``{prefix}{idx}``, where idx is in a range + from 0 to size, and tries to retrieve the data. + + The Rank0 process waits at the end to make sure all other processes + finished the procedure before exiting. + + Usage + + :: + + values = get_all(store, "torchelastic/data", 3) + value1 = values[0] # retrieves the data for key torchelastic/data0 + value2 = values[1] # retrieves the data for key torchelastic/data1 + value3 = values[2] # retrieves the data for key torchelastic/data2 + + """ + data_arr = store.multi_get([f"{prefix}{idx}" for idx in range(world_size)]) + + barrier_key = _barrier_nonblocking( + store=store, + world_size=world_size, + key_prefix=f"{prefix}/finished", + ) + if rank == 0: + # Rank0 runs the TCPStore daemon, as a result it needs to exit last. + # Otherwise, the barrier may timeout if rank0 process finished the work + # before other processes finished `get_all` method + store.wait([barrier_key]) + + return data_arr + + +def synchronize( + store, + data: bytes, + rank: int, + world_size: int, + key_prefix: str, + timeout: float = 300, +) -> list[bytes]: + """ + Synchronizes ``world_size`` agents between each other using the underlying c10d store. + The ``data`` will be available on each of the agents. + + Note: The data on the path is not deleted, as a result there can be stale data if + you use the same key_prefix twice. + + Time complexity: O(N) per worker, O(N^2) globally. + """ + with store_timeout(store, timeout): + store.set(f"{key_prefix}{rank}", data) + agent_data = get_all(store, rank, key_prefix, world_size) + return agent_data + + +def _try_detecting_missing_ranks( + store, + world_size: int, + key_prefix: str, + rank: int, + rank_decoder: Callable[[int], str], + trace_timeout: float, +) -> Iterable[str] | None: + store.set(f"{key_prefix}{rank}{_TRACE}", "") + + def _find_missing_ranks(): + missing_rank_info = set() + ranks_missing = 0 + for i in range(1, world_size): + # reduce noise, assuming in general 8 ranks per node + # It is valuable to know that 1 or >1 nodes have timed-out. + if ranks_missing >= _MAX_TRACE_MISSING_RANKS: + break + try: + if ranks_missing == 0: + store.wait( + [f"{key_prefix}{i}{_TRACE}"], timedelta(seconds=trace_timeout) + ) + else: + # use a shortest timeout, some ranks have failed to check-in + store.wait([f"{key_prefix}{i}{_TRACE}"], timedelta(milliseconds=1)) + except DistStoreError: + ranks_missing += 1 + missing_rank_info.add(rank_decoder(i)) + return missing_rank_info + + def _checkin(): + try: + store.wait([f"{key_prefix}{_TRACING_GATE}"]) + return [f"[]"] + except DistStoreError: + # in case rank0 is the source of the timeout, original exception will be raised + return None + + if rank == 0: + missing_rank_info = _find_missing_ranks() + store.set(f"{key_prefix}{_TRACING_GATE}", "") + return missing_rank_info + else: + return _checkin() + + +def _barrier_nonblocking(store, world_size: int, key_prefix: str) -> str: + """ + Does all the non-blocking operations for a barrier and returns the final key + that can be waited on. + """ + num_members_key = key_prefix + _NUM_MEMBERS + last_member_key = key_prefix + _LAST_MEMBER_CHECKIN + + idx = store.add(num_members_key, 1) + if idx == world_size: + store.set(last_member_key, "") + + return last_member_key + + +def barrier( + store, + world_size: int, + key_prefix: str, + barrier_timeout: float = 300, + rank: int | None = None, + rank_tracing_decoder: Callable[[int], str] | None = None, + trace_timeout: float = 10, +) -> None: + """ + A global lock between agents. This will pause all workers until at least + ``world_size`` workers respond. + + This uses a fast incrementing index to assign waiting ranks and a success + flag set by the last worker. + + Time complexity: O(1) per worker, O(N) globally. + + Optionally, passing rank will enable tracing of missing ranks on timeouts. + `rank_tracing_decoder` lambda arg can be used to convert rank data + into a more meaningful information at an app level (e.g. hostname). + + Note: Since the data is not removed from the store, the barrier can be used + once per unique ``key_prefix``. + """ + + if rank is None: + assert rank_tracing_decoder is None, "Tracing requires rank information" + + with store_timeout(store, barrier_timeout): + last_member_key = _barrier_nonblocking( + store=store, world_size=world_size, key_prefix=key_prefix + ) + try: + store.wait([last_member_key]) + except DistStoreError as e: + if rank is None: + raise e + else: + missing_ranks = _try_detecting_missing_ranks( + store, + world_size, + key_prefix, + rank, + rank_tracing_decoder or (lambda x: str(x)), + trace_timeout, + ) + if missing_ranks is not None: + raise DistStoreError( + "Timed out waiting on barrier on " + "rank {}, for key prefix: {} (world_size={}, missing_ranks={}, timeout={})".format( + rank, + key_prefix, + world_size, + f"[{', '.join(missing_ranks)}]", + barrier_timeout, + ) + ) from None + else: + raise e diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/flight_recorder/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/flight_recorder/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/flight_recorder/fr_trace.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/flight_recorder/fr_trace.py new file mode 100644 index 0000000000000000000000000000000000000000..ab338d1503ae0ac4359728ba3a5983041e678f3d --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/flight_recorder/fr_trace.py @@ -0,0 +1,67 @@ +#!/usr/bin/env python3 +"""Flight Recorder Trace Analyzer + +This script primarily merges data from individual flight recorder buffers from individual ranks in a +PyTorch Distributed program into a flattened database format that can be used for further analysis. + +However as part of the merging process, it is necessary to perform some analysis in order to match operators +on one rank with corresponding operators on other ranks and register them as one 'collective' entry. During this +process, a significant amount of useful information can already be extracted such as where the first mismatch occurs +in cases of desync (when not all ranks issue a compatible collective in a particular process group). + + +Not Yet Implemented +- TODO- tracebacks aren't implemented + +Known Issues +- Flight Recorder buffer sequence_id information is not sufficient to match collectives and coalesced collectives + unless we have the trace data from the beginning of the program. To enable confident analysis of trace buffers that + do not start from zero (and to simplify the script's matching logic) we need to add more information to the recorder. +- Currently, the script omits checking the 'status' of collectives. We can look for the first 'non completed' + collective easily enough and report that. + +Usage +python fr_trace.py [-o ] + +- Omitting the optional output file will still yield analysis information to stdout +- The output file is a pickle of the flat DB, which may change in format in the future. +- This script is versioned so that we can ensure our future changes to flight recorder are backwards compatible. +""" + +import pickle +from collections.abc import Sequence + +from torch.distributed.flight_recorder.components.builder import build_db, transform_ft +from torch.distributed.flight_recorder.components.config_manager import JobConfig +from torch.distributed.flight_recorder.components.loader import read_dir +from torch.distributed.flight_recorder.components.types import types + + +__all__ = ["main"] + + +def main(args: Sequence[str] | None = None) -> None: + config = JobConfig() + # pyrefly: ignore [bad-assignment] + args = config.parse_args(args) + # pyrefly: ignore [missing-attribute] + assert args.trace_dir, "Trace directory trace_dir is required" + # pyrefly: ignore [bad-argument-type] + details, version = read_dir(args) + # pyrefly: ignore [missing-attribute] + if args.transform_ft: + # pyrefly: ignore [missing-attribute] + assert args.group_world_size, "World size is required for transform_ft" + # pyrefly: ignore [bad-argument-type] + details = transform_ft(details, args.group_world_size) + # pyrefly: ignore [bad-argument-type] + db = build_db(details, args, version) + # pyrefly: ignore [missing-attribute] + if args.output: + # pyrefly: ignore [no-matching-overload] + with open(args.output, "wb") as f: + pickle.dump((types, db), f) + + +if __name__ == "__main__": + main() diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/fsdp/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/fsdp/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1e4219250c39dc44dd0c1132e4e1b263de08f5c5 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/fsdp/__init__.py @@ -0,0 +1,69 @@ +from ._flat_param import FlatParameter as FlatParameter +from ._fully_shard import ( + CPUOffloadPolicy, + FSDPModule, + fully_shard, + MixedPrecisionPolicy, + OffloadPolicy, + register_fsdp_forward_method, + share_comm_ctx, + UnshardHandle, +) +from .fully_sharded_data_parallel import ( + BackwardPrefetch, + CPUOffload, + FullOptimStateDictConfig, + FullStateDictConfig, + FullyShardedDataParallel, + LocalOptimStateDictConfig, + LocalStateDictConfig, + MixedPrecision, + OptimStateDictConfig, + OptimStateKeyType, + ShardedOptimStateDictConfig, + ShardedStateDictConfig, + ShardingStrategy, + StateDictConfig, + StateDictSettings, + StateDictType, +) + + +__all__ = [ + # FSDP1 + "BackwardPrefetch", + "CPUOffload", + "FullOptimStateDictConfig", + "FullStateDictConfig", + "FullyShardedDataParallel", + "LocalOptimStateDictConfig", + "LocalStateDictConfig", + "MixedPrecision", + "OptimStateDictConfig", + "OptimStateKeyType", + "ShardedOptimStateDictConfig", + "ShardedStateDictConfig", + "ShardingStrategy", + "StateDictConfig", + "StateDictSettings", + "StateDictType", + # FSDP2 + "CPUOffloadPolicy", + "FSDPModule", + "fully_shard", + "MixedPrecisionPolicy", + "OffloadPolicy", + "register_fsdp_forward_method", + "UnshardHandle", + "share_comm_ctx", +] + +# Set namespace for exposed private names +CPUOffloadPolicy.__module__ = "torch.distributed.fsdp" +FSDPModule.__module__ = "torch.distributed.fsdp" +fully_shard.__module__ = "torch.distributed.fsdp" +MixedPrecisionPolicy.__module__ = "torch.distributed.fsdp" +OffloadPolicy.__module__ = "torch.distributed.fsdp" +register_fsdp_forward_method.__module__ = "torch.distributed.fsdp" +UnshardHandle.__module__ = "torch.distributed.fsdp" +share_comm_ctx.__module__ = "torch.distributed.fsdp" diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/fsdp/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/fsdp/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9f2da793e0b8f1da030a8933e8375726fd98d183 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/fsdp/__pycache__/__init__.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/fsdp/__pycache__/_common_utils.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/fsdp/__pycache__/_common_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3b71db11844af758d6a6a4d33515c273a2bb4915 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/fsdp/__pycache__/_common_utils.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/fsdp/__pycache__/_debug_utils.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/fsdp/__pycache__/_debug_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..efa0ff5c9d16f9ca0d8363887323d7eb75550349 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/fsdp/__pycache__/_debug_utils.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/fsdp/__pycache__/_dynamo_utils.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/fsdp/__pycache__/_dynamo_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..57a8031a0e2a049f376975f2f47cdcf6f1a473f6 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/fsdp/__pycache__/_dynamo_utils.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/fsdp/__pycache__/_exec_order_utils.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/fsdp/__pycache__/_exec_order_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e702eec7fdb597d9f099837b62ad5b360f4ab09b Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/fsdp/__pycache__/_exec_order_utils.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/fsdp/__pycache__/_fsdp_extensions.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/fsdp/__pycache__/_fsdp_extensions.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fac75be32f10467690cb0d93f640defb91128ebe Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/fsdp/__pycache__/_fsdp_extensions.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/fsdp/__pycache__/_init_utils.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/fsdp/__pycache__/_init_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..15f1ee1fe0107737775c1a3eae2921483a5b78b9 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/fsdp/__pycache__/_init_utils.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/fsdp/__pycache__/_limiter_utils.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/fsdp/__pycache__/_limiter_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3200b50970cbd93c15a45c5c76c4afa780830980 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/fsdp/__pycache__/_limiter_utils.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/fsdp/__pycache__/_optim_utils.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/fsdp/__pycache__/_optim_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..914c29f87687ec60ff2b349352350fbb47a8e376 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/fsdp/__pycache__/_optim_utils.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/fsdp/__pycache__/_shard_utils.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/fsdp/__pycache__/_shard_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3177b5ae3ba6acbffeb82ec6e6aa1173362a4753 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/fsdp/__pycache__/_shard_utils.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/fsdp/__pycache__/_state_dict_utils.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/fsdp/__pycache__/_state_dict_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5644c511d39f80572b624744fab30ab24592b7f0 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/fsdp/__pycache__/_state_dict_utils.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/fsdp/__pycache__/_trace_utils.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/fsdp/__pycache__/_trace_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c64a71157a4ca86e74a436b7cf54e5b8b8550b9b Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/fsdp/__pycache__/_trace_utils.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/fsdp/__pycache__/_traversal_utils.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/fsdp/__pycache__/_traversal_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..501dea3112bc11bf445483811dbf1bfe2fe236cc Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/fsdp/__pycache__/_traversal_utils.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/fsdp/__pycache__/_unshard_param_utils.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/fsdp/__pycache__/_unshard_param_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..67a23880abe768c60d929378268e690d1a3e37f8 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/fsdp/__pycache__/_unshard_param_utils.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/fsdp/__pycache__/_wrap_utils.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/fsdp/__pycache__/_wrap_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..017b651feb946fc8b535d7243adbffd8433e3629 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/fsdp/__pycache__/_wrap_utils.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/fsdp/__pycache__/api.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/fsdp/__pycache__/api.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..464e2fbf006bdcf81bbabd9b28cf0fd3f1db2617 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/fsdp/__pycache__/api.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/fsdp/__pycache__/sharded_grad_scaler.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/fsdp/__pycache__/sharded_grad_scaler.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..93fde939a3c780ab7a8d78aea2605300b96bd3cd Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/fsdp/__pycache__/sharded_grad_scaler.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/fsdp/__pycache__/wrap.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/fsdp/__pycache__/wrap.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d8ef120c13e4cf51ded4e4ef87d2ffb5f2dd01f8 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/fsdp/__pycache__/wrap.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/fsdp/_common_utils.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/fsdp/_common_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..54d6c974caedf83a473148b7eb85da267f2be070 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/fsdp/_common_utils.py @@ -0,0 +1,550 @@ +# mypy: allow-untyped-defs +""" +This file includes private common utilities for FSDP. +""" + +import logging +import traceback +import warnings +import weakref +from collections.abc import Callable, Generator, Iterable +from enum import auto, Enum +from functools import partial +from itertools import chain +from typing import Any, cast, no_type_check, Optional, TYPE_CHECKING + +import torch +import torch.distributed as dist +import torch.distributed.fsdp._flat_param as flat_param_file +import torch.nn as nn +from torch.distributed._composable_state import _get_module_state, _State +from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( + _CHECKPOINT_PREFIX, +) +from torch.distributed.utils import _apply_to_tensors +from torch.utils._mode_utils import no_dispatch + +from .api import ( + FullOptimStateDictConfig, + FullStateDictConfig, + OptimStateDictConfig, + ShardingStrategy, + StateDictConfig, + StateDictType, +) + + +if TYPE_CHECKING: + from torch.distributed.device_mesh import DeviceMesh + from torch.distributed.fsdp._fsdp_extensions import FSDPExtensions + + from ._flat_param import FlatParamHandle + +FSDP_WRAPPED_MODULE = "_fsdp_wrapped_module" +FSDP_PREFIX = FSDP_WRAPPED_MODULE + "." +FSDP_FLATTENED = "_fsdp_flattened" + +# Save a global mapping from module to its input tensor dtype to be populated +# during the forward pre-hook and consumed in the forward post-hook when +# overriding a module's mixed precision +# NOTE: We currently take the last input tensor's dtype in the case of multiple +# floating-point input tensors, which may be incorrect. However, since there is +# not a 1:1 correspondence between input and output tensors, we must use *some* +# heuristic like this to predict the desired output dtype. +_MODULE_TO_INP_DTYPE: weakref.WeakKeyDictionary = weakref.WeakKeyDictionary() + + +class _FSDPDeviceHandle: + """ + This is a simple abstraction for FSDP computing devices, + which enables custom backends that implement CUDA-like + semantics to be integrated with FSDP. + """ + + def __init__(self, device: torch.device, backend: Any = None): + if backend is None: + try: + self.__backend = getattr(torch, device.type) + # pyrefly: ignore [read-only] + self.__device = device + except AttributeError as exc: + raise AttributeError( + f"Device '{device}' does not have a corresponding backend registered as 'torch.{device.type}'." + ) from exc + else: + self.__backend = backend + + @classmethod + def from_device(cls, device: torch.device) -> "_FSDPDeviceHandle": + """ + Return a device handle corresponding to the device, and through this handle, + operations with the same semantics as CUDA can be performed on the device. + Just return torch.cuda if the device is cuda to make attribute-access faster. + Custom backend must first register a module with the same name with {device.type} on torch. + """ + if device.type == "cuda": + return cast(_FSDPDeviceHandle, torch.cuda) + elif device.type == "mtia": + return cast(_FSDPDeviceHandle, torch.mtia) + return cls(device) + + def __getattr__(self, name: str, /) -> Any: + try: + return getattr(self.__backend, name) + except AttributeError as exc: + raise AttributeError( + f"Custom backend '{self.__device.type}' not implement 'torch.{self.__device.type}.{name}'" + ) from exc + + +class _UninitializedDeviceHandle(_FSDPDeviceHandle): + def __init__(self) -> None: + pass + + def __getattribute__(self, name: str, /) -> Any: + raise RuntimeError("Trying to use an uninitialized device handle.") + + +class _FSDPState(_State): + def __init__(self) -> None: + # TODO: Move all the attributes to this class to enable typing for + # FSDP/fully_shard. + self._ignored_modules: set[nn.Module] = set() + self._ignored_params: set[nn.Parameter] = set() + # Buffer names are cleaned (without wrapper prefixes) + self._ignored_buffer_names: set[str] = set() + self.process_group: Optional[dist.ProcessGroup] = None + self.rank: int = -1 + self.world_size: int = -1 + self._device_mesh: Optional[DeviceMesh] = None + self.sharding_strategy = ShardingStrategy.FULL_SHARD + self._use_orig_params: bool = False + self.training_state = TrainingState.IDLE + self._unshard_params_ctx: dict[nn.Module, Generator] = {} + self._state_dict_type: StateDictType = StateDictType.FULL_STATE_DICT + self._state_dict_config: StateDictConfig = FullStateDictConfig() + self._optim_state_dict_config: OptimStateDictConfig = FullOptimStateDictConfig() + self._is_root: Optional[bool] = None + self._handle: Optional[flat_param_file.FlatParamHandle] = None + self._fully_sharded_module_to_handle: dict[ + nn.Module, Optional[flat_param_file.FlatParamHandle] + ] = {} + self.compute_device: Optional[torch.device] = None + self._gradient_predivide_factor: int = 0 + self._gradient_postdivide_factor: int = 0 + self._comm_hook: Optional[Callable] = None + self._comm_hook_state: Optional[Any] = None + self._unshard_event: Optional[torch.Event] = None + # Abstract device handle for fsdp compute device. For now, + # the compute device must implement cuda semantics used by fsdp + self._device_handle: _FSDPDeviceHandle = _UninitializedDeviceHandle() + # All following attributes should only be used for root states: + # Save these static lists to avoid the repeated tree traversals + self._all_fsdp_states: list[_FSDPState] = [] + self._all_handles: list[flat_param_file.FlatParamHandle] = [] + self._fsdp_extension: Optional[FSDPExtensions] = None + + +def _get_module_fsdp_state(module: nn.Module) -> Optional[_FSDPState]: + state = _get_module_state(module) + if state is None or not isinstance(state, _FSDPState): + return None + return state + + +def _get_module_fsdp_state_if_fully_sharded_module( + module: nn.Module, +) -> Optional[_FSDPState]: + state = _get_module_fsdp_state(module) + if state is None: + return None + if state == module: # FullyShardedDataParallel module case. + return state + if module in state._fully_sharded_module_to_handle: # fully_shard case. + return state + return None + + +class TrainingState(Enum): + """ + An enum that indicates the state of a ``FullyShardedDataParallel` instance. + """ + + IDLE = auto() + FORWARD_BACKWARD = auto() + SUMMON_FULL_PARAMS = auto() + + +class HandleTrainingState(Enum): + """ + An enum that indicates the state of a ``FlatParamHandle`. + """ + + IDLE = auto() + FORWARD = auto() + BACKWARD_PRE = auto() + BACKWARD_POST = auto() + SUMMON_FULL_PARAMS = auto() + + +def _is_composable(state: _FSDPState): + # TODO: This is a temporary hack for differentiate between code paths. + return not isinstance(state, nn.Module) + + +@no_type_check +def _module_handle(state: _FSDPState, module: nn.Module) -> Optional["FlatParamHandle"]: + """ + Returns the ``FlatParamHandle`` s corresponding to ``module``. This is + the handle that contains some parameter in ``module``. + """ + if _is_composable(state): + # A valid FSDP state may have no managed parameters and hence no + # handles, meaning no entry in `_fully_sharded_module_to_handles` + if state._handle is None: + return None + if module not in state._fully_sharded_module_to_handle: + raise AssertionError( + f"Expects a fully sharded module but got {module} on rank {state.rank}" + ) + return state._fully_sharded_module_to_handle[module] + else: + # NOTE: This assumes `module` is a `FullyShardedDataParallel` instance. + return module._handle + + +@no_type_check +def _has_fsdp_params(state: _FSDPState, module: nn.Module) -> bool: + """Returns if ``module`` has parameters managed by FSDP.""" + return _module_handle(state, module) is not None + + +def _get_sharding_strategy(handle): + """ + Returns the sharding strategy of the handle. + """ + return handle._sharding_strategy if handle else None + + +def clean_tensor_name(tensor_name: str) -> str: + """ + Cleans the parameter or buffer name by removing any module wrapper + prefixes. + """ + tensor_name = tensor_name.replace(FSDP_PREFIX, "") + # TODO: Explicitly replacing the checkpoint wrapper prefix is not ideal as + # it couples `CheckpointWrapper` and FSDP and also does not scale for more + # module wrappers. + tensor_name = tensor_name.replace(_CHECKPOINT_PREFIX, "") + return tensor_name + + +def _set_fsdp_flattened(tensor: torch.Tensor) -> None: + """ + Sets an attribute on ``tensor`` to mark it as flattened by FSDP. This is to + avoid re-flattening it during nested construction. + """ + setattr(tensor, FSDP_FLATTENED, True) + + +def _is_fsdp_flattened(tensor: torch.Tensor) -> bool: + """Returns if ``tensor`` has been marked as flattened by FSDP.""" + return getattr(tensor, FSDP_FLATTENED, False) + + +def _named_parameters_with_duplicates( + module: nn.Module, **kwargs: Any +) -> list[tuple[str, nn.Parameter]]: + """ + This API is required as some modules overwrite `named_parameters()` but do not support + `remove_duplicate`. + """ + if "remove_duplicate" in kwargs: + raise AssertionError( + "_named_parameters_with_duplicates cannot be used with `remove_duplicate` argument." + ) + kwargs["remove_duplicate"] = False + try: + ret = list(module.named_parameters(**kwargs)) + except AssertionError: + kwargs.pop("remove_duplicate") + ret = list(module.named_parameters(**kwargs)) + return ret + + +def _get_param_to_fqns( + model: torch.nn.Module, + dedup_shared_params: bool = True, +) -> dict[nn.Parameter, list[str]]: + """ + Constructs a mapping from parameter to a list of its \"canonical\" FQNs. Here, + we use canonical to mean the fully-qualified name assigned to the parameter + based on its position in the original nn.Module hierarchy before any wrapper + or parallelism has been applied to it. This is in contrast to FQNs that may be + generated after parallelisms or wrappers have been applied to the model. + + Each normal parameter maps to a singleton list containing its FQN, while each + ``FlatParameter`` maps to a list of its original parameter FQNs, which may + have length greater than one. All FQNs are prefixed starting from ``model``. + + In the case where FSDP was applied with ``use_orig_params=True``, there should be no + ``FlatParameter`` s registered to the model's modules and this mapping will only + contain mappings from ``nn.Parameter`` s to singleton FQN lists. + + It is only in the case where FSDP was applied with ``use_orig_params=False`` where + a ``FlatParameter`` will be registered in place of the original parameters and there + will be mappings from each ``FlatParameter`` to lists of FQNs corresponding to the + original parameters. + + Args: + model (torch.nn.Module): Root module (which may or may not be a + :class:`FullyShardedDataParallel` instance). + dedup_shared_params (bool): For shared parameters, if ``True``, only + includes the FQNs corresponding to the first encounter of the + shared parameter in the module traversal; if ``False``, then + includes the FQNs across all encounters. (Default: ``True``) + """ + + def module_fn(module, prefix, tree_level, param_to_fqns): + for param_name, param in _named_parameters_with_duplicates( + module, recurse=False + ): + local_fqns = ( + param._fqns + if isinstance(param, flat_param_file.FlatParameter) + else [param_name] + ) # prefixed from `module` + global_fqns = [ + clean_tensor_name(prefix + name) for name in local_fqns + ] # prefixed from the top level `model` (i.e. including `prefix`) + is_shared_param = param in param_to_fqns + if not is_shared_param: + param_to_fqns[param] = global_fqns + else: + if isinstance(param, flat_param_file.FlatParameter): + # DMP overwrites `named_parameters` and skip (advance to + # the next child module) the wrapped_module (e.g., + # _dmp_wrapped_module and _fsdp_wrapped_module). When a user + # calls `named_child` to traverse the module recursively and + # calls `named_parameters` with `recurse=False`, parameters + # will be traversed more than once. + # This hack is specified designed for DMP + FSDP. We + # overwrite the flat_parameters traversal result to only obtain + # the last one, which happens to be the correct one. + # + # TODO: Remove this hack once DMP + FSDP is not supported. + warnings.warn( + "FlatParameter is being traversed more than once. " + "This case should only happen when using " + "DistributedModelParallel with FullyShardedDataParallel.", + stacklevel=2, + ) + param_to_fqns[param] = global_fqns + elif not dedup_shared_params: + param_to_fqns[param].extend(global_fqns) + + def return_fn(param_to_fqns): + return param_to_fqns + + param_to_unflat_param_names: dict[torch.nn.Parameter, list[str]] = {} + return _apply_to_modules( + model, + module_fn, + return_fn, + [key for key, _ in _named_parameters_with_duplicates(model)], + param_to_unflat_param_names, + ) + + +@no_type_check +def _log_post_backward_hook( + state: _FSDPState, handle: "FlatParamHandle", logger: logging.Logger +) -> None: + # Under TORCH_DISTRIBUTED_DEBUG=INFO, log the module names this hook fires for. + # Below logging of module names this post-bwd hook fires for can help debug certain + # cases where hooks don't fire, such as under certain activation checkpoint configs. + if state._use_orig_params and handle._debug_level == dist.DebugLevel.INFO: + param_fqns = _get_handle_fqns_from_root(state, handle) + logger.warning("FSDP firing post-backward hooks for parameters %s", param_fqns) + + +@no_type_check +def _get_handle_fqns_from_root( + state: _FSDPState, handle: "FlatParamHandle" +) -> Optional[list[str]]: + if handle is None: + return None + param_to_fqn = state._exec_order_data.param_to_fqn + handle_params = handle.flat_param._params # only populated for use_orig_params + param_fqns = [*chain.from_iterable(param_to_fqn[p] for p in handle_params)] + return param_fqns + + +def _apply_to_modules( + root_module: torch.nn.Module, + module_fn: Callable, + return_fn: Callable, + filter_fqns: Optional[list[str]] = None, + *args, + **kwargs, +): + """ + Performs a pre-order traversal of the modules in the hierarchy rooted at + ``root_module``, applying ``module_fn`` at each module and finally + returning a value using ``return_fn``. The traversal constructs the full + module prefix name (e.g. "module.submodule." just like in model state dict) + and makes that available to ``module_fn``. + + ``filter_fqns`` is used because some module may have its own prefix similar + to ``FullyShardedDataParallel`` and the ``named_parameters()`` is overwritten + to remove the prefix. + """ + + def f(module: torch.nn.Module, prefix: str, tree_level: int, *args, **kwargs): + # Call the module function before recursing over children (pre-order) + module_fn(module, prefix, tree_level, *args, **kwargs) + for submodule_name, submodule in module.named_children(): + if submodule is None: + continue + new_prefix = prefix + submodule_name + "." + new_tree_level = tree_level + 1 + if filter_fqns is not None: + for fqn in filter_fqns: + if fqn.startswith(new_prefix): + break + else: + # DMP's named_parameter() will mess up the traversal with + # ``named_children`` + `named_parameter(recurse=False)``. + # This hack is a must to make the traversal work. + # TODO: Remove this hack once DMP + FSDP is not supported. + # It turns out that recursive wrapping may trigger this as + # well. + if ( + submodule_name == "_fsdp_wrapped_module" + or submodule_name == "_dmp_wrapped_module" + ): + new_prefix = prefix + elif submodule_name == "module": + new_prefix = prefix + f(submodule, new_prefix, new_tree_level, *args, **kwargs) + + f(root_module, "", 0, *args, **kwargs) + return return_fn(*args, **kwargs) + + +@no_type_check +def _assert_in_training_states( + state: _FSDPState, + training_states: list[TrainingState], +) -> None: + """Asserts that FSDP is in the states ``_training_states``.""" + # Raise a `ValueError` instead of using `assert` to ensure that these + # logical assertions run even if `assert`s are disabled + if state.training_state not in training_states: + msg = ( + f"expected to be in states {training_states} but current state is " + f"{state.training_state}" + ) + # Print the error on rank 0 in case this is called in the backward pass + if state.rank == 0: + if isinstance(state, nn.Module): + print(f"Asserting FSDP instance is: {state}") + print(f"ERROR: {msg}") + traceback.print_stack() + raise ValueError(msg) + + +def _get_root_modules(modules: set[nn.Module]) -> set[nn.Module]: + """ + Returns: + Set[nn.Module]: The subset of ``modules`` that are root modules (i.e. + parent-less) with respect to the modules in the set itself. In other + words, these are the modules in ``modules`` that are not the child of + any other module in ``modules``. + """ + root_modules: set[nn.Module] = set() + module_to_submodules = {module: set(module.modules()) for module in modules} + for candidate_module in modules: + is_root_module = True + for module, submodules in module_to_submodules.items(): + is_child_module = ( + candidate_module is not module and candidate_module in submodules + ) + if is_child_module: + is_root_module = False + break + if is_root_module: + root_modules.add(candidate_module) + return root_modules + + +def _override_module_mixed_precision( + root: torch.nn.Module, + module_classes_to_override: Iterable[type[nn.Module]], + wrap_override_dict: dict[str, Any] = {"mixed_precision": None}, # noqa: B006 +) -> set[type[nn.Module]]: + module_classes_to_override = tuple(set(module_classes_to_override)) + # Return a set of the actually overridden module classes + overridden_module_classes: set[type[nn.Module]] = set() + for mod in root.modules(): + if isinstance(mod, module_classes_to_override): + overridden_module_classes.add(type(mod)) + mod._wrap_overrides = wrap_override_dict # type: ignore[assignment] + # TODO: We need to run this mixed precision ignored module in fp32, + # but ensure subsequent modules, that may possibly be running with + # mixed precision, still receive the appropriate precision inputs + # without user having to adjust mixed precision config too much. + # As a result, we attach pre and post forward hooks to up / down + # cast. We should revisit this design. + + def cast_fn( + dtype: torch.dtype, module: nn.Module, x: torch.Tensor + ) -> torch.Tensor: + if not torch.is_floating_point(x) or x.dtype == dtype: + return x + _MODULE_TO_INP_DTYPE[module] = x.dtype + return x.to(dtype) + + def forward_pre_hook(module, args): + return _apply_to_tensors(partial(cast_fn, torch.float32, module), args) + + def forward_post_hook(module, args, output): + # NOTE: If the forward did not have any floating-point tensors, + # then the dtype will not be set for this module, and we do not + # upcast the dtype. + if module in _MODULE_TO_INP_DTYPE: + old_dtype = _MODULE_TO_INP_DTYPE[module] + return _apply_to_tensors( + partial(cast_fn, old_dtype, module), output + ) + + # We intentionally append both of these hooks so that they run after + # all other hooks. + mod.register_forward_pre_hook(forward_pre_hook, prepend=False) + mod.register_forward_hook(forward_post_hook, prepend=False) + return overridden_module_classes + + +def _no_dispatch_record_stream(tensor: torch.Tensor, stream: torch.Stream) -> None: + # FIXME record_stream doesn't work with non-cuda/mtia/xpu tensors + if tensor.device.type not in [ + "cuda", + "mtia", + "xpu", + torch._C._get_privateuse1_backend_name(), + ]: + return + + if torch.distributed._functional_collectives.is_torchdynamo_compiling(): + return + # from @ezyang: + # The no_dispatch was added in https://github.com/pytorch/pytorch/pull/88014 cc @fegin + # Looking over the PR, it looks like this is because we don't actually support Stream arguments + # in torch dispatch, so it just chokes. + # If Dynamo is able to answer "are there any torch dispatch modes" active (it should answer False), + # a better version of this would just be to check if there are any modes before disabling dispatch. + # TODO(voz): Extend a dynamo util to answer the above, unify the codepaths here. + tensor.record_stream(stream) + else: + with no_dispatch(): + tensor.record_stream(stream) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/fsdp/_debug_utils.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/fsdp/_debug_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..cf5a411f8c556ff1922775514cb2361a87bb492d --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/fsdp/_debug_utils.py @@ -0,0 +1,159 @@ +# mypy: allow-untyped-defs +import logging +import time +from collections import defaultdict +from collections.abc import Iterator +from contextlib import contextmanager +from enum import Enum + +import torch +import torch.distributed as dist +import torch.distributed.fsdp._flat_param as flat_param_file +from torch.distributed.fsdp._common_utils import ( + _apply_to_modules, + _get_module_fsdp_state, + clean_tensor_name, +) + + +logger = logging.getLogger(__name__) + + +class SimpleProfiler: + class Type(str, Enum): + ALL = "all" + ALLGATHER = "all_gather" + ALLGATHER_OBJ = "all_gather_object" + RESHARDING = "resharding" + H2D = "H2D" + D2H = "D2H" + + results: dict[str, float] = defaultdict(float) + profiling: set[str] = set() + + @classmethod + def reset(cls) -> None: + cls.results.clear() + cls.profiling.clear() + + @classmethod + @contextmanager + def profile(cls, profile_type: str) -> Iterator[None]: + if profile_type in cls.profiling: + raise AssertionError( + f"{profile_type} is already being profiled. " + "SimpleProfiler does not support profiling multiple instances at " + "the same time. " + ) + + cls.profiling.add(profile_type) + begin = time.monotonic() + try: + yield + finally: + end = time.monotonic() + cls.results[profile_type] += end - begin + cls.profiling.remove(profile_type) + + @classmethod + def dump_and_reset(cls, msg: str) -> None: + # This cannot be combined with DETAIL distributed log + # as the profiling will be very incorrect. + if dist.get_rank() == 0 and dist.get_debug_level() == dist.DebugLevel.INFO: + logger.info("%s %s", msg, cls.results) + cls.reset() + + +def _get_sharded_module_tree_with_module_name_to_fqns( + model: torch.nn.Module, +) -> tuple[str, dict[str, list[str]]]: + """ + It is used for composable fully_shard() code path, it returns + 1. sharded module tree info: each line represents a submodule name that contains the + submodule's FQN and its submodule class name, if the submodule is sharded by `fully_shard`, + the submodule name will add a postfix with ' FULLY SHARDED'. Each increased tree + level adds 4 spaces before the printed name. A printed sharded module tree info for a toy model + is like this: + [CompositeModel] FULLY SHARDED + l1[Linear] + u1[UnitModule] FULLY SHARDED + u1.l1[Linear] + u1.seq[Sequential] + u1.seq.0[ReLU] + u1.seq.1[Linear] + u1.seq.2[ReLU] + u1.l2[Linear] + u2[UnitModule] FULLY SHARDED + u2.l1[Linear] + u2.seq[Sequential] + u2.seq.0[ReLU] + u2.seq.1[Linear] + u2.seq.2[ReLU] + u2.l2[Linear] + l2[Linear] + 2. a dict mapping from the concated module FQN and class name to a list of its managed + original parameters' FQNs. An example of the dict for the above toy sharded model is like this: + {'[CompositeModel]': ['l1.weight', 'l1.bias', 'l2.weight', 'l2.bias'], + 'u1[UnitModule]': ['u1.l1.weight', 'u1.l1.bias', 'u1.seq.1.weight', 'u1.seq.1.bias', 'u1.l2.weight', 'u1.l2.bias'], + 'u2[UnitModule]': ['u2.l1.weight', 'u2.l1.bias', 'u2.seq.1.weight', 'u2.seq.1.bias', 'u2.l2.weight', 'u2.l2.bias'] + } + All FQNs are prefixed starting from ``model``. + + Args: + model (torch.nn.Module): Root module (which may or may not be passed to + composable `fully_shard()`). + """ + + def module_fn( + module, prefix, tree_level, sharded_tree_info, sharded_module_name_to_fqns + ): + num_spaces = tree_level * 4 + trimed_prefix = ( + prefix[:-1] if (len(prefix) > 0 and prefix[-1] == ".") else prefix + ) + prefixed_module_name = trimed_prefix + "[" + module.__class__.__name__ + "]" + printed_prefixed_module_name = " " * num_spaces + prefixed_module_name + + state = _get_module_fsdp_state(module) + if state is None: + sharded_tree_info[0] += printed_prefixed_module_name + "\n" + return + + handle = state._fully_sharded_module_to_handle.get(module, None) + + if handle: + sharded_tree_info[0] += ( + printed_prefixed_module_name + " FULLY SHARDED" + "\n" + ) + else: + sharded_tree_info[0] += printed_prefixed_module_name + "\n" + + if handle: + param = handle.flat_param + if not isinstance(param, flat_param_file.FlatParameter): + raise AssertionError(f"Expected FlatParameter, got {type(param)}") + global_fqns = [ + clean_tensor_name(prefix + name) for name in param._fqns + ] # prefixed from the top level `model` (i.e. including `prefix`) + + if prefixed_module_name in sharded_module_name_to_fqns: + sharded_module_name_to_fqns[prefixed_module_name].extend(global_fqns) + else: + sharded_module_name_to_fqns[prefixed_module_name] = global_fqns + + def return_fn(sharded_tree_info, sharded_module_name_to_fqns): + return sharded_tree_info[0], sharded_module_name_to_fqns + + # Use List to mutate its value in place while running the recursive functions + sharded_tree_info: list[str] = [ + "", + ] + sharded_module_name_to_fqns: dict[str, list[str]] = {} + return _apply_to_modules( + model, + module_fn, + return_fn, + [key for key, _ in model.named_parameters()], + sharded_tree_info, + sharded_module_name_to_fqns, + ) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/fsdp/_dynamo_utils.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/fsdp/_dynamo_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..77bcd43b63be27da8e8b79f877ce7cb9d67c74b8 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/fsdp/_dynamo_utils.py @@ -0,0 +1,43 @@ +import torch.nn as nn + + +def _annotate_modules_for_dynamo( + module: nn.Module, + ignored_modules: set[nn.Module], + use_orig_params: bool, +) -> None: + """ + Annotates the submodules in ``module`` 's tree, except those in + ``ignored_modules``, indicating that the submodules are FSDP-managed and + saving the ``use_orig_params`` setting passed to the FSDP constructor. + """ + for submodule in module.modules(): + if submodule not in ignored_modules: + """[note: Dynamo treats FSDP wrapped modules as UnspecializedNNModule] + + Dynamo doesn't get to see this instance (FullyShardedDataParallel) during tracing, since + it skips tracing all the torch.distributed.fsdp code. + - Why? Running the FSDP code eagerly avoids lots of issues trying to trace complex hooks, and also + gets us graph-breaks on FSDP module boundaries which we want anyway for comm ops. + - However, we _also_ want dynamo to treat the wrapped module inside FSDP 'unspecially' (*), + and we need a way to indicate to dynamo which modules are wrapped by FSDP. + + (*) UnspecializedNNModules in dynamo are traced-through without any assumptions, and with thorough + guards. NNModules otherwise are 'specialized', meaning there is less overhead due to assuming + their code is well-behaved. + + One particular issue with specialized NNModules for FSDP is that the + views created for orig_params are captured into the compiled graph on the first iteration, and while + they are always going to point to the correct flatparameter and give correct results, their order + of creation influences the order of backward execution, preventing overlap of comm and computation + during backward. We need to _use_ the new parameter views created on each forward iteration, in + order for backward to interleave hooks with compute per layer. UnspecializedNNModule lets us achieve + this by capturing the module code more 'functionally' and passing parameters in as inputs each time. + """ + submodule._is_fsdp_managed_module = True # type: ignore[assignment] + + # Dynamo only supports FSDP with use_orig_params=True. + # This is hacky, but I could not think of another way to add an assertion to dynamo + # for this, since Dynamo skips all the FSDP code frames and thus can't inspect the + # FSDP module directly + submodule._fsdp_use_orig_params = use_orig_params # type: ignore[assignment] diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/fsdp/_exec_order_utils.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/fsdp/_exec_order_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..db2ea7bfae0b92a6a103ac35655a6da627761e7e --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/fsdp/_exec_order_utils.py @@ -0,0 +1,366 @@ +# mypy: allow-untyped-defs +import itertools +import warnings +from enum import auto, Enum +from typing import Optional, Union + +import torch +import torch.distributed as dist +import torch.distributed.fsdp._traversal_utils as traversal_utils +import torch.nn as nn +from torch.distributed.fsdp._common_utils import _FSDPState, _get_param_to_fqns +from torch.distributed.fsdp._flat_param import FlatParamHandle + + +class _ExecOrderWarnStatus(Enum): + """Used internally for execution order validation.""" + + NONE = auto() # no deviation yet + WARNING = auto() # deviated this iteration; currently issuing warnings + WARNED = auto() # deviated in a previous iteration + + +class _ExecOrderData: + """ + This contains the data structures to track the execution order. We track + the pre-forward order on the *first* iteration for forward prefetching + (which thus assumes static graph) and the post-forward order on *every* + iteration for backward prefetching (which thus does not assume static + graph but may be provide an incorrect order). + """ + + def __init__( + self, + debug_level: dist.DebugLevel, + backward_prefetch_limit: int, + forward_prefetch_limit: int, + ) -> None: + # Tracks the (static) pre-forward order for execution order validation + # and forward prefetching + self.handles_pre_forward_order: list[FlatParamHandle] = [] + # Tracks the post-forward order for pre-backward prefetching + self.handles_post_forward_order: list[Optional[FlatParamHandle]] = [] + self._iter = 0 + + # Gives the max number of backward/forward prefetched all-gathers by a + # single module + self._backward_prefetch_limit = backward_prefetch_limit + self._forward_prefetch_limit = forward_prefetch_limit + + # Data structures for execution order validation + self._checking_order: bool = debug_level == dist.DebugLevel.DETAIL + self.process_group: Optional[dist.ProcessGroup] = None + self.world_size: Optional[int] = None + self.all_handles: list[FlatParamHandle] = [] + # Names are prefixed from the root module + self.param_to_fqn: dict[nn.Parameter, list[str]] = {} + # Current index in the pre-forward execution order + self.current_order_index = 0 + self.warn_status = _ExecOrderWarnStatus.NONE + + def init( + self, + state: _FSDPState, + root_module: nn.Module, + process_group: dist.ProcessGroup, + ) -> None: + """ + Initializes the data structures needed for checking the forward order. + This should be called after a root FSDP instance has been set during + lazy initialization. + """ + self.process_group = process_group + self.rank = process_group.rank() + self.world_size = process_group.size() + # Fix an order over the handles, which should be the same across ranks + for handle in traversal_utils._get_fsdp_handles(root_module): + index = len(self.all_handles) + self.all_handles.append(handle) + handle._handle_index = index + self.param_to_fqn = _get_param_to_fqns(root_module) + # TODO (awgu): We can broadcast the metadata of rank 0's `all_handles` + # to check that all ranks have the same handles in the same order. + # https://github.com/pytorch/pytorch/issues/79620 + + @property + def is_first_iter(self) -> bool: + return self._iter == 0 + + def get_handle_to_backward_prefetch( + self, + current_handle: FlatParamHandle, + ) -> Optional[FlatParamHandle]: + """ + Returns a :class:`list` of the handles keys of the handles to backward + prefetch given the current handles key. If there are no valid handles + keys to prefetch, then this returns an empty :class:`list`. + """ + current_index = current_handle._post_forward_index + if current_index is None: + return None + target_index = current_index - 1 + target_handle: Optional[FlatParamHandle] = None + for _ in range(self._backward_prefetch_limit): + if target_index < 0: + break + target_handle = self.handles_post_forward_order[target_index] + target_index -= 1 + return target_handle + + def get_handle_to_forward_prefetch( + self, + current_handle: FlatParamHandle, + ) -> Optional[FlatParamHandle]: + """ + Returns a :class:`list` of the handles keys of the handles to forward + prefetch given the current handles key. If there are no valid handles + keys to prefetch, then this returns an empty :class:`list`. + """ + current_index = current_handle._pre_forward_order_index + if current_index is None: + return None + target_index = current_index + 1 + target_handle: Optional[FlatParamHandle] = None + for _ in range(self._forward_prefetch_limit): + if target_index >= len(self.handles_pre_forward_order): + break + target_handle = self.handles_pre_forward_order[target_index] + target_index += 1 + return target_handle + + def record_post_forward(self, handle: Optional[FlatParamHandle]) -> None: + """ + Records ``handles`` in the post-forward order, where ``handles`` should + be a group of handles used in the same module's forward. If ``handles`` + is empty, then it is omitted. + + Unlike :meth:`record_pre_forward`, this records the order *every* + iteration with the expectation that the recorded order is reset in + :meth:`next_iter`. + """ + if not handle: + return + # Only record the first usage of a handles key + if handle._post_forward_index: + self.handles_post_forward_order.append(handle) + return + index = len(self.handles_post_forward_order) + handle._post_forward_index = index + self.handles_post_forward_order.append(handle) + + def record_pre_forward( + self, handle: Optional[FlatParamHandle], is_training: bool + ) -> None: + """ + Records ``handles`` in the pre-forward order, where ``handles`` should + be a group of handles used in the same module's forward. If ``handles`` + is empty, then it is omitted. + + On the first iteration, this checks the execution order across ranks. + See :meth:`_check_order` for details. + """ + if not handle: + return + self._check_order(handle, is_training) + # Fix the order after the first iteration and only record the first + # usage of a handles key + if not self.is_first_iter or handle._pre_forward_order_index is not None: + return + index = len(self.handles_pre_forward_order) + handle._pre_forward_order_index = index + self.handles_pre_forward_order.append(handle) + + def _check_order(self, handle: FlatParamHandle, is_training: bool) -> None: + """ + Checks the forward execution order as long as ``is_training`` is + ``True`` since checking in eval mode is not supported. This only checks + if the distributed debug level is DETAIL. + + - On the first iteration, this uses all-gathers to check that all ranks + are all-gathering the same handles and hence ``FlatParameter`` s, + raising an error if not. + - On subsequent iterations, this checks that each rank is locally + consistent with its own forward order from the first iteration, issuing + a warning if not. This issues a warning on the first deviating + iteration and stops warning thereafter. + """ + # Do not check order in eval mode since the post-backward callback does + # not run so it cannot be used to mark the end of an iteration + if not is_training or not self._checking_order: + return + if self.is_first_iter: + msg_prefix = "Forward order differs across ranks:" + optional_local_indices: tuple[Optional[int], ...] = ( + self._get_handle_indices(handle) + ) + device = handle.device # guaranteed to be non-CPU + num_valid_indices = sum( + (index is not None) for index in optional_local_indices + ) + tensor_kwargs: dict[str, Union[torch.dtype, torch.device]] = { + "dtype": torch.int32, + "device": device, + } + world_num_valid_indices = torch.zeros(self.world_size, **tensor_kwargs) # type: ignore[arg-type, call-overload] + local_num_valid_indices = torch.tensor([num_valid_indices], **tensor_kwargs) # type: ignore[arg-type, call-overload] + dist.all_gather_into_tensor( + world_num_valid_indices, + local_num_valid_indices, + group=self.process_group, + ) + # Copy entire tensor from D2H once to avoid per element D2H copies + world_num_valid_indices = world_num_valid_indices.cpu() + # Check that all ranks plan to all-gather the same number of + # parameters + # TODO (awgu): Since every module has at most one handle in the + # current implementation, this should never raise the error. + if self.world_size is None: + raise AssertionError("Expected world_size to not be None") + if not torch.distributed._functional_collectives.is_torchdynamo_compiling(): + # TODO(voz): Don't graph break on this - dynamo hates the n1 != n2 + # tensor comparison control flow. + # https://github.com/pytorch/pytorch/issues/107055 + for (r1, n1), (r2, n2) in itertools.combinations( + ( + (rank, world_num_valid_indices[rank]) + for rank in range(self.world_size) + ), + 2, + ): + if n1 != n2: + raise RuntimeError( + f"{msg_prefix} rank {r1} is all-gathering {n1} parameters " + f"while rank {r2} is all-gathering {n2} parameters" + ) + world_indices = torch.zeros( # type: ignore[call-overload] + self.world_size * num_valid_indices, **tensor_kwargs + ) + local_indices = torch.tensor(optional_local_indices, **tensor_kwargs) # type: ignore[arg-type] + dist.all_gather_into_tensor( + world_indices, local_indices, group=self.process_group + ) + # Copy entire tensor from D2H once to avoid per element D2H copies + world_indices = world_indices.cpu() + # Check that all ranks plan to all-gather the same index parameters + if not torch.distributed._functional_collectives.is_torchdynamo_compiling(): + # TODO(voz): Don't graph break on this - dynamo hates the i1 != i2 + # tensor comparison control flow. + # https://github.com/pytorch/pytorch/issues/107055 + for (r1, i1), (r2, i2) in itertools.combinations( + ( + ( + rank, + world_indices[ + rank * num_valid_indices : (rank + 1) + * num_valid_indices + ], + ) + for rank in range(self.world_size) + ), + 2, + ): + if i1 != i2: + r1_param_names = self._get_names_from_handle_indices(i1) + r2_param_names = self._get_names_from_handle_indices(i2) + raise RuntimeError( + f"{msg_prefix} rank {r1} is all-gathering parameters " + f"for {r1_param_names} while rank {r2} is all-gathering " + f"parameters for {r2_param_names}" + ) + else: + # Only issue warnings on the first deviating iteration and stop + # checking thereafter to avoid flooding the console + if self.warn_status == _ExecOrderWarnStatus.WARNED: + return + msg_prefix = None # non-`None` means we should warn + if self.current_order_index >= len(self.handles_pre_forward_order): + # This iteration sees extra all-gather(s) compared to the first + msg_prefix = ( + "Expected to not all-gather any more parameters in the " + "forward but trying to all-gather parameters for " + ) + else: + expected_handle = self.handles_pre_forward_order[ + self.current_order_index + ] + if expected_handle != handle: + expected_param_names = self._get_names_from_handles(expected_handle) + msg_prefix = ( + f"Expected to all-gather for {expected_param_names} " + "but trying to all-gather parameters for " + ) + if msg_prefix is not None: + param_names = self._get_names_from_handles(handle) + msg_suffix = ( + f"{param_names}" + if param_names + else "a newly-added parameter since construction time" + ) + warnings.warn( + "Forward order differs from that of the first iteration " + f"on rank {self.rank}. Collectives are unchecked and may " + f"give incorrect results or hang.\n{msg_prefix}{msg_suffix}", + stacklevel=2, + ) + self.warn_status = _ExecOrderWarnStatus.WARNING + self.current_order_index += 1 + + def _get_handle_indices( + self, + handle: FlatParamHandle, + ) -> tuple[Optional[int], ...]: + """ + Returns the handle indices (i.e. indices into ``self.all_handles``) + corresponding to the handles in ``handle``. An entry in the + returned tuple is ``None`` if the handle is invalid. + """ + indices: list[Optional[int]] = [] + if handle: + indices.append(handle._handle_index) + return tuple(indices) + + def _get_names_from_handle_indices( + self, + handle_indices: tuple[int, ...], + ) -> list[list[str]]: + """ + Returns a list of FQNs for each handle in ``handle_indices``. If a + handle index is invalid, then its FQNs are omitted from the returned + list. + """ + fqns: list[list[str]] = [] + for index in handle_indices: + if index is None or index < 0 or index >= len(self.all_handles): + continue + handle = self.all_handles[index] + flat_param = handle.flat_param + fqns.append(self.param_to_fqn[flat_param]) + return fqns + + def _get_names_from_handles( + self, + handle: FlatParamHandle, + ) -> list[list[str]]: + """ + Returns a list of FQNs for each handle in ``handles_key``. If a handle + is invalid, then its FQNs are omitted from the returned list. + """ + fqns: list[list[str]] = [] + if handle: + flat_param = handle.flat_param + if flat_param in self.param_to_fqn: + fqns.append(self.param_to_fqn[flat_param]) + return fqns + + def next_iter(self): + """ + Advances the internal data structures per iteration. This should be + called in the post-backward callback since that marks the true end of + an iteration. + """ + self._iter += 1 + self.handles_post_forward_order.clear() + if self._checking_order: + self.current_order_index = 0 + if self.warn_status == _ExecOrderWarnStatus.WARNING: + self.warn_status = _ExecOrderWarnStatus.WARNED diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/fsdp/_flat_param.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/fsdp/_flat_param.py new file mode 100644 index 0000000000000000000000000000000000000000..85e4c23d509f8c8751ac60572a7a4a78da0fc9cf --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/fsdp/_flat_param.py @@ -0,0 +1,2841 @@ +# mypy: allow-untyped-defs +import contextlib +import functools +import logging +import os +import warnings +from collections.abc import Callable, Generator, Iterator, Sequence +from enum import auto, Enum +from itertools import accumulate, chain +from typing import Any, cast, NamedTuple, no_type_check, Optional, Union + +import torch +import torch.distributed as dist +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor +from torch.distributed.fsdp._common_utils import ( + _FSDPDeviceHandle, + _named_parameters_with_duplicates, + _no_dispatch_record_stream, + _set_fsdp_flattened, + HandleTrainingState, +) +from torch.distributed.utils import ( + _alloc_storage, + _data_ptr_allocated, + _free_storage, + _p_assert, +) +from torch.nn.parameter import _ParameterMeta # type: ignore[attr-defined] +from torch.testing._internal.distributed.fake_pg import FakeProcessGroup + +from ._fsdp_extensions import ( + _ext_post_unflatten_transform, + _ext_pre_flatten_transform, + FSDPExtensions, +) + + +__all__ = [ + "FlatParameter", + "FlatParamHandle", + "FlatParamShardMetadata", + "ParamInfo", + "SharedParamInfo", + "HandleShardingStrategy", +] + +logger = logging.getLogger(__name__) + + +""" +[Note: Fully Sharded Module] +We define the "fully sharded module" to be the original ``nn.Module`` that owns +a ``FlatParamHandle``. It is the *single* module logically responsible for the +*single* unshard/reshard pair for the handle's ``FlatParameter`` for a given +forward or backward pass. The fully sharded module should be passed to the +``FlatParamHandle`` constructor. + +For the wrapper code path: +- The ``FullyShardedDataParallel`` module wrapping the fully sharded module +runs the unshard/reshard on behalf of the fully sharded module by overriding +``nn.Module.forward``. +- The fully sharded module is exactly the module passed to the +``FullyShardedDataParallel`` constructor's ``module`` argument. + +For the non-wrapper code path: +- Hooks registered on the fully sharded module run the unshard/reshard. +- The fully sharded module may either be the direct argument to ``fully_shard`` +or a submodule chosen by the provided wrapping policy. +""" + +# Environment variable toggling whether to use unsafe `setattr()` for view +# setting in `_use_sharded_views()` and `_use_unsharded_views()` +# We should use 'safe' by default since it respects method overrides, but for +# special cases such as for high CPU overhead or for intentionally bypassing +# checks in the overrides, we may use 'unsafe'. +_FSDP_USE_UNSAFE_SETATTR = "FSDP_USE_UNSAFE_SETATTR" + +# Environment variable toggling whether to check for parameter/gradient +# writeback in case their storages change after FSDP initialization +# We should check by default since it prevents silent correctness errors, but +# since such changes are atypical, we may want to skip the check to save CPU +# overhead, especially since the check happens in the pre-forward and +# pre-backward each iteration. +_FSDP_SKIP_WRITEBACK_CHECK = "FSDP_SKIP_WRITEBACK_CHECK" + +# Env var toggling whether when model is in .eval() mode, should we run in fp32 +# or the reduced precision. +_FSDP_USE_FULL_PREC_IN_EVAL = "FSDP_USE_FULL_PREC_IN_EVAL" + +# Some value to set padding in tensors to for debuggability +_FLAT_PARAM_PADDING_VALUE = 42 + +# Environment variables for disabling the all-gather and reduce-scatter +# communication ops for ablation studies. Note that without these communication +# ops the training won't converge, and you probably need to disable correctness +# checks in your model. +_FSDP_USE_FAKE_ALL_GATHER = "FSDP_USE_FAKE_ALL_GATHER" +_FSDP_USE_FAKE_REDUCE = "FSDP_USE_FAKE_REDUCE" + + +# TODO: Define this for now to avoid circular imports. See if we can remove. +class HandleShardingStrategy(Enum): + FULL_SHARD = auto() + SHARD_GRAD_OP = auto() + NO_SHARD = auto() + HYBRID_SHARD = auto() + _HYBRID_SHARD_ZERO2 = auto() + + +RESHARD_AFTER_FORWARD_HANDLE_STRATEGIES = ( + HandleShardingStrategy.FULL_SHARD, + HandleShardingStrategy.HYBRID_SHARD, +) +NO_RESHARD_AFTER_FORWARD_HANDLE_STRATEGIES = ( + HandleShardingStrategy.SHARD_GRAD_OP, + HandleShardingStrategy._HYBRID_SHARD_ZERO2, +) + + +class ParamInfo(NamedTuple): + """Information for an original parameter.""" + + param_name: str # unprefixed + module: nn.Module + module_name: str + + +class SharedParamInfo(NamedTuple): + """ + Additional information for a shared parameter. + + For each shared parameter, we designate one module and its parameter + variable to be the primary owner, determined as the first one encountered + in the parameter walk. These are prefixed with "prim". The primary module + and parameter do not have their own :class:`SharedParamInfo` instance. + """ + + param_name: str # unprefixed + module: nn.Module + module_name: str + prim_param_name: str # unprefixed + prim_module: nn.Module + prim_module_name: str + + +class _ShardParamInfo(NamedTuple): + """Shard-related information for an original parameter.""" + + in_shard: bool + # Use to index into the sharded flat parameter, e.g. + # `flat_param[offset_in_shard : offset_in_shard + numel_in_shard]` + offset_in_shard: Optional[int] + numel_in_shard: Optional[int] + # Use to get part of the parameter in the local shard from a flattened + # version of the unsharded parameter, e.g. either + # `param.flatten()[intra_param_start_idx : intra_param_end_idx + 1]` or + # `param.as_strided((param.numel(),), (1,))[intra_param_start_idx : intra_param_end_idx + 1]` + intra_param_start_idx: Optional[int] + intra_param_end_idx: Optional[int] # inclusive + + +class FlatParamShardMetadata(NamedTuple): + """ + This holds metadata specific to this rank's shard of the flat parameter. + + Attributes: + param_names (Tuple[str, ...]): Prefixed parameter names of this rank's + shard of the parameters; see :class:`FlatParameter`. + param_shapes (Tuple[torch.Size, ...]): Parameter shapes of this rank's + shard of the parameters; see :class:`FlatParameter`. + param_strides (Tuple[torch.Size, ...]): Parameter strides of this rank's + shard of the parameters; see :class:`FlatParameter`. + param_contiguities (Tuple[bool, ...]): Parameter `.contiguous` call results + of this rank's shard of the parameters; see :class:`FlatParameter`. + param_numels (Tuple[int, ...]): Parameter numels of this rank's shard + of the parameters; see :class:`FlatParameter`. + param_offsets (Tuple[Tuple[int, int], ...]): [start, end] offsets (in + units of numels) giving this rank's part of each flattened + original parameter. + """ + + param_names: tuple[str, ...] + param_shapes: tuple[torch.Size, ...] + param_strides: tuple[tuple[int, ...], ...] + param_contiguities: tuple[bool, ...] + param_numels: tuple[int, ...] + param_offsets: tuple[tuple[int, int], ...] + + +class _FlatParameterMeta(_ParameterMeta): + # Make `isinstance(t, FlatParameter)` return True for custom tensor + # instances that have the _is_flat_param flag for BC + def __instancecheck__(self, instance): + # NB: do NOT test the super implementation + return isinstance(instance, torch.Tensor) and getattr( + instance, "_is_flat_param", False + ) + + +class FlatParameter(nn.Parameter, metaclass=_FlatParameterMeta): + """ + This is the flat parameter used by :class:`FullyShardedDataParallel`. + + It is comprised of one or more original parameters, which are flattened and + concatenated to construct the flat parameter. + + Under the current design, this parameter logically represents both the + unsharded and sharded flat parameter, and its data changes storages + dynamically. + - In the :class:`FullyShardedDataParallel` constructor, the parameter + is initialized as unsharded and then sharded in-place. + - At runtime, the parameter is lazily (re)-initialized. The sharded + parameter data is saved in ``self._local_shard``, and a new ``Tensor`` + ``self._full_param_padded`` is created, which is the all-gather + destination and owns the unsharded parameter storage thereafter. (See + :meth:`FlatParamHandle.init_flat_param_attributes`.) + - Throughout runtime, the parameter data changes storages as needed, + e.g. to the sharded flat parameter, low precision sharded flat + parameter, or the unsharded flat parameter. + + NOTE: Since ``use_orig_params=True`` supports intra-``FlatParameter`` + padding, we have two versions of the per-parameter numels, one that + includes the padding (``_numels_with_padding``) and one that does not + (``_numels``). The former may have length longer than the other data + structures, while the latter has the same length as the number of actual + original parameters like the other per-parameter data structures. + + NOTE: This is not a real class; instead, you will always get a Parameter + back out if you try to create one of these. This is similar to the trick + we implemented for Parameter to get it to work with subclasses; this + is primarily so that FlatParameter supports combination with FakeTensor. + + Attributes: + _unpadded_unsharded_size (torch.Size): Unsharded flat parameter's size + without right-hand-side padding for divisibility by the world size. + For ``use_orig_params=True``, this includes alignment padding. + _padded_unsharded_size (torch.Size): Unsharded flat parameter's size + with right-hand-side padding for divisibility by the world size. + For ``use_orig_params=True``, this includes alignment padding. This + is only set for sharded strategies since they require padding for + the all-gather. + _sharded_size (torch.Size): Sharded flat parameter's size with padding. + This is also set for ``NO_SHARD``, in which case it is the same as + the unsharded sizes. (We omit "padded" because there is no + analogous unpadded one.) + + _num_params (int): Number of original parameters flattened into this + flat parameter. This is the length of the per-parameter data + structures. + _param_infos (Tuple[ParamInfo, ...]): Each parameter's parameter info + entry; see :class:`ParamInfo` for details. + _shapes (Tuple[torch.Size, ...]): Each parameter's original shape. + _strides (Tuple[torch.Size, ...]): Each parameter's original stride. + _contiguities (Tuple[bool, ...]): Each parameter's ``contiguous()`` + call result. + _fqns (Tuple[str, ...]): Each parameter's fully-qualified name (FQN) + prefixed from the ``_fully_sharded_module``. The names are + guaranteed to be unique in the subtree rooted at that module. + _param_extensions (Tuple[Optional[Any], ...]): Each parameter's + extension (i.e. some per-parameter state) used to customize + pre-flatten and post-unflatten behavior or ``None``. This is + experimental, and users should not depend on its existence in the + future. + _numels_with_padding (Tuple[int, ...]): Each parameter's numel + including entries for the padding. This is used to construct views + into the flat parameter via ``torch.split()``. This may have length + longer than ``_num_params``. + _numels (Tuple[int, ...]): Each parameter's numel excluding entries for + padding. This has length equal to ``_num_params``. + _shard_param_infos (Tuple[_ShardParamInfo, ...]): Each parameter's + shard parameter info; see :class:`_ShardParamInfo` for details. + _shared_param_infos (Tuple[SharedParamInfo, ...]): Shared parameter + info entries; see :class:`SharedParamInfo` for details. + _modules (set[nn.Module]): Modules that contain some original parameter + that is flattened into the flat parameter. + + _shard_numel_padded (int): Numel padded for this rank's sharded flat + parameter. + _local_shard (Tensor): Sharded flat parameter with padding if using a + sharded strategy. If using ``NO_SHARD``, then this is the unpadded + unsharded flat parameter, and there is no notion of a sharded flat + parameter or padded unsharded flat parameter. + _full_param_padded (Tensor): Unsharded flat parameter with padding. + This is not defined for ``NO_SHARD``. When using mixed precision + for parameters, this has the low precision. + _full_prec_full_param_padded (Tensor): Full precision unsharded flat + parameter with padding. This is used for unsharding outside of + computation when using mixed precision for parameters. This is + never defined for ``NO_SHARD``. + _post_backward_hook_handle (RemovableHandle): + Flat parameter's post-backward hook handle. (Compile only) + _post_backward_hook_state (Tuple[AccumulateGrad, RemovableHandle]): + Flat parameter's :class:`AccumulateGrad` object and post-backward + hook handle. (Eager only) + _mp_shard (Tensor): Low precision sharded flat parameter with padding. + This is only defined when parameter mixed precision is enabled. For + ``NO_SHARD``, this is used for computation. + _cpu_grad (Tensor): Sharded gradient with padding stored on CPU. + This is only defined when offloading parameters is enabled. + _saved_grad_shard (Tensor): Sharded gradient with padding from previous + iterations for gradient accumulation without :meth:`no_sync`. + + _params (Optional[List[nn.Parameter]]): If ``use_orig_params=True``, + then each original parameter variable; otherwise, ``None``. This + does not include any padding tensors. + _shared_params (Optional[List[nn.Parameter]]): The original shared + parameter variables if ``use_orig_params=True`` and ``None`` + otherwise. + _tensors (Optional[List[Optional[Tensor]]]): This saves the ``Tensor`` + views created in the forward and tracked by autograd when + ``use_orig_params=True`` and is ``None`` otherwise. This is to + preserve those ``Tensor`` variables for the backward to ensure that + the ``FlatParameter`` 's ``AccumulateGrad`` object does not change + in which case the post-backward hook does not run. This is relevant + for cases like reentrant activation checkpointing. + _is_grad_none_mask (Optional[List[bool]]): If ``use_orig_params=True``, + a mask over the original parameters' gradients indicating if it is + logically ``None`` or not; otherwise, ``None``. This does not + include entries for padding. This mask is needed because only some + of the parameters may have ``None`` gradient, in which case the + flat gradient must be non-``None`` and must use zeros to + approximate those original ``None`` gradients. This mask informs + FSDP to set the original parameter gradients to ``None`` (instead + of zeros) as needed. + """ + + _unpadded_unsharded_size: torch.Size + _padded_unsharded_size: torch.Size + _sharded_size: torch.Size + _num_params: int + _param_infos: tuple[ParamInfo, ...] + _shapes: tuple[torch.Size, ...] + _strides: tuple[tuple[int, ...], ...] + _contiguities: tuple[bool, ...] + _fqns: tuple[str, ...] + _param_extensions: tuple[Optional[Any], ...] + _numels_with_padding: tuple[int, ...] + _numels: tuple[int, ...] + _shard_param_infos: tuple[_ShardParamInfo, ...] + _shared_param_infos: tuple[SharedParamInfo, ...] + _modules: set[nn.Module] + _shard_numel_padded: int + _local_shard: Tensor + _full_param_padded: Tensor + _full_prec_full_param_padded: Tensor + # Eager only + _post_backward_hook_state: tuple[Any, Any] + # Compile only + _post_backward_hook_handle: Any + _mp_shard: Tensor + _cpu_grad: Tensor + _saved_grad_shard: Tensor + _params: Optional[list[nn.Parameter]] + _shared_params: Optional[list[nn.Parameter]] + _tensors: Optional[list[Optional[Tensor]]] + _is_grad_none_mask: Optional[list[bool]] + + _is_padding_mask: list[bool] + + def __new__(cls, data=None, requires_grad=True): + if cls is not FlatParameter: + raise AssertionError("subclasses FlatParameter not supported") + r = nn.Parameter.__new__(nn.Parameter, data, requires_grad) # type: ignore[call-arg] + r._is_flat_param = True # type: ignore[attr-defined] + return r + + # NB: This is not a regular method, because FlatParameters are not actually + # instances of this class (see __new__ above). So you must indirectly + # call this directly through the classmethod. + @classmethod + def _init_metadata( + cls, + self, + param_infos: list[ParamInfo], + numels: list[int], + shapes: list[torch.Size], + strides: list[tuple[int, ...]], + contiguities: list[bool], + fqns: list[str], + shared_param_infos: list[SharedParamInfo], + param_extensions: list[Optional[Any]], + params: Optional[list[nn.Parameter]], + shared_params: Optional[list[nn.Parameter]], + is_padding_mask: list[bool], + ) -> None: + """ + Initialize attributes holding metadata about the original parameters comprising the flat parameter. + + We expose this method separate from the constructor to keep the + constructor only responsible for the flat parameter's tensor data. This + method should only be called once per model, while the constructor may + be called multiple times, e.g. when reloading from a checkpoint, in + which case only the tensor data needs to be passed to the constructor. + Since :meth:`load_state_dict` is implemented via :meth:`copy_`, the + metadata is correctly assumed to be unchanged. + + Args: + See the Attributes in the class docstring. + """ + if len(param_infos) != len(shapes): + raise AssertionError( + f"Expected param_infos length {len(param_infos)} to match shapes length {len(shapes)}" + ) + if len(param_infos) != len(strides): + raise AssertionError( + f"Expected param_infos length {len(param_infos)} to match strides length {len(strides)}" + ) + if len(param_infos) != len(contiguities): + raise AssertionError( + f"Expected param_infos length {len(param_infos)} to match contiguities length {len(contiguities)}" + ) + if len(param_infos) != len(fqns): + raise AssertionError( + f"Expected param_infos length {len(param_infos)} to match fqns length {len(fqns)}" + ) + if len(param_infos) != len(param_extensions): + raise AssertionError( + f"Expected param_infos length {len(param_infos)} to match param_extensions length {len(param_extensions)}" + ) + self._num_params = len(param_infos) + self._param_infos = param_infos + self._shapes = shapes + self._strides = strides + self._contiguities = contiguities + self._fqns = fqns + self._param_extensions = param_extensions + self._is_padding_mask = is_padding_mask + + numels_without_padding: list[int] = [] + for numel, is_padding in zip(numels, is_padding_mask): + if not is_padding: + numels_without_padding.append(numel) + self._numels = tuple(numels_without_padding) + self._numels_with_padding = tuple(numels) + if len(self._numels) != self._num_params: + raise AssertionError( + f"Expected _numels length {len(self._numels)} to equal _num_params {self._num_params}" + ) + + self._shared_param_infos = tuple(shared_param_infos) + self._modules = {pi.module for pi in self._param_infos}.union( + {spi.module for spi in self._shared_param_infos} + ) + if (params is None) != (shared_params is None): + raise AssertionError( + "Expected params and shared_params to both be None or both be not None" + ) + if params is not None: + if shared_params is None or len(shared_params) != len(shared_param_infos): + raise AssertionError( + f"Expected shared_params to be not None and have length {len(shared_param_infos)}, got {shared_params}" + ) + self._params = [] + for param, is_padding in zip(params, is_padding_mask): + if not is_padding: + self._params.append(param) + if shared_params is not None: + self._shared_params = shared_params + else: + self._shared_params = [] + # Mark the original parameters to avoid flattening them into + # another `FlatParameter` during recursive construction + for param in chain(self._params, self._shared_params): + _set_fsdp_flattened(param) + self._is_grad_none_mask = [False for _ in range(self._num_params)] + self._tensors = [None for _ in range(self._num_params)] + else: + self._params = None + self._shared_params = None + self._is_grad_none_mask = None + self._tensors = None + self._unpadded_unsharded_size = self.size() + _set_fsdp_flattened(self) + # Tracks whether the `FlatParameter`'s post-backward hook has been + # called to modify the behavior of the post-backward callback + self._post_backward_called = False + + +class FlatParamHandle: + """ + A handle that manages a flat parameter (:class:`FlatParameter`). + + This includes sharding and view management. + + Args: + params (Sequence[nn.Parameter]): The parameters to flatten into the + flat parameter. + fully_sharded_module (nn.Module): See [Note: Fully Sharded Module]. + device (torch.device): The compute and communication device, which + should be a non-CPU device. We refer to it as the compute device. + sharding_strategy (ShardingStrategy): Sharding strategy to apply to + this handle's ``FlatParameter``. + offload_params (bool): Whether to offload the handle's + ``FlatParameter`` to CPU. + mp_param_dtype (Optional[torch.dtype]): Parameter mixed precision + setting passed to the FSDP constructor. + mp_reduce_dtype (Optional[torch.dtype]): Gradient reduction mixed + precision setting passed to the FSDP constructor. + keep_low_precision_grads (bool): Whether to keep gradients in low + precision. + use_orig_params (bool): If ``True``, then FSDP preserves the original + parameter variables and returns them from ``named_parameters()`` + (e.g. to support different optimizer hyperparameters within one + :class:`FlatParameter`). If ``False``, then FSDP reconstructs the + parameters every iteration and returns the :class:`FlatParameter` s + from ``named_parameters()``. + """ + + ################## + # INITIALIZATION # + ################## + def __init__( + self, + params: Sequence[Union[nn.Parameter, Tensor]], + fully_sharded_module: nn.Module, + device: torch.device, + sharding_strategy: HandleShardingStrategy, + offload_params: bool, + mp_param_dtype: Optional[torch.dtype], + mp_reduce_dtype: Optional[torch.dtype], + keep_low_precision_grads: bool, + process_group: dist.ProcessGroup, + use_orig_params: bool, + *, + fsdp_extension: Optional[FSDPExtensions] = None, + ): + super().__init__() + params = list(params) + if len(params) == 0: + raise ValueError( + f"Cannot construct a {self.__class__.__name__} with an empty parameter list" + ) + self._init_setattr_fns() + self._skip_writeback_check = ( + os.environ.get(_FSDP_SKIP_WRITEBACK_CHECK, "") == "1" + ) + self._use_full_prec_in_eval = ( + os.environ.get(_FSDP_USE_FULL_PREC_IN_EVAL, "") == "1" + ) + self._use_fake_all_gather = os.environ.get(_FSDP_USE_FAKE_ALL_GATHER, "") == "1" + self._use_fake_reduce = os.environ.get(_FSDP_USE_FAKE_REDUCE, "") == "1" + if self._skip_writeback_check: + _warn_skip_writeback_check( + logger, + f"Since {_FSDP_SKIP_WRITEBACK_CHECK}=1, FSDP will not check " + "for parameter or gradient writeback. Changing parameter or " + "gradient storages may lead to silent correctness errors.", + ) + if self._use_fake_all_gather: + _warn_use_fake_all_gather( + logger, + f"Since {_FSDP_USE_FAKE_ALL_GATHER}=1, FSDP will not execute " + "all-gather ops. Your training will be incorrect, but " + "can reveal how much time spent on all-gather ops.", + ) + if self._use_fake_reduce: + _warn_use_fake_reduce( + logger, + f"Since {_FSDP_USE_FAKE_REDUCE}=1, FSDP will not execute " + "reduce-scatter ops. Your training will be incorrect, but " + "can reveal how much time spent on reduce-scatter ops.", + ) + # Only align addresses for `use_orig_params=True` (for now) + align_addresses = use_orig_params + self._init_get_unflat_views_fn(align_addresses) + # pyrefly: ignore [read-only] + self.device = device + self._device_handle = _FSDPDeviceHandle.from_device(self.device) + self.process_group = process_group + if self._use_fake_all_gather or self._use_fake_reduce: + self._fake_process_group = FakeProcessGroup._create_internal( + rank=process_group.rank(), world_size=process_group.size() + ) + self.rank = process_group.rank() + self.world_size = process_group.size() + self._sharding_strategy = sharding_strategy + self._offload_params = offload_params + self._use_orig_params = use_orig_params + self._keep_low_precision_grads = keep_low_precision_grads + self._training_state = HandleTrainingState.IDLE + self._debug_level = dist.get_debug_level() + self._fully_sharded_module = fully_sharded_module + # For strategies that do not free after forward, we skip using sharded + # views after forward since the unsharded data exists. We still switch + # `self.flat_param` to point to the sharded flat parameter since what + # it points to parameterizes behavior. We use the following attribute + # to track which tensor data the parameters are unsharded views into. + self._unsharded_flat_param_for_skipped_views: Optional[Tensor] = None + # The index in the state's `all_handles`, which must be the + # same across ranks for the execution order validation to work + self._handle_index: Optional[int] = None + # Index in handles_to_pre_forward_order + self._pre_forward_order_index: Optional[int] = None + # Index in `handles_post_forward_order` + self._post_forward_index: Optional[int] = None + # Used for guarding against mistargeted forward prefetches + self._needs_pre_forward_unshard = False + # Used for guarding against mistargeted backward prefetches + self._needs_pre_backward_unshard = False + # Was the handle prefetched? Set on successful _prefetch_handle and unshard + self._prefetched = False + # Optimistically assume a valid input `params` and set dtype attributes + # before `_init_flat_param()`, which performs the actual validation + self._orig_param_dtype = params[0].dtype + self._init_param_reduce_dtypes(mp_param_dtype, mp_reduce_dtype) + if self._fwd_bwd_param_dtype is None: + raise AssertionError("Expected _fwd_bwd_param_dtype to be not None") # mypy + self._aligned_numel = ( + _get_aligned_numel(unsharded_dtype=self._fwd_bwd_param_dtype) + if align_addresses + else 0 + ) + self._fsdp_extension = fsdp_extension + self._init_flat_param_and_metadata( + params, + fully_sharded_module, + self._aligned_numel, + use_orig_params, # type: ignore[arg-type] + ) + self._use_unsharded_views(as_params=False) + + def __repr__(self): + return f"FlatParamHandle(flat_param.fqns={self.flat_param._fqns})" + + def _init_setattr_fns(self): + use_unsafe_setattr = os.environ.get(_FSDP_USE_UNSAFE_SETATTR, "") == "1" + self._setattr_tensor: Callable[[nn.Module, str, Tensor], None] + self._setattr_param: Callable[[nn.Module, str, nn.Parameter], None] + if use_unsafe_setattr: + self._setattr_tensor = _unsafe_setattr_tensor + self._setattr_param = _unsafe_setattr_param + else: + self._setattr_tensor = _safe_setattr_tensor_or_param + self._setattr_param = _safe_setattr_tensor_or_param + + def _init_get_unflat_views_fn(self, align_addresses: bool): + self._get_unflat_views = ( + self._get_unflat_views_aligned + if align_addresses + else self._get_unflat_views_unaligned + ) + + def _init_flat_param_and_metadata( + self, + params: list[Union[Tensor, nn.Parameter]], + module: nn.Module, + aligned_numel: int, + use_orig_params: bool, + ) -> None: + """ + Initialize the ``FlatParameter`` and its metadata. + + NOTE: This should only be called once at construction time, after which + the ``FlatParameter`` metadata is assumed to be static. + + NOTE: The elements of ``params`` should only be ``Tensor`` s when + composing with ``DTensor`` -based tensor parallelism, in which case the + elements may be ``DTensor`` local shards. + """ + if len(params) == 0: + raise ValueError("Expects non-empty `params`") + if aligned_numel < 0: + raise ValueError( + f"Expects non-negative `aligned_numel` but got {aligned_numel}" + ) + ( + dtype, + flat_param_requires_grad, + device, + ) = self._validate_tensors_to_flatten(params) + params_set = set(params) + # For alignment padding, only `numels` gets strictly non-`None` + # elements, and all other lists get `None` elements for padding. + param_infos: list[ParamInfo] = [] + numels: list[int] = [] + shapes: list[torch.Size] = [] + strides: list[tuple[int, ...]] = [] + contiguities: list[bool] = [] + fqns: list[str] = [] + shared_param_infos: list[SharedParamInfo] = [] + shared_param_memo: dict[ + Union[Tensor, nn.Parameter], tuple[nn.Module, str, str] + ] = {} + params_to_flatten: list[Union[Tensor, nn.Parameter]] = [] + shared_params: list[Union[Tensor, nn.Parameter]] = [] + param_extensions: list[Any] = [] + is_padding_mask: list[bool] = [] + total_numel = total_numel_without_padding = 0 + for submodule_name, submodule in module.named_modules(remove_duplicate=False): + for param_name, param in _named_parameters_with_duplicates( + submodule, recurse=False + ): + if param not in params_set: + continue + if param in shared_param_memo: # shared reference + prim_module, prim_module_name, prim_param_name = shared_param_memo[ + param + ] + shared_params.append(param) + shared_param_infos.append( + SharedParamInfo( + param_name, + submodule, + submodule_name, + prim_param_name, + prim_module, + prim_module_name, + ) + ) + else: + if aligned_numel > 0: + numel_to_pad = aligned_numel - (total_numel % aligned_numel) + if numel_to_pad > 0 and numel_to_pad < aligned_numel: + padding_tensor = _construct_padding_tensor( + numel_to_pad, dtype, False, device + ) + params_to_flatten.append(padding_tensor) + is_padding_mask.append(True) + numels.append(numel_to_pad) + total_numel += numel_to_pad + transform_t, extension = _ext_pre_flatten_transform( + param, + self._fsdp_extension, + ) + param = cast(nn.Parameter, transform_t) + param_extensions.append(extension) + shared_param_memo[param] = (submodule, submodule_name, param_name) + params_to_flatten.append(param) + is_padding_mask.append(False) + param_infos.append(ParamInfo(param_name, submodule, submodule_name)) + numels.append(param.numel()) + shapes.append(param.shape) + strides.append(param.stride()) + contiguities.append(_is_truly_contiguous(param)) + fqn = ( + submodule_name + "." + param_name + if submodule_name + else param_name + ) + fqns.append(fqn) + total_numel += param.numel() + total_numel_without_padding += param.numel() + if len(params_to_flatten) == 0: + raise ValueError( + f"`params` were not found in `module`'s tree" + f"params: {params}\nmodule: {module}" + ) + if ( + self.rank == 0 + and aligned_numel > 0 + and total_numel != total_numel_without_padding + ): + logger.debug( + "FSDP FlatParameter address alignment created " + "%s numel of padding (%s vs. %s)", + total_numel - total_numel_without_padding, + total_numel, + total_numel_without_padding, + ) + if aligned_numel > 0: + # Pad to be divisible by world size to avoid a copy for the + # post-backward reduce-scatter + numel_to_pad = self.world_size - (total_numel % self.world_size) + if numel_to_pad > 0 and numel_to_pad < self.world_size: + if self.rank == 0: + logger.info( + "FSDP FlatParameter world size divisibility created " + "%s numel of padding", + numel_to_pad, + ) + padding_tensor = _construct_padding_tensor( + numel_to_pad, dtype, False, device + ) + params_to_flatten.append(padding_tensor) + is_padding_mask.append(True) + numels.append(numel_to_pad) + total_numel += numel_to_pad + # Pass `aligned_numel=0` since we already included padding tensors + self.flat_param: FlatParameter = self.flatten_tensors_into_flat_param( + params_to_flatten, + aligned_numel=0, + requires_grad=flat_param_requires_grad, + ) + FlatParameter._init_metadata( + self.flat_param, + param_infos, + numels, + shapes, + strides, + contiguities, + fqns, + shared_param_infos, + param_extensions, + _convert_to_params(params_to_flatten) if use_orig_params else None, + _convert_to_params(shared_params) if use_orig_params else None, + is_padding_mask, + ) + + def _validate_tensors_to_flatten( + self, tensors: list[Union[Tensor, nn.Parameter]] + ) -> tuple: + """Validate the tensors to flatten and returns any necessary metadata.""" + dtype: Optional[torch.dtype] = None + # Return as the logical OR over each tensor's value + flat_param_requires_grad: Optional[bool] = None + device: Optional[torch.device] = None + # For `use_orig_params=True`, permit non-uniform `requires_grad` + for tensor in tensors: + if isinstance(tensor, FlatParameter): + raise ValueError("Cannot flatten a `FlatParameter`") + if dtype is None and not tensor.is_floating_point(): + raise ValueError("Cannot flatten integer dtype tensors") + if dtype is not None and tensor.dtype != dtype: + raise ValueError( + f"Must flatten tensors with uniform dtype but got {dtype} " + f"and {tensor.dtype}" + ) + if ( + not self._use_orig_params + and flat_param_requires_grad is not None + and tensor.requires_grad != flat_param_requires_grad + ): + raise ValueError( + "Must flatten tensors with uniform `requires_grad` when " + "`use_orig_params=False`" + ) + if device is not None and tensor.device != device: + raise ValueError( + "Must flatten tensors on the same device but got both " + f"{device} and {tensor.device}" + ) + dtype = tensor.dtype + flat_param_requires_grad = flat_param_requires_grad or tensor.requires_grad + device = tensor.device + if flat_param_requires_grad is None: + raise AssertionError("Requires non-empty `tensors` list") + return dtype, flat_param_requires_grad, device + + def flatten_tensors( + self, + tensors: list[Tensor], + aligned_numel: int, + ) -> Tensor: + """ + Flatten ``tensors`` into a single flat tensor. + + The flattening optionally includes + padding if ``aligned_numel`` is greater than 0, where ``aligned_numel`` + gives the numel required to have address alignment. + + NOTE: The padding alignment algorithm must be kept in sync with + :meth:`_init_flat_param_metadata`. We separate the two methods because + the initialization happens once, whereas this method may be called + multiple times throughout training (e.g. for checkpointing). + """ + if len(tensors) == 0: + raise ValueError("Expects non-empty `tensors`") + if aligned_numel < 0: + raise ValueError( + f"Expects non-negative `aligned_numel` but got {aligned_numel}" + ) + dtype, _, device = self._validate_tensors_to_flatten(tensors) + flat_tensors: list[Tensor] = [] + if aligned_numel > 0: + total_numel = 0 + for tensor in tensors: + numel_to_pad = aligned_numel - (total_numel % aligned_numel) + if numel_to_pad > 0 and numel_to_pad < aligned_numel: + padding_tensor = _construct_padding_tensor( + numel_to_pad, dtype, False, device + ) + flat_tensors.append(padding_tensor) + total_numel += numel_to_pad + flat_tensors.append( + torch.flatten(_detach_if_needed(tensor)) + if _is_truly_contiguous(tensor) + else _detach_if_needed(tensor).as_strided((tensor.numel(),), (1,)) + ) + total_numel += tensor.numel() + numel_to_pad = self.world_size - (total_numel % self.world_size) + if numel_to_pad > 0 and numel_to_pad < self.world_size: + padding_tensor = _construct_padding_tensor( + numel_to_pad, dtype, False, device + ) + flat_tensors.append(padding_tensor) + total_numel += numel_to_pad + else: + flat_tensors = [ + torch.flatten(_detach_if_needed(tensor)) + if _is_truly_contiguous(tensor) + else _detach_if_needed(tensor).as_strided((tensor.numel(),), (1,)) + for tensor in tensors + ] + return torch.cat(flat_tensors, dim=0) + + def flatten_tensors_into_flat_param( + self, + tensors: list[Tensor], + aligned_numel: int, + requires_grad: bool, + ) -> FlatParameter: + flat_param_data = self.flatten_tensors(tensors, aligned_numel) + return FlatParameter(flat_param_data, requires_grad=requires_grad) + + def _init_param_reduce_dtypes( + self, + mp_param_dtype: Optional[torch.dtype], + mp_reduce_dtype: Optional[torch.dtype], + ) -> None: + """ + Initialize param and reduce dtypes. + + Precondition: ``self.flat_param`` is set. This ensures that this + handle's parameters have a single dtype. + + Postcondition: This sets ``self._fwd_bwd_param_dtype`` and + ``self._reduce_dtype``. If ``mp_param_dtype`` or ``mp_reduce_dtype`` + is ``None``, then we assume the original parameter dtype. One special + case is if ``mp_param_dtype`` is not ``None`` and ``mp_reduce_dtype`` + is ``None``, in which case we assume the gradient reduction dtype + matches the forward/backward parameter dtype. + """ + # Save whether these dtypes were specified so that we permit the + # parameter dtype to change up until the lazy initialization + self._low_prec_param_dtype_specified = mp_param_dtype is not None + self._low_prec_reduce_dtype_specified = mp_reduce_dtype is not None + if ( + self._low_prec_param_dtype_specified + and not self._low_prec_reduce_dtype_specified + ): + # Special case: infer gradient reduction mixed precision + self._fwd_bwd_param_dtype = mp_param_dtype + self._reduce_dtype = self._fwd_bwd_param_dtype + else: + self._fwd_bwd_param_dtype = mp_param_dtype or self._orig_param_dtype + self._reduce_dtype = mp_reduce_dtype or self._orig_param_dtype + if self._fwd_bwd_param_dtype is None: + raise AssertionError("Expected _fwd_bwd_param_dtype to be not None") + if self._reduce_dtype is None: + raise AssertionError("Expected _reduce_dtype to be not None") + + ################################### + # SHARD INITIALIZATION & METADATA # + ################################### + @torch.no_grad() + def shard(self): + """ + Shard the handle's ``FlatParameter``. + + This allocates new memory for + the sharded flat parameter and frees the unsharded flat parameter's + storage. + + Postcondition: ``self.flat_param`` is the sharded flat parameter. Shard + metadata attributes are set for all sharding strategies. + """ + flat_param = self.flat_param + if not self.uses_sharded_strategy: + self._init_shard_metadata(0, 0, flat_param.numel() - 1) + else: + _p_assert( + flat_param.storage_offset() == 0, + "The `FlatParameter` is not the sole occupant of its storage", + ) + sharded_flat_param, numel_padded = FlatParamHandle._get_shard( + flat_param, self.rank, self.world_size + ) + if not torch.distributed._functional_collectives.is_torchdynamo_compiling(): + allocated = flat_param._typed_storage()._size() > 0 + if allocated: + flat_param._typed_storage()._resize_(0) + flat_param.set_(sharded_flat_param) # type: ignore[call-overload] + start_idx = sharded_flat_param.numel() * self.rank + end_idx = sharded_flat_param.numel() * (self.rank + 1) - 1 # inclusive + self._init_shard_metadata(numel_padded, start_idx, end_idx) + if self._use_orig_params: + self._use_sharded_views() + + def _init_shard_metadata( + self, + numel_padded: int, + unsharded_start_idx: int, + unsharded_end_idx: int, + ) -> None: + """ + Initialize shard-related metadata for this rank's shard of the flat parameter. + + This includes ``_sharded_size``, ``_shard_param_infos``, and ``_shard_numel_padded``. + + Args: + numel_padded (int): Numel padded for this rank's sharded flat + parameter. + unsharded_start_idx (int): Start index in the unsharded flat + parameter assigned to this rank. + unsharded_end_idx (int): End index (inclusive) in the unsharded + flat parameter assigned to this rank. + + Precondition: ``self.flat_param`` 's data is the sharded flat + parameter. + """ + flat_param = self.flat_param + flat_param._sharded_size = flat_param.size() # type: ignore[attr-defined] + sharded_flat_param_numel = flat_param.numel() # includes `numel_padded` + _p_assert( + unsharded_start_idx >= 0 and unsharded_start_idx <= unsharded_end_idx, + f"unsharded_start_idx: {unsharded_start_idx} unsharded_end_idx: {unsharded_end_idx}", + ) + _p_assert( + numel_padded <= sharded_flat_param_numel, + f"numel_padded: {numel_padded} " + f"sharded_flat_param_numel: {sharded_flat_param_numel}", + ) + shard_param_infos = self._get_shard_metadata( + unsharded_start_idx, unsharded_end_idx + ) + if len(shard_param_infos) != flat_param._num_params: + raise AssertionError( + f"Expects length {flat_param._num_params} but got {len(shard_param_infos)}" + ) + flat_param._shard_param_infos = shard_param_infos # type: ignore[attr-defined] + flat_param._shard_numel_padded = numel_padded # type: ignore[attr-defined] + + def _get_shard_metadata( + self, + unsharded_start_idx: int, + unsharded_end_idx: int, + ) -> tuple[_ShardParamInfo, ...]: + """ + Compute the shard metadata based on ``unsharded_start_idx`` and ``unsharded_end_idx`` (inclusive). + + ``unsharded_start_idx`` and ``unsharded_end_idx`` give the interval of the + unsharded flat parameter specifying the shard. + """ + flat_param_offsets = self._get_flat_param_offsets() + if len(flat_param_offsets) != len(self.flat_param._numels_with_padding): + raise AssertionError( + f"Expected {len(self.flat_param._numels_with_padding)} but got {len(flat_param_offsets)}" + ) + shard_param_infos: list[_ShardParamInfo] = [] + sharded_flat_param_numel = unsharded_end_idx - unsharded_start_idx + 1 + # `unsharded_param_start_idx` and `unsharded_param_end_idx` are indices + # into the unsharded flat parameter (inclusive) of the given parameter + for ( + (unsharded_param_start_idx, unsharded_param_end_idx), + is_padding, + ) in zip(flat_param_offsets, self.flat_param._is_padding_mask): + if is_padding: + continue + in_sharded_flat_param = ( + unsharded_start_idx <= unsharded_param_end_idx + and unsharded_end_idx >= unsharded_param_start_idx + ) + if not in_sharded_flat_param: + shard_param_info = _ShardParamInfo(False, None, None, None, None) + else: + if unsharded_start_idx <= unsharded_param_start_idx: + # This branch can only happen once since the rank's + # unsharded start index can only intersect one parameter + intra_param_start_idx = 0 + offset_in_shard = unsharded_param_start_idx - unsharded_start_idx + else: + intra_param_start_idx = ( + unsharded_start_idx - unsharded_param_start_idx + ) + offset_in_shard = 0 + if not ( + offset_in_shard >= 0 and offset_in_shard < sharded_flat_param_numel + ): + raise AssertionError( + f"Invalid `offset_in_shard` of {offset_in_shard} for " + f"sharded flat parameter with {sharded_flat_param_numel} numel" + ) + intra_param_end_idx = ( + min(unsharded_param_end_idx, unsharded_end_idx) + - unsharded_param_start_idx + ) + numel_in_shard = intra_param_end_idx - intra_param_start_idx + 1 + shard_param_info = _ShardParamInfo( + True, + offset_in_shard, + numel_in_shard, + intra_param_start_idx, + intra_param_end_idx, + ) + shard_param_infos.append(shard_param_info) + return tuple(shard_param_infos) + + @staticmethod + def _get_unpadded_shard( + tensor: Tensor, + rank: int, + world_size: int, + ) -> tuple[Tensor, int]: + """ + Return the unpadded shard of ``tensor`` for the given ``rank`` and ``world_size``. + + The returned value is a tuple of the shard of ``tensor`` without any + padding and the numel to pad for that shard. + + If ``tensor`` is already flattened or may be viewed in the flattened + shape (which is true in the expected usage), then this method does not + allocate any new tensor memory. + """ + chunks = ( + torch.flatten(tensor).chunk(world_size) + if _is_truly_contiguous(tensor) + else tensor.as_strided((tensor.numel(),), (1,)).chunk(world_size) + ) + if len(chunks) < (rank + 1): + # This rank gets an empty chunk fully padded with zeros since there + # are not enough chunks across ranks + chunk = chunks[0].new_empty(0) + else: + chunk = chunks[rank] + numel_to_pad = chunks[0].numel() - chunk.numel() + if numel_to_pad < 0: + raise AssertionError( + "Chunk's size should be at most the first chunk's size" + ) + return chunk, numel_to_pad + + @staticmethod + def _get_shard( + tensor: Tensor, + rank: int, + world_size: int, + ) -> tuple[Tensor, int]: + """ + Return the shard of ``tensor`` with padding for the given ``rank`` and ``world_size`` and the numel padded for that shard. + + This method allocates new memory (via :meth:`clone`) since the + unsharded ``tensor`` may be deallocated after this method returns. + """ + chunk, numel_to_pad = FlatParamHandle._get_unpadded_shard( + tensor, rank, world_size + ) + shard = chunk.clone() + if numel_to_pad > 0: + shard = F.pad(shard, [0, numel_to_pad]) + return shard, numel_to_pad + + @staticmethod + def _get_sharded_size(tensor: Tensor, rank: int, world_size: int) -> torch.Size: + """ + Return the shape of ``tensor`` after sharding including padding. + + This requires ``tensor`` to have 1D shape and ensures that the returned + shape is 1D. + """ + if len(tensor.shape) != 1: + raise AssertionError(f"Expected 1D tensor shape, got {tensor.shape}") + unpadded_sharded_tensor, numel_to_pad = FlatParamHandle._get_unpadded_shard( + tensor, rank, world_size + ) + unpadded_sharded_size = unpadded_sharded_tensor.size() + if len(unpadded_sharded_size) != 1: + raise AssertionError( + f"Expected 1D unpadded_sharded_size, got {unpadded_sharded_size}" + ) + return torch.Size([unpadded_sharded_size[0] + numel_to_pad]) + + def _get_flat_param_offsets(self) -> list[tuple[int, int]]: + """ + Return [start, end] offsets of each original parameter's flattened data in the unsharded flat parameter (without padding). + + NOTE: The returned list includes elements for alignment padding. + """ + cumulative_sum = list(accumulate(self.flat_param._numels_with_padding)) + starts = [0] + cumulative_sum[:-1] + ends = [end - 1 for end in cumulative_sum] # inclusive + param_offsets = list(zip(starts, ends)) + return param_offsets + + @no_type_check + def shard_metadata( + self, + ) -> FlatParamShardMetadata: + """ + Return the shard-related metadata specific to this rank's shard of the flat parameter. + + NOTE: The returned tuple does not include elements for alignment + padding but does account for the padding. + """ + fqns_list = [] + shapes_list = [] + strides_list = [] + contiguities_list = [] + numels_list = [] + shard_param_offsets = [] + for fqn, shape, stride, contiguous, numel, shard_param_info in zip( + self.flat_param._fqns, + self.flat_param._shapes, + self.flat_param._strides, + self.flat_param._contiguities, + self.flat_param._numels, + self.flat_param._shard_param_infos, + ): + if not shard_param_info.in_shard: + continue + fqns_list.append(fqn) + shapes_list.append(shape) + strides_list.append(stride) + contiguities_list.append(contiguous) + numels_list.append(numel) + shard_param_offsets.append( + ( + shard_param_info.intra_param_start_idx, + shard_param_info.intra_param_end_idx, + ) + ) + return FlatParamShardMetadata( + tuple(fqns_list), + tuple(shapes_list), + tuple(strides_list), + tuple(contiguities_list), + tuple(numels_list), + tuple(shard_param_offsets), + ) + + @no_type_check + @torch.no_grad() + def init_flat_param_attributes(self) -> None: + """ + This initializes some attributes on the handle's ``FlatParameter``. + This should be called during lazy initialization since it requires the + parameter to be on the compute device if not offloading to CPU and we + want to give users the chance to move the parameter appropriately after + the FSDP constructor. + + For each tensor attribute on the ``FlatParameter``, see the unshard and + reshard methods in this class for the allocation and free pattern. + """ + flat_param = self.flat_param + if flat_param.dtype != self._orig_param_dtype: + # Entering this branch means that the user changed the parameter + # dtype after FSDP initialization, in which case we may need to + # refresh some saved dtype attributes (dtypes specified as a part + # of mixed precision take precedence). + if not self._low_prec_param_dtype_specified: + self._fwd_bwd_param_dtype = flat_param.dtype + # For `reduce_dtype`, require `param_dtype` was not specified since + # then we infer the `reduce_dtype` from the specified `param_dtype` + if ( + not self._low_prec_reduce_dtype_specified + and not self._low_prec_param_dtype_specified + ): + self._reduce_dtype = flat_param.dtype + self._orig_param_dtype = flat_param.dtype + cpu_device = torch.device("cpu") + if self._offload_params: + _p_assert( + flat_param.device == cpu_device, + f"Expects the `FlatParameter` to be on CPU when parameter CPU " + f"offloading is enabled, not {flat_param.device}", + ) + else: + self._check_on_compute_device(self.flat_param) + flat_param._local_shard = flat_param.data + if self._offload_params: + # Pin the memory for faster H2D transfer + flat_param._local_shard = flat_param._local_shard.pin_memory() + # Pre-allocate the sharded gradient on CPU to enable non-blocking + # D2H transfer during the backward pass + flat_param._cpu_grad = torch.zeros_like( + flat_param._local_shard, device=cpu_device + ).pin_memory() + if self._uses_param_mixed_precision: + # For parameter mixed precision, we maintain a low precision + # sharded tensor on the compute device to be all-gathered (for + # sharded strategies) or directly used (for `NO_SHARD`) for + # computation. + flat_param._mp_shard = torch.empty_like( + flat_param._local_shard, + device=self.device, + dtype=self._fwd_bwd_param_dtype, + ) + _free_storage(flat_param._mp_shard) + if self.uses_sharded_strategy: + # We maintain a padded unsharded tensor that serves as the + # all-gather destination and owns the original parameter storages. + unsharded_param_dtype = ( + self._fwd_bwd_param_dtype + if self._uses_param_mixed_precision + else flat_param.dtype + ) # use low precision if parameter mixed precision is enabled + padded_unsharded_numel = flat_param.numel() * self.world_size + flat_param._full_param_padded = torch.empty( + padded_unsharded_numel, + device=self.device, + dtype=unsharded_param_dtype, + ) + flat_param._padded_unsharded_size = flat_param._full_param_padded.size() + _free_storage(flat_param._full_param_padded) + + if self._uses_param_mixed_precision: + # For parameter mixed precision, we maintain a full precision + # padded unsharded tensor for when we force full precision. + flat_param._full_prec_full_param_padded = torch.empty( + padded_unsharded_numel, + device=self.device, + dtype=flat_param.dtype, # full precision + ) + _free_storage(flat_param._full_prec_full_param_padded) + + ################### + # UNSHARD/RESHARD # + ################### + def pre_unshard(self) -> bool: + """ + Return ``False`` if this is a no-op and ``True`` otherwise. + + Postcondition: ``self.flat_param`` 's data is on the device for + communication and is what should be all-gathered. This means that it + matches the dtype of the expected unsharded parameter. + """ + if ( + self._training_state == HandleTrainingState.SUMMON_FULL_PARAMS + and self._skipped_use_sharded_views + ): + # Since this path imposes special semantics for the unsharded flat + # parameter (e.g. forcing full precision), use sharded views to + # reuse the existing logic for that special handling + self._use_sharded_views() + ret = False + if self._use_orig_params and not self._skip_writeback_check: + ret = self._writeback_orig_params() + if ( + self.uses_sharded_strategy + and not self._offload_params + and not self.needs_unshard() + ): + pass # no-op + elif self._uses_param_mixed_precision and not self._force_full_precision: + self._use_low_precision_shard() + ret = True + elif self._offload_params and self.flat_param.device != self.device: + # NOTE: This creates a new tensor distinct from any attributes. + self.flat_param_to(self.device, non_blocking=True) + ret = True + self._check_on_compute_device(self.flat_param) + return ret + + def _use_low_precision_shard(self): + """Allocate on the compute device and switch to using the low precision sharded flat parameter.""" + self._check_low_precision_shard() + flat_param = self.flat_param + _alloc_storage( + flat_param._mp_shard, + flat_param._local_shard.size(), # type: ignore[attr-defined] + ) + # `copy_()` implicitly casts to the low precision + flat_param._mp_shard.copy_( # type: ignore[attr-defined] + flat_param._local_shard.to( # type: ignore[attr-defined] + self.device, non_blocking=True + ) + ) + # Invariant: `_mp_shard` is always on the compute device. + flat_param.data = flat_param._mp_shard # type: ignore[attr-defined] + + def unshard(self): + """ + Run the unshard logic. + + This includes all-gathering the flat parameter + and switching to using the unsharded flat parameter. If the handle does + not need unsharding, then this only switches to using the unsharded + flat parameter. For ``NO_SHARD``, this is a no-op. + + If FSDP is in :meth:`summon_full_params` and the handle uses parameter + mixed precision, then the parameter is forced to full precision. + """ + if not self.needs_unshard(): + # Even when not needing an unshard, we should switch to using + # the unsharded flat parameter + unsharded_flat_param = ( + self._get_padded_unsharded_flat_param() + if self.uses_sharded_strategy + else self.flat_param + ) + self._use_unsharded_flat_param(unsharded_flat_param) + return + unsharded_flat_param = self._alloc_padded_unsharded_flat_param() + padded_unsharded_flat_param = self._all_gather_flat_param(unsharded_flat_param) + self._use_unsharded_flat_param(padded_unsharded_flat_param) + + def needs_unshard(self) -> bool: + """Return if the handle's flat parameter needs to be unsharded.""" + if not self.uses_sharded_strategy: + return False + unsharded_flat_param = self._get_padded_unsharded_flat_param() + already_unsharded = _same_storage_size( + unsharded_flat_param, unsharded_flat_param.numel() + ) + return not already_unsharded + + def _alloc_padded_unsharded_flat_param(self): + """ + Allocate the *padded* unsharded flat parameter. + + The unpadded unsharded + flat parameter is always a view into the padded one. This padded + parameter is saved to a different attribute on the ``FlatParameter`` + depending on if we force full precision. + """ + self._check_sharded_strategy() + flat_param = self.flat_param + unsharded_flat_param = self._get_padded_unsharded_flat_param() + self._check_storage_freed(unsharded_flat_param) + _alloc_storage(unsharded_flat_param, flat_param._padded_unsharded_size) # type: ignore[attr-defined] + return unsharded_flat_param + + def _get_padded_unsharded_flat_param(self) -> torch.Tensor: + """ + Return a reference to the padded unsharded flat parameter depending on the calling context. + + This should only be called if using a sharded strategy. + """ + self._check_sharded_strategy() + flat_param = self.flat_param + if self._force_full_precision and self._uses_param_mixed_precision: + # When parameter mixed precision is enabled, we use a different + # tensor as the all-gather destination to preserve the invariant + # that `_full_param_padded` is in the low precision + unsharded_flat_param = flat_param._full_prec_full_param_padded # type: ignore[attr-defined] + _p_assert( + unsharded_flat_param.dtype != self._fwd_bwd_param_dtype, + f"Expects full precision but got {self._fwd_bwd_param_dtype}", + ) + # For no-reshard-after-forward strategies, `_full_param_padded` may + # still be allocated from a previous forward. As we are forcing + # full precision here, the full-precision unsharded copy may be + # modified, invalidating the existing low-precision unsharded copy, + # so we should free it here to ensure a new all-gather for the next + # forward/backward computation to persist the modifications. + if flat_param._full_param_padded.untyped_storage().size() > 0: + _free_storage(flat_param._full_param_padded) + else: + unsharded_flat_param = flat_param._full_param_padded # type: ignore[attr-defined] + return unsharded_flat_param + + def _all_gather_flat_param( + self, + padded_unsharded_flat_param: Tensor, + ) -> Tensor: + """ + All-gather the handle's flat parameter to the destination ``padded_unsharded_flat_param``. + + Then switch to use the all-gathered tensor. + """ + _p_assert( + hasattr(self, "process_group") and hasattr(self, "world_size"), + "Expects a process group and world size to have been set via `shard()`", + ) + sharded_flat_param = self.flat_param.data + expected_numel = sharded_flat_param.numel() * self.world_size + _p_assert( + padded_unsharded_flat_param.numel() == expected_numel, + f"Expects {expected_numel} numel but got {padded_unsharded_flat_param.numel()}", + ) + + pg = ( + self._fake_process_group + if self._use_fake_all_gather + else self.process_group + ) + + # HACK this should be handled by C10D + if sharded_flat_param.is_cpu: # type: ignore[attr-defined] + tensor_list = list( + torch.chunk( + padded_unsharded_flat_param, + dist.get_world_size(pg), # type: ignore[arg-type] + ) + ) + dist.all_gather(tensor_list, sharded_flat_param, group=pg) + else: + dist.all_gather_into_tensor( + padded_unsharded_flat_param, + sharded_flat_param, + pg, + ) + + if self._offload_params: + # In case of offloading, `flat_param.data` (i.e. sharded param) is + # created on the pre-unshard stream. We need to hand it over to the + # unshard stream for all-gather + _no_dispatch_record_stream( + sharded_flat_param, + self._device_handle.current_stream(), # unshard_stream + ) + return padded_unsharded_flat_param + + def _use_unsharded_flat_param( + self, + padded_unsharded_flat_param: torch.Tensor, + ) -> None: + """ + Switch to use the *unpadded* unsharded flat parameter. + + This is a view into the *padded* unsharded flat parameter. + """ + unsharded_size = self.flat_param._unpadded_unsharded_size + flat_param_part = padded_unsharded_flat_param[: unsharded_size.numel()] + # slicing [:] is not visible to autograd because of .data + self.flat_param.data = flat_param_part + in_forward = self._training_state == HandleTrainingState.FORWARD + in_pre_backward = self._training_state == HandleTrainingState.BACKWARD_PRE + if self._use_orig_params: + if self._skipped_use_sharded_views and in_pre_backward: + # This call corresponds to the complementary pre-backward + # `_use_unsharded_views()` to the skipped pre-forward + # `_use_sharded_views()`, so we should skip this one too. + return + # We use `Tensor` views in the forward so that they are tracked by + # autograd. We use them in the pre-backward as well to support + # reentrant activation checkpointing, which needs the views to be + # tracked by autograd in the backward pass's recomputed forward. + self._use_unsharded_views( + as_params=(not in_forward and not in_pre_backward) + ) + elif in_forward: + self._use_unsharded_views(as_params=False) + + def post_unshard(self): + """ + Run the post-unshard logic. + + This includes freeing the low precision shard if needed. + """ + if self._uses_param_mixed_precision and self.uses_sharded_strategy: + self._free_low_precision_sharded_param() + self._check_on_compute_device(self.flat_param) + + def _free_low_precision_sharded_param(self): + """Frees the low precision sharded flat parameter.""" + self._check_low_precision_shard() + # `_mp_shard` is allocated in the pre-unshard stream, consumed in the + # unshard stream for sharded strategies, and consumed in both the + # unshard and default streams for `NO_SHARD`. For sharded strategies, + # the current stream here is the unshard stream, and for `NO_SHARD`, + # it is the default stream. For `NO_SHARD`, only recording for the + # default stream suffices since the default stream waits for the + # unshard stream. + _no_dispatch_record_stream( + self.flat_param._mp_shard, + self._device_handle.current_stream(), # type: ignore[attr-defined] + ) + _free_storage(self.flat_param._mp_shard) # type: ignore[attr-defined] + + @torch.no_grad() + def unshard_grad(self): + """ + Unshard the handle's ``FlatParameter``'s gradient. + + If all ranks have + ``None`` gradient, then all original parameters will as well. This + method performs an all-reduce and an all-gather. The additional + all-reduce is tolerable since this method is not meant to be used on + the computation critical path. + + Postcondition: ``_saved_grad_shard`` is defined and contains the value + to set ``flat_param.grad`` after gradients are resharded. + """ + if not self.uses_sharded_strategy: + self._use_unsharded_grad_views() + return + flat_param = self.flat_param + self._check_unsharded(flat_param) + + # Check if all ranks have a `None` gradient + num_grad_none = torch.zeros(1, dtype=torch.int32, device=self.device) + num_grad_none[0] = flat_param.grad is None + dist.all_reduce(num_grad_none, group=self.process_group) + if num_grad_none[0] == self.world_size: + flat_param._saved_grad_shard = None # type: ignore[assignment] + self._use_unsharded_grad_views() + return + + if flat_param.grad is None: + # In the case that only some ranks have `None` gradient, we use + # zeros to approximate as a best effort attempt + if self._debug_level == dist.DebugLevel.INFO: + warnings.warn( + f"[Rank {self.rank}] Only some but not all ranks have a " + "`None` `FlatParameter` gradient, so FSDP is using zeros to " + "approximate those ranks' sharded gradients being `None`", + stacklevel=2, + ) + flat_param._saved_grad_shard = None # type: ignore[assignment] + sharded_grad = torch.zeros(flat_param._sharded_size, device=self.device) # type: ignore[attr-defined] + else: + self._check_sharded(flat_param.grad) + flat_param._saved_grad_shard = flat_param.grad # type: ignore[attr-defined] + sharded_grad = flat_param._saved_grad_shard # type: ignore[attr-defined] + padded_unsharded_grad = torch.empty( + flat_param._padded_unsharded_size, # type: ignore[attr-defined] + device=self.device, + dtype=sharded_grad.dtype, + ) + dist.all_gather_into_tensor( + padded_unsharded_grad, sharded_grad, self.process_group + ) + unsharded_size = self.flat_param._unpadded_unsharded_size + flat_param.grad = padded_unsharded_grad[: unsharded_size.numel()].view( + unsharded_size + ) + self._use_unsharded_grad_views() + + def reshard_grad(self): + if self._use_orig_params: + self._use_sharded_grad_views() + if not self.uses_sharded_strategy: + return + self.flat_param.grad = self.flat_param._saved_grad_shard # type: ignore[attr-defined] + delattr(self.flat_param, "_saved_grad_shard") + + def prepare_gradient_for_backward(self): + """ + Prepare the gradient for the backward computation. + + This is done by saving and clearing any existing sharded gradient + in ``.grad`` to enable computing a new unsharded gradient. + """ + _p_assert( + self._training_state + in (HandleTrainingState.BACKWARD_PRE, HandleTrainingState.IDLE), + "Expects to be in `BACKWARD_PRE` or `IDLE` (if prefetching)", + ) + flat_param = self.flat_param + if flat_param.grad is not None and ( + flat_param.grad.size() != flat_param._unpadded_unsharded_size + or flat_param.grad.device != flat_param.device # grad on CPU + ): + self._check_on_compute_device(self.flat_param) + grad_offloaded = flat_param.grad.device != self.device + _p_assert( + not grad_offloaded or self._offload_params, + f"Expects the sharded gradient to be on {self.device} " + f"but got {flat_param.grad.device}", + ) + prev_iter_synced_gradients = ( + flat_param.grad.size() == flat_param._local_shard.size() # type: ignore[attr-defined] + ) + if prev_iter_synced_gradients: + # TODO (awgu): Gradient accumulation outside `no_sync()` + # does not work with CPU offloading. The issue should be + # that, in the post-backward hook, we cannot do an addition + # between a CPU tensor (the existing sharded gradient) and + # a GPU tensor (the new sharded gradient). + if not grad_offloaded: + flat_param._saved_grad_shard = flat_param.grad.data # type: ignore[attr-defined] + sharded_grad = flat_param._saved_grad_shard # type: ignore[attr-defined] + else: + _p_assert( + hasattr(flat_param, "_cpu_grad"), + "`_cpu_grad` should be defined if the gradient is on CPU", + ) + sharded_grad = flat_param._cpu_grad # type: ignore[attr-defined] + # If user specified to keep the gradient in low precision, then + # the gradient may still be of the low precision dtype if the + # user did not set the gradient to `None` after the previous + # backward, in which case FSDP should cast back to the full + # precision dtype so that FSDP can accumulate in that dtype in + # the post-backward hook and assign to `.grad` in that dtype in + # the post-backward callback. + local_shard_dtype = flat_param._local_shard.dtype # type: ignore[attr-defined] + if ( + self._keep_low_precision_grads + and sharded_grad.dtype != local_shard_dtype + ): + sharded_grad.data = sharded_grad.to(local_shard_dtype) + else: + padded_unsharded_size = flat_param._padded_unsharded_size # type: ignore[attr-defined] + _p_assert( + flat_param.grad.size() == padded_unsharded_size, + "Expects `.grad` to be the unsharded gradient in " + f"`no_sync()` with size {padded_unsharded_size} " + f"but got size {flat_param.grad.size()}", + ) + flat_param.grad = None + + def prepare_gradient_for_optim(self): + """Prepare the gradient for optimizer computation by moving the sharded gradient to the ``.grad`` attribute.""" + + def cast_grad_to_param_dtype_if_needed(flat_param): + # TODO (rohan-varma): test for full precision with keep_low_precision_grads + if not self._force_full_precision and self._keep_low_precision_grads: + _p_assert(flat_param.grad is not None, "Unexpected None grad!") + if flat_param.grad.dtype != self._fwd_bwd_param_dtype: + flat_param.grad.data = flat_param.grad.to(self._fwd_bwd_param_dtype) + if self._use_orig_params: + self._use_sharded_grad_views() + + flat_param = self.flat_param + # TODO (awgu): We should replace these conditional checks to encode + # the logical intention more directly. + if hasattr(flat_param, "_cpu_grad"): + # NOTE: This branch includes `NO_SHARD`. + self._check_sharded(flat_param) + self._check_on_cpu(flat_param) + flat_param.grad = flat_param._cpu_grad # type: ignore[attr-defined] + cast_grad_to_param_dtype_if_needed(flat_param) + elif hasattr(flat_param, "_saved_grad_shard"): + self._check_sharded(flat_param) + self._check_on_compute_device(flat_param) + if flat_param._saved_grad_shard is not None: + self._check_on_compute_device(flat_param._saved_grad_shard) # type: ignore[attr-defined] + # If no sharded gradient was computed this iteration, then there is + # no need to forward `_saved_grad_shard` to `grad` + if flat_param._post_backward_called: # type: ignore[attr-defined] + flat_param.grad = flat_param._saved_grad_shard # type: ignore[attr-defined] + if flat_param.grad is not None: + cast_grad_to_param_dtype_if_needed(flat_param) + else: + _p_assert( + not self.uses_sharded_strategy or not flat_param._post_backward_called, # type: ignore[attr-defined] + "All sharded parameters that received a gradient in the " + "post-backward should use `_saved_grad_shard`", + ) + # Delete `_saved_grad_shard` since its existence indicates a previous + # gradient to accumulate with in the post-backward hook + if hasattr(flat_param, "_saved_grad_shard"): + delattr(flat_param, "_saved_grad_shard") + + @contextlib.contextmanager + def to_cpu(self): + """ + Move the unpadded unsharded flat parameter to CPU while in the context and moves it back to the previous device upon exit. + + For now, this assumes the ``FlatParameter`` is the unpadded unsharded flat parameter + since (1) there is no reason to include the padding in the copy and (2) + there is no use case for the sharded flat parameter. + + Precondition: ``self.flat_param`` 's data is the unpadded unsharded + flat parameter on the compute device, and the handle uses a sharded + strategy. + Postcondition: Same as the precondition. + """ + self._check_sharded_strategy() + _p_assert( + self.flat_param.size() == self.flat_param._unpadded_unsharded_size, + f"Expects size {self.flat_param._unpadded_unsharded_size} but got {self.flat_param.size()}", + ) + self._check_on_compute_device(self.flat_param) + # Check that the unpadded unsharded flat parameter is a view into the + # padded unsharded flat parameter as expected + # NOTE: This check is not strictly needed for correctness but is a + # useful sanity check since the tensor should only be used internally. + _p_assert( + _same_storage(self.flat_param, self._get_padded_unsharded_flat_param()), + "Expects the unpadded parameter to be a view into the padded parameter", + ) + self.flat_param_to(torch.device("cpu")) + self._free_unsharded_flat_param() + try: + yield + finally: + _p_assert( + self.flat_param.size() == self.flat_param._unpadded_unsharded_size, + f"Expects size {self.flat_param._unpadded_unsharded_size} but got {self.flat_param.size()}", + ) + padded_unsharded_flat_param = self._alloc_padded_unsharded_flat_param() + # Copy from CPU to the compute device + padded_unsharded_flat_param[: self.flat_param.numel()].copy_( + self.flat_param + ) + self._use_unsharded_flat_param(padded_unsharded_flat_param) + + def reshard(self, free_unsharded_flat_param: bool): + """ + Run the reshard logic. + + This includes freeing the unsharded flat + parameter if ``free_unsharded_flat_param`` and switching to using the + sharded flat parameter. Note that this also implicitly offloads + the sharded flat parameter (if CPU offload is enabled) by pointing + it to the ``_local_shard`` attribute which resides on CPU. + """ + # Switch to the sharded `FlatParameter` before freeing to prevent + # "use-after-free"-type bugs with external profiling tools, where for + # `use_orig_params=True`, the `param` does not point to valid memory + # when setting `param.data = ...` in `_use_sharded_views()`. + self._use_sharded_flat_param() + if free_unsharded_flat_param: + self._free_unsharded_flat_param() + + def post_reshard(self): + """ + Run the post-reshard logic. + + This includes freeing any memory that + can now be freed given that the ``FlatParameter`` points to the full + precision sharded flat parameter. + + Precondition: ``self.flat_param`` 's data points to the full precision + sharded flat parameter. + """ + # For `NO_SHARD`, `_mp_shard` is not freed in the post-unshard since it + # is also the low precision *unsharded* flat parameter. Hence, we delay + # the free until the reshard. + if ( + self._uses_param_mixed_precision + and not self.uses_sharded_strategy + and not self._force_full_precision # did not use the low precision shard + ): + self._free_low_precision_sharded_param() + + def _free_unsharded_flat_param(self): + """ + Free the padded unsharded flat parameter. We allow this + function to be called even when storage is not allocated + + The tensor to free depends + on the calling context since the unshard may have forced full + precision, in which case a different tensor is used. + """ + self._check_sharded_strategy() + unsharded_flat_param = self._get_padded_unsharded_flat_param() + self._check_on_compute_device(unsharded_flat_param) + # Do not free the memory until all ops in the current stream finish + _no_dispatch_record_stream( + unsharded_flat_param, self._device_handle.current_stream() + ) + _free_storage(unsharded_flat_param) + + def _use_sharded_flat_param(self) -> None: + """Switches to using the sharded flat parameter.""" + flat_param = self.flat_param + if self._use_orig_params: + in_forward = self._training_state == HandleTrainingState.FORWARD + skip_use_sharded_views = ( + torch.is_grad_enabled() + and in_forward + and self._sharding_strategy + in NO_RESHARD_AFTER_FORWARD_HANDLE_STRATEGIES + ) + # Only incur the extra `.data` call if needed + if skip_use_sharded_views: + unsharded_flat_param = flat_param.data + if self._offload_params: + device = flat_param._local_shard.device # type: ignore[attr-defined] + _p_assert( + device == torch.device("cpu"), + f"Expects the local shard to be on CPU but got {device}", + ) + flat_param.data = flat_param._local_shard # type: ignore[attr-defined] + if self._use_orig_params: + if skip_use_sharded_views: # type: ignore[possibly-undefined] + self._unsharded_flat_param_for_skipped_views = unsharded_flat_param # type: ignore[possibly-undefined] + else: + self._use_sharded_views() + # For the post-forward reshard, we may try to use sharded gradient + # views (or unsharded gradient views if a gradient was accumulated + # in `no_sync()`), but for the post-backward reshard, we delay the + # call to after the reduce-scatter. + if ( + in_forward # type: ignore[possibly-undefined] + # Skip using gradient views if skipped using sharded views + # since exposing unsharded parameters with sharded gradients + # may be confusing to the user + and not self._skipped_use_sharded_views + ): + # TODO: Change `_unpadded_unsharded_size` if we change the + # gradient to be computed directly with padding. + accumulated_grad_in_no_sync = ( + flat_param.grad is not None + and self.uses_sharded_strategy + and flat_param.grad.shape == flat_param._unpadded_unsharded_size + ) + if accumulated_grad_in_no_sync: + self._use_unsharded_grad_views() + else: + self._use_sharded_grad_views() + + ######### + # VIEWS # + ######### + @no_type_check + def _get_unflat_views_unaligned( + self, + tensor: Optional[torch.Tensor] = None, + ) -> Iterator[Tensor]: + """ + Return unflattened ``Tensor`` views into ``tensor``. + + If `tensor`` is ``None``, ``flat_param`` is used. The unflattening is based + on ``flat_param`` 's metadata. + + Examples for ``tensor`` include ``flat_param.grad`` or unsharded + tensor optimizer state. + """ + flat_param = self.flat_param + if tensor is None: + tensor = flat_param + views = ( + _ext_post_unflatten_transform( + subtensor.view(shape) + if contiguous + else subtensor.as_strided(shape, stride), + param_extension, + self._fsdp_extension, + ) + for (subtensor, shape, stride, contiguous, param_extension) in zip( + torch.split(tensor, flat_param._numels, dim=0), + flat_param._shapes, + flat_param._strides, + flat_param._contiguities, + flat_param._param_extensions, + ) + ) + return views + + @no_type_check + def _get_unflat_views_aligned( + self, + tensor: Optional[Tensor] = None, + ) -> list[Tensor]: + """ + Return unflattened ``Tensor`` views into ``tensor`` with handling for padding. + + This method has the same contract as :meth:`_get_unflat_views_unaligned` + except it checks for ``None`` placeholders representing padding for + alignment, which may incur slightly more CPU overhead. + """ + flat_param = self.flat_param + if tensor is None: + tensor = flat_param + splits: list[Tensor] = torch.split( + tensor, flat_param._numels_with_padding, dim=0 + ) + idx = 0 + views: list[Tensor] = [] + for split, is_padding in zip(splits, flat_param._is_padding_mask): + if is_padding: + continue + views.append( + _ext_post_unflatten_transform( + split.view(flat_param._shapes[idx]) + if flat_param._contiguities[idx] + else split.as_strided( + flat_param._shapes[idx], flat_param._strides[idx] + ), + flat_param._param_extensions[idx], + self._fsdp_extension, + ) + ) + idx += 1 + return views + + @no_type_check + @torch.enable_grad() + def _use_unsharded_views(self, as_params: bool) -> None: + """ + Unflatten the unsharded flat parameter by setting the original parameter variables to be views into it. + + Args: + as_params (bool): If ``True``, then registers the original + parameters as ``nn.Parameter`` s; if ``False``, then registers + the original parameters only as ``Tensor`` s. ``False`` should + be used during forward/backward computation and when hiding the + original parameters from :meth:`nn.Module.named_parameters`. + + Note: + when prefetching for next forward, current forward may be + annotated with `@torch.no_grad()` + `@torch.enable_grad()` ensures non-empty `view.grad_fn` + otherwise `_post_backward_hook` will not get called + """ + flat_param = self.flat_param + self._check_unsharded(flat_param) + views = self._get_unflat_views() + from torch.distributed.tensor import DTensor + + for i, (view, (param_name, module, _)) in enumerate( + zip(views, flat_param._param_infos) + ): + if self._use_orig_params and as_params: + if type(view) is DTensor: + # A `DTensor` `view` is not compatible with assigning + # `param.data = view`, so we cannot preserve the parameter + # variable. + self._setattr_param( + module, + param_name, + nn.Parameter(view, requires_grad=flat_param.requires_grad), + ) + continue + param = self.flat_param._params[i] + self._setattr_param(module, param_name, param) + param.data = view + elif as_params: + self._setattr_param( + module, + param_name, + nn.Parameter(view, requires_grad=flat_param.requires_grad), + ) + else: # `as_params=False` + param_var: Tensor = view + if self._use_orig_params: + if self._training_state == HandleTrainingState.FORWARD: + # Save the `Tensor` for the pre-backward + self.flat_param._tensors[i] = view # save for pre-backward + elif self._training_state == HandleTrainingState.BACKWARD_PRE: + # Use the saved `Tensor` variable from the forward to + # preserve the autograd graph so that the post-backward + # hook fires (e.g. for reentrant AC) + tensor = self.flat_param._tensors[i] + tensor.data = view + param_var = tensor + self._setattr_tensor(module, param_name, param_var) + if ( + self._use_orig_params + and self._training_state == HandleTrainingState.FORWARD + ): + module._parameters[param_name] = param_var + for i, ( + param_name, + module, + _, + prim_param_name, + prim_module, + _, + ) in enumerate(self.flat_param._shared_param_infos): + prim_param: Union[Tensor, nn.Parameter] = getattr( + prim_module, prim_param_name + ) + _p_assert( + not as_params or isinstance(prim_param, nn.Parameter), + f"as_params={as_params} type(prim_param)={type(prim_param)}", + ) + if self._use_orig_params and as_params: + shared_param = self.flat_param._shared_params[i] + self._setattr_param(module, param_name, shared_param) + shared_param.data = prim_param + elif as_params: + self._setattr_param(module, param_name, prim_param) + else: + self._setattr_tensor(module, param_name, prim_param) + if ( + self._use_orig_params + and self._training_state == HandleTrainingState.FORWARD + ): + module._parameters[param_name] = prim_param + + @no_type_check + def _use_unsharded_grad_views(self) -> None: + """ + Unflatten the unsharded flat parameter's gradient. + + The original parameter variables' gradients are set to be views into + the unsharded flat parameter's gradient. + """ + # Expects the gradient to be in `flat_param.grad` + if self.flat_param.grad is None: + for param in chain(self.flat_param._params, self.flat_param._shared_params): + param.grad = None + return + self._check_unsharded(self.flat_param.grad) + views = self._get_unflat_views(self.flat_param.grad) + for i, (view, (param_name, module, _)) in enumerate( + zip(views, self.flat_param._param_infos) + ): + _p_assert( + hasattr(module, param_name), + f"{self.flat_param._fqns[i]} is missing", + ) + param = getattr(module, param_name) + if ( + param.shape != view.shape + or param.dtype != view.dtype + or param.device != view.device + ): + # NOTE: This is a hack using `.data` to side step the check + # that parameter/gradient sizes/dtypes/devices match. From + # calling `reshard()`, `param` has the sharded size, has the + # full precision dtype, and if CPU offloading is enabled, is on + # CPU. Thus, one or more of the following cases can hold when + # in `no_sync()`, where `view` is the original parameter's + # gradient: + # 1. `view` can have the unsharded size. + # 2. `view` can have the parameter low precision dtype. + # 3. `view` can be on GPU. + if param.grad is None: + param.grad = torch.empty_like(param) + param.grad.data = view + else: + param.grad = view + for ( + param_name, + module, + module_name, + prim_param_name, + prim_module, + _, + ) in self.flat_param._shared_param_infos: + _p_assert( + hasattr(module, param_name), + f"{module_name + '.' + param_name if module_name else param_name} is missing", + ) + param = getattr(module, param_name) + prim_param = getattr(prim_module, prim_param_name) + if ( + param.shape != prim_param.grad.shape + or param.dtype != prim_param.grad.dtype + or param.device != prim_param.grad.device + ): + # NOTE: This is the same hack to use `.data` to side step the + # size check. + if param.grad is None: + param.grad = torch.empty_like(param) + param.grad.data = prim_param.grad + else: + param.grad = prim_param.grad + + @contextlib.contextmanager + def unflatten_as_params(self) -> Generator: + """ + Unflatten the original parameters. + + The function assumes that the flat parameter is unsharded. When in the context, + unflattens the original parameters as ``nn.Parameter`` views into the + flat parameter, and after the context, restores the original parameters + as ``Tensor`` views into the flat parameter. + """ + self._use_unsharded_views(as_params=True) + try: + yield + finally: + self._use_unsharded_views(as_params=False) + + @no_type_check + @torch.no_grad() + def _use_sharded_views(self) -> None: + """ + Set the original parameter variables' data to be flattened views into the sharded flat parameter. + + The views are kept as flattened to simplify the case where a parameter + is sharded across ranks. Parameters whose data is not present in the + sharded flat parameter have their data set to a size-0 empty tensor. We + do not delete them to ensure to preserve expected behaviors like model + printability. Parameters whose data is present must preserve their + variables to be passable to an optimizer. + """ + self._unsharded_flat_param_for_skipped_views = None + if not self.uses_sharded_strategy: + # For `NO_SHARD`, use the *unflattened* unsharded views since we + # have the unsharded parameter + self._use_unsharded_views(as_params=True) + return + flat_param = self.flat_param + self._check_sharded(flat_param) + # Construct once and reuse for all parameters not in the local shard + size_0_empty_tensor = torch.empty( + 0, + dtype=self.flat_param.dtype, # in case `flat_param` changed dtype + device=self.flat_param.device, + requires_grad=False, + ) + for param, shard_param_info, (param_name, module, _) in zip( + flat_param._params, flat_param._shard_param_infos, flat_param._param_infos + ): + self._setattr_param(module, param_name, param) + if not shard_param_info.in_shard: + # Allow the original data to be freed via garbage collection + param.data = size_0_empty_tensor + else: + offset = shard_param_info.offset_in_shard + numel_in_shard = shard_param_info.numel_in_shard + param.data = flat_param[offset : offset + numel_in_shard] + if self.flat_param._shared_params is None: + raise AssertionError("Expected _shared_params to be not None") + for param, (param_name, module, _, prim_param_name, prim_module, _) in zip( + self.flat_param._shared_params, self.flat_param._shared_param_infos + ): + self._setattr_param(module, param_name, param) + prim_param = getattr(prim_module, prim_param_name) + param.data = prim_param # could be both empty and non-empty + if self._training_state == HandleTrainingState.BACKWARD_POST: + # Clear the saved `Tensor`s since they are unneeded now + for i in range(len(self.flat_param._tensors)): + self.flat_param._tensors[i] = None + + @no_type_check + @torch.no_grad() + def _use_sharded_grad_views(self) -> None: + """ + Set the original parameter variables' gradients to be flattened views into the sharded flat parameter's gradient. + + This is a no-op if there is no gradient. + + Parameters whose data is not present in the sharded flat parameter and + parameters with ``requires_grad=False`` have their gradients set to + ``None``. Since the gradient variables do not need to be preserved, + this method does not manipulate existing ``Tensor`` data directly and + creates new ``Tensor`` variables instead. + """ + flat_param = self.flat_param + self._check_sharded(flat_param) + grad = self.sharded_grad + if grad is None: + for param in chain(flat_param._params, flat_param._shared_params): + param.grad = None + return + self._check_sharded(grad) + for param, shard_param_info, is_grad_none in zip( + flat_param._params, + flat_param._shard_param_infos, + flat_param._is_grad_none_mask, + ): + if not shard_param_info.in_shard: + param.grad = None + else: + numel_in_shard = shard_param_info.numel_in_shard + if param.requires_grad and not is_grad_none: + offset = shard_param_info.offset_in_shard + if self._keep_low_precision_grads or param.dtype != grad.dtype: + # NOTE: This is a hack using `.data` to side step the + # check that parameter/gradient dtypes match. Here, + # `param` has full precision; `grad` has low precision. + if param.grad is None: + # `.grad` must have the same shape as `param` + param.grad = torch.empty_like(param) + param.grad.data = grad[ + offset : offset + numel_in_shard + ].reshape(param.shape) + else: + param.grad = grad[offset : offset + numel_in_shard].reshape( + param.shape + ) + else: + param.grad = None + if flat_param._shared_params is None: + raise AssertionError("Expected _shared_params to be not None") + for param, (_, _, _, prim_param_name, prim_module, _) in zip( + flat_param._shared_params, flat_param._shared_param_infos + ): + in_sharded_flat_param = hasattr(prim_module, prim_param_name) + if in_sharded_flat_param and param.requires_grad: + prim_param = getattr(prim_module, prim_param_name) + param.grad = prim_param.grad # share the same reference + else: + param.grad = None + + @no_type_check + @torch.no_grad() + def _writeback_orig_params(self) -> bool: + """ + Write back any parameters that changed storage to the handle's ``FlatParameter``. + + Iterates over the original parameters and writes back any parameters + that changed storages (due to a non-inplace operator) to the handle's + ``FlatParameter``. This method preserves the ``FlatParameter` 's + device even if an original parameter's device changes. + + Raises: + RuntimeError: If an original parameter or gradient changes storages + but no longer has the expected flattened shape. + Returns: ``True`` if some writeback happened, and ``False`` otherwise. + """ + if ( + self.uses_sharded_strategy + and not self.is_sharded(self.flat_param) + and not self._skipped_use_sharded_views + ): + # For `NO_SHARD`, we may still need to writeback + return False + flat_param = self.flat_param + wroteback = False + if self._skipped_use_sharded_views and self.uses_sharded_strategy: + # NOTE: We must use the unsharded flat parameter from which the + # unsharded views were computed, not the one from the current + # calling context (`_get_padded_unsharded_flat_param()`) since that + # may be different (e.g. the model changed from train to eval). + flat_param_tensor = self._unsharded_flat_param_for_skipped_views + _p_assert( + _data_ptr_allocated(flat_param_tensor), + "If skipped using sharded views, the unsharded flat parameter " + "should be allocated", + ) + else: + flat_param_tensor = flat_param + # NOTE: Since this method is called in the pre-unshard, which is only + # called during computation in the pre-forward or pre-backward, the + # sharded gradient should be guaranteed to be in `.grad`, not in + # `._saved_grad_shard`. + flat_param_grad = ( + flat_param.grad + if self.uses_sharded_strategy or not self._offload_params + else flat_param._cpu_grad + ) + for i, ( + param, + (in_shard, offset_in_shard, numel_in_shard, _, _), + (param_name, module, _), + ) in enumerate( + zip( + flat_param._params, + flat_param._shard_param_infos, + flat_param._param_infos, + ) + ): + if not in_shard: + continue + if not hasattr(module, param_name): + # Do not writeback if original parameters are deregistered + # (e.g. during model checkpointing) + continue + + # Check for parameter writeback + if self._skipped_use_sharded_views: + param = flat_param._tensors[i] + _p_assert( + param is not None, + f"Expects to have saved tensor for {flat_param._fqns[i]}", + ) + param_changed = getattr(module, param_name) is not param + needs_param_writeback = ( + param_changed # changed parameter variable itself + or not _same_storage(param, flat_param_tensor) + ) + if self._skipped_use_sharded_views and ( + param_changed or needs_param_writeback + ): + raise AssertionError( + "FSDP does not support changing the parameters between " + f"forward and backward for {self._sharding_strategy}" + ) + if param_changed: + # NOTE: The gradient is not preserved after a parameter change. + param = getattr(module, param_name) + flat_param._params[i] = param + if needs_param_writeback: + expected_shape = torch.Size([numel_in_shard]) + src = param if self.uses_sharded_strategy else param.view(-1) + self._writeback_tensor( + src, flat_param, i, expected_shape, offset_in_shard, True + ) + wroteback = True + + # Check for gradient writeback + if self._skipped_use_sharded_views: + # Skip the writeback check because we do not expose gradients + # when we skipped using sharded views + continue + if param.grad is None and flat_param.grad is not None: + expected_shape = torch.Size([numel_in_shard]) + self._writeback_tensor( + None, flat_param.grad, i, expected_shape, offset_in_shard, False + ) + elif param.grad is not None: + # For `NO_SHARD` + CPU offloading, `_cpu_grad` is always in + # memory and owns the gradient storage, so it will never + # require gradient writeback. + if not self.uses_sharded_strategy and self._offload_params: + # Explicitly continue to handle the case of `no_sync()`, + # where `param.grad` is a view into the GPU gradient + # referenced by `flat_param.grad`, while `flat_param_grad` + # is `flat_param._cpu_grad`, which is on CPU + continue + + needs_grad_writeback = flat_param_grad is None or not _same_storage( + param.grad, flat_param_grad + ) + if needs_grad_writeback: + if flat_param_grad is None: + flat_param_grad = torch.zeros_like(flat_param) + expected_shape = torch.Size([numel_in_shard]) + src = ( + param.grad + if self.uses_sharded_strategy + else param.grad.view(-1) + ) + self._writeback_tensor( + src, + flat_param_grad, + i, + expected_shape, + offset_in_shard, + False, + ) + flat_param.grad = flat_param_grad + flat_param_grad = flat_param.grad + + # TODO: If we want to handle shared parameters, we need to re-generate + # the shared parameter data structures in case sharedness changed. + for ( + param_name, + module, + _, + prim_param_name, + prim_module, + _, + ) in flat_param._shared_param_infos: + if getattr(module, param_name) is not getattr(prim_module, prim_param_name): + raise NotImplementedError( + "Changing shared parameters is not supported yet" + ) + return wroteback + + def _writeback_tensor( + self, + src_tensor: Optional[Tensor], + dst_tensor: Tensor, + tensor_index: int, + expected_shape: torch.Size, + offset: int, + is_param: bool, # else gradient + ) -> None: + """ + Write back ``src_tensor`` to ``dst_tensor`` at offset ``offset``, where ``src_tensor`` should have shape ``expected_shape``. + + ``is_param`` indicates if the tensor is the parameter (if ``True``) or gradient (if + ``False``). If ``src_tensor`` is ``None``, then the effect is zeroing + instead of copying. ``tensor_index`` gives the index of ``src_tensor`` + in the metadata structures. + + Raises: + RuntimeError: If the ``src_tensor`` does not have the expected + shape. + """ + _p_assert( + len(expected_shape) == 1, + f"Expects a 1D expected shape but got {expected_shape}", + ) + if self._debug_level == dist.DebugLevel.INFO: + rank = self.rank if hasattr(self, "rank") else dist.get_rank() + src_shape = src_tensor.shape if src_tensor is not None else None + src_device = src_tensor.device if src_tensor is not None else None + warnings.warn( + f"[Rank {rank}] {'Parameter' if is_param else 'Gradient'} needs " + f"writeback in {self._training_state}\n" + f"expected shape={expected_shape} shape={src_shape} " + f"expected device={dst_tensor.device} device={src_device}", + stacklevel=2, + ) + if src_tensor is not None and src_tensor.shape != expected_shape: + # NOTE: Gradient shape mismatch is not possible in practice since + # the gradient shape is enforced to match that of the parameter and + # we already check for parameter shape mismatch. + raise RuntimeError( + f"Cannot writeback when the {'parameter' if is_param else 'gradient'} " + f"shape changes\nExpects {expected_shape} but got {src_tensor.shape}" + ) + if src_tensor is not None: + dst_tensor[offset : offset + expected_shape.numel()].copy_(src_tensor) + else: + dst_tensor[offset : offset + expected_shape.numel()].zero_() + if self.flat_param._is_grad_none_mask is None: + raise AssertionError("Expected _is_grad_none_mask to be not None") + self.flat_param._is_grad_none_mask[tensor_index] = True + + def _reset_flat_param_grad_info_if_needed(self): + """ + Reset ``flat_param.grad`` if needed. + + When ``use_orig_params=True``: + (1) sets the underlying ``flat_param.grad`` to ``None`` if *all* of the + original parameters' ``.grad`` are ``None``, and + (2) sets ``flat_param.requires_grad=False`` if *none* of the original + parameters require gradient. + For (1), this is targeting ``optim.zero_grad(set_to_none=True)``, in + which case we want to free the gradients as soon after the + ``zero_grad()`` call as possible. + """ + if not self._use_orig_params: + return + flat_param = self.flat_param + if flat_param._params is None: + raise AssertionError("Expected _params to be not None") # mypy + all_grad_none = True + requires_grad = False + for param in flat_param._params: + all_grad_none &= param.grad is None + requires_grad |= param.requires_grad + if all_grad_none: + flat_param.grad = None + # As long as one parameter requires gradient, then the flat parameter + # must require gradient + flat_param.requires_grad = requires_grad + + def _deregister_orig_params(self): + for param_info in self.flat_param._param_infos: + param_name, module, _ = param_info + if hasattr(module, param_name): + delattr(module, param_name) + for param_name, module, _, _, _, _ in self.flat_param._shared_param_infos: + if hasattr(module, param_name): + delattr(module, param_name) + + ########### + # HELPERS # + ########### + def flat_param_to(self, *args, **kwargs): + """Wrap an in-place call to ``.to()`` for ``self.flat_param``.""" + # pyrefly: ignore [not-iterable] + self.flat_param.data = self.flat_param.to(*args, **kwargs) + if self._use_orig_params: + # Refresh the views because their storage may have changed + if self.is_sharded(self.flat_param): + self._use_sharded_views() + else: + self._use_unsharded_views(as_params=True) + + def _get_modules(self) -> set[nn.Module]: + """Return a :class:`set` of the modules whose parameters are included in this handle's flat parameter.""" + return {pi.module for pi in self.flat_param._param_infos}.union( + {spi.module for spi in self.flat_param._shared_param_infos} + ) + + def is_sharded(self, tensor: Tensor) -> bool: + """ + Return whether ``tensor`` is *currently* sharded. + + For ``NO_SHARD``, we choose to have this always return ``False`` for clarity. + """ + if ( + not hasattr(self.flat_param, "_sharded_size") + or not self.uses_sharded_strategy + ): + # `_sharded_size` is defined iff `handle.shard()` has been called + return False + sharded_size = self.flat_param._sharded_size # type: ignore[attr-defined] + return tensor.size() == sharded_size + + def param_module_names(self) -> Iterator[tuple[str, str]]: + shared_param_infos = [ + ParamInfo(param_name, module, module_name) + for ( + param_name, + module, + module_name, + _, + _, + _, + ) in self.flat_param._shared_param_infos + ] + for param_info in chain(self.flat_param._param_infos, shared_param_infos): + param_name, _, module_name = param_info # type: ignore[misc] + yield (param_name, module_name) + + def shared_param_module_names(self) -> Iterator[tuple[str, str]]: + for param_name, _, module_name in [ + ParamInfo(param_name, module, module_name) + for ( + param_name, + module, + module_name, + _, + _, + _, + ) in self.flat_param._shared_param_infos + ]: + yield (param_name, module_name) + + @property + def _fqns_in_shard(self) -> list[str]: + """Return the FQNs of the parameters present in this rank's shard.""" + fqns_in_shard: list[str] = [] + for fqn, shard_param_info in zip( + self.flat_param._fqns, + self.flat_param._shard_param_infos, # type: ignore[attr-defined] + ): + if shard_param_info.in_shard: + fqns_in_shard.append(fqn) + return fqns_in_shard + + @property + def sharded_grad(self) -> Optional[Tensor]: + """Return the handle's sharded gradient.""" + flat_param = self.flat_param + # Priority for non-`None`: `_cpu_grad` > `_saved_grad_shard` > `grad` + # - CPU offloading: `_cpu_grad` + # - No CPU offloading + sharded strategies: `_saved_grad_shard` + # - No CPU offloading + `NO_SHARD`: `grad` + grad: Optional[Tensor] + if hasattr(flat_param, "_cpu_grad"): + grad = flat_param._cpu_grad # type: ignore[attr-defined] + elif hasattr(flat_param, "_saved_grad_shard"): + # In the post-backward hook, the sharded gradient is still in + # `_saved_grad_shard`. + grad = flat_param._saved_grad_shard # type: ignore[attr-defined] + else: + # If in IDLE or in FORWARD states, then there may be an + # (accumulated) gradient. If accessed in IDLE, then this should + # be due to re-registering the original parameters (e.g. in state + # dict load). + _p_assert( + flat_param.grad is None + or not self.uses_sharded_strategy + or self._training_state + in (HandleTrainingState.FORWARD, HandleTrainingState.IDLE), + "Sharded strategies should use `_cpu_grad` or `_saved_grad_shard` " + "unless in IDLE or FORWARD", + ) + grad = flat_param.grad + return grad + + def _reset_is_grad_none(self) -> None: + """ + Reset ``_is_grad_none_mask`` as needed. + + This method should only be + called in the post-backward after gradient computation, in which case + if a parameter requires gradient, then it will surely receive a + gradient and we may reset its mask entry to ``False``. + """ + if not self._use_orig_params: + return + _p_assert( + self._training_state == HandleTrainingState.BACKWARD_POST, + "Expects to only be called in the post-backward after gradient computation", + ) + flat_param = self.flat_param + if flat_param._params is None: + raise AssertionError("Expected _params to be not None") # mypy + for i, param in enumerate(flat_param._params): # type: ignore[arg-type] + # As long as the parameter requires gradient, it should receive a + # meaningful gradient (even if the gradient happens to be zeros) + if param.requires_grad: + if flat_param._is_grad_none_mask is None: + raise AssertionError( + "Expected _is_grad_none_mask to be not None" + ) # mypy + flat_param._is_grad_none_mask[i] = False + + ####################### + # CHECKS & INVARIANTS # + ####################### + def _check_sharded_strategy(self): + _p_assert(self.uses_sharded_strategy, "Expects sharded strategy") + + def _check_on_compute_device(self, tensor: Tensor): + _p_assert( + tensor.device == self.device, + f"Expects tensor to be on the compute device {self.device}, was on {tensor.device}", + ) + + def _check_on_cpu(self, tensor: Tensor): + _p_assert( + tensor.device == torch.device("cpu"), + f"Expects tensor to be on CPU but got {tensor.device}", + ) + + @staticmethod + def _check_storage_freed(tensor: Tensor): + # Compile does not resize during trace + if not torch.distributed._functional_collectives.is_torchdynamo_compiling(): + _p_assert( + _same_storage_size(tensor, 0), + "Expects storage to be freed but got storage with size > 0", + ) + + @staticmethod + def _check_storage_allocated(tensor: Tensor): + _p_assert(_storage_size_allocated(tensor), "Expects storage to be allocated") + + def _check_low_precision_shard(self): + _p_assert( + self._uses_param_mixed_precision, + "Not using low precision for parameters", + ) + _p_assert( + getattr(self.flat_param, "_mp_shard", None) is not None, + "Expects `_mp_shard` to exist", + ) + device = self.flat_param._mp_shard.device # type: ignore[attr-defined] + _p_assert( + device == self.device, + f"Expects the low precision shard to be on {self.device} but got {device}", + ) + + def _check_unsharded(self, tensor: Tensor): + msg_prefix = "Expects tensor to be unsharded " + _p_assert(tensor is not None, msg_prefix + "but got `None`") + unsharded_size = self.flat_param._unpadded_unsharded_size + _p_assert( + tensor.size() == unsharded_size, + msg_prefix + f"with size {unsharded_size} but got {tensor.size()}", + ) + + def _check_sharded(self, tensor: Tensor): + msg_prefix = "Expects tensor to be sharded " + _p_assert(tensor is not None, msg_prefix + "but got `None`") + sharded_size = self.flat_param._sharded_size # type: ignore[attr-defined] + _p_assert( + tensor.size() == sharded_size, + msg_prefix + f"with size {sharded_size} but got {tensor.size()}", + ) + + ############## + # PROPERTIES # + ############## + @property + def uses_sharded_strategy(self) -> bool: + return self._sharding_strategy != HandleShardingStrategy.NO_SHARD + + @property + def _uses_param_mixed_precision(self) -> bool: + return self._fwd_bwd_param_dtype != self._orig_param_dtype + + @property + def _uses_reduce_mixed_precision(self) -> bool: + return self._reduce_dtype != self._orig_param_dtype + + @property + def _force_full_precision(self) -> bool: + return ( + self._uses_param_mixed_precision or self._uses_reduce_mixed_precision + ) and ( + self._training_state == HandleTrainingState.SUMMON_FULL_PARAMS + or + # Also disable mixed precision in model eval mode, if configured + (not self._fully_sharded_module.training and self._use_full_prec_in_eval) + ) + + @property + def _skipped_use_sharded_views(self) -> bool: + """ + This property is used for sharding strategies that do not free after forward with ``use_orig_params=True``. + + This returns if this handle is + currently in a state where it has skipped using sharded views, in which + case it can restore view invariants via ``_use_sharded_views()``. + """ + return self._unsharded_flat_param_for_skipped_views is not None + + +# NOTE: These are hacks to bypass `nn.Module.__setattr__` checks. +def _unsafe_setattr_param( + module: nn.Module, param_name: str, param: nn.Parameter +) -> None: + module._parameters[param_name] = param + # This bypasses any overrides in case `module` is an instance of an + # `nn.Module` subclass + super(nn.Module, module).__setattr__(param_name, param) + + +def _unsafe_setattr_tensor(module: nn.Module, param_name: str, tensor: Tensor) -> None: + module._parameters.pop(param_name, None) + # This bypasses any overrides in case `module` is an instance of an + # `nn.Module` subclass + super(nn.Module, module).__setattr__(param_name, tensor) + + +def _safe_setattr_tensor_or_param( + module: nn.Module, param_name: str, tensor_or_param: Union[Tensor, nn.Parameter] +): + # Call `delattr()` and `setattr()` to go through `nn.Module` checks + if hasattr(module, param_name): + delattr(module, param_name) + setattr(module, param_name, tensor_or_param) + + +def _convert_to_params( + tensors: list[Union[torch.Tensor, nn.Parameter]], +) -> list[nn.Parameter]: + return [t if isinstance(t, nn.Parameter) else nn.Parameter(t) for t in tensors] + + +def _is_truly_contiguous(x: Tensor) -> bool: + # Special case: Pytorch thinks that 1x1 channels_last convolution weights are + # both contiguous and channels_last contiguous at the same time. + # CuDNN does not agree though and refuses to select faster kernels. + # It is the reason of having the extra check here. + return x.stride(-1) == 1 and x.is_contiguous() + + +def _detach_if_needed(param_or_tensor: Union[nn.Parameter, Tensor]) -> Tensor: + return ( + param_or_tensor.detach() + if isinstance(param_or_tensor, nn.Parameter) + else param_or_tensor + ) + + +def _get_aligned_numel(unsharded_dtype: torch.dtype): + # NOTE: This alignment constraint comes from TorchInductor. + ALIGNMENT = 16 # bytes + unsharded_dtype_size = _get_dtype_size(unsharded_dtype) + aligned_numel = ALIGNMENT // unsharded_dtype_size + return aligned_numel + + +@functools.lru_cache(8) +def _get_dtype_size(dtype): + return torch.empty((), dtype=dtype).element_size() + + +def _construct_padding_tensor( + padding_numel: int, dtype: torch.dtype, requires_grad: bool, device: torch.device +): + # NOTE: Set the padding value as a magic number for debuggability. The + # value itself should never be used in any user-facing computation. + return ( + torch.ones( + (padding_numel,), dtype=dtype, requires_grad=requires_grad, device=device + ) + * _FLAT_PARAM_PADDING_VALUE + ) + + +# Use `lru_cache(1)` to only log the warning once (assuming the fixed warning +# message is passed in) +@functools.lru_cache(1) +def _warn_skip_writeback_check(log: logging.Logger, warning: str): + logger.warning(warning) + + +# Use `lru_cache(1)` to only log the warning once +@functools.lru_cache(1) +def _warn_use_fake_all_gather(log: logging.Logger, warning: str): + logger.warning(warning) + + +# Use `lru_cache(1)` to only log the warning once +@functools.lru_cache(1) +def _warn_use_fake_reduce(log: logging.Logger, warning: str): + logger.warning(warning) + + +def _same_storage(a, b): + # Params are DTensors in backward + # with SHARD_GRAD_OP + TP + from torch.distributed.tensor import DTensor + + if isinstance(a, DTensor): + a = a._local_tensor + if isinstance(b, DTensor): + b = b._local_tensor + return a.untyped_storage().data_ptr() == b.untyped_storage().data_ptr() + + +def _same_storage_size(a: torch.Tensor, b: int): + return a.untyped_storage().size() // a.element_size() == b + + +def _storage_size_allocated(tensor: Tensor): + storage_size: int = tensor.untyped_storage().size() + return storage_size > 0 diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/fsdp/_fsdp_extensions.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/fsdp/_fsdp_extensions.py new file mode 100644 index 0000000000000000000000000000000000000000..699274ba50f9a57f26120bd15f5c49b4679f0e9e --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/fsdp/_fsdp_extensions.py @@ -0,0 +1,180 @@ +from abc import ABC, abstractmethod +from typing import Any, Optional + +import torch +import torch.distributed as dist +from torch.distributed._shard.sharded_tensor.api import ShardedTensor +from torch.distributed._shard.sharded_tensor.shard import Shard +from torch.distributed.fsdp._shard_utils import ( + _all_gather_dtensor, + _create_chunk_dtensor, + _create_chunk_sharded_tensor, +) +from torch.distributed.tensor import DeviceMesh, DTensor + + +class FSDPExtensions(ABC): + """ + This enables some customizable hooks to enable composability with tensor + parallelism. To activate these hooks, use :func:`_set_fsdp_extensions` to + set a custom :class:`FSDPExtensions` that implements the hooks. + """ + + @abstractmethod + def pre_flatten_transform( + self, + tensor: torch.Tensor, + ) -> tuple[torch.Tensor, Optional[Any]]: + """E.g. converting ``DistributedTensor`` to local tensor.""" + ... + + @abstractmethod + def post_unflatten_transform( + self, + tensor: torch.Tensor, + param_extension: Any, + ) -> torch.Tensor: + """E.g. converting local tensor to ``DistributedTensor``.""" + ... + + @abstractmethod + def chunk_tensor( + self, + tensor: torch.Tensor, + rank: int, + world_size: int, + num_devices_per_node: int, + pg: dist.ProcessGroup, + device: Optional[torch.device] = None, + ) -> torch.Tensor: + """Shards a tensor to chunks and returns the local chunk.""" + ... + + @abstractmethod + def chunk_dtensor( + self, + tensor: torch.Tensor, + rank: int, + device_mesh: DeviceMesh, + ) -> torch.Tensor: + """Shards a tensor/DTensor to DTensor and returns the local DTensor.""" + ... + + @abstractmethod + def pre_load_state_dict_transform( + self, + tensor: torch.Tensor, + ) -> tuple[torch.Tensor, list[Shard]]: + """ + This is to be called before loading a *sharded* model state dict and + should return the tensor and list of shards from which to load data. + """ + ... + + @abstractmethod + def all_gather_dtensor( + self, + tensor: DTensor, + parent_mesh: Optional[DeviceMesh], + ) -> torch.Tensor: + """ + This is to be called before loading a *sharded* DTensor state dict. + This gathers tensor in FSDP dimension and returns local tensor of + TP DTensor. + """ + ... + + +_extensions: Optional[FSDPExtensions] = None + + +def _set_fsdp_extensions(flattener: FSDPExtensions) -> None: + global _extensions + _extensions = flattener + + +def _ext_pre_flatten_transform( + tensor: torch.Tensor, + fsdp_extension: Optional[FSDPExtensions] = None, +) -> tuple[torch.Tensor, Optional[Any]]: + if fsdp_extension is not None: + new_tensor, param_extension = fsdp_extension.pre_flatten_transform(tensor) + if param_extension is not None: + return new_tensor, param_extension + return tensor, None + + +def _ext_post_unflatten_transform( + tensor: torch.Tensor, + param_extension: Any, + fsdp_extension: Optional[FSDPExtensions] = None, +) -> torch.Tensor: + if fsdp_extension is not None and param_extension is not None: + return fsdp_extension.post_unflatten_transform(tensor, param_extension) + return tensor + + +def _ext_chunk_tensor( + tensor: torch.Tensor, + rank: int, + world_size: int, + num_devices_per_node: int, + pg: dist.ProcessGroup, + fsdp_extension: Optional[FSDPExtensions] = None, +) -> torch.Tensor: + chunk_tensor_fn = ( + fsdp_extension.chunk_tensor + if fsdp_extension is not None + else _create_chunk_sharded_tensor + ) + return chunk_tensor_fn( + tensor, + rank, + world_size, + num_devices_per_node, + pg, + ) + + +def _ext_chunk_dtensor( + tensor: torch.Tensor, + rank: int, + device_mesh: DeviceMesh, + fsdp_extension: Optional[FSDPExtensions] = None, +) -> torch.Tensor: + chunk_dtensor_fn = ( + fsdp_extension.chunk_dtensor + if fsdp_extension is not None + else _create_chunk_dtensor + ) + return chunk_dtensor_fn( + tensor, + rank, + device_mesh, + ) + + +def _ext_pre_load_state_dict_transform( + tensor: torch.Tensor, + fsdp_extension: Optional[FSDPExtensions] = None, +) -> tuple[torch.Tensor, list[Shard]]: + if fsdp_extension is not None: + return fsdp_extension.pre_load_state_dict_transform(tensor) + + if type(tensor) is not ShardedTensor: + raise AssertionError(f"Expected ShardedTensor, got {type(tensor)}") + shards = tensor.local_shards() + return (tensor, shards) + + +def _ext_all_gather_dtensor( + tensor: DTensor, + parent_mesh: Optional[DeviceMesh], + fsdp_extension: Optional[FSDPExtensions] = None, +) -> torch.Tensor: + all_gather_dtensor_fn = ( + fsdp_extension.all_gather_dtensor + if fsdp_extension is not None + else _all_gather_dtensor + ) + return all_gather_dtensor_fn(tensor, parent_mesh) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/fsdp/_init_utils.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/fsdp/_init_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..36bdc23e741c0bbee64d4c79e8b1b5e0c553263c --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/fsdp/_init_utils.py @@ -0,0 +1,1206 @@ +# mypy: allow-untyped-defs +import collections +import itertools +import os +import warnings +from collections.abc import Callable, Generator, Iterable, Iterator +from typing import Any, no_type_check, Optional, TYPE_CHECKING, Union + +import torch +import torch.distributed as dist +import torch.distributed.fsdp._exec_order_utils as exec_order_utils +import torch.distributed.fsdp._traversal_utils as traversal_utils +import torch.distributed.fsdp.fully_sharded_data_parallel as fsdp_file +import torch.nn as nn +from torch.distributed.algorithms._comm_hooks import default_hooks +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.distributed_c10d import _get_default_group +from torch.distributed.fsdp._common_utils import ( + _FSDPDeviceHandle, + _FSDPState, + _get_module_fsdp_state, + _is_fsdp_flattened, + _named_parameters_with_duplicates, + clean_tensor_name, + TrainingState, +) +from torch.distributed.fsdp._flat_param import ( + _FSDP_USE_FULL_PREC_IN_EVAL, + FlatParameter, + FlatParamHandle, + HandleShardingStrategy, +) +from torch.distributed.fsdp._limiter_utils import _FreeEventQueue +from torch.distributed.fsdp.api import ( + BackwardPrefetch, + CPUOffload, + FullOptimStateDictConfig, + FullStateDictConfig, + MixedPrecision, + ShardingStrategy, + StateDictConfig, + StateDictType, +) +from torch.distributed.fsdp.wrap import _Policy +from torch.distributed.tensor.parallel.fsdp import DTensorExtensions +from torch.distributed.utils import _sync_params_and_buffers +from torch.utils._python_dispatch import is_traceable_wrapper_subclass + + +if TYPE_CHECKING: + from torch.utils.hooks import RemovableHandle + +_TORCHDISTX_AVAIL = True +try: + from torchdistx import deferred_init, fake # type: ignore[import] +except ImportError: + _TORCHDISTX_AVAIL = False + +PARAM_BROADCAST_BUCKET_SIZE = 250 * 1024 * 1024 +FSDP_SYNCED = "_fsdp_synced" +# Specification of process groups for hybrid sharding strategies. +HybridShardProcessGroupType = tuple[dist.ProcessGroup, dist.ProcessGroup] +# Overall specification of process group. +ProcessGroupType = Optional[Union[dist.ProcessGroup, HybridShardProcessGroupType]] + + +# TODO (awgu): Refactor this later +SHARDING_STRATEGY_MAP = { + ShardingStrategy.NO_SHARD: HandleShardingStrategy.NO_SHARD, + ShardingStrategy.FULL_SHARD: HandleShardingStrategy.FULL_SHARD, + ShardingStrategy.SHARD_GRAD_OP: HandleShardingStrategy.SHARD_GRAD_OP, + ShardingStrategy.HYBRID_SHARD: HandleShardingStrategy.HYBRID_SHARD, + ShardingStrategy._HYBRID_SHARD_ZERO2: HandleShardingStrategy._HYBRID_SHARD_ZERO2, +} +HYBRID_SHARDING_STRATEGIES = [ + ShardingStrategy.HYBRID_SHARD, + ShardingStrategy._HYBRID_SHARD_ZERO2, +] +NO_RESHARD_AFTER_FORWARD_STRATEGIES = ( + ShardingStrategy.SHARD_GRAD_OP, + ShardingStrategy._HYBRID_SHARD_ZERO2, +) + + +# NOTE: Since non-self attributes cannot be type annotated, several attributes +# on `state` are defined first as local variables before being assigned. + + +@no_type_check +def _init_process_group_state( + state: _FSDPState, + process_group: ProcessGroupType, + sharding_strategy: ShardingStrategy, + policy: Optional[_Policy], + device_mesh: Optional[DeviceMesh] = None, +) -> _FSDPState: + if process_group is not None and device_mesh is not None: + raise ValueError( + "Cannot pass both process_group and device_mesh at the " + "same time. Please just pass only one of them." + ) + is_hybrid_strategy = sharding_strategy in HYBRID_SHARDING_STRATEGIES + if is_hybrid_strategy: + if process_group is None and policy is None and device_mesh is None: + # Raise an error here, since this is manual wrapping with no process group + # passed in, there is no way to ensure all wrapped FSDP instances use the same + # process groups. + raise ValueError( + f"Manual wrapping with {sharding_strategy} " + "requires explicit specification of process group or device_mesh." + ) + else: + state = _init_process_group_state_for_hybrid_shard( + state, process_group, device_mesh + ) + else: + if device_mesh: + state._device_mesh = device_mesh + state.process_group = device_mesh.get_group(mesh_dim=0) + else: + state.process_group = ( + process_group if process_group is not None else _get_default_group() + ) + + state.rank = state.process_group.rank() + state.world_size = state.process_group.size() + data_parallel_world_size = state.world_size + if is_hybrid_strategy: + data_parallel_world_size *= state._inter_node_pg.size() + state._gradient_predivide_factor = ( + default_hooks.DefaultState._get_gradient_predivide_factor( + data_parallel_world_size + ) + ) + state._gradient_postdivide_factor = ( + data_parallel_world_size / state._gradient_predivide_factor + ) + return state + + +@no_type_check +def _init_process_group_state_for_hybrid_shard( + state: _FSDPState, + process_group: ProcessGroupType, + device_mesh: DeviceMesh, +) -> _FSDPState: + if device_mesh: + if _is_valid_hybrid_shard_device_mesh(device_mesh): + state._device_mesh = device_mesh + # We currently only allow _inter_node_pg to be the outermost dimension, and the + # process_group(intra_node) to be the innermost dimension. + state._inter_node_pg = device_mesh.get_group(mesh_dim=0) + state.process_group = device_mesh.get_group(mesh_dim=1) + else: + raise ValueError( + f"Expected device_mesh to have ndim=2 but got {device_mesh.ndim}" + ) + elif process_group is None: + default_group = _get_default_group() + intra_node_group, inter_node_group = _init_intra_and_inter_node_groups( + default_group, state._device_handle.device_count() + ) + # we shard across intra-node + state.process_group = intra_node_group + # save _inter_node_pg to allreduce across. + state._inter_node_pg = inter_node_group + else: + # Check type and assign state.process_group and state._inter_node_pg. + if _is_valid_hybrid_shard_pg_type(process_group): + # Assuming that user passed in as intra node group and inter node group + # as documented. + state.process_group, state._inter_node_pg = process_group + else: + raise ValueError( + "Expected process_group to be passed in as either None or " + f"Tuple[dist.ProcessGroup, dist.ProcessGroup] but got {type(process_group)}" + ) + # Create state for allreduce + state._inter_node_state = _get_default_comm_hook_state( + process_group=state._inter_node_pg, + ) + return state + + +@no_type_check +def _is_valid_hybrid_shard_pg_type(process_group: Any) -> bool: + return ( + isinstance(process_group, tuple) + and len(process_group) == 2 + and all(isinstance(pg, dist.ProcessGroup) for pg in process_group) + ) + + +@no_type_check +def _is_valid_hybrid_shard_device_mesh(device_mesh: DeviceMesh) -> bool: + return isinstance(device_mesh, DeviceMesh) and device_mesh.ndim == 2 + + +@no_type_check +def _init_intra_node_process_group(num_devices_per_node: int) -> dist.ProcessGroup: + """ + Return a process group across the current node. + + For example, given each row is a distinct node: + 0 1 2 3 4 5 6 7 + 8 9 10 11 12 13 14 15 + This API would return an intra-node subgroup across + [0, 1, ..., 7] or [8, 9, ..., 15] depending on the process's rank. + For example, rank 3 would get [0, 1, ..., 7]. + """ + intra_node_subgroup, _ = dist.new_subgroups(num_devices_per_node) + return intra_node_subgroup + + +@no_type_check +def _init_inter_node_process_group( + global_process_group: dist.ProcessGroup, + num_devices_per_node: int, +) -> dist.ProcessGroup: + """ + Return an inter-node process group where each contained rank has the same local rank. + + For example, given each row is a distinct node: + 0 1 2 3 4 5 6 7 + 8 9 10 11 12 13 14 15 + This API would return inter-node process group [0, 8], [1, 9], [2, 10], and so forth + depending on the process's rank. For example, rank 1 would get [1, 9], rank 5 + would get [5, 13]. + """ + # the inter-node pg that is returned + inter_node_pg = None + sharding_backend = dist.get_backend(global_process_group) + world_size = dist.get_world_size(global_process_group) + # Assuming fully homogeneous setup + num_nodes = world_size // num_devices_per_node + my_local_rank = dist.get_rank(global_process_group) % num_devices_per_node + for local_rank in range(num_devices_per_node): + ranks_for_inter_group = [ + local_rank + (i * num_devices_per_node) for i in range(num_nodes) + ] + # every rank always needs to call dist.new_group + grp = dist.new_group(ranks=ranks_for_inter_group, backend=sharding_backend) + if local_rank == my_local_rank: + inter_node_pg = grp + + if inter_node_pg is None: + raise AssertionError( + f"{my_local_rank} expected to assign inter-node pg, but did not" + ) + return inter_node_pg + + +def _init_intra_and_inter_node_groups( + global_process_group: dist.ProcessGroup, + num_devices_per_node: int, +) -> tuple[dist.ProcessGroup, dist.ProcessGroup]: + """ + Initialize intra and inter-node process groups and return the ones corresponding to this process's rank. + + This function can be used to initialize process groups for ``HYBRID_SHARD`` or + ``_HYBRID_SHARD_ZERO2`` in FSDP. + This function assumes each node has an equal number of CUDA-enabled devices. + Returns: + Tuple[dist.ProcessGroup, dist.ProcessGroup]: Intra and inter-node process group. + """ + return ( + _init_intra_node_process_group(num_devices_per_node), + _init_inter_node_process_group(global_process_group, num_devices_per_node), + ) + + +@no_type_check +def _init_ignored_module_states( + state: _FSDPState, + module: nn.Module, + ignored_modules: Optional[Iterable[torch.nn.Module]], + ignored_states: Union[ + Optional[Iterable[torch.nn.Parameter]], Optional[Iterable[torch.nn.Module]] + ] = None, +) -> _FSDPState: + if ignored_modules is not None and ignored_states is not None: + raise ValueError( + "Cannot pass both ignored_modules and ignored_states at the " + "same time. Please just pass ignored_states." + ) + ignored_parameters = None + passed_as_ignored_states = ignored_states is not None + if passed_as_ignored_states: + ignored_states_list = list(ignored_states) + _check_ignored_states(ignored_states_list, True) + else: + ignored_states_list = [] + _check_ignored_states( + list(ignored_modules) if ignored_modules is not None else [], False + ) + if len(ignored_states_list) > 0: + if isinstance(ignored_states_list[0], nn.Parameter): + ignored_parameters = ignored_states_list + else: + ignored_modules = ignored_states_list + state._ignored_modules = _get_ignored_modules(module, ignored_modules) + state._ignored_params = _get_ignored_params( + module, + state._ignored_modules, + ignored_parameters, + ) + state._ignored_buffer_names = _get_ignored_buffer_names( + module, + state._ignored_modules, + ) + # TODO: FSDP's contract for buffers is not well-defined. They are + # implicitly ignored for most functionality since they are not sharded; + # however, FSDP still imposes some semantics on buffers (e.g. buffer mixed + # precision). We should formalize this contract and decide if we need to + # compute and store `_ignored_buffers`. + return state + + +def _check_ignored_states( + ignored_states: list[Any], passed_as_ignored_states: bool +) -> None: + """ + Check that the ignored states are uniformly parameters or uniformly modules. + + We may remove this check in the future if we permit mixing. + """ + if len(ignored_states) == 0: + return + if passed_as_ignored_states: + all_params = all(isinstance(state, nn.Parameter) for state in ignored_states) + all_modules = all(isinstance(state, nn.Module) for state in ignored_states) + if not all_params and not all_modules: + # Sort for consistent ordering for unit test regex matching + sorted_types = sorted({type(state) for state in ignored_states}, key=repr) + raise ValueError( + "ignored_states expects all nn.Parameter or all nn.Module list " + f"elements but got types {sorted_types}" + ) + else: + if not all(isinstance(state, nn.Module) for state in ignored_states): + sorted_types = sorted({type(state) for state in ignored_states}, key=repr) + raise ValueError( + "ignored_modules expects nn.Module list elements but got " + f"types {sorted_types}" + ) + + +@no_type_check +def _init_device_handle( + state: _FSDPState, + module: nn.Module, + ignored_params: set[nn.Parameter], + device_id: Optional[Union[int, torch.device]], +) -> _FSDPState: + """ + Determine device handle used for initializing FSDP. + + If a device is specified by ``device_id``, + then returns device handle corresponds to that device type. Otherwise, If the + module is already on a non-CPU device, then the device type is that non-CPU device type. + If the module is on CPU or meta, then the device type is the current accelerator device. + See the :ref:`Accelerators` for details. + + + This method will be called once ignored parameters was determined, as the device handle maybe needed + for other initialization. + """ + determined_device = None + if device_id is not None: + determined_device = ( + device_id + if isinstance(device_id, torch.device) + else torch.device(device_id) + ) + if determined_device is None: + for param in _get_orig_params(module, ignored_params): + if param.device.type in {"cpu", "meta"}: + continue + if determined_device is None: + determined_device = param.device + else: + if param.device.type != determined_device.type: + raise RuntimeError( + f"FSDP does not support modules with different device types " + f"but got params on {determined_device.type} and {param.device.type}" + ) + determined_device = determined_device or torch._C._get_accelerator() + if determined_device.type == "cpu": + raise RuntimeError( + "FSDP needs a non-CPU accelerator device, but no accelerator device is detected." + ) + + state._device_handle = _FSDPDeviceHandle.from_device(determined_device) + return state + + +@no_type_check +def _init_buffer_state( + state: _FSDPState, + module: nn.Module, +) -> _FSDPState: + state._buffer_names = _get_buffer_names(module) + # Save a mapping from clean fully-qualified buffer name (starting from + # `module`) to its original dtype for restoring that dtype during model + # checkpointing when buffer mixed precision is enabled. The names should + # be clean since the casting happens in a `summon_full_params()` context. + _buffer_name_to_orig_dtype: dict[str, torch.dtype] = {} + for buffer_name, buffer in module.named_buffers(): + buffer_name = clean_tensor_name(buffer_name) + _buffer_name_to_orig_dtype[buffer_name] = buffer.dtype + state._buffer_name_to_orig_dtype = _buffer_name_to_orig_dtype + return state + + +@no_type_check +def _init_core_state( + state: _FSDPState, + sharding_strategy: Optional[ShardingStrategy], + mixed_precision: Optional[MixedPrecision], + cpu_offload: Optional[CPUOffload], + limit_all_gathers: bool, + use_orig_params: bool, + backward_prefetch_limit: int, + forward_prefetch_limit: int, +) -> _FSDPState: + # We clamp the strategy to `NO_SHARD` for world size of 1 since they are + # currently functionally equivalent. This may change if/when we integrate + # FSDP with MoE. + if state.world_size == 1: + if sharding_strategy != ShardingStrategy.NO_SHARD: + warnings.warn( + "FSDP is switching to use `NO_SHARD` instead of " + f"{sharding_strategy or ShardingStrategy.FULL_SHARD} since " + "the world size is 1.", + stacklevel=2, + ) + sharding_strategy = ShardingStrategy.NO_SHARD + elif sharding_strategy == ShardingStrategy.NO_SHARD: + warnings.warn( + "The `NO_SHARD` sharding strategy is deprecated. If having issues, " + "please use `DistributedDataParallel` instead.", + FutureWarning, + # Level 1 is here, level 2 is from `FullyShardedDataParallel`, and + # level 3 is from the true caller + stacklevel=3, + ) + state.sharding_strategy = sharding_strategy or ShardingStrategy.FULL_SHARD + state.mixed_precision = mixed_precision or MixedPrecision() + if mixed_precision is not None: + torch._C._log_api_usage_once( + f"torch.distributed.fsdp.mixed_precision.{str(state.mixed_precision)}" + ) + state._use_full_prec_in_eval = ( + os.environ.get(_FSDP_USE_FULL_PREC_IN_EVAL, "") == "1" + ) + state.cpu_offload = cpu_offload or CPUOffload() + state.limit_all_gathers = limit_all_gathers + state._use_orig_params = use_orig_params + state.training_state = TrainingState.IDLE + state._is_root = None + state._free_event_queue = _FreeEventQueue() + state._debug_level = dist.get_debug_level() + state._exec_order_data = exec_order_utils._ExecOrderData( + state._debug_level, + backward_prefetch_limit, + forward_prefetch_limit, + ) + state._unshard_event = None + # Mapping from fully sharded module to the handles it is responsible to + # unshard and reshard (see [Note: Fully Sharded Module]) + _fully_sharded_module_to_handle: dict[nn.Module, FlatParamHandle] = {} + state._fully_sharded_module_to_handle = _fully_sharded_module_to_handle + # Invariant: `state.params` contains exactly the `FlatParameter`s of the + # handles in `state._handle` + _handle: Optional[FlatParamHandle] = None + state._handle = _handle + params: list[FlatParameter] = [] + state.params = params + return state + + +@no_type_check +def _init_runtime_state( + state: _FSDPState, +) -> _FSDPState: + _root_pre_forward_handles: list[RemovableHandle] = [] + state._root_pre_forward_handles = _root_pre_forward_handles + _pre_forward_handles: list[RemovableHandle] = [] + state._pre_forward_handles = _pre_forward_handles + _post_forward_handles: list[RemovableHandle] = [] + state._post_forward_handles = _post_forward_handles + state._sync_gradients = True + state._comm_hook = None + state._comm_hook_state = None + # Used to prevent running the pre-backward hook multiple times + return state + + +@no_type_check +def _init_prefetching_state( + state: _FSDPState, + backward_prefetch: BackwardPrefetch, + forward_prefetch: bool, +) -> _FSDPState: + state.backward_prefetch = backward_prefetch + state.forward_prefetch = forward_prefetch + # The data structures use tuples of handles to generalize over the case + # where a module's forward involves multiple handles. + return state + + +@no_type_check +# pyrefly: ignore [bad-function-definition] +def _init_extension(state: _FSDPState, device_mesh: DeviceMesh = None) -> _FSDPState: + # TODO: we need to add additional check once we support FSDP + PiPPy. + # This check is currently sufficient, since we only support FSDP + TP. + root_mesh = device_mesh._get_root_mesh() if device_mesh is not None else None + # if a root mesh is not the same as device_mesh, + # meaning the device_mesh is sliced out from the root mesh. + if device_mesh and root_mesh != state._device_mesh: + state._fsdp_extension = DTensorExtensions(state._device_handle) + else: + # We need to explicitly set _fsdp_extension to None. + # Otherwise, we will run into an infinite recursion when getting the attribute. + state._fsdp_extension = None + return state + + +@no_type_check +def _init_state_dict_state(state: _FSDPState) -> _FSDPState: + state._state_dict_type = StateDictType.FULL_STATE_DICT + state_dict_config: StateDictConfig = FullStateDictConfig() + state._optim_state_dict_config = FullOptimStateDictConfig() + state._state_dict_config = state_dict_config + unshard_params_ctx: dict[nn.Module, Generator] = {} + state._unshard_params_ctx = unshard_params_ctx + + return state + + +def _verify_managed_params(module: nn.Module, params: list[nn.Parameter]) -> None: + """ + Verify if the parameters are accepted by FSDP. The only restriction now + is that the parameter cannot be a scalar tensor (param.shape == []). + """ + for param in params: + if len(param.shape) == 0: + param_name = "" + for name, param_ in module.named_parameters(): + if param is param_: + param_name = name + break + if not param_name: + raise AssertionError("Expected param_name to be set") + raise ValueError( + "FSDP doesn't support scalar parameters. " + f"Change {param_name} to a 1D tensor with numel equal to 1." + ) + + +@no_type_check +def _init_param_handle_from_module( + state: _FSDPState, + fully_sharded_module: nn.Module, + device_id: Optional[Union[int, torch.device]], + param_init_fn: Optional[Callable[[nn.Module], None]], + sync_module_states: bool, +) -> _FSDPState: + """Initialize a ``FlatParamHandle`` from a module ``fully_sharded_module``.""" + _check_single_device_module(fully_sharded_module, state._ignored_params, device_id) + device_from_device_id = _get_device_from_device_id( + device_id, state.rank, state._device_handle + ) + is_meta_module, is_torchdistX_deferred_init = _need_to_materialize_module( + fully_sharded_module, state._ignored_params, state._ignored_modules + ) + # Materialize the module if needed + if (is_meta_module or is_torchdistX_deferred_init) and param_init_fn is not None: + _materialize_with_param_init_fn( + fully_sharded_module, param_init_fn, state._ignored_modules + ) + elif is_meta_module: + _materialize_meta_module( + fully_sharded_module, + device_id, + state._ignored_modules, + state._device_handle, + ) + elif is_torchdistX_deferred_init: + deferred_init.materialize_module( + fully_sharded_module, + check_fn=lambda submodule: _get_module_fsdp_state(submodule) is None + and submodule not in state._ignored_modules, + ) + + ignored_buffers = { + buffer + for ignored_module in state._ignored_modules + for buffer in ignored_module.buffers() + } + + _move_module_to_device( + fully_sharded_module, + state._ignored_params, + ignored_buffers, + device_from_device_id, + ) + state.compute_device = _get_compute_device( + fully_sharded_module, + state._ignored_params, + device_from_device_id, + state.rank, + state._device_handle, + ) + + managed_params = list(_get_orig_params(fully_sharded_module, state._ignored_params)) + _verify_managed_params(fully_sharded_module, managed_params) + if sync_module_states: + _sync_module_params_and_buffers( + fully_sharded_module, managed_params, state.process_group + ) + if state.sharding_strategy in HYBRID_SHARDING_STRATEGIES: + _sync_module_params_and_buffers( + fully_sharded_module, managed_params, state._inter_node_pg + ) + _init_param_handle_from_params(state, managed_params, fully_sharded_module) + return state + + +@no_type_check +def _init_param_handle_from_params( + state: _FSDPState, + params: list[nn.Parameter], + fully_sharded_module: nn.Module, +): + if len(params) == 0: + return + handle = FlatParamHandle( + params, + fully_sharded_module, + state.compute_device, + SHARDING_STRATEGY_MAP[state.sharding_strategy], + state.cpu_offload.offload_params, + state.mixed_precision.param_dtype, + state.mixed_precision.reduce_dtype, + state.mixed_precision.keep_low_precision_grads, + state.process_group, + state._use_orig_params, + fsdp_extension=state._fsdp_extension, + ) + handle.shard() + if state._handle: + raise AssertionError("Expected state._handle to be None") + state.params.append(handle.flat_param) + state._handle = handle + state._fully_sharded_module_to_handle[handle._fully_sharded_module] = handle + cpu_device = torch.device("cpu") + if state.cpu_offload.offload_params and handle.flat_param.device != cpu_device: + handle.flat_param_to(cpu_device) + + +def _get_ignored_modules( + root_module: nn.Module, + _ignored_modules: Optional[Iterable[torch.nn.Module]], +) -> set[nn.Module]: + """ + Check that ``_ignored_modules`` is an iterable of ``nn.Module`` s without any FSDP instances. + + Return the modules contained in their module + subtrees as a :class:`set`. Nested FSDP instances are excluded, but their + already-computed ignored modules are included. + + ``_ignored_modules`` represents the argument passed by the user to FSDP. + """ + msg_prefix = "`ignored_modules` should be an iterable of `torch.nn.Module`s " + try: + ignored_root_modules = ( + set(_ignored_modules) if _ignored_modules is not None else set() + ) + except TypeError as e: + raise TypeError(msg_prefix + f"but got {type(_ignored_modules)}") from e + for module in ignored_root_modules: + if not isinstance(module, torch.nn.Module): + raise TypeError(msg_prefix + f"but got an iterable with {type(module)}") + if _get_module_fsdp_state(module): + # TODO: We may relax this by taking the FSDP instance's wrapped + # module to provide more flexibility to the user. + raise ValueError("`ignored_modules` should not include FSDP modules") + # Treat modules that cannot compose with `fully_shard` as ignored modules, + # meaning that their subtrees are ignored + for module in root_module.modules(): + if not traversal_utils._composable(module): + ignored_root_modules.add(module) + # NOTE: Even if `ignored_root_modules` is empty, do not return early so + # that this FSDP instance can get any ignored modules from its children. + + # Include child modules and exclude nested FSDP modules themselves + ignored_modules = { + child + for module in ignored_root_modules + for child in module.modules() + if not isinstance(child, fsdp_file.FullyShardedDataParallel) + } + if root_module in ignored_modules: + warnings.warn( + "Trying to ignore the top-level module passed into the FSDP " + "constructor itself will result in all parameters being " + f"ignored and is not well-supported: {module}", + stacklevel=2, + ) + # Include nested FSDP modules' ignored modules + for submodule in root_module.modules(): + optional_fsdp_state = _get_module_fsdp_state(submodule) + if optional_fsdp_state is not None: + if not hasattr(optional_fsdp_state, "_ignored_modules"): + raise AssertionError( + "Expected optional_fsdp_state to have _ignored_modules attribute" + ) + ignored_modules.update(optional_fsdp_state._ignored_modules) + return ignored_modules + + +def _get_ignored_params( + root_module: torch.nn.Module, + ignored_modules: set[torch.nn.Module], + ignored_parameters: Optional[Iterable[torch.nn.Parameter]] = None, +) -> set[torch.nn.Parameter]: + """ + Return the parameters of the modules in ``ignored_modules`` and the parameters in ``ignored_parameters``. + + :class:`FlatParameter` s are excluded from the result. + """ + all_ignored_params: set[torch.nn.Parameter] = set() + + params_in_ignored_modules = { + p for m in ignored_modules for p in m.parameters() if not _is_fsdp_flattened(p) + } + + all_ignored_params.update(params_in_ignored_modules) + + if ignored_parameters is not None: + params_in_ignored_parameters = { + p for p in ignored_parameters if not _is_fsdp_flattened(p) + } + all_ignored_params.update(params_in_ignored_parameters) + + # Always include nested FSDP modules' ignored parameters + for submodule in root_module.modules(): + optional_fsdp_state = _get_module_fsdp_state(submodule) + if optional_fsdp_state is not None: + if not hasattr(optional_fsdp_state, "_ignored_params"): + raise AssertionError( + "Expected optional_fsdp_state to have _ignored_params attribute" + ) + all_ignored_params.update(optional_fsdp_state._ignored_params) + + return all_ignored_params + + +def _get_ignored_buffer_names( + root_module: torch.nn.Module, + ignored_modules: set[torch.nn.Module], +) -> set[str]: + """Return the cleaned buffer FQNs in ``ignored_modules``.""" + all_ignored_buffer_names: set[str] = set() + + buffers_in_ignored_modules = { + buffer for m in ignored_modules for buffer in m.buffers() + } + + all_ignored_buffer_names.update( + { + clean_tensor_name(buffer_name) + for buffer_name, buffer in root_module.named_buffers() + if buffer in buffers_in_ignored_modules + } + ) + + # Always include nested FSDP modules' ignored buffer names + for submodule in root_module.modules(): + optional_fsdp_state = _get_module_fsdp_state(submodule) + if optional_fsdp_state is not None: + if not hasattr(optional_fsdp_state, "_ignored_buffer_names"): + raise AssertionError( + "Expected optional_fsdp_state to have _ignored_buffer_names attribute" + ) + all_ignored_buffer_names.update(optional_fsdp_state._ignored_buffer_names) + + return all_ignored_buffer_names + + +def _get_buffer_names(root_module: nn.Module) -> set[str]: + """Return the fully prefixed names of all buffers in the module hierarchy rooted at ``root_module`` as a class:`set`.""" + return { + clean_tensor_name(buffer_name) for buffer_name, _ in root_module.named_buffers() + } + + +def _check_single_device_module( + module: nn.Module, + ignored_params: set[nn.Parameter], + device_id: Optional[Union[int, torch.device]], +) -> None: + """ + Raise an error if ``module`` has original parameters on multiple devices, ignoring the parameters in ``ignored_params``. + + Thus, after this method, the + module must be either fully on the CPU or fully on a non-CPU device. + """ + devices = {param.device for param in _get_orig_params(module, ignored_params)} + # We allow module to be partially on CPU and partially on GPU if device_id is not + # None, since the device_id arg will result in the CPU portion being moved to + # GPU. This is useful in cases where part of the module may be parallelized + # by another algorithm and may already be on GPU. We'd like to enforce device_id + # to not be None, otherwise we'd flatten parameters in a mixed module which is + # not supported. + if len(devices) == 2 and torch.device("cpu") in devices: + if device_id is None: + raise RuntimeError( + "To support a module with both CPU and GPU params, " + "please pass in device_id argument." + ) + elif len(devices) > 1: + raise RuntimeError( + f"FSDP only supports single device modules but got params on {devices}" + ) + + +def _get_device_from_device_id( + device_id: Optional[Union[int, torch.device]], + rank: int, + device_handle: _FSDPDeviceHandle, +) -> Optional[torch.device]: + """ + Return a ``torch.device`` for the specified ``device_id``. + + Processes ``device_id`` and returns either the corresponding device or + ``None`` if ``device_id`` is ``None``. + """ + if device_id is None: + return None + device = ( + device_id if isinstance(device_id, torch.device) else torch.device(device_id) + ) + if device.type != "cpu" and device.index is None: + warnings.warn( + f"FSDP got the argument `device_id` {device_id} on rank " + f"{rank}, which does not have an explicit index. " + f"FSDP will use the current device {device_handle.current_device()}. " + f"If this is incorrect, please explicitly call `torch.{device.type}.set_device()` " + "before FSDP initialization or pass in the explicit device " + "index as the `device_id` argument.", + stacklevel=2, + ) + device = torch.device(device_handle.current_device()) + return device + + +def _need_to_materialize_module( + module: nn.Module, + ignored_params: set[nn.Parameter], + ignored_modules: set[nn.Module], +) -> tuple[bool, bool]: + """ + Return if ``module`` has parameters on meta device and if ``module`` is using torchdistX deferred initialization. + + At most of the returned bools can + be ``True``. If either is ``True``, then ``module`` needs to be + materialized. + """ + managed_params = list(_get_orig_params(module, ignored_params)) + is_meta_module = any(param.is_meta for param in managed_params) + # TODO: We need to establish a contract for FSDP and buffers. For now, we + # skip checking for meta buffers from ignored modules. We should consider + # refactoring the initialization holistically to avoid so many traversals. + for submodule in module.modules(): + if submodule in ignored_modules: + continue + for buf in submodule.buffers(recurse=False): + is_meta_module |= buf.is_meta + is_torchdistX_deferred_init = ( + not is_meta_module + and _TORCHDISTX_AVAIL + and any(fake.is_fake(param) for param in managed_params) + ) + return is_meta_module, is_torchdistX_deferred_init + + +def _materialize_with_param_init_fn( + root_module: nn.Module, + param_init_fn: Callable[[nn.Module], None], + ignored_modules: set[nn.Module], +) -> None: + if not callable(param_init_fn): + raise ValueError( + f"Expected {param_init_fn} to be callable but got {type(param_init_fn)}" + ) + modules_to_materialize = _get_modules_to_materialize(root_module, ignored_modules) + for module in modules_to_materialize: + param_init_fn(module) + + +def _materialize_meta_module( + root_module: nn.Module, + device_from_device_id: Optional[torch.device], + ignored_modules: set[nn.Module], + device_handle: _FSDPDeviceHandle, +): + # Run default meta device initialization + materialization_device = device_from_device_id or torch.device( + device_handle.current_device() + ) + modules_to_materialize = _get_modules_to_materialize(root_module, ignored_modules) + module = None + try: + # Assume that each module's `reset_parameters()` only initializes its + # own parameters and not those of its children + with torch.no_grad(): + for module in modules_to_materialize: + # As a contract to the user, only call `reset_parameters()` if + # the module has directly managed parameters/buffers + module_state_iter = itertools.chain( + module.parameters(recurse=False), + # pyrefly: ignore [bad-argument-type] + module.buffers(recurse=False), + ) + has_module_states = len(list(module_state_iter)) > 0 + if has_module_states: + module.to_empty(device=materialization_device, recurse=False) + module.reset_parameters() # type: ignore[operator] + except BaseException as e: + warnings.warn( + "Unable to call `reset_parameters()` for module on meta " + f"device with error {str(e)}. Please ensure that your module of" + f"type {type(module)} implements a `reset_parameters()` method.", + stacklevel=2, # type: ignore[possibly-undefined] + ) + raise e + + +def _get_modules_to_materialize( + root_module: nn.Module, ignored_modules: set[nn.Module] +) -> list[nn.Module]: + # Run BFS to collect the modules to materialize via `reset_parameters()`, + # stopping at any module with FSDP already applied or at ignored modules. + modules_to_materialize: list[nn.Module] = [] + queue = collections.deque([root_module]) + visited_modules: set[nn.Module] = {root_module} + while queue: + module = queue.popleft() + modules_to_materialize.append(module) + for child_module in module.children(): + if ( + child_module not in visited_modules + and _get_module_fsdp_state(child_module) is None + and child_module not in ignored_modules + ): + visited_modules.add(child_module) + queue.append(child_module) + return modules_to_materialize + + +def _move_module_to_device( + module: nn.Module, + ignored_params: set[nn.Parameter], + ignored_buffers: set[torch.Tensor], + device_from_device_id: Optional[torch.device], +) -> None: + """ + Move ``module`` depending on ``device_from_device_id`` and its current device. + + This includes moving ignored modules' parameters. + + - If ``device_from_device_id`` is not ``None``, then this moves + ``module`` to the device. + - If ``device_from_device_id`` is ``None``, then this does not move + ``module`` but warns the user if it is on CPU. + + Precondition: ``_check_single_device_module()``. + """ + cpu_device = torch.device("cpu") + if device_from_device_id is not None: + # BFS from `module` without traversing any nested FSDP instances to + # collect the parameters/buffers that have not yet been managed + queue: collections.deque[nn.Module] = collections.deque() + queue.append(module) + params: list[nn.Parameter] = [] + buffers: list[torch.Tensor] = [] + while queue: + curr_module = queue.popleft() + # NOTE: We include a check to only move parameters/buffers that are + # on CPU device. If they are on a CUDA device different from the + # one specified by `device_id`, then this does NOT move them. This + # is so that we can raise an error in `_get_compute_device()`. + params.extend( + param + for param in curr_module.parameters(recurse=False) + if param.device == cpu_device + ) + buffers.extend( + buffer + for buffer in curr_module.buffers(recurse=False) + if buffer.device == cpu_device + ) + for submodule in curr_module.children(): + if not isinstance(submodule, fsdp_file.FullyShardedDataParallel): + queue.append(submodule) + params_to_move = [p for p in params if p not in ignored_params] + bufs_to_move = [p for p in buffers if p not in ignored_buffers] + _move_states_to_device(params_to_move, bufs_to_move, device_from_device_id) + return + param = next(_get_orig_params(module, ignored_params), None) + if param is not None and param.device == cpu_device: + _warn_cpu_init() + + +def _move_states_to_device( + params: list[nn.Parameter], + buffers: list[torch.Tensor], + device_from_device_id: Optional[torch.device], +) -> None: + """ + Move states to the specified device. + + Precondition: ``_check_single_device_module()`` and module's parameters and + buffers have been materialized if needed. + """ + if len(params) == 0 and len(buffers) == 0: + return + if len(params) > 0: + current_device = params[0].device + elif len(buffers) > 0: + current_device = buffers[0].device + cpu_device = torch.device("cpu") + if device_from_device_id is not None: + # Move the parameters and buffers like the `.data` code path in + # `nn.Module._apply()`, which underlies `nn.Module.to()` + for param in params: + with torch.no_grad(): + param.data = param.to(device_from_device_id) + if param.grad is not None: + param.grad.data = param.grad.to(device_from_device_id) + for buffer in buffers: + buffer.data = buffer.to(device_from_device_id) + elif current_device == cpu_device: # type: ignore[possibly-undefined] + _warn_cpu_init() + + +def _warn_cpu_init(): + warnings.warn( + "The passed-in `module` is on CPU and will thus have FSDP's sharding " + "initialization run on CPU, which may be slower than on GPU. We " + "recommend passing in the `device_id` argument for FSDP to move " + "`module` to GPU for the sharding initialization. `module` must also " + "be on GPU device to work with the `sync_module_states=True` flag " + "since that requires GPU communication.", + stacklevel=2, + ) + + +def _get_compute_device( + module: nn.Module, + ignored_params: set[nn.Parameter], + device_from_device_id: Optional[torch.device], + rank: int, + device_handle: _FSDPDeviceHandle, +) -> torch.device: + """ + Determine and return this FSDP instance's compute device. + + If the module is already on a non-CPU device, then the compute device is that non-CPU + device. If the module is on CPU, then the compute device is the current + device. + + Since this method should be called after materializing the module, any + non-CPU device should not be meta device. For now, the compute device is + always a CUDA or CUDA-like device with its explicit index. + + Precondition: ``_check_single_device_module()`` and + ``_move_module_to_device()``. + """ + param = next(_get_orig_params(module, ignored_params), None) + if param is not None and param.device.type != "cpu": + compute_device = param.device # Determined by model param placement + else: + compute_device = torch.device(device_handle.current_device()) + if device_from_device_id is not None and compute_device != device_from_device_id: + raise ValueError( + f"Inconsistent compute device and `device_id` on rank {rank}: " + f"{compute_device} vs {device_from_device_id}" + ) + return compute_device + + +# TODO: See how to deprecate! +def _sync_module_params_and_buffers( + module: nn.Module, + params: list[nn.Parameter], + process_group: dist.ProcessGroup, +) -> None: + """ + Synchronize module states (i.e. parameters ``params`` and all not-yet-synced buffers) by broadcasting from rank 0 to all ranks. + + Precondition: ``sync_module_states == True`` and ``self.process_group`` has + been set. + """ + module_states: list[torch.Tensor] = [] + for buffer in module.buffers(): + # Avoid re-synchronizing buffers in case of nested wrapping + if not getattr(buffer, FSDP_SYNCED, False): + setattr(buffer, FSDP_SYNCED, True) + detached_buffer = buffer.detach() + if is_traceable_wrapper_subclass(detached_buffer): + # NOTE: Here we assume no nested subclasses, at most one level of subclass + # in both model's buffers and params + attrs, _ = detached_buffer.__tensor_flatten__() # type: ignore[attr-defined] + inner_buffers = [getattr(detached_buffer, attr) for attr in attrs] + module_states.extend(inner_buffers) + else: + module_states.append(detached_buffer) + + for param in params: + detached_param = param.detach() + if is_traceable_wrapper_subclass(detached_param): + attrs, _ = detached_param.__tensor_flatten__() # type: ignore[attr-defined] + inner_params = [getattr(detached_param, attr) for attr in attrs] + module_states.extend(inner_params) + else: + module_states.append(detached_param) + + _check_module_states_for_sync_module_states(module_states) + _sync_params_and_buffers( + process_group, + module_states, + PARAM_BROADCAST_BUCKET_SIZE, + src=0, + ) + + +def _check_module_states_for_sync_module_states( + module_states: list[torch.Tensor], +) -> None: + if module_states and any( + tensor.device == torch.device("cpu") for tensor in module_states + ): + raise ValueError( + "The module has CPU parameters or buffers when `sync_module_states=True`, " + "which requires them to be on GPU. Please specify the `device_id` argument " + "or move the module to GPU before passing it to FSDP." + ) + + +def _get_orig_params( + module: nn.Module, + ignored_params: set[nn.Parameter], +) -> Iterator[nn.Parameter]: + """ + Return an iterator over the original parameters in ``module``. + + The iterator does not return + the parameters in ``ignored_params``, any ``FlatParameter`` s (which may be + present due to nested FSDP wrapping), or any original parameters already + flattened (only relevant when ``use_orig_params=True``). + """ + param_gen = module.parameters() + try: + while True: + param = next(param_gen) + if param not in ignored_params and not _is_fsdp_flattened(param): + yield param + except StopIteration: + pass + + +def _check_orig_params_flattened( + fsdp_module, + ignored_params: set[nn.Parameter], +) -> None: + """ + Check that original parameters in ``fsdp_module`` have been flattened. + + The flattened parameters are made + invisible to ``named_parameters()`` for the module hierarchy rooted at + ``fsdp_module``. This should be called as a sanity check after flattening + the wrapped module's parameters. + """ + for param_name, param in _named_parameters_with_duplicates(fsdp_module): + if param not in ignored_params and not _is_fsdp_flattened(param): + raise RuntimeError( + f"Found an unflattened parameter: {param_name}; " + f"{param.size()} {param.__class__}" + ) + + +def _get_default_comm_hook(sharding_strategy: ShardingStrategy): + return ( + default_hooks.allreduce_hook + if sharding_strategy == ShardingStrategy.NO_SHARD + else default_hooks.reduce_scatter_hook + ) + + +def _get_default_comm_hook_state( + process_group: dist.ProcessGroup, +) -> default_hooks.DefaultState: + return default_hooks.DefaultState(process_group=process_group) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/fsdp/_limiter_utils.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/fsdp/_limiter_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..f9b190585342ee267716abace19add022b4d6b3e --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/fsdp/_limiter_utils.py @@ -0,0 +1,33 @@ +import collections +from typing import Optional + +import torch + + +class _FreeEventQueue: + """ + This tracks all pending frees corresponding to inflight all-gathers. The + queueing pattern is iterative enqueues with a single dequeue per iteration + once the limit ``_max_num_inflight_all_gathers`` is reached. + """ + + def __init__(self) -> None: + self._queue: collections.deque[torch.Event] = collections.deque() + self._max_num_inflight_all_gathers = 2 # empirically chosen + + def enqueue(self, free_event: torch.Event) -> None: + """Enqueues a free event.""" + self._queue.append(free_event) + + def dequeue_if_needed(self) -> Optional[torch.Event]: + """Dequeues a single event if the limit is reached.""" + if len(self._queue) >= self._max_num_inflight_all_gathers: + return self._dequeue() + return None + + def _dequeue(self) -> Optional[torch.Event]: + """Dequeues a free event if possible.""" + if self._queue: + event = self._queue.popleft() + return event + return None diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/fsdp/_optim_utils.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/fsdp/_optim_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..564cfeece48ee1e656ea4e06628c36c0d01c0af8 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/fsdp/_optim_utils.py @@ -0,0 +1,2139 @@ +# mypy: allow-untyped-defs +import copy +import functools +import logging +import warnings +from collections.abc import Iterable, Iterator, Sequence +from contextlib import ExitStack +from dataclasses import dataclass, field +from itertools import chain +from typing import Any, cast, NamedTuple, no_type_check, Optional, TYPE_CHECKING, Union + +import torch +import torch.distributed as dist +import torch.distributed.fsdp._traversal_utils as traversal_utils +import torch.nn as nn +from torch.distributed._state_dict_utils import _gather_state_dict +from torch.distributed.distributed_c10d import _get_pg_default_device +from torch.distributed.fsdp._common_utils import ( + _apply_to_modules, + _FSDPState, + _get_module_fsdp_state_if_fully_sharded_module, + _get_param_to_fqns, + _module_handle, + _named_parameters_with_duplicates, + clean_tensor_name, +) +from torch.distributed.fsdp._debug_utils import SimpleProfiler +from torch.distributed.fsdp._flat_param import FlatParameter, FlatParamHandle +from torch.distributed.fsdp._fsdp_extensions import ( + _ext_chunk_dtensor, + _ext_chunk_tensor, +) +from torch.distributed.fsdp._runtime_utils import ( + _lazy_init, + _reset_flat_param_grad_info_if_needed, +) +from torch.distributed.fsdp.api import ( + ShardingStrategy, + StateDictSettings, + StateDictType, +) +from torch.distributed.tensor import DTensor, Replicate +from torch.utils._pytree import tree_map_only + + +if TYPE_CHECKING: + from torch.distributed._shard.sharded_tensor import ShardedTensor + + +logger = logging.getLogger(__name__) + + +@dataclass +class FSDPParamInfo: + state: _FSDPState + handle: FlatParamHandle + param_indices: dict[str, int] + param_requires_grad: list[bool] + + +def sorted_items(dictionary: dict[str, Any]) -> Iterator[tuple[str, Any]]: + keys = sorted(dictionary.keys()) + for k in keys: + yield k, dictionary[k] + + +@dataclass +class _ConsolidatedOptimState: + """ + This holds the consolidated optimizer state on the target rank. Positive- + dimension tensor state is communicated across ranks, while zero-dimension + tensor state and non-tensor state is taken directly from the target rank. + + PyTorch version 1.12 moved to using zero-dimension tensors for scalar + values, but user implemented optimizers may still use float (i.e. a + non-tensor). Thus, we support both and handle them identically. + + Attributes: + tensor_state (Dict[str, torch.Tensor]): Mapping from positive-dimension + tensor state name to the unsharded flat tensor representing the + state. + zero_dim_tensor_state (Dict[str, torch.Tensor]): Mapping from zero- + dimension tensor state name to its value. + non_tensor_state (Dict[str, Any]): Mapping from non-tensor state + name to its value. + """ + + tensor_state: dict[str, torch.Tensor] = field(default_factory=dict) + zero_dim_tensor_state: dict[str, torch.Tensor] = field(default_factory=dict) + non_tensor_state: dict[str, Any] = field(default_factory=dict) + + +class _PosDimTensorInfo(NamedTuple): + """ + Metadata for positive-dimension tensors used internally for + :meth:`scatter_full_optim_state_dict`. + + Attributes: + shape (torch.Size): Sharded tensor shape (which is equal to the + unsharded tensor shape if the tensor is optimizer state for a + non-FSDP parameter and is hence not sharded). + dtype (torch.dtype): Data type of the tensor. + """ + + shape: torch.Size + dtype: torch.dtype + + +class _OptimStateKey(NamedTuple): + """ + This represents an optimizer state key that may be used commonly across + ranks. It is based on the unflattened parameter names rather than parameter + IDs to make it independent of each rank's own optimizer construction. + """ + + unflat_param_names: tuple[str, ...] + is_fsdp_managed: bool + + +def _unflatten_optim_state( + fsdp_param_info: FSDPParamInfo, + flat_param_state: dict[str, Any], + to_save: bool, + shard_state: bool, + cpu_offload: bool, +) -> list[dict[str, Any]]: + """ + Unflattens the optimizer state, consisting of the "state" part and the + "param_groups" part. Unflattening the "state" part involves consolidating + the state on the target rank and remapping from flattened to unflattened + parameter IDs, and the "param_groups" part only involves remapping from + flattened to unflattened parameter IDs. + + Args: + fsdp_param_info (FSDPParamInfo): The FSDP state, the handle, and a + mapping from FQN to original parameter index. + flat_param_state (Dict[str, Any]): Entry for the flat parameter in the + "state" part of the optimizer state dict. + to_save (bool): Whether to save the state on this rank. + + Returns: + List[Dict[str, Any]]: A :class:`list` holding the entries in the + "state" part of the optimizer state dict corresponding to the + unflattened parameters comprising the flat parameter if on the target + rank or an empty :class:`list` otherwise. The final optimizer state + dict will need to map these entries using the proper unflattened + parameter IDs. + """ + if shard_state and not to_save: + raise AssertionError("If ``shard_state`` is True, ``to_save`` has to be True.") + consolidated_state = _communicate_optim_state( + fsdp_param_info, + flat_param_state, + ) + if to_save: + unflat_param_state = _unflatten_communicated_optim_state( + fsdp_param_info, + consolidated_state, + shard_state, + ) + for optim_state in unflat_param_state: + # We can't use .items() below cuz we'd run into a concurrent modification error + if cpu_offload: + for key in list(optim_state.keys()): + state = optim_state[key] + if not isinstance(state, torch.Tensor): + continue + optim_state[key] = state.cpu() + return unflat_param_state + else: + return [] + + +def _is_zero_dim_tensor(x: Any) -> bool: + return torch.is_tensor(x) and x.dim() == 0 + + +def _communicate_optim_state( + fsdp_param_info: FSDPParamInfo, + flat_param_state: dict[str, Any], +) -> _ConsolidatedOptimState: + """ + Communicates the optimizer state for a flat parameter across ranks. All + ranks will hold the entire non-sharded optimizer state on GPU. + + If ``N`` is the number of tensor optimizer states in the optimizer state + dict, then the communication complexity is 0 if ``N = 0`` and ``N + 1`` + otherwise (where the plus 1 comes from all-gathering the padding per rank). + + Args: + fsdp_param_info (FSDPParamInfo): The FSDP state, the handle, and a + mapping from FQN to original parameter index. + flat_param_state (Dict[str, Any]): The entry in the "state" part of the + optimizer state dict corresponding to the flat parameter. + + Returns: + ConsolidatedOptimState: Consolidated optimizer state for the target + flat parameter. + """ + fsdp_state = fsdp_param_info.state + flat_param = fsdp_param_info.handle.flat_param + state = _ConsolidatedOptimState() + tensor_state, zero_dim_tensor_state, non_tensor_state = ( + state.tensor_state, + state.zero_dim_tensor_state, + state.non_tensor_state, + ) + + for state_name, value in sorted_items(flat_param_state): + # Positive-dimension tensor state: communicate across ranks + if torch.is_tensor(value) and value.dim() > 0: + # If the parameter is not sharded, then neither is the + # positive-dimension tensor state, so no need to communicate it -- + # we take the target rank's value + if ( + fsdp_state.world_size == 1 + or fsdp_state.sharding_strategy == ShardingStrategy.NO_SHARD + ): + tensor_state[state_name] = value + continue + if fsdp_state.compute_device is None: + raise AssertionError("compute_device has not been initialized") + if value.device.type != fsdp_state.compute_device.type: + value = value.to(fsdp_state.compute_device) + # Assume that positive-dimension tensor optimizer state + # has the same shape as the sharded flat parameter + buffer_size = flat_param._full_param_padded.size() # type: ignore[attr-defined] + tensor_buffer = value.new_zeros(*buffer_size) + dist.all_gather_into_tensor( + tensor_buffer, value, group=fsdp_state.process_group + ) + fsdp_state._device_handle.synchronize() + unpadded_numel = cast( + nn.Parameter, flat_param._unpadded_unsharded_size + ).numel() + tensor_state[state_name] = tensor_buffer[:unpadded_numel] + # Zero-dimension tensor state and non-tensor state: take this rank's + # value directly + else: + if _is_zero_dim_tensor(value): + zero_dim_tensor_state[state_name] = value.detach().clone() + else: + non_tensor_state[state_name] = value + return state + + +def _unflatten_communicated_optim_state( + fsdp_param_info: FSDPParamInfo, + state: _ConsolidatedOptimState, + shard_state: bool, +) -> list[dict[str, Any]]: + """ + Unflattens the communicated optimizer state (given by ``tensor_state``, + ``non_tensor_state``, and ``zero_dim_tensor_state``) for a single flat + parameter. This should only be called on the target rank. + + Args: + fsdp_param_info (FSDPParamInfo): The FSDP state, the handle, and a + mapping from FQN to original parameter index. + state (_ConsolidatedOptimState): Consolidated optimizer state. + + Returns: + List[Dict[str, Any]]: A :class:`list` holding the entries in the + "state" part of the optimizer state dict corresponding to the + unflattened parameters comprising the flat parameter. The final + optimizer state dict will need to map these entries using the proper + unflattened parameter IDs. + """ + fsdp_state = fsdp_param_info.state + handle = fsdp_param_info.handle + flat_param = handle.flat_param + unflat_param_state: list[dict[str, Any]] = [] + flat_param_views: dict[str, Iterator] = {} + num_unflat_params = flat_param._num_params + tensor_state, zero_dim_tensor_state, non_tensor_state = ( + state.tensor_state, + state.zero_dim_tensor_state, + state.non_tensor_state, + ) + + for _ in range(num_unflat_params): + unflat_state_param = {} + # Add positive-dimension tensor state: unflatten with views + for state_name, flat_tensor in sorted_items(tensor_state): + views_generated = state_name in flat_param_views + if not views_generated: + views = handle._get_unflat_views(flat_tensor) + flat_param_views[state_name] = views + else: + views = flat_param_views[state_name] + optim_state: Union[torch.Tensor, ShardedTensor, DTensor] = next(views) + if shard_state: + osd_config = fsdp_state._optim_state_dict_config + if getattr(osd_config, "_use_dtensor", False): + if fsdp_state._device_mesh is None: + raise AssertionError( + f"Expected _device_mesh to be not None, got {fsdp_state._device_mesh}" + ) + optim_state = _ext_chunk_dtensor( + optim_state, + fsdp_state.rank, + fsdp_state._device_mesh, + fsdp_state._fsdp_extension, + ) + else: + if fsdp_state.process_group is None: + raise AssertionError( + f"Expected process_group to be not None, got {fsdp_state.process_group}" + ) + optim_state = _ext_chunk_tensor( + optim_state, + fsdp_state.rank, + fsdp_state.world_size, + fsdp_state._device_handle.device_count(), + fsdp_state.process_group, + fsdp_state._fsdp_extension, + ) + unflat_state_param[state_name] = optim_state + + # Add zero-dimension tensor state: take the target rank's value + unflat_state_param.update(sorted_items(zero_dim_tensor_state)) + # Add non-tensor state: take the target rank's value + unflat_state_param.update(sorted_items(non_tensor_state)) + unflat_param_state.append(unflat_state_param) + return unflat_param_state + + +def _broadcast_processed_state( + fsdp_state: _FSDPState, + optim_state: dict[str, Any], + group: Optional[dist.ProcessGroup], +) -> dict[str, Any]: + objects: list[Any] = [None] + if dist.get_rank(group) == 0: + objects[0] = tree_map_only( + torch.Tensor, + lambda v: v.cpu() if v.dim() == 0 else _PosDimTensorInfo(v.shape, v.dtype), # type: ignore[union-attr] + optim_state, + ) + dist.broadcast_object_list(objects, src=0, group=group) + if dist.get_rank(group) == 0: + return optim_state + else: + return objects[0] + + +def _broadcast_state( + fsdp_state: _FSDPState, state: Any, group: Optional[dist.ProcessGroup] +) -> Any: + if dist.get_rank(group) == 0: + if not isinstance(state, torch.Tensor) or state.dim() == 0: + return state + tensor = state.to(fsdp_state.compute_device) + else: + if isinstance(state, torch.Tensor): + if state.dim() != 0: + raise AssertionError( + "For non-zero ranks, a tensor state should have zero dimension, " + f"but got the state with shape {state.shape}." + ) + return state + elif not isinstance(state, _PosDimTensorInfo): + return state + tensor = torch.zeros( + state.shape, dtype=state.dtype, device=fsdp_state.compute_device + ) + dist.broadcast(tensor, src=0, group=group) + return tensor + + +def _shard_orig_param_state( + fsdp_param_info: FSDPParamInfo, + fqn: str, + optim_state: dict[str, Any], +) -> dict[str, Any]: + """ + Shard the optimizer state for the original parameter with the name ``fqn``. + This API should only be used when ``use_orig_params`` is True. + """ + if not optim_state: + return {} + fsdp_state = fsdp_param_info.state + flat_param = fsdp_param_info.handle.flat_param + param_idx = fsdp_param_info.param_indices[fqn] + shard_param_info = flat_param._shard_param_infos[param_idx] # type: ignore[attr-defined] + optim_state = _gather_state_dict( + optim_state, pg=fsdp_state.process_group, device=fsdp_state.compute_device + ) + if not shard_param_info.in_shard: + return {} + # Flatten and shard the state. + new_optim_state: dict[str, Any] = {} + intra_param_start_idx = shard_param_info.intra_param_start_idx + intra_param_end_idx = shard_param_info.intra_param_end_idx + for state_name, value in optim_state.items(): + if ( + torch.is_tensor(value) + and value.dim() > 0 + and fsdp_state.sharding_strategy != ShardingStrategy.NO_SHARD + ): + value = value.flatten()[ + intra_param_start_idx : intra_param_end_idx # type: ignore[operator] + + 1 + ].clone() + new_optim_state[state_name] = value + return new_optim_state + + +def _flatten_optim_state_dict( + optim_state_dict: dict[str, Any], + model: nn.Module, + use_orig_params: bool = False, + optim: Optional[torch.optim.Optimizer] = None, + rank0_only: bool = False, + group: Optional[dist.ProcessGroup] = None, +) -> dict[str, Any]: + """ + Flattens the full optimizer state dict, still keying by unflattened parameter + names. + + If ``use_orig_params`` is True, each rank will have all FSDP-managed + parameters but some of these parameters may be empty due to the sharding. + For a regular optim.Optimizer, states for those empty parameters will + not be initialized. So, when aggregating the FQNs across ranks, no assert + will be raised on a rank even if it does not have all the states -- it is + valid and FSDP know how to aggregate them. However, FSDP has to ignore + handling those parameters that are not managed by FSDP and do not exist on + the local rank -- it is managed by other parallelism and FSDP does not + know ho to handle/aggregate them. + + Note that ``_flatten_tensor_optim_state`` does not need ``optim`` to + flatten/shard the state. However, NamedOptimizer and KeyedOptimizer require + all the states even if the corresponding parameters are empty. To this end, + ``optim`` will be used to get the initial state of the empty parameters. + ``optim`` should only be non-None if the ``optim` is KeyedOptimizer or + NamedOptimizer. + + Returns: + Dict[str, Any]: The flattened optimizer state dict. + """ + SimpleProfiler.reset() + + unflat_osd = optim_state_dict + if "state" not in unflat_osd and not rank0_only: + raise ValueError( + '`optim_state_dict` must have the keys "state"' + "to be a valid optimizer state dict" + ) + param_to_fqns = _get_param_to_fqns(model) + fqn_to_fsdp_param_info = _get_fqn_to_fsdp_param_info(model) + fsdp_state = next(iter(fqn_to_fsdp_param_info.values())).state + + # Broadcast unflat_osd without non-scalar tensor if rank0_only is True. + if rank0_only: + unflat_osd = _broadcast_processed_state(fsdp_state, unflat_osd, group=group) + + # Construct the "state" part + flat_osd_state: dict[Union[_OptimStateKey, str], Any] = {} + unflat_osd_state = unflat_osd["state"] + all_state_keys = set(unflat_osd_state.keys()) + + for param, fqns in param_to_fqns.items(): + fqn = fqns[0] + if fqn not in unflat_osd_state: + continue + all_state_keys.difference_update(fqns) + + if rank0_only: + for fqn in fqns: + if not unflat_osd_state[fqn]: + continue + for state_name in unflat_osd_state[fqn]: + unflat_osd_state[fqn][state_name] = _broadcast_state( + fsdp_state, unflat_osd_state[fqn][state_name], group=group + ) + fqn = fqns[0] + if fqn in fqn_to_fsdp_param_info: + fsdp_param_info = fqn_to_fsdp_param_info[fqn] + if use_orig_params: + with SimpleProfiler.profile(SimpleProfiler.Type.RESHARDING): + flat_state = _shard_orig_param_state( + fsdp_param_info, + fqn, + unflat_osd_state[fqn], + ) + else: + flat_state = _flatten_optim_state( + fsdp_param_info, + unflat_osd_state, + fqns, + ) + key = _OptimStateKey(tuple(fqns), True) + # Only include non-empty states since as expected by + # `torch.optim.Optimizer` s unless the optimizer is KeyedOptimizer + # or NamedOptimizer. + if flat_state: + flat_osd_state[key] = flat_state + elif use_orig_params: + if len(fqns) != 1: + raise AssertionError( + f"use_orig_params is True but there are multiple FQNs, {fqns}." + ) + if optim is not None: # NamedOptimizer or KeyedOptimizer case. + state = optim.state.get(param, None) # type: ignore[call-overload] + if state is not None: + flat_osd_state[key] = copy.deepcopy(state) + else: + warnings.warn( + f"optim_state[{key}] is not on rank{fsdp_state.rank}.", + stacklevel=2, + ) + + else: + raise RuntimeError( + f"The state of {key} is empty. This should happen when " + "use_orig_params=True." + ) + else: # do not flatten non-FSDP parameters' states + if len(fqns) != 1: + raise AssertionError(f"Expected len(fqns) == 1, got {len(fqns)}") + key = _OptimStateKey(tuple(fqns), False) + flat_osd_state[key] = copy.copy(unflat_osd_state[fqn]) + + if rank0_only: + for fqn in fqns: + if not unflat_osd_state[fqn]: + continue + for state_name, param_state in list(unflat_osd_state[fqn].items()): + if fsdp_state.rank > 0: + # Deference the tensor so that PyTorch can collect the memory. + del unflat_osd_state[fqn][state_name] + else: + # Move the tensor in the original osd back to CPU to make the + # original osd unaffected. + unflat_osd_state[fqn][state_name] = param_state.cpu() + + # Handle user-defined state, states that are not associated with parameters. + for key in all_state_keys: + user_state = unflat_osd_state[key] + if isinstance(user_state, torch.Tensor) and rank0_only and use_orig_params: + user_state = _broadcast_state(fsdp_state, user_state, group=group) + flat_osd_state[key] = copy.copy(user_state) + + SimpleProfiler.dump_and_reset("FSDP _flatten_optim_state_dict() profiling: ") + # Construct the "param_groups" part -- copy as is since it will be + # rekeyed later according to the target rank's optimizer + # Only copy param_groups if it exists in unflat_osd + if "param_groups" in unflat_osd: + flat_osd_param_groups = copy.deepcopy(unflat_osd["param_groups"]) + return {"state": flat_osd_state, "param_groups": flat_osd_param_groups} + else: + return {"state": flat_osd_state} + + +def _flatten_optim_state( + fsdp_param_info: FSDPParamInfo, + unflat_osd_state: dict[str, dict[str, Any]], + unflat_param_names: list[str], +) -> dict[str, Any]: + """ + Flattens the optimizer state in ``full_optim_state_dict`` for a single + flat parameter in ``fsdp_param_info`` corresponding to the unflattened + parameter names in ``unflat_param_names``. + + Args: + fsdp_param_info (FSDPParamInfo): The FSDP state, the handle, and a + mapping from FQN to original parameter index. + unflat_osd_state (Dict[str, Dict[str, Any]]): The "state" part of the + optimizer state dict corresponding to the unflattened parameters. + unflat_param_names (List[str]): A :class:`list` of unflattened + parameter names corresponding to the flat parameter ``flat_param``. + + Returns: + Dict[str, Any]: A :class:`dict` mapping state names to their values for + a particular flat parameter. The sharded optimizer state dict's "state" + part will map a key to this returned value. + """ + fsdp_state = fsdp_param_info.state + handle = fsdp_param_info.handle + flat_param = handle.flat_param + num_unflat_params = len(unflat_param_names) + if num_unflat_params <= 0: + raise AssertionError( + "Expects at least one unflattened parameter corresponding to the flat parameter" + ) + unflat_param_shapes = flat_param._shapes + num_unflat_param_shapes = len(unflat_param_shapes) + if num_unflat_params != num_unflat_param_shapes: + raise AssertionError( + f"Expects {num_unflat_params} shapes but got {num_unflat_param_shapes}" + ) + + # Check if these unflattened parameters have any optimizer state + has_state = [ + bool(unflat_param_name in unflat_osd_state) + for unflat_param_name in unflat_param_names + ] + # If none of the unflattened parameters comprising this flat parameter have + # any state, then we do not want an entry in the optimizer state dict + if not any(has_state): + return {} # no need to flatten any state + # There may still be some unflattened parameters with state and some + # without + unflat_param_states = [ + _gather_state_dict( + unflat_osd_state[unflat_param_name], + pg=fsdp_state.process_group, + device=fsdp_state.compute_device, + ) + if unflat_param_name in unflat_osd_state + else None + for unflat_param_name in unflat_param_names + ] + # Check that the unflattened parameters have the same state names + state_names = None + # pyrefly: ignore [bad-assignment] + for unflat_param_state in unflat_param_states: + if unflat_param_state is None: + continue + if state_names is None: + state_names = set(unflat_param_state.keys()) + else: + if state_names != set(unflat_param_state.keys()): + raise ValueError( + "Differing optimizer state names for the unflattened " + f"parameters: {unflat_param_names}" + ) + if state_names is None: + raise AssertionError(f"Expected state_names to be not None, got {state_names}") + + # Flatten the state + flat_state: dict[str, Optional[torch.Tensor]] = {} + for state_name in state_names: + state_values = [ + unflat_param_state[state_name] if unflat_param_state is not None else None + for unflat_param_state in unflat_param_states + ] + non_none_state_values = [v for v in state_values if v is not None] + # If all ranks have None, this is a None value + if not non_none_state_values: + flat_state[state_name] = None + continue + are_pos_dim_tensors = are_zero_dim_tensors = are_non_tensors = True + for v in non_none_state_values: + are_pos_dim_tensors &= torch.is_tensor(v) and v.dim() > 0 + are_zero_dim_tensors &= _is_zero_dim_tensor(v) + are_non_tensors &= not torch.is_tensor(v) + types = {type(v) for v in non_none_state_values} + if len(types) != 1 or not ( + are_pos_dim_tensors or are_zero_dim_tensors or are_non_tensors + ): + raise ValueError( + f"Differing optimizer state types for state {state_name}, " + f"values {non_none_state_values}, and unflattened parameter " + f"names {unflat_param_names}" + ) + if are_pos_dim_tensors: + flat_tensor = _flatten_tensor_optim_state( + state_name, + state_values, # type: ignore[arg-type] + unflat_param_names, + unflat_param_shapes, + handle, + ) + # Shard the flattened tensor immediately to minimize max memory + # usage + if ( + fsdp_state.world_size != 1 + and fsdp_state.sharding_strategy != ShardingStrategy.NO_SHARD + ): + sharded_flat_tensor, _ = FlatParamHandle._get_shard( + flat_tensor, + fsdp_state.rank, + fsdp_state.world_size, + ) + else: + sharded_flat_tensor = flat_tensor + flat_state[state_name] = sharded_flat_tensor + elif are_zero_dim_tensors: + flat_state[state_name] = _flatten_zero_dim_tensor_optim_state( + state_name, + state_values, # type: ignore[arg-type] + unflat_param_names, + ) + else: + if not are_non_tensors: + raise AssertionError( + f"Expected are_non_tensors to be True, got {are_non_tensors}" + ) + flat_state[state_name] = _flatten_non_tensor_optim_state( + state_name, + state_values, + unflat_param_names, + ) + + return flat_state + + +def _flatten_tensor_optim_state( + state_name: str, + pos_dim_tensors: list[torch.Tensor], + unflat_param_names: list[str], + unflat_param_shapes: Sequence[torch.Size], + handle: FlatParamHandle, +) -> torch.Tensor: + """ + Flattens the positive-dimension tensor optimizer state given by the values + ``tensors`` for the state ``state_name`` for a single flat parameter + from ``handle`` corresponding to the unflattened parameter names + ``unflat_param_names`` and unflatted parameter shapes + ``unflat_param_shapes``. This flattens each unflattened parameter's tensor + state into one tensor. + + NOTE: We use zero tensors for any unflattened parameters without state + since some value is required to fill those entries. This assumes that the + zero tensor is mathematically equivalent to having no state, which is true + for Adam's "exp_avg" and "exp_avg_sq" but may not be true for all + optimizers. + + Args: + state_name (str): Optimizer state name. + pos_dim_tensors (List[torch.Tensor]): Positive-dimension tensor + optimizer state values for the unflattened parameters corresponding + to the single flat parameter. + unflat_param_names (List[str]): A :class:`list` of unflattened + parameter names corresponding to the single flat parameter. + unflat_param_shapes (List[torch.Size]): Unflattened parameter shapes + corresponding to the single flat parameter. + handle (FlatParamHandle): The flat parameter's handle. + + Returns: + torch.Tensor: A flat tensor containing the optimizer state + corresponding to ``state_name`` constructed by concatenating the + unflattened parameter tensor states in ``pos_dim_tensors`` (using zero + tensors for any unflattened parameters without the state). + """ + flat_param = handle.flat_param + non_none_tensors = [t for t in pos_dim_tensors if t is not None] + # Check that all are tensors with the same dtype + dtypes = {t.dtype for t in non_none_tensors} + if len(dtypes) != 1: + raise ValueError( + "All unflattened parameters comprising a single flat " + "parameter must have positive-dimension tensor state with the " + f"same dtype but got dtypes {dtypes} for state {state_name} and " + f"unflattened parameter names {unflat_param_names}" + ) + dtype = next(iter(dtypes)) + # Check that each tensor state matches its parameter's shape + for tensor, shape in zip(pos_dim_tensors, unflat_param_shapes): + if tensor is None and len(shape) == 0: + raise ValueError("Flattening a zero-dimension parameter is not supported") + elif tensor is not None and tensor.shape != shape: + raise ValueError( + "Tensor optimizer state does not have same shape as its " + f"parameter: {tensor.shape} {shape}" + ) + # Flatten the tensor states: we do not need to add any right-hand-side + # padding since the flat optimizer state tensor is sharded via + # `_get_shard()`, which pads the shard as needed (just like for the flat + # parameter) + cpu_device = torch.device("cpu") + tensors_to_flatten = [ + torch.flatten(state_value.to(cpu_device)) + if state_value is not None + else torch.flatten( + torch.zeros( + size=shape, + dtype=dtype, + device=cpu_device, + ) + ) + for state_value, shape in zip(pos_dim_tensors, unflat_param_shapes) + ] + flat_tensor = handle.flatten_tensors(tensors_to_flatten, handle._aligned_numel) + flat_param_shape = flat_param._unpadded_unsharded_size # type: ignore[attr-defined] + if flat_tensor.shape != flat_param_shape: + raise AssertionError( + f"tensor optim state: {flat_tensor.shape} flat parameter: {flat_param_shape}" + ) + return flat_tensor + + +def _flatten_zero_dim_tensor_optim_state( + state_name: str, + zero_dim_tensors: list[torch.Tensor], + unflat_param_names: list[str], +) -> torch.Tensor: + """ + Flattens the zero-dimension tensor optimizer state given by the values + ``zero_dim_tensors`` for the state ``state_name`` for a single flat + parameter corresponding to the unflattened parameter names + ``unflat_param_names`` by enforcing that all tensors are the same and using + that common value. + + NOTE: The requirement that the tensors are the same across all unflattened + parameters comprising the flat parameter is needed to maintain the + invariant that FSDP performs the same computation as its non-sharded + equivalent. This means that none of the unflattened parameters can be + missing this state since imposing a value may differ from having no value. + For example, for Adam's "step", no value means maximum bias correction, + while having some positive value means less bias correction. + + Args: + state_name (str): Optimizer state name. + zero_dim_tensors (List[torch.Tensor]): Zero-dimension optimizer state + for the unflattened parameters corresponding to the single + flat parameter. + unflat_param_names (List[str]): A :class:`list` of unflattened + parameter names corresponding to the single flat parameter. + + Returns: + torch.Tensor: A zero-dimensional tensor giving the value of the state + ``state_name`` for all unflattened parameters corresponding to the + names ``unflat_param_names``. + """ + non_none_tensors = [t for t in zero_dim_tensors if t is not None] + # Enforce that all have the same value and dtype + values_set = {t.item() if t is not None else None for t in zero_dim_tensors} + dtypes = {t.dtype if t is not None else None for t in zero_dim_tensors} + if ( + len(non_none_tensors) != len(zero_dim_tensors) + or len(values_set) != 1 + or len(dtypes) != 1 + ): + raise ValueError( + "All unflattened parameters comprising a single flat " + "parameter must have scalar state with the same value and dtype " + f"but got values {values_set} and dtypes {dtypes} for state " + f"{state_name} and unflattened parameter names " + f"{unflat_param_names}" + ) + value = next(iter(values_set)) + dtype = next(iter(dtypes)) + return torch.tensor(value, dtype=dtype, device=torch.device("cpu")) + + +def _flatten_non_tensor_optim_state( + state_name: str, + non_tensors: list[Any], + unflat_param_names: list[str], +) -> Any: + """ + Flattens the non-tensor optimizer state given by the values ``non_tensors`` + for the state ``state_name`` for a single flat parameter corresponding + to the unflattened parameter names ``unflat_param_names`` by enforcing that + all values are the same and using that common value. + + See the note in :func:`_flatten_zero_dim_tensor_optim_state`. + + Args: + state_name (str): Optimizer state name. + non_tensors (List[Any]): Non-tensor optimizer state for the unflattened + parameters corresponding to the single flat parameter. + unflat_param_names (List[str]): A :class:`list` of unflattened + parameter names corresponding to the single flat parameter. + + Returns: + Any: A non-tensor giving the value of the state ``state_name`` for all + unflattened parameters corresponding to the names + ``unflat_param_names``. + """ + non_none_non_tensors = [nt for nt in non_tensors if nt is not None] + # Enforce that all have the same value (same type already checked) + non_tensor_set = set(non_tensors) + if len(non_none_non_tensors) != len(non_tensors) or len(non_tensor_set) != 1: + raise ValueError( + "All unflattened parameters comprising a single flat " + "parameter must have scalar state with the same value and dtype " + f"but got values {non_tensor_set} for state {state_name} and " + f"unflattened parameter names {unflat_param_names}" + ) + non_tensor = next(iter(non_tensor_set)) + return non_tensor + + +def _rekey_sharded_optim_state_dict( + sharded_osd: dict[str, Any], + model: nn.Module, + optim: torch.optim.Optimizer, + optim_input: Optional[ + Union[ + list[dict[str, Any]], + Iterable[nn.Parameter], + ] + ], + using_optim_input: bool, + is_named_optimizer: bool = False, +) -> dict[str, Any]: + """ + Rekeys the optimizer state dict from unflattened parameter names to flat + parameter IDs according to the calling rank's ``optim``, which may be + different across ranks. In particular, the unflattened parameter names are + represented as :class:`_OptimStateKey` s. + """ + param_to_fqns = _get_param_to_fqns(model) + flat_param_to_fqn = _get_flat_param_to_fqn(model) + param_to_param_key: dict[nn.Parameter, Union[int, str]] = cast( + dict[nn.Parameter, Union[int, str]], + ( + _get_param_to_param_id_from_optim_input(model, optim_input) + if using_optim_input + else _get_param_to_param_key( + optim, model, is_named_optimizer, param_to_fqns, flat_param_to_fqn + ) + ), + ) + # All parameter keys in `param_to_param_key` should be in + # `param_to_fqns` -- strict inequality follows when not all parameters are + # passed to the optimizer + if len(param_to_param_key) > len(param_to_fqns): + raise AssertionError( + f"Expected len(param_to_param_key) <= len(param_to_fqns), got {len(param_to_param_key)} > {len(param_to_fqns)}" + ) + + unflat_param_names_to_flat_param_key: dict[ + tuple[str, ...], Union[int, str] + ] = {} # for "state" + unflat_param_name_to_flat_param_key: dict[ + str, Union[int, str] + ] = {} # for "param_groups" + for param, unflat_param_names in param_to_fqns.items(): + if param not in param_to_param_key: + # This parameter was not passed to the optimizer + continue + flat_param_key = param_to_param_key[param] + unflat_param_names_to_flat_param_key[tuple(unflat_param_names)] = flat_param_key + for unflat_param_name in unflat_param_names: + unflat_param_name_to_flat_param_key[unflat_param_name] = flat_param_key + + sharded_osd_state = sharded_osd["state"] + rekeyed_osd_state: dict[Union[str, int], Any] = {} + for key, param_state in sharded_osd_state.items(): + if isinstance(key, str): + rekeyed_osd_state[key] = param_state + continue + flat_param_key = unflat_param_names_to_flat_param_key.get( + key.unflat_param_names, key.unflat_param_names + ) + # pyrefly: ignore [unsupported-operation] + rekeyed_osd_state[flat_param_key] = param_state + + # Only process param_groups if it exists in sharded_osd + if "param_groups" in sharded_osd: + rekeyed_osd_param_groups: list[dict[str, Any]] = [] + for unflat_param_group in sharded_osd["param_groups"]: + flat_param_group = copy.deepcopy(unflat_param_group) + flat_param_keys = sorted( + { + unflat_param_name_to_flat_param_key[unflat_param_name] + for unflat_param_name in unflat_param_group["params"] + } + ) + flat_param_group["params"] = flat_param_keys + rekeyed_osd_param_groups.append(flat_param_group) + return {"state": rekeyed_osd_state, "param_groups": rekeyed_osd_param_groups} + else: + return {"state": rekeyed_osd_state} + + +def _get_param_id_to_param_from_optim_input( + model: nn.Module, + optim_input: Optional[ + Union[ + list[dict[str, Any]], + Iterable[nn.Parameter], + ] + ] = None, +) -> dict[int, nn.Parameter]: + """ + Constructs a mapping from parameter IDs to parameters. This may be used + both for models with ``FlatParameter`` s and without. + + NOTE: This method is only preserved for backward compatibility. The method + :meth:`_get_param_key_to_param` is the preferred code path that does not + rely on ``optim_input``. + + NOTE: We critically assume that, whether the optimizer input is a list of + parameters or a list of parameter groups, :class:`torch.optim.Optimizer` + enumerates the parameter IDs in order. In other words, for a parameter list + input, the parameter IDs should be in that list order, and for a parameter + groups input, the parameter IDs should be in order within each parameter + group and in order across parameter groups. + + Args: + model (nn.Module): Model whose parameters are passed into the + optimizer. + optim_input (Optional[Union[List[Dict[str, Any]], + Iterable[nn.Parameter]]]): Input passed into the optimizer + representing either a :class:`list` of parameter groups or an + iterable of parameters; if ``None``, then this method assumes the + input was ``model.parameters()``. (Default: ``None``) + + Returns: + List[nn.Parameter]: Mapping from parameter IDs to parameters, + where the parameter ID is implicitly the index in the :class:`list`. + """ + # Assume the standard case of passing `model.parameters()` to the optimizer + # if `optim_input` is not specified + if optim_input is None: + return dict(enumerate(model.parameters())) + try: + # pyrefly: ignore [no-matching-overload] + # pyrefly: ignore [redundant-cast] + params = cast(list[nn.Parameter], list(optim_input)) + except TypeError as e: + raise TypeError( + "Optimizer input should be an iterable of Tensors or dicts, " + f"but got {optim_input}" + ) from e + if len(params) == 0: + raise ValueError("Optimizer input should not be empty") + + # Check if the optimizer input represents tensors or parameter groups + all_tensors = True + all_dicts = True + for param in params: + all_tensors &= isinstance(param, torch.Tensor) + all_dicts &= isinstance(param, dict) + if not all_tensors and not all_dicts: + raise TypeError("Optimizer input should be an iterable of Tensors or dicts") + if all_tensors: + return dict(enumerate(params)) + if not all_dicts: + raise AssertionError(f"Expected all_dicts to be True, got {all_dicts}") + param_id_to_param: list[nn.Parameter] = [] + for param_group in params: + has_params_key = "params" in param_group # type: ignore[operator] + if not has_params_key: + raise AssertionError( + 'A parameter group should map "params" to a list of the parameters in the group' + ) + # Implicitly map `flat_param_id` (current length of the list) to + # `param` + param_id_to_param.extend(param_group["params"]) # type: ignore[index] + return dict(enumerate(param_id_to_param)) + + +def _get_flat_param_to_fqn(model: torch.nn.Module) -> dict[FlatParameter, str]: + """ + Constructs a mapping from ``FlatParameter`` to a cleaned (devoid of prefixes + from wrappers) fully qualified name (FQN). Note that this FQN is "non-canonical" + because ``FlatParameter`` s do not come from the original module but are + registered only after FSDP has been applied. This function returns the FSDP-given + name for the ``FlatParameter`` (usually module._flat_param) as opposed to the + canonical FQNs returned for ``FlatParameter`` s in ``_common_utils._get_param_to_fqns(...)``). + + Consequently, this function will only return a non-empty mapping if FSDP was + applied with ``use_orig_params=False`` as, otherwise, the original parameters + are used within the module and there would be no ``FlatParameter`` s in the module. + + """ + + def module_fn(module, prefix, tree_level, flat_param_to_fqn): + for param_name, param in _named_parameters_with_duplicates( + module, recurse=False + ): + if not isinstance(param, FlatParameter): + continue + fqn = clean_tensor_name(prefix + param_name) + flat_param_to_fqn[param] = fqn + + def return_fn(flat_param_to_fqn): + return flat_param_to_fqn + + flat_param_to_fqn_ret: dict[FlatParameter, str] = {} + return _apply_to_modules( + model, + module_fn, + return_fn, + [fqn for fqn, _ in _named_parameters_with_duplicates(model)], + flat_param_to_fqn_ret, + ) + + +def _get_param_key_to_param( + optim: torch.optim.Optimizer, + model: Optional[nn.Module] = None, + is_named_optimizer: bool = False, + param_to_fqns: Optional[dict[nn.Parameter, list[str]]] = None, + flat_param_to_fqn: Optional[dict[FlatParameter, str]] = None, +) -> dict[Union[int, str], nn.Parameter]: + """ + Constructs a mapping from parameter keys to parameters. For the regular + optimizers, the keys are parameter IDs. For NamedOptimizer, the keys + are FQNs. This API may be used both for models with ``FlatParameter`` s and + without. + """ + clean_fqn_to_curr_fqn: dict[str, str] = {} + if is_named_optimizer: + if param_to_fqns is None or flat_param_to_fqn is None: + raise AssertionError( + "The optimizer is a NamedOptimizer, `param_to_fqns` must not be None." + ) + if model is None: + raise AssertionError(f"Expected model to be not None, got {model}") + for key, _ in _named_parameters_with_duplicates(model): + clean_fqn_to_curr_fqn[clean_tensor_name(key)] = key + + param_key_to_param: dict[Union[str, int], nn.Parameter] = {} + pid = 0 + for param_group in optim.param_groups: + if is_named_optimizer: + for param in param_group["params"]: + if flat_param_to_fqn is None: + raise AssertionError( + f"Expected flat_param_to_fqn to be not None, got {flat_param_to_fqn}" + ) + if param in flat_param_to_fqn: + # FlatParameter case + key = flat_param_to_fqn[param] + else: + if param_to_fqns is None: + raise AssertionError( + f"Expected param_to_fqns to be not None, got {param_to_fqns}" + ) + # use_orig_params case + if len(param_to_fqns[param]) != 1: + raise AssertionError( + f"Expected len(param_to_fqns[param]) == 1, got {len(param_to_fqns[param])}" + ) + key = param_to_fqns[param][0] + try: + key = clean_fqn_to_curr_fqn[key] + except KeyError as e: + raise KeyError( + f"Can't find {key} from {list(clean_fqn_to_curr_fqn.keys())}." + ) from e + param_key_to_param[key] = param + else: + for param in param_group["params"]: + param_key_to_param[pid] = param + pid += 1 + + return param_key_to_param + + +def _get_param_to_param_key( + optim: torch.optim.Optimizer, + model: Optional[nn.Module] = None, + is_named_optimizer: bool = False, + param_to_fqns: Optional[dict[nn.Parameter, list[str]]] = None, + flat_param_to_fqn: Optional[dict[FlatParameter, str]] = None, +) -> dict[nn.Parameter, Union[int, str]]: + """ + Constructs the inverse mapping of :func:`_get_param_key_to_param`. This API + only supports the case where `optim` is a regular optimizer, not NamedOptimizer. + So the parameter keys will be parameter ids. + """ + param_id_to_param = _get_param_key_to_param( + optim, model, is_named_optimizer, param_to_fqns, flat_param_to_fqn + ) + return {param: param_id for param_id, param in param_id_to_param.items()} + + +def _get_param_to_param_id_from_optim_input( + model: nn.Module, + optim_input: Optional[ + Union[ + list[dict[str, Any]], + Iterable[nn.Parameter], + ] + ] = None, +) -> dict[nn.Parameter, int]: + """Constructs the inverse mapping of :func:`_get_param_id_to_param_from_optim_input`.""" + param_id_to_param = _get_param_id_to_param_from_optim_input(model, optim_input) + return {param: param_id for param_id, param in param_id_to_param.items()} + + +def _check_missing_keys_on_rank( + r0_optim_state_keys: list[_OptimStateKey], + optim_state_key_to_param_key: dict[_OptimStateKey, Union[str, int]], + param_key_to_param: dict[Union[str, int], nn.Parameter], + group: Optional[dist.ProcessGroup], +) -> None: + # Ensure that all ranks have at least the optimizer states needed by + # rank 0's optimizer + missing_keys: list[_OptimStateKey] = [] + for r0_optim_state_key in r0_optim_state_keys: + if r0_optim_state_key not in optim_state_key_to_param_key: + # A parameter from rank 0's optimizer does not exist for this + # rank's optimizer + missing_keys.append(r0_optim_state_key) + continue + param_key = optim_state_key_to_param_key[r0_optim_state_key] + if isinstance(param_key, int): + if not (param_key >= 0 and param_key < len(param_key_to_param)): + raise AssertionError("Check the `param_key_to_param` construction") + # We cannot use FSDPState.compute_device as this API is a global view. + device = _get_pg_default_device(group) + num_missing = torch.tensor([len(missing_keys)], dtype=torch.int32, device=device) + dist.all_reduce(num_missing, group=group) + if num_missing.item() > 0: + obj_list = [None for _ in range(dist.get_world_size(group))] + dist.all_gather_object(obj_list, missing_keys, group=group) + error_msg = ( + "FSDP currently requires each rank to have at least the " + "optimizer states needed by rank 0's optimizer but some ranks " + "are missing some of those states" + ) + for rank, keys in enumerate(obj_list): + keys = cast(list[_OptimStateKey], keys) + if len(keys) > 0: + error_msg += ( + f"\nRank {rank} is missing states for the parameters: " + f"{[key.unflat_param_names for key in keys]}" + ) + raise RuntimeError(error_msg) + + +def _map_param_key_to_optim_keys( + optim_state_dict: dict[str, Any], + group: Optional[dist.ProcessGroup], + param_key_to_param: dict[Union[int, str], nn.Parameter], + param_to_fqns: dict[nn.Parameter, list[str]], + fqn_to_fsdp_param_info: dict[str, FSDPParamInfo], + merge_keys: bool = False, +) -> tuple[list[_OptimStateKey], dict[_OptimStateKey, Union[int, str]]]: + """ + Construct the local mapping between the ``_OptimStateKey`` and parameter keys + and all the ``_OptimStateKey`` across ranks. If ``merge_keys`` is False, rank0 + must contain all the ``_OptimStateKey``, an exception will be raised otherwise. + Note that ``merge_keys`` should equal to ``use_orig_params``. + """ + rank = dist.get_rank(group) + optim_state_key_to_param_key: dict[_OptimStateKey, Union[int, str]] = {} # local + all_optim_state_keys: list[_OptimStateKey] = [] + + for param_key, param in param_key_to_param.items(): + # Do not include parameters without state to avoid empty mappings + # just like in normal `torch.optim.Optimizer.state_dict()` + if param_key not in optim_state_dict["state"]: + continue + fqns = param_to_fqns[param] + is_fsdp_managed = isinstance(param, FlatParameter) + if is_fsdp_managed: + if fqns[0] not in fqn_to_fsdp_param_info: + raise AssertionError( + f"Expected {fqns[0]} to be in fqn_to_fsdp_param_info, got keys: {list(fqn_to_fsdp_param_info.keys())}" + ) + is_fsdp_managed = fqns[0] in fqn_to_fsdp_param_info + optim_state_key = _OptimStateKey( + unflat_param_names=tuple(fqns), + is_fsdp_managed=is_fsdp_managed, + ) + if rank == 0 or merge_keys: + all_optim_state_keys.append(optim_state_key) + optim_state_key_to_param_key[optim_state_key] = param_key + + if merge_keys: + all_keys: list[list[_OptimStateKey]] = [ + [] for _ in range(dist.get_world_size(group)) + ] + dist.all_gather_object(all_keys, all_optim_state_keys, group=group) + merge_all_optim_state_keys = [*chain.from_iterable(all_keys)] + all_optim_state_keys = sorted(set(merge_all_optim_state_keys)) + else: + key_obj_list: list[Optional[list[_OptimStateKey]]] = ( + [all_optim_state_keys] if rank == 0 else [None] + ) + dist.broadcast_object_list(key_obj_list, src=0, group=group) + if key_obj_list[0] is None: + raise AssertionError( + f"Expected key_obj_list[0] to be not None, got {key_obj_list[0]}" + ) + all_optim_state_keys = key_obj_list[0] + _check_missing_keys_on_rank( + all_optim_state_keys, + optim_state_key_to_param_key, + param_key_to_param, + group, + ) + + return all_optim_state_keys, optim_state_key_to_param_key + + +def _unflatten_param_groups( + state_dict: dict[str, Any], + param_key_to_param: dict[Union[int, str], nn.Parameter], + param_to_fqns: dict[nn.Parameter, list[str]], +) -> list[dict[str, Any]]: + param_groups: list[dict[str, Any]] = [] + for flat_param_group in state_dict["param_groups"]: + unflat_param_group = copy.deepcopy(flat_param_group) + param_group_params = [ + param_key_to_param[flat_param_key] + for flat_param_key in flat_param_group["params"] + ] + nested_unflat_param_names = [ + param_to_fqns[param] for param in param_group_params + ] + unflat_param_group["params"] = [ + *chain.from_iterable(nested_unflat_param_names) + ] # flatten the list of lists + param_groups.append(unflat_param_group) + return param_groups + + +def _is_named_optimizer(optim_state_dict: dict[str, Any]) -> bool: + """ + Returns whether the state_dict is from a NamedOptimizer. + This function checks that the keys in the state_dict['state'] are strings + (which usually are FQNs) versus integers (which usually refer to param_ids + from a vanilla torch.optim.Optimizer). + """ + state = optim_state_dict.get("state") + if not state: + # If we cannot find a state, assume it is not NamedOptimizer as + # NamedOptimizer has eager initialization. + return False + try: + key = next(iter(state.keys())) + except Exception as e: + raise Exception(optim_state_dict) from e # noqa: TRY002 + return isinstance(key, str) + + +@dataclass +class StateInfo: + # The key of these dictionaries are the state name, e.g., `exp_avg`. + tensors: dict[str, _PosDimTensorInfo] + scalar_tensors: dict[str, torch.Tensor] + non_tensors: dict[str, Any] + + +def _allgather_state_info( + fsdp_state: _FSDPState, + input_states: dict[str, Any], +) -> list[dict[str, StateInfo]]: + """ + Given the ``input_states``, allgather StateInfo for each state. The function + uses all_gather_object to gather StateInfo so no GPU tensors are sent. + """ + + processed_state_dict: dict[str, StateInfo] = {} + gathered_state_info: list[dict[str, StateInfo]] = [ + {} for _ in range(fsdp_state.world_size) + ] + + for fqn, optim_state in input_states.items(): + # Allgather the scalar tensor state, non-tensor states and tensors metadata. + processed_state = StateInfo({}, {}, {}) + for state_name, value in sorted_items(optim_state): + if torch.is_tensor(value): + if value.dim() == 0: + # Ensure that `step` is on CPU. + processed_state.scalar_tensors[state_name] = value.cpu() + else: + processed_state.tensors[state_name] = _PosDimTensorInfo( + value.shape, value.dtype + ) + else: + processed_state.non_tensors[state_name] = value + processed_state_dict[fqn] = processed_state + dist.all_gather_object( + gathered_state_info, + processed_state_dict, + group=fsdp_state.process_group, + ) + return gathered_state_info + + +def _convert_all_state_info( + fsdp_param_info: FSDPParamInfo, + gathered_state_info: list[dict[str, StateInfo]], + input_states: dict[str, Any], + output_states: dict[str, dict[str, Any]], +) -> tuple[Optional[torch.dtype], dict[str, list[Optional[torch.Tensor]]]]: + """ + Given the ``gathered_state_info`` and ``input_states``, the API converted + the StateInfo into the original state if the state is not a non-scalar + tensor. For a multi-dimensional tensor, the local state will be stored in + ``state_buffer`` in a correct order for later allgather purpose. + """ + + state_buffers: dict[str, list[Optional[torch.Tensor]]] = {} + + for fqn, gathered_state in output_states.items(): + state_info = [s[fqn] for s in gathered_state_info] + all_tensor_states = sorted({n for state in state_info for n in state.tensors}) + empty_ranks: set[int] = set() + dtype: Optional[torch.dtype] = None + # First check all the non-scalar states and get the information of + # states on each rank. + for state_name in all_tensor_states: + numels = [] + _empty_ranks: set[int] = set() + for rank, object_state in enumerate(state_info): + numels.append(0) + info = object_state.tensors.get(state_name, None) + if info is not None: + numels[-1] = info.shape.numel() + if not dtype: + dtype = info.dtype + else: + if dtype != info.dtype: + raise AssertionError( + f"Expected dtype == info.dtype, got {dtype} != {info.dtype}" + ) + if numels[-1] == 0: + _empty_ranks.add(rank) + + if not (not empty_ranks or empty_ranks == _empty_ranks): + raise AssertionError( + f"Expected empty_ranks to be empty or equal to _empty_ranks, got {empty_ranks} vs {_empty_ranks}" + ) + empty_ranks = _empty_ranks + if state_name not in state_buffers: + state_buffers[state_name] = [ + None for _ in fsdp_param_info.param_indices + ] + local_state = input_states[fqn].get(state_name, None) + # N.B. We need to move the state to compute_device. The reason is + # not yet clear and we need to figure out why the state may be on a + # different device. + if local_state is not None: + local_state = local_state.to(fsdp_param_info.state.compute_device) + state_buffers[state_name][fsdp_param_info.param_indices[fqn]] = local_state + + # Restoring the scalar and non-tensor states. If the corresponding + # non-scalar states do not exist on the rank, we also skip the scalar + # non-tensor states on that rank. + for rank, object_state in enumerate(state_info): + if rank in empty_ranks: + continue + for name, non_tensor_value in object_state.non_tensors.items(): + curr_non_tensor_value = gathered_state.get(name, None) + if not ( + curr_non_tensor_value is None + or curr_non_tensor_value == non_tensor_value + ): + raise AssertionError( + f"Rank {rank} has different values for {name}: {non_tensor_value}." + + f" Other ranks: {curr_non_tensor_value}" + ) + gathered_state[name] = non_tensor_value + + for name, scalar_tensor_value in object_state.scalar_tensors.items(): + curr_scalar_tensor_value = gathered_state.get(name, None) + if not ( + curr_scalar_tensor_value is None + or torch.equal(scalar_tensor_value, curr_scalar_tensor_value) + ): + raise AssertionError( + f"Rank {rank} has different values for {name}: {scalar_tensor_value}." + + f" Other ranks: {curr_scalar_tensor_value}" + ) + gathered_state[name] = scalar_tensor_value + + return dtype, state_buffers # type: ignore[possibly-undefined] + + +def _unflatten_orig_param_states( + fsdp_param_info: FSDPParamInfo, + output_states: dict[str, dict[str, Any]], + state_name: str, + shard_state: bool, + to_save: bool, + cpu_offload: bool, +) -> None: + """ + Given a output state dict, ``output_states``, which the keys are FQNs to the + original parameters (not FlatParameters nor parameter ID), and the values + are gathered states, unflatten the states to the original dimensions. + + This function performs the unflattening process in-place. + """ + if not to_save: + return + flat_param = fsdp_param_info.handle.flat_param + fsdp_state = fsdp_param_info.state + for fqn, gathered_state in output_states.items(): + value = gathered_state[state_name] + param_idx = fsdp_param_info.param_indices[fqn] + + # TODO: This solution is not general and only apply to PTD TP solution. + if isinstance(value, DTensor): + placement = value.placements[0] + # If gathered state is a DTensor and its TP placement is not Replicate(), we need to + # gather the tensor on its TP dimension before chunking them into DTensor again. + if placement != Replicate(): + placement_dim = placement.dim # type: ignore[attr-defined] + value.redistribute(placements=(Replicate(),)) + reshape_size = list(flat_param._shapes[param_idx]) + reshape_size[placement_dim] *= value.device_mesh.size(0) + reshape_size = torch.Size(reshape_size) + value = value.reshape(reshape_size) + # If gathered state is a replicate DTensor, we directly reshape it. + else: + value = value.reshape(flat_param._shapes[param_idx]) + else: + # If gathered state is a tensor, we directly reshape it into unflatten state. + value = value.reshape(flat_param._shapes[param_idx]) + + if shard_state: + osd_config = fsdp_state._optim_state_dict_config + if getattr(osd_config, "_use_dtensor", False): + if fsdp_state._device_mesh is None: + raise AssertionError( + f"Expected _device_mesh to be not None, got {fsdp_state._device_mesh}" + ) + value = _ext_chunk_dtensor( + value, + fsdp_state.rank, + fsdp_state._device_mesh, + fsdp_state._fsdp_extension, + ) + else: + if fsdp_state.process_group is None: + raise AssertionError( + f"Expected process_group to be not None, got {fsdp_state.process_group}" + ) + value = _ext_chunk_tensor( + value, + fsdp_state.rank, + fsdp_state.world_size, + fsdp_state._device_handle.device_count(), + fsdp_state.process_group, + fsdp_state._fsdp_extension, + ) + elif not cpu_offload: + with SimpleProfiler.profile("clone"): + value = value.detach().clone() + + if cpu_offload: + with SimpleProfiler.profile(SimpleProfiler.Type.D2H): + value = value.cpu() + gathered_state[state_name] = value + + +def _allgather_orig_param_states( + fsdp_param_info: FSDPParamInfo, + gathered_state_info: list[dict[str, StateInfo]], + input_states: dict[str, Any], + shard_state: bool, + to_save: bool, + cpu_offload: bool, +) -> dict[str, dict[str, Any]]: + """ + Given the ``gathered_state_info`` and ``input_states``, the API allgathers + all tensor states and restore non-tensor states from ``gathered_state_info``. + """ + fsdp_state = fsdp_param_info.state + if fsdp_state.rank == 0 and dist.get_debug_level() == dist.DebugLevel.DETAIL: + logger.info( + "Memory Summary before calling to _allgather_orig_param_states %s", + fsdp_state._device_handle.memory_summary(), + ) + + output_states: dict[str, dict[str, Any]] = {fqn: {} for fqn in input_states} + + dtype, state_buffers = _convert_all_state_info( + fsdp_param_info, gathered_state_info, input_states, output_states + ) + + if len(state_buffers) == 0: + return output_states + + has_state_params: list[bool] = [ + fqn in output_states for fqn, idx in fsdp_param_info.param_indices.items() + ] + + # Loop through the ``state_buffers`` and construct the flattened, concatenated, + # sharded states. The size of the constructed state will be the same size as + # flat_param (also sharded). + # Then we perform an allgather_into_tensor to get the full flat_param state. + # The full flat_param state is the result of concatenation of multiple states + # the order of of flat_param._fqns. + # The final step is to split the flat_param state into original param states + # and return the result. + flat_param = fsdp_param_info.handle.flat_param + empty_func = functools.partial( + torch.empty, dtype=dtype, device=fsdp_state.compute_device + ) + gathered_tensor = empty_func(flat_param._padded_unsharded_size) + # Synchronize can be slow but this will be easier for us to debug. + fsdp_state._device_handle.synchronize() + for state_name, buffers in state_buffers.items(): + local_buffers: list[torch.Tensor] = [] + begin = fsdp_state.rank * flat_param._sharded_size.numel() + # End is inclusive. + end = begin + flat_param._sharded_size.numel() - 1 + # param_idx corresponds to the parameter index in the FlatParameter. + mem_offset, param_idx = 0, 0 + for numel, is_padding in zip( + flat_param._numels_with_padding, flat_param._is_padding_mask + ): + frozen_and_no_state = not is_padding and ( + not fsdp_param_info.param_requires_grad[param_idx] + and not has_state_params[param_idx] + ) + + if is_padding or frozen_and_no_state: + # This memory range is a padding or the param is frozen and does + # not require gradient. For the later case, we treat it as a + # padding and add empty values to the local_buffers. + + padding_begin, padding_end = mem_offset, mem_offset + numel - 1 + if padding_begin <= begin <= padding_end: + # The range is an align padding before the first parameter in + # the shard. The shard includes parts of this align padding. + padding_len = ( + padding_end - begin + 1 + if end >= padding_end + else end - begin + 1 + ) + elif padding_begin <= end <= padding_end: + # The range is an align padding after the last parameter in + # the shard. The shard includes parts of this align padding. + padding_len = ( + end - padding_begin + 1 + if begin <= padding_begin + else end - begin + 1 + ) + elif begin < padding_begin <= padding_end < end: + # The range is an align padding that is completely in the + # shard. + padding_len = numel + else: + padding_len = 0 + if padding_len: + local_buffers.append(empty_func(padding_len)) + + if not is_padding: + # This memory range is a parameter in FlatParameter. So there + # should be an corresponding state in the optimizer unless the + # parameter is frozen, which we treat it as a padding above. + + # We need to check if this rank owns the buffer. If this is None: + # 1.) the rank does not own any part of the original parameter. + # As a result, there is no corresponding optimizer state on + # the rank as well. + # 2.) the parameter is frozen AND no optimizer state for the + # parameter. If a parameter is frozen, there can still be + # optimizer state if the parameter is not frozen in the + # previous steps. + if buffers[param_idx] is not None: + local_buffers.append(cast(torch.Tensor, buffers[param_idx])) + param_idx += 1 + + mem_offset += numel + + shard_numel_padded = flat_param._sharded_size.numel() - ( + sum(t.numel() for t in local_buffers) + ) + + if flat_param._shard_numel_padded != shard_numel_padded: + raise AssertionError( + "Manually calculated _sharded_numel_padded is incorrect. " + f"_shard_numel_padded={flat_param._shard_numel_padded}, " + f"shard_numel_padded={shard_numel_padded}, " + f"_sharded_size.numel={flat_param._sharded_size.numel()}, " + f"_numels_with_padding={flat_param._numels_with_padding}, " + f"begin={begin}, end={end}," + ) + if shard_numel_padded > 0: + # Add right-handed padding. + local_buffers.append(empty_func(shard_numel_padded)) + local_shard = torch.cat(local_buffers) + if local_shard.numel() * fsdp_state.world_size != gathered_tensor.numel(): + raise AssertionError( + "The size of local shard times the world size should equal to the " + "gathered tensor size. The inconsistency may be from a bug of " + "FlatParameter's metadata or the reconstruction logic in optimizer " + "state dict." + ) + fsdp_state._device_handle.synchronize() + with SimpleProfiler.profile(SimpleProfiler.Type.ALLGATHER): + dist.all_gather_into_tensor( + gathered_tensor, local_shard, group=fsdp_state.process_group + ) + # Synchronize can be slow but this will be easier for us to debug. + fsdp_state._device_handle.synchronize() + + unpadded_tensor = gathered_tensor[: flat_param._unpadded_unsharded_size.numel()] + flat_param_handle = fsdp_param_info.handle + orig_states = flat_param_handle._get_unflat_views_aligned(unpadded_tensor) + if len(orig_states) != len(fsdp_param_info.param_indices): + raise AssertionError( + "The number of parameters from FlatParameter is not consistent to " + "the number of states used by optimizer state dict reconstruction " + "logic." + ) + for fqn, idx in fsdp_param_info.param_indices.items(): + if fsdp_param_info.param_requires_grad[idx] or fqn in output_states: + output_states[fqn][state_name] = orig_states[idx] + + _unflatten_orig_param_states( + fsdp_param_info, + output_states, + state_name, + shard_state, + to_save, + cpu_offload, + ) + + del gathered_tensor + return output_states + + +def _gather_all_orig_param_state( + fsdp_param_info: FSDPParamInfo, + input_states: dict[str, Any], + shard_state: bool, + to_save: bool, + cpu_offload: bool, +) -> dict[str, Any]: + """ + Given a optimizer state dict, ``input_states``, which the keys are FQNs to the + original parameters (not FlatParameters nor parameter ID), gather all the + states and unflatten them to the original dimensions. Note that all the + params referred by the ``input_states`` must be managed by FSDP. + """ + fsdp_state = fsdp_param_info.state + if ( + fsdp_state.world_size == 1 + or fsdp_state.sharding_strategy == ShardingStrategy.NO_SHARD + ): + return input_states if to_save else {} + + with SimpleProfiler.profile(SimpleProfiler.Type.RESHARDING): + with SimpleProfiler.profile(SimpleProfiler.Type.ALLGATHER_OBJ): + gathered_state_info = _allgather_state_info(fsdp_state, input_states) + output_states = _allgather_orig_param_states( + fsdp_param_info, + gathered_state_info, + input_states, + shard_state, + to_save, + cpu_offload, + ) + if to_save: + for key, idx in fsdp_param_info.param_indices.items(): + if key in output_states: + continue + if not fsdp_param_info.param_requires_grad[idx]: + continue + + raise RuntimeError( + f"{key} is not in the output state. " + "The FSDPParamInfo has the param keys " + f"{sorted(fsdp_param_info.param_indices.keys())} while " + "the output_states has the param keys " + f"{sorted(output_states.keys())}." + ) + return output_states + else: + return {} + + +def _convert_state_with_orig_params( + all_optim_state_keys: list[_OptimStateKey], + optim_state_key_to_param_key: dict[_OptimStateKey, Union[int, str]], + fqn_to_fsdp_param_info: dict[str, FSDPParamInfo], + optim_state_dict: dict[Union[str, int], Any], + to_save: bool, + shard_state: bool, + cpu_offload: bool = True, +) -> dict[str, Any]: + fsdp_osd_state: dict[str, Any] = {} + # This variable is used to deduplicate the FSDPParamInfo as one FSDPParamInfo + # usually corresponds to multiple parameters. We could not use FSDPParamInfo + # as the key because FSDPParamInfo is not hashable. As a result, we fall back + # to `id(FSDPParamInfo)`, which the type is an integer. + all_states: dict[int, dict[str, Any]] = {} + # Iterate in rank 0's flat parameter ID order to ensure aligned all-gathers + # across ranks + for optim_state_key in all_optim_state_keys: + param_key: Union[str, int, None] = optim_state_key_to_param_key.get( + optim_state_key + ) + + if param_key is None and not optim_state_key.is_fsdp_managed: + continue + + if optim_state_key.is_fsdp_managed: + fqn = optim_state_key.unflat_param_names[0] + fsdp_param_info = fqn_to_fsdp_param_info.get(fqn) + if fsdp_param_info is None: + # This can happen if the not all FSDP instances have all the + # parameters. This can happen with FSDP + some MPMD style + # parallelism. + + # TODO: it is unclear if we need to do the same check with + # non-FSDP managed keys. + continue + state = {} if param_key is None else optim_state_dict[param_key] + if id(fsdp_param_info) not in all_states: + all_states[id(fsdp_param_info)] = {} + all_states[id(fsdp_param_info)][fqn] = state + + elif to_save: + if len(optim_state_key.unflat_param_names) != 1: + raise AssertionError( + f"Expected len(optim_state_key.unflat_param_names) == 1, got {len(optim_state_key.unflat_param_names)}" + ) + unflat_param_name = optim_state_key.unflat_param_names[0] + with SimpleProfiler.profile("none_fsdp_managed_copy"): + param_key = cast(Union[str, int], param_key) + fsdp_osd_state[unflat_param_name] = copy.copy( + optim_state_dict[param_key] + ) + if cpu_offload: + for state_name, value in sorted_items( + fsdp_osd_state[unflat_param_name] + ): + if not torch.is_tensor(value): + continue + fsdp_osd_state[unflat_param_name][state_name] = value.cpu() + + # Instead of gathering the state of each parameter individually, we perform + # the gathering all at once to speed up the process. + for _all_states in all_states.values(): + fqn = next(iter(_all_states.keys())) + fsdp_param_info = fqn_to_fsdp_param_info[fqn] + if len(fsdp_param_info.param_requires_grad) <= 0: + raise AssertionError( + "With use_orig_params, FSDPParamInfo should have requires_grad " + "information. However, the length is zero." + ) + for key, idx in fsdp_param_info.param_indices.items(): + if key in _all_states: + continue + if not fsdp_param_info.param_requires_grad[idx]: + continue + raise RuntimeError( + f"{key} is not in the optimizer state. " + "The FSDPParamInfo has the param keys " + f"{sorted(fsdp_param_info.param_indices.keys())} while " + "the optimizer has the param keys " + f"{sorted(_all_states.keys())}." + ) + fsdp_osd_state.update( + _gather_all_orig_param_state( + fsdp_param_info, + _all_states, + shard_state, + to_save, + cpu_offload, + ) + ) + + return fsdp_osd_state + + +def _convert_state_with_flat_params( + all_optim_state_keys: list[_OptimStateKey], + optim_state_key_to_param_key: dict[_OptimStateKey, Union[int, str]], + fqn_to_fsdp_param_info: dict[str, FSDPParamInfo], + optim_state_dict: dict[Union[str, int], Any], + to_save: bool, + shard_state: bool, + cpu_offload: bool = True, +) -> dict[str, Any]: + fsdp_osd_state: dict[str, Any] = {} + # Iterate in rank 0's flat parameter ID order to ensure aligned all-gathers + # across ranks + for optim_state_key in all_optim_state_keys: + param_key: Union[str, int, None] = optim_state_key_to_param_key.get( + optim_state_key + ) + + if param_key is None: + raise AssertionError( + "If use_orig_params is False, we must be able to find the " + f"corresponding param id. {optim_state_key} {param_key}" + ) + + if optim_state_key.is_fsdp_managed: + # If there are multiple unflat_param_names (not use_orig_params), + # they share the same FSDPParamInfo. So the first unflat_param_name + # is sufficient to fetch the FSDPParamInfo. + fqn = optim_state_key.unflat_param_names[0] + fsdp_param_info = fqn_to_fsdp_param_info[fqn] + unflat_state = _unflatten_optim_state( + fsdp_param_info, + optim_state_dict[param_key], + to_save, + shard_state, + cpu_offload, + ) + if to_save: + if len(unflat_state) != len(optim_state_key.unflat_param_names): + raise AssertionError( + f"Expected len(unflat_state) == len(optim_state_key.unflat_param_names), " + f"got {len(unflat_state)} != {len(optim_state_key.unflat_param_names)}" + ) + fsdp_osd_state.update( + zip( + optim_state_key.unflat_param_names, + unflat_state, + ) + ) + elif to_save: + if len(optim_state_key.unflat_param_names) != 1: + raise AssertionError( + f"Expected len(optim_state_key.unflat_param_names) == 1, got {len(optim_state_key.unflat_param_names)}" + ) + unflat_param_name = optim_state_key.unflat_param_names[0] + fsdp_osd_state[unflat_param_name] = copy.copy(optim_state_dict[param_key]) + if cpu_offload: + for state_name, value in sorted_items( + fsdp_osd_state[unflat_param_name] + ): + if not torch.is_tensor(value): + continue + fsdp_osd_state[unflat_param_name][state_name] = value.cpu() + + return fsdp_osd_state + + +@torch.no_grad() +def _optim_state_dict( + model: nn.Module, + optim: torch.optim.Optimizer, + optim_state_dict: dict[str, Any], + optim_input: Optional[ + Union[ + list[dict[str, Any]], + Iterable[nn.Parameter], + ] + ], + rank0_only: bool, + shard_state: bool, + group: Optional[dist.ProcessGroup], + using_optim_input: bool, + use_orig_params: bool = False, + cpu_offload: bool = True, +) -> dict[str, Any]: + """ + Consolidates the optimizer state and returns it as a :class:`dict` + following the convention of :meth:`torch.optim.Optimizer.state_dict`, + i.e. with keys ``"state"`` and ``"param_groups"``. + The flat parameters in ``FSDP`` modules contained in ``model`` are mapped + back to their unflattened parameters. + + Parameter keys are not well-defined. For a regular optimizer, the optimizer + state_dict contains a mapping from parameter IDs to parameter states. + Parameter IDs are the order of parameters in ``optim.param_groups()`` across + all the groups. This API also allows user to pass ``optim_input`` for the + mapping between parameters and parameter IDs. Using ``optim_input`` is being + deprecated. + + If the optimizer is a ``NamedOptimizer``, the optimizer state_dict does not + contain parameter IDs mapping but a mapping from parameter FQNs to parameter + states. This API finds the mapping from FQNs to parameters if the optimizer + is a ``NamedOptimizer``. + + If ``use_orig_params`` is True, each rank will have all FSDP-managed + parameters but some of these parameters may be empty due to the sharding. + For a regular optim.Optimizer, states for those empty parameters will + not be initialized. So, when aggregating the FQNs across ranks, no assert + will be raised on a rank even if it does not have all the states -- it is + valid and FSDP knows how to aggregate them. However, FSDP has to ignore + handling those parameters that are not managed by FSDP and do not exist on + the local rank -- those are managed by other parallelisms and FSDP does not + know how to handle/aggregate them. + + Args: + model (nn.Module): Root module (which may or may not be a + :class:`FullyShardedDataParallel` instance) whose parameters + were passed into the optimizer ``optim``. + optim (torch.optim.Optimizer): Optimizer for ``model`` 's + parameters. + rank0_only (bool): If ``True``, saves the populated :class:`dict` + only on rank 0; if ``False``, saves it on all ranks. (Default: + ``True``) + shard_state (bool): If ``True``, shard and distribute all + non-zero-dimension states. + + Returns: + Dict[str, Any]: A :class:`dict` containing the optimizer state for + ``model`` 's original unflattened parameters and including keys + "state" and "param_groups" following the convention of + :meth:`torch.optim.Optimizer.state_dict`. If ``rank0_only=False``, + then nonzero ranks return an empty :class:`dict`. + """ + SimpleProfiler.reset() + cm = ExitStack() + cm.enter_context(SimpleProfiler.profile(SimpleProfiler.Type.ALL)) + _reset_flat_param_grad_info_if_needed(traversal_utils._get_fsdp_handles(model)) + to_save = not rank0_only or dist.get_rank(group) == 0 or shard_state + + with SimpleProfiler.profile("preprocessing"): + param_to_fqns = _get_param_to_fqns(model) + flat_param_to_fqn = _get_flat_param_to_fqn(model) + is_named_optimizer = _is_named_optimizer(optim_state_dict) + + param_key_to_param = cast( + dict[Union[int, str], nn.Parameter], + ( + _get_param_id_to_param_from_optim_input(model, optim_input) + if using_optim_input + else _get_param_key_to_param( + optim, model, is_named_optimizer, param_to_fqns, flat_param_to_fqn + ) + ), + ) + fqn_to_fsdp_param_info = _get_fqn_to_fsdp_param_info(model) + + with SimpleProfiler.profile("preprocessing_with_comm"): + ( + all_optim_state_keys, + optim_state_key_to_param_key, + ) = _map_param_key_to_optim_keys( + optim_state_dict, + group, + param_key_to_param, + param_to_fqns, + fqn_to_fsdp_param_info, + merge_keys=use_orig_params, + ) + + with SimpleProfiler.profile("state_converting"): + convert_fn = ( + _convert_state_with_orig_params + if use_orig_params + else _convert_state_with_flat_params + ) + fsdp_osd_state = convert_fn( + all_optim_state_keys, + optim_state_key_to_param_key, + fqn_to_fsdp_param_info, + optim_state_dict["state"], + to_save, + shard_state, + cpu_offload, + ) + + # At this point, communication is complete and ranks can return early if nothing + # will be saved on that rank. + if not to_save: + return {} + + fsdp_osd: dict[str, Any] = {"state": fsdp_osd_state} + + flat_param_fqns = set(flat_param_to_fqn.values()) + for key, value in optim_state_dict["state"].items(): + if key in fsdp_osd_state: + continue + if key in flat_param_fqns: + continue + if key in param_key_to_param: + continue + # This key is not recognized by FSDP. It may be a user-defined state + # or some parameters state that FSDP is unable to map from + # ``optim.param_groups``. + warnings.warn( + f"Found a optim state, {key}, that FSDP cannot process. FSDP " + "will directly copy everything to the returned state_dict. In " + "most cases, this is a user-defined state that is not " + "associated with any particular parameter. Another possible " + "case is this state is managed by TorchRec. Otherwise, there may " + " be a mismatched assumption of optim_state_dict of this mode.", + stacklevel=2, + ) + fsdp_osd_state[key] = value + + if "param_groups" in optim_state_dict: + fsdp_osd["param_groups"] = _unflatten_param_groups( + optim_state_dict, param_key_to_param, param_to_fqns + ) + + cm.close() + SimpleProfiler.dump_and_reset("FSDP _optim_state_dict() profiling: ") + + return fsdp_osd + + +def _get_fqn_to_fsdp_param_info(model: nn.Module) -> dict[str, FSDPParamInfo]: + """ + Construct the mapping from a param's fqn to its corresponding ``FSDPParamInfo`` + if the param is managed by FSDP. Shared parameters, or original parameters that + are shared across multiple nn.Modules, are required to belong to one and only + one FSDP instance and thus correspond to one ``FlatParameter``. Within the one + ``FlatParameter``, ``FlatParameter._fqns`` only stores the first FQN of a shared + parameter. Thus, the keys in the mapping are guaranteed to map to unique parameters. + """ + + def module_fn(module, prefix, tree_level, fqn_to_param_info): + fsdp_state = _get_module_fsdp_state_if_fully_sharded_module(module) + if fsdp_state is None: + return + _lazy_init(fsdp_state, module) + handle = _module_handle(fsdp_state, module) + if not handle: + return + flat_param = handle.flat_param + fsdp_param_info = FSDPParamInfo(fsdp_state, handle, {}, []) + # NOTE: `idx` indexes into the data structures *without* padding + # elements + for idx, local_fqn in enumerate(flat_param._fqns): + fqn = clean_tensor_name(prefix + local_fqn) + if fqn in fqn_to_param_info: + if fqn_to_param_info[fqn].handle.flat_param is not flat_param: + raise AssertionError( + f"Expected fqn_to_param_info[fqn].handle.flat_param is flat_param for {fqn}" + ) + fqn_to_param_info[fqn] = fsdp_param_info + fsdp_param_info.param_indices[fqn] = idx + if flat_param._params is not None: + fsdp_param_info.param_requires_grad.append( + flat_param._params[idx].requires_grad + ) + + def return_fn(fqn_to_param_info): + return fqn_to_param_info + + fqn_to_param_info: dict[str, FSDPParamInfo] = {} + # FlatParameter._fqns stores the local fqn, starting from the root of the + # FSDP. Using _apply_to_modules() with model (may not be the FSDP root + # module) allows us to construct the global fqn. + return _apply_to_modules( + model, + module_fn, + return_fn, + [fqn for fqn, _ in _named_parameters_with_duplicates(model)], + fqn_to_param_info, + ) + + +@no_type_check +def _set_optim_use_dtensor( + fsdp_state: _FSDPState, + state_dict_settings: StateDictSettings, +) -> None: + # If device_mesh is passed in when initializing FSDP, we automatically turn the + # _use_dtensor flag to be true for ShardedOptimStateDictConfig() if state_dict_type + # has to be set to SHARDED_STATE_DICT. + if getattr(fsdp_state, "_device_mesh", None): + state_dict_type = state_dict_settings.state_dict_type + if state_dict_type == StateDictType.LOCAL_STATE_DICT: + raise RuntimeError( + "Found state_dict_type LOCAL_STATE_DICT.", + "DeviceMesh is not compatible with LOCAL_STATE_DICT.", + "Please set state_dict_type to SHARDED_STATE_DICT to get DTensor state_dict.", + ) + else: + state_dict_settings.optim_state_dict_config._use_dtensor = True diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/fsdp/_runtime_utils.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/fsdp/_runtime_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..eab47412f5d25a3c8a3472141208d6833ec633d1 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/fsdp/_runtime_utils.py @@ -0,0 +1,1654 @@ +# mypy: allow-untyped-defs +import functools +import logging +from collections.abc import Callable +from enum import auto, Enum +from typing import Any, no_type_check, Optional + +import torch +import torch.distributed as dist +import torch.distributed.fsdp._traversal_utils as traversal_utils +import torch.nn as nn +import torch.nn.functional as F +from torch.autograd import Variable +from torch.autograd.graph import register_multi_grad_hook +from torch.distributed.algorithms._comm_hooks import LOW_PRECISION_HOOKS +from torch.distributed.fsdp._common_utils import ( + _assert_in_training_states, + _FSDPState, + _get_module_fsdp_state, + _is_composable, + _log_post_backward_hook, + _no_dispatch_record_stream, + clean_tensor_name, + TrainingState, +) +from torch.distributed.fsdp._flat_param import ( + FlatParameter, + FlatParamHandle, + HandleShardingStrategy, + HandleTrainingState, + RESHARD_AFTER_FORWARD_HANDLE_STRATEGIES, +) +from torch.distributed.fsdp._init_utils import HYBRID_SHARDING_STRATEGIES +from torch.distributed.fsdp.api import BackwardPrefetch +from torch.distributed.utils import ( + _apply_to_tensors, + _cast_forward_inputs, + _p_assert, + _to_kwargs, +) +from torch.utils import _pytree as pytree + + +logger = logging.getLogger(__name__) + +# Do not include "process_group" to enable hybrid shard and MoE cases +HOMOGENEOUS_ATTR_NAMES = ( + "_use_orig_params", + "limit_all_gathers", + "_use_full_prec_in_eval", +) + + +class _PrefetchMode(Enum): + BACKWARD = auto() + FORWARD = auto() + + +def _get_fsdp_root_states_with_modules( + module: nn.Module, +) -> tuple[list[_FSDPState], list[nn.Module]]: + """ + Returns a tuple containing: + 1. A list of the root ``_FSDPState`` instances in the module tree rooted at + ``module`` without any duplicates and following the ``module.modules()`` + traversal order (which is assumed to be depth-first). + 2. A corresponding list of the root modules owning the states in the first + list. + + This is similar to :func:`_get_fsdp_states_with_modules` except that we + must call :func:`_is_fsdp_root` to force a lazy initialization to determine + the FSDP root in case lazy initialization has not yet happened. + """ + fsdp_root_states: list[_FSDPState] = [] + fsdp_root_modules: list[nn.Module] = [] + visited_fsdp_states: set[_FSDPState] = set() + # NOTE: This function assumes that `module.modules()` proceeds top-down. + for submodule in module.modules(): + optional_state = _get_module_fsdp_state(submodule) + if ( + optional_state is not None + and optional_state not in visited_fsdp_states + and _is_fsdp_root(optional_state, submodule) + ): + visited_fsdp_states.add(optional_state) + fsdp_root_states.append(optional_state) + fsdp_root_modules.append(submodule) + return fsdp_root_states, fsdp_root_modules + + +def _get_fsdp_root_states(module: nn.Module) -> list[_FSDPState]: + """See :func:`_get_fsdp_root_states_with_modules`.""" + fsdp_root_states, _ = _get_fsdp_root_states_with_modules(module) + return fsdp_root_states + + +def _is_fsdp_root(state: _FSDPState, module: nn.Module) -> bool: + """ + Returns if ``state`` corresponds to that of an FSDP root. + + For the wrapper code path, ``state`` and ``module`` should be the same. For + the non-wrapper code path, ``state`` should be ``module`` 's state. + """ + # Force a lazy initialization to determine the FSDP root + _lazy_init(state, module) + if state._is_root is None: + raise AssertionError("Expected _is_root to be set after lazy init") + return state._is_root + + +@no_type_check +def _lazy_init( + state: _FSDPState, + root_module: nn.Module, +) -> _FSDPState: + """ + Performs initialization lazily, typically right before the first forward + pass. The laziness is needed to ensure that the parameter device/dtype and + the FSDP hierarchy have finalized. This method's actual logic only runs on + the root FSDP instance, which performs initialization for all non-root FSDP + instances to avoid partial initialization. + + For the non-composable code path, ``state`` and ``root_module`` should be + the same, namely the FSDP instance itself. + """ + if state._is_root is not None: + return # no-op: already lazily initialized + if not state._device_handle.is_available(): + # Allow the FSDP constructor to run even without CUDA but check this + # once we start real execution + raise RuntimeError("FSDP does not support CPU only execution") + # The following logic is only run on the root FSDP instance since it will + # set `_is_root=False` for the non-root instances + state._is_root = True + _assert_in_training_states(state, [TrainingState.IDLE]) + _check_flat_params_on_expected_device(state, root_module) + state._all_fsdp_states = traversal_utils._get_fsdp_states(root_module) + _init_streams(state) + buffers, buffer_dtypes = _get_buffers_and_dtypes_for_computation(state, root_module) + _cast_buffers_to_dtype_and_device(buffers, buffer_dtypes, state.compute_device) + state._exec_order_data.init(state, root_module, state.process_group) + _share_state_and_init_handle_attrs(state, root_module) + return state + + +def _check_flat_params_on_expected_device(state: _FSDPState, module: nn.Module): + """ + Checks that all ``FlatParameter``s in ``module`` 's tree managed by + ``state`` are on the expected device for *lazy initialization*. + """ + cpu_device = torch.device("cpu") + for handle in traversal_utils._get_fsdp_handles(module): + if ( + not handle._offload_params + and handle.flat_param.device != state.compute_device + ): + raise RuntimeError( + "An FSDP-managed module unexpectedly has parameters on " + f"{handle.flat_param.device}. Make sure to move the module to " + f"{state.compute_device} before training." + ) + elif handle._offload_params and handle.flat_param.device != cpu_device: + raise RuntimeError( + "An FSDP-managed module with parameter CPU offloading enabled " + f"has parameters on {handle.flat_param.device}. Make sure to " + f"not move the module from CPU when offloading parameters." + ) + + +@no_type_check +def _share_state_and_init_handle_attrs( + root_state: _FSDPState, + root_module: nn.Module, +) -> None: + """ + Shares data structure state from the ``root_state`` to all FSDP states in + ``root_module`` 's module tree, and initializes handle attributes. These + are done together to require a single loop over the states. + """ + handle = root_state._handle + if handle: + handle.init_flat_param_attributes() + attr_name_to_values: dict[str, set[Any]] = {} + for attr_name in HOMOGENEOUS_ATTR_NAMES: + attr_name_to_values[attr_name] = set() + root_state._all_handles = root_state._exec_order_data.all_handles # share reference + # Update _has_optim_in_backward for each handle. + for handle in root_state._all_handles: + flat_param = handle.flat_param + if hasattr(flat_param, "_in_backward_optimizers"): + raise RuntimeError( + "FSDP optimizer in backward only supported with use_orig_params=True!" + ) + handle._has_optim_in_backward = flat_param._params is not None and any( + hasattr(param, "_in_backward_optimizers") for param in flat_param._params + ) + if handle._has_optim_in_backward: + torch._C._log_api_usage_once("fsdp.optimizer_in_backward") + for fsdp_state in root_state._all_fsdp_states: + for attr_name in HOMOGENEOUS_ATTR_NAMES: + _p_assert( + hasattr(fsdp_state, attr_name), + f"FSDP state missing attribute {attr_name}", + ) + attr_name_to_values[attr_name].add(getattr(fsdp_state, attr_name)) + if fsdp_state is root_state: + continue + # Relax the assert for non-root FSDP instances in case the nested + # initialized module is wrapped again in FSDP later (e.g. after + # training to run inference) + _p_assert( + fsdp_state._is_root is None or not fsdp_state._is_root, + "Non-root FSDP instance's `_is_root` should not have been " + "set yet or should have been set to `False`", + ) + fsdp_state._is_root = False + fsdp_state._unshard_stream = root_state._unshard_stream + fsdp_state._post_backward_stream = root_state._post_backward_stream + fsdp_state._pre_unshard_stream = root_state._pre_unshard_stream + fsdp_state._all_reduce_stream = root_state._all_reduce_stream + fsdp_state._default_stream = root_state._default_stream + fsdp_state._exec_order_data = root_state._exec_order_data + fsdp_state._free_event_queue = root_state._free_event_queue + if fsdp_state._fsdp_extension is not None: + fsdp_state._fsdp_extension.compute_stream = root_state._default_stream + handle = fsdp_state._handle + if handle: + handle.init_flat_param_attributes() + for attr_name, attr_values in attr_name_to_values.items(): + if len(attr_values) != 1: + raise ValueError( + f"Expects one homogeneous value for {attr_name} but got {attr_values}" + ) + + +@no_type_check +def _init_streams( + state: _FSDPState, +) -> None: + """ + Initializes CUDA streams for overlapping communication, computation, and + data transfers. The streams should be shared across FSDP instances. + """ + if not state._is_root: + raise AssertionError("Expected state to be root") + if not state._device_handle.is_available(): + raise AssertionError("Expected device handle to be available") + uses_hybrid_sharding = any( + fsdp_state.sharding_strategy in HYBRID_SHARDING_STRATEGIES + for fsdp_state in state._all_fsdp_states + ) + # Prioritize all-gathers/reduce-scatters over async all-reduce for HSDP and + # preserve the default priority of 0 otherwise + high_priority = -1 if state.limit_all_gathers and uses_hybrid_sharding else 0 + # Default stream for computation + state._default_stream = state._device_handle.current_stream() + if state._fsdp_extension is not None: + # set the compute stream to the FSDP extension + state._fsdp_extension.compute_stream = state._default_stream + + # Stream for unshard logic, including allocating the all-gather destination + # tensors and the all-gathers themselves + state._unshard_stream = state._device_handle.Stream(priority=high_priority) + # Stream for overlapping gradient reduction with the backward pass gradient + # computation + state._post_backward_stream = state._device_handle.Stream(priority=high_priority) + # Stream for pre-unshard logic, namely allocations and writes for CPU + # offloading (H2D copy) and mixed precision (low precision cast) + state._pre_unshard_stream = state._device_handle.Stream(priority=high_priority) + # Stream to run HSDP's all-reduce as async (if using HSDP) + state._all_reduce_stream = ( + state._device_handle.Stream() if uses_hybrid_sharding else state._default_stream + ) + + +@no_type_check +def _unshard( + state: _FSDPState, + handle: FlatParamHandle, + unshard_stream: torch.Stream, + pre_unshard_stream: torch.Stream, +) -> None: + """ + Unshards the handles in ``handles``. If the handles are in + :meth:`summon_full_params` and are using mixed precision, then they are + forced to full precision. + + Postcondition: handle's ``FlatParameter`` 's data is the padded + unsharded flat parameter on the compute device. + """ + if not handle: + return + with state._device_handle.stream(pre_unshard_stream): + ran_pre_unshard = handle.pre_unshard() + if ran_pre_unshard: + unshard_stream.wait_stream(pre_unshard_stream) + if state.limit_all_gathers: + event = state._free_event_queue.dequeue_if_needed() + if event: + with torch.profiler.record_function( + "FullyShardedDataParallel.rate_limiter" + ): + event.synchronize() + with state._device_handle.stream(unshard_stream): + handle.unshard() + handle.post_unshard() + + +@no_type_check +def _reshard( + state: _FSDPState, + handle: FlatParamHandle, + free_unsharded_flat_param: bool, +): + """ + Reshards the handle. ``free_unsharded_flat_param`` indicates whether to + free the handle's padded unsharded flat parameter. + """ + handle.reshard(free_unsharded_flat_param) + if state.limit_all_gathers and free_unsharded_flat_param: + if not torch.distributed._functional_collectives.is_torchdynamo_compiling(): + # We don't run a even queue for freeing under torch compile atm + # But maybe we need to? TODO(voz): Look into this + free_event = state._device_handle.Event() + free_event.record() + state._free_event_queue.enqueue(free_event) + handle.post_reshard() + # Flat parameter freed or not, we always have to "unshard" the parameter + # upon next access to get its shape correct. + handle._prefetched = False + + +def _unshard_grads( + handle: Optional[FlatParamHandle], +) -> None: + if handle: + handle.unshard_grad() + + +def _reshard_grads( + handle: Optional[FlatParamHandle], +) -> None: + if handle: + handle.reshard_grad() + + +@no_type_check +def _pre_forward( + state: _FSDPState, + handle: Optional[FlatParamHandle], + unshard_fn: Callable, + module: nn.Module, + args: tuple[Any, ...], + kwargs: dict[str, Any], +) -> tuple[tuple[Any, ...], dict[str, Any]]: + """ + Runs the pre-forward logic. This includes an opportunity to unshard + currently sharded parameters such as those for the current forward and + registering post-backward hooks for these current parameters. This function + also converts forward ``args`` and ``kwargs`` to the given precision. + + Args: + handles (List[FlatParamHandle]): Handles giving the parameters used in + the current forward. + unshard_fn (Optional[Callable]): A callable to unshard any currently + sharded parameters or ``None`` to not do any unsharding. + module (nn.Module): Module whose forward this method runs right before; + expected by the hook signature. + args (Tuple[Any, ...]): Module forward ``args``. + kwargs (Dict[str, Any]): Module forward ``kwargs``. + """ + with torch.profiler.record_function("FullyShardedDataParallel._pre_forward"): + # For `fully_shard` + `checkpoint`, skip pre-forward logic in the + # recomputed forward + if handle and handle._training_state == HandleTrainingState.BACKWARD_PRE: + # For both checkpoint implementations, we do not need to re-cast + # inputs here since they will be checkpointed in the low precision + # either by AC or normally by autograd as long as the AC region is + # nested within FSDP + return args, kwargs + state.training_state = TrainingState.FORWARD_BACKWARD + state._exec_order_data.record_pre_forward(handle, module.training) + if handle: + handle._training_state = HandleTrainingState.FORWARD + if unshard_fn is not None: + unshard_fn(state, handle) + # Register post-backward hooks to reshard the parameters and reduce-scatter + # their gradients. They must be re-registered every forward pass in case + # the `grad_fn` is mutated. + _register_post_backward_hook(state, handle) + # We have to reallocate the _cpu_grad if optimizer overlap + # set the grad to None in the backward pass. + if handle and handle._offload_params and handle.flat_param._cpu_grad is None: + handle.flat_param._cpu_grad = torch.zeros_like( + handle.flat_param._local_shard, device=torch.device("cpu") + ).pin_memory() + + should_cast_forward_inputs = ( + state._handle and not state._handle._force_full_precision + ) + + if should_cast_forward_inputs and state.mixed_precision.cast_forward_inputs: + # Recursively convert args and kwargs to specified precision. + input_dtype: Optional[torch.dtype] = state.mixed_precision.param_dtype + args, kwargs = _cast_forward_inputs(input_dtype, *args, **kwargs) + _register_post_backward_reshard_only_hook(state, handle, args, kwargs) + return args, kwargs + + +@no_type_check +def _pre_forward_unshard( + state: _FSDPState, + handle: Optional[FlatParamHandle], +) -> None: + """Unshards parameters in the pre-forward.""" + if not handle: + return + # If the handles have been prefetched, then there is no need to call + # `_unshard()` again + if not handle._prefetched: + _unshard(state, handle, state._unshard_stream, state._pre_unshard_stream) + handle._needs_pre_forward_unshard = False + # Don't wait during trace + if not torch.distributed._functional_collectives.is_torchdynamo_compiling(): + current_stream = state._device_handle.current_stream() + if state._unshard_event is not None: + current_stream.wait_event(state._unshard_event) + state._unshard_event = None + else: + current_stream.wait_stream(state._unshard_stream) + with torch.profiler.record_function( + "FullyShardedDataParallel._pre_forward_prefetch" + ): + _prefetch_handle(state, handle, _PrefetchMode.FORWARD) + + +@no_type_check +def _post_forward( + state: _FSDPState, + handle: Optional[FlatParamHandle], + reshard_fn: Callable, + module: nn.Module, + input: Any, + output: Any, +) -> Any: + """ + Runs the post-forward logic. This includes an opportunity to reshard + currently unsharded parameters such as those used in the current forward + and registering pre-backward hooks on the forward outputs. + + Args: + handles (List[FlatParamHandle]): Handles giving the parameters used in + the current forward. + reshard_fn (Optional[Callable]): A callable to reshard any currently + unsharded parameters (e.g. from the current forward) or ``None`` to + not do any resharding. + module (nn.Module): Module whose forward just ran, which should be a + fully sharded module (see [Note: Fully Sharded Module]); expected + by the hook signature. + input (Any): Unused; expected by the hook signature. + output (Any): Forward pass output; pre-backward hooks are registered on + the tensors that require gradients in this output. + + Postcondition: Each ``FlatParameter`` 's data points to the sharded flat + parameter. + """ + with torch.profiler.record_function("FullyShardedDataParallel._post_forward"): + # For `fully_shard` + `checkpoint`, skip post-forward logic in the + # recomputed forward + if handle and handle._training_state == HandleTrainingState.BACKWARD_PRE: + return output + + state._exec_order_data.record_post_forward(handle) + if reshard_fn is not None: + reshard_fn(state, handle) + # Register pre-backward hooks to unshard the flat parameters for the + # gradient computation (if needed) + output = _register_pre_backward_hooks(state, module, output, handle) + state.training_state = TrainingState.IDLE + if handle: + handle._training_state = HandleTrainingState.IDLE + return output + + +@no_type_check +def _post_forward_reshard( + state: _FSDPState, + handle: FlatParamHandle, +) -> None: + """Reshards parameters in the post-forward.""" + if not handle: + return + # Do not free the root's parameters in the post-forward for `FULL_SHARD` + # with the intention that they are immediately used for backward + # computation (though this may not be true) + free_unsharded_flat_param = ( + not state._is_root + and handle._sharding_strategy in RESHARD_AFTER_FORWARD_HANDLE_STRATEGIES + ) + _reshard(state, handle, free_unsharded_flat_param) + + +@no_type_check +def _root_pre_forward( + state: _FSDPState, + module: nn.Module, + args, + kwargs, +) -> None: + """ + Runs pre-forward logic specific to the root FSDP instance, which should run + before any individual module's pre-forward. This starts with an attempt at + lazy initialization (which only runs non-vacuously once). Otherwise, if + this is called on a non-root FSDP instance, then it returns directly. + + Args: + module (nn.Module): Module for which this logic tries to run. It may or + may not be the root. If not, then this method does not do anything. + """ + with torch.profiler.record_function("FullyShardedDataParallel._root_pre_forward"): + _lazy_init(state, module) + _p_assert(state._is_root is not None, "Expects a root FSDP to have been set") + if not state._is_root: + # Always cast forward inputs in the root of this local FSDP unit for mixed + # precision, as this is where mixed precision could be configured. + # This is more useful for auto wrapping that is recommended in composable path. + # For manual wrapping, cast forward inputs on each local FSDP unit root will + # increase some overhead, so not turned on for model wrapper path right now where + # manual wrapping is more broadly used. + if _is_composable(state): + return _root_cast_forward_input(state, module, args, kwargs) + return args, kwargs + + # We cast buffers back to full precision if we're forcing full precision. Disjointly, we check if buffers + # are in full precision and if we should cast them back to lower precision, which happens when + # exiting eval() mode. + handle = state._handle + if handle: + should_cast_buffers_to_full_prec = handle._force_full_precision + else: + # If the root has no handle (no managed parameters), then we fall + # back to checking if any child wants to force full precision as a + # workaround + handles = traversal_utils._get_fsdp_handles(module) + should_cast_buffers_to_full_prec = any( + handle._force_full_precision for handle in handles + ) + + if should_cast_buffers_to_full_prec: + _cast_buffers_to_dtype_and_device( + buffers=dict(module.named_buffers()).values(), + buffer_dtypes=list(state._buffer_name_to_orig_dtype.values()), + device=state.compute_device, + ) + # This flag is only set when we cast buffers to full precision, to avoid the + # CPU overhead that can stem from retrieving all buffers and their types in the + # following else branch. + state._needs_buffer_dtype_restore_check = True + elif getattr(state, "_needs_buffer_dtype_restore_check", False): + # Check if buffers are in full precision and we need to cast them + # back down. + ( + buffers, + buffer_dtypes_for_computation, + ) = _get_buffers_and_dtypes_for_computation(state, module) + if len(buffers) > 0 and len(buffer_dtypes_for_computation) > 0: + if any( + buffer.dtype != buffer_dtype_for_computation + for buffer, buffer_dtype_for_computation in zip( + buffers, buffer_dtypes_for_computation + ) + ): + # Assume we have to cast everything if there is one mismatch + _cast_buffers_to_dtype_and_device( + buffers, buffer_dtypes_for_computation, state.compute_device + ) + # We don't have to check this again until we cast buffers to full precision again. + state._needs_buffer_dtype_restore_check = False + + if state.forward_prefetch: + handles = [ + fsdp_state._handle + for fsdp_state in state._all_fsdp_states + if fsdp_state._handle + ] + for handle in handles: + handle._needs_pre_forward_unshard = True + handle._prefetched = False + _wait_for_computation_stream( + state._device_handle.current_stream(), + state._unshard_stream, + state._pre_unshard_stream, + ) + _reset_flat_param_grad_info_if_needed(state._all_handles) + + # Prepares the forward inputs by moving them to ``compute_device`` + # TODO: Do not use the side stream for tensor copies for now; investigate + # the perf with/without it. + with torch.profiler.record_function("FullyShardedDataParallel._to_kwargs"): + args_tuple, kwargs_tuple = _to_kwargs( + args, kwargs, state.compute_device, False + ) + args = args_tuple[0] if args_tuple else tuple() + kwargs = kwargs_tuple[0] if kwargs_tuple else {} + + return _root_cast_forward_input(state, module, args, kwargs) + + +@no_type_check +def _root_cast_forward_input( + state: _FSDPState, module: torch.nn.Module, args, kwargs +) -> tuple[Any, Any]: + if state._handle: + force_full_precision = not state._handle._force_full_precision + else: + force_full_precision = True + + should_cast_forward_inputs = ( + (module.training or not state._use_full_prec_in_eval) and force_full_precision + ) and state.mixed_precision.cast_root_forward_inputs + + if should_cast_forward_inputs: + input_dtype: Optional[torch.dtype] = state.mixed_precision.param_dtype + args, kwargs = _cast_forward_inputs(input_dtype, *args, **kwargs) + + return args, kwargs + + +@no_type_check +def _pre_backward_hook( + state: _FSDPState, + module: nn.Module, + handle: FlatParamHandle, + grad, + *unused: Any, +) -> Any: + """ + Prepares ``_handle`` 's ``FlatParameter`` s for gradient computation. + + Args: + module (nn.Module): Fully sharded module (see [Note: Fully Sharded + Module]). + """ + # Only run the pre-backward hook once per group of handles involved in the + # same module forward computation + if ( + handle + and hasattr(handle, "_ran_pre_backward_hook") + and handle._ran_pre_backward_hook + ): + return grad + + with torch.profiler.record_function("FullyShardedDataParallel._pre_backward_hook"): + # Queue the post-backward callback once for the root FSDP instance to + # attach it to the outermost backward graph task so that it is called + # after all backward calls complete + if state._is_root and not state._post_backward_callback_queued: + _register_post_backward_final_callback(state, module) + _reset_flat_param_grad_info_if_needed(state._all_handles) + elif handle: + allowed_states = [TrainingState.IDLE] + if _is_composable(state): + allowed_states.append(TrainingState.FORWARD_BACKWARD) + _assert_in_training_states(state, allowed_states) + state.training_state = TrainingState.FORWARD_BACKWARD + # Queueing the post-backward callback is the only logic that is not + # per-handle in the pre-backward hook, so we can return early here if + # there are no handles. + if not handle: + return grad + handle._training_state = HandleTrainingState.BACKWARD_PRE + + if handle._needs_pre_backward_unshard: + # If the handles have been prefetched, then there is no need to + # call `_unshard()` again + if not handle._prefetched: + _unshard( + state, + handle, + state._unshard_stream, + state._pre_unshard_stream, + ) + # Don't wait during trace + if not torch.distributed._functional_collectives.is_torchdynamo_compiling(): + state._device_handle.current_stream().wait_stream(state._unshard_stream) + + # Set this to `False` to ensure that a mistargeted prefetch does not + # actually unshard these handles + handle._needs_pre_backward_unshard = False + with torch.profiler.record_function( + "FullyShardedDataParallel._pre_backward_prefetch" + ): + _prefetch_handle(state, handle, _PrefetchMode.BACKWARD) + handle.prepare_gradient_for_backward() + handle._ran_pre_backward_hook = True + return grad + + +@no_type_check +@torch.no_grad() +def _post_backward_hook( + state: _FSDPState, + handle: FlatParamHandle, + flat_param, + *unused: Any, +): + """ + Reduce-scatters the gradient of ``handle`` 's ``FlatParameter``. + + Precondition: The ``FlatParameter`` 's ``.grad`` attribute contains the + unsharded gradient for the local batch. + + Postcondition: + - If using ``NO_SHARD``, then the ``.grad`` attribute is the reduced + unsharded gradient. + - Otherwise, the ``_saved_grad_shard`` attribute is the reduced sharded + gradient (accumulating with any existing gradient). + """ + _log_post_backward_hook(state, handle, logger) + flat_param = handle.flat_param + flat_param._post_backward_called = True + with torch.autograd.profiler.record_function( + "FullyShardedDataParallel._post_backward_hook" + ): + _assert_in_training_states(state, [TrainingState.FORWARD_BACKWARD]) + # For multiple applications of reentrant AC across submodules sharing + # the same `FlatParameter`, the post-backward hook may run multiple + # times in one backward, in which case we permit the state to already + # be in `BACKWARD_POST`. + _p_assert( + handle._training_state + in (HandleTrainingState.BACKWARD_PRE, HandleTrainingState.BACKWARD_POST), + f"Expects `BACKWARD_PRE` or `BACKWARD_POST` state but got {handle._training_state}", + ) + handle._training_state = HandleTrainingState.BACKWARD_POST + + if flat_param.grad is None: + return + if flat_param.grad.requires_grad: + raise RuntimeError("FSDP does not support gradients of gradients") + + _post_backward_reshard(state, handle) + if not state._sync_gradients: + if handle._use_orig_params: + handle._use_unsharded_grad_views() + return + + # Wait for all ops in the current stream (e.g. gradient computation) to + # finish before reduce-scattering the gradient + if not torch.distributed._functional_collectives.is_torchdynamo_compiling(): + state._post_backward_stream.wait_stream( + state._device_handle.current_stream() + ) + + with state._device_handle.stream(state._post_backward_stream): + autograd_computed_grad = flat_param.grad.data + if ( + not _low_precision_hook_enabled(state) + and flat_param.grad.dtype != handle._reduce_dtype + # If we are forcing full precision but communicating grads + # (i.e. model.eval() + full precision in eval was configured), don't downcast gradient. + and not handle._force_full_precision + ): + flat_param.grad.data = flat_param.grad.to(handle._reduce_dtype) + if handle.uses_sharded_strategy: + _reduce_grad(state, handle) + else: + _reduce_grad_no_shard(state, handle) + # Since the unsharded gradient is produced in the computation + # stream and consumed in the post-backward stream, inform the + # caching allocator (before it goes out of scope) + _no_dispatch_record_stream( + autograd_computed_grad, state._post_backward_stream + ) + + +def _post_backward_reshard_only_hook( + state: _FSDPState, + handle: FlatParamHandle, + *unused: Any, +) -> None: + with torch.profiler.record_function( + "FullyShardedDataParallel._post_backward_hook_reshard_only" + ): + # `_pre_backward_hook` may not get executed + # if forward output does not require grad + # overwrite IDLE state for post-backward prefetching + state.training_state = TrainingState.FORWARD_BACKWARD + handle._training_state = HandleTrainingState.BACKWARD_POST + _post_backward_reshard(state, handle) + + +def _post_backward_reshard( + state: _FSDPState, + handle: FlatParamHandle, + *unused: Any, +) -> None: + free_unsharded_flat_param = _should_free_in_backward(state, handle) + _reshard(state, handle, free_unsharded_flat_param) + + # TODO: Post-backward prefetching does not support the multiple handles + # per module case since the post-backward hook runs per handle, not per + # group of handles. + with torch.profiler.record_function( + "FullyShardedDataParallel._post_backward_prefetch" + ): + _prefetch_handle(state, handle, _PrefetchMode.BACKWARD) + + +@no_type_check +def _should_free_in_backward( + state: _FSDPState, + handle: FlatParamHandle, +) -> bool: + """ + Returns whether FSDP should free the unsharded flat parameter in the + post-backward or not. + """ + if not handle.uses_sharded_strategy: + return False + # If not syncing gradients, then we do not free for strategies that do not + # reshard after forward as a *heuristic* to tradeoff higher memory for + # higher throughput. + return ( + state._sync_gradients + or handle._sharding_strategy in RESHARD_AFTER_FORWARD_HANDLE_STRATEGIES + ) + + +@no_type_check +def _reduce_grad(state: _FSDPState, handle: FlatParamHandle) -> None: + """ + For sharded strategies, this runs gradient reduction, sharded gradient + accumulation if needed, and the post-reduction callback. + """ + flat_param = handle.flat_param + uses_hybrid_sharded_strategy = handle._sharding_strategy in ( + HandleShardingStrategy.HYBRID_SHARD, + HandleShardingStrategy._HYBRID_SHARD_ZERO2, + ) + # We clear `.grad` to permit multiple backwards. This avoids a race where + # the second backward pass computation precedes ahead of the first backward + # pass reduction, which is possible since the reduction is issued in a + # separate stream and is async and would result in reducing the wrong + # gradient. + unsharded_grad = flat_param.grad.data + flat_param.grad = None + padded_unsharded_grad, new_sharded_grad = _get_reduce_scatter_tensors( + state, unsharded_grad + ) + if state._comm_hook is None: # default path + _div_if_needed(padded_unsharded_grad, state._gradient_predivide_factor) + pg = ( + handle._fake_process_group + if handle._use_fake_reduce + else state.process_group + ) + dist.reduce_scatter_tensor( + new_sharded_grad, + padded_unsharded_grad, + group=pg, + ) + if uses_hybrid_sharded_strategy: + # Don't wait during trace + if not torch.distributed._functional_collectives.is_torchdynamo_compiling(): + state._all_reduce_stream.wait_stream(state._post_backward_stream) + with state._device_handle.stream(state._all_reduce_stream): + # Since the new sharded gradient is produced in the post- + # backward stream and consumed in the all-reduce stream, + # inform the caching allocator + _no_dispatch_record_stream(new_sharded_grad, state._all_reduce_stream) + dist.all_reduce(new_sharded_grad, group=state._inter_node_pg) + _div_if_needed(new_sharded_grad, state._gradient_postdivide_factor) + grad_to_offload = _accumulate_sharded_grad( + state, handle, new_sharded_grad + ) + _post_reduce_grad_callback(state, handle, grad_to_offload) + return + _div_if_needed(new_sharded_grad, state._gradient_postdivide_factor) + else: + state._comm_hook( + state._comm_hook_state, padded_unsharded_grad, new_sharded_grad + ) + # NOTE: HSDP variants do not support communication hook. + grad_to_offload = _accumulate_sharded_grad(state, handle, new_sharded_grad) + _post_reduce_grad_callback(state, handle, grad_to_offload) + + +@no_type_check +def _get_reduce_scatter_tensors( + state: _FSDPState, unsharded_grad: torch.Tensor +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Returns the input and output tensors to reduce-scatter, respectively. + """ + chunks = list(unsharded_grad.chunk(state.world_size)) + numel_to_pad = state.world_size * chunks[0].numel() - unsharded_grad.numel() + padded_unsharded_grad = ( + F.pad(unsharded_grad, [0, numel_to_pad]) if numel_to_pad > 0 else unsharded_grad + ) + new_sharded_grad = torch.empty_like(chunks[0]) # padded + return padded_unsharded_grad, new_sharded_grad + + +@no_type_check +def _accumulate_sharded_grad( + state: _FSDPState, + handle: FlatParamHandle, + sharded_grad: torch.Tensor, +) -> torch.Tensor: + """ + Accumulates the reduce-scattered sharded gradient with any existing sharded + gradient if needed, returning the gradient to offload (if CPU offloading is + enabled). + """ + flat_param = handle.flat_param + _cast_grad_to_param_dtype(state, sharded_grad, flat_param) + # Save the sharded gradient in `_saved_grad_shard` to support gradient + # accumulation -- for multiple backwards, the gradient reductions may + # happen in arbitrary order + accumulate_grad = hasattr(flat_param, "_saved_grad_shard") + if accumulate_grad: + _check_grad_to_accumulate(sharded_grad, flat_param._saved_grad_shard) + flat_param._saved_grad_shard += sharded_grad + else: + flat_param._saved_grad_shard = sharded_grad + grad_to_offload = flat_param._saved_grad_shard + return grad_to_offload + + +@no_type_check +def _reduce_grad_no_shard(state: _FSDPState, handle: FlatParamHandle) -> None: + """ + For no-shard, this runs gradient reduction (which directly covers any + gradient accumulation implicitly) and the post-reduction callback. + """ + flat_param = handle.flat_param + if state._comm_hook is None: # default path + _div_if_needed(flat_param.grad, state._gradient_predivide_factor) + dist.all_reduce(flat_param.grad, group=state.process_group) + _div_if_needed(flat_param.grad, state._gradient_postdivide_factor) + else: + state._comm_hook(state._comm_hook_state, flat_param.grad) + # For `NO_SHARD`, we can keep the low precision gradients by simply + # omitting the cast altogether + if not handle._keep_low_precision_grads: + _cast_grad_to_param_dtype(state, flat_param.grad, flat_param) + grad_to_offload = flat_param.grad.data + _post_reduce_grad_callback(state, handle, grad_to_offload) + + +@no_type_check +def _post_reduce_grad_callback( + state: _FSDPState, + handle: FlatParamHandle, + # Additional arguments needed for the callback logic + grad_to_offload: torch.Tensor, +): + """ + This callback captures any logic to run after the gradient reduction + finishes. Currently, this offloads the gradient to CPU if CPU offloading is + enabled and uses sharded gradient views if ``use_orig_params=True``. + """ + _offload_grad(state, handle, grad_to_offload) + _post_backward_use_sharded_grad_views(handle) + + +@no_type_check +def _offload_grad( + state: _FSDPState, + handle: FlatParamHandle, + grad_to_offload: torch.Tensor, +): + if not handle._offload_params: + return + # Offload the gradient to CPU to ensure parameters and gradients are on the + # same device as required by the optimizer + # TODO: Investigate why `NO_SHARD` breaks correctness when using + # `non_blocking=True` here. + # TODO (rohan-varma): When CPU offload and optimizer overlap, + # non_blocking=True won't work since the copy may have not finished before + # the optimizer step executes on CPU. If we want to use non-blocking=True + # here, we'll have to synchronize before using result on CPU. + non_blocking = handle.uses_sharded_strategy and not handle._has_optim_in_backward + handle.flat_param._cpu_grad.copy_( + grad_to_offload.detach(), non_blocking=non_blocking + ) # synchronized in the post-backward callback + # Since the gradient being offloaded may have been produced in the + # computation stream and is being consumed here in the post-backward + # stream, inform the caching allocator + _no_dispatch_record_stream(grad_to_offload.data, state._post_backward_stream) + + +@no_type_check +def _post_backward_use_sharded_grad_views(handle: FlatParamHandle): + if not handle._use_orig_params: + return + # Since the handle's `FlatParameter` completed its gradient computation, we + # should reset the gradient noneness mask + handle._reset_is_grad_none() + # Delay using sharded gradient views until after the reduce-scatter instead + # of immediately after resharding + handle._use_sharded_grad_views() + if handle._has_optim_in_backward: + handle.prepare_gradient_for_optim() + for orig_param in handle.flat_param._params: + # Check for `None` gradient to filter parameters not in the rank + if orig_param.grad is not None and hasattr( + orig_param, "_in_backward_optimizers" + ): + # TODO (rohan-varma): For CPU offload, this unfortunately + # operates on CPU because the parameters and gradients have + # already been offloaded. We should run this on GPU after + # refactoring. + for optim in orig_param._in_backward_optimizers: + optim.step() + + optim.zero_grad(set_to_none=True) + handle._reset_flat_param_grad_info_if_needed() + if handle._offload_params: + handle.flat_param._cpu_grad = None + + +def _div_if_needed(tensor: torch.Tensor, div_factor: float) -> None: + if div_factor > 1: + tensor.div_(div_factor) + + +@no_type_check +def _cast_grad_to_param_dtype( + state: _FSDPState, + sharded_grad: torch.Tensor, + param: FlatParameter, +): + """ + Casts ``sharded_grad`` back to the full parameter dtype so that the + optimizer step runs with that dtype. This performs an actual cast if + 1. parameters were in reduced precision during the forward since then + gradients would be in that reduced precision, or + 2. parameters were not in reduced precision but gradients were in + reduced precision for communication. + However, if a low precision communication hook is registered, then this + dtype cast happens in the hook instead. + """ + _assert_in_training_states(state, [TrainingState.FORWARD_BACKWARD]) + if not _low_precision_hook_enabled(state) and sharded_grad.dtype != param.dtype: + low_prec_grad_data = sharded_grad.data + sharded_grad.data = sharded_grad.data.to(dtype=param.dtype) + # Since for `NO_SHARD`, the gradient is produced in the computation + # stream and consumed here in the post-backward stream, inform the + # caching allocator; for the sharded strategies, the gradient is + # produced in the post-backward stream, so this `record_stream()` + # should be a no-op + _no_dispatch_record_stream( + low_prec_grad_data, state._device_handle.current_stream() + ) + + +def _check_grad_to_accumulate( + new_sharded_grad: torch.Tensor, + accumulated_grad: torch.Tensor, +) -> None: + _p_assert( + accumulated_grad.shape == new_sharded_grad.shape, + "Shape mismatch when accumulating gradients: " + f"existing gradient shape={accumulated_grad.shape} " + f"new gradient shape={new_sharded_grad.shape}", + ) + _p_assert( + accumulated_grad.device == new_sharded_grad.device, + "Device mismatch when accumulating gradients: " + f"existing gradient device={accumulated_grad.device} " + f"new gradient device={new_sharded_grad.device}", + ) + + +@no_type_check +def _low_precision_hook_enabled(state: _FSDPState) -> bool: + return state._comm_hook in LOW_PRECISION_HOOKS + + +@no_type_check +@torch.no_grad() +def _post_backward_final_callback( + state: _FSDPState, + module: nn.Module, +): + """ + This waits for the post-backward to finish and performs some final cleanup. + This runs at the end of the entire backward pass and should only be called + on the root FSDP instance. + """ + _p_assert( + state._is_root, + "The post-backward callback should only be called on the root FSDP instance", + ) + root_state = state + + if root_state._sync_gradients: + current_stream = state._device_handle.current_stream() + # TODO (rohan-varma): this also waits for the overlapped optimizer step to finish + # since it currently runs in the post-backward stream. That can be + # pushed to the next forward if run in a different stream + current_stream.wait_stream(root_state._post_backward_stream) + if root_state._all_reduce_stream is not current_stream: # uses HSDP + current_stream.wait_stream(root_state._all_reduce_stream) + if root_state.cpu_offload.offload_params: + # Wait for non-blocking GPU -> CPU sharded gradient copies from the + # post-backward hooks to finish explicitly since CPU gradients do + # not automatically synchronize with the GPU + state._device_handle.current_stream().synchronize() + root_state._exec_order_data.next_iter() + + for fsdp_state in state._all_fsdp_states: + _catch_all_reshard(fsdp_state) + _finalize_params(fsdp_state) + fsdp_state.training_state = TrainingState.IDLE + handle = fsdp_state._handle + if handle: + handle._ran_pre_backward_hook = False + handle._needs_pre_backward_unshard = False + handle._post_forward_index = None + handle._training_state = HandleTrainingState.IDLE + handle._prefetched = False + # Reset for cases like one forward and multiple backwards + root_state._post_backward_callback_queued = False + + +@no_type_check +def _catch_all_reshard( + state: _FSDPState, +) -> None: + """ + Reshards the parameters that may not have been resharded in the + post-backward hook. This can happen when a module's output is used in the + forward pass, meaning that its pre-backward hook runs (unsharding the + parameter), but the post-backward hook does not run because the output was + not jused in the loss computation corresponding to this backward pass. + """ + # Wrap with a try-except to provide a more informative traceback if an + # error is raised + try: + if state._handle: + # TODO: This already-resharded check is brittle: + # https://github.com/pytorch/pytorch/issues/83956 + already_resharded = ( + state._handle.flat_param.data_ptr() + == state._handle.flat_param._local_shard.data_ptr() + # If FSDP skipped using sharded views, then the flat parameter + # still points to the sharded data, so we need to reshard to + # use sharded views + and not state._handle._skipped_use_sharded_views + ) + if already_resharded: + return + free_unsharded_flat_param = _should_free_in_backward(state, state._handle) + _reshard(state, state._handle, free_unsharded_flat_param) + except Exception as e: + _p_assert( + False, + f"Got exception in the catch-all reshard for {state}: {str(e)}", + raise_assertion_error=False, + ) + raise e + + +@no_type_check +def _finalize_params( + state: _FSDPState, +) -> None: + """Finalizes the parameters before the next iteration.""" + handle = state._handle + if not handle: + return + flat_param = handle.flat_param + if torch.distributed._functional_collectives.is_torchdynamo_compiling(): + if hasattr(flat_param, "_post_backward_hook_handle"): + pbhs_handle = flat_param._post_backward_hook_handle + pbhs_handle.remove() + del flat_param._post_backward_hook_handle + else: + if hasattr(flat_param, "_post_backward_hook_state"): + post_backward_hook_state_len = len(flat_param._post_backward_hook_state) + expected_post_backward_hook_state_len = int(flat_param.requires_grad) + 1 + _p_assert( + post_backward_hook_state_len == expected_post_backward_hook_state_len, + f"Invalid: ``_post_backward_hook_state``: {flat_param._post_backward_hook_state}", + ) + flat_param._post_backward_hook_state[-1].remove() + delattr(flat_param, "_post_backward_hook_state") + if flat_param.requires_grad: + if not state._sync_gradients: + # Preserve the gradient accumulation state if not synchronizing + # gradients: `.grad` remains the unsharded gradient from prior + # `no_sync()` iterations, and `_saved_grad_shard` remains the + # sharded gradient from the last synchronized iteration + return + if not handle._has_optim_in_backward: + handle.prepare_gradient_for_optim() + _p_assert( + hasattr(flat_param, "_post_backward_called"), + "Expects `_post_backward_called` to be set on the `FlatParameter`", + ) + flat_param._post_backward_called = False + + +@no_type_check +def _prefetch_handle( + state: _FSDPState, + current_handle: Optional[FlatParamHandle], + prefetch_mode: _PrefetchMode, +) -> None: + """ + Prefetches the next handles if needed (without synchronization). An empty + handles key cannot prefetch. + """ + if not current_handle: + return + handle = _get_handle_to_prefetch(state, current_handle) + if not handle: + return + # Temporarily emulate the training state while calling `_unshard` to + # ensure the correct `as_params` for `_use_unsharded_views()` + prev_training_state = handle._training_state + if prefetch_mode == _PrefetchMode.BACKWARD: + handle._training_state = HandleTrainingState.BACKWARD_PRE + elif prefetch_mode == _PrefetchMode.FORWARD: + handle._training_state = HandleTrainingState.FORWARD + else: + raise ValueError(f"Invalid prefetch mode on rank {state.rank}: {prefetch_mode}") + # Prefetch the next set of handles without synchronizing to allow + # the sync to happen as late as possible to maximize overlap + _unshard(state, handle, state._unshard_stream, state._pre_unshard_stream) + handle._training_state = prev_training_state + handle._prefetched = True + + +@no_type_check +def _get_handle_to_prefetch( + state: _FSDPState, + current_handle: FlatParamHandle, +) -> FlatParamHandle: + """ + Returns a :class:`list` of the handles keys to prefetch for the next + module(s), where ``current_handle`` represents the current module. + + "Prefetching" refers to running the unshard logic early (without + synchronization), and the "next" modules depend on the recorded execution + order and the current training state. + """ + training_state = _get_training_state(current_handle) + valid_training_states = ( + HandleTrainingState.BACKWARD_PRE, + HandleTrainingState.BACKWARD_POST, + HandleTrainingState.FORWARD, + ) + _p_assert( + training_state in valid_training_states, + f"Prefetching is only supported in {valid_training_states} but " + f"currently in {training_state}", + ) + eod = state._exec_order_data + target_handle: Optional[FlatParamHandle] = None + if ( + training_state == HandleTrainingState.BACKWARD_PRE + and state.backward_prefetch == BackwardPrefetch.BACKWARD_PRE + ) or ( + training_state == HandleTrainingState.BACKWARD_POST + and state.backward_prefetch == BackwardPrefetch.BACKWARD_POST + ): + target_handle_candidate = eod.get_handle_to_backward_prefetch(current_handle) + if ( + target_handle_candidate + and target_handle_candidate._needs_pre_backward_unshard + and not target_handle_candidate._prefetched + ): + target_handle = target_handle_candidate + else: + target_handle = None + elif training_state == HandleTrainingState.FORWARD and state.forward_prefetch: + target_handle_candidate = eod.get_handle_to_forward_prefetch(current_handle) + if ( + target_handle_candidate + and target_handle_candidate._needs_pre_forward_unshard + and not target_handle_candidate._prefetched + ): + target_handle = target_handle_candidate + else: + target_handle = None + + return target_handle + + +def _get_training_state( + handle: FlatParamHandle, +) -> HandleTrainingState: + """Returns the training state of the handles in ``handle``.""" + _p_assert(handle, "Expects a non-empty handle") + return handle._training_state + + +@no_type_check +def _register_pre_forward_hook( + state: _FSDPState, + module: nn.Module, +) -> None: + """ + Registers a pre-forward hook on ``module``. + """ + for forward_handle in state._pre_forward_handles: + forward_handle.remove() + state._pre_forward_handles.clear() + module_param_handle = state._fully_sharded_module_to_handle.get(module, None) + hook = functools.partial( + _pre_forward, state, module_param_handle, _pre_forward_unshard + ) + state._pre_forward_handles.append( + module.register_forward_pre_hook(hook, prepend=True, with_kwargs=True) + ) + + +@no_type_check +def _register_post_forward_hook( + state: _FSDPState, + module: nn.Module, +) -> None: + """ + Registers a post-forward hook on ``module``. Even if the module has no + handles, we should register the hook since it will register the module's + pre-backward hook. + """ + for forward_handle in state._post_forward_handles: + forward_handle.remove() + state._post_forward_handles.clear() + module_param_handle = state._fully_sharded_module_to_handle.get(module, None) + hook = functools.partial( + _post_forward, + state, + module_param_handle, + _post_forward_reshard, + ) + state._post_forward_handles.append(module.register_forward_hook(hook)) + + +@no_type_check +def _register_root_pre_forward_hook( + state: _FSDPState, + module: nn.Module, +): + """ + Registers root pre-forward hook on ``module``, which should be the local + FSDP root. + + NOTE: For the current composable FSDP design, we have each application of + ``fully_shard()`` to a module to indicate that that module is the local + FSDP root. We may remove this assumption in the future, in which case we + will need to register this root pre-forward hook on any candidate module + that may be the local FSDP root. + """ + for forward_handle in state._root_pre_forward_handles: + forward_handle.remove() + state._root_pre_forward_handles.clear() + hook = functools.partial(_root_pre_forward, state) + state._root_pre_forward_handles.append( + module.register_forward_pre_hook(hook, prepend=True, with_kwargs=True) + ) + + +@no_type_check +def _register_pre_backward_hooks( + state: _FSDPState, + module: nn.Module, + outputs: Any, + handle: FlatParamHandle, +) -> None: + """ + Registers pre-backward hooks on the tensors that require gradients in the + forward pass outputs ``outputs``, which were computed using the + ``FlatParameter`` s of ``handles``. + + Args: + module (nn.Module): Fully sharded module (see [Note: Fully Sharded + Module]). + + Returns: + Forward pass outputs with pre-backward hooks registered to tensors that + require gradients. + """ + # If there is no gradient computation, then there is no need for + # pre-backward logic + if not torch.is_grad_enabled(): + return outputs + if state._is_root: + state._post_backward_callback_queued = False # only defined on the root + + if handle: + handle._needs_pre_backward_unshard = False + # Since these handles' `FlatParameter`s participated in a forward, we + # conservatively assume that they will be used in the backward + handle._ran_pre_backward_hook = False + + def _register_hook(t: torch.Tensor) -> torch.Tensor: + if t.requires_grad: + t.register_hook( + torch.utils.hooks.unserializable_hook( + functools.partial(_pre_backward_hook, state, module, handle) + ) + ) + if handle: + handle._needs_pre_backward_unshard = True + return t + + return _apply_to_tensors(_register_hook, outputs) + + +def _register_post_backward_hook( + state: _FSDPState, + handle: Optional[FlatParamHandle], +) -> None: + """ + Registers post-backward hooks on the ``FlatParameter`` s' + ``AccumulateGrad`` objects to reshard and to reduce-scatter gradients. + + The ``AccumulateGrad`` object represents the last function that finalizes + the ``FlatParameter`` 's gradient, so it only runs after its entire + gradient computation has finished. + + We register the post-backward hook only once in the *first* forward that a + ``FlatParameter`` participates in. This relies on the ``AccumulateGrad`` + object being preserved through multiple forwards. + + NOTE: We follow this heuristic to prefer the *first* forward to target the + parameter mixed precision case, where there are *separate* + ``AccumulateGrad`` objects across the different forwards. (Without + parameter mixed precision, the ``AccumulateGrad`` objects are the same.) If + we instead prefer the *last* forward, then the hook runs early. + """ + # If there is no gradient computation, then there is no need for + # post-backward logic + if not torch.is_grad_enabled(): + return + if not handle: + return + flat_param = handle.flat_param + + if torch.distributed._functional_collectives.is_torchdynamo_compiling(): + already_registered = hasattr(flat_param, "_post_backward_hook_handle") + if already_registered or not flat_param.requires_grad: + return + hook = functools.partial(_post_backward_hook, state, handle) + hook_handle = flat_param.register_post_accumulate_grad_hook(hook) + flat_param._post_backward_hook_handle = hook_handle # type: ignore[attr-defined] + else: + already_registered = hasattr(flat_param, "_post_backward_hook_state") + if already_registered or not flat_param.requires_grad: + return + # Get the `AccumulateGrad` object + temp_flat_param = flat_param.expand_as(flat_param) + _p_assert( + temp_flat_param.grad_fn is not None, + "The `grad_fn` is needed to access the `AccumulateGrad` and " + "register the post-backward hook", + ) + acc_grad = temp_flat_param.grad_fn.next_functions[0][0] # type: ignore[union-attr] + if acc_grad is None: + raise AssertionError("Expected acc_grad to be set") + hook_handle = acc_grad.register_hook( + functools.partial(_post_backward_hook, state, handle) + ) + flat_param._post_backward_hook_state = (acc_grad, hook_handle) # type: ignore[attr-defined] + + +def _register_post_backward_reshard_only_hook( + state: _FSDPState, + handle: Optional[FlatParamHandle], + args: tuple[Any, ...], + kwargs: dict[str, Any], +) -> None: + """ + Registers post-backward hooks to reshard flat parameters that do not + require gradient. We register these using multi-post-grad hooks on the + input activations to ensure that all gradients that may depend on the + parameters have been computed before resharding. + """ + # If there is no gradient computation, then there is no need for + # post-backward logic + if not torch.is_grad_enabled(): + return + # Construct `inp_tensors` lazily to avoid CPU overhead in typical case + # where each flat parameter requires gradient + inp_tensors: Optional[list[torch.Tensor]] = None + if not handle: + return + flat_param = handle.flat_param + + if torch.distributed._functional_collectives.is_torchdynamo_compiling(): + already_registered = hasattr(flat_param, "_post_backward_hook_handle") + else: + already_registered = hasattr(flat_param, "_post_backward_hook_state") + + if already_registered or flat_param.requires_grad: + return + if inp_tensors is None: + args_flat = pytree.arg_tree_leaves(*args, **kwargs) + inp_tensors = [ + obj for obj in args_flat if torch.is_tensor(obj) and obj.requires_grad + ] + if inp_tensors is None: + raise AssertionError("Expected inp_tensors to be set") + hook_handle = register_multi_grad_hook( + inp_tensors, functools.partial(_post_backward_reshard_only_hook, state, handle) + ) + if torch.distributed._functional_collectives.is_torchdynamo_compiling(): + flat_param._post_backward_hook_handle = hook_handle # type: ignore[attr-defined, assignment] + else: + flat_param._post_backward_hook_state = (hook_handle,) # type: ignore[attr-defined, assignment] + + +@no_type_check +def _register_post_backward_final_callback( + state: _FSDPState, module: nn.Module +) -> None: + """ + Registers the post-backward final callback that runs at the end of the + backward pass. This should be called from the root FSDP instance at the + beginning of the pre-backward. + """ + _p_assert( + state._is_root, + "Only the root FSDP instance should register the post-backward callback", + ) + if state._post_backward_callback_queued: + return + _assert_in_training_states(state, [TrainingState.IDLE]) + # Trace does not need this callback + if not torch.distributed._functional_collectives.is_torchdynamo_compiling(): + state._post_backward_callback_queued = True + Variable._execution_engine.queue_callback( + functools.partial(_post_backward_final_callback, state, module) + ) + + +def _wait_for_computation_stream( + computation_stream: torch.Stream, + unshard_stream: torch.Stream, + pre_unshard_stream: torch.Stream, +): + """ + Has the unshard and pre-unshard streams wait for the computation stream. + For example, this should be called in the FSDP root's pre-forward to + respect optimizer step computation. + """ + # Tracing does not need to wait + if torch.distributed._functional_collectives.is_torchdynamo_compiling(): + return + unshard_stream.wait_stream(computation_stream) # type: ignore[attr-defined] + # Having the pre-all-gather stream wait for the current stream even if we + # do not leverage the pre-all-gather stream is tolerable since this only + # runs once per iteration + pre_unshard_stream.wait_stream(computation_stream) # type: ignore[attr-defined] + + +def _reset_flat_param_grad_info_if_needed( + handles: list[FlatParamHandle], +): + """ + Clears the original parameters' gradients if needed. This method's CPU + overhead is minimal, so we may call it throughout FSDP methods, which serve + as callsites to free the gradient memory earlier. + """ + if not isinstance(handles, list): + handles = [handles] + for handle in handles: + if handle._use_orig_params: + handle._reset_flat_param_grad_info_if_needed() + + +@no_type_check +def _get_buffers_and_dtypes_for_computation( + state: _FSDPState, + root_module: nn.Module, +) -> tuple[list[torch.Tensor], list[Optional[torch.dtype]]]: + """ + Returns all buffers in the module tree rooted at ``root_module`` and a + corresponding list of the buffer dtypes for computation. Each buffer dtype + is either ``None`` if buffer mixed precision is not enabled or the buffer + low precision dtype otherwise. + """ + _p_assert(state._is_root, "Expects the root to cast buffers") + buffers: list[torch.Tensor] = [] + buffer_dtypes: list[Optional[torch.dtype]] = [] + visited_buffers: set[torch.Tensor] = set() + # Traverse the FSDP states bottom-up so that we prefer the owning FSDP + # instance's mixed precision setting for each buffer + fsdp_states, fsdp_modules = traversal_utils._get_fsdp_states_with_modules( + root_module + ) + for fsdp_state, fsdp_module in zip(reversed(fsdp_states), reversed(fsdp_modules)): + for buffer_name, buffer in fsdp_module.named_buffers(): + if buffer in visited_buffers: + continue + visited_buffers.add(buffer) + if clean_tensor_name(buffer_name) in fsdp_state._ignored_buffer_names: + continue + buffers.append(buffer) + buffer_dtypes.append(fsdp_state.mixed_precision.buffer_dtype) + if len(buffers) != len(buffer_dtypes): + raise AssertionError( + f"Expected buffers and buffer_dtypes to have the same length, got {len(buffers)} and {len(buffer_dtypes)}" + ) + return buffers, buffer_dtypes + + +@no_type_check +def _get_orig_buffer_dtypes( + state: _FSDPState, + buffer_names: list[str], +) -> list[torch.dtype]: + """ + Returns the original buffer types of the given buffer names. + """ + buffer_dtypes: list[torch.dtype] = [] + for buffer_name in buffer_names: + _p_assert( + buffer_name in state._buffer_name_to_orig_dtype, + f"{buffer_name} is missing from pre-computed dict on rank " + f"{state.rank}, which only has keys " + f"{state._buffer_name_to_orig_dtype.keys()}", + ) + buffer_dtypes.append(state._buffer_name_to_orig_dtype[buffer_name]) + return buffer_dtypes + + +def _cast_buffers_to_dtype_and_device( + buffers: list[torch.Tensor], + buffer_dtypes: list[Optional[torch.dtype]], + device: torch.device, +) -> None: + """ + Casts ``buffers`` to the dtypes given by ``buffer_dtypes`` and moves them + to ``device``. If an element in ``buffer_dtypes`` is ``None``, then the + corresponding buffer is only moved to ``device``. + """ + _p_assert( + buffer_dtypes is None or len(buffers) == len(buffer_dtypes), + f"Expects `buffers` and `buffer_dtypes` to have the same length if " + f"`buffer_dtypes` is specified but got {len(buffers)} and " + f"{len(buffer_dtypes)}", + ) + for buffer, buffer_dtype in zip(buffers, buffer_dtypes): + if not torch.is_floating_point(buffer) or buffer_dtype is None: + buffer.data = buffer.to(device=device) + else: + buffer.data = buffer.to(device=device, dtype=buffer_dtype) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/fsdp/_shard_utils.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/fsdp/_shard_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..eca5b9bd398749f1f38f50a48969cfbc3758352a --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/fsdp/_shard_utils.py @@ -0,0 +1,140 @@ +# mypy: allow-untyped-defs +import copy +import itertools +import math +from typing import Optional + +import torch +import torch.distributed as dist +from torch._utils import _get_device_module +from torch.distributed import distributed_c10d +from torch.distributed._shard.sharded_tensor import ( + Shard, + ShardedTensor, + ShardedTensorMetadata, + TensorProperties, +) +from torch.distributed._shard.sharding_spec import ShardMetadata +from torch.distributed.tensor import DeviceMesh, DTensor, Replicate, Shard as DShard + + +def _get_remote_device_str(rank, device_type, num_devices_per_node): + if device_type.lower() == "cpu": + return f"rank:{rank}/{device_type}" + elif device_type.lower() == "hpu": + return f"rank:{rank}/{device_type}:{_get_device_module(device_type).current_device()}" + else: + return f"rank:{rank}/{device_type}:{rank % num_devices_per_node}" + + +def _create_chunk_sharded_tensor( + tensor: torch.Tensor, + rank: int, + world_size: int, + num_devices_per_node: int, + pg: dist.ProcessGroup, + device: Optional[torch.device] = None, +) -> ShardedTensor: + """ + Shard a tensor to chunks along the first dimension. The local rank will gets its + corresponding chunk as the local shard to create a ShardedTensor. + """ + chunks = tensor.chunk(world_size, dim=0) + if len(chunks) > rank: + local_shard = chunks[rank].clone() + offsets = [0 for _ in tensor.size()] + offsets[0] = math.ceil(tensor.size()[0] / world_size) * rank + local_shards = [Shard.from_tensor_and_offsets(local_shard, offsets, rank)] + else: + local_shards = [] + + # Create a ShardedTensor without invoking communication. + chunk_sizes = [list(chunk.size()) for chunk in chunks] + dim0_offsets = [0] + list( + itertools.accumulate([chunk_size[0] for chunk_size in chunk_sizes]) + )[:-1] + offsets = [0] * (len(chunk_sizes[0]) - 1) + chunk_offsets = [[d0] + offsets for d0 in dim0_offsets] + device_type = ( + distributed_c10d._get_pg_default_device(pg).type + if device is None + else device.type + ) + placements = [ + _get_remote_device_str( + dist.get_global_rank(pg, r), + device_type, + num_devices_per_node, + ) + for r in range(len(chunk_sizes)) + ] + if len(chunk_sizes) != len(chunk_offsets) or len(chunk_sizes) != len(placements): + raise AssertionError( + f"Expected chunk_sizes, chunk_offsets, and placements to have the same length, " + f"got {len(chunk_sizes)}, {len(chunk_offsets)}, {len(placements)}" + ) + shard_metadata = [ + ShardMetadata(offset, size, placement) + for offset, size, placement in zip(chunk_offsets, chunk_sizes, placements) + ] + sharded_tensor_metadata = ShardedTensorMetadata( + shards_metadata=shard_metadata, + size=tensor.size(), + tensor_properties=TensorProperties( + dtype=tensor.dtype, + layout=tensor.layout, + requires_grad=False, + memory_format=torch.contiguous_format, + pin_memory=tensor.is_pinned(), + ), + ) + return ShardedTensor._init_from_local_shards_and_global_metadata( + local_shards, sharded_tensor_metadata=sharded_tensor_metadata, process_group=pg + ) + + +def _create_chunk_dtensor( + tensor: torch.Tensor, + rank: int, + device_mesh: DeviceMesh, +) -> DTensor: + """ + Shard a tensor to chunks along the first dimension. The local rank will gets its + corresponding chunk as the local tensor to create a DTensor. + """ + # We need to explicitly call .detach() to return a new tensor detached from the current graph. + tensor = tensor.detach().clone() + + # FSDP placements: [Shard(0)] + # HSDP placements: [Replicate(), Shard(0)] + replicate_placements = [Replicate() for _ in range(device_mesh.ndim)] + shard_placements = [Replicate() for _ in range(device_mesh.ndim)] + shard_placements[-1] = DShard(0) # type: ignore[call-overload] + + return DTensor.from_local( + tensor, device_mesh, replicate_placements, run_check=False + ).redistribute( + placements=shard_placements, + ) + + +def _all_gather_dtensor( + tensor: DTensor, + root_mesh: Optional[DeviceMesh], +) -> torch.Tensor: + """ + All gather a DTensor in its sharded dimension and return the local tensor. + """ + if root_mesh != tensor.device_mesh: + raise AssertionError("The device mesh of a tensor should be a root mesh.") + + placements = list(copy.deepcopy(tensor.placements)) + # FSDP placements: [Shard(0)] -> [Replicate()] + # HSDP placements: [Replicate(), Shard(0)] -> [Replicate(), Replicate()] + placements[-1] = Replicate() + tensor = tensor.redistribute( + device_mesh=tensor.device_mesh, + placements=placements, + ) + + return tensor.to_local() diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/fsdp/_state_dict_utils.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/fsdp/_state_dict_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..ec648ced837e155018c7002560bb7e297b163c78 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/fsdp/_state_dict_utils.py @@ -0,0 +1,932 @@ +# mypy: allow-untyped-defs +import contextlib +import logging +import math +import warnings +from collections.abc import Callable, Generator, Iterator +from typing import Any, cast, no_type_check + +import torch +import torch.distributed as dist +import torch.distributed.algorithms._checkpoint.checkpoint_wrapper as checkpoint_wrapper +import torch.nn as nn +import torch.nn.functional as F +from torch.distributed._shard.sharded_tensor import ( + init_from_local_shards, + Shard, + ShardedTensor, +) +from torch.distributed.fsdp._common_utils import ( + _FSDPState, + _get_module_fsdp_state_if_fully_sharded_module, + _has_fsdp_params, + _is_composable, + _module_handle, + clean_tensor_name, + FSDP_PREFIX, + FSDP_WRAPPED_MODULE, +) +from torch.distributed.fsdp._debug_utils import SimpleProfiler +from torch.distributed.fsdp._runtime_utils import ( + _cast_buffers_to_dtype_and_device, + _get_orig_buffer_dtypes, + _lazy_init, + _reset_flat_param_grad_info_if_needed, +) +from torch.distributed.fsdp.api import ( + FullStateDictConfig, + ShardingStrategy, + StateDictType, +) +from torch.distributed.tensor import DTensor +from torch.distributed.utils import _replace_by_prefix + +from ._fsdp_extensions import ( + _ext_all_gather_dtensor, + _ext_chunk_dtensor, + _ext_chunk_tensor, + _ext_post_unflatten_transform, + _ext_pre_load_state_dict_transform, +) +from ._unshard_param_utils import _unshard_fsdp_state_params, FLAT_PARAM + + +logger = logging.getLogger(__name__) + + +def _should_unshard_params(fsdp_state: _FSDPState) -> bool: + return not ( + fsdp_state.sharding_strategy == ShardingStrategy.NO_SHARD + and (_is_composable(fsdp_state) or fsdp_state._use_orig_params) + ) + + +def _convert_to_wrapped_module_name(module_name: str) -> str: + module_name = module_name.replace(f"{FSDP_PREFIX}", "") + module_name = module_name.replace(f"{FSDP_WRAPPED_MODULE}", "") + if module_name: + module_name = f"{module_name}." + # `CheckpointWrapper` adds a prefix that has to be removed as well. + module_name = module_name.replace(checkpoint_wrapper._CHECKPOINT_PREFIX, "") + return module_name + + +def _param_name_infos( + module: nn.Module, fsdp_state: _FSDPState +) -> Iterator[tuple[str, str, str]]: + if not _has_fsdp_params(fsdp_state, module): + return + for param_name, module_name in _module_handle( + fsdp_state, module + ).param_module_names(): + module_name = _convert_to_wrapped_module_name(module_name) + fqn = f"{module_name}{param_name}" + yield fqn, param_name, module_name + + +def _shared_param_name_infos( + module: nn.Module, fsdp_state +) -> Iterator[tuple[str, str, str]]: + for param_name, module_name in _module_handle( + fsdp_state, module + ).shared_param_module_names(): + module_name = _convert_to_wrapped_module_name(module_name) + fqn = f"{module_name}{param_name}" + yield fqn, param_name, module_name + + +@no_type_check +def _enter_unshard_params_ctx( + module: nn.Module, + fsdp_state: _FSDPState, + writeback: bool = False, + rank0_only: bool = False, + offload_to_cpu: bool = False, + with_grads: bool = False, +) -> None: + """ + state_dict hooks cannot use the pure context call as the checkpoint flow + requires to enter the context in the pre-hook but leave the context in the + post-hook. This API enters the context of ``_unshard_fsdp_state_params``. + """ + if module in fsdp_state._unshard_params_ctx: + raise AssertionError( + "Entering the ``_unshard_fsdp_state_params`` context but _unshard_params_ctx[module] " + "is not None." + ) + fsdp_state._unshard_params_ctx[module] = _unshard_fsdp_state_params( + module, + fsdp_state, + writeback=writeback, + rank0_only=rank0_only, + offload_to_cpu=offload_to_cpu, + with_grads=with_grads, + ) + fsdp_state._unshard_params_ctx[module].__enter__() + + +@no_type_check +def _exit_unshard_params_ctx(module: nn.Module, fsdp_state: _FSDPState) -> None: + """A helper function to exit ``_unshard_fsdp_state_params`` context.""" + fsdp_state._unshard_params_ctx[module].__exit__(None, None, None) + fsdp_state._unshard_params_ctx.pop(module) + + +def _common_pre_state_dict_hook( + module: nn.Module, + fsdp_state: _FSDPState, +) -> None: + """Performs the pre-state_dict tasks shared by all state_dict types.""" + if fsdp_state._device_handle.is_available(): + fsdp_state._device_handle.synchronize() + # TODO: need to check if this is always correct for composable FSDP. + _lazy_init(fsdp_state, module) + if fsdp_state._is_root: + _reset_flat_param_grad_info_if_needed(fsdp_state._all_handles) + + +def _common_unshard_pre_state_dict_hook( + module: nn.Module, + fsdp_state: _FSDPState, + offload_to_cpu: bool, + rank0_only: bool, +) -> None: + """ + Performs the pre-state_dict tasks shared by all state_dict types that require + ``_unshard_fsdp_state_params()``. FULL_STATE_DICT and SHARDED_STATE_DICT use this hook. + """ + # For composable `fully_shard`, it does not need to unshard parameters for `NO_SHARD` cases. + if not _should_unshard_params(fsdp_state): + return + _enter_unshard_params_ctx( + module, + fsdp_state, + writeback=False, + offload_to_cpu=offload_to_cpu, + rank0_only=rank0_only, + ) + + +@no_type_check +def _common_unshard_post_state_dict_hook( + module: nn.Module, + fsdp_state: _FSDPState, + state_dict: dict[str, Any], + prefix: str, + param_hook: Callable, +) -> dict[str, Any]: + """ + The post-state_dict flow that shared by all state_dict types that require + ``_unshard_fsdp_state_params()``. FULL_STATE_DICT and SHARDED_STATE_DICT use this + hook. + """ + _replace_by_prefix(state_dict, prefix + f"{FSDP_PREFIX}", prefix) + # Return early for trivial cases + if not state_dict or not _has_fsdp_params(fsdp_state, module): + if _should_unshard_params(fsdp_state): + _exit_unshard_params_ctx(module, fsdp_state) + return state_dict + + # If a rank does not have unsharded parameters(when `rank0_only=True` + # and `rank != 0`), then the rank only needed to participate in the + # all-gather and does not need to save the # state dict. We simply check + # rank0_only to ensure this issue. + rank0_only = ( + fsdp_state._state_dict_type == StateDictType.FULL_STATE_DICT + and cast(FullStateDictConfig, fsdp_state._state_dict_config).rank0_only + ) + # no_fsdp_return means the state_dict returned by this rank should contain + # only non-FSDP controlled parameters and buffers. + no_fsdp_return = rank0_only and fsdp_state.rank != 0 + if no_fsdp_return and not fsdp_state._use_orig_params: + for clean_key in fsdp_state._buffer_names: + # This is a hack to support activation checkpoint. + clean_key = clean_key.replace( + f"{checkpoint_wrapper._CHECKPOINT_PREFIX}.", "" + ) + state_dict.pop(f"{prefix}{clean_key}", None) + # Non-zero ranks have flat_param key when rank0_only=True, because rank0_only=True is + # passed in to unshard context, but nonzero ranks reshard early, causing this flat_param + # to appear in state_dict. + state_dict.pop(f"{prefix}{FLAT_PARAM}") + _exit_unshard_params_ctx(module, fsdp_state) + return state_dict + + # Loop only the parameters saved in this instance's wrapped module to + # avoid processing buffers. + for fqn, param_name, module_name in _param_name_infos(module, fsdp_state): + fqn = f"{prefix}{fqn}" + if no_fsdp_return: + state_dict.pop(fqn) + continue + if fqn not in state_dict: + raise AssertionError( + f"FSDP assumes {fqn} is in the state_dict but the state_dict only " + f"has {state_dict.keys()}. " + f"prefix={prefix}, module_name={module_name}, " + f"param_name={param_name} rank={fsdp_state.rank}." + ) + + param_hook(state_dict, prefix, fqn) + + if _should_unshard_params(fsdp_state): + _exit_unshard_params_ctx(module, fsdp_state) + + cpu_device = torch.device("cpu") + buffer_clean_fqns = [] + buffers = [] + for clean_key in fsdp_state._buffer_names: + # This is a hack to support activation checkpoint. + clean_key = clean_tensor_name(clean_key) + fqn = f"{prefix}{clean_key}" + if fqn not in state_dict: + # A buffer can be registered as non-persistent. + continue + if no_fsdp_return: + state_dict.pop(fqn) + else: + buffer = state_dict[fqn] + if ( + fsdp_state._state_dict_config.offload_to_cpu + and buffer.device != cpu_device + ): + state_dict[fqn] = buffer.to(cpu_device) + # skip upcasting for ignored buffers + if clean_key not in fsdp_state._ignored_buffer_names: + buffer_clean_fqns.append(clean_key) + buffers.append(state_dict[fqn]) + + if buffers: + mixed_precision_enabled_for_buffers = ( + fsdp_state._mixed_precision_enabled_for_buffers() + if not _is_composable(fsdp_state) + else (fsdp_state.mixed_precision.buffer_dtype is not None) + ) + if mixed_precision_enabled_for_buffers: + buffer_dtypes = _get_orig_buffer_dtypes(fsdp_state, buffer_clean_fqns) + _cast_buffers_to_dtype_and_device( + buffers, buffer_dtypes, fsdp_state.compute_device + ) + for buffer, clean_fqn in zip(buffers, buffer_clean_fqns): + fqn = f"{prefix}{clean_fqn}" + logger.info("FSDP is casting the dtype of %s to %s", fqn, buffer.dtype) + state_dict[fqn] = buffer.clone() + return state_dict + + +@no_type_check +def _full_pre_state_dict_hook( + fsdp_state: _FSDPState, + module: nn.Module, + *args, + **kwargs, +) -> None: + """ + Hook that runs before model.state_dict() is called. pre-state_dict hook is + not actually supported by ``nn.Module``. As a result, this API is called + from ``_full_post_state_dict_hook()`` to simulate the case. Once pre-state_dict + is supported in ``nn.Module``, this hook will be registered as a hook in + ``nn.Module``. + """ + if getattr(fsdp_state, "_device_mesh", False): + fsdp_state._device_mesh._get_root_mesh() + + _common_pre_state_dict_hook(module, fsdp_state) + _common_unshard_pre_state_dict_hook( + module, + fsdp_state, + offload_to_cpu=fsdp_state._state_dict_config.offload_to_cpu, + rank0_only=cast(FullStateDictConfig, fsdp_state._state_dict_config).rank0_only, + ) + + +@no_type_check +def _full_post_state_dict_hook( + module: nn.Module, + fsdp_state: _FSDPState, + state_dict: dict[str, Any], + prefix: str, +) -> dict[str, Any]: + """ + Hook that runs after model.state_dict() is called before returning result to + user. For FSDP, we may have to clone the tensors in state_dict as params go + back to sharded version after _unshard_fsdp_state_params ends, and also remove + the ``FSDP_WRAPPED_MODULE`` prefix. + """ + + def param_hook( + state_dict: dict[str, Any], + prefix: str, + fqn: str, + ) -> None: + clean_key = fqn + clean_prefix = clean_tensor_name(prefix) + # Strip prefix out of key if needed as buffer names and param names + # do not have prefix considered as they are not computed in `state_dict` + # call. + clean_key = clean_key.removeprefix(clean_prefix) + + # Clone parameters before exiting the `_unshard_fsdp_state_params()` context. + if not getattr(state_dict[fqn], "_has_been_cloned", False): + try: + state_dict[fqn] = state_dict[fqn].detach().clone() + state_dict[fqn]._has_been_cloned = True # type: ignore[attr-defined] + except BaseException as e: # noqa: B036 + warnings.warn( + f"Failed to clone() tensor with name {fqn} on rank {fsdp_state.rank}. " + "This may mean that this state_dict entry could point to invalid " + "memory regions after returning from state_dict() call if this " + "parameter is managed by FSDP. Please check clone " + f"implementation of {fqn}. Error: {str(e)}", + stacklevel=2, + ) + + return _common_unshard_post_state_dict_hook( + module, fsdp_state, state_dict, prefix, param_hook + ) + + +def _full_pre_load_state_dict_hook( + module: nn.Module, + fsdp_state: _FSDPState, + state_dict: dict[str, Any], + prefix: str, +) -> None: + _lazy_init(fsdp_state, module) + if _should_unshard_params(fsdp_state): + with SimpleProfiler.profile("_enter_unshard_params_ctx"): + _enter_unshard_params_ctx(module, fsdp_state, writeback=True) + # Add FSDP_PREFIX only for wrapper-based FSDP. + if not _is_composable(fsdp_state): + _replace_by_prefix(state_dict, prefix, prefix + f"{FSDP_PREFIX}") + + +def _full_post_load_state_dict_hook( + module: nn.Module, fsdp_state: _FSDPState, *args, **kwargs +) -> None: + if _should_unshard_params(fsdp_state): + with SimpleProfiler.profile("_exit_unshard_params_ctx"): + _exit_unshard_params_ctx(module, fsdp_state) + + +def _local_pre_state_dict_hook( + fsdp_state: _FSDPState, + module: nn.Module, + *args, + **kwargs, +) -> None: + """ + Hook that runs before model.state_dict() is called. Right now, pre-state_dict + hook is not supported by the PyTorch core. So this API is called from + `_local_post_state_dict_hook()` to simulate the case. + """ + if ( + _has_fsdp_params(fsdp_state, module) + and not _module_handle(fsdp_state, module).uses_sharded_strategy + ): + raise RuntimeError( + "``local_state_dict`` can only be used when parameters are flatten " + "and sharded." + ) + _common_pre_state_dict_hook(module, fsdp_state) + + +@no_type_check +def _local_post_state_dict_hook( + module: nn.Module, + fsdp_state: _FSDPState, + state_dict: dict[str, Any], + prefix: str, +) -> dict[str, Any]: + """ + This hook create a ShardedTensor from the local flat_param and replace + the state_dict[f"{prefix}{FLAT_PARAM}] with the ShardedTensor. No copy + will happen. The underlying storage is the same. + """ + + _replace_by_prefix(state_dict, f"{prefix}{FSDP_PREFIX}", prefix) + if not _has_fsdp_params(fsdp_state, module): + return state_dict + + # state_dict[f"{prefix}{FLAT_PARAM}"] exists and has the same tensor + # value as the flat_param but it is a pure Tensor because + # nn.Module.state_dict() will detach the parameter. Therefore, we need + # to get flat_param to get the metadata. + if not _module_handle(fsdp_state, module): + raise AssertionError("Should have returned early") + flat_param = _module_handle(fsdp_state, module).flat_param + # Constructs a ShardedTensor from the flat_param "without" padding. + # Removing the padding allows users to change the number of ranks + # when loading the local_state_dict. + full_numel = flat_param._unpadded_unsharded_size.numel() # type: ignore[attr-defined] + shard_offset = flat_param.numel() * fsdp_state.rank + valid_data_size = flat_param.numel() - flat_param._shard_numel_padded + if valid_data_size > 0: + # If FlatParameter is returned, FlatParameter._local_shard cause a + # pickling issue (can be torch.save but not torch.load). Since there + # is no benefit for state_dict to return the actual FlatParameter class, + # a view (which is a tensor) of the FlatParameter will be returned. + flat_param = flat_param[:valid_data_size].view(valid_data_size) + local_shards = [ + Shard.from_tensor_and_offsets(flat_param, [shard_offset], fsdp_state.rank) + ] + else: + local_shards = [] + sharded_tensor = init_from_local_shards( + local_shards, full_numel, process_group=fsdp_state.process_group + ) # type: ignore[assignment] + # TODO: Add DTensor state_dict support for LOCAL_STATE_DICT. + if fsdp_state._state_dict_config.offload_to_cpu: + sharded_tensor = sharded_tensor.cpu() + state_dict[f"{prefix}{FLAT_PARAM}"] = sharded_tensor + return state_dict + + +def _local_post_load_state_dict_hook( + module: nn.Module, fsdp_state: _FSDPState, *args, **kwargs +) -> None: + pass + + +def _local_pre_load_state_dict_hook( + module: nn.Module, + fsdp_state: _FSDPState, + state_dict: dict[str, Any], + prefix: str, +) -> None: + """ + This hook finds the local flat_param for this FSDP module from the + state_dict. The flat_param should be a ShardedTensor. This hook converts + the ShardedTensor to a tensor. No copy happen unless padding is required. + """ + _lazy_init(fsdp_state, module) + _replace_by_prefix(state_dict, prefix, f"{prefix}{FSDP_PREFIX}") + fqn = f"{prefix}{FSDP_PREFIX}{FLAT_PARAM}" + if fqn not in state_dict: + if _has_fsdp_params(fsdp_state, module): + raise AssertionError( + "No `FlatParameter` in `state_dict` for this FSDP instance " + "but it has parameters" + ) + return + load_tensor = state_dict[fqn] + if not isinstance(load_tensor, ShardedTensor): + raise AssertionError("Tensors in local_state_dict should be ShardedTensor.") + + # Convert the ShardedTensor to a Tensor. + flat_param = _module_handle(fsdp_state, module).flat_param + if flat_param is None: + raise AssertionError("Expected flat_param to be set") + valid_data_size = flat_param.numel() - flat_param._shard_numel_padded + shards = load_tensor.local_shards() + if valid_data_size > 0: + if not len(shards): + raise AssertionError( + "load_local_state_dict assume one shard per ShardedTensor." + ) + load_tensor = shards[0].tensor + + # Get the metadata of the flat_param to decide whether to pad the loaded + # tensor. + if flat_param._shard_numel_padded > 0: + if load_tensor.numel() >= flat_param.numel(): + raise AssertionError( + f"Local shard size = {flat_param.numel()} and the tensor in " + f"the state_dict is {load_tensor.numel()}." + ) + load_tensor = F.pad(load_tensor, [0, flat_param._shard_numel_padded]) + else: + load_tensor = flat_param + # TODO: Add DTensor state_dict support for LOCAL_STATE_DICT. + state_dict[fqn] = load_tensor + + +def _sharded_pre_state_dict_hook( + fsdp_state: _FSDPState, + module: nn.Module, + *args, + **kwargs, +) -> None: + """ + Hook that runs before model.state_dict() is called. Check + ``_full_pre_load_state_dict_hook`` for the detail. + """ + if ( + _has_fsdp_params(fsdp_state, module) + and not _module_handle(fsdp_state, module).uses_sharded_strategy + ): + raise RuntimeError( + "``sharded_state_dict`` can only be used when parameters are flatten " + "and sharded." + ) + _common_pre_state_dict_hook(module, fsdp_state) + # Setting offload_to_cpu here does not work even if offload_to_cpu is True. + # We have to create ShardedTensor first then move it to CPU. + _common_unshard_pre_state_dict_hook( + module, + fsdp_state, + offload_to_cpu=False, + rank0_only=False, + ) + + +@no_type_check +def _sharded_post_state_dict_hook( + module: nn.Module, + fsdp_state: _FSDPState, + state_dict: dict[str, Any], + prefix: str, +) -> dict[str, Any]: + """ + The hook replaces the unflattened, unsharded parameter in the state_dict + with a unflattened, sharded parameter (a ShardedTensor). + """ + + def param_hook(state_dict: dict[str, Any], prefix: str, fqn: str): + param = state_dict[fqn] + if not fsdp_state._state_dict_config._use_dtensor: + sharded_tensor = _ext_chunk_tensor( + tensor=param, + rank=fsdp_state.rank, + world_size=fsdp_state.world_size, + num_devices_per_node=fsdp_state._device_handle.device_count(), + pg=fsdp_state.process_group, + fsdp_extension=fsdp_state._fsdp_extension, + ) + else: + sharded_tensor = _ext_chunk_dtensor( + tensor=param, + rank=fsdp_state.rank, + device_mesh=fsdp_state._device_mesh, + fsdp_extension=fsdp_state._fsdp_extension, + ) + if fsdp_state._state_dict_config.offload_to_cpu: + sharded_tensor = sharded_tensor.cpu() + state_dict[fqn] = sharded_tensor + + return _common_unshard_post_state_dict_hook( + module, fsdp_state, state_dict, prefix, param_hook + ) + + +@no_type_check +def _sharded_post_load_state_dict_hook( + module: nn.Module, fsdp_state: _FSDPState, *args, **kwargs +) -> None: + if _has_fsdp_params(fsdp_state, module): + with SimpleProfiler.profile("_exit_unshard_params_ctx"): + _exit_unshard_params_ctx(module, fsdp_state) + + +@no_type_check +def _sharded_pre_load_state_dict_hook( + module: nn.Module, + fsdp_state: _FSDPState, + state_dict: dict[str, Any], + prefix: str, +) -> None: + """ + The hook combines the unflattened, sharded parameters (ShardedTensor) to + a new FlatParameter and shards the new FlatParameter to the local chunk. + """ + _lazy_init(fsdp_state, module) + if not _is_composable(fsdp_state): + _replace_by_prefix(state_dict, prefix, prefix + f"{FSDP_PREFIX}") + if not _has_fsdp_params(fsdp_state, module): + return + + handle = _module_handle(fsdp_state, module) + if not handle.uses_sharded_strategy: + raise RuntimeError( + "load_sharded_state_dict can only be called when parameters " + "are flattened and sharded." + ) + fqn_to_param_ext = dict( + zip(handle.flat_param._fqns, handle.flat_param._param_extensions) + ) + + for fqn, _, _ in _param_name_infos(module, fsdp_state): + if not _is_composable(fsdp_state): + fqn_from_global_root = f"{prefix}{FSDP_PREFIX}{fqn}" + else: + fqn_from_global_root = f"{prefix}{fqn}" + try: + param = state_dict.pop(fqn_from_global_root) + except KeyError: + logger.warning( + f"Did not find param with FQN {fqn_from_global_root}, skipping it. " # noqa: G004 + "The weight will not be filled if you expect it to be." + ) + continue # TODO: Improve unittesting for state_dict finetuning + # cases: https://github.com/pytorch/pytorch/issues/109134 + + if not fsdp_state._state_dict_config._use_dtensor: + # All-gather the param (ShardedTensor) + param, shards = _ext_pre_load_state_dict_transform( + param, fsdp_state._fsdp_extension + ) + + if len(shards) >= 2: + raise AssertionError( + "Expects 0 or 1 shard per rank " + f"but got {len(shards)} shards on rank {fsdp_state.rank}." + ) + param_numel = param.size().numel() + dim_0_size = param.size()[0] + chunk_size = ( + math.ceil(dim_0_size / fsdp_state.world_size) + * param_numel + // dim_0_size + ) + if len(shards) == 1: + local_tensor = shards[0].tensor.flatten() + with SimpleProfiler.profile(SimpleProfiler.Type.H2D): + local_tensor = local_tensor.to(fsdp_state.compute_device) + num_padding = chunk_size - local_tensor.numel() + if num_padding > 0: + local_tensor = F.pad(local_tensor, [0, num_padding]) + else: + local_tensor = torch.zeros( + chunk_size, dtype=param.dtype, device=fsdp_state.compute_device + ) + tensor = torch.empty( + chunk_size * fsdp_state.world_size, + dtype=local_tensor.dtype, + device=fsdp_state.compute_device, + ) + with SimpleProfiler.profile(SimpleProfiler.Type.ALLGATHER): + dist.all_gather_into_tensor( + tensor, local_tensor, group=fsdp_state.process_group + ) + tensor = tensor.narrow(0, 0, param_numel).reshape(param.size()) + state_dict[fqn_from_global_root] = tensor + else: + if param.device != fsdp_state._device_mesh.device_type: + param = param.to(fsdp_state._device_mesh.device_type) + + root_mesh = fsdp_state._device_mesh._get_root_mesh() + local_tensor = _ext_all_gather_dtensor( + param, root_mesh, fsdp_state._fsdp_extension + ) + + if fqn_to_param_ext.get(fqn) is not None: + ext = fqn_to_param_ext[fqn] + local_tensor = _ext_post_unflatten_transform( + local_tensor, ext, fsdp_state._fsdp_extension + ) + state_dict[fqn_from_global_root] = local_tensor + + with SimpleProfiler.profile("_enter_unshard_params_ctx"): + _enter_unshard_params_ctx(module, fsdp_state, writeback=True) + + +@contextlib.contextmanager +def _replace_with_full_state_dict_type(fsdp_state: _FSDPState) -> Generator: + old_state_dict_config = fsdp_state._state_dict_config + old_state_dict_type = fsdp_state._state_dict_type + fsdp_state._state_dict_config = FullStateDictConfig() + fsdp_state._state_dict_type = StateDictType.FULL_STATE_DICT + yield + fsdp_state._state_dict_config = old_state_dict_config + fsdp_state._state_dict_type = old_state_dict_type + + +@no_type_check +@torch.no_grad() +def _post_state_dict_hook( + module: nn.Module, + state_dict: dict[str, Any], + prefix: str, + *args: Any, +) -> dict[str, Any]: + """ + _post_state_dict_hook() is called after the state_dict() of this + FSDP module is executed. ``fsdp_state._state_dict_type`` is used to decide + what postprocessing will be done. + """ + fsdp_state = _get_module_fsdp_state_if_fully_sharded_module(module) + if fsdp_state.sharding_strategy == ShardingStrategy.NO_SHARD: + context = _replace_with_full_state_dict_type(fsdp_state) + warnings.warn( + "When using ``NO_SHARD`` for ``ShardingStrategy``, full_state_dict will " + "be returned.", + stacklevel=2, + ) + else: + context = contextlib.nullcontext() + + with context: + _post_state_dict_hook_fn = { + StateDictType.FULL_STATE_DICT: _full_post_state_dict_hook, + StateDictType.LOCAL_STATE_DICT: _local_post_state_dict_hook, + StateDictType.SHARDED_STATE_DICT: _sharded_post_state_dict_hook, + } + processed_state_dict = _post_state_dict_hook_fn[fsdp_state._state_dict_type]( + module, fsdp_state, state_dict, prefix + ) + + if fsdp_state._is_root: + logger.info("FSDP finished processing state_dict(), prefix=%s", prefix) + for key, tensor in sorted(processed_state_dict.items()): + if key.startswith(prefix) and isinstance(tensor, torch.Tensor): + local_shape = tensor.shape + device = None + if isinstance(tensor, ShardedTensor): + local_shape = None + shards = tensor.local_shards() + if shards: + local_shape = shards[0].tensor.shape + device = shards[0].tensor.device + elif isinstance(tensor, DTensor): + local_shape = tensor.to_local().shape + device = tensor.device + else: + device = tensor.device + logger.info( + "FQN=%s: type=%s, shape=%s, local_shape=%s, dtype=%s, device=%s", + key, + type(tensor), + tensor.shape, + local_shape, + tensor.dtype, + device, + ) + + return processed_state_dict + + +@no_type_check +@torch.no_grad() +def _pre_state_dict_hook( + module: nn.Module, + *args, + **kwargs, +) -> None: + """ + This is called before the core state dict saving logic of ``module``. + ``fsdp_state._state_dict_type`` is used to decide what postprocessing will + be done. + """ + fsdp_state = _get_module_fsdp_state_if_fully_sharded_module(module) + if fsdp_state.sharding_strategy == ShardingStrategy.NO_SHARD: + context = _replace_with_full_state_dict_type(fsdp_state) + warnings.warn( + "When using ``NO_SHARD`` for ``ShardingStrategy``, full_state_dict will " + "be returned.", + stacklevel=2, + ) + else: + _set_use_dtensor(fsdp_state) + context = contextlib.nullcontext() + + with context: + _pre_state_dict_hook_fn = { + StateDictType.FULL_STATE_DICT: _full_pre_state_dict_hook, + StateDictType.LOCAL_STATE_DICT: _local_pre_state_dict_hook, + StateDictType.SHARDED_STATE_DICT: _sharded_pre_state_dict_hook, + } + _pre_state_dict_hook_fn[fsdp_state._state_dict_type]( + fsdp_state, + module, + *args, + **kwargs, + ) + + +@no_type_check +def _set_use_dtensor(fsdp_state: _FSDPState) -> None: + # If device_mesh is passed in when initializing FSDP, we automatically turn the + # _use_dtensor flag to be true for ShardedStateDictConfig(). + if getattr(fsdp_state, "_device_mesh", None): + state_dict_type = fsdp_state._state_dict_type + if state_dict_type == StateDictType.LOCAL_STATE_DICT: + raise RuntimeError( + "Found state_dict_type LOCAL_STATE_DICT", + "DeviceMesh is not compatible with LOCAL_STATE_DICT.", + "Please set state_dict_type to SHARDED_STATE_DICT to get DTensor state_dict.", + ) + else: + fsdp_state._state_dict_config._use_dtensor = True + + +@no_type_check +@torch.no_grad() +def _pre_load_state_dict_hook( + module: nn.Module, + state_dict: dict[str, Any], + prefix: str, + *args: Any, +) -> None: + """ + This is called before ``module._load_from_state_dict()``. + ``fsdp_state._state_dict_type`` is used to decide what preprocessing will + be done. + """ + fsdp_state = _get_module_fsdp_state_if_fully_sharded_module(module) + if fsdp_state.sharding_strategy == ShardingStrategy.NO_SHARD: + context = _replace_with_full_state_dict_type(fsdp_state) + warnings.warn( + "When using ``NO_SHARD`` for ``ShardingStrategy``, full_state_dict will" + "be returned.", + stacklevel=2, + ) + else: + _set_use_dtensor(fsdp_state) + context = contextlib.nullcontext() + + _lazy_init(fsdp_state, module) + if fsdp_state._is_root: + SimpleProfiler.reset() + + with context: + _pre_load_state_dict_hook_fn = { + StateDictType.FULL_STATE_DICT: _full_pre_load_state_dict_hook, + StateDictType.LOCAL_STATE_DICT: _local_pre_load_state_dict_hook, + StateDictType.SHARDED_STATE_DICT: _sharded_pre_load_state_dict_hook, + } + # Code that is common for all state_dict impls + if fsdp_state._device_handle.is_available(): + fsdp_state._device_handle.synchronize() + # Dispatch into state_dict specific implementation of pre-hook. + _pre_load_state_dict_hook_fn[fsdp_state._state_dict_type]( + module, fsdp_state, state_dict, prefix + ) + + +@no_type_check +@torch.no_grad() +def _post_load_state_dict_hook( + module: nn.Module, + incompatible_keys: tuple[list[str], list[str]], + *args: Any, +) -> None: + fsdp_state = _get_module_fsdp_state_if_fully_sharded_module(module) + if fsdp_state.sharding_strategy == ShardingStrategy.NO_SHARD: + context = _replace_with_full_state_dict_type(fsdp_state) + warnings.warn( + "When using ``NO_SHARD`` for ``ShardingStrategy``, full_state_dict will" + "be returned.", + stacklevel=2, + ) + else: + context = contextlib.nullcontext() + + with context: + _post_load_state_dict_hook_fn = { + StateDictType.FULL_STATE_DICT: _full_post_load_state_dict_hook, + StateDictType.LOCAL_STATE_DICT: _local_post_load_state_dict_hook, + StateDictType.SHARDED_STATE_DICT: _sharded_post_load_state_dict_hook, + } + # Code that is common for all state_dict impls + # Dispatch into state_dict type specific implementation of post-hook for + # loading state_dict. + _post_load_state_dict_hook_fn[fsdp_state._state_dict_type](module, fsdp_state) + + # When reporting incompatible keys, trim FSDP prefixes. + missing_keys = incompatible_keys[0] + unexpected_keys = incompatible_keys[1] + for i in range(len(missing_keys)): + missing_keys[i] = clean_tensor_name(missing_keys[i]) + + for i in range(len(unexpected_keys)): + unexpected_keys[i] = clean_tensor_name(unexpected_keys[i]) + + if fsdp_state._is_root: + SimpleProfiler.dump_and_reset("FSDP model load_state_dict profiling: ") + + +def _register_all_state_dict_hooks(state: _FSDPState): + """ + Registers pre-save, post-save, pre-load, and post-load state dict hooks. + """ + for hook_registration_fn_str, hook, hook_registration_fn_kwargs in ( + ("register_state_dict_pre_hook", _pre_state_dict_hook, {}), + ("_register_state_dict_hook", _post_state_dict_hook, {}), + ( + "_register_load_state_dict_pre_hook", + _pre_load_state_dict_hook, + {"with_module": True}, + ), + ("register_load_state_dict_post_hook", _post_load_state_dict_hook, {}), + ): + _register_state_dict_hooks_base( + state, hook_registration_fn_str, hook, hook_registration_fn_kwargs + ) + + +@no_type_check +def _register_state_dict_hooks_base( + state: _FSDPState, + hook_registration_fn_name: str, + hook: Callable, + hook_registration_fn_kwargs: dict[str, Any], +) -> None: + """Registers ``hook`` using ``hook_registration_fn``.""" + if not _is_composable(state): + getattr(state, hook_registration_fn_name)(hook, **hook_registration_fn_kwargs) + else: + handle = state._handle + if handle: + getattr(handle._fully_sharded_module, hook_registration_fn_name)( + hook, **hook_registration_fn_kwargs + ) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/fsdp/_trace_utils.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/fsdp/_trace_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..c4d514c5c6474b3a984424b1cd7563e1656f3f2a --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/fsdp/_trace_utils.py @@ -0,0 +1,240 @@ +# mypy: allow-untyped-defs +import functools +from collections.abc import Callable +from contextlib import contextmanager +from dataclasses import dataclass, field +from typing import Any, NamedTuple, Optional + +import torch +import torch.nn as nn + + +@dataclass +class TracingConfig: + """ + This represents a symbolic tracing configuration. + + Args: + tracer (torch.fx.Tracer): An instance of :class:`torch.fx.Tracer` to + use for symbolic tracing. The default value is the native + :class:`torch.fx.Tracer` constructed with default arguments. + However, the user may want to pass a different value such as the + ``HFTracer`` for models in the HuggingFace Transformers_ library. + .. _Transformers: https://huggingface.co/docs/transformers/index + concrete_args (Optional[Dict[str, Any]]): Concrete arguments that + should not be treated as ``torch.fx.Proxy`` when tracing the + module ``forward()``. Passing ``concrete_args`` allows partially + specializing the forward, e.g. to remove control flow or data + structures. This ``concrete_args`` here is the same argument used + in :meth:`~torch.fx.Tracer.trace`. + """ + + tracer: torch.fx.Tracer = field(default_factory=torch.fx.Tracer) + concrete_args: Optional[dict[str, Any]] = None + + +class _ParamUsageInfo(NamedTuple): + """ + This is used for ``_ExecutionInfo.module_to_param_usage_infos`` to record + execution information. The ``dict`` maps modules to a list of these + ``_ParamUsageInfo`` instances, where each instance represents a group of + parameters used together. + + Specifically, for each module key in the ``dict``, each instance of this + class represents either: + (1) the module and some sublist of its ``named_parameters()`` used + together in execution (see ``_patched_create_proxy()``), or + (2) a submodule and all of ``submodule.named_parameters()`` (see + ``_patched_call_module()``). + + Type (1) corresponds to directly using parameters in ops without calling + ``forward()``, and type (2) corresponds to calling ``forward()``. The + mapped-to lists in the ``dict`` follow the execution order. + """ + + module: nn.Module + named_params: list[tuple[str, nn.Parameter]] + + +class _ExecutionInfo: + """ + This represents the execution order information from the forward pass. + + Attributes: + curr_module (nn.Module): Current module being traced. + module_forward_order (List[nn.Module]): The modules in (pre-)forward + order, i.e. the order in which their ``forward()`` methods are + called. Each call to a module's ``forward()`` corresponds to one + element in the list. + module_to_param_usage_infos (Dict[nn.Module, List[_ParamUsageInfo]]): + Maps a module to a list of module execution infos. See + :class:`_ParamUsageInfo` for details. + param_forward_order (List[nn.Parameter]): The parameters in forward + execution order, where only a parameter's first participation is + included. + visited_params (Set[nn.Parameter]): The parameters visited so far + during the trace. This is only used during tracing for fast + membership check. Invariant: The parameters in + ``param_forward_order`` are exactly those in ``visited_params``. + """ + + def __init__(self, root_module: nn.Module) -> None: + self.curr_module: nn.Module = root_module + self.module_forward_order: list[nn.Module] = [root_module] + self.module_to_param_usage_infos: dict[nn.Module, list[_ParamUsageInfo]] = { + root_module: [] + } + self.param_forward_order: list[nn.Parameter] = [] + self.visited_params: set[nn.Parameter] = set() + + +class _ExecOrderTracer: + def __init__(self) -> None: + self.exec_info: Optional[_ExecutionInfo] = None + + @contextmanager + def patch_tracer(self, tracer: torch.fx.Tracer, root_module: nn.Module): + self.exec_info = _ExecutionInfo(root_module) + orig_call_module = tracer.call_module + orig_create_proxy = tracer.create_proxy + tracer.call_module = functools.partial( # type: ignore[method-assign] + self._patched_call_module, orig_call_module, self.exec_info + ) + fqn_to_param = dict(root_module.named_parameters()) + tracer.create_proxy = functools.partial( # type: ignore[method-assign] + self._patched_create_proxy, + orig_create_proxy, + self.exec_info, + fqn_to_param, + ) + try: + yield + finally: + tracer.call_module = orig_call_module # type: ignore[method-assign] + tracer.create_proxy = orig_create_proxy # type: ignore[method-assign] + + def _patched_call_module( + self, + call_module: Callable, + exec_info: _ExecutionInfo, + # Below are the expected arguments to `call_module()` + module: nn.Module, + forward: Callable, + args: tuple[Any, ...], + kwargs: dict[str, Any], + ) -> Any: + """ + Overrides ``call_module`` to save execution information to + ``exec_info``. Note that ``call_module`` is called during symbolic + tracing for each non-root module. + + Args: + call_module (Callable): Original ``call_module`` to override. + exec_info (_ExecutionInfo): Used to record execution information. + module (nn.Module): Module corresponding to this ``call_module``. + forward (Callable): ``forward()`` method of ``module`` to be called + for this ``call_module``. + args (Tuple[Any, ...]): Positional arguments for ``forward``. + kwargs (Dict[str, Any]): Keyword arguments for ``forward``. + + Returns: + Same return value as ``call_module``. + """ + exec_info.module_forward_order.append(module) + named_params = list(module.named_parameters()) + curr_module = exec_info.curr_module + if named_params: + if curr_module not in exec_info.module_to_param_usage_infos: + raise AssertionError( + "The current module should have already been processed by a patched `call_module`" + ) + exec_info.module_to_param_usage_infos[exec_info.curr_module].append( + _ParamUsageInfo(module, named_params) + ) + prev_curr_module = curr_module + exec_info.curr_module = module + exec_info.module_to_param_usage_infos[module] = [] + output = call_module(module, forward, args, kwargs) + exec_info.curr_module = prev_curr_module + return output + + def _patched_create_proxy( + self, + create_proxy: Callable, + exec_info: _ExecutionInfo, + fqn_to_param: dict[str, nn.Parameter], + # Below are the expected arguments to `create_proxy()` + kind: str, + target: torch.fx.node.Target, + args: tuple[Any, ...], + kwargs: dict[str, Any], + name: Optional[str] = None, + type_expr: Optional[Any] = None, + proxy_factory_fn: Optional[Callable[[torch.fx.Node], torch.fx.Proxy]] = None, + ) -> torch.fx.Proxy: + """ + Overrides ``create_proxy`` to save execution information to + ``exec_info``. Note that ``create_proxy`` is called during symbolic + tracing for each leaf function/method/module. + + Args: + create_proxy (Callable): Original ``create_proxy`` to override. + exec_info (_ExecutionInfo): Used to record execution information. + fqn_to_param (Dict[str, nn.Parameter]): ``dict`` version of the + root module's ``named_parameters()`` with FQN as key and + parameter as value. + kind (str): Kind of the target method ('call_function', + 'call_method', 'get_attr', 'call_module', 'placeholder', or + 'output'). See :class:`torch.fx.Graph` for details. This is + passed to ``create_proxy``. + target (torch.fx.node.Target): Contains the string name of the + function/method/module. This is passed to ``create_proxy``. + args (Tuple[Any, ...]): Positional arguments for the function/ + method/module. This is passed to ``create_proxy``. + kwargs (Dict[str, Any]): Keyword arguments for the function/method/ + module. This is passed to ``create_proxy`` + name (Optional[str]): An optional string name for the ``Node`` + created in ``create_proxy``. This is passed to + ``create_proxy``. + type_expr (Optional[Any]): An optional type annotation representing + the Python type that the output of the node has. This is passed + to ``create_proxy``. + proxy_factory_fn (Callable[[torch.fx.Node], torch.fx.Proxy]): + An alternative proxy constructor used in ``create_proxy``. This + is passed to ``create_proxy``. + + Returns: + torch.fx.Proxy: Created ``Node`` wrapped in a ``Proxy`` object. + """ + proxy = create_proxy( + kind, target, args, kwargs, name, type_expr, proxy_factory_fn + ) + curr_module = exec_info.curr_module + if kind in ("call_function", "call_method"): + if args is not None: + named_params: list[tuple[str, nn.Parameter]] = [] + for arg in args: + if ( + isinstance(arg, torch.fx.Proxy) + and arg.node.target in fqn_to_param + ): + param = fqn_to_param[arg.node.target] # type: ignore[index] + named_params.append((arg.node.target, param)) # type: ignore[arg-type] + if param not in exec_info.visited_params: + exec_info.visited_params.add(param) + exec_info.param_forward_order.append(param) + if named_params: + exec_info.module_to_param_usage_infos[curr_module].append( + _ParamUsageInfo(curr_module, named_params) + ) + elif kind == "call_module": + named_params = list(curr_module.named_parameters()) + if named_params: + exec_info.module_to_param_usage_infos[curr_module].append( + _ParamUsageInfo(curr_module, named_params) + ) + for _, param in named_params: + if param not in exec_info.visited_params: + exec_info.visited_params.add(param) + exec_info.param_forward_order.append(param) + return proxy diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/fsdp/_traversal_utils.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/fsdp/_traversal_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..51140d3b0a8d3d16ab50226b414e651f22772648 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/fsdp/_traversal_utils.py @@ -0,0 +1,112 @@ +""" +NOTE: This file must be imported like +``import torch.distributed.fsdp._traversal_utils`` and not like +``from torch.distributed.fsdp._traversal_utils import ...`` to avoid circular +imports. For brevity, we may import the file as ``traversal_utils``. +""" + +import collections + +import torch.nn as nn +from torch.distributed._composable.contract import _get_registry +from torch.distributed.fsdp._common_utils import _FSDPState, _get_module_fsdp_state + + +""" +[Note: FSDP State Traversal] +For the wrapper code path, ``_FSDPState`` is the ``FullyShardedDataParallel`` +module wrapping a fully sharded module, and for the non-wrapper code path, +``_FSDPState`` is an object that gets embedded on a fully sharded module. +See [Note: Fully Sharded Module] for the definition. + +There are three common traversal idioms: Given a root module, +- ``_get_fsdp_states()`` returns all ``_FSDPState`` s in the tree. +- ``get_fsdp_root_states()`` returns all local root ``_FSDPState`` s in the +tree (i.e. those with ``_is_root == True``). +- ``_get_fsdp_handles()``returns all ``FlatParamHandle`` s in the tree. + +All of these methods must take in the root module (i.e. an ``nn.Module``) and +not a general ``_FSDPState`` because ``_FSDPState`` does not support a graph +traversal, whereas ``nn.Module`` has ``nn.Module.modules()`` for traversal. +""" + + +def _composable(module: nn.Module) -> bool: + """ + Returns if ``module`` can compose with ``fully_shard``. + """ + # TODO: Add any other composable APIs that are mutually exclusive. + registry = _get_registry(module) + if registry is None: + return True + return "replicate" not in registry + + +# TODO (awgu): We may be able to remove this function if we retired the +# `use_orig_params=False` code path since so far we only need the module for +# `FlatParameter` registration, which is not needed for `use_orig_params=True`. +def _get_fsdp_states_with_modules( + module: nn.Module, +) -> tuple[list[_FSDPState], list[nn.Module]]: + """ + Returns a tuple containing: + 1. A list of the ``_FSDPState`` instances in the module tree rooted at + ``module`` without any duplicates and following the ``module.modules()`` + traversal order (which is assumed to be depth-first). + 2. A corresponding list of the modules owning the states in the first list. + + For the wrapper code path, both returned lists are the same, each + containing all ``FullyShardedDataParallel`` instances. For the composable + code path, this returns a list of all composable state instances and a list + of the corresponding fully sharded modules. See [Note: Fully Sharded + Module]. + + NOTE: The traversal does not proceed into any module annotated by an + incompatible API (e.g. ``replicate``). + """ + fsdp_states: list[_FSDPState] = [] + fsdp_modules: list[nn.Module] = [] + # Track the visited FSDP states since multiple modules may share the same + # one and we want to return a de-duplicated list + visited_fsdp_states: set[_FSDPState] = set() + # Track the visited modules in case of shared modules, which implies the + # module graph is no longer a tree + visited_modules: set[nn.Module] = set() + + # Perform depth-first search from `module` to ensure that we do not + # traverse into an incompatible API's subtree (use DFS instead of BFS to + # match `.modules()` order) + deque: collections.deque[nn.Module] = collections.deque([module]) + while deque: + submodule = deque.popleft() + visited_modules.add(submodule) + if not _composable(submodule): + continue + for child_module in reversed(list(submodule.children())): + if child_module not in visited_modules: + deque.appendleft(child_module) + optional_state = _get_module_fsdp_state(submodule) + if optional_state is not None and optional_state not in visited_fsdp_states: + visited_fsdp_states.add(optional_state) + fsdp_states.append(optional_state) + fsdp_modules.append(submodule) + return fsdp_states, fsdp_modules + + +def _get_fsdp_states(module: nn.Module) -> list[_FSDPState]: + """See :func:`_get_fsdp_states_with_modules`.""" + fsdp_states, _ = _get_fsdp_states_with_modules(module) + return fsdp_states + + +def _get_fsdp_handles(module: nn.Module) -> list: + """ + Returns all ``FlatParamHandle`` s in the module tree rooted at ``module`` + following the rules in :func:`_get_fsdp_state`. + """ + handles = [ + fsdp_state._handle + for fsdp_state in _get_fsdp_states(module) + if fsdp_state._handle is not None + ] + return handles diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/fsdp/_unshard_param_utils.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/fsdp/_unshard_param_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..71dc1a9f4e28c7101fc0acdae2582be89e954013 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/fsdp/_unshard_param_utils.py @@ -0,0 +1,340 @@ +# mypy: allow-untyped-defs +import contextlib +import warnings +from collections.abc import Generator +from typing import cast + +import torch +import torch.distributed.fsdp._traversal_utils as traversal_utils +import torch.nn as nn +from torch.distributed.fsdp._common_utils import ( + _FSDPState, + _get_module_fsdp_state, + _has_fsdp_params, + _module_handle, + HandleTrainingState, + TrainingState, +) +from torch.distributed.fsdp._runtime_utils import ( + _lazy_init, + _reset_flat_param_grad_info_if_needed, + _reshard, + _reshard_grads, + _unshard, + _unshard_grads, +) +from torch.distributed.utils import _p_assert + +from ._flat_param import FlatParamHandle + + +FLAT_PARAM = "_flat_param" + + +@torch.no_grad() +def _writeback_to_local_shard( + handle: FlatParamHandle, + writeback_grad: bool, +): + """ + For the handle, writes back the this rank's shard of the unsharded + flattened parameter to the sharded flattened parameter. If + ``writeback_grad=True``, then writes back to the sharded gradient as + well. + + Precondition: The handle's ``FlatParameter`` 's data points to the + padded unsharded flattened parameter. + """ + + def _get_shard(flat_param_or_grad: torch.Tensor) -> torch.Tensor: + if handle.uses_sharded_strategy: + # For sharded strategies, get the *unpadded* shard instead of + # the *padded* shard to persist user changes to the padding + # (though FSDP does not explicitly support this) + shard, _ = FlatParamHandle._get_unpadded_shard( + flat_param_or_grad, + handle.rank, + handle.world_size, + ) + return shard + # For `NO_SHARD`, the `flat_param` or its gradient may be modified, + # so we write it back directly + return flat_param_or_grad + + param_shard = _get_shard(handle.flat_param) + handle.flat_param._local_shard[: param_shard.numel()].copy_(param_shard) # type: ignore[attr-defined] + if writeback_grad: + existing_grad = handle.sharded_grad + if existing_grad is not None: + if handle.flat_param.grad is None: + raise AssertionError("Expected handle.flat_param.grad to not be None") + grad_shard = _get_shard(handle.flat_param.grad) + existing_grad[: grad_shard.numel()].copy_(grad_shard) + + +def _deregister_flat_param(state: _FSDPState, module: nn.Module) -> None: + """ + De-registers the flattened parameter from the wrapped module, hiding it + from ``nn.Module`` methods. + + We do not use ``del`` because we want ``FLAT_PARAM`` to always be an + attribute but dynamically change whether it is visible to ``nn.Module`` + methods. + """ + if _has_fsdp_params(state, module): + # TODO: figure out the case for the composable APIs. + cast(nn.Module, module.module)._parameters.pop(FLAT_PARAM, None) + + +def _register_flat_param(state: _FSDPState, module: nn.Module) -> None: + """ + Registers the flattened parameter to the wrapped module, making it + visible to ``nn.Module`` methods. + + We do not use :meth:`nn.Module.register_parameter` because we want + ``FLAT_PARAM`` to always be an attribute but dynamically change whether + it is visible to ``nn.Module`` methods. + """ + handle = _module_handle(state, module) + if _has_fsdp_params(state, module): + # TODO: figure out the case for the composable APIs. + cast(nn.Module, module.module)._parameters[FLAT_PARAM] = handle.flat_param + + +@contextlib.contextmanager +def _unflatten_as_params(state: _FSDPState, module: nn.Module) -> Generator: + """ + Assumes that the flattened parameter is unsharded. When in the context, + de-registers the flattened parameter and unflattens the original + parameters as ``nn.Parameter`` views into the flattened parameter. + After the context, re-registers the flattened parameter and restores + the original parameters as ``Tensor`` views into the flattened + parameter. + """ + handle = _module_handle(state, module) + if not handle: + yield + else: + _deregister_flat_param(state, module) + try: + with handle.unflatten_as_params(): + yield + finally: + if not handle._use_orig_params: + _register_flat_param(state, module) + + +def _validate_unshard_params_args( + state: _FSDPState, + writeback: bool, + rank0_only: bool, + offload_to_cpu: bool, + with_grads: bool, +) -> None: + if with_grads and (offload_to_cpu or not state._use_orig_params): + raise NotImplementedError( + f"with_grads={with_grads}, " + f"use_orig_params={state._use_orig_params}, " + f"offload_to_cpu={offload_to_cpu} " + f"is not supported yet" + ) + if offload_to_cpu and state._handle and (not state._handle.uses_sharded_strategy): + raise NotImplementedError( + "offload_to_cpu=True and NO_SHARD is not supported yet" + ) + if writeback and rank0_only: + # TODO: Rank 0 can broadcast the `FlatParameter` to allow all ranks to + # persist the changes. + raise NotImplementedError( + "writeback=True and rank0_only=True is not supported yet" + ) + if offload_to_cpu and not rank0_only: + warnings.warn( + "offload_to_cpu=True and rank0_only=False may result in the" + "unsharded parameters being redundantly copied to CPU memory for " + "GPUs sharing the same CPU memory, which risks CPU OOM. We " + "recommend using offload_to_cpu=True with rank0_only=True.", + stacklevel=2, + ) + + +@contextlib.contextmanager +def _unshard_fsdp_state_params( + module: nn.Module, + state: _FSDPState, + writeback: bool, + rank0_only: bool, + offload_to_cpu: bool, + with_grads: bool, +): + """ + This unshards the parameters for a single FSDP state ``state`` that + corresponds to ``module``. + """ + _validate_unshard_params_args( + state, writeback, rank0_only, offload_to_cpu, with_grads + ) + state._device_handle.synchronize() + # If handles are shared by other module(s), the handle may be already unsharded. + maybe_handle = _module_handle(state, module) + handle = None + if ( + maybe_handle + and maybe_handle._training_state != HandleTrainingState.SUMMON_FULL_PARAMS + ): + handle = maybe_handle + if not handle: + yield + return + + if handle._training_state != HandleTrainingState.IDLE: + raise AssertionError( + f"Expects the handle training to be IDLE but got {handle._training_state}" + ) + + handle._training_state = HandleTrainingState.SUMMON_FULL_PARAMS + + _reset_flat_param_grad_info_if_needed(handle) + free_unsharded_flat_param = handle.needs_unshard() + # No need to call `wait_stream()` since we unshard in the computation + # stream directly + computation_stream = state._device_handle.current_stream() + _unshard(state, handle, computation_stream, computation_stream) + if with_grads: + _unshard_grads(handle) + + if rank0_only and state.rank != 0: + # Free the unsharded flattened parameter early + _reshard(state, handle, free_unsharded_flat_param) + if with_grads: + _reshard_grads(handle) + try: + yield + finally: + handle._training_state = HandleTrainingState.IDLE + else: + # Unflatten the unsharded flattened parameters + with contextlib.ExitStack() as stack: + # Invariant: rank == 0 or !rank0_only + if offload_to_cpu and handle.uses_sharded_strategy: + stack.enter_context(handle.to_cpu()) + # NOTE: Since PyTorch enforces that a parameter and its + # gradients need to match metadata (e.g. device), we must + # move gradients to CPU *after* we move parameters. + # NOTE: This assumes 1 `FlatParameter` + if not state._use_orig_params: + stack.enter_context(_unflatten_as_params(state, module)) + try: + yield + finally: + stack.close() + if writeback: + _writeback_to_local_shard(handle, with_grads) + _reshard(state, handle, free_unsharded_flat_param) + if with_grads: + _reshard_grads(handle) + handle._training_state = HandleTrainingState.IDLE + + +@contextlib.contextmanager +def _unshard_params_for_summon( + module: nn.Module, + state: _FSDPState, + writeback: bool, + rank0_only: bool, + offload_to_cpu: bool, + with_grads: bool, +): + _validate_unshard_params_args( + state, writeback, rank0_only, offload_to_cpu, with_grads + ) + _lazy_init(state, module) + if state.training_state == TrainingState.FORWARD_BACKWARD: + raise AssertionError( + "Cannot manually unshard parameters during forward/backward" + ) + elif state.training_state == TrainingState.SUMMON_FULL_PARAMS: + raise AssertionError( + "Cannot manually unshard parameters when already unsharding parameters" + ) + with _unshard_fsdp_state_params( + module=module, + state=state, + writeback=writeback, + rank0_only=rank0_only, + offload_to_cpu=offload_to_cpu, + with_grads=with_grads, + ): + try: + state.training_state = TrainingState.SUMMON_FULL_PARAMS + yield + finally: + state.training_state = TrainingState.IDLE + + +@contextlib.contextmanager +def _unshard_params( + module: nn.Module, + recurse: bool, + writeback: bool, + rank0_only: bool, + offload_to_cpu: bool, + with_grads: bool, +): + """ + This unshards FSDP-managed parameters for all modules with FSDP applied in + the module tree rooted at ``module``. + """ + if not recurse: + optional_state = _get_module_fsdp_state(module) + if optional_state is None: + with contextlib.nullcontext(): + yield + return + states_and_modules = ([optional_state], [module]) + else: + states_and_modules = traversal_utils._get_fsdp_states_with_modules(module) + with contextlib.ExitStack() as stack: + for state, module in zip(*states_and_modules): + stack.enter_context( + _unshard_params_for_summon( + module=module, + state=state, + writeback=writeback, + rank0_only=rank0_only, + offload_to_cpu=offload_to_cpu, + with_grads=with_grads, + ) + ) + yield + + +def _deregister_orig_params(state: _FSDPState, module: nn.Module) -> None: + """ + Deregisters the original parameters; registers the ``FlatParameter``. + """ + handle = _module_handle(state, module) + if not handle: + return + _p_assert( + handle._use_orig_params, + f"Inconsistent `_use_orig_params` -- FSDP: {state._use_orig_params} " + f"handle: {handle._use_orig_params}", + ) + handle._deregister_orig_params() + _register_flat_param(state, module) + + +def _register_orig_params(state: _FSDPState, module: nn.Module) -> None: + """ + Deregisters the ``FlatParameter``; registers the original parameters. + """ + handle = _module_handle(state, module) + if not handle: + return + _deregister_flat_param(state, module) + if handle.is_sharded(handle.flat_param): + handle._use_sharded_views() + handle._use_sharded_grad_views() + else: + handle._use_unsharded_views(as_params=True) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/fsdp/_wrap_utils.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/fsdp/_wrap_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..41dc4d8575198875b8403c2a41c7b2f547a1b742 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/fsdp/_wrap_utils.py @@ -0,0 +1,264 @@ +# mypy: allow-untyped-defs +import collections +import functools +import inspect +import warnings +from collections.abc import Callable +from functools import partial +from typing import Any, Union + +import torch.nn as nn +from torch.distributed.fsdp._common_utils import ( + _get_module_fsdp_state, + _override_module_mixed_precision, +) +from torch.distributed.fsdp.wrap import ( + _construct_wrap_fn, + _or_policy, + _Policy, + _post_order_apply, + _recursive_wrap, + _run_mixed_precision_override_policy, + _wrap_module_cls_individually, +) + + +def _auto_wrap( + root_module: nn.Module, + policy: Union[Callable, _Policy], + ignored_modules: set[nn.Module], + ignored_params: set[nn.Parameter], + root_kwargs: dict[str, Any], + fsdp_fn: Callable, # e.g. `FullyShardedDataParallel` or `fully_shard` +): + """ + Auto wraps modules in ``root_module`` 's tree according to ``policy`` + following a post-order traversal. + + Precondition: ``root_kwargs`` should contain all arguments except + ``module``. This function accepts the kwargs dict directly since it gets + forwarded into the post-order traversal function. + """ + mixed_precision = root_kwargs["mixed_precision"] + is_wrapper = inspect.isclass(fsdp_fn) + # TODO: We may relax this no-nested-wrapping constraint to support manual + # wrapping followed by auto wrapping. + _check_nested_wrapping(root_module) + + if isinstance(policy, _Policy): + root_kwargs["auto_wrap_policy" if is_wrapper else "policy"] = None + target_module_to_kwargs = policy._run_policy( + root_module, ignored_modules, root_kwargs + ) + if mixed_precision is not None: + target_module_to_kwargs = _run_mixed_precision_override_policy( + root_module, + mixed_precision._module_classes_to_ignore, + ignored_modules, + root_kwargs, + target_module_to_kwargs, + ) + overridden_module_classes = _override_module_mixed_precision( + root_module, mixed_precision._module_classes_to_ignore + ) + _warn_on_overridden_mixed_precision(overridden_module_classes) + use_orig_params = root_kwargs.get("use_orig_params", False) + _validate_frozen_params( + root_module, + set(target_module_to_kwargs.keys()), + ignored_params, + use_orig_params, + ) + wrap_fn = _construct_wrap_fn(root_module, target_module_to_kwargs, fsdp_fn) + _post_order_apply(root_module, wrap_fn) + return + + recursive_wrap_kwargs = { + "module": root_module, + "auto_wrap_policy": policy, + "wrapper_cls": fsdp_fn, + "ignored_modules": ignored_modules, + "ignored_params": ignored_params, + "only_wrap_children": True, + } + if mixed_precision is not None: + # Wrap modules of the ignored types separately and register forward + # hooks to cast to fp32 and back to the original dtype, respectively + overridden_module_classes = _override_module_mixed_precision( + root_module, mixed_precision._module_classes_to_ignore + ) + policy = functools.partial( + _or_policy, + policies=[ + policy, + partial( + _wrap_module_cls_individually, + module_classes=mixed_precision._module_classes_to_ignore, + ), + ], + ) + recursive_wrap_kwargs["auto_wrap_policy"] = policy + _warn_on_overridden_mixed_precision(overridden_module_classes) + _recursive_wrap(**recursive_wrap_kwargs, **root_kwargs) # type: ignore[arg-type] + + +def _check_nested_wrapping(root_module: nn.Module): + for module_name, module in root_module.named_modules(): + if _get_module_fsdp_state(module) is not None: + raise ValueError( + "FSDP auto wrapping requires modules to not already have " + f"FSDP applied but found {module_name} in\n{root_module}" + ) + + +def _warn_on_overridden_mixed_precision( + overridden_module_classes: set[type[nn.Module]], +): + if len(overridden_module_classes) == 0: + return + warnings.warn( + "Both mixed precision and an auto_wrap_policy were specified to FSDP, " + f"where the wrapped module has submodules of type:\n{overridden_module_classes}\n" + "These modules will be wrapped as separate FSDP instacnes with mixed " + "precision disabled.", + stacklevel=2, + ) + + +def _validate_frozen_params( + root_module: nn.Module, + modules_to_wrap: set[nn.Module], + ignored_params: set[nn.Parameter], + use_orig_params: bool, +): + """ + This checks that, given ``modules_to_wrap``, each module would manage + parameters that are uniformly frozen or non-frozen. This uniformity + requirement is strict for ``use_orig_params=False`` (hard error) and highly + recommended for ``use_orig_params=True`` (user warning). + """ + post_order_named_modules = _get_post_order_named_modules(root_module) + visited_modules: set[nn.Module] = set() + for module_name, module in post_order_named_modules: + if module in modules_to_wrap: + param_to_fqn = _get_managed_param_to_fqn( + module, ignored_params, visited_modules, module_name + ) + frozen_param_fqns: list[str] = [] + frozen_param_numel = 0 + nonfrozen_param_fqns: list[str] = [] + nonfrozen_param_numel = 0 + for param, fqn in param_to_fqn.items(): + if param.requires_grad: + nonfrozen_param_fqns.append(fqn) + nonfrozen_param_numel += param.numel() + else: + frozen_param_fqns.append(fqn) + frozen_param_numel += param.numel() + if len(frozen_param_fqns) > 0 and len(nonfrozen_param_fqns) > 0: + msg = f"{module_name} has both parameters with requires_grad=True and False." + if use_orig_params: + total_param_numel = frozen_param_numel + nonfrozen_param_numel + msg += ( + " We do not recommend wrapping such modules since " + "the gradient memory usage will be higher than expected " + f"({total_param_numel} numel instead of {nonfrozen_param_numel} numel " + "before sharding via reduce-scatter). " + ) + else: + msg += " FSDP does not support wrapping such modules when use_orig_params=False. " + msg += "If possible, wrap the frozen parameters with FSDP separately.\n" + msg += ( + f"The following parameters have requires_grad=True:\n{nonfrozen_param_fqns}\n" + f"The following parameters have requires_grad=False:\n{frozen_param_fqns}" + ) + if use_orig_params: + warnings.warn(msg, stacklevel=2) + else: + raise ValueError(msg) + + +def _get_post_order_named_modules( + root_module: nn.Module, +) -> list[tuple[str, nn.Module]]: + """ + This returns the named modules following a post-order traversal, which is a + valid reverse topological sort. We achieve this using the reverse of a + stack-based DFS order instead of reversing ``root_module.named_modules()`` + since the former gives the modules in registration order at each level in + the module tree (as opposed to the reverse), which allows us to error/warn + on the first registered module that violates the condition. + + For example, consider the following module structure: + M( + S1(), + S2( + SS1(), + SS2(), + ), + S3(), + ) + The reverse DFS order is [S1, SS1, SS2, S2, S3, M], while the reverse + ``named_modules()`` order is [S3, SS2, SS1, S2, S1, M]. + """ + visited_modules = {root_module} + stack = [("", root_module)] + # Append and reverse at the end for linear-time algorithm + reverse_post_order_named_modules: list[tuple[str, nn.Module]] = [] + while stack: + module_name, module = stack.pop() + reverse_post_order_named_modules.append((module_name, module)) + for child_module_name, child_module in module.named_children(): + if child_module is None: # only for overrides of `named_children()` + continue + if child_module not in visited_modules: + visited_modules.add(child_module) + if module_name != "": + child_module_name = module_name + "." + child_module_name + stack.append((child_module_name, child_module)) + post_order_named_modules = list(reversed(reverse_post_order_named_modules)) + return post_order_named_modules + + +def _get_managed_param_to_fqn( + module_to_wrap: nn.Module, + ignored_params: set[nn.Parameter], + visited_modules: set[nn.Module], + root_prefix: str, +) -> dict[nn.Parameter, str]: + """ + This returns a dict that maps managed parameter to its FQN for the given + ``module_to_wrap``. The dict's keys are exactly the parameters that would + be managed by the module, where this is achieved by calling this function + on the modules to wrap in reverse topological order, destructively updating + ``visited_modules``, and not traversing into those modules. The FQNs are + prefixed from the root (via ``root_prefix``) to be more informative. + + NOTE: This function is meant to be called pre-wrapping and iteratively in + reverse topological order to cover the full module tree. This differs from + the ``_get_param_to_fqn()`` function meant to be called post-wrapping and + on the full module tree in one shot. Given those differences, we do not try + to unify the two. + """ + param_to_fqn: dict[nn.Parameter, str] = {} + # Run BFS (or any tree traversal works) + queue = collections.deque([(module_to_wrap, root_prefix)]) + visited_modules.add(module_to_wrap) + while queue: + module, prefix = queue.popleft() + for param_name, param in module.named_parameters(recurse=False): + if param not in ignored_params: + fqn = param_name if prefix == "" else prefix + "." + param_name + param_to_fqn[param] = fqn + for child_module_name, child_module in module.named_children(): + if child_module is None: # only for overrides of `named_children()` + continue + if child_module not in visited_modules: + visited_modules.add(child_module) + child_prefix = ( + child_module_name + if prefix == "" + else prefix + "." + child_module_name + ) + queue.append((child_module, child_prefix)) + return param_to_fqn diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/fsdp/api.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/fsdp/api.py new file mode 100644 index 0000000000000000000000000000000000000000..17ed0483f1c26248673fe888bc5489e099b1313b --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/fsdp/api.py @@ -0,0 +1,417 @@ +""" +This file includes public APIs for FSDP such as the classes used for the +constructor arguments. +""" + +from collections.abc import Sequence +from dataclasses import dataclass +from enum import auto, Enum +from typing import Optional + +import torch +from torch.nn.modules.batchnorm import _BatchNorm + + +__all__ = [ + "ShardingStrategy", + "BackwardPrefetch", + "MixedPrecision", + "CPUOffload", + "StateDictType", + "StateDictConfig", + "FullStateDictConfig", + "LocalStateDictConfig", + "ShardedStateDictConfig", + "OptimStateDictConfig", + "FullOptimStateDictConfig", + "LocalOptimStateDictConfig", + "ShardedOptimStateDictConfig", + "StateDictSettings", +] + + +class ShardingStrategy(Enum): + """ + This specifies the sharding strategy to be used for distributed training by + :class:`FullyShardedDataParallel`. + + - ``FULL_SHARD``: Parameters, gradients, and optimizer states are sharded. + For the parameters, this strategy unshards (via all-gather) before the + forward, reshards after the forward, unshards before the backward + computation, and reshards after the backward computation. For gradients, + it synchronizes and shards them (via reduce-scatter) after the backward + computation. The sharded optimizer states are updated locally per rank. + - ``SHARD_GRAD_OP``: Gradients and optimizer states are sharded during + computation, and additionally, parameters are sharded outside + computation. For the parameters, this strategy unshards before the + forward, does not reshard them after the forward, and only reshards them + after the backward computation. The sharded optimizer states are updated + locally per rank. Inside ``no_sync()``, the parameters are not resharded + after the backward computation. + - ``NO_SHARD``: Parameters, gradients, and optimizer states are not sharded + but instead replicated across ranks similar to PyTorch's + :class:`DistributedDataParallel` API. For gradients, this strategy + synchronizes them (via all-reduce) after the backward computation. The + unsharded optimizer states are updated locally per rank. + - ``HYBRID_SHARD``: Apply ``FULL_SHARD`` within a node, and replicate parameters across + nodes. This results in reduced communication volume as expensive all-gathers and + reduce-scatters are only done within a node, which can be more performant for medium + -sized models. + - ``_HYBRID_SHARD_ZERO2``: Apply ``SHARD_GRAD_OP`` within a node, and replicate parameters across + nodes. This is like ``HYBRID_SHARD``, except this may provide even higher throughput + since the unsharded parameters are not freed after the forward pass, saving the + all-gathers in the pre-backward. + """ + + FULL_SHARD = auto() + SHARD_GRAD_OP = auto() + NO_SHARD = auto() + HYBRID_SHARD = auto() + _HYBRID_SHARD_ZERO2 = auto() + + +class BackwardPrefetch(Enum): + """ + This configures explicit backward prefetching, which improves throughput by + enabling communication and computation overlap in the backward pass at the + cost of slightly increased memory usage. + + - ``BACKWARD_PRE``: This enables the most overlap but increases memory + usage the most. This prefetches the next set of parameters *before* the + current set of parameters' gradient computation. This overlaps the *next + all-gather* and the *current gradient computation*, and at the peak, it + holds the current set of parameters, next set of parameters, and current + set of gradients in memory. + - ``BACKWARD_POST``: This enables less overlap but requires less memory + usage. This prefetches the next set of parameters *after* the current + set of parameters' gradient computation. This overlaps the *current + reduce-scatter* and the *next gradient computation*, and it frees the + current set of parameters before allocating memory for the next set of + parameters, only holding the next set of parameters and current set of + gradients in memory at the peak. + - FSDP's ``backward_prefetch`` argument accepts ``None``, which disables + the backward prefetching altogether. This has no overlap and does not + increase memory usage. In general, we do not recommend this setting since + it may degrade throughput significantly. + + For more technical context: For a single process group using NCCL backend, + any collectives, even if issued from different streams, contend for the + same per-device NCCL stream, which implies that the relative order in which + the collectives are issued matters for overlapping. The two backward + prefetching values correspond to different issue orders. + """ + + # NOTE: For both modes, the ordering that defines "current" and "next" is + # not always exact in the current implementation. A mistargeted prefetch + # simply means that the parameter memory is allocated earlier than needed, + # possibly increasing peak memory usage, but does not affect correctness. + BACKWARD_PRE = auto() + BACKWARD_POST = auto() + + +@dataclass +class MixedPrecision: + """ + This configures FSDP-native mixed precision training. + + Attributes: + param_dtype (Optional[torch.dtype]): This specifies the dtype for model + parameters during forward and backward and thus the dtype for + forward and backward computation. Outside forward and backward, the + *sharded* parameters are kept in full precision (e.g. for the + optimizer step), and for model checkpointing, the parameters are + always saved in full precision. (Default: ``None``) + reduce_dtype (Optional[torch.dtype]): This specifies the dtype for + gradient reduction (i.e. reduce-scatter or all-reduce). If this is + ``None`` but ``param_dtype`` is not ``None``, then this takes on + the ``param_dtype`` value, still running gradient reduction in low + precision. This is permitted to differ from ``param_dtype``, e.g. + to force gradient reduction to run in full precision. (Default: + ``None``) + buffer_dtype (Optional[torch.dtype]): This specifies the dtype for + buffers. FSDP does not shard buffers. Rather, FSDP casts them to + ``buffer_dtype`` in the first forward pass and keeps them in that + dtype thereafter. For model checkpointing, the buffers are saved + in full precision except for ``LOCAL_STATE_DICT``. (Default: + ``None``) + keep_low_precision_grads (bool): If ``False``, then FSDP upcasts + gradients to full precision after the backward pass in preparation + for the optimizer step. If ``True``, then FSDP keeps the gradients + in the dtype used for gradient reduction, which can save memory if + using a custom optimizer that supports running in low precision. + (Default: ``False``) + cast_forward_inputs (bool): If ``True``, then this FSDP module casts + its forward args and kwargs to ``param_dtype``. This is to ensure + that parameter and input dtypes match for forward computation, as + required by many ops. This may need to be set to ``True`` when only + applying mixed precision to some but not all FSDP modules, in which + case a mixed-precision FSDP submodule needs to recast its inputs. + (Default: ``False``) + cast_root_forward_inputs (bool): If ``True``, then the root FSDP module + casts its forward args and kwargs to ``param_dtype``, overriding + the value of ``cast_forward_inputs``. For non-root FSDP modules, + this does not do anything. (Default: ``True``) + _module_classes_to_ignore: (Sequence[Type[nn.Module]]): This specifies + module classes to ignore for mixed precision when using an + ``auto_wrap_policy``: Modules of these classes will have FSDP + applied to them separately with mixed precision disabled (meaning + that the final FSDP construction would deviate from the specified + policy). If ``auto_wrap_policy`` is not specified, then this does + not do anything. This API is experimental and subject to change. + (Default: ``(_BatchNorm,)``) + + .. note:: This API is experimental and subject to change. + + .. note:: Only floating point tensors are cast to their specified dtypes. + + .. note:: In ``summon_full_params``, parameters are forced to full + precision, but buffers are not. + + .. note:: Layer norm and batch norm accumulate in ``float32`` even when + their inputs are in a low precision like ``float16`` or ``bfloat16``. + Disabling FSDP's mixed precision for those norm modules only means that + the affine parameters are kept in ``float32``. However, this incurs + separate all-gathers and reduce-scatters for those norm modules, which + may be inefficient, so if the workload permits, the user should prefer + to still apply mixed precision to those modules. + + .. note:: By default, if the user passes a model with any ``_BatchNorm`` + modules and specifies an ``auto_wrap_policy``, then the batch norm + modules will have FSDP applied to them separately with mixed precision + disabled. See the ``_module_classes_to_ignore`` argument. + + .. note:: ``MixedPrecision`` has ``cast_root_forward_inputs=True`` and + ``cast_forward_inputs=False`` by default. For the root FSDP instance, + its ``cast_root_forward_inputs`` takes precedence over its + ``cast_forward_inputs``. For non-root FSDP instances, their + ``cast_root_forward_inputs`` values are ignored. The default setting is + sufficient for the typical case where each FSDP instance has the same + ``MixedPrecision`` configuration and only needs to cast inputs to the + ``param_dtype`` at the beginning of the model's forward pass. + + .. note:: For nested FSDP instances with different ``MixedPrecision`` + configurations, we recommend setting individual ``cast_forward_inputs`` + values to configure casting inputs or not before each instance's + forward. In such a case, since the casts happen before each FSDP + instance's forward, a parent FSDP instance should have its non-FSDP + submodules run before its FSDP submodules to avoid the activation dtype + being changed due to a different ``MixedPrecision`` configuration. + + Example:: + + >>> # xdoctest: +SKIP("undefined variables") + >>> model = nn.Sequential(nn.Linear(3, 3), nn.Linear(3, 3)) + >>> model[1] = FSDP( + >>> model[1], + >>> mixed_precision=MixedPrecision(param_dtype=torch.float16, cast_forward_inputs=True), + >>> ) + >>> model = FSDP( + >>> model, + >>> mixed_precision=MixedPrecision(param_dtype=torch.bfloat16, cast_forward_inputs=True), + >>> ) + + The above shows a working example. On the other hand, if ``model[1]`` + were replaced with ``model[0]``, meaning that the submodule using + different ``MixedPrecision`` ran its forward first, then ``model[1]`` + would incorrectly see ``float16`` activations instead of ``bfloat16`` + ones. + + """ + + param_dtype: Optional[torch.dtype] = None + reduce_dtype: Optional[torch.dtype] = None + buffer_dtype: Optional[torch.dtype] = None + keep_low_precision_grads: bool = False + cast_forward_inputs: bool = False + cast_root_forward_inputs: bool = True + _module_classes_to_ignore: Sequence[type[torch.nn.Module]] = (_BatchNorm,) + + +@dataclass +class CPUOffload: + """ + This configures CPU offloading. + + Attributes: + offload_params (bool): This specifies whether to offload parameters to + CPU when not involved in computation. If ``True``, then this + offloads gradients to CPU as well, meaning that the optimizer step + runs on CPU. + """ + + offload_params: bool = False + + +class StateDictType(Enum): + """ + This enum indicates that which type of ``state_dict`` the FSDP module is + currently processing (returning or loading). + The default value is FULL_STATE_DICT to comply the PyTorch convention. + + .. note:: + FSDP currently supports three types of ``state_dict``: + 1. ``state_dict/load_state_dict`: this pair of APIs return and load + the non-sharded, unflattened parameters. The semantics is the + same as using DDP. + 2. ``_local_state_dict/_load_local_state_dict``: this pair of APIs return + and load local sharded, flattened parameters. The values returned + by ``_local_state_dict`` can be directly used by FSDP and is only + meaningful to FSDP (because parameters are flattened). Note that + these APIs are meant for use via the :func:`state_dict_type` + context manager as follows: + >>> # xdoctest: +SKIP("undefined variables") + >>> with fsdp.state_dict_type(StateDictType.LOCAL_STATE_DICT): + ... state = fsdp.state_dict() # loads local state dict + 3. ``_sharded_state_dict/_load_sharded_state_dict``: this pair of APIs + return and load sharded, unflattened parameters. The ``state_dict`` + return by ``sharded_state_dict`` can be used by all other parallel + schemes (resharding may be required). + """ + + FULL_STATE_DICT = auto() + LOCAL_STATE_DICT = auto() + SHARDED_STATE_DICT = auto() + + +@dataclass +class StateDictConfig: + """ + ``StateDictConfig`` is the base class for all ``state_dict`` configuration + classes. Users should instantiate a child class (e.g. + ``FullStateDictConfig``) in order to configure settings for the + corresponding ``state_dict`` type supported by FSDP. + + Attributes: + offload_to_cpu (bool): If ``True``, then FSDP offloads the state dict + values to CPU, and if ``False``, then FSDP keeps them on GPU. + (Default: ``False``) + """ + + offload_to_cpu: bool = False + + +@dataclass +class FullStateDictConfig(StateDictConfig): + """ + ``FullStateDictConfig`` is a config class meant to be used with + ``StateDictType.FULL_STATE_DICT``. We recommend enabling both + ``offload_to_cpu=True`` and ``rank0_only=True`` when saving full state + dicts to save GPU memory and CPU memory, respectively. This config class + is meant to be used via the :func:`state_dict_type` context manager as + follows: + + >>> # xdoctest: +SKIP("undefined variables") + >>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP + >>> fsdp = FSDP(model, auto_wrap_policy=...) + >>> cfg = FullStateDictConfig(offload_to_cpu=True, rank0_only=True) + >>> with FSDP.state_dict_type(fsdp, StateDictType.FULL_STATE_DICT, cfg): + >>> state = fsdp.state_dict() + >>> # `state` will be empty on non rank 0 and contain CPU tensors on rank 0. + >>> # To reload checkpoint for inference, finetuning, transfer learning, etc: + >>> model = model_fn() # Initialize model in preparation for wrapping with FSDP + >>> if dist.get_rank() == 0: + >>> # Load checkpoint only on rank 0 to avoid memory redundancy + >>> state_dict = torch.load("my_checkpoint.pt") + >>> model.load_state_dict(state_dict) + >>> # All ranks initialize FSDP module as usual. `sync_module_states` argument + >>> # communicates loaded checkpoint states from rank 0 to rest of the world. + >>> fsdp = FSDP( + ... model, + ... device_id=torch.cuda.current_device(), + ... auto_wrap_policy=..., + ... sync_module_states=True, + ... ) + >>> # After this point, all ranks have FSDP model with loaded checkpoint. + + Attributes: + rank0_only (bool): If ``True``, then only rank 0 saves the full state + dict, and nonzero ranks save an empty dict. If ``False``, then all + ranks save the full state dict. (Default: ``False``) + """ + + rank0_only: bool = False + + +@dataclass +class LocalStateDictConfig(StateDictConfig): + pass + + +@dataclass +class ShardedStateDictConfig(StateDictConfig): + """ + ``ShardedStateDictConfig`` is a config class meant to be used with + ``StateDictType.SHARDED_STATE_DICT``. + + Attributes: + _use_dtensor (bool): If ``True``, then FSDP saves the state dict values + as ``DTensor``, and if ``False``, then FSDP saves them as + ``ShardedTensor``. (Default: ``False``) + + .. warning:: ``_use_dtensor`` is a private field of :class:`ShardedStateDictConfig` + and it is used by FSDP to determine the type of state dict values. Users should not + manually modify ``_use_dtensor``. + """ + + _use_dtensor: bool = False + + +@dataclass +class OptimStateDictConfig: + """ + ``OptimStateDictConfig`` is the base class for all ``optim_state_dict`` + configuration classes. Users should instantiate a child class (e.g. + ``FullOptimStateDictConfig``) in order to configure settings for the + corresponding ``optim_state_dict`` type supported by FSDP. + + Attributes: + offload_to_cpu (bool): If ``True``, then FSDP offloads the state dict's + tensor values to CPU, and if ``False``, then FSDP keeps them on the + original device (which is GPU unless parameter CPU offloading is + enabled). (Default: ``True``) + """ + + offload_to_cpu: bool = True + + +@dataclass +class FullOptimStateDictConfig(OptimStateDictConfig): + """ + Attributes: + rank0_only (bool): If ``True``, then only rank 0 saves the full state + dict, and nonzero ranks save an empty dict. If ``False``, then all + ranks save the full state dict. (Default: ``False``) + """ + + rank0_only: bool = False + + +@dataclass +class LocalOptimStateDictConfig(OptimStateDictConfig): + offload_to_cpu: bool = False + + +@dataclass +class ShardedOptimStateDictConfig(OptimStateDictConfig): + """ + ``ShardedOptimStateDictConfig`` is a config class meant to be used with + ``StateDictType.SHARDED_STATE_DICT``. + + Attributes: + _use_dtensor (bool): If ``True``, then FSDP saves the state dict values + as ``DTensor``, and if ``False``, then FSDP saves them as + ``ShardedTensor``. (Default: ``False``) + + .. warning:: ``_use_dtensor`` is a private field of :class:`ShardedOptimStateDictConfig` + and it is used by FSDP to determine the type of state dict values. Users should not + manually modify ``_use_dtensor``. + """ + + _use_dtensor: bool = False + + +@dataclass +class StateDictSettings: + state_dict_type: StateDictType + state_dict_config: StateDictConfig + optim_state_dict_config: OptimStateDictConfig diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py new file mode 100644 index 0000000000000000000000000000000000000000..cdc5ef424e7052a41ddb986da07e1edb389bed27 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py @@ -0,0 +1,2199 @@ +# mypy: ignore-errors + +import contextlib +import copy +import functools +import math +import traceback +import warnings +from collections.abc import Callable, Generator, Iterable, Iterator +from contextlib import contextmanager +from enum import auto, Enum +from typing import Any, Optional, Union + +import torch +import torch.distributed as dist +import torch.distributed.fsdp._traversal_utils as traversal_utils +import torch.nn as nn +from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( + _CHECKPOINT_WRAPPED_MODULE, + ActivationWrapper, +) +from torch.distributed.algorithms._comm_hooks import LOW_PRECISION_HOOKS +from torch.distributed.fsdp._common_utils import ( + _FSDPState, + _get_param_to_fqns, + FSDP_PREFIX, + FSDP_WRAPPED_MODULE, + HandleTrainingState, + TrainingState, +) +from torch.distributed.fsdp._dynamo_utils import _annotate_modules_for_dynamo +from torch.distributed.fsdp._init_utils import ( + _check_orig_params_flattened, + _init_buffer_state, + _init_core_state, + _init_device_handle, + _init_extension, + _init_ignored_module_states, + _init_param_handle_from_module, + _init_prefetching_state, + _init_process_group_state, + _init_runtime_state, + _init_state_dict_state, + HYBRID_SHARDING_STRATEGIES, + ProcessGroupType, +) +from torch.distributed.fsdp._runtime_utils import ( + _get_fsdp_root_states, + _is_fsdp_root, + _lazy_init, + _post_forward, + _post_forward_reshard, + _pre_forward, + _pre_forward_unshard, + _root_pre_forward, + _unshard, + _wait_for_computation_stream, +) +from torch.distributed.fsdp._wrap_utils import _auto_wrap +from torch.distributed.fsdp.api import ( + BackwardPrefetch, + CPUOffload, + FullOptimStateDictConfig, + FullStateDictConfig, + LocalOptimStateDictConfig, + LocalStateDictConfig, + MixedPrecision, + OptimStateDictConfig, + ShardedOptimStateDictConfig, + ShardedStateDictConfig, + ShardingStrategy, + StateDictConfig, + StateDictSettings, + StateDictType, +) +from torch.distributed.tensor import DeviceMesh +from torch.distributed.utils import _p_assert + +from ._flat_param import FlatParameter, FlatParamHandle +from ._optim_utils import ( + _flatten_optim_state_dict, + _get_param_id_to_param_from_optim_input, + _get_param_key_to_param, + _get_param_to_param_id_from_optim_input, + _get_param_to_param_key, + _optim_state_dict, + _rekey_sharded_optim_state_dict, + _set_optim_use_dtensor, +) +from ._state_dict_utils import _register_all_state_dict_hooks +from ._unshard_param_utils import ( + _deregister_orig_params, + _register_flat_param, + _register_orig_params, + _unshard_params, + _unshard_params_for_summon, +) +from .wrap import CustomPolicy, ModuleWrapPolicy + + +__all__ = [ + "FullyShardedDataParallel", + "OptimStateKeyType", +] + + +FLAT_PARAM = "_flat_param" + + +class OptimStateKeyType(Enum): + """Represents the type of key in an optimizer state-dict.""" + + PARAM_NAME = auto() + PARAM_ID = auto() + + +class FullyShardedDataParallel(nn.Module, _FSDPState): + """A wrapper for sharding module parameters across data parallel workers. + + This is inspired by `Xu et al. `_ as + well as the ZeRO Stage 3 from `DeepSpeed `_. + FullyShardedDataParallel is commonly shortened to FSDP. + + Example:: + + >>> # xdoctest: +SKIP("undefined variables") + >>> import torch + >>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP + >>> torch.cuda.set_device(device_id) + >>> sharded_module = FSDP(my_module) + >>> optim = torch.optim.Adam(sharded_module.parameters(), lr=0.0001) + >>> x = sharded_module(x, y=3, z=torch.Tensor([1])) + >>> loss = x.sum() + >>> loss.backward() + >>> optim.step() + + Using FSDP involves wrapping your module and then initializing your + optimizer after. This is required since FSDP changes the parameter + variables. + + When setting up FSDP, you need to consider the destination CUDA + device. If the device has an ID (``dev_id``), you have three options: + + * Place the module on that device + * Set the device using ``torch.cuda.set_device(dev_id)`` + * Pass ``dev_id`` into the ``device_id`` constructor argument. + + This ensures that the FSDP instance's compute device is the + destination device. For option 1 and 3, the FSDP initialization + always occurs on GPU. For option 2, the FSDP initialization + happens on module's current device, which may be a CPU. + + If you're using the ``sync_module_states=True`` flag, you need to + ensure that the module is on a GPU or use the ``device_id`` + argument to specify a CUDA device that FSDP will move the module + to in the FSDP constructor. This is necessary because + ``sync_module_states=True`` requires GPU communication. + + FSDP also takes care of moving input tensors to the forward method + to the GPU compute device, so you don't need to manually move them + from CPU. + + For ``use_orig_params=True``, + ``ShardingStrategy.SHARD_GRAD_OP`` exposes the unsharded + parameters, not the sharded parameters after forward, unlike + ``ShardingStrategy.FULL_SHARD``. If you want + to inspect the gradients, you can use the ``summon_full_params`` + method with ``with_grads=True``. + + With ``limit_all_gathers=True``, you may see a gap in the FSDP + pre-forward where the CPU thread is not issuing any kernels. This is + intentional and shows the rate limiter in effect. Synchronizing the CPU + thread in that way prevents over-allocating memory for subsequent + all-gathers, and it should not actually delay GPU kernel execution. + + FSDP replaces managed modules' parameters with ``torch.Tensor`` + views during forward and backward computation for autograd-related + reasons. If your module's forward relies on saved references to + the parameters instead of reacquiring the references each + iteration, then it will not see FSDP's newly created views, + and autograd will not work correctly. + + Finally, when using ``sharding_strategy=ShardingStrategy.HYBRID_SHARD`` + with the sharding process group being intra-node and the + replication process group being inter-node, setting + ``NCCL_CROSS_NIC=1`` can help improve the all-reduce times over + the replication process group for some cluster setups. + + **Limitations** + + There are several limitations to be aware of when using FSDP: + + * FSDP currently does not support gradient accumulation outside + ``no_sync()`` when using CPU offloading. This is because FSDP + uses the newly-reduced gradient instead of accumulating with any + existing gradient, which can lead to incorrect results. + + * FSDP does not support running the forward pass of a submodule + that is contained in an FSDP instance. This is because the + submodule's parameters will be sharded, but the submodule itself + is not an FSDP instance, so its forward pass will not all-gather + the full parameters appropriately. + + * FSDP does not work with double backwards due to the way it + registers backward hooks. + + * FSDP has some constraints when freezing parameters. + For ``use_orig_params=False``, each FSDP instance must manage + parameters that are all frozen or all non-frozen. For + ``use_orig_params=True``, FSDP supports mixing frozen and + non-frozen parameters, but it's recommended to avoid doing so to + prevent higher than expected gradient memory usage. + + * As of PyTorch 1.12, FSDP offers limited support for shared + parameters. If enhanced shared parameter support is needed for + your use case, please post in + `this issue `__. + + * You should avoid modifying the parameters between forward and + backward without using the ``summon_full_params`` context, as + the modifications may not persist. + + Args: + module (nn.Module): + This is the module to be wrapped with FSDP. + process_group (Optional[Union[ProcessGroup, Tuple[ProcessGroup, ProcessGroup]]]): + This is the process group over which the model is sharded and thus + the one used for FSDP's all-gather and reduce-scatter collective + communications. If ``None``, then FSDP uses the default process + group. For hybrid sharding strategies such as + ``ShardingStrategy.HYBRID_SHARD``, users can pass in a tuple of + process groups, representing the groups over which to shard and + replicate, respectively. If ``None``, then FSDP constructs process + groups for the user to shard intra-node and replicate inter-node. + (Default: ``None``) + sharding_strategy (Optional[ShardingStrategy]): + This configures the sharding strategy, which may trade off memory + saving and communication overhead. See :class:`ShardingStrategy` + for details. (Default: ``FULL_SHARD``) + cpu_offload (Optional[CPUOffload]): + This configures CPU offloading. If this is set to ``None``, then + no CPU offloading happens. See :class:`CPUOffload` for details. + (Default: ``None``) + auto_wrap_policy (Optional[Union[Callable[[nn.Module, bool, int], bool], ModuleWrapPolicy, CustomPolicy]]): + This specifies a policy to apply FSDP to submodules of ``module``, + which is needed for communication and computation overlap and thus + affects performance. If ``None``, then FSDP only applies to + ``module``, and users should manually apply FSDP to parent modules + themselves (proceeding bottom-up). For convenience, this accepts + ``ModuleWrapPolicy`` directly, which allows users to specify the + module classes to wrap (e.g. the transformer block). Otherwise, + this should be a callable that takes in three arguments + ``module: nn.Module``, ``recurse: bool``, and + ``nonwrapped_numel: int`` and should return a ``bool`` specifying + whether the passed-in ``module`` should have FSDP applied if + ``recurse=False`` or if the traversal should continue into the + module's subtree if ``recurse=True``. Users may add additional + arguments to the callable. The ``size_based_auto_wrap_policy`` in + ``torch.distributed.fsdp.wrap.py`` gives an example callable that + applies FSDP to a module if the parameters in its subtree exceed + 100M numel. We recommend printing the model after applying FSDP + and adjusting as needed. + + Example:: + + >>> def custom_auto_wrap_policy( + >>> module: nn.Module, + >>> recurse: bool, + >>> nonwrapped_numel: int, + >>> # Additional custom arguments + >>> min_num_params: int = int(1e8), + >>> ) -> bool: + >>> return nonwrapped_numel >= min_num_params + >>> # Configure a custom `min_num_params` + >>> my_auto_wrap_policy = functools.partial(custom_auto_wrap_policy, min_num_params=int(1e5)) + + backward_prefetch (Optional[BackwardPrefetch]): + This configures explicit backward prefetching of all-gathers. If + ``None``, then FSDP does not backward prefetch, and there is no + communication and computation overlap in the backward pass. See + :class:`BackwardPrefetch` for details. (Default: ``BACKWARD_PRE``) + mixed_precision (Optional[MixedPrecision]): + This configures native mixed precision for FSDP. If this is set to + ``None``, then no mixed precision is used. Otherwise, parameter, + buffer, and gradient reduction dtypes can be set. See + :class:`MixedPrecision` for details. (Default: ``None``) + ignored_modules (Optional[Iterable[torch.nn.Module]]): Modules whose + own parameters and child modules' parameters and buffers are + ignored by this instance. None of the modules directly in + ``ignored_modules`` should be :class:`FullyShardedDataParallel` + instances, and any child modules that are already-constructed + :class:`FullyShardedDataParallel` instances will not be ignored if + they are nested under this instance. This argument may be used to + avoid sharding specific parameters at module granularity when using an + ``auto_wrap_policy`` or if parameters' sharding is not managed by + FSDP. (Default: ``None``) + param_init_fn (Optional[Callable[[nn.Module], None]]): + A ``Callable[torch.nn.Module] -> None`` that + specifies how modules that are currently on the meta device should + be initialized onto an actual device. As of v1.12, FSDP detects + modules with parameters or buffers on meta device via ``is_meta`` + and either applies ``param_init_fn`` if specified or calls + ``nn.Module.reset_parameters()`` otherwise. For both cases, the + implementation should *only* initialize the parameters/buffers of + the module, not those of its submodules. This is to avoid + re-initialization. In addition, FSDP also supports deferred + initialization via torchdistX's (https://github.com/pytorch/torchdistX) + ``deferred_init()`` API, where the deferred modules are initialized + by calling ``param_init_fn`` if specified or torchdistX's default + ``materialize_module()`` otherwise. If ``param_init_fn`` is + specified, then it is applied to all meta-device modules, meaning + that it should probably case on the module type. FSDP calls the + initialization function before parameter flattening and sharding. + + Example:: + + >>> # xdoctest: +SKIP("undefined variables") + >>> module = MyModule(device="meta") + >>> def my_init_fn(module: nn.Module): + >>> # E.g. initialize depending on the module type + >>> ... + >>> fsdp_model = FSDP(module, param_init_fn=my_init_fn, auto_wrap_policy=size_based_auto_wrap_policy) + >>> print(next(fsdp_model.parameters()).device) # current CUDA device + >>> # With torchdistX + >>> module = deferred_init.deferred_init(MyModule, device="cuda") + >>> # Will initialize via deferred_init.materialize_module(). + >>> fsdp_model = FSDP(module, auto_wrap_policy=size_based_auto_wrap_policy) + + device_id (Optional[Union[int, torch.device]]): An ``int`` or + ``torch.device`` giving the CUDA device on which FSDP + initialization takes place, including the module initialization + if needed and the parameter sharding. This should be specified to + improve initialization speed if ``module`` is on CPU. If the + default CUDA device was set (e.g. via ``torch.cuda.set_device``), + then the user may pass ``torch.cuda.current_device`` to this. + (Default: ``None``) + sync_module_states (bool): If ``True``, then each FSDP module will + broadcast module parameters and buffers from rank 0 to ensure that + they are replicated across ranks (adding communication overhead to + this constructor). This can help load ``state_dict`` checkpoints + via ``load_state_dict`` in a memory efficient way. See + :class:`FullStateDictConfig` for an example of this. (Default: + ``False``) + forward_prefetch (bool): If ``True``, then FSDP *explicitly* prefetches + the next forward-pass all-gather before the current forward + computation. This is only useful for CPU-bound workloads, in which + case issuing the next all-gather earlier may improve overlap. This + should only be used for static-graph models since the prefetching + follows the first iteration's execution order. (Default: ``False``) + limit_all_gathers (bool): If ``True``, then FSDP explicitly + synchronizes the CPU thread to ensure GPU memory usage from only + *two* consecutive FSDP instances (the current instance running + computation and the next instance whose all-gather is prefetched). + If ``False``, then FSDP allows the CPU thread to issue all-gathers + without any extra synchronization. (Default: ``True``) We often + refer to this feature as the "rate limiter". This flag should only + be set to ``False`` for specific CPU-bound workloads with low + memory pressure in which case the CPU thread can aggressively issue + all kernels without concern for the GPU memory usage. + use_orig_params (bool): Setting this to ``True`` has FSDP use + ``module`` 's original parameters. FSDP exposes those original + parameters to the user via :meth:`nn.Module.named_parameters` + instead of FSDP's internal :class:`FlatParameter` s. This means + that the optimizer step runs on the original parameters, enabling + per-original-parameter hyperparameters. FSDP preserves the original + parameter variables and manipulates their data between unsharded + and sharded forms, where they are always views into the underlying + unsharded or sharded :class:`FlatParameter`, respectively. With the + current algorithm, the sharded form is always 1D, losing the + original tensor structure. An original parameter may have all, + some, or none of its data present for a given rank. In the none + case, its data will be like a size-0 empty tensor. Users should not + author programs relying on what data is present for a given + original parameter in its sharded form. ``True`` is required to + use ``torch.compile()``. Setting this to ``False`` exposes FSDP's + internal :class:`FlatParameter` s to the user via + :meth:`nn.Module.named_parameters`. (Default: ``False``) + ignored_states (Optional[Iterable[torch.nn.Parameter]], Optional[Iterable[torch.nn.Module]]): + Ignored parameters or modules that will not be managed by this FSDP + instance, meaning that the parameters are not sharded and their + gradients are not reduced across ranks. This argument unifies with + the existing ``ignored_modules`` argument, and we may deprecate + ``ignored_modules`` soon. For backward compatibility, we keep both + ``ignored_states`` and `ignored_modules``, but FSDP only allows one + of them to be specified as not ``None``. + device_mesh (Optional[DeviceMesh]): DeviceMesh can be used as an alternative to + process_group. When device_mesh is passed, FSDP will use the underlying process + groups for all-gather and reduce-scatter collective communications. Therefore, + these two args need to be mutually exclusive. For hybrid sharding strategies such as + ``ShardingStrategy.HYBRID_SHARD``, users can pass in a 2D DeviceMesh instead + of a tuple of process groups. For 2D FSDP + TP, users are required to pass in + device_mesh instead of process_group. For more DeviceMesh info, please visit: + https://pytorch.org/tutorials/recipes/distributed_device_mesh.html + """ + + def __init__( + self, + module: nn.Module, + process_group: ProcessGroupType = None, + sharding_strategy: Optional[ShardingStrategy] = None, + cpu_offload: Optional[CPUOffload] = None, + auto_wrap_policy: Optional[ + Union[Callable, ModuleWrapPolicy, CustomPolicy] + ] = None, + backward_prefetch: Optional[BackwardPrefetch] = BackwardPrefetch.BACKWARD_PRE, + mixed_precision: Optional[MixedPrecision] = None, + ignored_modules: Optional[Iterable[torch.nn.Module]] = None, + param_init_fn: Optional[Callable[[nn.Module], None]] = None, + device_id: Optional[Union[int, torch.device]] = None, + sync_module_states: bool = False, + forward_prefetch: bool = False, + limit_all_gathers: bool = True, + use_orig_params: bool = False, + ignored_states: Union[ + Optional[Iterable[torch.nn.Parameter]], Optional[Iterable[torch.nn.Module]] + ] = None, + device_mesh: Optional[DeviceMesh] = None, + ): + torch._C._log_api_usage_once("torch.distributed.fsdp") + super().__init__() + if isinstance(module, (nn.ModuleList, nn.ModuleDict)): + warnings.warn( + "FSDP will not all-gather parameters for containers that do " + f"not implement forward: {module}", + stacklevel=2, + ) + _init_ignored_module_states(self, module, ignored_modules, ignored_states) + _init_device_handle(self, module, self._ignored_params, device_id) + + # Add module annotations for Dynamo support (see function for details) + _annotate_modules_for_dynamo(module, self._ignored_modules, use_orig_params) + + # Initializes self.process_group, along with rank and world size. This will + # also set another attribute, _inter_node_pg, to control the process group + # over which sharding occurs, if sharding_strategy is {HYBRID_SHARD, _HYBRID_SHARD_ZERO2}. + # Note that this is done before auto_wrapping, so that child FSDP modules simply pick up + # the same process group state as the root FSDP module. + self._device_mesh = device_mesh + _init_process_group_state( + self, + process_group, + sharding_strategy, + auto_wrap_policy, + device_mesh, + ) + if auto_wrap_policy is not None: + root_kwargs = { + "process_group": process_group, + "sharding_strategy": sharding_strategy, + "cpu_offload": cpu_offload, + "backward_prefetch": backward_prefetch, + "mixed_precision": mixed_precision, + "param_init_fn": param_init_fn, + "device_id": device_id, + "sync_module_states": sync_module_states, + "forward_prefetch": forward_prefetch, + "limit_all_gathers": limit_all_gathers, + "use_orig_params": use_orig_params, + "ignored_states": self._ignored_params, + "device_mesh": device_mesh, + } + if sharding_strategy in HYBRID_SHARDING_STRATEGIES and device_mesh is None: + # Share root process groups with children to maintain + # the invariant that all FSDP modules will have the same + # process groups. + root_kwargs["process_group"] = (self.process_group, self._inter_node_pg) + + _auto_wrap( + module, + auto_wrap_policy, + self._ignored_modules, + self._ignored_params, + root_kwargs, + FullyShardedDataParallel, + ) + + backward_prefetch_limit = 1 + forward_prefetch_limit = 1 + _init_core_state( + self, + sharding_strategy, + mixed_precision, + cpu_offload, + limit_all_gathers, + use_orig_params, + backward_prefetch_limit, + forward_prefetch_limit, + ) + _init_runtime_state(self) + _init_prefetching_state(self, backward_prefetch, forward_prefetch) + _init_buffer_state(self, module) + # extension needs to be set before `_init_param_handle_from_module()` + _init_extension(self, device_mesh) + _init_param_handle_from_module( + self, + module, + device_id, + param_init_fn, + sync_module_states, + ) + self._fsdp_wrapped_module = module + if not use_orig_params: + _check_orig_params_flattened(self, self._ignored_params) + _register_flat_param(self, self) + + # `_state_dict_type` controls the `state_dict()` behavior, which is + # implemented using post-save and pre-load hooks + _init_state_dict_state(self) + _register_all_state_dict_hooks(self) + self._zero_scalar = None + + @property + def module(self) -> nn.Module: + """Return the wrapped module.""" + # FSDP's `.module` must refer to the innermost wrapped module when + # composing with other module wrappers in order for state dict to work + if isinstance(self._fsdp_wrapped_module, ActivationWrapper): + return getattr(self._fsdp_wrapped_module, _CHECKPOINT_WRAPPED_MODULE) + return self._fsdp_wrapped_module + + @property + def _has_params(self) -> bool: + """Returns whether this FSDP instance manages any parameters.""" + return hasattr(self, "_handle") and self._handle is not None + + @property + def _flat_param(self) -> Optional[FlatParameter]: + return self._handle.flat_param if self._handle else None + + def __getattr__(self, name: str) -> Any: + """Forward missing attributes to the wrapped module.""" + try: + return super().__getattr__(name) # defer to nn.Module's logic + except AttributeError: + return getattr(self._fsdp_wrapped_module, name) + + def __getitem__(self, key: int) -> Any: + """Forward indexing calls in case the module is an ``nn.Sequential``.""" + if hasattr(self, FSDP_WRAPPED_MODULE): + return self._fsdp_wrapped_module.__getitem__(key) # type: ignore[operator] + return super().__getitem__(key) + + def check_is_root(self) -> bool: + """Check if this instance is a root FSDP module.""" + return _is_fsdp_root(self, self) + + @staticmethod + def fsdp_modules( + module: nn.Module, + root_only: bool = False, + ) -> list["FullyShardedDataParallel"]: + """Return all nested FSDP instances. + + This possibly includes ``module`` itself and only includes FSDP root modules if ``root_only=True``. + + Args: + module (torch.nn.Module): Root module, which may or may not be an + ``FSDP`` module. + root_only (bool): Whether to return only FSDP root modules. + (Default: ``False``) + + Returns: + List[FullyShardedDataParallel]: FSDP modules that are nested in + the input ``module``. + """ + if root_only: + return _get_fsdp_root_states(module) + return traversal_utils._get_fsdp_states(module) + + def apply(self, fn: Callable[[nn.Module], None]) -> "FullyShardedDataParallel": + r"""Apply ``fn`` recursively to every submodule (as returned by ``.children()``) as well as self. + + Typical use includes initializing the parameters of a model (see also :ref:`nn-init-doc`). + + Compared to ``torch.nn.Module.apply``, this version additionally gathers + the full parameters before applying ``fn``. It should not be called from + within another ``summon_full_params`` context. + + Args: + fn (:class:`Module` -> None): function to be applied to each submodule + + Returns: + Module: self + """ + uninitialized = self._is_root is None + self._assert_state(TrainingState.IDLE) + # Use `_unshard_params_for_summon()` with `recurse=False` instead of + # `_unshard_fsdp_state_params()` directly to perform lazy + # initialization, which is needed to initialize `FlatParameter` + # parameter attributes as required by the unshard logic + with _unshard_params_for_summon( + self, + self, + writeback=True, + rank0_only=False, + offload_to_cpu=False, + with_grads=False, + ): + ret = super().apply(fn) + + # Reset lazy init called in `_unshard_params_for_summon()` since + # `apply()` may have been called on FSDP instance that is not truly a + # root, in which case it will be incorrectly marked as one. + if uninitialized and self._is_root: + for module in traversal_utils._get_fsdp_states(self): + module._reset_lazy_init() + + return ret + + def _mixed_precision_enabled_for_buffers(self) -> bool: + """Return whether the user explicitly enabled buffer mixed precision. + + NOTE: Unlike parameters and gradient reduction, buffer mixed precision + is applied at the FSDP instance level, not the ``FlatParameter`` level, + which may be different for the composable code path. + """ + return self.mixed_precision.buffer_dtype is not None + + def _low_precision_hook_enabled(self) -> bool: + """Whether a low precision hook is registered or not.""" + return self._comm_hook is not None and self._comm_hook in LOW_PRECISION_HOOKS + + def _reset_lazy_init(self) -> None: + """Reset instance so :func:`_lazy_init` will run on the next forward.""" + self._is_root: Optional[bool] = None + + @staticmethod + def set_state_dict_type( + module: nn.Module, + state_dict_type: StateDictType, + state_dict_config: Optional[StateDictConfig] = None, + optim_state_dict_config: Optional[OptimStateDictConfig] = None, + ) -> StateDictSettings: + """Set the ``state_dict_type`` of all the descendant FSDP modules of the target module. + + Also takes (optional) configuration for the model's and optimizer's state dict. + The target module does not have to be a FSDP module. If the target + module is a FSDP module, its ``state_dict_type`` will also be changed. + + .. note:: This API should be called for only the top-level (root) + module. + + .. note:: This API enables users to transparently use the conventional + ``state_dict`` API to take model checkpoints in cases where the + root FSDP module is wrapped by another ``nn.Module``. For example, + the following will ensure ``state_dict`` is called on all non-FSDP + instances, while dispatching into `sharded_state_dict` implementation + for FSDP: + + Example:: + + >>> # xdoctest: +SKIP("undefined variables") + >>> model = DDP(FSDP(...)) + >>> FSDP.set_state_dict_type( + >>> model, + >>> StateDictType.SHARDED_STATE_DICT, + >>> state_dict_config = ShardedStateDictConfig(offload_to_cpu=True), + >>> optim_state_dict_config = OptimStateDictConfig(offload_to_cpu=True), + >>> ) + >>> param_state_dict = model.state_dict() + >>> optim_state_dict = FSDP.optim_state_dict(model, optim) + + Args: + module (torch.nn.Module): Root module. + state_dict_type (StateDictType): the desired ``state_dict_type`` to set. + state_dict_config (Optional[StateDictConfig]): the configuration for the + target ``state_dict_type``. + optim_state_dict_config (Optional[OptimStateDictConfig]): the configuration + for the optimizer state dict. + + Returns: + A StateDictSettings that include the previous state_dict type and + configuration for the module. + """ + warnings.warn( + "FSDP.state_dict_type() and FSDP.set_state_dict_type() are being " + "deprecated. Please use APIs, get_state_dict() and set_state_dict(), " + "which can support different parallelisms, FSDP1, FSDP2, DDP. " + "API doc: https://pytorch.org/docs/stable/distributed.checkpoint.html" + "#torch.distributed.checkpoint.state_dict.get_state_dict ." + "Tutorial: https://pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html .", + FutureWarning, + stacklevel=2, + ) + _state_dict_type_to_config = { + StateDictType.FULL_STATE_DICT: FullStateDictConfig, + StateDictType.LOCAL_STATE_DICT: LocalStateDictConfig, + StateDictType.SHARDED_STATE_DICT: ShardedStateDictConfig, + } + _optim_state_dict_type_to_config = { + StateDictType.FULL_STATE_DICT: FullOptimStateDictConfig, + StateDictType.LOCAL_STATE_DICT: LocalOptimStateDictConfig, + StateDictType.SHARDED_STATE_DICT: ShardedOptimStateDictConfig, + } + + # Use the default config if a state_dict config is not set. + state_dict_config_type = _state_dict_type_to_config[state_dict_type] + optim_state_dict_config_type = _optim_state_dict_type_to_config[state_dict_type] + if state_dict_config is None: + state_dict_config = state_dict_config_type() + if optim_state_dict_config is None: + optim_state_dict_config = optim_state_dict_config_type() + if state_dict_config_type is not type(state_dict_config): + raise RuntimeError( + f"Expected state_dict_config of type {state_dict_config_type} " + f"but got {type(state_dict_config)}" + ) + if optim_state_dict_config_type is not type(optim_state_dict_config): + raise RuntimeError( + f"Expected optim_state_dict_config of type {optim_state_dict_config_type} " + f"but got {type(optim_state_dict_config)}" + ) + + # Set the state_dict type and configurations. + prev_state_dict_type = None + prev_state_dict_config = None + prev_optim_state_dict_config = None + for submodule in traversal_utils._get_fsdp_states(module): + if prev_state_dict_type is None: + prev_state_dict_type = submodule._state_dict_type + else: + if prev_state_dict_type != submodule._state_dict_type: + raise AssertionError( + "All FSDP modules should have the same state_dict_type." + ) + if prev_state_dict_config is None: + prev_state_dict_config = submodule._state_dict_config + else: + if not isinstance( + submodule._state_dict_config, type(prev_state_dict_config) + ): + raise AssertionError( + "All FSDP modules must have the same type of state_dict_config." + ) + if prev_optim_state_dict_config is None: + prev_optim_state_dict_config = submodule._optim_state_dict_config + else: + if not isinstance( + submodule._optim_state_dict_config, + type(prev_optim_state_dict_config), + ): + raise AssertionError( + "All FSDP modules must have the same type of optim_state_dict_config." + ) + + submodule._state_dict_type = state_dict_type + submodule._state_dict_config = state_dict_config + submodule._optim_state_dict_config = optim_state_dict_config + + return StateDictSettings( + prev_state_dict_type, prev_state_dict_config, prev_optim_state_dict_config + ) + + @staticmethod + def get_state_dict_type(module: nn.Module) -> StateDictSettings: + """Get the state_dict_type and the corresponding configurations for the FSDP modules rooted at ``module``. + + The target module does not have to be an FSDP module. + + Returns: + A ``StateDictSettings`` containing the state_dict_type and + state_dict / optim_state_dict configs that are currently set. + + Raises: + ``AssertionError`` if the ``StateDictSettings`` for different + FSDP submodules differ. + """ + state_dict_settings: Optional[StateDictSettings] = None + for submodule in FullyShardedDataParallel.fsdp_modules(module): + if state_dict_settings is None: + state_dict_settings = StateDictSettings( + state_dict_type=submodule._state_dict_type, + state_dict_config=submodule._state_dict_config, + optim_state_dict_config=submodule._optim_state_dict_config, + ) + _set_optim_use_dtensor(submodule, state_dict_settings) + else: + submodule_settings = StateDictSettings( + submodule._state_dict_type, + submodule._state_dict_config, + submodule._optim_state_dict_config, + ) + if state_dict_settings != submodule_settings: + raise AssertionError( + "All FSDP modules must have the same state dict settings." + f"Got {submodule_settings} and {state_dict_settings}." + ) + _set_optim_use_dtensor(submodule, submodule_settings) + return state_dict_settings + + @staticmethod + @contextlib.contextmanager + def state_dict_type( + module: nn.Module, + state_dict_type: StateDictType, + state_dict_config: Optional[StateDictConfig] = None, + optim_state_dict_config: Optional[OptimStateDictConfig] = None, + ) -> Generator: + """Set the ``state_dict_type`` of all the descendant FSDP modules of the target module. + + This context manager has the same functions as :meth:`set_state_dict_type`. Read the document of + :meth:`set_state_dict_type` for the detail. + + Example:: + + >>> # xdoctest: +SKIP("undefined variables") + >>> model = DDP(FSDP(...)) + >>> with FSDP.state_dict_type( + >>> model, + >>> StateDictType.SHARDED_STATE_DICT, + >>> ): + >>> checkpoint = model.state_dict() + + Args: + module (torch.nn.Module): Root module. + state_dict_type (StateDictType): the desired ``state_dict_type`` to set. + state_dict_config (Optional[StateDictConfig]): the model ``state_dict`` + configuration for the target ``state_dict_type``. + optim_state_dict_config (Optional[OptimStateDictConfig]): the optimizer + ``state_dict`` configuration for the target ``state_dict_type``. + """ + prev_state_dict_settings = FullyShardedDataParallel.set_state_dict_type( + module, + state_dict_type, + state_dict_config, + optim_state_dict_config, + ) + yield + FullyShardedDataParallel.set_state_dict_type( + module, + prev_state_dict_settings.state_dict_type, + prev_state_dict_settings.state_dict_config, + prev_state_dict_settings.optim_state_dict_config, + ) + + def forward(self, *args: Any, **kwargs: Any) -> Any: + """Run the forward pass for the wrapped module, inserting FSDP-specific pre- and post-forward sharding logic.""" + handle = self._handle + with torch.autograd.profiler.record_function( + "FullyShardedDataParallel.forward" + ): + args, kwargs = _root_pre_forward(self, self, args, kwargs) + unused = None + args, kwargs = _pre_forward( + self, + handle, + _pre_forward_unshard, + self._fsdp_wrapped_module, + args, + kwargs, + ) + if handle: + _p_assert( + handle.flat_param.device == self.compute_device, + "Expected `FlatParameter` to be on the compute device " + f"{self.compute_device} but got {handle.flat_param.device}", + ) + output = self._fsdp_wrapped_module(*args, **kwargs) + return _post_forward( + self, handle, _post_forward_reshard, self, unused, output + ) + + @staticmethod + @contextlib.contextmanager + def summon_full_params( + module: nn.Module, + recurse: bool = True, + writeback: bool = True, + rank0_only: bool = False, + offload_to_cpu: bool = False, + with_grads: bool = False, + ) -> Generator: + r"""Expose full params for FSDP instances with this context manager. + + Can be useful *after* forward/backward for a model to get + the params for additional processing or checking. It can take a non-FSDP + module and will summon full params for all contained FSDP modules as + well as their children, depending on the ``recurse`` argument. + + .. note:: This can be used on inner FSDPs. + .. note:: This can *not* be used within a forward or backward pass. Nor + can forward and backward be started from within this context. + .. note:: Parameters will revert to their local shards after the context + manager exits, storage behavior is the same as forward. + .. note:: The full parameters can be modified, but only the portion + corresponding to the local param shard will persist after the + context manager exits (unless ``writeback=False``, in which case + changes will be discarded). In the case where FSDP does not shard + the parameters, currently only when ``world_size == 1``, or ``NO_SHARD`` + config, the modification is persisted regardless of ``writeback``. + .. note:: This method works on modules which are not FSDP themselves but + may contain multiple independent FSDP units. In that case, the given + arguments will apply to all contained FSDP units. + + .. warning:: Note that ``rank0_only=True`` in conjunction with + ``writeback=True`` is not currently supported and will raise an + error. This is because model parameter shapes would be different + across ranks within the context, and writing to them can lead to + inconsistency across ranks when the context is exited. + + .. warning:: Note that ``offload_to_cpu`` and ``rank0_only=False`` will + result in full parameters being redundantly copied to CPU memory for + GPUs that reside on the same machine, which may incur the risk of + CPU OOM. It is recommended to use ``offload_to_cpu`` with + ``rank0_only=True``. + + Args: + recurse (bool, Optional): recursively summon all params for nested + FSDP instances (default: True). + writeback (bool, Optional): if ``False``, modifications to params are + discarded after the context manager exits; + disabling this can be slightly more efficient (default: True) + rank0_only (bool, Optional): if ``True``, full parameters are + materialized on only global rank 0. This means that within the + context, only rank 0 will have full parameters and the other + ranks will have sharded parameters. Note that setting + ``rank0_only=True`` with ``writeback=True`` is not supported, + as model parameter shapes will be different across ranks + within the context, and writing to them can lead to + inconsistency across ranks when the context is exited. + offload_to_cpu (bool, Optional): If ``True``, full parameters are + offloaded to CPU. Note that this offloading currently only + occurs if the parameter is sharded (which is only not the case + for world_size = 1 or ``NO_SHARD`` config). It is recommended + to use ``offload_to_cpu`` with ``rank0_only=True`` to avoid + redundant copies of model parameters being offloaded to the same CPU memory. + with_grads (bool, Optional): If ``True``, gradients are also + unsharded with the parameters. Currently, this is only + supported when passing ``use_orig_params=True`` to the FSDP + constructor and ``offload_to_cpu=False`` to this method. + (Default: ``False``) + """ + with _unshard_params( + module, recurse, writeback, rank0_only, offload_to_cpu, with_grads + ): + yield + + @contextlib.contextmanager + def _deregister_orig_params_ctx(self): + """Deregister the original parameters and expose the :class:`FlatParameter`. + + If a :class:`FlatParameter` is sharded, then + this refreshes the sharded views before exiting. This method should + only be called when using the original parameters. + """ + _p_assert( + self._use_orig_params, + "`_deregister_orig_params_ctx()` should only be called when " + "`_use_orig_params=True`", + ) + for fsdp_module in traversal_utils._get_fsdp_states(self): + _deregister_orig_params(fsdp_module, fsdp_module) + try: + yield + finally: + for fsdp_module in traversal_utils._get_fsdp_states(self): + _register_orig_params(fsdp_module, fsdp_module) + + def _apply(self, *args, **kwargs): + """Deregister the original parameters and expose the :class:`FlatParameter` s before calling ``_apply()``.""" + # When using the original parameters: Since (1) the `FlatParameter`s + # own the storage and (2) `_apply()` is the subroutine underlying the + # most common storage-changing ops like `to()` and `cuda()`, we + # override `_apply()` to have the storage change directly performed on + # the `FlatParameter`s instead of applying to the original parameters + # and then writing back to the `FlatParameter`s. + context = ( + self._deregister_orig_params_ctx() + if self._use_orig_params + else contextlib.nullcontext() + ) + with context: + return super()._apply(*args, **kwargs) + + def named_buffers( + self, + *args, + **kwargs, + ) -> Iterator[tuple[str, torch.Tensor]]: + """Return an iterator over module buffers, yielding both the name of the buffer and the buffer itself. + + Intercepts buffer names and removes all occurrences of the FSDP-specific flattened buffer prefix + when inside the :meth:`summon_full_params` context manager. + """ + should_clean_name = self.training_state == TrainingState.SUMMON_FULL_PARAMS + for buffer_name, buffer in super().named_buffers(*args, **kwargs): + if should_clean_name: + # Remove any instances of the FSDP-specific prefix; there can + # be multiple in the case of nested FSDP modules + buffer_name = buffer_name.replace(FSDP_PREFIX, "") + yield (buffer_name, buffer) + + def named_parameters( + self, + *args, + **kwargs, + ) -> Iterator[tuple[str, torch.nn.Parameter]]: + """Return an iterator over module parameters, yielding both the name of the parameter and the parameter itself. + + Intercepts parameter names and removes all occurrences of the FSDP-specific flattened parameter prefix + when inside the :meth:`summon_full_params` context manager. + """ + should_clean_name = self.training_state == TrainingState.SUMMON_FULL_PARAMS + for param_name, param in super().named_parameters(*args, **kwargs): + if should_clean_name: + # Remove any instances of the FSDP-specific prefix; there can + # be multiple in the case of nested FSDP modules + param_name = param_name.replace(FSDP_PREFIX, "") + yield (param_name, param) + + def _assert_state(self, state: Union[TrainingState, list[TrainingState]]) -> None: + """Assert we are in the given state.""" + # Since assert can be turned off and this error checking + # is really important, we use explicit error checking + # and raise a ValueError if needed. + if isinstance(state, TrainingState): + state = [state] + if self.training_state not in state: + msg = ( + f"expected to be in states {state} but current state " + f"is {self.training_state}" + ) + # In case we are failing in the context of autograd hook, asserting + # may not generate useful msg. So, let's print it to be sure. + if self.rank == 0: + print(f"Asserting FSDP instance is: {self}") + print(f"ERROR: {msg}") + traceback.print_stack() + raise ValueError(msg) + + @contextmanager + def no_sync(self) -> Generator: + """Disable gradient synchronizations across FSDP instances. + + Within this context, gradients will be accumulated in module + variables, which will later be synchronized in the first + forward-backward pass after exiting the context. This should only be + used on the root FSDP instance and will recursively apply to all + children FSDP instances. + + .. note:: This likely results in higher memory usage because FSDP will + accumulate the full model gradients (instead of gradient shards) + until the eventual sync. + + .. note:: When used with CPU offloading, the gradients will not be + offloaded to CPU when inside the context manager. Instead, they + will only be offloaded right after the eventual sync. + """ + _lazy_init(self, self) + if not self._is_root: + raise RuntimeError( + "`no_sync()` on inner FSDP instances is not supported. Please call `no_sync()` on root FSDP module." + ) + self._assert_state(TrainingState.IDLE) + old_flags = [] + for m in self.modules(): + if isinstance(m, FullyShardedDataParallel): + old_flags.append((m, m._sync_gradients)) + m._sync_gradients = False + try: + yield + finally: + for m, old_flag in old_flags: + if m._sync_gradients: + raise AssertionError( + "`_sync_gradients` was incorrectly set to " + "`True` while in the `no_sync()` context manager" + ) + m._sync_gradients = old_flag + + @torch.no_grad() + def clip_grad_norm_( + self, max_norm: Union[float, int], norm_type: Union[float, int] = 2.0 + ) -> torch.Tensor: + """Clip the gradient norm of all parameters. + + The norm is computed over all parameters' gradients as viewed as a single vector, and the + gradients are modified in-place. + + Args: + max_norm (float or int): max norm of the gradients + norm_type (float or int): type of the used p-norm. Can be ``'inf'`` + for infinity norm. + + Returns: + Total norm of the parameters (viewed as a single vector). + + If every FSDP instance uses ``NO_SHARD``, meaning that no + gradients are sharded across ranks, then you may directly use + :func:`torch.nn.utils.clip_grad_norm_`. + + If at least some FSDP instance uses a sharded strategy (i.e. + one other than ``NO_SHARD``), then you should use this method + instead of :func:`torch.nn.utils.clip_grad_norm_` since this method + handles the fact that gradients are sharded across ranks. + + The total norm returned will have the "largest" dtype across + all parameters/gradients as defined by PyTorch's type promotion + semantics. For example, if *all* parameters/gradients use a low + precision dtype, then the returned norm's dtype will be that low + precision dtype, but if there exists at least one parameter/ + gradient using FP32, then the returned norm's dtype will be FP32. + + .. warning:: This needs to be called on all ranks since it uses + collective communications. + """ + _lazy_init(self, self) + if not self._is_root: + raise RuntimeError( + "`clip_grad_norm_()` should only be called on the root FSDP instance" + ) + if self._zero_scalar is None: + self._zero_scalar = torch.tensor(0.0, device=self.compute_device) + self._assert_state(TrainingState.IDLE) + # If every FSDP instance uses `NO_SHARD`, then we can directly use + # the normal `nn.utils` one targeting local gradients + all_no_shard = all( + not handle.uses_sharded_strategy for handle in self._all_handles + ) + if all_no_shard: + return torch.nn.utils.clip_grad_norm_( + self.parameters(), max_norm, norm_type + ) + # Otherwise, there exists some FSDP instance using a sharded strategy, + # where sharded and non-sharded parameters must be handled separately + max_norm = float(max_norm) + norm_type = float(norm_type) + sharded_params_set = set() + nonsharded_params_set = set() # `NO_SHARD` or not FSDP-managed + # Make sure to compute the local norm using lists for deterministic + # iteration order and hence deterministic total norm computation + sharded_params = [] + nonsharded_params = [] + grads: list[torch.Tensor] = [] + for handle in self._all_handles: + if handle.uses_sharded_strategy: + target_set = sharded_params_set + target_list = sharded_params + else: + target_set = nonsharded_params_set + target_list = nonsharded_params + if handle._use_orig_params: + for param in handle.flat_param._params: + if param not in target_set: + target_set.add(param) + target_list.append(param) + if param.grad is not None: + grads.append(param.grad) + else: + if handle.flat_param not in target_set: + target_set.add(handle.flat_param) + target_list.append(handle.flat_param) + if handle.flat_param.grad is not None: + grads.append(handle.flat_param.grad) + for param in self.parameters(): + not_fsdp_managed = ( + param not in sharded_params_set and param not in nonsharded_params_set + ) + if not_fsdp_managed: + nonsharded_params_set.add(param) + nonsharded_params.append(param) + if param.grad is not None: + grads.append(param.grad) + # Compute local norms (forced to be in FP32) + local_sharded_norm = _get_grad_norm( + sharded_params, norm_type, self._zero_scalar, self.compute_device + ) + local_nonsharded_norm = ( + _get_grad_norm( + nonsharded_params, norm_type, self._zero_scalar, self.compute_device + ) + if nonsharded_params + else None + ) + # Reconstruct the total gradient norm depending on the norm type + if norm_type == math.inf: + total_norm = ( + torch.maximum(local_sharded_norm, local_nonsharded_norm) + if local_nonsharded_norm is not None + else local_sharded_norm + ) + dist.all_reduce( + total_norm, op=torch.distributed.ReduceOp.MAX, group=self.process_group + ) + else: + total_norm = local_sharded_norm**norm_type + dist.all_reduce(total_norm, group=self.process_group) + # All-reducing the local non-sharded norm would count it an extra + # world-size-many times + if local_nonsharded_norm is not None: + total_norm += local_nonsharded_norm**norm_type + total_norm = total_norm ** (1.0 / norm_type) + if self.cpu_offload.offload_params: + total_norm = total_norm.cpu() + + clip_coef = max_norm / (total_norm + 1e-6) + # Multiplying by the clamped coefficient is meaningless when it is + # equal to 1, but it avoids the host-device sync that would result from + # `if clip_coef < 1` + clip_coef_clamped = torch.clamp(clip_coef, max=1.0) + for grad in grads: + grad.mul_(clip_coef_clamped.to(grad.device, grad.dtype)) + # Use the "largest" dtype by type promotion semantics to use the same + # dtype as if we did not force local norm computation to be in FP32 + if len(grads) == 0: + # If this rank has no gradients, then we must default to FP32 + # unless we use additional communication, which we prefer to avoid + # since `clip_grad_norm_()` is called in the training loop + warnings.warn( + f"Called FSDP.clip_grad_norm_() on rank {self.rank} with no " + "gradients -- returning the total norm in the default dtype " + f"{total_norm.dtype}", + stacklevel=2, + ) # warn since this is generally unexpected + return total_norm + total_norm_dtype = functools.reduce( + torch.promote_types, + [grad.dtype for grad in grads], + ) + return total_norm.to(total_norm_dtype) + + @staticmethod + def _warn_optim_input(optim_input, *, stacklevel: int = 1): + if optim_input is not None: + warnings.warn( + "The `optim_input` argument is deprecated and will be removed after PyTorch 1.13. " + "You may remove it from your code without changing its functionality.", + FutureWarning, + stacklevel=stacklevel + 1, + ) + + @staticmethod + def _is_using_optim_input(optim_input, optim) -> bool: + if optim_input is None and optim is None: + # Use the default behavior of `optim_input`` + return True + if optim_input is not None: + # Use the `optim_input` code path + return True + # Use the `optim` code path + return False + + @staticmethod + def _warn_legacy_optim_state_dict(curr: str, new: str, *, stacklevel: int = 1): + warnings.warn( + f"``FullyShardedDataParallel.{curr}``is being deprecated and is " + f"replaced by ``FullyShardedDataParallel.{new}``. " + f"``FullyShardedDataParallel.{curr}`` may be removed after PyTorch 2.2.", + FutureWarning, + stacklevel=stacklevel + 1, + ) + + @staticmethod + def _optim_state_dict_impl( + model: torch.nn.Module, + optim: torch.optim.Optimizer, + optim_state_dict: dict[str, Any], + optim_input: Optional[ + Union[ + list[dict[str, Any]], + Iterable[torch.nn.Parameter], + ] + ] = None, + rank0_only: bool = True, + full_state_dict: bool = True, + group: Optional[dist.ProcessGroup] = None, + cpu_offload: bool = True, + *, + _stacklevel: int = 1, + ) -> dict[str, Any]: + """Transform the state-dict of an optimizer corresponding to a sharded model. + + This is the internal API that is used by all the optim_state_dict implementations. + Given model, optim, the original optim_state_dict, this API removes the + FSDP internal information and internal sharding from the optim_state_dict. + """ + if full_state_dict: + FullyShardedDataParallel._warn_optim_input( + optim_input, stacklevel=_stacklevel + 1 + ) + using_optim_input = FullyShardedDataParallel._is_using_optim_input( + optim_input, + optim, + ) + else: + using_optim_input = False + if optim_input is not None or rank0_only: + raise AssertionError( + f"Expected optim_input to be None and rank0_only to be False, " + f"got optim_input={optim_input}, rank0_only={rank0_only}" + ) + + use_orig_params = FullyShardedDataParallel.fsdp_modules(model)[ + 0 + ]._use_orig_params + if not all( + use_orig_params == m._use_orig_params + for m in FullyShardedDataParallel.fsdp_modules(model) + ): + raise AssertionError( + "Not all FSDP modules have the same _use_orig_params value" + ) + + return _optim_state_dict( + model=model, + optim=optim, + optim_state_dict=optim_state_dict, + optim_input=optim_input, + rank0_only=rank0_only, + shard_state=not full_state_dict, + group=group, + using_optim_input=using_optim_input, + use_orig_params=use_orig_params, + cpu_offload=cpu_offload, + ) + + @staticmethod + def _optim_state_dict_to_load_impl( + optim_state_dict: dict[str, Any], + model: torch.nn.Module, + optim_input: Optional[ + Union[ + list[dict[str, Any]], + Iterable[torch.nn.Parameter], + ] + ] = None, + optim: Optional[torch.optim.Optimizer] = None, + full_state_dict: bool = True, + rank0_only: bool = False, + is_named_optimizer: bool = False, + group: Optional[dist.ProcessGroup] = None, + ) -> dict[str, Any]: + """ + Convert an optimizer state-dict so that it can be loaded into the optimizer associated with the FSDP model. + + This is the internal API that is used by all the load optim_state_dict implementations. + Given model, optim, and the saved optim_state_dict, this API adds the FSDP + internal information and internal sharding to the optim_state_dict. + """ + if full_state_dict: + FullyShardedDataParallel._warn_optim_input(optim_input) + using_optim_input = FullyShardedDataParallel._is_using_optim_input( + optim_input, + optim, + ) + else: + using_optim_input = False + if optim_input is not None or rank0_only: + raise AssertionError( + f"Expected optim_input to be None and rank0_only to be False, " + f"got optim_input={optim_input}, rank0_only={rank0_only}" + ) + + use_orig_params = FullyShardedDataParallel.fsdp_modules(model)[ + 0 + ]._use_orig_params + if not all( + use_orig_params == m._use_orig_params + for m in FullyShardedDataParallel.fsdp_modules(model) + ): + raise AssertionError( + "Not all FSDP modules have the same _use_orig_params value" + ) + + if rank0_only and dist.get_rank(group) > 0: + optim_state_dict = {} + sharded_osd = _flatten_optim_state_dict( + optim_state_dict, + model=model, + use_orig_params=use_orig_params, + optim=(optim if is_named_optimizer else None), + rank0_only=rank0_only, + group=group, + ) + return _rekey_sharded_optim_state_dict( + sharded_osd, + model=model, + optim=optim, + optim_input=optim_input, + using_optim_input=using_optim_input, + is_named_optimizer=is_named_optimizer, + ) + + @staticmethod + def full_optim_state_dict( + model: torch.nn.Module, + optim: torch.optim.Optimizer, + optim_input: Optional[ + Union[ + list[dict[str, Any]], + Iterable[torch.nn.Parameter], + ] + ] = None, + rank0_only: bool = True, + group: Optional[dist.ProcessGroup] = None, + ) -> dict[str, Any]: + """Return the full optimizer state-dict. + + Consolidates the full optimizer state on rank 0 and returns it + as a :class:`dict` following the convention of + :meth:`torch.optim.Optimizer.state_dict`, i.e. with keys ``"state"`` + and ``"param_groups"``. The flattened parameters in ``FSDP`` modules + contained in ``model`` are mapped back to their unflattened parameters. + + This needs to be called on all ranks since it uses + collective communications. However, if ``rank0_only=True``, then + the state dict is only populated on rank 0, and all other ranks + return an empty :class:`dict`. + + Unlike ``torch.optim.Optimizer.state_dict()``, this method + uses full parameter names as keys instead of parameter IDs. + + Like in :meth:`torch.optim.Optimizer.state_dict`, the tensors + contained in the optimizer state dict are not cloned, so there may + be aliasing surprises. For best practices, consider saving the + returned optimizer state dict immediately, e.g. using + ``torch.save()``. + + Args: + model (torch.nn.Module): Root module (which may or may not be a + :class:`FullyShardedDataParallel` instance) whose parameters + were passed into the optimizer ``optim``. + optim (torch.optim.Optimizer): Optimizer for ``model`` 's + parameters. + optim_input (Optional[Union[List[Dict[str, Any]], Iterable[torch.nn.Parameter]]]): + Input passed into the optimizer ``optim`` representing either a + :class:`list` of parameter groups or an iterable of parameters; + if ``None``, then this method assumes the input was + ``model.parameters()``. This argument is deprecated, and there + is no need to pass it in anymore. (Default: ``None``) + rank0_only (bool): If ``True``, saves the populated :class:`dict` + only on rank 0; if ``False``, saves it on all ranks. (Default: + ``True``) + group (dist.ProcessGroup): Model's process group or ``None`` if using + the default process group. (Default: ``None``) + + Returns: + Dict[str, Any]: A :class:`dict` containing the optimizer state for + ``model`` 's original unflattened parameters and including keys + "state" and "param_groups" following the convention of + :meth:`torch.optim.Optimizer.state_dict`. If ``rank0_only=True``, + then nonzero ranks return an empty :class:`dict`. + """ + FullyShardedDataParallel._warn_legacy_optim_state_dict( + "full_optim_state_dict", + "optim_state_dict", + stacklevel=2, + ) + return FullyShardedDataParallel._optim_state_dict_impl( + model=model, + optim=optim, + optim_state_dict=optim.state_dict(), + optim_input=optim_input, + rank0_only=rank0_only, + group=group, + full_state_dict=True, + _stacklevel=2, + ) + + @staticmethod + def sharded_optim_state_dict( + model: torch.nn.Module, + optim: torch.optim.Optimizer, + group: Optional[dist.ProcessGroup] = None, + ) -> dict[str, Any]: + """Return the optimizer state-dict in its sharded form. + + The API is similar to :meth:`full_optim_state_dict` but this API chunks + all non-zero-dimension states to :class:`ShardedTensor` to save memory. + This API should only be used when the model ``state_dict`` is derived + with the context manager ``with state_dict_type(SHARDED_STATE_DICT):``. + + For the detailed usage, refer to :meth:`full_optim_state_dict`. + + .. warning:: The returned state dict contains ``ShardedTensor`` and + cannot be directly used by the regular ``optim.load_state_dict``. + """ + FullyShardedDataParallel._warn_legacy_optim_state_dict( + "sharded_optim_state_dict", + "optim_state_dict", + stacklevel=2, + ) + return FullyShardedDataParallel._optim_state_dict_impl( + model=model, + optim=optim, + optim_state_dict=optim.state_dict(), + optim_input=None, + rank0_only=False, + full_state_dict=False, + group=group, + _stacklevel=2, + ) + + @staticmethod + def shard_full_optim_state_dict( + full_optim_state_dict: dict[str, Any], + model: torch.nn.Module, + optim_input: Optional[ + Union[ + list[dict[str, Any]], + Iterable[torch.nn.Parameter], + ] + ] = None, + optim: Optional[torch.optim.Optimizer] = None, + ) -> dict[str, Any]: + """Shard a full optimizer state-dict. + + Remaps the state in ``full_optim_state_dict`` to flattened parameters instead of unflattened + parameters and restricts to only this rank's part of the optimizer state. + The first argument should be the return value of :meth:`full_optim_state_dict`. + + Example:: + + >>> # xdoctest: +SKIP("undefined variables") + >>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP + >>> model, optim = ... + >>> full_osd = FSDP.full_optim_state_dict(model, optim) + >>> torch.save(full_osd, PATH) + >>> # Define new model with possibly different world size + >>> new_model, new_optim = ... + >>> full_osd = torch.load(PATH) + >>> sharded_osd = FSDP.shard_full_optim_state_dict(full_osd, new_model) + >>> new_optim.load_state_dict(sharded_osd) + + .. note:: Both :meth:`shard_full_optim_state_dict` and + :meth:`scatter_full_optim_state_dict` may be used to get the + sharded optimizer state dict to load. Assuming that the full + optimizer state dict resides in CPU memory, the former requires + each rank to have the full dict in CPU memory, where each rank + individually shards the dict without any communication, while the + latter requires only rank 0 to have the full dict in CPU memory, + where rank 0 moves each shard to GPU memory (for NCCL) and + communicates it to ranks appropriately. Hence, the former has + higher aggregate CPU memory cost, while the latter has higher + communication cost. + + Args: + full_optim_state_dict (Dict[str, Any]): Optimizer state dict + corresponding to the unflattened parameters and holding the + full non-sharded optimizer state. + model (torch.nn.Module): Root module (which may or may not be a + :class:`FullyShardedDataParallel` instance) whose parameters + correspond to the optimizer state in ``full_optim_state_dict``. + optim_input (Optional[Union[List[Dict[str, Any]], Iterable[torch.nn.Parameter]]]): + Input passed into the optimizer representing either a + :class:`list` of parameter groups or an iterable of parameters; + if ``None``, then this method assumes the input was + ``model.parameters()``. This argument is deprecated, and there + is no need to pass it in anymore. (Default: ``None``) + optim (Optional[torch.optim.Optimizer]): Optimizer that will load + the state dict returned by this method. This is the preferred + argument to use over ``optim_input``. (Default: ``None``) + + Returns: + Dict[str, Any]: The full optimizer state dict now remapped to + flattened parameters instead of unflattened parameters and + restricted to only include this rank's part of the optimizer state. + """ + FullyShardedDataParallel._warn_legacy_optim_state_dict( + "shard_full_optim_state_dict", + "optim_state_dict_to_load", + stacklevel=2, + ) + return FullyShardedDataParallel._optim_state_dict_to_load_impl( + optim_state_dict=full_optim_state_dict, + model=model, + optim_input=optim_input, + optim=optim, + full_state_dict=True, + is_named_optimizer=False, + ) + + @staticmethod + def flatten_sharded_optim_state_dict( + sharded_optim_state_dict: dict[str, Any], + model: torch.nn.Module, + optim: torch.optim.Optimizer, + ) -> dict[str, Any]: + """Flatten a sharded optimizer state-dict. + + The API is similar to :meth:`shard_full_optim_state_dict`. The only + difference is that the input ``sharded_optim_state_dict`` should be + returned from :meth:`sharded_optim_state_dict`. Therefore, there will + be all-gather calls on each rank to gather ``ShardedTensor`` s. + + Args: + sharded_optim_state_dict (Dict[str, Any]): Optimizer state dict + corresponding to the unflattened parameters and holding the + sharded optimizer state. + model (torch.nn.Module): + Refer to :meth:`shard_full_optim_state_dict`. + optim (torch.optim.Optimizer): Optimizer for ``model`` 's + parameters. + + Returns: + Refer to :meth:`shard_full_optim_state_dict`. + """ + FullyShardedDataParallel._warn_legacy_optim_state_dict( + "flatten_sharded_optim_state_dict", + "optim_state_dict_to_load", + stacklevel=2, + ) + return FullyShardedDataParallel._optim_state_dict_to_load_impl( + optim_state_dict=sharded_optim_state_dict, + model=model, + optim_input=None, + optim=optim, + full_state_dict=False, + is_named_optimizer=False, + ) + + @staticmethod + def scatter_full_optim_state_dict( + full_optim_state_dict: Optional[dict[str, Any]], + model: torch.nn.Module, + optim_input: Optional[ + Union[ + list[dict[str, Any]], + Iterable[torch.nn.Parameter], + ] + ] = None, + optim: Optional[torch.optim.Optimizer] = None, + group: Optional[Any] = None, + ) -> dict[str, Any]: + """Scatter the full optimizer state dict from rank 0 to all other ranks. + + Returns the sharded optimizer state dict on each rank. + The return value is the same as :meth:`shard_full_optim_state_dict`, and on rank + 0, the first argument should be the return value of + :meth:`full_optim_state_dict`. + + Example:: + + >>> # xdoctest: +SKIP("undefined variables") + >>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP + >>> model, optim = ... + >>> full_osd = FSDP.full_optim_state_dict(model, optim) # only non-empty on rank 0 + >>> # Define new model with possibly different world size + >>> new_model, new_optim, new_group = ... + >>> sharded_osd = FSDP.scatter_full_optim_state_dict(full_osd, new_model, group=new_group) + >>> new_optim.load_state_dict(sharded_osd) + + .. note:: Both :meth:`shard_full_optim_state_dict` and + :meth:`scatter_full_optim_state_dict` may be used to get the + sharded optimizer state dict to load. Assuming that the full + optimizer state dict resides in CPU memory, the former requires + each rank to have the full dict in CPU memory, where each rank + individually shards the dict without any communication, while the + latter requires only rank 0 to have the full dict in CPU memory, + where rank 0 moves each shard to GPU memory (for NCCL) and + communicates it to ranks appropriately. Hence, the former has + higher aggregate CPU memory cost, while the latter has higher + communication cost. + + Args: + full_optim_state_dict (Optional[Dict[str, Any]]): Optimizer state + dict corresponding to the unflattened parameters and holding + the full non-sharded optimizer state if on rank 0; the argument + is ignored on nonzero ranks. + model (torch.nn.Module): Root module (which may or may not be a + :class:`FullyShardedDataParallel` instance) whose parameters + correspond to the optimizer state in ``full_optim_state_dict``. + optim_input (Optional[Union[List[Dict[str, Any]], Iterable[torch.nn.Parameter]]]): + Input passed into the optimizer representing either a + :class:`list` of parameter groups or an iterable of parameters; + if ``None``, then this method assumes the input was + ``model.parameters()``. This argument is deprecated, and there + is no need to pass it in anymore. (Default: ``None``) + optim (Optional[torch.optim.Optimizer]): Optimizer that will load + the state dict returned by this method. This is the preferred + argument to use over ``optim_input``. (Default: ``None``) + group (dist.ProcessGroup): Model's process group or ``None`` if + using the default process group. (Default: ``None``) + + Returns: + Dict[str, Any]: The full optimizer state dict now remapped to + flattened parameters instead of unflattened parameters and + restricted to only include this rank's part of the optimizer state. + """ + FullyShardedDataParallel._warn_legacy_optim_state_dict( + "scatter_full_optim_state_dict", + "optim_state_dict_to_load", + stacklevel=2, + ) + return FullyShardedDataParallel._optim_state_dict_to_load_impl( + optim_state_dict=full_optim_state_dict, + model=model, + optim_input=optim_input, + optim=optim, + full_state_dict=True, + rank0_only=True, + is_named_optimizer=False, + group=group, + ) + + @staticmethod + def rekey_optim_state_dict( + optim_state_dict: dict[str, Any], + optim_state_key_type: OptimStateKeyType, + model: torch.nn.Module, + optim_input: Optional[ + Union[ + list[dict[str, Any]], + Iterable[torch.nn.Parameter], + ] + ] = None, + optim: Optional[torch.optim.Optimizer] = None, + ) -> dict[str, Any]: + """Re-keys the optimizer state dict ``optim_state_dict`` to use the key type ``optim_state_key_type``. + + This can be used to achieve compatibility between optimizer state dicts from models with FSDP + instances and ones without. + + To re-key an FSDP full optimizer state dict (i.e. from + :meth:`full_optim_state_dict`) to use parameter IDs and be loadable to + a non-wrapped model:: + + >>> # xdoctest: +SKIP("undefined variables") + >>> wrapped_model, wrapped_optim = ... + >>> full_osd = FSDP.full_optim_state_dict(wrapped_model, wrapped_optim) + >>> nonwrapped_model, nonwrapped_optim = ... + >>> rekeyed_osd = FSDP.rekey_optim_state_dict(full_osd, OptimStateKeyType.PARAM_ID, nonwrapped_model) + >>> nonwrapped_optim.load_state_dict(rekeyed_osd) + + To re-key a normal optimizer state dict from a non-wrapped model to be + loadable to a wrapped model:: + + >>> # xdoctest: +SKIP("undefined variables") + >>> nonwrapped_model, nonwrapped_optim = ... + >>> osd = nonwrapped_optim.state_dict() + >>> rekeyed_osd = FSDP.rekey_optim_state_dict(osd, OptimStateKeyType.PARAM_NAME, nonwrapped_model) + >>> wrapped_model, wrapped_optim = ... + >>> sharded_osd = FSDP.shard_full_optim_state_dict(rekeyed_osd, wrapped_model) + >>> wrapped_optim.load_state_dict(sharded_osd) + + Returns: + Dict[str, Any]: The optimizer state dict re-keyed using the + parameter keys specified by ``optim_state_key_type``. + """ + FullyShardedDataParallel._warn_optim_input(optim_input) + using_optim_input = FullyShardedDataParallel._is_using_optim_input( + optim_input, + optim, + ) + if optim_state_key_type not in ( + OptimStateKeyType.PARAM_NAME, + OptimStateKeyType.PARAM_ID, + ): + raise AssertionError( + f"Expected optim_state_key_type to be PARAM_NAME or PARAM_ID, got {optim_state_key_type}" + ) + osd = optim_state_dict # alias + # Validate that the existing parameter keys are uniformly typed + uses_param_name_mask = [type(param_key) is str for param_key in osd["state"]] + uses_param_id_mask = [type(param_key) is int for param_key in osd["state"]] + if (any(uses_param_name_mask) and not all(uses_param_name_mask)) or ( + any(uses_param_id_mask) and not all(uses_param_id_mask) + ): + error_msg = f"Invalid parameter keys: {osd['state'].keys()}" + raise ValueError(error_msg) + # Return directly if the existing key type matches the target key type + if ( + optim_state_key_type == OptimStateKeyType.PARAM_NAME + and all(uses_param_name_mask) + ) or ( + optim_state_key_type == OptimStateKeyType.PARAM_ID + and all(uses_param_id_mask) + ): + return osd + # Otherwise, actually perform the re-keying + new_osd = {} + if optim_state_key_type == OptimStateKeyType.PARAM_NAME: # ID -> name + param_id_to_param = ( + _get_param_id_to_param_from_optim_input(model, optim_input) + if using_optim_input + else _get_param_key_to_param(optim) + ) + param_to_param_name = _get_param_to_fqn(model) + param_id_to_param_name: list[str] = [ + param_to_param_name[param] for param in param_id_to_param.values() + ] + new_osd["state"] = { + param_id_to_param_name[param_id]: param_state + for param_id, param_state in osd["state"].items() + } + new_osd["param_groups"] = copy.deepcopy(osd["param_groups"]) + for param_group in new_osd["param_groups"]: + param_group["params"] = sorted( + [ + param_id_to_param_name[param_id] + for param_id in param_group["params"] + ] + ) + return new_osd + elif optim_state_key_type == OptimStateKeyType.PARAM_ID: # name -> ID + param_name_to_param = _get_fqn_to_param(model) + param_to_param_id = ( + _get_param_to_param_id_from_optim_input(model, optim_input) + if using_optim_input + else _get_param_to_param_key(optim) + ) + # Because not all model parameters may be passed as the optimizer + # input, we may need to drop some parameters from this mapping + param_name_to_param_id = { + param_name: param_to_param_id[param] + for param_name, param in param_name_to_param.items() + if param in param_to_param_id + } + new_osd["state"] = { + param_name_to_param_id[param_name]: param_state + for param_name, param_state in osd["state"].items() + } + new_osd["param_groups"] = copy.deepcopy(osd["param_groups"]) + for param_group in new_osd["param_groups"]: + param_group["params"] = sorted( + [ + param_name_to_param_id[param_name] + for param_name in param_group["params"] + ] + ) + return new_osd + return new_osd # should never reach here + + @staticmethod + def optim_state_dict( + model: torch.nn.Module, + optim: torch.optim.Optimizer, + optim_state_dict: Optional[dict[str, Any]] = None, + group: Optional[dist.ProcessGroup] = None, + ) -> dict[str, Any]: + """ + Transform the state-dict of an optimizer corresponding to a sharded model. + + The given state-dict can be transformed to one of three types: + 1) full optimizer state_dict, 2) sharded optimizer state_dict, 3) local optimizer state_dict. + + For full optimizer state_dict, all states are unflattened and not sharded. + Rank0 only and CPU only can be specified via :meth:`state_dict_type` to + avoid OOM. + + For sharded optimizer state_dict, all states are unflattened but sharded. + CPU only can be specified via :meth:`state_dict_type` to further save + memory. + + For local state_dict, no transformation will be performed. But a state + will be converted from nn.Tensor to ShardedTensor to represent its sharding + nature (this is not supported yet). + + Example:: + + >>> # xdoctest: +SKIP("undefined variables") + >>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP + >>> from torch.distributed.fsdp import StateDictType + >>> from torch.distributed.fsdp import FullStateDictConfig + >>> from torch.distributed.fsdp import FullOptimStateDictConfig + >>> # Save a checkpoint + >>> model, optim = ... + >>> FSDP.set_state_dict_type( + >>> model, + >>> StateDictType.FULL_STATE_DICT, + >>> FullStateDictConfig(rank0_only=False), + >>> FullOptimStateDictConfig(rank0_only=False), + >>> ) + >>> state_dict = model.state_dict() + >>> optim_state_dict = FSDP.optim_state_dict(model, optim) + >>> save_a_checkpoint(state_dict, optim_state_dict) + >>> # Load a checkpoint + >>> model, optim = ... + >>> state_dict, optim_state_dict = load_a_checkpoint() + >>> FSDP.set_state_dict_type( + >>> model, + >>> StateDictType.FULL_STATE_DICT, + >>> FullStateDictConfig(rank0_only=False), + >>> FullOptimStateDictConfig(rank0_only=False), + >>> ) + >>> model.load_state_dict(state_dict) + >>> optim_state_dict = FSDP.optim_state_dict_to_load( + >>> model, optim, optim_state_dict + >>> ) + >>> optim.load_state_dict(optim_state_dict) + + Args: + model (torch.nn.Module): Root module (which may or may not be a + :class:`FullyShardedDataParallel` instance) whose parameters + were passed into the optimizer ``optim``. + optim (torch.optim.Optimizer): Optimizer for ``model`` 's + parameters. + optim_state_dict (Dict[str, Any]): the target optimizer state_dict to + transform. If the value is None, optim.state_dict() will be used. ( + Default: ``None``) + group (dist.ProcessGroup): Model's process group across which parameters + are sharded or ``None`` if using the default process group. ( + Default: ``None``) + + Returns: + Dict[str, Any]: A :class:`dict` containing the optimizer state for + ``model``. The sharding of the optimizer state is based on + ``state_dict_type``. + """ + state_dict_settings = FullyShardedDataParallel.get_state_dict_type(model) + if optim_state_dict is None: + optim_state_dict = optim.state_dict() + return FullyShardedDataParallel._optim_state_dict_impl( + model=model, + optim=optim, + optim_state_dict=optim_state_dict, + optim_input=None, + rank0_only=getattr( + state_dict_settings.optim_state_dict_config, "rank0_only", False + ), + full_state_dict=state_dict_settings.state_dict_type + == StateDictType.FULL_STATE_DICT, + group=group, + cpu_offload=getattr( + state_dict_settings.optim_state_dict_config, "offload_to_cpu", True + ), + _stacklevel=2, + ) + + @staticmethod + def optim_state_dict_to_load( + model: torch.nn.Module, + optim: torch.optim.Optimizer, + optim_state_dict: dict[str, Any], + is_named_optimizer: bool = False, + load_directly: bool = False, + group: Optional[dist.ProcessGroup] = None, + ) -> dict[str, Any]: + """ + Convert an optimizer state-dict so that it can be loaded into the optimizer associated with the FSDP model. + + Given a ``optim_state_dict`` that is transformed through + :meth:`optim_state_dict`, it gets converted to the flattened optimizer + state_dict that can be loaded to ``optim`` which is the optimizer for + ``model``. ``model`` must be sharded by FullyShardedDataParallel. + + >>> # xdoctest: +SKIP("undefined variables") + >>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP + >>> from torch.distributed.fsdp import StateDictType + >>> from torch.distributed.fsdp import FullStateDictConfig + >>> from torch.distributed.fsdp import FullOptimStateDictConfig + >>> # Save a checkpoint + >>> model, optim = ... + >>> FSDP.set_state_dict_type( + >>> model, + >>> StateDictType.FULL_STATE_DICT, + >>> FullStateDictConfig(rank0_only=False), + >>> FullOptimStateDictConfig(rank0_only=False), + >>> ) + >>> state_dict = model.state_dict() + >>> original_osd = optim.state_dict() + >>> optim_state_dict = FSDP.optim_state_dict( + >>> model, + >>> optim, + >>> optim_state_dict=original_osd + >>> ) + >>> save_a_checkpoint(state_dict, optim_state_dict) + >>> # Load a checkpoint + >>> model, optim = ... + >>> state_dict, optim_state_dict = load_a_checkpoint() + >>> FSDP.set_state_dict_type( + >>> model, + >>> StateDictType.FULL_STATE_DICT, + >>> FullStateDictConfig(rank0_only=False), + >>> FullOptimStateDictConfig(rank0_only=False), + >>> ) + >>> model.load_state_dict(state_dict) + >>> optim_state_dict = FSDP.optim_state_dict_to_load( + >>> model, optim, optim_state_dict + >>> ) + >>> optim.load_state_dict(optim_state_dict) + + Args: + model (torch.nn.Module): Root module (which may or may not be a + :class:`FullyShardedDataParallel` instance) whose parameters + were passed into the optimizer ``optim``. + optim (torch.optim.Optimizer): Optimizer for ``model`` 's + parameters. + optim_state_dict (Dict[str, Any]): The optimizer states to be loaded. + is_named_optimizer (bool): Is this optimizer a NamedOptimizer or + KeyedOptimizer. Only set to True if ``optim`` is TorchRec's + KeyedOptimizer or torch.distributed's NamedOptimizer. + load_directly (bool): If this is set to True, this API will also + call optim.load_state_dict(result) before returning the result. + Otherwise, users are responsible to call ``optim.load_state_dict()`` + (Default: ``False``) + group (dist.ProcessGroup): Model's process group across which parameters + are sharded or ``None`` if using the default process group. ( + Default: ``None``) + """ + state_dict_settings = FullyShardedDataParallel.get_state_dict_type(model) + result = FullyShardedDataParallel._optim_state_dict_to_load_impl( + optim_state_dict=optim_state_dict, + model=model, + optim_input=None, + optim=optim, + full_state_dict=( + state_dict_settings.state_dict_type == StateDictType.FULL_STATE_DICT + ), + rank0_only=getattr( + state_dict_settings.optim_state_dict_config, "rank0_only", False + ), + is_named_optimizer=is_named_optimizer, + group=group, + ) + if load_directly: + optim.load_state_dict(result) + return result + + def register_comm_hook(self, state: object, hook: callable): + """Register a communication hook. + + This is an enhancement that provides a flexible hook to users where they can specify how FSDP aggregates + gradients across multiple workers. + This hook can be used to implement several algorithms like + `GossipGrad `_ and gradient compression + which involve different communication strategies for + parameter syncs while training with :class:`FullyShardedDataParallel`. + + .. warning :: + FSDP communication hook should be registered before running an initial forward pass + and only once. + + Args: + state (object): Passed to the hook to maintain any state information during the training process. + Examples include error feedback in gradient compression, + peers to communicate with next in `GossipGrad `_, etc. + It is locally stored by each worker + and shared by all the gradient tensors on the worker. + hook (Callable): Callable, which has one of the following signatures: + 1) ``hook: Callable[torch.Tensor] -> None``: + This function takes in a Python tensor, which represents + the full, flattened, unsharded gradient with respect to all variables + corresponding to the model this FSDP unit is wrapping + (that are not wrapped by other FSDP sub-units). + It then performs all necessary processing and returns ``None``; + 2) ``hook: Callable[torch.Tensor, torch.Tensor] -> None``: + This function takes in two Python tensors, the first one represents + the full, flattened, unsharded gradient with respect to all variables + corresponding to the model this FSDP unit is wrapping + (that are not wrapped by other FSDP sub-units). The latter + represents a pre-sized tensor to store a chunk of a sharded gradient after + reduction. + In both cases, callable performs all necessary processing and returns ``None``. + Callables with signature 1 are expected to handle gradient communication for a `NO_SHARD` case. + Callables with signature 2 are expected to handle gradient communication for sharded cases. + + """ + if not self.check_is_root(): + raise AssertionError( + "register_comm_hook can only be called on a root instance." + ) + for fsdp_state in traversal_utils._get_fsdp_states(self): + if fsdp_state.sharding_strategy in HYBRID_SHARDING_STRATEGIES: + raise AssertionError( + f"Communication hook is not supported for hybrid strategies: {fsdp_state.sharding_strategy}" + ) + if fsdp_state._comm_hook is not None: + raise AssertionError("A communication hook is already registered") + if not callable(hook): + raise ValueError( + f"The communication hook must be callable but got {hook}" + ) + fsdp_state._comm_hook = hook + fsdp_state._comm_hook_state = state + + def _unshard(self, async_op: bool = False): + class UnshardHandle: + def __init__( + self, + flat_param_handle: Optional[FlatParamHandle], + unshard_event: torch.Event, + ): + self._flat_param_handle = flat_param_handle + self._unshard_event = unshard_event + + def wait(self): + if self._flat_param_handle is not None: + current_stream = ( + self._flat_param_handle._device_handle.current_stream() + ) + current_stream.wait_event(self._unshard_event) + self._flat_param_handle = None + + if self._handle: + with self._use_training_state( + TrainingState.FORWARD_BACKWARD, HandleTrainingState.FORWARD + ): + _unshard( + self, self._handle, self._unshard_stream, self._pre_unshard_stream + ) + self._unshard_event = self._unshard_stream.record_event() + self._handle._prefetched = True + unshard_handle = UnshardHandle(self._handle, self._unshard_stream) + if async_op: + return unshard_handle + unshard_handle.wait() + return None + + def _wait_unshard_streams_on_current_stream(self): + _wait_for_computation_stream( + self._device_handle.current_stream(), + self._unshard_stream, + self._pre_unshard_stream, + ) + + @contextlib.contextmanager + def _use_training_state( + self, training_state: TrainingState, handle_training_state: HandleTrainingState + ): + prev_training_state = self.training_state + self.training_state = training_state + if self._handle: + prev_handle_training_state = self._handle._training_state + self._handle._training_state = handle_training_state + try: + yield + finally: + self.training_state = prev_training_state + if self._handle: + self._handle._training_state = prev_handle_training_state + + +def _get_grad_norm( + params: Iterable[nn.Parameter], + norm_type: float, + zero: torch.Tensor, + device: torch.device, +) -> torch.Tensor: + """ + Return the gradient norm of parameters ``param`` s, where the gradients are viewed as a single vector. + + The returned norm is in FP32 even if parameters/gradients are in a low precision. This is because the downstream + use of this return value is a reduction across ranks. + """ + params_with_grad = [param for param in params if param.grad is not None] + if len(params_with_grad) == 0: + # Reuse a tensor for zero to avoid a GPU sync + return zero + grads = [param.grad for param in params_with_grad] + grad_dtypes = {grad.dtype for grad in grads} + if len(grad_dtypes) != 1: + raise ValueError( + f"Requires uniform dtype across all gradients but got {grad_dtypes}" + ) + # Compute the gradient norm in FP32, where we treat the gradients as a + # single vector + grad_norm = torch.linalg.vector_norm( + torch.stack( + [ + torch.linalg.vector_norm(grad.detach(), norm_type, dtype=torch.float32) + for grad in grads + ], + ), + norm_type, + dtype=torch.float32, + ) + return grad_norm.to(device=device) + + +def _get_param_to_fqn( + model: torch.nn.Module, +) -> dict[torch.nn.Parameter, str]: + """ + Construct a mapping from parameters to their parameter names. + + The ``model`` should not contain any :class:`FullyShardedDataParallel` instances, which + means that none of the parameters should be ``FlatParameter`` s. As a + result, compared to :meth:`_get_param_to_fqns`, the mapped + values may be flattened from singleton :class:`list` s to the contained + names themselves. + + Args: + model (torch.nn.Module): Root module, which should not contain any + :class:`FullyShardedDataParallel` instances. + """ + param_to_param_names = _get_param_to_fqns(model) + for param_names in param_to_param_names.values(): + if len(param_names) == 0: + raise AssertionError( + "`_get_param_to_fqns()` should not construct empty lists" + ) + if len(param_names) > 1: + raise RuntimeError( + "Each parameter should only map to one parameter name but got " + f"{len(param_names)}: {param_names}" + ) + param_to_param_name = { + param: param_names[0] for param, param_names in param_to_param_names.items() + } + return param_to_param_name + + +def _get_fqn_to_param( + model: torch.nn.Module, +) -> dict[str, torch.nn.Parameter]: + """Construct the inverse mapping of :meth:`_get_param_to_fqn`.""" + param_to_param_name = _get_param_to_fqn(model) + return dict(zip(param_to_param_name.values(), param_to_param_name.keys())) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/fsdp/sharded_grad_scaler.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/fsdp/sharded_grad_scaler.py new file mode 100644 index 0000000000000000000000000000000000000000..3986d733328c80f12e6eed138386a9e8aafe6a3a --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/fsdp/sharded_grad_scaler.py @@ -0,0 +1,377 @@ +# mypy: allow-untyped-defs +import logging +from collections import abc, defaultdict +from collections.abc import Iterable +from typing import Any, Optional, overload, Union + +import torch +import torch.distributed as dist +from torch.amp.grad_scaler import _MultiDeviceReplicator, GradScaler, OptState +from torch.distributed.distributed_c10d import ProcessGroup + + +logger = logging.getLogger(__name__) + + +def _refresh_per_optimizer_state() -> dict[str, Any]: + return {"stage": OptState.READY, "found_inf_per_device": {}} + + +def _is_supported_device(tensor: torch.Tensor) -> bool: + return tensor.is_cuda or tensor.device.type in ( + "xla", + "cpu", + "hpu", + "mtia", + "xpu", + torch._C._get_privateuse1_backend_name(), + ) + + +class _GeneralMultiDeviceReplicator(_MultiDeviceReplicator): + """ + Lazily serves tensor to request device. This class extends + _MultiDeviceReplicator to allow support for "cpu" as a device. + """ + + def __init__(self, master_tensor: torch.Tensor) -> None: + if not _is_supported_device(master_tensor): + raise AssertionError( + f"Expected supported device, got {master_tensor.device}" + ) + self.master = master_tensor + self._per_device_tensors: dict[torch.device, torch.Tensor] = {} + + +class ShardedGradScaler(GradScaler): + """ + ShardedGradScaler helps perform gradient scaling in a shard aware manner. It extends + functionality from GradScaler: + * Supports Pytorch DDP and FSDP implementations + * Support CPU offloaded tensors (as used in fully sharded data parallel[FSDP]) + * Supports the custom Mixed Precision loss dtype (fp16, bf16) that FSDP returns + * Sync inf/nan for scaled gradient tensors on any torch.device (where tensors are placed) across + nodes + + Example:: + + # Creates a ShardedGradScaler once at the beginning of training. + scaler = ShardedGradScaler() + + for epoch in epochs: + for input, target in data: + optimizer.zero_grad() + output = model(input) + loss = loss_fn(output, target) + + # Scales loss. Calls backward() on scaled loss to create scaled gradients. + scaler.scale(loss).backward() + + # scaler.step() first unscales gradients of the optimizer's params. + # If gradients don't contain infs/NaNs, optimizer.step() is then called, + # otherwise, optimizer.step() is skipped. + scaler.step(optimizer) + + # Updates the scale for next iteration. + scaler.update() + + See :class:`GradScaler` for explanation of scaling/unscaling and more use cases. + + Args: + init_scale (float, optional, default=2.**16): Initial scale factor. + growth_factor (float, optional, default=2.0): Factor by which the scale is multiplied during + :meth:`update` if no inf/NaN gradients occur for ``growth_interval`` consecutive iterations. + backoff_factor (float, optional, default=0.5): Factor by which the scale is multiplied during + :meth:`update` if inf/NaN gradients occur in an iteration. + growth_interval (int, optional, default=2000): Number of consecutive iterations without inf/NaN gradients + that must occur for the scale to be multiplied by ``growth_factor``. + enabled (bool, optional): If ``False``, disables gradient scaling. :meth:`step` simply + invokes the underlying ``optimizer.step()``, and other methods become no-ops. + Default: ``True`` + process_group (ProcessGroup, optional, default=torch.distributed.group.WORLD): + process group for sharding + """ + + def __init__( + self, + device: str = "cuda", + init_scale: float = 2.0**16, + backoff_factor: float = 0.5, + growth_factor: float = 2.0, + growth_interval: int = 2000, + enabled: bool = True, + process_group: Optional[ProcessGroup] = dist.group.WORLD, + ) -> None: + super().__init__( + device, + init_scale=init_scale, + backoff_factor=backoff_factor, + growth_factor=growth_factor, + growth_interval=growth_interval, + enabled=enabled, + ) + if self._enabled: + self.process_group = process_group + self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state) + + @overload + def scale(self, outputs: torch.Tensor) -> torch.Tensor: ... + + @overload + def scale(self, outputs: list[torch.Tensor]) -> list[torch.Tensor]: ... + + @overload + def scale(self, outputs: tuple[torch.Tensor, ...]) -> tuple[torch.Tensor, ...]: ... + + @overload + def scale(self, outputs: Iterable[torch.Tensor]) -> Iterable[torch.Tensor]: ... + + def scale( + self, outputs: Union[torch.Tensor, Iterable[torch.Tensor]] + ) -> Union[torch.Tensor, Iterable[torch.Tensor]]: + if not self._enabled: + return outputs + + if isinstance(outputs, torch.Tensor): + if not _is_supported_device(outputs): + raise AssertionError(f"Expected supported device, got {outputs.device}") + if self._scale is None: + self._lazy_init_scale_growth_tracker(outputs.device) + if self._scale is None: + raise AssertionError("Expected _scale to be initialized, got None") + scaled_output = outputs * self._scale.to( + device=outputs.device, non_blocking=True + ) + # Here we ensure the return dtype is the same as the outputs dtype. + # For the FSDP + Mixed Precision use case, the loss output is in the Mixed Precision + # format (fp16, bf16) and so the scaled loss should be of the same dtype. + return scaled_output.type(outputs.dtype) + + stash: list[_GeneralMultiDeviceReplicator] = [] + + def apply_scale(val: Union[torch.Tensor, Iterable[torch.Tensor]]): + if isinstance(val, torch.Tensor): + if not _is_supported_device(val): + raise AssertionError(f"Expected supported device, got {val.device}") + if len(stash) == 0: + if self._scale is None: + self._lazy_init_scale_growth_tracker(val.device) + if self._scale is None: + raise AssertionError( + "Expected _scale to be initialized, got None" + ) + stash.append(_GeneralMultiDeviceReplicator(self._scale)) + scaled_val = val * stash[0].get(val.device) + # Here we ensure the return dtype is the same as the outputs dtype. + # For the FSDP + Mixed Precision use case, the loss output is in the Mixed Precision + # format (fp16, bf16) and so the scaled loss should be of the same dtype. + return scaled_val.type(val.dtype) + if isinstance(val, abc.Iterable): + iterator = map(apply_scale, val) + if isinstance(val, (list, tuple)): + return type(val)(iterator) + return iterator + raise ValueError("outputs must be a Tensor or an iterable of Tensors") + + return apply_scale(outputs) + + def _unscale_grads_( + self, + optimizer: torch.optim.Optimizer, + inv_scale: torch.Tensor, + found_inf: torch.Tensor, + allow_fp16: bool = True, + ) -> dict[torch.device, torch.Tensor]: + per_device_inv_scale = _GeneralMultiDeviceReplicator(inv_scale) + per_device_found_inf = _GeneralMultiDeviceReplicator(found_inf) + + # To set up _amp_foreach_non_finite_check_and_unscale_, split grads by device and dtype. + # There could be thousands of grads, so we'd like to iterate through them just once. + # However, we don't know their devices or dtypes in advance. + + # https://stackoverflow.com/questions/5029934/defaultdict-of-defaultdict + # Google says mypy struggles with defaultdicts type annotations. + per_device_and_dtype_grads = defaultdict(lambda: defaultdict(list)) # type: ignore[var-annotated] + with torch.no_grad(): + for group in optimizer.param_groups: + for param in group["params"]: + if param.grad is None: + continue + if (not allow_fp16) and param.grad.dtype == torch.float16: + raise ValueError("Attempting to unscale FP16 gradients.") + if param.grad.is_sparse: + # is_coalesced() == False means the sparse grad has values with duplicate indices. + # coalesce() deduplicates indices and adds all values that have the same index. + # For scaled fp16 values, there's a good chance coalescing will cause overflow, + # so we should check the coalesced _values(). + if param.grad.dtype is torch.float16: + # coalesce is not supported in torch.float16 + param_grad_fp32 = param.grad.type(torch.float32).coalesce() + param.grad = param_grad_fp32.type(torch.float16) + to_unscale = param.grad._values() + else: + to_unscale = param.grad + + per_device_and_dtype_grads[to_unscale.device][ + to_unscale.dtype + ].append(to_unscale) + + for device, per_dtype_grads in per_device_and_dtype_grads.items(): + for grads in per_dtype_grads.values(): + torch._amp_foreach_non_finite_check_and_unscale_( + grads, + per_device_found_inf.get(device), + per_device_inv_scale.get(device), + ) + # There exist contexts (e.g. w/ `use_orig_params=True`) wherein some + # ranks may have no (non-zero sized) parameter shards, necessitating the + # initialization of `per_device_found_inf._per_device_tensors` here + if not per_device_found_inf._per_device_tensors: + if self._scale is None: + raise AssertionError("Expected _scale to be initialized, got None") + per_device_found_inf.get(self._scale.device) + return per_device_found_inf._per_device_tensors + + def unscale_(self, optimizer: torch.optim.Optimizer) -> None: + if not self._enabled: + return + + self._check_scale_growth_tracker("unscale_") + + optimizer_state = self._per_optimizer_states[id(optimizer)] + + if optimizer_state["stage"] is OptState.UNSCALED: + raise RuntimeError( + "unscale_() has already been called on this optimizer since the last update()." + ) + elif optimizer_state["stage"] is OptState.STEPPED: + raise RuntimeError("unscale_() is being called after step().") + + # FP32 division can be imprecise for certain compile options, so we carry out the reciprocal in FP64. + if self._scale is None: + raise AssertionError("Expected _scale to be initialized, got None") + inv_scale = self._scale.double().reciprocal().float() + found_inf = torch.full( + (1,), 0.0, dtype=torch.float32, device=self._scale.device + ) + + optimizer_state["found_inf_per_device"] = self._unscale_grads_( + optimizer, inv_scale, found_inf, True + ) + optimizer_state["stage"] = OptState.UNSCALED + + # Synchronize the detected inf across the ranks + optimizer_state = self._per_optimizer_states[id(optimizer)] + works = [] + found_inf_on_cpus = [] + found_inf_on_devices = [] + + for found_inf in optimizer_state["found_inf_per_device"].values(): + if self._device != "cpu" and found_inf.device.type == "cpu": + found_inf_on_cpus.append(found_inf) + found_inf_on_device = found_inf.to(self._device) + found_inf_on_devices.append(found_inf_on_device) + works.append( + dist.all_reduce( + found_inf_on_device, async_op=True, group=self.process_group + ) + ) + else: + works.append( + dist.all_reduce(found_inf, async_op=True, group=self.process_group) + ) + for work in works: + work.wait() + if found_inf_on_cpus: + torch._foreach_copy_(found_inf_on_cpus, found_inf_on_devices) + + def _amp_update_scale_cpu_(self, found_inf: torch.Tensor) -> None: + """ + If found_inf is 1.0 (True), then scale is multiplied by backoff_factor and growth_tracker is set to zero. + Otherwise, scale is multiplied by the growth factor when the growth interval is reached. + """ + if self._scale is None or self._growth_tracker is None: + raise AssertionError( + "Expected _scale and _growth_tracker to be initialized, got None" + ) + + if found_inf.item() >= 1.0: + self._scale *= self._backoff_factor + self._growth_tracker.fill_(0) + else: + successful = self._growth_tracker + 1 + if successful == self._growth_interval: + self._scale *= self._growth_factor + self._growth_tracker.fill_(0) + else: + self._growth_tracker = successful + + def update(self, new_scale: Optional[Union[float, torch.Tensor]] = None) -> None: + """ + Updates the scale factor. + If any optimizer steps were skipped the scale is multiplied by ``backoff_factor`` + to reduce it. If ``growth_interval`` unskipped iterations occurred consecutively, + the scale is multiplied by ``growth_factor`` to increase it. + Passing ``new_scale`` sets the new scale value manually. (``new_scale`` is not + used directly, it's used to fill GradScaler's internal scale tensor. So if + ``new_scale`` was a tensor, later in-place changes to that tensor will not further + affect the scale GradScaler uses internally.) + Args: + new_scale (float or :class:`torch.Tensor`, optional, default=None): New scale factor. + .. warning:: + :meth:`update` should only be called at the end of the iteration, after ``scaler.step(optimizer)`` has + been invoked for all optimizers used this iteration. + """ + + if not self._enabled: + return + + _scale, _growth_tracker = self._check_scale_growth_tracker("update") # type: ignore[var-annotated] + + if new_scale is not None: + # Accept a new user-defined scale. + if isinstance(new_scale, float): + self._scale.fill_(new_scale) # type: ignore[union-attr] + else: + reason = ( + "new_scale should be a float or a 1-element torch.cuda.FloatTensor or " + "torch.FloatTensor with requires_grad=False." + ) + if new_scale.device.type != self._device: + raise AssertionError(reason) + if new_scale.numel() != 1: + raise AssertionError(reason) + if new_scale.requires_grad is not False: + raise AssertionError(reason) + self._scale.copy_(new_scale) # type: ignore[union-attr] + else: + # Consume shared inf/nan data collected from optimizers to update the scale. + # If all found_inf tensors are on the same device as self._scale, this operation is asynchronous. + found_infs = [ + found_inf.to(device=_scale.device, non_blocking=True) + for state in self._per_optimizer_states.values() + for found_inf in state["found_inf_per_device"].values() + ] + + if len(found_infs) == 0: + raise AssertionError("No inf checks were recorded prior to update.") + + found_inf_combined = found_infs[0] + if len(found_infs) > 1: + for i in range(1, len(found_infs)): + found_inf_combined += found_infs[i] + + if _scale.device.type == "cpu": + self._amp_update_scale_cpu_(found_inf_combined) + else: + torch._amp_update_scale_( + self._scale, # type: ignore[arg-type] + self._growth_tracker, # type: ignore[arg-type] + found_inf_combined, + self._growth_factor, # type: ignore[arg-type] + self._backoff_factor, # type: ignore[arg-type] + self._growth_interval, # type: ignore[arg-type] + ) + + # To prepare for next iteration, clear the data collected from optimizers this iteration. + self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/fsdp/wrap.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/fsdp/wrap.py new file mode 100644 index 0000000000000000000000000000000000000000..f731854dab2eb475e7c8321738552fed205db70d --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/fsdp/wrap.py @@ -0,0 +1,608 @@ +# mypy: allow-untyped-defs +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +import contextlib +import copy +from abc import ABC, abstractmethod +from collections.abc import Callable, Generator, Iterable, Sequence +from typing import Any, cast, Optional, Union + +import torch.nn as nn + + +__all__ = [ + "always_wrap_policy", + "lambda_auto_wrap_policy", + "transformer_auto_wrap_policy", + "size_based_auto_wrap_policy", + "enable_wrap", + "wrap", + "CustomPolicy", + "ModuleWrapPolicy", +] + + +# NOTE: We intentionally keep this function simple and isolate the complexity +# to `fn` to enable using this function generically. We may move this to a +# non-FSDP-specific folder and/or make it public in the future. +def _post_order_apply( + root_module: nn.Module, + fn: Callable[[nn.Module], Optional[nn.Module]], +): + """ + This applies ``fn`` to every module in the module tree of ``root_module`` + following a post-order traversal. If ``fn`` returns an :class:`nn.Module`, + then this replaces the original module with the newly returned one in the + tree. Otherwise, ``fn`` should return ``None``, in which case the module is + not changed. + """ + # Track visited modules to avoid visiting shared modules multiple times + visited_modules: set[nn.Module] = {root_module} + + def _post_order_apply_inner( + module: nn.Module, + module_name: str, + parent_module: Optional[nn.Module], + ): + for child_module_name, child_module in module.named_children(): + if child_module not in visited_modules: + visited_modules.add(child_module) + _post_order_apply_inner(child_module, child_module_name, module) + optional_module = fn(module) + if optional_module is not None: + if not isinstance(parent_module, nn.Module): + raise AssertionError( + "Non-root modules should have their parent module set but got " + f"{parent_module} for {module}" + ) + if not module_name: + raise AssertionError( + "Non-root modules should have their module name set but got " + f"an empty module name for {module}" + ) + if not isinstance(optional_module, nn.Module): + raise AssertionError( + f"fn should return None or an nn.Module but got {optional_module}" + ) + setattr(parent_module, module_name, optional_module) + + _post_order_apply_inner(root_module, "", None) + + +def _construct_wrap_fn( + root_module: nn.Module, + target_module_to_kwargs: dict[nn.Module, dict[str, Any]], + fsdp_fn: Callable, +) -> Callable[[nn.Module], Optional[nn.Module]]: + """ + This constructs the "wrap" function to pass to :func:`_post_order_apply` + based on ``target_module_to_kwargs``, which should be constructed from the + wrapping policy. + """ + + def fn(module: nn.Module) -> Optional[nn.Module]: + # Explicitly avoid wrapping the root module since for FSDP, it is + # handled by the caller + if module in target_module_to_kwargs and module is not root_module: + kwargs = target_module_to_kwargs[module] + return fsdp_fn(module, **kwargs) + return None + + return fn + + +def _run_mixed_precision_override_policy( + root_module: nn.Module, + module_classes: Iterable[type[nn.Module]], + ignored_modules: set[nn.Module], + root_kwargs: dict[str, Any], + target_module_to_kwargs: dict[nn.Module, dict[str, Any]], +): + module_classes_tuple = tuple(set(module_classes)) + for module in root_module.modules(): + if module in ignored_modules: + continue + elif isinstance(module, module_classes_tuple): + # This policy overrides any existing policy + if module not in target_module_to_kwargs: + # Only inherit from the root kwargs if not already specified + target_module_to_kwargs[module] = root_kwargs + target_module_to_kwargs[module]["mixed_precision"] = None + return target_module_to_kwargs + + +def always_wrap_policy(*args, **kwargs) -> bool: + """ + A simple recursive wrap policy that always returns ``True``. This means + that every submodule is wrapped by the wrapper class in + :func:`_recursive_wrap`. + """ + return True + + +class _Policy(ABC): + """ + This defines an abstract base class that represents a policy for applying + a module-level API. + """ + + @abstractmethod + def _run_policy( + self, + root_module: nn.Module, + ignored_modules: set[nn.Module], + root_kwargs: dict[str, Any], + ) -> dict[nn.Module, dict[str, Any]]: + """ + This should return a dict ``target_module_to_kwargs`` that maps from + each target module to wrap to its kwargs. + """ + ... + + +def _module_wrap_policy( + module: nn.Module, + recurse: bool, + nonwrapped_numel: int, + module_classes: set[type[nn.Module]], +) -> bool: + """ + This auto wrap policy wraps every module that is an instance of any type in + ``module_classes`` as its own FSDP instance. The root module given by + ``module`` is always wrapped as an FSDP instance regardless. Since the + wrapping proceeds bottom up, each FSDP instance manages the parameters in + its subtree excluding any already managed by a child FSDP instance. + + Args: + module (nn.Module): Current module being considered. + recurse (bool): If ``False``, then this function must decide whether + ``module`` should be wrapped as an FSDP instance or not. If + ``True``, then the function is still recursing down the module + tree as a part of the DFS. + nonwrapped_numel (int): Parameter numel not yet wrapped. + module_classes (Set[Type[nn.Module]]): Set of module classes that are + wrapped as FSDP instances. + + Returns: + ``True`` if ``recurse=True``, and whether ``module`` should be wrapped + if ``recurse=False``. + """ + if recurse: + return True # always recurse + return isinstance(module, tuple(module_classes)) + + +class ModuleWrapPolicy(_Policy): + """ + This policy applies to every module of the specified module classes, + passing in the kwargs given to the root. + """ + + def __init__(self, module_classes: Iterable[type[nn.Module]]): + module_classes_set = set(module_classes) + self._module_classes = module_classes_set + self._module_classes_str = str(module_classes_set) + + def _run_policy( + self, + root_module: nn.Module, + ignored_modules: set[nn.Module], + root_kwargs: dict[str, Any], + ) -> dict[nn.Module, dict[str, Any]]: + module_classes = tuple(self._module_classes) + target_module_to_kwargs: dict[nn.Module, dict[str, Any]] = {} + for module in root_module.modules(): + if module in ignored_modules: + continue + elif isinstance(module, module_classes): + # Shallow copy to avoid coupling changes across modules + target_module_to_kwargs[module] = copy.copy(root_kwargs) + return target_module_to_kwargs + + def __call__(self, module, recurse, *args, **kwargs): + # nonwrapped_numel is not used. + return _module_wrap_policy( + module, recurse, nonwrapped_numel=-1, module_classes=self._module_classes + ) + + def __repr__(self) -> str: + return super().__repr__() + f"({self._module_classes_str})" + + +class CustomPolicy(_Policy): + """ + This policy takes in a lambda function that maps a given ``nn.Module`` to + either ``False``, ``True``, or a kwarg dictionary. + - If the function returns ``False`` or an empty dictionary, then the module + does not have the API applied. + - If the function returns ``True``, then the module has the API applied + with the root's kwargs. + - If the function returns a non-empty dictionary, then the module has the + API applied, and the dictionary overrides the root's kwargs. + + Example:: + + >>> # xdoctest: +SKIP("undefined variables") + >>> model = init_transformer_model(...) + >>> def lambda_fn(module: nn.Module): + >>> if module is model.lm_head: + >>> return {"sharding_strategy": ShardingStrategy.SHARD_GRAD_OP} + >>> elif isinstance(module, TransformerBlock): + >>> return True + >>> return False + >>> policy = CustomPolicy(lambda_fn) + >>> fsdp_model = FSDP(model, auto_wrap_policy=policy) + """ + + def __init__(self, lambda_fn: Callable[[nn.Module], Union[bool, dict[str, Any]]]): + self._lambda_fn = lambda_fn + + def _run_policy( + self, + root_module: nn.Module, + ignored_modules: set[nn.Module], + root_kwargs: dict[str, Any], + ) -> dict[nn.Module, dict[str, Any]]: + target_module_to_kwargs: dict[nn.Module, dict[str, Any]] = {} + for module in root_module.modules(): + if module in ignored_modules: + continue + res = self._lambda_fn(module) + if not isinstance(res, (dict, bool)): + raise ValueError( + "The lambda_fn passed to CustomPolicy should return " + f"False/True or a kwarg dict, but it returned {res}" + ) + if not res: + continue + kwargs = copy.copy(root_kwargs) + if isinstance(res, dict): + # Override the root kwargs with the ones specified by the + # lambda function + kwargs.update(res) + target_module_to_kwargs[module] = kwargs + return target_module_to_kwargs + + +def lambda_auto_wrap_policy( + module: nn.Module, recurse: bool, nonwrapped_numel: int, lambda_fn: Callable +) -> bool: + """ + A convenient auto wrap policy to wrap submodules based on an arbitrary user + function. If `lambda_fn(submodule) == True``, the submodule will be wrapped as + a `wrapper_cls` unit. + + Return if a module should be wrapped during auto wrapping. + + The first three parameters are required by :func:`_recursive_wrap`. + + Args: + module (nn.Module): Current module being considered. + recurse (bool): If ``False``, then this function must decide whether + ``module`` should be wrapped as an FSDP instance or not. If + ``True``, then the function is still recursing down the module + tree as a part of the DFS. + nonwrapped_numel (int): Parameter numel not yet wrapped. + + lambda_fn (Callable[[nn.Module], bool]): If this returns ``True``, then + this module will be wrapped. + """ + if recurse: + return True # always recurse + return lambda_fn(module) + + +def transformer_auto_wrap_policy( + module: nn.Module, + recurse: bool, + nonwrapped_numel: int, + transformer_layer_cls: set[type[nn.Module]], +) -> bool: + """ + See :func:`_module_wrap_policy`, where ``transformer_layer_cls`` is the + same as ``module_classes``. Note that shared parameters must be wrapped in + the same FSDP instance, so this auto wrap policy can help wrap shared + embeddings into the same FSDP instance for transformer models. + """ + return _module_wrap_policy(module, recurse, nonwrapped_numel, transformer_layer_cls) + + +def _wrap_module_cls_individually( + module: nn.Module, module_classes: Sequence[type], recurse: bool, *args, **kwargs +): + if recurse: + # always recurse + return True + else: + # if not recursing, decide whether we should wrap based on whether the type of module + # is in `module_classes`. + return isinstance(module, tuple(module_classes)) + + +def _or_policy( + module: nn.Module, + recurse: bool, + nonwrapped_numel: int, + policies, +) -> bool: + """ + A policy that wraps ``module`` if any policy in the passed in iterable of + ``policies`` returns ``True``. + """ + return any( + policy(module=module, recurse=recurse, nonwrapped_numel=nonwrapped_numel) + for policy in policies + ) + + +def size_based_auto_wrap_policy( + module: nn.Module, + recurse: bool, + nonwrapped_numel: int, + # Additional custom arguments + min_num_params: int = int(1e8), + force_leaf_modules: Optional[set[type[nn.Module]]] = None, + exclude_wrap_modules: Optional[set[type[nn.Module]]] = None, +) -> bool: + """ + A size-based auto wrap policy. + + Args: + module (nn.Module): Current module being considered. + recurse (bool): If ``False``, then this function must decide whether + ``module`` should be wrapped as an FSDP instance or not. If + ``True``, then the function is still recursing down the module + tree as a part of the DFS. + nonwrapped_numel (int): Parameter numel not yet wrapped. + + min_num_params (int): Customizable policy input that controls the size + threshold over which a module is ready to be wrapped. This is in + units of numel. + force_leaf_modules (Optional[set[type[nn.Module]]]): Set of module types to keep + as leaves, i.e. their children will never be wrapped. + exclude_wrap_modules (Optional[set[type[nn.Module]]]): Set of module types to be + excluded in wrapping. + + Returns: + Whether ``module`` should be wrapped. + """ + force_leaf_modules = ( + size_based_auto_wrap_policy.FORCE_LEAF_MODULES # type: ignore[attr-defined] + if force_leaf_modules is None + else force_leaf_modules + ) + exclude_wrap_modules = ( + size_based_auto_wrap_policy.EXCLUDE_WRAP_MODULES # type: ignore[attr-defined] + if exclude_wrap_modules is None + else exclude_wrap_modules + ) + + # Keep the argument `min_num_params` for BC for now, but it represents the + # minimum non-wrapped *numel* before triggering a wrapping + min_nonwrapped_numel = min_num_params + is_large = nonwrapped_numel >= min_nonwrapped_numel + if recurse: + # We should recurse if the module is big enough but not in force_leaf_modules list. + return is_large and not isinstance(module, tuple(force_leaf_modules)) + else: + # If we are not recursing, determine if we should wrap. + return is_large and not isinstance(module, tuple(exclude_wrap_modules)) + + +# Set those defaults to the size_based_auto_wrap_policy function. Make them easy to be imported. +size_based_auto_wrap_policy.EXCLUDE_WRAP_MODULES = {nn.ModuleList, nn.ModuleDict} # type: ignore[attr-defined] +size_based_auto_wrap_policy.FORCE_LEAF_MODULES = {nn.MultiheadAttention} # type: ignore[attr-defined] + + +@contextlib.contextmanager +def enable_wrap( + *, wrapper_cls: Any, **wrapper_kwargs: Any +) -> Generator[None, None, None]: + """ + Context manager to wrap modules using a wrapper. + + Useful for when you'd like to apply the same configuration arguments to all + child modules that you wrap. A particularly important use case is wrapping + large layers so that they get sharded (in-place) during initialization, to + avoid running out of system memory. Large layers can indicate that they + should be sharded via the ``wrap`` annotation and this context manager can + provide the exact configuration for these nested instances. + + Usage:: + + with enable_wrap(wrapper_cls, **params): + # Wraps layer in FSDP by default if within context + self.l1 = wrap(torch.nn.Linear(5, 5)) + + Args: + wrapper_cls: + Class that `wrap` annotation will `wrap` modules with, such as + `FullyShardedDataParallel`. + **wrapper_kwargs: + Configuration settings that will be passed to all ``wrap`` + instances inside the context + """ + kwargs = { + "wrapper_cls": wrapper_cls, + **wrapper_kwargs, + } + with _ConfigAutoWrap(**kwargs): + yield + + +def wrap(module: nn.Module, **wrap_overrides: Any) -> nn.Module: + """ + Annotate that a module should be wrapped. Annotated modules will only be + wrapped if inside of an :func:`enable_wrap` context manager. This allows + a module to be initialized both with and without a wrapper without code + change. + + The class that this function wraps the passed in ``nn.Module`` with is the + passed in ``wrapper_cls`` argument into ``enable_wrap``. Both + ``enable_wrap`` and ``wrap`` can take in kwargs specifying how to construct + the ``wrapper_cls`` instance. In the case of duplicate kwargs in + ``enable_wrap`` and ``wrap``, the argument passed into ``wrap`` will be + respected. + + Usage:: + + with enable_wrap(wrapper_cls=FSDP, **fsdp_config): + # Wraps layer in FSDP by default if within context + self.l1 = wrap(torch.nn.Linear(5, 5)) + + Args: + module (nn.Module): module to wrap (if in :func:`enable_wrap` context) + **wrap_overrides: configuration overrides that will take priority over + the values provided by the :func:`enable_wrap` context + """ + if _ConfigAutoWrap.in_autowrap_context: + if _ConfigAutoWrap.wrapper_cls is None: + raise AssertionError("Expected _ConfigAutoWrap.wrapper_cls to be set") + + wrap_overrides = {**_ConfigAutoWrap.kwargs, **wrap_overrides} + return _wrap( + module, + _ConfigAutoWrap.wrapper_cls, + **wrap_overrides, + ) + return module + + +def _wrap(module: nn.Module, wrapper_cls: Callable, **kwargs) -> nn.Module: + if wrapper_cls is None: + raise AssertionError("Expected wrapper_cls to be set") + if hasattr(module, "_wrap_overrides"): + # If module has a _wrap_overrides attribute, we force overriding the + # FSDP config with these attributes for this module. Currently this + # is only used to disable mixed precision for BatchNorm when + # auto_wrapping. + overrides = {**kwargs, **module._wrap_overrides} # type: ignore[arg-type, dict-item] + return wrapper_cls(module, **overrides) + + return wrapper_cls(module, **kwargs) + + +def _recursive_wrap( + module: nn.Module, + auto_wrap_policy: Callable, + wrapper_cls: Callable, + ignored_modules: set[nn.Module], + ignored_params: set[nn.Parameter], + only_wrap_children: bool = False, + **kwargs: Any, +) -> tuple[nn.Module, int]: + """ + Wraps submodules of ``module`` for which ``auto_wrap_policy`` returns + ``True`` with ``wrapper_cls``. + + Args: + module (nn.Module): Module to recursively wrap. + auto_wrap_policy (Callable): A callable representing a policy that + determines which modules to recursively wrap with ``wrapper_cls``. + ignored_modules (set[torch.nn.Module]): Modules to ignore when + wrapping. + ignored_params (set[torch.nn.Parameter]): Parameters to ignore when + wrapping; these should be the parameters contained in the modules + in ``ignored_modules``. + Returns: + (nn.Module, int): + ``module`` after wrapping and the numel recursively wrapped. + """ + if auto_wrap_policy is None: + raise AssertionError("Must specify auto_wrap_policy.") + if wrapper_cls is None: + raise AssertionError("Must specify wrapper_cls") + # Make sure no child is already wrapped. + for _, child in module.named_modules(): + if child in ignored_modules: + continue + try: + if isinstance(child, cast(type, wrapper_cls)): + raise AssertionError( + f"Child module {child} is already wrapped by {wrapper_cls}" + ) + except TypeError: + # wrapper_cls is a function as opposed to a class type, just bypass above check. + pass + + # We count all params, assuming none of them are already wrapped. + nonwrapped_numel = sum( + p.numel() for p in module.parameters() if p not in ignored_params + ) + + if auto_wrap_policy is None: + raise AssertionError("Expected auto_wrap_policy to be set") + if auto_wrap_policy(module=module, recurse=True, nonwrapped_numel=nonwrapped_numel): + total_wrapped_numel = 0 + # Iterate through the children, recursively wrap if necessary + for name, child in module.named_children(): + if child in ignored_modules: + continue + wrapped_child, num_wrapped_params = _recursive_wrap( + module=child, + auto_wrap_policy=auto_wrap_policy, + wrapper_cls=wrapper_cls, + ignored_modules=ignored_modules, + ignored_params=ignored_params, + **kwargs, + ) + setattr(module, name, wrapped_child) + # Keep track of how many parameters have been wrapped + total_wrapped_numel += num_wrapped_params + # decide if we need to wrap the current module, + # since the left over parameters exceed the number of params to wrap + remainder = nonwrapped_numel - total_wrapped_numel + if not only_wrap_children and auto_wrap_policy( + module=module, recurse=False, nonwrapped_numel=remainder + ): + # Leaf node or final wrapping of the remainder both happen here. + return _wrap(module, wrapper_cls, **kwargs), nonwrapped_numel + else: + return module, total_wrapped_numel + return module, 0 + + +class _ConfigAutoWrap: + """ + Helper class to wrap modules based on default config args via a context manager. + See :func:`enable_wrap` for more information. + """ + + in_autowrap_context: bool = False # Context flag + wrapper_cls: Optional[Callable] = None # The wrapper class + kwargs: dict[str, Any] = {} # Wrapper's args + + def __init__(self, **kwargs: dict[str, Any]): + self.kwargs = kwargs + + @staticmethod + def enable_autowrap_context(kwargs: Any) -> None: + if _ConfigAutoWrap.in_autowrap_context: + raise NotImplementedError( + "You are already within an autowrap context and we currently do not supported nested autowrap." + ) + _ConfigAutoWrap.in_autowrap_context = True + # Get and save the wrapper cls for the context. + if "wrapper_cls" not in kwargs: + raise AssertionError( + "Expected to pass in wrapper_cls arg into _ConfigAutoWrap." + ) + _ConfigAutoWrap.wrapper_cls = cast(Callable, kwargs["wrapper_cls"]) + del kwargs["wrapper_cls"] + # Save the rest. + _ConfigAutoWrap.kwargs = kwargs + + @staticmethod + def disable_autowrap_context() -> None: + _ConfigAutoWrap.in_autowrap_context = False + _ConfigAutoWrap.wrapper_cls = None + _ConfigAutoWrap.kwargs = {} + + def __enter__(self) -> None: + self.enable_autowrap_context(self.kwargs) + + def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: + self.disable_autowrap_context() diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/launcher/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/launcher/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..fb744a2b93615b703eb0dafb7c8e6c71bc1ad5d2 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/launcher/__init__.py @@ -0,0 +1,14 @@ +#!/usr/bin/env/python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +from torch.distributed.launcher.api import ( # noqa: F401 + elastic_launch, + launch_agent, + LaunchConfig, +) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/launcher/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/launcher/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..349ee10501a13ddfa5895813e4c3075084400c9b Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/launcher/__pycache__/__init__.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/launcher/__pycache__/api.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/launcher/__pycache__/api.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a5e346c4c86f78da2cfceea607135b1b1a3d8927 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/launcher/__pycache__/api.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/launcher/api.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/launcher/api.py new file mode 100644 index 0000000000000000000000000000000000000000..2adf5549fecf13560d0c8637085872688c9454a4 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/launcher/api.py @@ -0,0 +1,337 @@ +#!/usr/bin/env python3 +# mypy: allow-untyped-defs + +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +import os +import sys +import uuid +from collections.abc import Callable +from dataclasses import dataclass, field +from typing import Any + +import torch +import torch.distributed.elastic.rendezvous.registry as rdzv_registry +from torch._utils_internal import get_default_numa_options +from torch.distributed.elastic import events, metrics +from torch.distributed.elastic.agent.server.api import WorkerSpec +from torch.distributed.elastic.agent.server.local_elastic_agent import LocalElasticAgent +from torch.distributed.elastic.multiprocessing import ( + DefaultLogsSpecs, + LogsSpecs, + SignalException, +) +from torch.distributed.elastic.multiprocessing.errors import ChildFailedError +from torch.distributed.elastic.rendezvous import RendezvousParameters +from torch.distributed.elastic.rendezvous.utils import parse_rendezvous_endpoint +from torch.distributed.elastic.utils.logging import get_logger +from torch.numa.binding import NumaOptions + + +__all__ = ["LaunchConfig", "elastic_launch", "launch_agent"] + +logger = get_logger(__name__) + + +@dataclass +class LaunchConfig: + """ + Creates a rendezvous config. + + Args: + min_nodes: Minimum amount of nodes that the user function will + be launched on. Elastic agent ensures that the user + function start only when the min_nodes amount enters + the rendezvous. + max_nodes: Maximum amount of nodes that the user function + will be launched on. + nproc_per_node: On each node the elastic agent will launch + this amount of workers that will execute user + defined function. + rdzv_backend: rdzv_backend to use in the rendezvous (zeus-adapter, etcd). + rdzv_endpoint: The endpoint of the rdzv sync. storage. + rdzv_configs: Key, value pair that specifies rendezvous specific configuration. + rdzv_timeout: Legacy argument that specifies timeout for the rendezvous. It is going + to be removed in future versions, see the note below. The default timeout is 900 seconds. + run_id: The unique run id of the job (if not passed a unique one will be + deduced from run environment - flow workflow id in flow - or auto generated). + role: User defined role of the worker (defaults to "trainer"). + max_restarts: The maximum amount of restarts that elastic agent will conduct + on workers before failure. + monitor_interval: The interval in seconds that is used by the elastic_agent + as a period of monitoring workers. + start_method: The method is used by the elastic agent to start the + workers (spawn, fork, forkserver). + metrics_cfg: configuration to initialize metrics. + local_addr: address of the local node if any. If not set, a lookup on the local + machine's FQDN will be performed. + local_ranks_filter: ranks for which to show logs in console. If not set, show from all. + event_log_handler: name of the event logging handler as registered in + `elastic/events/handlers.py `_. + duplicate_stdout_filters: If non-empty, duplicates stdout to a file containing only lines + that match _any_ of the filter strings. + duplicate_stderr_filters: If non-empty, duplicates stderr to a file containing only lines + that match _any_ of the filter strings. + virtual_local_rank: Enable virtual local rank mode for workers (defaults to False). + When enabled, LOCAL_RANK is set to 0 for all workers and + CUDA_VISIBLE_DEVICES is adjusted so each worker accesses its + assigned GPU at device index 0. + + + .. note:: + `rdzv_timeout` is a legacy argument that will be removed in future. + Set the timeout via `rdzv_configs['timeout']` + + """ + + min_nodes: int + max_nodes: int + nproc_per_node: int + logs_specs: LogsSpecs | None = None + run_id: str = "" + role: str = "default_role" + rdzv_endpoint: str = "" + rdzv_backend: str = "etcd" + rdzv_configs: dict[str, Any] = field(default_factory=dict) + rdzv_timeout: int = -1 + max_restarts: int = 3 + monitor_interval: float = 0.1 + start_method: str = "spawn" + log_line_prefix_template: str | None = None + metrics_cfg: dict[str, str] = field(default_factory=dict) + local_addr: str | None = None + event_log_handler: str = "null" + numa_options: NumaOptions | None = None + signals_to_handle: str = "SIGTERM,SIGINT,SIGHUP,SIGQUIT" + duplicate_stdout_filters: list[str] | None = None + duplicate_stderr_filters: list[str] | None = None + virtual_local_rank: bool = False + + def __post_init__(self): + default_timeout = 900 + if self.rdzv_timeout != -1: + self.rdzv_configs["timeout"] = self.rdzv_timeout + elif "timeout" not in self.rdzv_configs: + self.rdzv_configs["timeout"] = default_timeout + + # Post-processing to enable refactoring to introduce logs_specs due to non-torchrun API usage + if self.logs_specs is None: + self.logs_specs = DefaultLogsSpecs() + + if ( + self.numa_options is None + and torch.cuda.is_available() + # We assume local_rank n uses cuda device n. + and torch.cuda.device_count() == self.nproc_per_node + ): + self.numa_options = get_default_numa_options() + logger.info("Using default numa options = %r", self.numa_options) + + +class elastic_launch: + """ + Launches an torchelastic agent on the container that invoked the entrypoint. + + 1. Pass the ``entrypoint`` arguments as non ``kwargs`` (e.g. no named parameters)/ + ``entrypoint`` can be a function or a command. + 2. The return value is a map of each worker's output mapped + by their respective global rank. + + Usage + + :: + + def worker_fn(foo): + # ... + + def main(): + # entrypoint is a function. + outputs = elastic_launch(LaunchConfig, worker_fn)(foo) + # return rank 0's output + return outputs[0] + + # entrypoint is a command and ``script.py`` is the python module. + outputs = elastic_launch(LaunchConfig, "script.py")(args) + outputs = elastic_launch(LaunchConfig, "python")("script.py") + """ + + def __init__( + self, + config: LaunchConfig, + entrypoint: Callable | str | None, + ): + self._config = config + self._entrypoint = entrypoint + + def __call__(self, *args): + return launch_agent(self._config, self._entrypoint, list(args)) + + +def _get_entrypoint_name(entrypoint: Callable | str | None, args: list[Any]) -> str: + """Retrieve entrypoint name with the rule: + 1. If entrypoint is a function, use ``entrypoint.__qualname__``. + 2. If entrypoint is a string, check its value: + 2.1 if entrypoint equals to ``sys.executable`` (like "python"), use the first element from ``args`` + which does not start with hifen letter (for example, "-u" will be skipped). + 2.2 otherwise, use ``entrypoint`` value. + 3. Otherwise, return empty string. + """ + if isinstance(entrypoint, Callable): # type: ignore[arg-type] + return entrypoint.__name__ # type: ignore[union-attr] + elif isinstance(entrypoint, str): + if entrypoint == sys.executable: + return next((arg for arg in args if arg[0] != "-"), "") + else: + return entrypoint + else: + return "" + + +def _get_addr_and_port( + rdzv_parameters: RendezvousParameters, +) -> tuple[str | None, int | None]: + if rdzv_parameters.backend != "static": + return (None, None) + endpoint = rdzv_parameters.endpoint + endpoint = endpoint.strip() + if not endpoint: + raise ValueError( + "Endpoint is missing in endpoint. Try to add --master-addr and --master-port" + ) + master_addr, master_port = parse_rendezvous_endpoint(endpoint, default_port=-1) + if master_port == -1: + raise ValueError( + f"port is missing in endpoint: {endpoint}. Try to specify --master-port" + ) + return (master_addr, master_port) + + +def launch_agent( + config: LaunchConfig, + entrypoint: Callable | str | None, + args: list[Any], +) -> dict[int, Any]: + if not config.run_id: + run_id = str(uuid.uuid4().int) + logger.warning("config has no run_id, generated a random run_id: %s", run_id) + config.run_id = run_id + + entrypoint_name = _get_entrypoint_name(entrypoint, args) + + logger.info( + "Starting elastic_operator with launch configs:\n" + " entrypoint : %(entrypoint)s\n" + " min_nodes : %(min_nodes)s\n" + " max_nodes : %(max_nodes)s\n" + " nproc_per_node : %(nproc_per_node)s\n" + " run_id : %(run_id)s\n" + " rdzv_backend : %(rdzv_backend)s\n" + " rdzv_endpoint : %(rdzv_endpoint)s\n" + " rdzv_configs : %(rdzv_configs)s\n" + " max_restarts : %(max_restarts)s\n" + " monitor_interval : %(monitor_interval)s\n" + " log_dir : %(log_dir)s\n" + " metrics_cfg : %(metrics_cfg)s\n" + " event_log_handler : %(event_log_handler)s\n" + " numa_options : %(numa_options)s\n" + " signals_to_handle : %(signals_to_handle)s\n" + " duplicate_stdout_filters : %(duplicate_stdout_filters)s\n" + " duplicate_stderr_filters : %(duplicate_stderr_filters)s\n", + { + "entrypoint": entrypoint_name, + "min_nodes": config.min_nodes, + "max_nodes": config.max_nodes, + "nproc_per_node": config.nproc_per_node, + "run_id": config.run_id, + "rdzv_backend": config.rdzv_backend, + "rdzv_endpoint": config.rdzv_endpoint, + "rdzv_configs": config.rdzv_configs, + "max_restarts": config.max_restarts, + "monitor_interval": config.monitor_interval, + "log_dir": config.logs_specs.root_log_dir, # type: ignore[union-attr] + "metrics_cfg": config.metrics_cfg, + "event_log_handler": config.event_log_handler, + "numa_options": config.numa_options, + "signals_to_handle": config.signals_to_handle, + "duplicate_stdout_filters": config.duplicate_stdout_filters, + "duplicate_stderr_filters": config.duplicate_stderr_filters, + }, + ) + + rdzv_parameters = RendezvousParameters( + backend=config.rdzv_backend, + endpoint=config.rdzv_endpoint, + run_id=config.run_id, + min_nodes=config.min_nodes, + max_nodes=config.max_nodes, + local_addr=config.local_addr, + **config.rdzv_configs, + ) + + master_addr, master_port = _get_addr_and_port(rdzv_parameters) + + # Set the signals to handle in the environment variable + os.environ["TORCHELASTIC_SIGNALS_TO_HANDLE"] = config.signals_to_handle + + spec = WorkerSpec( + role=config.role, + local_world_size=config.nproc_per_node, + entrypoint=entrypoint, + args=tuple(args), + rdzv_handler=rdzv_registry.get_rendezvous_handler(rdzv_parameters), + max_restarts=config.max_restarts, + monitor_interval=config.monitor_interval, + master_addr=master_addr, + master_port=master_port, + local_addr=config.local_addr, + event_log_handler=config.event_log_handler, + numa_options=config.numa_options, + duplicate_stdout_filters=config.duplicate_stdout_filters, + duplicate_stderr_filters=config.duplicate_stderr_filters, + virtual_local_rank=config.virtual_local_rank, + ) + + agent = LocalElasticAgent( + spec=spec, + logs_specs=config.logs_specs, # type: ignore[arg-type] + start_method=config.start_method, + log_line_prefix_template=config.log_line_prefix_template, + ) + + shutdown_rdzv = True + try: + metrics.initialize_metrics(metrics.MetricsConfig(config.metrics_cfg)) + + result = agent.run() + # records that agent.run() has succeeded NOT that workers have succeeded + events.record(agent.get_event_succeeded(), config.event_log_handler) + + if result.is_failed(): + # ChildFailedError is treated specially by @record + # if the error files for the failed children exist + # @record will copy the first error (root cause) + # to the error file of the launcher process. + raise ChildFailedError( + name=entrypoint_name, + failures=result.failures, + ) + + return result.return_values + except ChildFailedError: + raise + except SignalException: + # when the agent dies with a signal do NOT shutdown the rdzv_handler + # since this closes the rendezvous on this rdzv_id permanently and + # prevents any additional scaling events + shutdown_rdzv = False + events.record(agent.get_event_failed(), config.event_log_handler) + raise + except Exception: + events.record(agent.get_event_failed(), config.event_log_handler) + raise + finally: + if shutdown_rdzv: + spec.rdzv_handler.shutdown() diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/nn/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/nn/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e15fb517052e4aefeb7377d1f0ca63cf2b2da753 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/nn/__init__.py @@ -0,0 +1,7 @@ +import torch + +from .functional import * # noqa: F403 + + +if torch.distributed.rpc.is_available(): + from .api.remote_module import RemoteModule diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/nn/functional.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/nn/functional.py new file mode 100644 index 0000000000000000000000000000000000000000..287775be924a399aff01fcda66f6ebc838c62873 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/nn/functional.py @@ -0,0 +1,469 @@ +# mypy: allow-untyped-defs +import torch +import torch.distributed as dist +from torch.autograd import Function + +# The two imports below are not always available depending on the +# USE_DISTRIBUTED compile flag. Make sure they raise import error +# if we're trying to use them. +from torch.distributed import group, ReduceOp + + +def broadcast(tensor, src, group=group.WORLD): + """ + Broadcasts the tensor to the whole group. + + ``tensor`` must have the same number of elements in all processes + participating in the collective. + + Arguments: + tensor (Tensor): Data to be sent if ``src`` is the rank of current + process. + src (int): Source rank. + group (ProcessGroup, optional): The process group to work on. + + Returns: + Tensor: Received tensor from the broadcast op. + + """ + return _Broadcast.apply(src, group, tensor) + + +def gather(tensor, dst=0, group=group.WORLD): + """ + Gathers a list of tensors in a single process. + + Arguments: + tensor (Tensor): Input tensor. + dst (int, optional): Destination rank (default is 0). + group (ProcessGroup, optional): The process group to work on. + + Returns: + tuple[Tensor]: List of appropriately-sized tensors with the gathered data. + """ + return _Gather.apply(dst, group, tensor) + + +def scatter(tensors, src=0, group=group.WORLD): + """ + Scatters a list of tensors to all processes in a group. + + Each process will receive exactly one tensor and store its data in the + ``tensor`` argument. + + Arguments: + tensors (list[Tensor]): List of tensors to scatter on the source rank. + Receivers must pass ``None`. + src (int, optional): Source rank (default is 0). + group (ProcessGroup, optional): The process group to work on. + + Returns: + Tensor: Output tensor from the scatter operation. + + """ + return _Scatter.apply(src, group, *tensors) + + +def reduce(tensor, dst, op=ReduceOp.SUM, group=group.WORLD): + """ + Reduces the tensor data across all machines. + + Only the process with rank ``dst`` is going to receive the final result. + + Arguments: + tensor (Tensor): Input of the collective. + dst (int): Destination rank. + op (optional): One of the values from + ``torch.distributed.ReduceOp`` + enum. Specifies an operation used for element-wise reductions. + group (ProcessGroup, optional): The process group to work on. + + Returns: + Tensor: Output of the collective. + + """ + return _Reduce.apply(dst, op, group, tensor) + + +def reduce_scatter(output, input_list, op=ReduceOp.SUM, group=group.WORLD): + """ + Reduces, then scatters a list of tensors to all processes in a group. + + Arguments: + output (Tensor): Output tensor. + input_list (list[Tensor]): List of tensors to reduce and scatter. + op (optional): One of the values from + ``torch.distributed.ReduceOp`` + enum. Specifies an operation used for element-wise reductions. + group (ProcessGroup, optional): The process group to work on. + + Returns: + Tensor: Output of the collective. + + """ + return _Reduce_Scatter.apply(op, group, output, *input_list) + + +def all_gather(tensor, group=group.WORLD): + """ + Gathers tensors from the whole group in a list. + + Arguments: + tensor (Tensor): Tensor to be broadcast from current process. + group (ProcessGroup, optional): The process group to work on. + + Returns: + tuple([Tensor]): Output of the collective. + + """ + return _AllGather.apply(group, tensor) + + +def _all_gather_base(output_tensor, input_tensor, group=group.WORLD): + """ + Single tensor all gather. Gathers a single tensor from all ranks, and puts them in a single output tensor. + + Args: + output_tensor (Tensor): Output tensor. It should contain + correctly-sized tensors to be used for output of the collective. + input_tensor (Tensor): Tensor to be broadcast from current process. + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. + + Examples: + >>> # All tensors below are of torch.int64 dtype. + >>> # We have 2 process groups, 2 ranks. + >>> # xdoctest: +SKIP("incorrect want text") + >>> output_tensor = torch.zeros(2, dtype=torch.int64) + >>> output_tensor + [tensor([0, 0])] # Rank 0 and 1 + >>> tensor = torch.arange(1, dtype=torch.int64) + 1 + rank + >>> tensor + tensor([1]) # Rank 0 + tensor([2]) # Rank 1 + >>> dist.all_gather_base(output_tensor, tensor) + >>> output_tensor + tensor([1,2]) # Rank 0 + tensor([1,2]) # Rank 1 + + .. warning:: + `_all_gather_base` is experimental and subject to change. + It is the caller's responsibility to ensure the output_tensor + is correctly sized. + + """ + return _AllGatherBase.apply(output_tensor, input_tensor, group) + + +def all_to_all(output_tensor_list, input_tensor_list, group=group.WORLD): + """ + Each process scatters list of input tensors to all processes in a group and return gathered list of tensors in output list. + + Arguments: + output_tensor_list (list[Tensor]): list of tensors to gather one per rank. + input_tensor_list (list[Tensor]): List of tensors to scatter one per rank. + group (ProcessGroup, optional): The process group to work on. + + Returns: + tuple([Tensor]): Output of the collective. + + """ + return _AlltoAll.apply(group, output_tensor_list, *input_tensor_list) + + +def all_to_all_single( + output, + input, + output_split_sizes=None, + input_split_sizes=None, + group=group.WORLD, +): + """ + Each process splits input tensor and then scatters the split list to all processes in a group. + + Then concatenate the received tensors from all the processes in the group and return single output tensor. + + Arguments: + output (Tensor): Gathered concatenated output tensor. + input (Tensor): Input tensor to scatter. + output_split_sizes: (list[Int], optional): Output split sizes for dim 0 + if specified None or empty, dim 0 of ``output`` tensor must divide + equally by ``world_size``. + input_split_sizes: (list[Int], optional): Input split sizes for dim 0 + if specified None or empty, dim 0 of ``input`` tensor must divide + equally by ``world_size``. + + Returns: + Tensor: Output of the collective. + + """ + return _AlltoAllSingle.apply( + group, output, output_split_sizes, input_split_sizes, input + ) + + +def all_reduce(tensor, op=ReduceOp.SUM, group=group.WORLD): + """ + Reduces the tensor data across all machines in such a way that all get the final result. + + After the call the returned tensor is going to be bitwise + identical in all processes. + + Arguments: + tensor (Tensor): Input of the collective. + op (optional): One of the values from + ``torch.distributed.ReduceOp`` + enum. Specifies an operation used for element-wise reductions. + group (ProcessGroup, optional): The process group to work on. + + Returns: + Tensor: Output of the collective + + """ + return _AllReduce.apply(op, group, tensor) + + +class _Broadcast(Function): + @staticmethod + # pyrefly: ignore [bad-override] + def forward(ctx, src, group, tensor): + ctx.src = src + ctx.group = group + ctx.rank = dist.get_rank(group=group) + # torch.distributed makes all the calls in place + # we allocate new tensors to avoid this + tensor = tensor.clone() + dist.broadcast(tensor, src, group=group) + return tensor + + @staticmethod + # pyrefly: ignore [bad-override] + def backward(ctx, grad_output): + gx = _Reduce.apply(ctx.src, ReduceOp.SUM, ctx.group, grad_output) + if ctx.src != ctx.rank: + gx.zero_() + return (None, None, gx) + + +class _Gather(Function): + @staticmethod + # pyrefly: ignore [bad-override] + def forward(ctx, dst, group, tensor): + ctx.dst = dst + ctx.group = group + # Need to create a list of tensors here to do the + # aggregation, get it from the group size + # tensor should be correctly sized for the method + # gathering + tensor_list = [ + torch.zeros_like(tensor) for i in range(dist.get_world_size(group=group)) + ] + + tensor = tensor.contiguous() + if dist.get_rank(group=group) == dst: + dist.gather(tensor, tensor_list, dst, group=group) + else: + dist.gather(tensor, None, dst, group=group) + return tuple(tensor_list) + + @staticmethod + def backward(ctx, *grad_outputs): + return (None, None) + (_Scatter.apply(ctx.dst, ctx.group, *grad_outputs),) + + +class _Scatter(Function): + @staticmethod + # pyrefly: ignore [bad-override] + def forward(ctx, src, group, *tensors): + ctx.src = src + ctx.group = group + assert all(t.size() == tensors[0].size() for t in tensors) + output = torch.zeros_like(tensors[0]) + if dist.get_rank(group=group) == src: + dist.scatter(output, list(tensors), src, group=group) + else: + dist.scatter(output, None, src, group=group) + return output + + @staticmethod + # pyrefly: ignore [bad-override] + def backward(ctx, grad_output): + return (None, None) + _Gather.apply(ctx.src, ctx.group, grad_output) + + +class _Reduce(Function): + @staticmethod + # pyrefly: ignore [bad-override] + def forward(ctx, src, op, group, tensor): + ctx.src = src + ctx.group = group + tensor = tensor.clone() + dist.reduce(tensor, src, op=op, group=group) + return tensor + + @staticmethod + # pyrefly: ignore [bad-override] + def backward(ctx, grad_output): + return (None, None, None) + (_Broadcast.apply(ctx.src, ctx.group, grad_output),) + + +class _Reduce_Scatter(Function): + @staticmethod + # pyrefly: ignore [bad-override] + def forward(ctx, op, group, tensor, *input_tensor_list): + ctx.group = group + # Need contiguous tensors for collectives. + tensor = tensor.contiguous() + input_tensor_list = tuple(t.contiguous() for t in input_tensor_list) + dist.reduce_scatter(tensor, list(input_tensor_list), op=op, group=group) + return tensor + + @staticmethod + # pyrefly: ignore [bad-override] + def backward(ctx, grad_output): + return (None, None, None) + _AllGather.apply(ctx.group, grad_output) + + +class _AllGather(Function): + @staticmethod + # pyrefly: ignore [bad-override] + def forward(ctx, group, tensor): + # Need contiguous tensors for collectives. + tensor = tensor.contiguous() + + ctx.group = group + out_tensor_list = [ + torch.empty_like(tensor) for _ in range(dist.get_world_size(group=group)) + ] + + dist.all_gather(out_tensor_list, tensor, group=group) + return tuple(out_tensor_list) + + @staticmethod + def backward(ctx, *grad_outputs): + if dist.get_backend(group=ctx.group) is dist.Backend.NCCL: + rank = dist.get_rank(group=ctx.group) + gx = torch.empty_like(grad_outputs[rank]) + gx = _Reduce_Scatter.apply(ReduceOp.SUM, ctx.group, gx, *grad_outputs) + else: + # As many backends doesn't support ReduceScatter, we use AlltoAll with .sum() + # to emulate the ReduceScatter behavior + tensor_list = [torch.empty_like(tensor) for tensor in grad_outputs] + gxs = _AlltoAll.apply(ctx.group, tensor_list, *grad_outputs) + gx = torch.sum(torch.stack(gxs), dim=0) + return (None, gx) + + +class _AllGatherBase(Function): + @staticmethod + # pyrefly: ignore [bad-override] + def forward(ctx, output_tensor, input_tensor, group): + ctx.group = group + dist._all_gather_base(output_tensor, input_tensor.contiguous(), group=group) + return output_tensor + + @staticmethod + # pyrefly: ignore [bad-override] + def backward(ctx, grad_output): + if dist.get_backend(group=ctx.group) is dist.Backend.NCCL: + world_size = dist.get_world_size(group=ctx.group) + out_size = list(grad_output.size()) + if out_size[0] % world_size != 0: + raise RuntimeError( + f"Tensor with dimensions: {out_size} does " + f"not have first dimension divisible by world_size: {world_size}" + ) + out_size[0] = out_size[0] // dist.get_world_size(group=ctx.group) + gx = torch.empty( + out_size, device=grad_output.device, dtype=grad_output.dtype + ) + dist._reduce_scatter_base(gx, grad_output, ReduceOp.SUM, ctx.group) + else: + raise RuntimeError("Backend not supported!") + return (None, gx, None) + + +class _AlltoAll(Function): + @staticmethod + # pyrefly: ignore [bad-override] + def forward(ctx, group, out_tensor_list, *tensors): + ctx.group = group + ctx.input_tensor_size_list = [ + tensors[i].size() for i in range(dist.get_world_size(group=group)) + ] + my_rank = dist.get_rank(group=group) + tensors = tuple(t.contiguous() for t in tensors) + # Implement it on means of scatter/gather, send/recv async operations have issues + if dist.get_backend(group=group) is dist.Backend.GLOO: + for i in range(dist.get_world_size(group=group)): + to_send = None + if i == my_rank: + to_send = list(tensors) + dist.scatter(out_tensor_list[i], to_send, i, group=group) + else: + dist.all_to_all( + out_tensor_list, + list(tensors), + group=group, + ) + return tuple(out_tensor_list) + + @staticmethod + def backward(ctx, *grad_outputs): + tensor_list = [ + torch.empty( + size, device=grad_outputs[0].device, dtype=grad_outputs[0].dtype + ) + for size in ctx.input_tensor_size_list + ] + return (None, None) + _AlltoAll.apply(ctx.group, tensor_list, *grad_outputs) + + +class _AlltoAllSingle(Function): + @staticmethod + # pyrefly: ignore [bad-override] + def forward(ctx, group, output, output_split_sizes, input_split_sizes, input): + ctx.group = group + ctx.input_size = input.size() + ctx.output_split_sizes = input_split_sizes + ctx.input_split_sizes = output_split_sizes + dist.all_to_all_single( + output, + input, + output_split_sizes=output_split_sizes, + input_split_sizes=input_split_sizes, + group=group, + ) + return output + + @staticmethod + # pyrefly: ignore [bad-override] + def backward(ctx, grad_output): + tensor = torch.empty( + ctx.input_size, device=grad_output.device, dtype=grad_output.dtype + ) + return (None, None, None, None) + ( + _AlltoAllSingle.apply( + ctx.group, + tensor, + ctx.output_split_sizes, + ctx.input_split_sizes, + grad_output.contiguous(), + ), + ) + + +class _AllReduce(Function): + @staticmethod + # pyrefly: ignore [bad-override] + def forward(ctx, op, group, tensor): + ctx.group = group + ctx.op = op + tensor = tensor.clone(memory_format=torch.contiguous_format) + dist.all_reduce(tensor, op=op, group=group) + return tensor + + @staticmethod + # pyrefly: ignore [bad-override] + def backward(ctx, grad_output): + return (None, None) + (_AllReduce.apply(ctx.op, ctx.group, grad_output),) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/optim/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/optim/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..faac68bb632934ba730ba7c5ce3cf7fe934a58cf --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/optim/__init__.py @@ -0,0 +1,44 @@ +""" +:mod:`torch.distributed.optim` exposes DistributedOptimizer, which takes a list +of remote parameters (:class:`~torch.distributed.rpc.RRef`) and runs the +optimizer locally on the workers where the parameters live. The distributed +optimizer can use any of the local optimizer :ref:`optimizer-algorithms` to +apply the gradients on each worker. +""" + +import warnings + +import torch +from torch import optim + +from .apply_optimizer_in_backward import ( + _apply_optimizer_in_backward, + _get_in_backward_optimizers, +) +from .functional_adadelta import _FunctionalAdadelta +from .functional_adagrad import _FunctionalAdagrad +from .functional_adam import _FunctionalAdam +from .functional_adamax import _FunctionalAdamax +from .functional_adamw import _FunctionalAdamW +from .functional_rmsprop import _FunctionalRMSprop +from .functional_rprop import _FunctionalRprop +from .functional_sgd import _FunctionalSGD +from .named_optimizer import _NamedOptimizer +from .utils import as_functional_optim + + +# DistributedOptimizer imports torch.distributed.rpc names, so gate availability +# based on RPC being available. +if hasattr(torch._C, "_rpc_init"): + from .optimizer import DistributedOptimizer + +from .post_localSGD_optimizer import PostLocalSGDOptimizer +from .zero_redundancy_optimizer import ZeroRedundancyOptimizer + + +__all__ = [ + "as_functional_optim", + "DistributedOptimizer", + "PostLocalSGDOptimizer", + "ZeroRedundancyOptimizer", +] diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/optim/_deprecation_warning.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/optim/_deprecation_warning.py new file mode 100644 index 0000000000000000000000000000000000000000..c3434a4cd4f081843295e488c18a67a5c297fcbf --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/optim/_deprecation_warning.py @@ -0,0 +1,16 @@ +import warnings + +import torch + + +@torch.jit.ignore # type: ignore[misc] +def _scripted_functional_optimizer_deprecation_warning(stacklevel: int = 0) -> None: + with warnings.catch_warnings(): + warnings.simplefilter("always") + warnings.warn( + "`TorchScript` support for functional optimizers is deprecated " + "and will be removed in a future PyTorch release. " + "Consider using the `torch.compile` optimizer instead.", + DeprecationWarning, + stacklevel=stacklevel + 2, + ) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/optim/apply_optimizer_in_backward.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/optim/apply_optimizer_in_backward.py new file mode 100644 index 0000000000000000000000000000000000000000..1ff9854793df1aa96a27cb105a1afd1190df942a --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/optim/apply_optimizer_in_backward.py @@ -0,0 +1,121 @@ +from collections.abc import Iterable +from typing import Any, no_type_check + +import torch + + +__all__: list[str] = [] + +# WeakTensorKeyDictionary to store relevant meta-data for the Tensor/Parameter +# without changing it's life-time. +# NOTE: Alternative is to add the meta-data as an attribute to the tensor, +# but that will serialize the meta-data if Tensor is serialized. +param_to_optim_hook_handle_map = torch.utils.weak.WeakTensorKeyDictionary() +param_to_acc_grad_map = torch.utils.weak.WeakTensorKeyDictionary() + + +@no_type_check +def _apply_optimizer_in_backward( + optimizer_class: type[torch.optim.Optimizer], + params: Iterable[torch.nn.Parameter], + optimizer_kwargs: dict[str, Any], + register_hook: bool = True, +) -> None: + """ + Upon ``backward()``, the optimizer specified for each parameter will fire after + the gradient has been accumulated into the parameter. + + Note - gradients for these parameters will be set to None after ``backward()``. + This means that any other optimizer not specified via `_apply_optimizer_in_backward` + over this parameter will be a no-op. + + Args: + optimizer_class: (Type[torch.optim.Optimizer]): Optimizer to apply to parameter + params: (Iterator[nn.Parameter]): parameters to apply optimizer state to + optimizer_kwargs: (Dict[str, Any]): kwargs to pass to optimizer constructor + register_hook: (bool): whether to register a hook that runs the optimizer + after gradient for this parameter is accumulated. This is the default + way that optimizer in backward is implemented, but specific use cases + (such as DDP) may wish to override this to implement custom behavior. + (Default = True) + + Example:: + params_generator = model.parameters() + param_1 = next(params_generator) + remainder_params = list(params_generator) + + apply_optimizer_in_backward(torch.optim.SGD, [param_1], {"lr": 0.02}) + apply_optimizer_in_backward(torch.optim.Adam, remainder_params, {"lr": 0.04}) + + model(...).sum().backward() # after backward, parameters will already + # have their registered optimizer(s) applied. + + """ + torch._C._log_api_usage_once("torch.distributed.optim.apply_optimizer_in_backward") + + @no_type_check + def _apply_optimizer_in_backward_to_param(param: torch.nn.Parameter) -> None: + # view_as creates a node in autograd graph that allows us access to the + # parameter's AccumulateGrad autograd function object. We register a + # hook on this object to fire the optimizer when the gradient for + # this parameter is ready (has been accumulated into .grad field) + + # Don't create a new acc_grad if we already have one + # i.e. for shared parameters or attaching multiple optimizers to a param. + if param not in param_to_acc_grad_map: + param_to_acc_grad_map[param] = param.view_as(param).grad_fn.next_functions[ + 0 + ][0] + + optimizer = optimizer_class([param], **optimizer_kwargs) + + if not hasattr(param, "_in_backward_optimizers"): + param._in_backward_optimizers = [] # type: ignore[attr-defined] + # TODO: Remove these attributes once we have a better way of accessing + # optimizer classes and kwargs for a parameter. + param._optimizer_classes = [] # type: ignore[attr-defined] + param._optimizer_kwargs = [] # type: ignore[attr-defined] + + param._in_backward_optimizers.append(optimizer) # type: ignore[attr-defined] + param._optimizer_classes.append(optimizer_class) # type: ignore[attr-defined] + param._optimizer_kwargs.append(optimizer_kwargs) # type: ignore[attr-defined] + + if not register_hook: + return + + def optimizer_hook(*_unused) -> None: + for opt in param._in_backward_optimizers: # type: ignore[attr-defined] + opt.step() + + param.grad = None + + handle = param_to_acc_grad_map[param].register_hook(optimizer_hook) # type: ignore[attr-defined] + if param not in param_to_optim_hook_handle_map: + param_to_optim_hook_handle_map[param] = [] + param_to_optim_hook_handle_map[param].append(handle) + + for param in params: + _apply_optimizer_in_backward_to_param(param) + + +def _get_in_backward_optimizers(module: torch.nn.Module) -> list[torch.optim.Optimizer]: + """ + Return a list of in-backward optimizers applied to ``module``'s parameters. Note that these + optimizers are not intended to directly have their ``step`` or ``zero_grad`` methods called + by the user and are intended to be used for things like checkpointing. + + Args: + module: (torch.nn.Module): model to retrieve in-backward optimizers for + + Returns: + List[torch.optim.Optimizer]: the in-backward optimizers. + + Example:: + _apply_optimizer_in_backward(torch.optim.SGD, model.parameters(), {"lr": 0.01}) + optims = _get_optimizers_in_backward(model) + """ + optims: list[torch.optim.Optimizer] = [] + for param in module.parameters(): + optims.extend(getattr(param, "_in_backward_optimizers", [])) + + return optims diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/optim/functional_adadelta.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/optim/functional_adadelta.py new file mode 100644 index 0000000000000000000000000000000000000000..e8455c5ef5a41613dc15140b6c562ceb3134ca4e --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/optim/functional_adadelta.py @@ -0,0 +1,110 @@ +# mypy: allow-untyped-defs + +import torch +import torch.optim._functional as F +from torch import Tensor +from torch.distributed.optim._deprecation_warning import ( + _scripted_functional_optimizer_deprecation_warning, +) + + +__all__: list[str] = [] + + +# Define a TorchScript compatible Functional Adadelta Optimizer +# where we use these optimizer in a functional way. +# Instead of using the `param.grad` when updating parameters, +# we explicitly allow the distributed optimizer pass gradients to +# the `step` function. In this way, we could separate the gradients +# and parameters and allow multithreaded trainer to update the +# parameters without data traces on accumulating to the same .grad. +# NOTE: This should be only used by distributed optimizer internals +# and not meant to expose to the user. +@torch.jit.script +class _FunctionalAdadelta: + def __init__( + self, + params: list[Tensor], + lr: float = 1.0, + rho: float = 0.9, + eps: float = 1e-6, + weight_decay: float = 0.0, + foreach: bool = False, + maximize: bool = False, + _allow_empty_param_list: bool = False, + ): + _scripted_functional_optimizer_deprecation_warning(stacklevel=2) + self.defaults = { + "lr": lr, + "rho": rho, + "eps": eps, + "weight_decay": weight_decay, + } + self.foreach = foreach + self.maximize = maximize + + if len(params) == 0 and not _allow_empty_param_list: + raise ValueError("optimizer got an empty parameter list") + + # NOTE: we only have one param_group and don't allow user to add additional + # param group as it's not a common use case. + self.param_group = {"params": params} + + self.state = torch.jit.annotate(dict[torch.Tensor, dict[str, torch.Tensor]], {}) + + def step(self, gradients: list[Tensor | None]): + params = self.param_group["params"] + params_with_grad = [] + grads = [] + square_avgs = [] + acc_deltas = [] + state_steps = [] + lr = self.defaults["lr"] + rho = self.defaults["rho"] + eps = self.defaults["eps"] + weight_decay = self.defaults["weight_decay"] + + if len(params) != len(gradients): + raise ValueError( + "the gradients passed in does not equal to the size of the parameters!" + + f"Params length: {len(params)}. " + + f"Gradients length: {len(gradients)}" + ) + has_complex = False + for param, gradient in zip(params, gradients): + if gradient is not None: + has_complex |= torch.is_complex(param) + params_with_grad.append(param) + grads.append(gradient) + # Lazy state initialization + if param not in self.state: + self.state[param] = {} + state = self.state[param] + state["step"] = torch.tensor(0.0) + state["square_avg"] = torch.zeros_like( + param, memory_format=torch.preserve_format + ) + state["acc_delta"] = torch.zeros_like( + param, memory_format=torch.preserve_format + ) + + state = self.state[param] + square_avgs.append(state["square_avg"]) + acc_deltas.append(state["acc_delta"]) + state_steps.append(state["step"]) + + with torch.no_grad(): + F.adadelta( + params_with_grad, + grads, + square_avgs, + acc_deltas, + state_steps, + lr=lr, + rho=rho, + eps=eps, + weight_decay=weight_decay, + foreach=self.foreach, + maximize=self.maximize, + has_complex=has_complex, + ) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/optim/functional_adagrad.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/optim/functional_adagrad.py new file mode 100644 index 0000000000000000000000000000000000000000..3da4e29b3f0154ab58206c835f80a24ae208a05c --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/optim/functional_adagrad.py @@ -0,0 +1,114 @@ +# mypy: allow-untyped-defs + +import torch +import torch.optim._functional as F +from torch import Tensor +from torch.distributed.optim._deprecation_warning import ( + _scripted_functional_optimizer_deprecation_warning, +) + + +__all__: list[str] = [] + + +# Define a TorchScript compatible Functional Adagrad Optimizer +# where we use these optimizer in a functional way. +# Instead of using the `param.grad` when updating parameters, +# we explicitly let the user pass gradients to the `step` function +# this is so that we could separate the gradients and parameters +# and allow multithreaded trainer to update the parameters +# without data traces on accumulating to the same .grad. +# NOTE: This should be only used by distributed optimizer internals +# and not meant to expose to the user. +@torch.jit.script +class _FunctionalAdagrad: + def __init__( + self, + params: list[Tensor], + lr: float = 1e-2, + lr_decay: float = 0.0, + weight_decay: float = 0.0, + initial_accumulator_value: float = 0.0, + warmup_lr_multiplier: float = 1.0, + warmup_num_iters: float = 0.0, + eps: float = 1e-10, + coalesce_grad: bool = True, + foreach: bool = False, + fused: bool = False, + maximize: bool = False, + _allow_empty_param_list: bool = False, + ): + _scripted_functional_optimizer_deprecation_warning(stacklevel=2) + self.defaults = { + "lr": lr, + "lr_decay": lr_decay, + "eps": eps, + "weight_decay": weight_decay, + "initial_accumulator_value": initial_accumulator_value, + "warmup_lr_multiplier": warmup_lr_multiplier, + "warmup_num_iters": warmup_num_iters, + } + self.coalesce_grad = coalesce_grad + self.foreach = foreach + self.fused = fused + self.maximize = maximize + self.state = torch.jit.annotate(dict[torch.Tensor, dict[str, torch.Tensor]], {}) + + if len(params) == 0 and not _allow_empty_param_list: + raise ValueError("optimizer got an empty parameter list") + + # NOTE: we only have one param_group and don't allow user to add additional + # param group as it's not a common use case. + self.param_group = {"params": params} + + # TODO: no union or any types in TorchScript, make step a scalar tensor instead + # This is also needed by if we want to share_memory on the step across processes + for p in self.param_group["params"]: + self.state[p] = { + "sum": torch.full_like(p.data, initial_accumulator_value), + "step": torch.tensor(0.0), + } + + def step(self, gradients: list[Tensor | None]): + params = self.param_group["params"] + params_with_grad = [] + grads = [] + state_sums = [] + state_steps: list[Tensor] = [] + + if len(params) != len(gradients): + raise ValueError( + "the gradients passed in does not equal to the size of the parameters!" + + f"Params length: {len(params)}. " + + f"Gradients length: {len(gradients)}" + ) + + has_sparse_grad, has_complex = False, False + for param, gradient in zip(self.param_group["params"], gradients): + if gradient is not None: + has_sparse_grad |= gradient.is_sparse + has_complex |= torch.is_complex(param) + params_with_grad.append(param) + grads.append(gradient) + state = self.state[param] + state_sums.append(state["sum"]) + state_steps.append(state["step"]) + + with torch.no_grad(): + F.adagrad( + params, + grads, + state_sums, + state_steps, + lr=self.defaults["lr"], + weight_decay=self.defaults["weight_decay"], + lr_decay=self.defaults["lr_decay"], + eps=self.defaults["eps"], + has_sparse_grad=has_sparse_grad, + foreach=self.foreach, + maximize=self.maximize, + has_complex=has_complex, + fused=self.fused, + grad_scale=None, + found_inf=None, + ) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/optim/functional_adam.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/optim/functional_adam.py new file mode 100644 index 0000000000000000000000000000000000000000..1763edd14c9da1c19081fcc1334e267c889472c1 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/optim/functional_adam.py @@ -0,0 +1,201 @@ +# mypy: allow-untyped-defs + +import torch +import torch.optim._functional as F +from torch import Tensor +from torch.distributed.optim._deprecation_warning import ( + _scripted_functional_optimizer_deprecation_warning, +) + + +__all__: list[str] = [] + + +# Define a TorchScript compatible Functional Adam Optimizer +# where we use these optimizer in a functional way. +# Instead of using the `param.grad` when updating parameters, +# we explicitly allow the distributed optimizer pass gradients to +# the `step` function. In this way, we could separate the gradients +# and parameters and allow multithreaded trainer to update the +# parameters without data traces on accumulating to the same .grad. +# NOTE: This should be only used by distributed optimizer internals +# and not meant to expose to the user. +@torch.jit.script +class _FunctionalAdam: + def __init__( + self, + params: list[Tensor], + lr: float = 1e-3, + betas: tuple[float, float] = (0.9, 0.999), + eps: float = 1e-8, + weight_decay: float = 0.0, + amsgrad: bool = False, + maximize: bool = False, + foreach: bool = False, + fused: bool = False, + _allow_empty_param_list: bool = False, + ): + _scripted_functional_optimizer_deprecation_warning(stacklevel=2) + if not 0.0 <= lr: + raise ValueError(f"Invalid learning rate: {lr}") + if not 0.0 <= eps: + raise ValueError(f"Invalid epsilon value: {eps}") + if not 0.0 <= betas[0] < 1.0: + raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}") + if not 0.0 <= betas[1] < 1.0: + raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}") + if not 0.0 <= weight_decay: + raise ValueError(f"Invalid weight_decay value: {weight_decay}") + + self.defaults = { + "lr": lr, + "eps": eps, + "beta1": betas[0], + "beta2": betas[1], + "weight_decay": weight_decay, + } + self.amsgrad = amsgrad + self.maximize = maximize + self.foreach = foreach + self.fused = fused + self.state = torch.jit.annotate(dict[torch.Tensor, dict[str, torch.Tensor]], {}) + + if len(params) == 0 and not _allow_empty_param_list: + raise ValueError("optimizer got an empty parameter list") + + # NOTE: we only have one param_group and don't allow user to add additional + # param group as it's not a common use case. + self.param_group = {"params": params} + + def step_param(self, param: Tensor, grad: Tensor | None): + """ + Similar to step, but operates on a single parameter and optionally a + gradient tensor. + """ + params_with_grad = [] + grads = [] + exp_avgs = [] + exp_avg_sqs = [] + max_exp_avg_sqs = [] + state_steps: list[Tensor] = [] + has_complex = torch.is_complex(param) + if grad is not None: + params_with_grad.append(param) + grads.append(grad) + if param not in self.state: + self.state[param] = {} + state = self.state[param] + state["step"] = torch.tensor(0.0) + state["exp_avg"] = torch.zeros_like( + param, memory_format=torch.preserve_format + ) + state["exp_avg_sq"] = torch.zeros_like( + param, memory_format=torch.preserve_format + ) + if self.amsgrad: + state["max_exp_avg_sq"] = torch.zeros_like( + param, memory_format=torch.preserve_format + ) + + state = self.state[param] + exp_avgs.append(state["exp_avg"]) + exp_avg_sqs.append(state["exp_avg_sq"]) + + if self.amsgrad: + max_exp_avg_sqs.append(state["max_exp_avg_sq"]) + + state_steps.append(state["step"]) + with torch.no_grad(): + F.adam( + params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + max_exp_avg_sqs, + state_steps, + amsgrad=self.amsgrad, + has_complex=has_complex, + maximize=self.maximize, + beta1=self.defaults["beta1"], + beta2=self.defaults["beta2"], + lr=self.defaults["lr"], + weight_decay=self.defaults["weight_decay"], + eps=self.defaults["eps"], + foreach=self.foreach, + fused=self.fused, + grad_scale=None, + found_inf=None, + ) + + def step(self, gradients: list[Tensor | None]): + params = self.param_group["params"] + params_with_grad = [] + grads = [] + exp_avgs = [] + exp_avg_sqs = [] + max_exp_avg_sqs = [] + state_steps: list[Tensor] = [] + has_complex = False + + if len(params) != len(gradients): + raise ValueError( + "the gradients passed in does not equal to the size of the parameters!" + + f"Params length: {len(params)}. " + + f"Gradients length: {len(gradients)}" + ) + + for param, gradient in zip(self.param_group["params"], gradients): + if gradient is not None: + has_complex |= torch.is_complex(param) + params_with_grad.append(param) + grads.append(gradient) + # Lazy state initialization + if param not in self.state: + self.state[param] = {} + state = self.state[param] + state["step"] = torch.tensor(0.0) + # Exponential moving average of gradient values + state["exp_avg"] = torch.zeros_like( + param, memory_format=torch.preserve_format + ) + # Exponential moving average of squared gradient values + state["exp_avg_sq"] = torch.zeros_like( + param, memory_format=torch.preserve_format + ) + if self.amsgrad: + # Maintains max of all exp. moving avg. of sq. grad. values + state["max_exp_avg_sq"] = torch.zeros_like( + param, memory_format=torch.preserve_format + ) + + state = self.state[param] + + exp_avgs.append(state["exp_avg"]) + exp_avg_sqs.append(state["exp_avg_sq"]) + + if self.amsgrad: + max_exp_avg_sqs.append(state["max_exp_avg_sq"]) + + state_steps.append(state["step"]) + + with torch.no_grad(): + F.adam( + params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + max_exp_avg_sqs, + state_steps, + amsgrad=self.amsgrad, + has_complex=has_complex, + maximize=self.maximize, + beta1=self.defaults["beta1"], + beta2=self.defaults["beta2"], + lr=self.defaults["lr"], + weight_decay=self.defaults["weight_decay"], + eps=self.defaults["eps"], + foreach=self.foreach, + fused=self.fused, + grad_scale=None, + found_inf=None, + ) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/optim/functional_adamax.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/optim/functional_adamax.py new file mode 100644 index 0000000000000000000000000000000000000000..595a5668a78fc0f8451fa9e2a81c03d049bb4b82 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/optim/functional_adamax.py @@ -0,0 +1,122 @@ +# mypy: allow-untyped-defs + +import torch +import torch.optim._functional as F +from torch import Tensor +from torch.distributed.optim._deprecation_warning import ( + _scripted_functional_optimizer_deprecation_warning, +) + + +__all__: list[str] = [] + + +# Define a TorchScript compatible Functional Adamax Optimizer +# where we use these optimizer in a functional way. +# Instead of using the `param.grad` when updating parameters, +# we explicitly allow the distributed optimizer pass gradients to +# the `step` function. In this way, we could separate the gradients +# and parameters and allow multithreaded trainer to update the +# parameters without data traces on accumulating to the same .grad. +# NOTE: This should be only used by distributed optimizer internals +# and not meant to expose to the user. +@torch.jit.script +class _FunctionalAdamax: + def __init__( + self, + params: list[Tensor], + lr: float = 1e-3, + betas: tuple[float, float] = (0.9, 0.999), + eps: float = 1e-8, + weight_decay: float = 0.0, + foreach: bool = False, + maximize: bool = False, + _allow_empty_param_list: bool = False, + ): + _scripted_functional_optimizer_deprecation_warning(stacklevel=2) + if not 0.0 <= lr: + raise ValueError(f"Invalid learning rate: {lr}") + if not 0.0 <= eps: + raise ValueError(f"Invalid epsilon value: {eps}") + if not 0.0 <= betas[0] < 1.0: + raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}") + if not 0.0 <= betas[1] < 1.0: + raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}") + if not 0.0 <= weight_decay: + raise ValueError(f"Invalid weight_decay value: {weight_decay}") + + self.defaults = { + "lr": lr, + "eps": eps, + "beta1": betas[0], + "beta2": betas[1], + "weight_decay": weight_decay, + } + self.foreach = foreach + self.maximize = maximize + self.state = torch.jit.annotate(dict[torch.Tensor, dict[str, torch.Tensor]], {}) + + if len(params) == 0 and not _allow_empty_param_list: + raise ValueError("optimizer got an empty parameter list") + + # NOTE: we only have one param_group and don't allow user to add additional + # param group as it's not a common use case. + self.param_group = {"params": params} + + def step(self, gradients: list[Tensor | None]): + params = self.param_group["params"] + params_with_grad = [] + grads = [] + exp_avgs = [] + exp_infs = [] + state_steps: list[Tensor] = [] + + if len(params) != len(gradients): + raise ValueError( + "the gradients passed in does not equal to the size of the parameters!" + + f"Params length: {len(params)}. " + + f"Gradients length: {len(gradients)}" + ) + + has_complex = False + for param, gradient in zip(self.param_group["params"], gradients): + if gradient is not None: + has_complex |= torch.is_complex(param) + params_with_grad.append(param) + grads.append(gradient) + # Lazy state initialization + if param not in self.state: + self.state[param] = {} + state = self.state[param] + state["step"] = torch.tensor(0.0) + # Exponential moving average of gradient values + state["exp_avg"] = torch.zeros_like( + param, memory_format=torch.preserve_format + ) + # Exponential moving average of squared gradient values + state["exp_inf"] = torch.zeros_like( + param, memory_format=torch.preserve_format + ) + + state = self.state[param] + + exp_avgs.append(state["exp_avg"]) + exp_infs.append(state["exp_inf"]) + state_steps.append(state["step"]) + + with torch.no_grad(): + F.adamax( + params_with_grad, + grads, + exp_avgs, + exp_infs, + state_steps, + eps=self.defaults["eps"], + beta1=self.defaults["beta1"], + beta2=self.defaults["beta2"], + lr=self.defaults["lr"], + weight_decay=self.defaults["weight_decay"], + foreach=self.foreach, + maximize=self.maximize, + has_complex=has_complex, + ) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/optim/functional_adamw.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/optim/functional_adamw.py new file mode 100644 index 0000000000000000000000000000000000000000..d695ce8b473af8fbf1bde28293e576ff69fe6f04 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/optim/functional_adamw.py @@ -0,0 +1,202 @@ +# mypy: allow-untyped-defs + +import torch +import torch.optim._functional as F +from torch import Tensor +from torch.distributed.optim._deprecation_warning import ( + _scripted_functional_optimizer_deprecation_warning, +) + + +__all__: list[str] = [] + + +# Define a TorchScript compatible Functional AdamW Optimizer +# where we use these optimizer in a functional way. +# Instead of using the `param.grad` when updating parameters, +# we explicitly allow the distributed optimizer pass gradients to +# the `step` function. In this way, we could separate the gradients +# and parameters and allow multithreaded trainer to update the +# parameters without data traces on accumulating to the same .grad. +# NOTE: This should be only used by distributed optimizer internals +# and not meant to expose to the user. +@torch.jit.script +class _FunctionalAdamW: + def __init__( + self, + params: list[Tensor], + lr: float = 1e-3, + betas: tuple[float, float] = (0.9, 0.999), + eps: float = 1e-8, + weight_decay: float = 1e-2, + amsgrad: bool = False, + maximize: bool = False, + foreach: bool = False, + fused: bool = False, + _allow_empty_param_list: bool = False, + ): + _scripted_functional_optimizer_deprecation_warning(stacklevel=2) + if not 0.0 <= lr: + raise ValueError(f"Invalid learning rate: {lr}") + if not 0.0 <= eps: + raise ValueError(f"Invalid epsilon value: {eps}") + if not 0.0 <= betas[0] < 1.0: + raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}") + if not 0.0 <= betas[1] < 1.0: + raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}") + if not 0.0 <= weight_decay: + raise ValueError(f"Invalid weight_decay value: {weight_decay}") + + self.defaults = { + "lr": lr, + "eps": eps, + "beta1": betas[0], + "beta2": betas[1], + "weight_decay": weight_decay, + } + self.amsgrad = amsgrad + self.maximize = maximize + self.foreach = foreach + self.fused = fused + self.state = torch.jit.annotate(dict[torch.Tensor, dict[str, torch.Tensor]], {}) + + if len(params) == 0 and not _allow_empty_param_list: + raise ValueError("optimizer got an empty parameter list") + + # NOTE: we only have one param_group and don't allow user to add additional + # param group as it's not a common use case. + self.param_group = {"params": params} + + def step_param(self, param: Tensor, grad: Tensor | None): + params_with_grad = [] + grads = [] + exp_avgs = [] + exp_avg_sqs = [] + max_exp_avg_sqs = [] + state_steps: list[Tensor] = [] + has_complex = torch.is_complex(param) + if grad is not None: + params_with_grad.append(param) + grads.append(grad) + # Lazy state initialization + if param not in self.state: + self.state[param] = {} + state = self.state[param] + state["step"] = torch.tensor(0.0) + # Exponential moving average of gradient values + state["exp_avg"] = torch.zeros_like( + param, memory_format=torch.preserve_format + ) + # Exponential moving average of squared gradient values + state["exp_avg_sq"] = torch.zeros_like( + param, memory_format=torch.preserve_format + ) + if self.amsgrad: + # Maintains max of all exp. moving avg. of sq. grad. values + state["max_exp_avg_sq"] = torch.zeros_like( + param, memory_format=torch.preserve_format + ) + + state = self.state[param] + + exp_avgs.append(state["exp_avg"]) + exp_avg_sqs.append(state["exp_avg_sq"]) + + if self.amsgrad: + max_exp_avg_sqs.append(state["max_exp_avg_sq"]) + + state_steps.append(state["step"]) + with torch.no_grad(): + F.adamw( + params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + max_exp_avg_sqs, + state_steps, + amsgrad=self.amsgrad, + maximize=self.maximize, + beta1=self.defaults["beta1"], + beta2=self.defaults["beta2"], + lr=self.defaults["lr"], + weight_decay=self.defaults["weight_decay"], + eps=self.defaults["eps"], + foreach=self.foreach, + fused=self.fused, + grad_scale=None, + found_inf=None, + has_complex=has_complex, + ) + + def step(self, gradients: list[Tensor | None]): + params = self.param_group["params"] + params_with_grad = [] + grads = [] + exp_avgs = [] + exp_avg_sqs = [] + max_exp_avg_sqs = [] + state_steps: list[Tensor] = [] + + if len(params) != len(gradients): + raise ValueError( + "the gradients passed in does not equal to the size of the parameters!" + + f"Params length: {len(params)}. " + + f"Gradients length: {len(gradients)}" + ) + + has_complex = False + for param, gradient in zip(self.param_group["params"], gradients): + if gradient is not None: + has_complex |= torch.is_complex(param) + params_with_grad.append(param) + grads.append(gradient) + # Lazy state initialization + if param not in self.state: + self.state[param] = {} + state = self.state[param] + state["step"] = torch.tensor(0.0) + # Exponential moving average of gradient values + state["exp_avg"] = torch.zeros_like( + param, memory_format=torch.preserve_format + ) + # Exponential moving average of squared gradient values + state["exp_avg_sq"] = torch.zeros_like( + param, memory_format=torch.preserve_format + ) + if self.amsgrad: + # Maintains max of all exp. moving avg. of sq. grad. values + state["max_exp_avg_sq"] = torch.zeros_like( + param, memory_format=torch.preserve_format + ) + + state = self.state[param] + + exp_avgs.append(state["exp_avg"]) + exp_avg_sqs.append(state["exp_avg_sq"]) + + if self.amsgrad: + max_exp_avg_sqs.append(state["max_exp_avg_sq"]) + + state_steps.append(state["step"]) + + with torch.no_grad(): + F.adamw( + params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + max_exp_avg_sqs, + state_steps, + amsgrad=self.amsgrad, + maximize=self.maximize, + beta1=self.defaults["beta1"], + beta2=self.defaults["beta2"], + lr=self.defaults["lr"], + weight_decay=self.defaults["weight_decay"], + eps=self.defaults["eps"], + foreach=self.foreach, + fused=self.fused, + grad_scale=None, + found_inf=None, + has_complex=has_complex, + ) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/optim/functional_rmsprop.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/optim/functional_rmsprop.py new file mode 100644 index 0000000000000000000000000000000000000000..45341b03237b456419ec181ae8b771dec081d3cb --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/optim/functional_rmsprop.py @@ -0,0 +1,129 @@ +# mypy: allow-untyped-defs + +import torch +import torch.optim._functional as F +from torch import Tensor +from torch.distributed.optim._deprecation_warning import ( + _scripted_functional_optimizer_deprecation_warning, +) + + +__all__: list[str] = [] + + +# Define a TorchScript compatible Functional RMSprop Optimizer +# where we use these optimizer in a functional way. +# Instead of using the `param.grad` when updating parameters, +# we explicitly allow the distributed optimizer pass gradients to +# the `step` function. In this way, we could separate the gradients +# and parameters and allow multithreaded trainer to update the +# parameters without data traces on accumulating to the same .grad. +# NOTE: This should be only used by distributed optimizer internals +# and not meant to expose to the user. +@torch.jit.script +class _FunctionalRMSprop: + def __init__( + self, + params: list[Tensor], + lr: float = 1e-2, + alpha: float = 0.99, + eps: float = 1e-8, + weight_decay: float = 0.0, + momentum: float = 0.0, + centered: bool = False, + foreach: bool = False, + maximize: bool = False, + _allow_empty_param_list: bool = False, + ): + _scripted_functional_optimizer_deprecation_warning(stacklevel=2) + self.defaults = { + "lr": lr, + "alpha": alpha, + "eps": eps, + "weight_decay": weight_decay, + "momentum": momentum, + } + self.centered = centered + self.foreach = foreach + self.maximize = maximize + + if len(params) == 0 and not _allow_empty_param_list: + raise ValueError("optimizer got an empty parameter list") + + # NOTE: we only have one param_group and don't allow user to add additional + # param group as it's not a common use case. + self.param_group = {"params": params} + + self.state = torch.jit.annotate(dict[torch.Tensor, dict[str, torch.Tensor]], {}) + + def step(self, gradients: list[Tensor | None]): + params = self.param_group["params"] + params_with_grad = [] + grads = [] + square_avgs = [] + grad_avgs = [] + momentum_buffer_list = [] + state_steps = [] + lr = self.defaults["lr"] + alpha = self.defaults["alpha"] + eps = self.defaults["eps"] + momentum = self.defaults["momentum"] + weight_decay = self.defaults["weight_decay"] + + if len(params) != len(gradients): + raise ValueError( + "the gradients passed in does not equal to the size of the parameters!" + + f"Params length: {len(params)}. " + + f"Gradients length: {len(gradients)}" + ) + + has_complex = False + for param, gradient in zip(params, gradients): + if gradient is not None: + has_complex |= torch.is_complex(param) + params_with_grad.append(param) + grads.append(gradient) + # Lazy state initialization + if param not in self.state: + self.state[param] = {} + state = self.state[param] + state["step"] = torch.tensor(0.0) + state["square_avg"] = torch.zeros_like( + param, memory_format=torch.preserve_format + ) + if momentum > 0: + state["momentum_buffer"] = torch.zeros_like( + param, memory_format=torch.preserve_format + ) + if self.centered: + state["grad_avg"] = torch.zeros_like( + param, memory_format=torch.preserve_format + ) + + state = self.state[param] + square_avgs.append(state["square_avg"]) + if momentum > 0: + momentum_buffer_list.append(state["momentum_buffer"]) + if self.centered: + grad_avgs.append(state["grad_avg"]) + + state_steps.append(state["step"]) + + with torch.no_grad(): + F.rmsprop( + params_with_grad, + grads, + square_avgs, + grad_avgs, + momentum_buffer_list, + state_steps, + lr=lr, + alpha=alpha, + eps=eps, + weight_decay=weight_decay, + momentum=momentum, + centered=self.centered, + foreach=self.foreach, + maximize=self.maximize, + has_complex=has_complex, + ) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/optim/functional_rprop.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/optim/functional_rprop.py new file mode 100644 index 0000000000000000000000000000000000000000..ffc9c510dabca7871d19890c5e52e0f5eeafcd49 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/optim/functional_rprop.py @@ -0,0 +1,106 @@ +# mypy: allow-untyped-defs + +import torch +import torch.optim._functional as F +from torch import Tensor +from torch.distributed.optim._deprecation_warning import ( + _scripted_functional_optimizer_deprecation_warning, +) + + +__all__: list[str] = [] + + +# Define a TorchScript compatible Functional Rprop Optimizer +# where we use these optimizer in a functional way. +# Instead of using the `param.grad` when updating parameters, +# we explicitly allow the distributed optimizer pass gradients to +# the `step` function. In this way, we could separate the gradients +# and parameters and allow multithreaded trainer to update the +# parameters without data traces on accumulating to the same .grad. +# NOTE: This should be only used by distributed optimizer internals +# and not meant to expose to the user. +@torch.jit.script +class _FunctionalRprop: + def __init__( + self, + params: list[Tensor], + lr: float = 1e-2, + etas: tuple[float, float] = (0.5, 1.2), + step_sizes: tuple[float, float] = (1e-6, 50), + foreach: bool = False, + maximize: bool = False, + _allow_empty_param_list: bool = False, + ): + _scripted_functional_optimizer_deprecation_warning(stacklevel=2) + self.defaults = { + "lr": lr, + } + self.etas = etas + self.step_sizes = step_sizes + self.foreach = foreach + self.maximize = maximize + + if len(params) == 0 and not _allow_empty_param_list: + raise ValueError("optimizer got an empty parameter list") + + # NOTE: we only have one param_group and don't allow user to add additional + # param group as it's not a common use case. + self.param_group = {"params": params} + + self.state = torch.jit.annotate(dict[torch.Tensor, dict[str, torch.Tensor]], {}) + + def step(self, gradients: list[Tensor | None]): + params = self.param_group["params"] + params_with_grad = [] + grads = [] + prevs = [] + step_sizes = [] + state_steps = [] + lr = self.defaults["lr"] + etaminus, etaplus = self.etas + step_size_min, step_size_max = self.step_sizes + + if len(params) != len(gradients): + raise ValueError( + "the gradients passed in does not equal to the size of the parameters!" + + f"Params length: {len(params)}. " + + f"Gradients length: {len(gradients)}" + ) + + has_complex = False + for param, gradient in zip(params, gradients): + if gradient is not None: + has_complex |= torch.is_complex(param) + params_with_grad.append(param) + grads.append(gradient) + # Lazy state initialization + if param not in self.state: + self.state[param] = {} + state = self.state[param] + state["step"] = torch.tensor(0.0) + state["prev"] = torch.zeros_like( + param, memory_format=torch.preserve_format + ) + state["step_size"] = torch.full_like(gradient, lr) + + state = self.state[param] + prevs.append(state["prev"]) + step_sizes.append(state["step_size"]) + state_steps.append(state["step"]) + + with torch.no_grad(): + F.rprop( + params_with_grad, + grads, + prevs, + step_sizes, + state_steps, + step_size_min=step_size_min, + step_size_max=step_size_max, + etaminus=etaminus, + etaplus=etaplus, + foreach=self.foreach, + maximize=self.maximize, + has_complex=has_complex, + ) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/optim/functional_sgd.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/optim/functional_sgd.py new file mode 100644 index 0000000000000000000000000000000000000000..aed92403e6fb62394e2f755fffbe5b7f323200ff --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/optim/functional_sgd.py @@ -0,0 +1,165 @@ +# mypy: allow-untyped-defs + +import torch +import torch.optim._functional as F +from torch import Tensor +from torch.distributed.optim._deprecation_warning import ( + _scripted_functional_optimizer_deprecation_warning, +) + + +__all__: list[str] = [] + + +# Define a TorchScript compatible Functional SGD Optimizer +# where we use these optimizer in a functional way. +# Instead of using the `param.grad` when updating parameters, +# we explicitly allow the distributed optimizer pass gradients to +# the `step` function. In this way, we could separate the gradients +# and parameters and allow multithreaded trainer to update the +# parameters without data traces on accumulating to the same .grad. +# NOTE: This should be only used by distributed optimizer internals +# and not meant to expose to the user. +@torch.jit.script +class _FunctionalSGD: + def __init__( + self, + params: list[Tensor], + lr: float = 1e-2, + momentum: float = 0.0, + dampening: float = 0.0, + weight_decay: float = 0.0, + nesterov: bool = False, + maximize: bool = False, + foreach: bool = False, + fused: bool = False, + _allow_empty_param_list: bool = False, + ): + _scripted_functional_optimizer_deprecation_warning(stacklevel=2) + self.defaults = { + "lr": lr, + "momentum": momentum, + "dampening": dampening, + "weight_decay": weight_decay, + } + self.nesterov = nesterov + self.maximize = maximize + self.foreach = foreach + self.fused = fused + self.state = torch.jit.annotate(dict[torch.Tensor, dict[str, torch.Tensor]], {}) + + if len(params) == 0 and not _allow_empty_param_list: + raise ValueError("optimizer got an empty parameter list") + + # NOTE: we only have one param_group and don't allow user to add additional + # param group as it's not a common use case. + self.param_group = {"params": params} + + def step_param(self, param: Tensor, grad: Tensor | None): + """Similar to self.step, but operates on a single parameter and + its gradient. + """ + # TODO: Once step_param interface is robust, refactor step to call + # step param on each param. + weight_decay = self.defaults["weight_decay"] + momentum = self.defaults["momentum"] + dampening = self.defaults["dampening"] + lr = self.defaults["lr"] + params = [param] + momentum_buffer_list: list[Tensor | None] = [] + grads = [] + + has_sparse_grad = False + if grad is not None: + grads.append(grad) + if grad.is_sparse: + has_sparse_grad = True + if param not in self.state: + self.state[param] = {} + state = self.state[param] + if "momentum_buffer" not in state: + momentum_buffer_list.append(None) + else: + momentum_buffer_list.append(state["momentum_buffer"]) + + with torch.no_grad(): + F.sgd( + params, + grads, + momentum_buffer_list, + weight_decay=weight_decay, + momentum=momentum, + lr=lr, + dampening=dampening, + nesterov=self.nesterov, + maximize=self.maximize, + has_sparse_grad=has_sparse_grad, + foreach=self.foreach, + fused=self.fused, + grad_scale=None, + found_inf=None, + ) + # update momentum_buffer in state + state = self.state[param] + momentum_buffer = momentum_buffer_list[0] + if momentum_buffer is not None: + state["momentum_buffer"] = momentum_buffer + + def step(self, gradients: list[Tensor | None]): + params = self.param_group["params"] + params_with_grad = [] + grads = [] + momentum_buffer_list: list[Tensor | None] = [] + lr = self.defaults["lr"] + weight_decay = self.defaults["weight_decay"] + momentum = self.defaults["momentum"] + dampening = self.defaults["dampening"] + + if len(params) != len(gradients): + raise ValueError( + "the gradients passed in does not equal to the size of the parameters!" + + f"Params length: {len(params)}. " + + f"Gradients length: {len(gradients)}" + ) + + has_sparse_grad = False + for param, gradient in zip(params, gradients): + if gradient is not None: + params_with_grad.append(param) + grads.append(gradient) + if gradient.is_sparse: + has_sparse_grad = True + + if param not in self.state: + self.state[param] = {} + + state = self.state[param] + if "momentum_buffer" not in state: + momentum_buffer_list.append(None) + else: + momentum_buffer_list.append(state["momentum_buffer"]) + + with torch.no_grad(): + F.sgd( + params_with_grad, + grads, + momentum_buffer_list, + weight_decay=weight_decay, + momentum=momentum, + lr=lr, + dampening=dampening, + nesterov=self.nesterov, + maximize=self.maximize, + has_sparse_grad=has_sparse_grad, + foreach=self.foreach, + fused=self.fused, + grad_scale=None, + found_inf=None, + ) + + # update momentum_buffers in state + for i, p in enumerate(params_with_grad): + state = self.state[p] + momentum_buffer = momentum_buffer_list[i] + if momentum_buffer is not None: + state["momentum_buffer"] = momentum_buffer diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/optim/named_optimizer.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/optim/named_optimizer.py new file mode 100644 index 0000000000000000000000000000000000000000..a8432e198a083e194a1e48bf8c0af76ffa6b83a1 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/optim/named_optimizer.py @@ -0,0 +1,328 @@ +import logging +import warnings +from collections.abc import Callable, Collection, Mapping +from copy import deepcopy +from typing import Any, overload + +import torch +import torch.nn as nn +from torch import optim +from torch.distributed._shard.sharded_tensor import ShardedTensor +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP + + +__all__: list[str] = [] + +logger = logging.getLogger(__name__) + + +class _NamedOptimizer(optim.Optimizer): + """ + ``_NamedOptimizer`` takes a dict of parameters and exposes ``state_dict`` by parameter key. + + We replace the original key (number) in an optim to the + fully qualified name (FQN) string. User can initialize the optim as they + initialize a PyTorch optim, the only difference is that they also need to + pass in the FQN of each parameters. + + Args: + named_parameters (Mapping[str, Union[torch.Tensor, ShardedTensor]]): + Mapping from FQN to parameter. + optimizer_class (optim.Optimizer): + The class of optimizer to instantiate. + param_groups (Collection[Mapping[str, Any]]): + `param_groups` to pass to optimizer if specified. + The key of the inner map needs to be FQNs. + Default: None + module (nn.Module): the module whose parameters to updated + by the optimizer. + args: arguments to pass to the optimizer constructor. + kwargs: arguments to pass to the optimizer constructor. + + Example:: + >>> # xdoctest: +SKIP("distributed") + >>> from torch import optim + >>> from torch.distributed.optim import _NamedOptimizer + >>> + >>> # Define the named optimizer. + >>> m = Model(...) + >>> named_optim = _NamedOptimizer(m.named_parameters(), optim.SGD) + >>> # Forward pass + backward pass. + >>> named_optim.step() + >>> ... + >>> # Call state_dict for the named optimizer returns a FQN state_dict. + >>> named_optim.state_dict() + + Warning: This API is still in development and subject to change. + + TODO: Add tutorial for _NamedOptimizer. + TODO: Add documentation in the docstring for the public attributes + like self.param_groups and self.named_parameters. + """ + + def __init__( + self, + named_parameters: Mapping[str, torch.Tensor | ShardedTensor], + optimizer_class: optim.Optimizer, + param_groups: Collection[Mapping[str, Any]] | None = None, + module: nn.Module | None = None, + *args: tuple[Any, ...], + **kwargs: dict[str, Any], + ) -> None: + torch._C._log_api_usage_once("torch.distributed.optim._NamedOptimizer") + self.param_groups: Collection[Mapping[str, Any]] = param_groups # type: ignore[assignment] + self._param_groups_check() + self.named_parameters = dict(named_parameters) + params_for_optimizer = ( + self.named_parameters.values() if param_groups is None else param_groups + ) + self._optimizer = optimizer_class( # type: ignore[operator] + params_for_optimizer, + *args, + **kwargs, + ) + self.module = module + if param_groups is None: + self.ordered_param_keys = list(self.named_parameters.keys()) + else: + warnings.warn( + "Since we pass in param_groups, we will use param_groups to " + "initialize the optimizer, not all parameters of the module.", + stacklevel=2, + ) + param_to_key = {param: key for key, param in self.named_parameters.items()} # type: ignore[misc, has-type] + ordered_param_keys = [] + for group in param_groups: + for param in group["params"]: + if param not in param_to_key: + raise ValueError( + f"Expect param name {param} found in param group but is missing." + ) + ordered_param_keys.append(param_to_key[param]) + self.ordered_param_keys = ordered_param_keys + # Update param_groups from optimizer. + self.param_groups = self._optimizer.param_groups + + def _param_groups_check(self) -> None: + if self.param_groups is not None: + for param_group in self.param_groups: + assert isinstance(param_group, dict), "param group must be a dict" + assert "params" in param_group, "param group must contain key params" + params = param_group["params"] + if isinstance(params, torch.Tensor): + params = [params] + params = list(params) + for param in params: + if not isinstance(param, torch.Tensor): + raise TypeError( + "optimizer can only optimize Tensors, " + "but one of the params is " + torch.typename(param) + ) + param_group["params"] = params + + def state_dict(self) -> dict[str, Any]: + """ + Return the ``state_dict`` of the optimizer. + + Instead of using number to index + parameters, we will use module fully qualified name (FQN) as the key. + """ + state_dict = self._optimizer.state_dict() + param_groups = state_dict["param_groups"] + + ret_state = { + self.ordered_param_keys[st_key]: state_val + for st_key, state_val in state_dict["state"].items() + } + + ret_groups = [] + for group in param_groups: + param_keys = [self.ordered_param_keys[param] for param in group["params"]] + ret_group = {"params": sorted(param_keys)} + for k, v in group.items(): + if k != "params": + ret_group[k] = deepcopy(v) + ret_groups.append(ret_group) + + return self._post_state_dict({"state": ret_state, "param_groups": ret_groups}) + + @overload + def step(self, closure: None = None) -> None: ... + + @overload + def step(self, closure: Callable[[], float]) -> float: ... + + def step(self, closure: Callable[[], float] | None = None) -> float | None: + """ + Perform a single optimization step. + + This will call :meth:`torch.optim.Optimizer.step` on the wrapped + optimizer. + """ + return self._optimizer.step(closure=closure) + + @property + def state(self) -> Mapping[torch.Tensor, Any]: # type: ignore[override] + return self._optimizer.state + + def load_state_dict(self, state_dict: dict[str, Any]) -> None: + """ + Define the default behavior to load a state_dict for ``_NamedOptimizer``. + + Sample Code + ``` + my_model = MyModule() + optimizer = _NamedOptimizer(my_model.named_parameters(), Adagrad) + ... + + optim_state_dict = optimizer.state_dict() + ... + ... + + optimizer.load_state_dict(optim_state_dict) + ... + ``` + Args: + state_dict (dict[str, Any]) : A ``state_dict`` to load into the optimizer. + Note that this state dict update is performed in place. + + .. note:: PyTorch is using lazy init to initialize the optim states. + So it is possible that there is no optim state when user call + ``load_state_dict`` and for ``_NamedOptimizer`` we make it stricter + that users can only call ``load_state_dict`` after the state is initialized. + By doing this, we can validate the optim ``state_dict`` to be loaded. + """ + new_state_dict = self._optimizer.state_dict() + state_dict = self._pre_load_state_dict(state_dict) + state = state_dict["state"] + new_state = new_state_dict["state"] + if len(new_state) == 0: + raise ValueError( + "Expects the optim to be initialized before load but found not initialized." + ) + + for idx, param_key in enumerate(self.ordered_param_keys): + # When the conditional training is performed, not all parameters are updated in the optim. + if param_key not in state: + continue + if len(state[param_key]) != len(new_state[idx]): + raise ValueError( + f"Expects equal length as {len(new_state[idx])} for parameter {param_key} but found: {len(state[param_key])}" + ) + # Iterate through all optimizer states. + for state_key, state_val in new_state[idx].items(): + if state_key not in state[param_key]: + raise ValueError( + f"Expects state {state_key} for parameter {param_key} but not found." + ) + + src_state_val = state[param_key][state_key] + if isinstance(state_val, ShardedTensor): + assert isinstance(src_state_val, ShardedTensor) + num_shards = len(state_val.local_shards()) + num_new_shards = len(src_state_val.local_shards()) + if num_shards != num_new_shards: + raise ValueError( + f"Expects equal number of shards as {num_new_shards} but found {num_shards} for {param_key}/{state_key}" + ) + for shard, src_shard in zip( + state_val.local_shards(), src_state_val.local_shards() + ): + shard.tensor.detach().copy_(src_shard.tensor) + elif isinstance(state_val, torch.Tensor): + assert isinstance(src_state_val, torch.Tensor) + state_val.detach().copy_(src_state_val) + else: + new_state[idx][state_key] = deepcopy(src_state_val) + + # Load param_groups of state_dict + src_param_groups = state_dict["param_groups"] + new_param_groups = new_state_dict["param_groups"] + + src_group_map = {} + for group in src_param_groups: + param_keys = list(group["params"]) + src_group_map[_gen_param_group_key(param_keys)] = group + new_group_map = {} + for new_group in new_param_groups: + param_keys = [] + for param_key in new_group["params"]: + param_keys.append(self.ordered_param_keys[param_key]) # type: ignore[call-overload] + new_group_map[_gen_param_group_key(param_keys)] = new_group + for group_key, new_group in new_group_map.items(): + # When not all parameters are used in training or receive gradient, aka., not all parameters + # would be in the param_group. Thus we skip the group_key here. + if group_key not in src_group_map: + continue + src_group = src_group_map[group_key] + if len(src_group) != len(new_group): + raise ValueError( + f"Expects equal param_group size as {len(new_group)} for group {group_key} but found {len(src_group)}." + ) + for k in src_group: + if k not in new_group: + raise ValueError( + f"Expects group key {k} to be in group {group_key} in `state_dict` but is missing." + ) + if k != "params": + new_group[k] = deepcopy(src_group[k]) + + self._optimizer.load_state_dict(new_state_dict) + + def add_param_group(self, param_group: Mapping[str, Any]) -> None: + """ + Add a param group to the :class:`_NamedOptimizer` s `param_groups`. + + Warning: This API is still in development and subject to change. + """ + assert isinstance(param_group, dict), "param group must be a dict" + + params = param_group["params"] + if isinstance(params, torch.Tensor): + param_group["params"] = [params] + else: + param_group["params"] = list(params) + + param_to_key = {param: key for key, param in self.named_parameters.items()} # type: ignore[misc, has-type] + for param in param_group["params"]: + if param not in param_to_key: + raise ValueError("some parameters are not in the module") + self.ordered_param_keys.append(param_to_key[param]) + + self._optimizer.add_param_group(param_group) + # Update param_groups from optimizer. + self.param_groups = self._optimizer.param_groups + + def init_state(self) -> None: + """ + Run a dummy optimizer step, which allows to initialize optimizer state because we do lazy init for most optimizers. + + This allows doing in-place loading of optimizer state from a checkpoint. + """ + for param in self.named_parameters.values(): + if param.requires_grad: + t = torch.zeros_like(param) + param.grad = torch.autograd.Variable(t) + # Calling ``step`` will load the initial state for optimizer states. + self.step(closure=None) + + def _pre_load_state_dict(self, state_dict: dict[str, Any]) -> dict[str, Any]: + # TODO(chienchin): This API should be FSDP agnostic and should support + # general user hooks. + if isinstance(self.module, FSDP): + return FSDP.optim_state_dict_to_load( + self.module, self._optimizer, state_dict, is_named_optimizer=True + ) + return state_dict + + def _post_state_dict(self, state_dict: dict[str, Any]) -> dict[str, Any]: + # TODO(chienchin): This API should be FSDP agnostic and should support + # general user hooks. + if isinstance(self.module, FSDP): + FSDP.optim_state_dict(self.module, self._optimizer, state_dict) + return state_dict + + +def _gen_param_group_key(param_keys: list[str]) -> str: + """Concatenate all param keys as a unique identifier for one param group.""" + return "/".join(sorted(param_keys)) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/optim/optimizer.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/optim/optimizer.py new file mode 100644 index 0000000000000000000000000000000000000000..f9477aa414b429e4cb4ca8bf1d1fedf9788d4eaa --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/optim/optimizer.py @@ -0,0 +1,254 @@ +# mypy: allow-untyped-defs +import logging +from collections import defaultdict +from threading import Lock + +import torch +import torch.distributed.autograd as dist_autograd +import torch.distributed.rpc as rpc +import torch.jit as jit +import torch.nn as nn +from torch import Tensor +from torch.distributed.rpc import RRef + +from .utils import functional_optim_map + + +__all__ = ["DistributedOptimizer"] + +logger = logging.getLogger(__name__) + + +# XXX: we define a _ScriptModuleOptimizer here to explicitly +# compile the FunctionalOptimizer class into TorchScript +# This is because ScriptClass instance still lives in +# python unless you explicitly compile it as an attribute +# in ScriptModule or pass it to a ScriptFunction +# _ScriptLocalOptimizerInterface serves as a common +# interface type for Optimizer ScriptModules. +# +# TODO (wanchaol): remove this once we added TorchScript +# class reference semantics +@jit.interface +class _ScriptLocalOptimizerInterface: + def step(self, autograd_ctx_id: int) -> None: + pass + + +class _ScriptLocalOptimizer(nn.Module): + # TorchScript does not support multithread concurrent compiling. + # request_callback might invoke concurrent compiling, so we + # serialize the compiling with a lock + compile_lock = Lock() + + def __init__(self, optim_cls, local_params_rref, *args, **kwargs): + super().__init__() + self._local_params = [rref.local_value() for rref in local_params_rref] + self.optim = optim_cls(self._local_params, *args, **kwargs) + + @jit.export + def step(self, autograd_ctx_id: int): + all_local_grads = dist_autograd.get_gradients(autograd_ctx_id) + # apply functional optimizer step with a list of gradients + grads: list[Tensor | None] = [ + all_local_grads[p] if p in all_local_grads else None # noqa: SIM401 + for p in self._local_params + ] + + self.optim.step(grads) + + +# TODO (wanchaol): remove/merge this with ScriptLocalOptimizer once +# we have converted all to functional optimizer in distributed.optim +class _LocalOptimizer: + # Ideally we would only need to share a lock for instances of + # _LocalOptimizer that deal with the same parameters. We are + # making a simplifying assumption here that if there is more + # than one instance of _LocalOptimizer per worker, they will + # be optimizing the same parameters (e.g. each data parallel + # trainer will create its own instance of _LocalOptimizer but + # they will all optimize the same parameters on each worker) + global_lock = Lock() + + def __init__(self, optim_cls, local_params_rref, *args, **kwargs): + self._local_params = [rref.local_value() for rref in local_params_rref] + self.optim = optim_cls(self._local_params, *args, **kwargs) + + def step(self, autograd_ctx_id): + all_local_grads = dist_autograd.get_gradients(autograd_ctx_id) + + with _LocalOptimizer.global_lock: + for param, grad in all_local_grads.items(): + param.grad = grad + self.optim.step() + + +def _new_local_optimizer(optim_cls, local_params_rref, *args, **kwargs): + return rpc.RRef(_LocalOptimizer(optim_cls, local_params_rref, *args, **kwargs)) + + +def _local_optimizer_step(local_optim_rref, autograd_ctx_id): + local_optim = local_optim_rref.local_value() + local_optim.step(autograd_ctx_id) + + +# new/step functions combined with _ScriptLocalOptimizer to provide GIL-free optimizer +def _new_script_local_optimizer(optim_cls, local_params_rref, *args, **kwargs): + optim = _ScriptLocalOptimizer(optim_cls, local_params_rref, *args, **kwargs) + + with _ScriptLocalOptimizer.compile_lock: + script_optim = jit.script(optim) + return rpc.RRef(script_optim, _ScriptLocalOptimizerInterface) + + +@jit.script +def _script_local_optimizer_step( + local_optim_rref: RRef[_ScriptLocalOptimizerInterface], autograd_ctx_id: int +) -> None: + local_optim = local_optim_rref.local_value() + local_optim.step(autograd_ctx_id) + + +def _wait_for_all(rpc_futs): + # TODO: improve error propagation + exception = None + results = [] + for fut in rpc_futs: + try: + results.append(fut.wait()) + except Exception as e: + results.append(e) + exception = e + if exception is not None: + raise exception + return results + + +class DistributedOptimizer: + """ + DistributedOptimizer takes remote references to parameters scattered + across workers and applies the given optimizer locally for each parameter. + + This class uses :meth:`~torch.distributed.autograd.get_gradients` in order + to retrieve the gradients for specific parameters. + + Concurrent calls to + :meth:`~torch.distributed.optim.DistributedOptimizer.step`, + either from the same or different clients, will + be serialized on each worker -- as each worker's optimizer can only work + on one set of gradients at a time. However, there is no guarantee that + the full forward-backward-optimizer sequence will execute for one client + at a time. This means that the gradients being applied may not correspond + to the latest forward pass executed on a given worker. Also, there is no + guaranteed ordering across workers. + + `DistributedOptimizer` creates the local optimizer with TorchScript enabled + by default, so that optimizer updates are not blocked by the Python Global + Interpreter Lock (GIL) in the case of multithreaded training (e.g. Distributed + Model Parallel). This feature is currently enabled for most optimizers. You + can also follow `the recipe`__ in PyTorch tutorials to enable TorchScript support + for your own custom optimizers. + + Args: + optimizer_class (optim.Optimizer): the class of optimizer to + instantiate on each worker. + params_rref (list[RRef]): list of RRefs to local or remote parameters + to optimize. + args: arguments to pass to the optimizer constructor on each worker. + kwargs: arguments to pass to the optimizer constructor on each worker. + + Example:: + >>> # xdoctest: +SKIP("distributed") + >>> import torch.distributed.autograd as dist_autograd + >>> import torch.distributed.rpc as rpc + >>> from torch import optim + >>> from torch.distributed.optim import DistributedOptimizer + >>> + >>> with dist_autograd.context() as context_id: + >>> # Forward pass. + >>> rref1 = rpc.remote("worker1", torch.add, args=(torch.ones(2), 3)) + >>> rref2 = rpc.remote("worker1", torch.add, args=(torch.ones(2), 1)) + >>> loss = rref1.to_here() + rref2.to_here() + >>> + >>> # Backward pass. + >>> dist_autograd.backward(context_id, [loss.sum()]) + >>> + >>> # Optimizer. + >>> dist_optim = DistributedOptimizer( + >>> optim.SGD, + >>> [rref1, rref2], + >>> lr=0.05, + >>> ) + >>> dist_optim.step(context_id) + + __ https://github.com/pytorch/tutorials/pull/1465 + """ + + def __init__(self, optimizer_class, params_rref, *args, **kwargs): + torch._C._log_api_usage_once("torch.distributed.optim.DistributedOptimizer") + per_worker_params_rref = defaultdict(list) + for param in params_rref: + per_worker_params_rref[param.owner()].append(param) + + if optimizer_class in functional_optim_map and jit._state._enabled: + optim_ctor = functional_optim_map.get(optimizer_class) + else: + optim_ctor = optimizer_class + self.is_functional_optim = optim_ctor != optimizer_class + + if self.is_functional_optim: + optimizer_new_func = _new_script_local_optimizer + else: + logger.warning( + "Creating the optimizer %s without TorchScript support, " + "this might result in slow computation time in multithreading environment" + "(i.e. Distributed Model Parallel training on CPU) due to the Python's " + "Global Interpreter Lock (GIL). Please file an issue if you need this " + "optimizer in TorchScript. ", + optimizer_class, + ) + optimizer_new_func = _new_local_optimizer + + remote_optim_futs = [] + for worker, param_rrefs in per_worker_params_rref.items(): + remote_optim_rref_fut = rpc.rpc_async( + worker, + optimizer_new_func, + args=(optim_ctor, param_rrefs) + args, + kwargs=kwargs, + ) + remote_optim_futs.append(remote_optim_rref_fut) + + self.remote_optimizers = _wait_for_all(remote_optim_futs) + + def step(self, context_id): + """ + Performs a single optimization step. + + This will call :meth:`torch.optim.Optimizer.step` on each worker + containing parameters to be optimized, and will block until all workers + return. The provided ``context_id`` will be used to retrieve the + corresponding :class:`~torch.distributed.autograd.context` that + contains the gradients that should be applied to the parameters. + + Args: + context_id: the autograd context id for which we should run the + optimizer step. + """ + dist_autograd._is_valid_context(context_id) + + optimizer_step_func = ( + _script_local_optimizer_step + if self.is_functional_optim + else _local_optimizer_step + ) + + rpc_futs = [ + rpc.rpc_async( + optimizer.owner(), + optimizer_step_func, + args=(optimizer, context_id), + ) + for optimizer in self.remote_optimizers + ] + _wait_for_all(rpc_futs) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/optim/post_localSGD_optimizer.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/optim/post_localSGD_optimizer.py new file mode 100644 index 0000000000000000000000000000000000000000..c7b78510ed1a111998a4eda21546b003eedbcce7 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/optim/post_localSGD_optimizer.py @@ -0,0 +1,111 @@ +# mypy: allow-untyped-defs +import warnings + +import torch +import torch.distributed.algorithms.model_averaging.averagers as averagers + + +class PostLocalSGDOptimizer(torch.optim.Optimizer): + r""" + Wraps an arbitrary :class:`torch.optim.Optimizer` and runs `post-local SGD `_, + This optimizer runs local optimizer at every step. + After the warm-up stage, it averages parameters periodically after the local optimizer is applied. + + Args: + optim: The local optimizer. + averager: A model averager instance to run post-localSGD algorithm. + + Example:: + + >>> # xdoctest: +SKIP("undefined variables") + >>> import torch + >>> import torch.distributed as dist + >>> import torch.distributed.algorithms.model_averaging.averagers as averagers + >>> import torch.nn as nn + >>> from torch.distributed.optim import PostLocalSGDOptimizer + >>> from torch.distributed.algorithms.ddp_comm_hooks.post_localSGD_hook import ( + >>> PostLocalSGDState, + >>> post_localSGD_hook, + >>> ) + >>> + >>> model = nn.parallel.DistributedDataParallel( + >>> module, device_ids=[rank], output_device=rank + >>> ) + >>> + >>> # Register a post-localSGD communication hook. + >>> state = PostLocalSGDState(process_group=None, subgroup=None, start_localSGD_iter=100) + >>> model.register_comm_hook(state, post_localSGD_hook) + >>> + >>> # Create a post-localSGD optimizer that wraps a local optimizer. + >>> # Note that ``warmup_steps`` used in ``PostLocalSGDOptimizer`` must be the same as + >>> # ``start_localSGD_iter`` used in ``PostLocalSGDState``. + >>> local_optim = torch.optim.SGD(params=model.parameters(), lr=0.01) + >>> opt = PostLocalSGDOptimizer( + >>> optim=local_optim, + >>> averager=averagers.PeriodicModelAverager(period=4, warmup_steps=100) + >>> ) + >>> + >>> # In the first 100 steps, DDP runs global gradient averaging at every step. + >>> # After 100 steps, DDP runs gradient averaging within each subgroup (intra-node by default), + >>> # and post-localSGD optimizer runs global model averaging every 4 steps after applying the local optimizer. + >>> for step in range(0, 200): + >>> opt.zero_grad() + >>> loss = loss_fn(output, labels) + >>> loss.backward() + >>> opt.step() + """ + + def __init__(self, optim: torch.optim.Optimizer, averager: averagers.ModelAverager): + self.optim = optim + self.param_groups = self.optim.param_groups + self.averager = averager + + @property + def state(self): # type: ignore[override] + return self.optim.state + + def __repr__(self): + return self.optim.__repr__() + + def state_dict(self): + r""" + This is the same as :class:`torch.optim.Optimizer` :meth:`state_dict`, + but adds an extra entry to record model averager's step to the checkpoint + to ensure reload does not cause unnecessary warm up again. + """ + optim_state_dict = self.optim.state_dict() + optim_state_dict["step"] = self.averager.step + return optim_state_dict + + def load_state_dict(self, state_dict): + r""" + This is the same as :class:`torch.optim.Optimizer` :meth:`load_state_dict`, + but also restores model averager's step value to the one + saved in the provided ``state_dict``. + + If there is no ``"step"`` entry in ``state_dict``, + it will raise a warning and initialize the model averager's step to 0. + """ + self.optim.load_state_dict(state_dict) + if "step" in state_dict: + self.averager.step = state_dict["step"] + else: + warnings.warn( + "Loaded state dict does not contain a step counter for an averager. " + "Setting step counter to 0.", + stacklevel=2, + ) + self.averager.step = 0 + + def step(self): # type: ignore[override] + r""" + Performs a single optimization step (parameter update). + """ + self.optim.step() + self.averager.average_parameters(params=self.param_groups) + + def zero_grad(self, set_to_none: bool = True): # type: ignore[override] + self.optim.zero_grad(set_to_none=set_to_none) + + def add_param_group(self, param_group): + self.optim.add_param_group(param_group) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/optim/utils.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/optim/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..c7075edd2e5210f1dc3d50aaa09688a4a4e1d09c --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/optim/utils.py @@ -0,0 +1,65 @@ +# mypy: allow-untyped-defs + +from torch import optim + +from .functional_adadelta import _FunctionalAdadelta +from .functional_adagrad import _FunctionalAdagrad +from .functional_adam import _FunctionalAdam +from .functional_adamax import _FunctionalAdamax +from .functional_adamw import _FunctionalAdamW +from .functional_rmsprop import _FunctionalRMSprop +from .functional_rprop import _FunctionalRprop +from .functional_sgd import _FunctionalSGD + + +# dict to map a user passed in optimizer_class to a functional +# optimizer class if we have already defined inside the +# distributed.optim package, this is so that we hide the +# functional optimizer to user and still provide the same API. +functional_optim_map = { + optim.Adagrad: _FunctionalAdagrad, + optim.Adam: _FunctionalAdam, + optim.AdamW: _FunctionalAdamW, + optim.SGD: _FunctionalSGD, + optim.Adadelta: _FunctionalAdadelta, + optim.RMSprop: _FunctionalRMSprop, + optim.Rprop: _FunctionalRprop, + optim.Adamax: _FunctionalAdamax, +} + + +def register_functional_optim(key, optim): + """ + Interface to insert a new functional optimizer to functional_optim_map + ``fn_optim_key`` and ``fn_optimizer`` are user defined. The optimizer and key + need not be of :class:`torch.optim.Optimizer` (e.g. for custom optimizers) + Example:: + >>> # import the new functional optimizer + >>> # xdoctest: +SKIP + >>> from xyz import fn_optimizer + >>> from torch.distributed.optim.utils import register_functional_optim + >>> fn_optim_key = "XYZ_optim" + >>> register_functional_optim(fn_optim_key, fn_optimizer) + """ + if key not in functional_optim_map: + functional_optim_map[key] = optim + + +def as_functional_optim(optim_cls: type, *args, **kwargs): + try: + functional_cls = functional_optim_map[optim_cls] + except KeyError as e: + raise ValueError( + f"Optimizer {optim_cls} does not have a functional counterpart!" + ) from e + + return _create_functional_optim(functional_cls, *args, **kwargs) + + +def _create_functional_optim(functional_optim_cls: type, *args, **kwargs): + return functional_optim_cls( + [], + *args, + **kwargs, + _allow_empty_param_list=True, + ) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/optim/zero_redundancy_optimizer.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/optim/zero_redundancy_optimizer.py new file mode 100644 index 0000000000000000000000000000000000000000..3183299a48347b4444cfe7b5105c1a1aadc8b4fd --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/optim/zero_redundancy_optimizer.py @@ -0,0 +1,1679 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +r"""Zero Redundancy Optimizer.""" + +import collections +import copy +import enum +import inspect +import io +import logging +from collections.abc import Callable +from itertools import chain +from typing import Any + +import torch +import torch.distributed as dist +from torch.distributed.algorithms.join import Join, Joinable, JoinHook +from torch.distributed.optim.utils import functional_optim_map +from torch.optim import Optimizer + + +__all__ = ["ZeroRedundancyOptimizer"] + + +logger = logging.getLogger(__name__) + + +# Credits: classy_vision/generic/distributed_util.py +def _recursive_copy_to_device( + value: Any, + non_blocking: bool, + device: torch.device, +) -> Any: + r""" + Recursively searches lists, tuples, dicts and copies tensors to device if possible. + + Non-tensor values are passed as-is in the result. + + .. note:: + These are all copies, so if there are two objects that reference + the same object, then after this call, there will be two different objects + referenced on the device. + """ + if isinstance(value, torch.Tensor): + return value.to(device, non_blocking=non_blocking) + + if isinstance(value, (list, tuple)): + values = [ + _recursive_copy_to_device(val, non_blocking=non_blocking, device=device) + for val in value + ] + return values if isinstance(value, list) else tuple(values) + + if isinstance(value, collections.abc.Mapping): + return { + key: _recursive_copy_to_device( + val, non_blocking=non_blocking, device=device + ) + for key, val in value.items() + } + + return value + + +def _is_trainable(param: torch.Tensor) -> bool: + r"""Return if a parameter is trainable, where trainability is equivalent to requiring a gradient.""" + return param.requires_grad + + +def _broadcast_object( + obj: Any, + src_rank: int, + group: object = dist.group.WORLD, + device: torch.device = torch.device("cpu"), +) -> Any: + r""" + Broadcasts an object to the given group. + + It will be sending the object if called from the source rank and receiving + the object otherwise. + + Arguments: + obj: object to broadcast; only used if called on the source rank. + src_rank (int): source rank. + group (``ProcessGroup``, optional): group used for the broadcast + (default: ``dist.group.WORLD``). + device (``torch.device``, optional): device to send from or receive + to (default: ``torch.device("cpu")``). + + Returns: + The broadcasted object. + """ + if dist.get_rank() == src_rank: + # Send the object + buffer = io.BytesIO() + torch.save(obj, buffer) + data = bytearray(buffer.getbuffer()) + length_tensor = torch.LongTensor([len(data)]).to(device) + data_send_tensor = torch.ByteTensor(data).to(device) + # pyrefly: ignore [bad-argument-type] + dist.broadcast(length_tensor, src=src_rank, group=group, async_op=False) + # pyrefly: ignore [bad-argument-type] + dist.broadcast(data_send_tensor, src=src_rank, group=group, async_op=False) + else: + # Receive the object + length_tensor = torch.LongTensor([0]).to(device) + # pyrefly: ignore [bad-argument-type] + dist.broadcast(length_tensor, src=src_rank, group=group, async_op=False) + data_recv_tensor = torch.empty( + [int(length_tensor.item())], dtype=torch.uint8, device=device + ) + # pyrefly: ignore [bad-argument-type] + dist.broadcast(data_recv_tensor, src=src_rank, group=group, async_op=False) + buffer = io.BytesIO(data_recv_tensor.cpu().numpy()) + obj = torch.load(buffer, map_location=device, weights_only=False) + return obj + + +class _ZeROJoinHook(JoinHook): + def __init__(self, zero): + assert isinstance(zero, ZeroRedundancyOptimizer), ( + "ZeRO join hook requires passing in a ZeroRedundancyOptimizer " + "instance as the state" + ) + self.zero = zero + super().__init__() + + def main_hook(self): + """ + Perform an optimizer step. + + This step updates the joined process's shard of + the parameters and broadcasts those parameters. + """ + self.zero.step() + + +class _DDPBucketAssignment: + r""" + Represent a :class:`DistributedDataParallel` bucket assignment. + + This means that a (possibly non-strict) subset of the parameters corresponding to + a DDP bucket assigned to a rank to update. + + Attributes: + bucket_index (int): index of the bucket determined by the DDP gradient + bucket all-reduce order. + parameters (List[torch.Tensor]): model parameters in the bucket + assigned to this rank. + offset (int): offset into the :class:`GradBucket` 's :meth:`parameters` + giving the index of the first element in the passed-in + ``parameters``; this equivalently indexes into the + :class:`GradBucket` 's :meth:`gradients`. + device (torch.device): device on which the parameters are stored. + tensor (torch.Tensor): flattened tensor giving the data of the + parameter subset assigned to the rank. + """ + + def __init__( + self, + bucket_index: int, + parameters: list[torch.Tensor], + offset: int, + ): + self.bucket_index = bucket_index + self.parameters = parameters + self.offset = offset + if len(self.parameters) == 0: + raise ValueError("Empty bucket assignment") + # DDP guarantees all parameters in the bucket have the same device + # pyrefly: ignore [read-only] + self.device: torch.device = self.parameters[0].device + self.tensor: torch.Tensor | None = None + + +class _OverlapStatus(enum.IntEnum): + r""" + Define possible statuses that :class:`ZeroRedundancyOptimizer` can be in when overlapping with :class:`DistributedDataParallel`. + + Attributes: + ``UNINITIALIZED``: The ZeRO instance is effectively uninitialized and + is waiting for DDP to finalize its bucketing. + ``DDP_HAS_REBUILT_BUCKETS``: DDP has rebuilt its buckets, meaning that + its bucketing is finalized. The ZeRO instance can now collect the + necessary information about the DDP bucketing. + ``INITIALIZED``: The ZeRO instance is fully initialized and can now + optimize parameters. + """ + + UNINITIALIZED = 0 + DDP_HAS_REBUILT_BUCKETS = 1 + INITIALIZED = 2 + + +class _OverlapInfo: + r""" + Information needed by :class:`ZeroRedundancyOptimizer` to overlap with :class:`DistributedDataParallel`. + + Arguments: + world_size (int): world size of the process group being used. + + Attributes: + shard_buckets (bool): if ``True``, then the assignment of each + :class:`DistributedDataParallel` bucket is partitioned across + possibly multiple :class:`ZeroRedundancyOptimizer` instances (i.e. + across possibly multiple ranks) to approximate uniformity following + a threshold given by the total parameter size divided by the world + size; if ``False``, then each bucket is wholly assigned to a single + :class:`ZeroRedundancyOptimizer` instance (i.e. to a single rank); + this should be set to the value passed into the hook constructor. + status (_OverlapStatus): current status; see :class:`_OverlapStatus` + for more information. + params_per_bucket (List[List[torch.Tensor]]): ``params_per_bucket[i]`` + gives the model parameters in the ``i``th bucket. + params_per_rank (List[List[torch.Tensor]]): ``params_per_rank[i]`` + gives the model parameters assigned to the ``i``th rank, where the + parameters are grouped by increasing bucket indices. + offsets (Dict[int, int]): maps from bucket index to the offset in + ``self.params_per_rank[rank]`` giving the index of the first + parameter in that bucket, where ``rank`` is this process's own + rank; the keys of this :class:`dict` are the bucket indices + assigned to this rank. + num_bucket_assignments (int): total number of bucket assignments across + all ranks; this is equal to the number of + :class:`DistributedDataParallel` gradient buckets if + ``shard_buckets=False`` and possibly greater otherwise. + total_size (int, optional): total size of all buckets (i.e. sum of + ``param.numel()`` for all ``param`` across all buckets) if + ``shard_buckets=True``; otherwise, ``None``. + broadcast_handles (List[Work]): :class:`list` of async work handles for + the parameter broadcasts. + bucket_index_to_future (Dict[int, torch.futures.Future]): + :class:`dict` mapping bucket index to the corresponding all-reduce + future. + bucket_index_to_bucket (Dict[int, dist.GradBucket]): :class:`dict` + mapping bucket index to the corresponding bucket. + bucket_indices_seen (List[int]): :class:`list` of the bucket indices + seen on this iteration. + """ + + def __init__(self, world_size) -> None: + self.status: _OverlapStatus = _OverlapStatus.UNINITIALIZED + self.shard_buckets: bool = False + + # Modified per bucket reconstruction + self.params_per_bucket: list[list[torch.Tensor]] = [] + self.params_per_rank: list[list[torch.Tensor]] = [[] for _ in range(world_size)] + self.offsets: dict[int, int] = {} + # Group Ranks + self.assigned_ranks_per_bucket: list[set[int]] = [] + self.num_bucket_assignments: int = 0 + self.total_size: int | None = None + + # Modified per iteration + self.broadcast_handles: list[Any] = [] + self.bucket_indices_seen: list[int] = [] + # Used by `hook_with_zero_step()` + self.bucket_index_to_future: dict[int, torch.futures.Future] = {} + self.bucket_index_to_bucket: dict[int, dist.GradBucket] = {} + + def wait_for_broadcasts(self) -> None: + r""" + Wait for all parameter broadcasts. + + This function should be called once all broadcasts have been scheduled, + meaning ``self.broadcast_handles`` is filled. This clears ``self.broadcast_handles`` + in preparation for the next iteration. + """ + assert len(self.broadcast_handles) == self.num_bucket_assignments, ( + f"Missing at least one broadcast handle on rank {dist.get_rank()}" + ) + _ = [x.wait() for x in self.broadcast_handles] + self.broadcast_handles.clear() + + def clear_per_iter_info(self) -> None: + r""" + Clear the data structures that are modified per-iteration. + + This function should be called at the end of an iteration. + """ + self.bucket_indices_seen.clear() + self.bucket_index_to_future.clear() + self.bucket_index_to_bucket.clear() + + +class ZeroRedundancyOptimizer(Optimizer, Joinable): + r""" + Wrap an arbitrary :class:`optim.Optimizer ` and shards its states across ranks in the group. + + The sharing is done as described by `ZeRO `_. + + The local optimizer instance in each rank is only + responsible for updating approximately ``1 / world_size`` parameters and + hence only needs to keep ``1 / world_size`` optimizer states. After + parameters are updated locally, each rank will broadcast its parameters to + all other peers to keep all model replicas in the same state. + ``ZeroRedundancyOptimizer`` can be used in conjunction with + :class:`torch.nn.parallel.DistributedDataParallel` to reduce per-rank peak + memory consumption. + + ``ZeroRedundancyOptimizer`` uses a sorted-greedy algorithm to pack a number + of parameters at each rank. Each parameter belongs to a single rank and is + not divided among ranks. The partition is arbitrary and might not match the + the parameter registration or usage order. + + Arguments: + params (``Iterable``): an ``Iterable`` of :class:`torch.Tensor` s + or :class:`dict` s giving all parameters, which will be sharded + across ranks. + + Keyword Args: + optimizer_class (:class:`torch.nn.Optimizer`): the class of the local + optimizer. + process_group (``ProcessGroup``, optional): ``torch.distributed`` + ``ProcessGroup`` (default: ``dist.group.WORLD`` initialized by + :meth:`torch.distributed.init_process_group`). + parameters_as_bucket_view (bool, optional): if ``True``, parameters are + packed into buckets to speed up communication, and ``param.data`` + fields point to bucket views at different offsets; if ``False``, + each individual parameter is communicated separately, and each + ``params.data`` stays intact (default: ``False``). + overlap_with_ddp (bool, optional): if ``True``, :meth:`step` is + overlapped with :class:`DistributedDataParallel` 's gradient + synchronization; this requires (1) either a functional optimizer + for the ``optimizer_class`` argument or one with a functional + equivalent and (2) registering a DDP communication hook + constructed from one of the functions in ``ddp_zero_hook.py``; + parameters are packed into buckets matching those in + :class:`DistributedDataParallel`, meaning that the + ``parameters_as_bucket_view`` argument is ignored. + If ``False``, :meth:`step` runs disjointly after the backward pass + (per normal). + (default: ``False``) + **defaults: any trailing arguments, which are forwarded to the local + optimizer. + + Example:: + + >>> # xdoctest: +SKIP + >>> import torch.nn as nn + >>> from torch.distributed.optim import ZeroRedundancyOptimizer + >>> from torch.nn.parallel import DistributedDataParallel as DDP + >>> model = nn.Sequential(*[nn.Linear(2000, 2000).to(rank) for _ in range(20)]) + >>> ddp = DDP(model, device_ids=[rank]) + >>> opt = ZeroRedundancyOptimizer( + >>> ddp.parameters(), + >>> optimizer_class=torch.optim.Adam, + >>> lr=0.01 + >>> ) + >>> ddp(inputs).sum().backward() + >>> opt.step() + + .. warning:: + Currently, ``ZeroRedundancyOptimizer`` requires that all of the + passed-in parameters are the same dense type. + + .. warning:: + If you pass ``overlap_with_ddp=True``, be wary of the following: Given + the way that overlapping :class:`DistributedDataParallel` with + :class:`ZeroRedundancyOptimizer` is currently implemented, the first + two or three training iterations do not perform parameter updates in + the optimizer step, depending on if ``static_graph=False`` or + ``static_graph=True``, respectively. This is because it needs + information about the gradient bucketing strategy used by + :class:`DistributedDataParallel`, which is not finalized until the + second forward pass if ``static_graph=False`` or until the third + forward pass if ``static_graph=True``. To adjust for this, one option + is to prepend dummy inputs. + + .. warning:: ZeroRedundancyOptimizer is experimental and subject to change. + """ + + def __init__( + self, + params, + optimizer_class: type[Optimizer], + process_group: Any | None = None, + parameters_as_bucket_view: bool = False, + overlap_with_ddp: bool = False, + **defaults: Any, + ): + r"""Init.""" + # Perform type and assumption checks on the input parameters + params = self._verify_and_init_params(params) + self._verify_same_dense_param_type() + + # NOTE: The parent constructor uses `add_param_group()` which is + # partially overloaded in ZeroRedundancyOptimizer, so we use the + # `initialized` flag to dissociate the behaviour of `add_param_group()` + # between the parent and child. + self.initialized = False + + Optimizer.__init__(self, params, defaults) + Joinable.__init__(self) + # Now, all parameters are held in both `self._all_params` and + # `self.param_groups` + + # Internal data structures (`_cache` indicates lazily evaluated) + self._param_to_rank_cache: dict[torch.Tensor, int] = {} + self._param_to_index_cache: dict[torch.Tensor, int] = {} + self._partition_parameters_cache: list[list[dict]] = [] + self._index_to_param_cache: list[torch.Tensor] = [] + self._device_to_params_per_rank_cache: dict[ + torch.device, list[list[torch.Tensor]] + ] = {} + self._bucket_assignments_per_rank_cache: list[ + dict[int, _DDPBucketAssignment] + ] = [] + self._is_trainable_mask = self._get_is_trainable_mask() + + # Default device for collective communication and buckets + self._default_device = self._all_params[0].device + + self.process_group = ( + process_group if process_group is not None else dist.group.WORLD + ) + self.world_size: int = dist.get_world_size(self.process_group) + self.rank: int = dist.get_rank(self.process_group) + self.global_rank: int = dist.distributed_c10d.get_global_rank( + # pyrefly: ignore [bad-argument-type] + self.process_group, + self.rank, + ) + + self._overlap_with_ddp: bool = overlap_with_ddp + self._optim_defaults = defaults + self._optim_constructor = self._get_optimizer_constructor(optimizer_class) + + # If `overlap_with_ddp=True`, local optimizer initialization is delayed + # to run time after the necessary information has been collected + if not overlap_with_ddp: + self._init_local_optimizer() + else: + self._overlap_info: _OverlapInfo = _OverlapInfo(self.world_size) + if parameters_as_bucket_view: + logger.warning( + "`parameters_as_bucket_view=True` will be ignored since " + "`overlap_with_ddp=True`; instead, a different bucketing " + "strategy will be used" + ) + + # `self._buckets` is used if `parameters_as_bucket_view=True`, in + # which case parameter data is flattened into contiguous bucket tensors + self.parameters_as_bucket_view = parameters_as_bucket_view + self._buckets: list[list[torch.Tensor]] = [] + self._build_param_buckets() + + # Optional consolidated optimizer state, only populated if this rank + # is the target in `consolidate_state_dict()` + self._all_state_dicts: list[dict[str, Any]] = [] + + self.initialized = True + + def _clear_cache(self) -> None: + r"""Clear the cached data structures giving partition information.""" + self._partition_parameters_cache.clear() + self._param_to_rank_cache.clear() + self._index_to_param_cache.clear() + self._param_to_index_cache.clear() + self._device_to_params_per_rank_cache.clear() + self._bucket_assignments_per_rank_cache.clear() + + def add_param_group(self, param_group: dict[str, Any]) -> None: + r""" + Add a parameter group to the :class:`Optimizer` 's ``param_groups``. + + This can be useful when fine tuning a pre-trained network, as frozen + layers can be made trainable and added to the :class:`Optimizer` as + training progresses. + + Arguments: + param_group (dict): specifies the parameters to be optimized and + group-specific optimization options. + + .. warning:: This method handles updating the shards on all partitions + but needs to be called on all ranks. Calling this on a subset of + the ranks will cause the training to hang because communication + primitives are called depending on the managed parameters and + expect all the ranks to participate on the same set of parameters. + """ + if self.initialized and self._overlap_with_ddp: + raise RuntimeError( + "ZeroRedundancyOptimizer with `overlap_with_ddp=True` only " + "supports a single parameter group" + ) + + super().add_param_group(param_group) + # NOTE: The rest of the method assumes that the call to the parent's + # `add_param_group()` appends the new parameter group and preserves + # the previous parameter-group ordering + + if self.initialized: + # Force a re-partitioning of the parameters + self._clear_cache() + param_groups = self._partition_parameters()[self.rank] + # NOTE: All parameters in the old parameter groups should be + # assigned to the same ranks so that the local optimizers do not + # need to be reinitialized + + # Add the parameters assigned to this rank from the new parameter + # group to the local optimizer, if any + if len(param_groups) == len(self.optim.param_groups) + 1: + self.optim.add_param_group(param_groups[-1]) + + # Update the bucketing strategy accordingly + if self.parameters_as_bucket_view: + self._build_param_buckets() + + def consolidate_state_dict(self, to: int = 0) -> None: + r""" + Consolidate a list of ``state_dict`` s (one per rank) on the target rank. + + Arguments: + to (int): the rank that receives the optimizer states (default: 0). + + Raises: + RuntimeError: if ``overlap_with_ddp=True`` and this method is + called before this :class:`ZeroRedundancyOptimizer` instance + has been fully initialized, which happens once + :class:`DistributedDataParallel` gradient buckets have been + rebuilt. + + .. warning:: This needs to be called on all ranks. + """ + self._check_overlap_initialized() + + # Sync the exposed `param_groups` attributes to the local optimizer in + # case they have been updated + self._sync_param_groups(self.param_groups, self.optim.param_groups) + + # Pull the sharded state from all ranks and store them in rank order + empty_messenger = torch.tensor( + [0], dtype=torch.uint8, device=self._default_device + ) + + # NOTE: We wastefully use `broadcast()` (e.g. instead of `gather()`) + # due to compatibility issues with NCCL backend; a possible follow-up + # is to move all sharded state management to RPC RRef + self._all_state_dicts = [] + for rank in range(self.world_size): + global_rank = dist.distributed_c10d.get_global_rank( + # pyrefly: ignore [bad-argument-type] + self.process_group, + rank, + ) + if self.rank == to: + # Consolidate all local `state_dict`s on this rank, storing on + # CPU to save GPU memory + if rank == self.rank: + # Directly append own optimizer state + self._all_state_dicts.append( + _recursive_copy_to_device( + self.optim.state_dict(), + non_blocking=True, + device=torch.device("cpu"), + ) + ) + else: + # Receive the optimizer state from the source rank + local_state_dict = _broadcast_object( + empty_messenger, + src_rank=global_rank, + group=self.process_group, + device=self._default_device, + ) + self._all_state_dicts.append( + _recursive_copy_to_device( + local_state_dict, + non_blocking=True, + device=torch.device("cpu"), + ) + ) + else: + if rank == self.rank: + # Send the optimizer state to the target rank + _ = _broadcast_object( + self.optim.state_dict(), + src_rank=self.global_rank, + group=self.process_group, + device=self._default_device, + ) + elif rank != to: + # Discard the received object; `broadcast()` is used for + # compatibility reasons + _ = _broadcast_object( + empty_messenger, + src_rank=global_rank, + group=self.process_group, + device=self._default_device, + ) + + def _verify_params_per_rank( + self, + params_per_rank: list[list[torch.Tensor]], + ) -> None: + r""" + Verify ``params_per_rank`` for :meth:`_partition_parameters`. + + The verification is done by checking that ``params_per_rank`` has length equal + to the world size and that it does not contain any parameters not passed into the + :class:`ZeroRedundancyOptimizer` constructor. + + The parameters in ``params_per_rank`` being a strict subset of those + passed into the constructor is valid since some parameters may be + frozen. + + Raises: + ValueError: if ``params_per_rank`` does not have length equal to + the world size or if it contains a parameter that was not + passed into the :class:`ZeroRedundancyOptimizer` constructor. + """ + if len(params_per_rank) != self.world_size: + raise ValueError( + "`params_per_rank` must have length equal to the world size" + ) + all_params_set = set(self._all_params) + for params in params_per_rank: + for param in params: + if param not in all_params_set: + raise ValueError( + "Passing a new parameter in `params_per_rank` that " + "was not passed into the ZeroRedundancyOptimizer " + "constructor" + ) + + def _partition_param_group( + self, param_group: dict[str, Any], params_per_rank: list[list[torch.Tensor]] + ) -> None: + r""" + Partition the parameter group ``param_group`` according to ``params_per_rank``. + + The partition will modify the ``self._partition_parameters_cache``. This method should + only be used as a subroutine for :meth:`_partition_parameters`. + + Arguments: + param_group (dict[str, Any]): a parameter group as normally defined + in an optimizer state. + params_per_rank (list[list[torch.Tensor]]): a :class:`list` of + length world size containing :class:`list` s of parameters to + assign to each rank. + """ + for rank, params in enumerate(params_per_rank): + rank_param_group = copy.copy(param_group) + rank_param_group["params"] = params + self._partition_parameters_cache[rank].append(rank_param_group) + + def _partition_parameters( + self, + params_per_rank: list[list[torch.Tensor]] | None = None, + ) -> list[list[dict]]: + r""" + Partitions parameters across distributed data parallel ranks. + + Arguments: + params_per_rank (list[list[torch.Tensor]], optional): a + :class:`list` of length world size containing :class:`list` s + of parameters to assign to each rank; this provides a way to + specify a partition manually. + If ``None``, the parameters are partitioned according to an + internal algorithm. + (default: ``None``) + + Returns: + A :class:`list` where each element of the list contains the + ``param_groups`` for a rank (which itself is a :class:`list` of + :class:`dict`); element 0 corresponds to rank 0, etc.; each rank + stores the ``param_groups`` for all ranks for the collective + communication in :meth:`step`. + + Raises: + ValueError: see :meth:`_validate_params_per_rank`. + RuntimeError: if ``params_per_rank`` is not ``None`` and this + :class:`ZeroRedundancyOptimizer` instance is using more than + one parameter group. + """ + if params_per_rank is None: + # Partition the parameters optimizing for uniformity + if len(self._partition_parameters_cache) == 0: + self._partition_parameters_cache = [[] for _ in range(self.world_size)] + sizes = [0] * self.world_size + for param_group in self.param_groups: + param_group_params_per_rank: list[list] = [ + [] for _ in range(self.world_size) + ] + # Sort the parameters by size (largest first) + params_sorted = sorted( + param_group["params"], key=lambda t: t.numel(), reverse=True + ) + for param in params_sorted: + # Greedily add the parameter to rank with smallest size so far + rank = self._get_min_index(sizes) + param_group_params_per_rank[rank].append(param) + sizes[rank] += param.numel() + # Apply the constructed partition of the parameter group + self._partition_param_group( + param_group, param_group_params_per_rank + ) + + return self._partition_parameters_cache + + # Partition the parameters according to `params_per_rank` + assert len(self._partition_parameters_cache) == 0, ( + "Specifying `params_per_rank` should only be done when the " + "parameters have not been partitioned yet" + ) + if len(self.param_groups) != 1: + raise RuntimeError( + "Specifying `params_per_rank` only supports a single parameter group" + ) + self._verify_params_per_rank(params_per_rank) + self._partition_parameters_cache = [[] for _ in range(self.world_size)] + + # Apply the passed-in partition of the parameter group + param_group = self.param_groups[0] + self._partition_param_group(param_group, params_per_rank) + + return self._partition_parameters_cache + + @property + def _param_to_rank(self) -> dict[torch.Tensor, int]: + r""":class:`dict` mapping parameters to their assigned data parallel rank in the partition.""" + if len(self._param_to_rank_cache) == 0: + for rank, param_groups in enumerate(self._partition_parameters()): + for param_group in param_groups: + for param in param_group["params"]: + self._param_to_rank_cache[param] = rank + return self._param_to_rank_cache + + @property + def _param_to_index(self) -> dict[torch.Tensor, int]: + r""" + :class:`dict` mapping parameters to their indices in the global optimizer state. + + NOTE: This assumes that the global optimizer state's indexing (in + ``state_dict``) follows a linear ordering over the parameter groups. + """ + if len(self._param_to_index_cache) == 0: + self._param_to_index_cache = { + p: i + for i, p in enumerate( + chain.from_iterable(g["params"] for g in self.param_groups) + ) + } + return self._param_to_index_cache + + @property + def _index_to_param(self) -> list[torch.Tensor]: + r"""List mapping parameter indices in the global optimizer scheme to the actual params.""" + if len(self._index_to_param_cache) == 0: + self._index_to_param_cache = list( + chain.from_iterable(g["params"] for g in self.param_groups) + ) + return self._index_to_param_cache + + def _broadcast_params_from_rank(self, rank: int): + r""" + Broadcast the shard of parameters from a given rank to all other ranks asynchronously. + + Arguments: + rank (int): the source rank. + + Returns: + A :class:`list` of async work handles for the ``broadcast()`` s + performed to synchronize the parameters. + """ + assert not self._overlap_with_ddp, ( + "`_broadcast_params_from_rank()` should not be used if " + "`overlap_with_ddp=True`; instead, the broadcasting should " + "happen in the DDP communication hook" + ) + handles = [] + if self.parameters_as_bucket_view: + for dev_i_buckets in self._buckets: + bucket = dev_i_buckets[rank] + global_rank = dist.distributed_c10d.get_global_rank( + # pyrefly: ignore [bad-argument-type] + self.process_group, + rank, + ) + handles.append( + dist.broadcast( + tensor=bucket, + src=global_rank, + group=self.process_group, + async_op=True, + ) + ) + else: + param_groups = self._partition_parameters()[rank] + global_rank = dist.distributed_c10d.get_global_rank( + # pyrefly: ignore [bad-argument-type] + self.process_group, + rank, + ) + for param_group in param_groups: + handles.extend( + dist.broadcast( + tensor=param.data, + src=global_rank, + group=self.process_group, + async_op=True, + ) + for param in param_group["params"] + ) + return handles + + def _sync_params(self): + r""" + Sync all parameter shards across the ranks. + + This rank sends its shard of the parameters to all other ranks and + receives a shard from each other rank. This is done using + ``broadcast()``. Parameters are sent bucket-by-bucket if + ``parameters_as_bucket_view=True``and sent parameter-by-parameter + otherwise. + """ + handles = [] + for rank in range(self.world_size): + handles.extend(self._broadcast_params_from_rank(rank)) + _ = [x.wait() for x in handles] + + @property + def _device_to_params_per_rank( + self, + ) -> dict[torch.device, list[list[torch.Tensor]]]: + r""" + Return device parameters assigned per rank. + + :class:`dict` mapping each device to a :class:`list` of the per-rank parameter + lists filtered to only include the parameters stored on that device. + Each per-rank parameter list gives the parameters assigned to that rank + to update. + + This is used for constructing the parameter buckets if + ``parameters_as_bucket_view=True``. + + Let ``dev_i`` denote the ``i``th device for this rank. Then: + ``dev_0`` maps to a list containing: + rank 0's assigned parameters stored on ``dev_0``, + rank 1's assigned parameters stored on ``dev_0``, + ... + ``dev_1`` maps to a list containing: + rank 0's assigned parameters stored on ``dev_1``, + rank 1's assigned parameters stored on ``dev_1``, + ... + ... + """ + assert self.parameters_as_bucket_view, ( + "`_device_to_params_per_rank` should only be used if " + "`parameters_as_bucket_view=True`" + ) + if len(self._device_to_params_per_rank_cache) == 0: + for rank, param_groups in enumerate(self._partition_parameters()): + for param_group in param_groups: + for param in param_group["params"]: + device = param.device + if device not in self._device_to_params_per_rank_cache: + self._device_to_params_per_rank_cache[device] = [ + [] for _ in range(self.world_size) + ] + self._device_to_params_per_rank_cache[device][rank].append( + param + ) + return self._device_to_params_per_rank_cache + + def _get_min_index( + self, + values: list[int], + disallowed_indices: set[int] | None = None, + ) -> int: + r""" + Return ``values.index(min(values))``, except only uses one pass. + + It also excludes any indices in ``disallowed_indices`` if provided. + + Arguments: + values: (List[int]): :class:`list` of values. + disallowed_indices (Optional[set[int]]): indices that are + disallowed from being the returned min index. + """ + min_index = -1 + min_value = float("inf") + for i, value in enumerate(values): + if disallowed_indices and i in disallowed_indices: + continue + if value < min_value: + min_value = value + min_index = i + assert min_index >= 0, "All indices are disallowed" + return min_index + + def _assign_bucket_subset_to_rank( + self, + bucket_index: int, + bucket_params: list[torch.Tensor], + bucket_offset: int, + assigned_rank: int, + assigned_ranks_per_bucket: list[set[int]], + ) -> None: + r""" + Assign ``bucket_params`` to the rank with the least size assigned so far and collects relevant information. + + The model parameters given by ``bucket_params`` represents a (possibly non-strict) + subset of the parameters corresponding to a :class:`DistributedDataParallel` bucket. + + Arguments: + bucket_index (int): index of the :class:`DistributedDataParallel` + gradient bucket. + bucket_params (List[torch.Tensor]): subset of the parameters + corresponding to the bucket to assign. + bucket_offset (int): offset giving the index of the first element + in ``bucket_params`` in the bucket's full parameter list. + assigned_rank (int): group rank to assign to. + assigned_ranks_per_bucket (list[set[int]]): :class:`set` of group ranks + assigned to each bucket. + """ + overlap_info = self._overlap_info + if len(bucket_params) == 0: + raise ValueError("Empty bucket assignment") + params_per_rank = overlap_info.params_per_rank + offsets = overlap_info.offsets + + self._bucket_assignments_per_rank_cache[assigned_rank][bucket_index] = ( + _DDPBucketAssignment(bucket_index, bucket_params, bucket_offset) + ) + if self.global_rank == assigned_rank: + offsets[bucket_index] = len(params_per_rank[assigned_rank]) + params_per_rank[assigned_rank].extend(bucket_params) + assigned_ranks_per_bucket[bucket_index].add(assigned_rank) + self._overlap_info.num_bucket_assignments += 1 + + @property + def _bucket_assignments_per_rank(self) -> list[dict[int, _DDPBucketAssignment]]: + r""" + Return DDP bucket parameters assigned per rank. + + :class:`list` of length world size consisting of :class:`dict` s + mapping bucket indices to :class:`_DDPBucketAssignment` s for each + rank. + """ + assert self._overlap_with_ddp, ( + "`_bucket_assignments_per_rank` only be used if `overlap_with_ddp=True`" + ) + if len(self._bucket_assignments_per_rank_cache) > 0: + return self._bucket_assignments_per_rank_cache + + overlap_info = self._overlap_info + assert overlap_info.status == _OverlapStatus.INITIALIZED + + self._bucket_assignments_per_rank_cache = [{} for _ in range(self.world_size)] + params_per_bucket = overlap_info.params_per_bucket + + if overlap_info.shard_buckets: + # Define the assignment threshold to approximate uniformity + assert overlap_info.total_size is not None, "`total_size` was not computed" + threshold = overlap_info.total_size / self.world_size # type: ignore[operator] + size_per_rank = [0 for _ in range(self.world_size)] + + num_buckets = len(params_per_bucket) + overlap_info.assigned_ranks_per_bucket = [set() for _ in range(num_buckets)] + assigned_ranks_per_bucket = overlap_info.assigned_ranks_per_bucket + if not overlap_info.shard_buckets: + # Assign each DDP bucket entirely to a single rank + for bucket_index, bucket_params in enumerate(params_per_bucket): + assert len(bucket_params) > 0, "Empty bucket" + assigned_rank = self._get_assigned_rank(bucket_index) + self._assign_bucket_subset_to_rank( + bucket_index, + bucket_params, + 0, + assigned_rank, + assigned_ranks_per_bucket, + ) + else: + # Assign each DDP bucket to possibly multiple ranks + # Specifically, sort the DDP buckets by increasing size, and for + # each bucket, iteratively assign the maximal unassigned subset + # with size less than `threshold` to the rank with the least total + # size so far -- each such assignment is represented by a + # `_DDPBucketAssignment` instance and only contains parameters from + # a single DDP bucket + params_per_bucket_enum = sorted( + enumerate(params_per_bucket), key=lambda x: sum(p.numel() for p in x[1]) + ) + for bucket_index, bucket_params in params_per_bucket_enum: + assert len(bucket_params) > 0, "Empty bucket" + bucket_offset = 0 + assignment_size = 0 + for param_index, param in enumerate(bucket_params): + param_numel = param.numel() + if ( + # pyrefly: ignore [unbound-name] + assignment_size + param_numel >= threshold + and param_index > bucket_offset + ): + assigned_rank = self._get_min_index( + # pyrefly: ignore [unbound-name] + size_per_rank, + assigned_ranks_per_bucket[bucket_index], + ) + # Include up to but not including the parameter that + # exceeded the threshold + self._assign_bucket_subset_to_rank( + bucket_index, + bucket_params[bucket_offset:param_index], + bucket_offset, + assigned_rank, + assigned_ranks_per_bucket, + ) + # pyrefly: ignore [unbound-name] + size_per_rank[assigned_rank] += assignment_size + bucket_offset = param_index + assignment_size = 0 + assignment_size += param_numel + # Assign the remainder of the bucket so that no assignment + # spans across two buckets + assigned_rank = self._get_min_index( + # pyrefly: ignore [unbound-name] + size_per_rank, + assigned_ranks_per_bucket[bucket_index], + ) + self._assign_bucket_subset_to_rank( + bucket_index, + bucket_params[bucket_offset:], + bucket_offset, + assigned_rank, + assigned_ranks_per_bucket, + ) + # pyrefly: ignore [unbound-name] + size_per_rank[assigned_rank] += assignment_size + + return self._bucket_assignments_per_rank_cache + + def _local_step( + self, + gradients: list[torch.Tensor | None] | None = None, + closure: Callable[[], float] | None = None, + **kwargs: Any, + ) -> float | None: + r""" + Perform a single optimizer step without syncing parameters across ranks. + + Arguments: + gradients (list[Optional[torch.Tensor]], optional): a :class:`list` + of length equal to the number of parameters assigned to this + rank containing gradient tensors or ``None`` as its elements; + a ``None`` in the :class:`list` indicates that the + corresponding parameter should not be updated. + If the argument itself is ``None``, then all parameters are + updated, and the gradients are assumed to be already populated. + (default: ``None``) + closure (Callable): a closure that re-evaluates the model and + returns the loss; optional for most optimizers and should be + ``None`` if ``gradients`` is not ``None``; (default: ``None``) + Returns: + Optional loss depending on the underlying local optimizer. + + .. warning:: + The argument ``gradients`` should only be specified (i.e. not + ``None``) if ``overlap_with_ddp=True``, in which case + :class:`ZeroRedundancyOptimizer` wraps a functional optimizer. + """ + Join.notify_join_context(self) + # Check if the model trainability has changed + is_trainable_mask = self._get_is_trainable_mask() + if is_trainable_mask != self._is_trainable_mask: + if self._overlap_with_ddp: + raise RuntimeError( + "ZeroRedundancyOptimizer with `overlap_with_ddp=True` " + "does not support changing parameter trainability at run " + "time" + ) + logger.warning( + "ZeroRedundancyOptimizer detected that the trainable " + "parameters changed; rebuilding the parameter buckets if " + "enabled" + ) + self._build_param_buckets() + self._is_trainable_mask = is_trainable_mask + + # Sync the exposed `param_groups` attributes to the local optimizer in + # case they have been updated + self._sync_param_groups(self.param_groups, self.optim.param_groups) + + # Run the optimizer step on this shard only + if gradients is None: + loss = ( + self.optim.step(**kwargs) + if closure is None + else self.optim.step(closure=closure, **kwargs) + ) + else: + assert self._overlap_with_ddp, ( + "Specifying `gradients` should not " + "be used when `overlap_with_ddp=False`" + ) + assert closure is None, ( + "`closure` is not supported when using a local functional optimizer" + ) + loss = self.optim.step(gradients=gradients) + + # Sync any updated attributes in the local optimizer to the exposed + # `param_groups` + self._sync_param_groups(self.optim.param_groups, self.param_groups) + + return loss + + # pyrefly: ignore [bad-override] + def step( + self, + closure: Callable[[], float] | None = None, + **kwargs: Any, + ) -> float | None: + r""" + Perform a single optimizer step and syncs parameters across all ranks. + + Arguments: + closure (Callable): a closure that re-evaluates the model and + returns the loss; optional for most optimizers. + Returns: + Optional loss depending on the underlying local optimizer. + + .. note:: Any extra parameters are passed to the base optimizer as-is. + """ + if self._overlap_with_ddp: + logger.warning( + "`step()` should not be included in the training loop when " + "`overlap_with_ddp=True`" + ) + return None + + # Perform the local optimizer step + loss = self._local_step(closure=closure, **kwargs) + + # Sync all of the updated parameter shards across the ranks + self._sync_params() + + return loss + + def join_hook(self, **kwargs): + r""" + Return the ZeRO join hook. + + It enables training on uneven inputs by + shadowing the collective communications in the optimizer step. + + Gradients must be properly set before this hook is called. + + Arguments: + kwargs (dict): a :class:`dict` containing any keyword arguments + to modify the behavior of the join hook at run time; all + :class:`Joinable` instances sharing the same join context + manager are forwarded the same value for ``kwargs``. + + This hook does not support any keyword arguments; i.e. ``kwargs`` is + unused. + """ + return _ZeROJoinHook(self) + + @property + def join_device(self) -> torch.device: + r"""Return default device.""" + return self._default_device + + @property + def join_process_group(self) -> Any: + r"""Return process group.""" + return self.process_group + + def load_state_dict(self, state_dict: dict[str, Any]) -> None: + r""" + Load the state pertaining to the given rank from the input ``state_dict``, updating the local optimizer as needed. + + Arguments: + state_dict (dict): optimizer state; should be an object returned + from a call to :meth:`state_dict`. + + Raises: + RuntimeError: if ``overlap_with_ddp=True`` and this method is + called before this :class:`ZeroRedundancyOptimizer` instance + has been fully initialized, which happens once + :class:`DistributedDataParallel` gradient buckets have been + rebuilt. + """ + self._check_overlap_initialized() + + for index, value in state_dict["state"].items(): + param = self._index_to_param[index] + if self._param_to_rank[param] != self.rank: + # Clear any state irrelevant to this rank + state_dict["state"][index] = None + else: + # Load the parameter state to the local optimizer + self.optim.state[param] = _recursive_copy_to_device( + value, non_blocking=True, device=param.device + ) + # Force zero-dimensional tensors (like Adam "step") on CPU + for state_name, state_value in self.optim.state[param].items(): + if torch.is_tensor(state_value) and state_value.dim() == 0: + self.optim.state[param][state_name] = state_value.cpu() + + super().load_state_dict(state_dict) + + # Sync the input state with the exposed and local optimizer states + self._sync_param_groups(state_dict["param_groups"], self.param_groups) + self._sync_param_groups(self.param_groups, self.optim.param_groups) + + def state_dict(self) -> dict[str, Any]: + r""" + Return the last global optimizer state known to this rank. + + .. warning: + If the state has not been consolidated to this rank, this raises a + runtime error, and even if it has, the state may not be up-to-date, + depending on when :meth:`consolidate_state_dict` was last called. + + Raises: + RuntimeError: if ``overlap_with_ddp=True`` and this method is + called before this :class:`ZeroRedundancyOptimizer` instance + has been fully initialized, which happens once + :class:`DistributedDataParallel` gradient buckets have been + rebuilt; or if this method is called without a preceding call + to :meth:`consolidate_state_dict`. + """ + self._check_overlap_initialized() + + if len(self._all_state_dicts) == 0: + raise RuntimeError( + "Optimizer state has not been consolidated on this rank. " + f"Please call `consolidate_state_dict(to={self.rank})` on " + "all ranks beforehand if you meant to save the global state." + ) + + # Get the possibly-stale global optimizer state that uses global + # parameter indexing + state_dict = super().state_dict() + + # Update the global optimizer state with local state information, + # factoring in the translation from local to global indexing + for rank, local_state_dict in enumerate(self._all_state_dicts): + local_param_groups = local_state_dict["param_groups"] + global_param_groups = self._partition_parameters()[rank] + assert len(local_param_groups) == len(global_param_groups), ( + "Mismatch between number of local and global parameter groups" + ) + + for local_param_group, global_param_group in zip( + local_param_groups, global_param_groups + ): + # `local_param_group` stores local indices, while + # `global_param_group` stores the tensors directly + local_param_indices = local_param_group["params"] + global_params = global_param_group["params"] + + assert len(local_param_indices) == len(global_params), ( + "Mismatch between number of local and global parameters in parameter group" + ) + for local_param_index, global_param in zip( + local_param_indices, global_params + ): + # Update the global parameter state, if any + if local_param_index in local_state_dict["state"]: + global_param_index = self._param_to_index[global_param] + state_dict["state"][global_param_index] = local_state_dict[ + "state" + ][local_param_index] + + # Sort the parameters in the state + state_dict["state"] = dict(sorted(state_dict["state"].items())) + return state_dict + + @staticmethod + def _sync_param_groups( + src_param_groups: list[dict[Any, Any]], + dst_param_groups: list[dict[Any, Any]], + ) -> None: + r""" + Sync the attributes from the source parameter groups to the destination parameter groups. + + Example attributes include learning rate or scheduler attributes. The + two parameter groups should have the same length (i.e. same number of + parameter groups). + + Arguments: + src_param_groups (list[dict]): parameter groups giving the + attribute settings to copy. + dst_param_groups (list[dict]): parameter groups giving the + attribute settings to set. + """ + assert len(src_param_groups) == len(dst_param_groups), ( + "Mismatch between number of source and destination parameter groups" + ) + for src_param_group, dst_param_group in zip(src_param_groups, dst_param_groups): + # Sync all attributes except the parameters + for attr in filter(lambda x: x != "params", src_param_group.keys()): + dst_param_group[attr] = src_param_group[attr] + + def _build_param_buckets(self) -> None: + r""" + Build parameter buckets if ``parameters_as_bucket_view=True``. + + For each device that stores this rank's parameters, there is a + bucket (represented as a tensor) containing all of the parameters on + that device that are assigned to a given rank in the parameter update + partition. + + This method is called in the constructor and any time parameter + trainability is changed. + + .. warning:: + The current implementation assumes that all of the parameters in a + bucket are of the same dense type when allocating the bucket's + tensor. + + .. warning:: + If the model parameters are stored across more than one device, + then the storage partitioning must be the same across all + processes in order for parameter synchronization to work. + """ + if not self.parameters_as_bucket_view or self._overlap_with_ddp: + return + + # `self._buckets[i][j]` are the parameters stored on device i and + # assigned to rank j + num_devices = len(self._device_to_params_per_rank) + self._buckets = [[] for _ in range(num_devices)] # type: ignore[assignment] + + for dev_i, (device, params_per_rank) in enumerate( + self._device_to_params_per_rank.items() + ): + for params in params_per_rank: + bucket_size = 0 + dtype = None + trainable_params = [] + for param in params: + if not _is_trainable(param): + # Clone in case the parameter was previously part of + # a bucket to avoid the data from being destroyed + param.data = param.data.detach().clone() + else: + bucket_size += param.numel() + trainable_params.append(param) + dtype = param.dtype # assumes all same dtype + + if bucket_size == 0: + # Create a dummy bucket if there are no parameters + bucket = torch.zeros(1, device=device) + else: + # Construct the bucket (assuming all dense and same dtype) + bucket = torch.empty(bucket_size, dtype=dtype, device=device) + offset = 0 + for param in trainable_params: + offset_next = offset + param.numel() + bucket[offset:offset_next].copy_(param.data.flatten()) + param.data = bucket[offset:offset_next].view_as(param.data) + offset = offset_next + self._buckets[dev_i].append(bucket) # type: ignore[arg-type] + + def _build_ddp_param_buckets(self) -> None: + r""" + Build the DDP bucket with parameters assigned to this rank. + + For each DDP bucket with parameters assigned to this rank, flattens the + data of those parameters into a single tensor and saves the tensor to + the ``tensor`` attribute in the corresponding + :class:`_DDPBucketAssignment` instance stored in + ``self._bucket_assignments_per_rank``. + + :class:`DistributedDataParallel` guarantees that the parameters + corresponding to a gradient bucket have the same device and the same + dtype. + """ + for bucket_assignments in self._bucket_assignments_per_rank: + for bucket_assignment in bucket_assignments.values(): + params = bucket_assignment.parameters + bucket_size = 0 + dtype = None + for param in params: + assert _is_trainable(param), ( + "Model parameter " + "corresponding to a gradient in a DDP bucket should " + "require a gradient" + ) + bucket_size += param.numel() + dtype = param.dtype # assumes all same dtype + assert bucket_size > 0, "Empty bucket" + + # Construct the bucket tensor (assuming all dense and same dtype) + tensor = torch.empty( + bucket_size, dtype=dtype, device=bucket_assignment.device + ) + offset = 0 + for param in params: + offset_next = offset + param.numel() + tensor[offset:offset_next].copy_(param.data.flatten()) + param.data = tensor[offset:offset_next].view_as(param.data) + offset = offset_next + bucket_assignment.tensor = tensor + + def _verify_and_init_params( + self, + params: Any, + ) -> list[torch.Tensor] | list[dict]: + r""" + Verify the type of ``params`` and initializes ``self._all_params`` as a :class:`list` of all parameters. + + The initializagtion will first make sure that provided ``params`` is valid. + + Arguments: + params (Any): Candidate parameter list or parameter groups to verify. + + Raises: + TypeError: ``params`` has an invalid type. + ValueError: ``params`` is empty. + + Returns: + The persistent form of ``params`` to be passed into the parent + :class:`Optimizer` constructor -- i.e. returns ``params`` as a + :class:`list` to ensure that it can be iterated over again. + """ + if isinstance(params, torch.Tensor): + raise TypeError( + "`params` argument should be an iterable of " + f"Tensors, but got {torch.typename(params)}" + ) + try: + all_params = list(params) + except TypeError as e: + raise TypeError( + "`params` argument should be an iterable of Tensors" + f" or dicts, but got {torch.typename(params)}" + ) from e + if len(all_params) == 0: + raise ValueError("ZeroRedundancyOptimizer got an empty parameter list") + all_tensors = True + all_dicts = True + for param in all_params: + all_tensors &= isinstance(param, torch.Tensor) + all_dicts &= isinstance(param, dict) + if not all_tensors and not all_dicts: + raise TypeError( + "`params` argument should be an iterable of Tensors or dicts" + ) + # Ensure that `self._all_params` contains a list of all parameters + if all_tensors: + self._all_params = all_params + elif all_dicts: + self._all_params = [] + # `all_params` contains parameter groups (not parameters) + for param_group in all_params: + if "params" not in param_group: + raise ValueError( + "Each parameter group passed-in via `params` must " + "have a 'params' key mapping to the parameters in " + "the group" + ) + self._all_params.extend(param_group["params"]) + return all_params + + def _verify_same_dense_param_type(self) -> None: + r""" + Verify that all parameters are of the same dense type. + + The method assumes that ``self._all_params`` has been initialized + and is non-empty. + + Raises: + ValueError: ``params`` contains sparse parameters or parameters + of varying dense types. + + NOTE: This method can be removed once support for sparse parameters + and varying parameter types is added. + """ + typename = torch.typename(self._all_params[0]) + if self._all_params[0].is_sparse: + raise ValueError( + "ZeroRedundancyOptimizer only supports using " + "the same dense type for all parameters but got " + f"{typename}" + ) + for param in self._all_params[1:]: + other_typename = torch.typename(param) + if other_typename != typename: + raise ValueError( + "ZeroRedundancyOptimizer only supports " + "using the same dense type for all " + f"parameters but got both {typename} and " + f"{other_typename}" + ) + + def _get_is_trainable_mask(self) -> list[bool]: + r"""Return a boolean mask indicating if each parameter is trainable (``requires_grad``) or not.""" + return list(map(_is_trainable, self._all_params)) + + def _init_local_optimizer(self) -> None: + r""" + Initialize this rank's local optimizer, responsible for its subset of the parameters. + + The local optimizer is saved in ``self.optim``. + """ + assert self._optim_constructor is not None, ( + "The local optimizer class has not been set" + ) + + param_groups = self._partition_parameters()[self.rank] + # `overlap_with_ddp=True` requires a local functional optimizer + if self._overlap_with_ddp: + # Functional optimizers only support a single parameter group and + # require passing in the parameters as a list + assert len(param_groups) == 1, ( + "Initializing the local " + "functional optimizer with more than one parameter group" + ) + params = param_groups[0]["params"] + # Try to pass `_allow_empty_param_list=True` to avoid erroring + if ( + "_allow_empty_param_list" + in inspect.signature(self._optim_constructor).parameters + ): + self.optim: Any = self._optim_constructor( + params, **self._optim_defaults, _allow_empty_param_list=True + ) + else: + logger.warning( + "%s does not support the argument " + "`_allow_empty_param_list`; ZeroRedundancyOptimizer may " + "error due to an empty parameter list", + self._optim_constructor, + ) + self.optim: Any = self._optim_constructor( + params, **self._optim_defaults + ) # type: ignore[no-redef] + + # Log information about the DDP and ZeRO bucketing + if dist.get_debug_level() != dist.DebugLevel.OFF: + local_numel = sum(p.numel() for p in params) + num_assigned_buckets = len( + self._bucket_assignments_per_rank[self.global_rank] + ) + logger.info( + "rank %s with %s parameters across %s buckets", + self.global_rank, + local_numel, + num_assigned_buckets, + ) + if self.global_rank == 0: + logger.info( + "%s DDP buckets and %s bucket assignments", + len(self._overlap_info.params_per_bucket), + self._overlap_info.num_bucket_assignments, + ) + else: + # NOTE: Passing `param_groups` into the local optimizer constructor + # bypasses the empty parameter list check + self.optim: Optimizer = self._optim_constructor( + param_groups, **self._optim_defaults + ) # type: ignore[no-redef] + + # TODO: Manually add `self.param_groups` if using a functional + # optimizer; remove this if/when the functional optimizers support + # multiple parameter groups + if self._overlap_with_ddp and not hasattr(self.optim, "param_groups"): + assert hasattr(self.optim, "param_group"), ( + "The functional optimizer should set at least one of the " + "attributes `param_group` or `param_groups`" + ) + self.optim.param_groups = [self.optim.param_group] # type: ignore[attr-defined] + + self._sync_param_groups(self.optim.param_groups, self.param_groups) + + def _init_zero_for_overlap(self) -> None: + r"""Perform a delayed initialization of the local optimizer and the supporting data structures.""" + assert self._overlap_with_ddp, ( + "`_init_zero_for_overlap()` should only be called when " + "`overlap_with_ddp=True`" + ) + self._overlap_info.status = _OverlapStatus.INITIALIZED + self._clear_cache() + self._partition_parameters(self._overlap_info.params_per_rank) + self._build_ddp_param_buckets() + self._init_local_optimizer() + + def _get_assigned_rank(self, bucket_index: int) -> int: + r""" + Return the single rank assigned to a :class:`DistributedDataParallel` gradient bucket. + + Arguments: + bucket_index (int): index of the :class:`DistributedDataParallel` + bucket for which to get the assigned rank. + """ + assert not self._overlap_info.shard_buckets, ( + "The bucket assignment requires global bucket information and " + "will be computed later; there should be no need to use this " + "method" + ) + return bucket_index % self.world_size + + def _check_overlap_initialized(self): + r""" + Check the delayed initialization depending on the value of ``overlap_with_ddp``. + + The delayed initialization has occurred (see + :meth:`_init_zero_for_overlap`) if ``overlap_with_ddp=True``, and + raises a ``RuntimeError`` if not. This should preface methods that + should not be run before that delayed initialization. + + Raises: + RuntimeError: if ``overlap_with_ddp=True`` and + :meth:`_init_zero_for_overlap` has not been called. + """ + if ( + self._overlap_with_ddp + and self._overlap_info.status != _OverlapStatus.INITIALIZED + ): + raise RuntimeError( + "This method should not be called until this " + "ZeroRedundancyOptimizer instance has been fully " + "initialized" + ) + + def _get_optimizer_constructor(self, optimizer_class: Any) -> Any: + r""" + Return the optimizer constructor using validation and transformation depending on ``overlap_with_ddp``. + + Returns: + - ``optimizer_class`` if ``overlap_with_ddp=False`` and + ``optimizer_class`` is not a functional optimizer. + - ``optimizer_class`` if ``overlap_with_ddp=True`` and + ``optimizer_class`` is already a functional optimizer. + - The functional equivalent of ``optimizer_class`` if + ``overlap_with_ddp=True`` and ``optimizer_class`` is not + already a functional optimizer (assuming the equivalent + exists). + + Raises: + ValueError: + + - if ``overlap_with_ddp=True`` but ``optimizer_class`` is + neither a functional optimizer nor translatable to a + functional optimizer. + - if ``overlap_with_ddp=False`` and ``optimizer_class`` is a + functional optimizer. + """ + functional_optims = functional_optim_map.values() + if not self._overlap_with_ddp: + if optimizer_class in functional_optims: + # Using a functional optimizer is only supported when + # `overlap_with_ddp=True` + raise ValueError( + f"Passing in a functional optimizer {optimizer_class} " + "when `overlap_with_ddp=False`" + ) + else: + return optimizer_class + else: + if optimizer_class in functional_optims: + # Already a functional optimizer + return optimizer_class + elif optimizer_class in functional_optim_map: + # Translate the passed-in optimizer class to its functional + # equivalent if `overlap_with_ddp=True` + optim_constructor = functional_optim_map[optimizer_class] + logger.info( + "Using the functional optimizer %s " + "instead of %s since " + "`overlap_with_ddp=True`", + optim_constructor, + optimizer_class, + ) + return optim_constructor + else: + raise ValueError( + "Using `ddp_with_overlap=True` requires using a " + "functional optimizer, but there is no supported functional " + f"optimizer equivalent for {optimizer_class}" + ) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/optim/zero_redundancy_optimizer.pyi b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/optim/zero_redundancy_optimizer.pyi new file mode 100644 index 0000000000000000000000000000000000000000..8ffbb04f13ffcfdba07589eac0594c80cc28968d --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/optim/zero_redundancy_optimizer.pyi @@ -0,0 +1,85 @@ +# mypy: allow-untyped-defs +import enum +from collections.abc import Callable +from typing import Any, overload + +import torch +from torch.distributed.algorithms.join import Joinable, JoinHook +from torch.optim import Optimizer + +class _ZeROJoinHook(JoinHook): + zero: Any = ... + def __init__(self, zero: Any) -> None: ... + def main_hook(self) -> None: ... + +class _DDPBucketAssignment: + bucket_index: int + parameters: list[torch.Tensor] + offset: int + device: torch.device + tensor: torch.Tensor | None + +class _OverlapStatus(enum.IntEnum): + UNINITIALIZED = ... + DDP_HAS_REBUILT_BUCKETS = ... + INITIALIZED = ... + +class _OverlapInfo: + status: Any = ... + params_per_bucket: Any = ... + params_per_rank: Any = ... + offsets: Any = ... + broadcast_handles: Any = ... + bucket_index_to_future: Any = ... + bucket_index_to_bucket: Any = ... + bucket_indices_seen: Any = ... + assigned_ranks_per_bucket: list[set[int]] = ... + total_size: int = ... + shard_buckets: bool = ... + def __init__(self) -> None: ... + def wait_for_broadcasts(self) -> None: ... + def clear_per_iter_info(self) -> None: ... + +class ZeroRedundancyOptimizer(Optimizer, Joinable): + functional_optim_map: Any = ... + initialized: bool = ... + process_group: Any = ... + world_size: int = ... + rank: int = ... + global_rank: int = ... + parameters_as_bucket_view: bool = ... + optim: Any = ... + _device_to_device_index: dict[torch.device, int] = ... + _overlap_with_ddp: bool = ... + _overlap_info: _OverlapInfo = ... + _buckets: list[list[torch.Tensor]] = ... + _bucket_assignments_per_rank: list[dict[int, _DDPBucketAssignment]] = ... + def __init__( + self, + params: Any, + optimizer_class: type[Optimizer], + process_group: Any | None = ..., + parameters_as_bucket_view: bool = ..., + overlap_with_ddp: bool = ..., + **defaults: Any, + ) -> None: ... + def add_param_group(self, param_group: dict[str, Any]) -> None: ... + def consolidate_state_dict(self, to: int = ...) -> None: ... + @overload + def step(self, closure: None = None, **kwargs: Any) -> None: ... + @overload + def step(self, closure: Callable[[], float], **kwargs: Any) -> float: ... + def load_state_dict(self, state_dict: dict[str, Any]) -> None: ... + def state_dict(self) -> dict[str, Any]: ... + def _local_step( + self, + gradients: list[torch.Tensor | None] | None = None, + closure: Callable[[], float] | None = None, + **kwargs: Any, + ) -> float | None: ... + def _get_assigned_rank(self, bucket_index: int) -> int: ... + def _init_zero_for_overlap(self) -> None: ... + def join_hook(self, **kwargs): ... + @property + def join_device(self) -> torch.device: ... + def join_process_group(self) -> Any: ... diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/pipelining/_IR.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/pipelining/_IR.py new file mode 100644 index 0000000000000000000000000000000000000000..eae7def75d5faf1b9427e7a477b866b8229b1651 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/pipelining/_IR.py @@ -0,0 +1,1257 @@ +# mypy: allow-untyped-defs +# Copyright (c) Meta Platforms, Inc. and affiliates +import copy +import logging +import operator +from collections import defaultdict +from collections.abc import Callable +from enum import Enum +from inspect import Parameter, Signature, signature +from types import MethodType +from typing import Any, Union + +import torch +import torch.fx as fx +from torch.distributed import ProcessGroup +from torch.export import ExportedProgram +from torch.export.unflatten import ( + _assign_attr, + _AttrKind, + _sink_params, + InterpreterModule, +) +from torch.fx.node import map_aggregate +from torch.fx.passes.split_module import split_module + +from ._backward import _null_coalesce_accumulate, stage_backward +from ._unflatten import _outline_submodules +from ._utils import PipeInfo +from .stage import _PipelineStage + + +logger = logging.getLogger(__name__) + +# TODO: +# 1. investigate gradient sync for shared parameters. how does DDP do it? +# 2. Add parameter movement to split_module + + +PP_SUBMOD_PREFIX = "submod_pp" + + +def get_submod_name(stage_idx: int): + """Returns the name of the submod for a given stage index. + For example, "submod_pp_0", "submod_pp_1", etc. + """ + return "_".join([PP_SUBMOD_PREFIX, str(stage_idx)]) + + +def _find_loss_from_output_and_spec(output_val, spec_val): + if spec_val is False: + return None + if spec_val is True: + if not isinstance(output_val, fx.Node): + raise RuntimeError( + f"Loss spec must specify a dynamic value but got {output_val}" + ) + return output_val + + if isinstance(spec_val, (tuple, list)): + if not isinstance(output_val, (tuple, list)): + raise RuntimeError( + f"Output value {output_val} must match type of loss specification " + f"{spec_val}" + ) + if len(output_val) != len(spec_val): + raise RuntimeError( + f"Output value {output_val} must match length of loss specification " + f"{spec_val}" + ) + for out, spec in zip(output_val, spec_val): + loss_val = _find_loss_from_output_and_spec(out, spec) + if loss_val is not None: + return loss_val + raise RuntimeError(f"Did not find loss value in specification {spec_val}") + + if isinstance(spec_val, dict): + if not isinstance(output_val, dict): + raise RuntimeError( + f"Output value {output_val} must match type of loss specification " + f"{spec_val}" + ) + if set(output_val.keys()) != set(spec_val.keys()): + raise RuntimeError( + f"Output value {output_val} must match keys of loss specification " + f"{spec_val}" + ) + for k in spec_val: + loss_val = _find_loss_from_output_and_spec(output_val[k], spec_val[k]) + if loss_val is not None: + return loss_val + raise RuntimeError(f"Did not find loss value in specification {spec_val}") + + raise RuntimeError(f"Unsupported type {type(spec_val)} in loss specification") + + +def _find_loss_output(mod: torch.nn.Module, g: fx.Graph, output_loss_value_spec): + output_nodes = [n for n in g.nodes if n.op == "output"] + assert len(output_nodes) == 1 + output_node = output_nodes[0] + output_val = output_node.args[0] + generated_spec: Any = None + + if isinstance(mod, TrivialLossWrapper): + # TrivialLossWrapper is pre-defined by PiPPy. + # It has loss as the only output so we can safely assume the first output arg is the loss. + assert len(output_node.args) == 1 + loss_node = output_val + generated_spec = TrivialLossWrapper.loss_spec + elif output_loss_value_spec is None: + # Use default spec, i.e. search for "loss" in output values + if isinstance(output_val, dict) and "loss" in output_val: + loss_node = output_val["loss"] + generated_spec = {k: k == "loss" for k in output_val} + else: + loss_node = None + generated_spec = None + else: + loss_node = _find_loss_from_output_and_spec(output_val, output_loss_value_spec) + generated_spec = output_loss_value_spec + + return loss_node, output_node, generated_spec + + +def _insert_stage_symbolic_backward( + g: fx.Graph, + loss_node: fx.Node, + output_node: fx.Node, +): + # Collect metadata about tuple output values. TODO: move this to split_module or FX IR + tuples: dict[fx.Node, tuple] = {} + for node in reversed(g.nodes): + if node.op == "call_function": + # In the forward pass, only emit placeholder, module calls, and + # getitem calls. If we have a target other than getitem in this + # (forward-only) code, there is a bug. + assert node.target is operator.getitem, ( + "Found non-getitem call in forward pass. Please report a bug to PiPPy" + ) + assert len(node.args) == 2, ( + "Found malformed getitem call. Please report a bug to PiPPy" + ) + indexed_value, node_idx = tuple(node.args) + + # indexed_value is a collection that we are indexing into. It could + # exist in the tuples map if we've processed another `getitem` + # already. + existing_list_size = ( + len(tuples[indexed_value]) if indexed_value in tuples else -1 + ) + new_list_size = max(node_idx + 1, existing_list_size) + + reconstructed_list = [None for _ in range(new_list_size)] + + # Copy over existing elements if present + if indexed_value in tuples: + for i, val in enumerate(tuples[indexed_value]): + reconstructed_list[i] = val + + # Populate value represented by this node + reconstructed_list[node_idx] = node + + tuples[indexed_value] = tuple(reconstructed_list) + + # Keep track of nodes that dominate the loss node. + # We will only emit backward operations for nodes that can contribute + # to the specified loss value. + live_nodes = {loss_node: None} + val_to_grad: dict[fx.Node, fx.Node | None] = {loss_node: None} + + def assign_or_accumulate_grad(forward_node, grad_value): + if forward_node in val_to_grad and forward_node.op != "placeholder": + grad_value = g.call_function( + _null_coalesce_accumulate, + (val_to_grad[forward_node], grad_value), + ) + val_to_grad[forward_node] = grad_value + + with g.inserting_before(output_node): + for node in reversed(g.nodes): + if node not in live_nodes: + continue + + def add_to_live_nodes(n): + live_nodes.setdefault(n, None) + + fx.node.map_arg(node.args, add_to_live_nodes) + fx.node.map_arg(node.kwargs, add_to_live_nodes) + if node.op == "call_module": + output_grads: tuple[fx.Node | None, ...] | fx.Node | None + if node in tuples: + stage_output = tuples[node] + output_grads = tuple(val_to_grad.get(n) for n in tuples[node]) + outputs_with_grads_idxs = [ + i for i, n in enumerate(tuples[node]) if n in live_nodes + ] + else: + stage_output = (node,) + output_grads = val_to_grad[node] + outputs_with_grads_idxs = [0] + + output_grads = ( + (output_grads,) + if not isinstance(output_grads, tuple) + else output_grads + ) + + grad_call = g.call_function( + stage_backward, + kwargs={ + "stage_output": stage_output, + "output_grads": output_grads, + "input_values": list(node.all_input_nodes), + "outputs_with_grads_idxs": outputs_with_grads_idxs, + }, + ) + # Insert backward stage debug info + kwargs_copy = dict(grad_call.kwargs) + grad_call.kwargs = kwargs_copy + + grad_call_proxy = fx.Proxy(grad_call) + grads = grad_call_proxy.node + + input_nodes = list(node.all_input_nodes) + grads_proxy = fx.Proxy(grads) + for i, input_node in enumerate(input_nodes): + assign_or_accumulate_grad(input_node, grads_proxy[i].node) # type: ignore[index] + + return g + + +class PipeSequential(torch.nn.Sequential): + @staticmethod + def from_sequential(sequential_instance: torch.nn.Sequential): + return PipeSequential(*[copy.copy(m) for m in sequential_instance]) + + def forward(self, input): + for i, module in enumerate(self): + input = module(input) + if i != len(self) - 1: + pipe_split() + return input + + +class LossWrapper(torch.nn.Module): + """ + LossWrapper is a convenient abstract class that allows you to wrap up both + your model as well as its loss function and specify the connectivity between + the inputs, model, loss function, and output value. Example:: + + class MyModelWrapper(LossWrapper): + def forward(self, x, targets): + model_out = self.module(x) + loss_value = self.loss_fn(model_out, targets) + return loss_value + + The above example defines a connectivity where we expect the forward/loss/backward + training procedure to take two arguments (x and targets), pass x into the module + to get the output of the feedforward computation, pass the model output and the + targets value into the loss function, and get and return the loss value, which will + be backpropagated by PiPPy. The above class would then be instantiated like:: + + model = ... # instantiate the model + loss_fn = torch.nn.MSELoss() # for the sake of demonstration + + wrapper = MyModelWrapper(model, loss_fn) + pipe = Pipe.from_tracing(wrapper, ...) + + """ + + def __init__(self, module, loss_fn): + super().__init__() + self.module = module + self.loss_fn = loss_fn + + def forward(self, *args, **kwargs): + raise NotImplementedError( + "This instance of LossWrapper does not have an overridden" + "forward(). Please implement forward() to specify the arguments, " + "connection between the module and loss, and loss output " + "value." + ) + + +class TrivialLossWrapper(LossWrapper): + # pyrefly: ignore [bad-override] + def forward(self, x, targets): + model_out = self.module(x) + return self.loss_fn(model_out, targets) + + loss_spec = True + + +# Pipe model representation +# +# Pipe can be thought of as an `nn.Sequential++`. That is to say: it specifies +# a single topological ordering of pipeline "stages" that, when run in series, +# constitutes all of the operations of the program. However, unlike `nn.Sequential`, +# Pipe allows non-local usages of values, so long as those uses still respect +# topological ordering. In particular: +# +# 1. Non-local activations. This type of usage can appear in, for example, skip +# connections. These values will be directly transmitted from the "def" stage +# to all stages that use them skipping intermediate stages. During autograd, +# gradients will be propagated back through this skip connection reverse +# to how activations propagated in the forward pass. +# 2. Non-local parameter/module invocations. This occurs when a parameter is used +# in a stage downstream of where it is resident. These values can be carried +# forward similarly to (1), but in addition one might want to replicate the +# value on multiple stages. Gradients for these shared parameters will be +# accumulated separately on each stage, but there will be an additional +# gradient accumulation before the optimizer step. + + +# Register `_pipe_split()` as an ATen operator. This is required for Export to +# preserve this marker in the graph. +torch.library.define("pippy::_pipe_split", "() -> ()") + + +@torch.library.impl("pippy::_pipe_split", "BackendSelect") +def _pipe_split(): + return None + + +@torch.library.register_fake("pippy::_pipe_split") # type: ignore[no-redef] +def _pipe_split(): # noqa: F811 + return None + + +# Add an alias for convenience +aten_pipe_split_alias = torch.ops.pippy._pipe_split.default + +# Ask Export to preserve the `_pipe_split` op. +# See examples in pytorch/torch/fx/node.py +fx.node._side_effectful_functions.add(aten_pipe_split_alias) + + +# User facing API +def pipe_split(): + """ + pipe_split is a special operator that is used to mark the boundary between + stages in a module. It is used to split the module into stages. It is a + no-op if your annotated module is run eagerly. + + Example: + >>> # xdoctest: +SKIP + >>> def forward(self, x): + >>> x = torch.mm(x, self.mm_param) + >>> x = torch.relu(x) + >>> pipe_split() + >>> x = self.lin(x) + >>> return x + + The above example will be split into two stages. + """ + return torch.ops.pippy._pipe_split() + + +class MultiUseParameterConfig(Enum): + TRANSMIT = 1 + REPLICATE = 2 + + +MultiUseParamSpec = Union[MultiUseParameterConfig, dict[str, MultiUseParameterConfig]] + + +class DetachExecutor(fx.Interpreter): + """ + Special interpreter to run the split_gm in testing that detaches all inputs to + a module invocation. This is needed so that the values at the boundary are + leaf modules in autograd execution. + """ + + def __init__(self, module, garbage_collect_values=True): + garbage_collect_values = False + super().__init__(module, garbage_collect_values) + self.value_remap = {} + + def run(self, *args, initial_env=None): # type: ignore[override] + self.value_remap = {} + return super().run(*args, initial_env=initial_env) + + def call_module(self, target, args, kwargs): + def detach_tensors(a): + if isinstance(a, torch.Tensor) and a.requires_grad: + if a not in self.value_remap: + new_val = a.detach().requires_grad_(True) + self.value_remap[a] = new_val + return self.value_remap[a] + else: + return a + + """ + def dont_traverse_size(a): + return type(a) is not torch.Size + """ + + args = map_aggregate( + args, + detach_tensors, # dont_traverse_size + ) + kwargs = map_aggregate( + kwargs, + detach_tensors, # dont_traverse_size + ) + + return super().call_module(target, args, kwargs) + + def call_function(self, target, args, kwargs): + # HACK to reroute saved input tensors to point to the detach()ed version + if target is stage_backward: + kwargs = dict(kwargs) + kwargs["input_values"] = [ + self.value_remap.get(v, v) for v in kwargs["input_values"] + ] + return super().call_function(target, args, kwargs) + + +class _NodeReference: + def __init__(self, name): + self.name = name + + name: str + + +class _LinearNodeList: + def __init__(self, node_list): + self.serialize_node_list = [] + for node in node_list: + node_args = fx.node.map_arg(node.args, lambda n: _NodeReference(n.name)) # type: ignore[arg-type,return-value] + node_kwargs = fx.node.map_arg(node.kwargs, lambda n: _NodeReference(n.name)) # type: ignore[arg-type,return-value] + serialize_node = fx.Node( + graph=None, # type: ignore[arg-type] + name=node.name, + op=node.op, + target=node.target, + args=node_args, # type: ignore[arg-type] + kwargs=node_kwargs, # type: ignore[arg-type] + return_type=node.type, + ) + serialize_node.meta = copy.copy(node.meta) + self.serialize_node_list.append(serialize_node) + + def to_graph(self): + graph = fx.Graph() + + ref_str_to_node: dict[str, fx.Node] = {} + + def ref_to_node(arg): + if isinstance(arg, _NodeReference): + return ref_str_to_node[arg.name] + else: + return arg + + for node in self.serialize_node_list: + node_args = map_aggregate(node.args, ref_to_node) + node_kwargs = map_aggregate(node.kwargs, ref_to_node) + deser_node = graph.create_node( + op=node.op, + target=node.target, + args=node_args, # type: ignore[arg-type] + kwargs=node_kwargs, # type: ignore[arg-type] + name=node.name, + type_expr=node.type, + ) + ref_str_to_node[node.name] = deser_node + + return graph + + +def _direct_serialization_deserialize(body, nodes): + """ + Custom `__reduce__` method for serialization. + DO AS I SAY -- NOT AS I DO. This violates the principle that + GraphModules serialize via code export & re-tracing. We allow + for this here because **PIPE STAGES SHOULD NOT BE PERSISTED + TO DISK -- THIS IS ONLY FOR TRANSMISSION VIA RPC**. Persisting + these instances to disk will expose internal implementation + details of `fx.Graph` and related data structures and is + NOT advised. + """ + + class DummyModule(torch.nn.Module): + def __init__(self, body): + super().__init__() + self.__dict__.update(body) + + dummy = DummyModule(body) + + return fx.GraphModule(dummy, nodes.to_graph()) + + +def _direct_serialization_reduce(self): + serialization_dict = dict(self.__dict__) + serialization_dict.pop("_graph") + return ( + _direct_serialization_deserialize, + (serialization_dict, _LinearNodeList(self.graph.nodes)), + ) + + +def _modify_graph_op_device( + gm: torch.fx.GraphModule, + new_device: torch.device, +): + """ + Modify the device argument of all "call_function" nodes in the graph. This + is useful for moving the graph to a different device. In particular for + generator ops, like torch.ones. + """ + modified = False + for node in gm.graph.nodes: + if node.op == "call_function": + if "device" in node.kwargs and node.kwargs["device"] != new_device: + logger.debug( + f"Changing device of Node {node.name} from {node.kwargs['device']} to {new_device}" # noqa: G004 + ) + node.update_kwarg("device", new_device) + modified = True + elif node.op == "call_module": + # Recursively modify "device" in submodules + submod = gm.get_submodule(node.target) + if isinstance(submod, torch.fx.GraphModule): + _modify_graph_op_device(submod, new_device) + elif isinstance(submod, InterpreterModule): + # If unflattening has been performed, we need to access its graph module by `.graph_module` + _modify_graph_op_device(submod.graph_module, new_device) # type: ignore[arg-type] + else: + logger.warning( + f"Skipping device modification for submodule {node.target} because it is a {type(submod)}" # noqa: G004 + ) + + if modified: + gm.recompile() + + +class Pipe(torch.nn.Module): + def __init__( + self, + split_gm: fx.GraphModule, + num_stages: int, + has_loss_and_backward: bool, + loss_spec, + ): + # TODO: is there a way not to hard wire init? + torch.nn.Module.__init__(self) + self.split_gm: fx.GraphModule = split_gm + self.executor: DetachExecutor = DetachExecutor(self.split_gm) + self.num_stages: int = num_stages + self.has_loss_and_backward = has_loss_and_backward + self.loss_spec = loss_spec + + for node in split_gm.graph.nodes: + assert ( + node.op in {"call_module", "placeholder", "output"} + or (node.op, node.target) == ("call_function", operator.getitem) + or (node.op, node.target) == ("call_method", "backward") + or (node.op, node.target) == ("call_function", stage_backward) + or (node.op, node.target) + == ("call_function", _null_coalesce_accumulate) + ), node + + # Detect replicated parameters so we know that we have to do an additional allreduce + # before applying the optimizer + # + # Note that this also handles the case where there were multiple calls to a single + # module from different stages, regardless of whether that module invocation + # was handled by the logic above. + + # Map parameter value to a dictionary that maps the user pipeline module + # to the local qualname within that module + params_to_users: dict[torch.nn.Parameter, dict[str, str]] = {} + + for m_qualname, mod in self.split_gm.named_children(): + for p_qualname, param in mod.named_parameters(): + params_to_users.setdefault(param, {}) + params_to_users[param][m_qualname] = p_qualname + + self.replicated_params: list[dict[str, str]] = [ + use_mapping + for _, use_mapping in params_to_users.items() + if len(use_mapping) > 1 + ] + + # We must break the aliasing relationship between the replicated parameters for correct + # numerics in reference runs. If we do not do this, the autograd tape in separate stages + # will have a reference to the same tensor value and will erroneously apply gradient + # updates multiple times. Therefore, for each replicated parameter set, we deepcopy the + # values so that we have separate instances. + for param_mapping in self.replicated_params: + for submod_name, param_qualname in param_mapping.items(): + submod = getattr(self.split_gm, submod_name) + atoms = param_qualname.split(".") + for atom in atoms[:-1]: + submod = getattr(submod, atom) + setattr(submod, atoms[-1], copy.deepcopy(getattr(submod, atoms[-1]))) + + def throw(self, *args, **kwargs): + raise RuntimeError( + "To run pipeline locally, invoke the Pipe object directly, not `split_gm`" + ) + + self.split_gm.forward = throw + + # Make submodules use custom direct-serialized GraphModule + i = 0 + while True: + try: + name = get_submod_name(i) + submod = getattr(self.split_gm, name) + submod.__class__.__reduce__ = _direct_serialization_reduce + i += 1 + except AttributeError: + break + + def forward(self, *args, **kwargs): + executor_args = args + if len(kwargs) > 0: + parameters = [] + for node in self.split_gm.graph.nodes: + if node.op == "placeholder": + if node.args and len(node.args) > 0: + parameters.append( + Parameter( + node.target, + Parameter.POSITIONAL_OR_KEYWORD, + default=node.args[0], + ) + ) + else: + parameter_kind = Parameter.POSITIONAL_OR_KEYWORD + param_name = node.target + if node.target.startswith("**"): + parameter_kind = Parameter.VAR_KEYWORD # type: ignore[assignment] + param_name = param_name[2:] + elif node.target.startswith("*"): + parameter_kind = Parameter.VAR_POSITIONAL # type: ignore[assignment] + param_name = param_name[1:] + parameters.append(Parameter(param_name, parameter_kind)) + signature = Signature(parameters) + ba = signature.bind(*args, **kwargs) + ba.apply_defaults() + executor_args = ba.arguments.values() # type: ignore[assignment] + + res = self.executor.run(*executor_args) + + return res + + def get_stage_module(self, stage_idx: int) -> torch.nn.Module: + """ + Return a stage module corresponding to `stage_idx` of the `pipe`. + """ + if stage_idx < 0 or stage_idx >= self.num_stages: + raise ValueError(f"Invalid stage index {stage_idx}!") + + submod_name = get_submod_name(stage_idx) + return getattr(self.split_gm, submod_name) + + @staticmethod + def _number_and_count_forward_stages(gm: fx.GraphModule): + num_stages = 0 + found_idxs: dict[int, None] = {} + for node in gm.graph.nodes: + if node.op == "call_module" and node.target.startswith(PP_SUBMOD_PREFIX): + node.meta["stage_idx"] = int(node.target[len(PP_SUBMOD_PREFIX) + 1 :]) + found_idxs.setdefault(node.meta["stage_idx"]) + num_stages += 1 + + # this assert will fail if a split point is inserted before the first layer, which creates empty first submodule + # Update: the following assert may fail against some torch versions >= + # 2.2.0, as: + # submod_0, submod_1, submod_2, ... + # may be named as + # submod_0, submod_2, submod_4, ... + # TODO: investigate + # assert all(i in found_idxs for i in range(num_stages)) + + return num_stages + + @staticmethod + def _from_traced( + mod: torch.nn.Module, + exported_program: ExportedProgram, + multi_use_param_spec: MultiUseParamSpec | None = None, + output_loss_value_spec=None, + split_policy: Callable[[torch.fx.GraphModule], torch.fx.GraphModule] + | None = None, + ): + """ + Additionally, the ``output_loss_value_spec`` value can be specified to disambiguate + which value in the output of `forward` is the loss value on which PiPPy should apply + backpropagation. For example, if your ``forward`` returns a tuple ``(loss, model_out)``, + you can specify ``output_loss_value_spec=(True, False)``. Or, if your ``forward`` returns + a dict ``{'loss': loss_value, 'model_out': model_out}``, you can specify + ``output_loss_value_spec={'loss': True, 'model_out': False}`` + """ + + traced = exported_program.module(check_guards=False) + + if split_policy is not None: + logger.info("Auto-splitting model") + traced = split_policy(traced) # type: ignore[arg-type] + + logger.debug(traced.print_readable(print_output=False)) # type: ignore[operator] + + # Deduplicate `get_attr` nodes that refer to the same parameter . Downstream code for moving + # parameters relies on the invariant that parameter accesses happen once. This is not necessarily + # the case (especially with custom tracers), so fix that up here. + get_attr_nodes: dict[str, fx.Node] = {} + for node in traced.graph.nodes: # type: ignore[union-attr] + if node.op == "get_attr": + get_attr_nodes.setdefault(node.target, node) + + if get_attr_nodes[node.target] != node: + node.replace_all_uses_with(get_attr_nodes[node.target]) + traced.graph.erase_node(node) # type: ignore[operator, union-attr] + + # avoid looking at next node by keeping track of previous pipe_split + prev_pipe_split_idx = -1 + pipe_split_nodes_to_erase = set() + for i, node in enumerate(traced.graph.nodes): # type: ignore[arg-type, union-attr] + if (node.op, node.target) == ("call_function", pipe_split): + if prev_pipe_split_idx == i - 1: + pipe_split_nodes_to_erase.add(node) + prev_pipe_split_idx = i + + for node in pipe_split_nodes_to_erase: + traced.graph.erase_node(node) # type: ignore[operator, union-attr] + + traced.recompile() # type: ignore[operator] + + part_idx = 0 + + def split_callback(n: fx.Node): + nonlocal part_idx + if (n.op, n.target) == ( + "call_function", + aten_pipe_split_alias, + ): + logger.debug(f"Found pipe_split {part_idx}") # noqa: G004 + part_idx += 1 + return part_idx + + # TODO: what does split do with module invocations? does it move the modules + # into the submodules? + split = split_module(traced, mod, split_callback, partition_affix="pp") # type: ignore[arg-type] + # a (custom) tracer can produce dead code like orphan get_attr nodes + split.graph.eliminate_dead_code() + + # peephole to remove pipe_split + for submodule in split.modules(): + if isinstance(submodule, fx.GraphModule): + for node in submodule.graph.nodes: + if (node.op, node.target) == ( + "call_function", + aten_pipe_split_alias, + ): + submodule.graph.erase_node(node) + submodule.recompile() + + for name, submodule in split.named_children(): + if isinstance(submodule, fx.GraphModule): + new_submod = _outline_submodules(submodule.graph) + # Replace old submod + split.register_module(name, new_submod) + + # TODO: backport this into split_module + def delete_user_reference(node, user): + """ + Delete reference of `node` from `user`'s arg list. + Args: + - node: a `get_attr` node at root. + - user: a submodule node that uses `node`. + """ + assert len(user.kwargs) == 0 + use_idxs = [i for i, arg in enumerate(user.args) if arg == node] + assert len(use_idxs) == 1 + args_copy = list(user.args) + args_copy.pop(use_idxs[0]) + user.args = tuple(args_copy) + logger.debug( + f"Deleted {node} from user {user}, arg index = {use_idxs[0]}" # noqa: G004 + ) + + # A list of param referrals for deferred deletion. + # To be accumulated in `move_param_to_callee`. + to_delete = [] + + def _recursive_getattr_with_parent(mod, fqn): + # Returns getattr call given a nested FQN, and the last parent + atoms = fqn.split(".") + for atom in atoms[:-1]: + if not hasattr(mod, atom): + return None, None + mod = getattr(mod, atom) + if not hasattr(mod, atoms[-1]): + return mod, None + attr = getattr(mod, atoms[-1]) + return mod, attr + + def move_param_to_callee( + root, + callee_name, + param_fqn, + ): + """ + Move a parameter from the root module to a submodule. + Args: + root: The root module. + callee_name: The name of the submodule to move the parameter to. + param_fqn: The fully qualified name of the parameter to move. + """ + # `atoms` is a list of strings representing the path to the + # parameter in the original model + atoms = param_fqn.split(".") + mod_itr, param_val = _recursive_getattr_with_parent(split, param_fqn) + # Check whether the parameter is a buffer or a parameter + is_buffer = atoms[-1] in mod_itr._buffers + + # Check whether the parameter is a tensor + assert isinstance(param_val, torch.Tensor), ( + f"Expected '{param_fqn}' to be {torch.Tensor} but got {type(param_val)}." + + ( + f" It might happen if module '{param_fqn}' was passed to some 'leaf function'" + f"(see https://pytorch.org/docs/stable/fx.html#fx.wrap). Please inspect " + f"usages of '{param_fqn}' in the traced graph." + if isinstance(param_val, torch.nn.Module) + else "" + ) + ) + + # Get submodule + callee = root.get_submodule(callee_name) + assert not hasattr(callee, param_fqn), ( + f"Module {callee_name} already has a parameter named {param_fqn}" + ) + + # Assign the parameter to the submodule + if is_buffer: + _assign_attr( + param_val, + callee, + param_fqn, + attr_kind=_AttrKind.BUFFER, + persistent=True, # TODO: handle non-persistent buffer + ) + else: + _assign_attr( + param_val, + callee, + param_fqn, + attr_kind=_AttrKind.PARAMETER, + ) + logger.debug(f"Moved parameter {param_fqn} to {callee_name}") # noqa: G004 + + # Next step is to replace placeholder of submodule with a get_attr. + # Those placeholders are created by `split_module` inside each + # submodule. + # Update: this step is now moved to `_sink_params` because + # `_sink_params` can do it recursively (i.e. for modules inside + # submodule) + + to_delete.append((mod_itr, atoms[-1])) + + # Get the list of all parameters in the root module + attr_nodes = list(filter(lambda n: n.op == "get_attr", split.graph.nodes)) + for node in attr_nodes: + # Check whether the parameter is used in only one submodule + if len(node.users) > 1: + logger.info( + f"Parameter {node.target} used in multiple stages: {node.users}." # noqa: G004 + ) + for user in node.users: + assert user.op == "call_module" + # Move parameter into submodule + move_param_to_callee( + split, + user.target, + node.target, + ) + + # [aliasing] store tensor id -> list of FQNs, built from state dict + # Also assign non-persistent buffers + id_to_fqns: dict[int, set[str]] = defaultdict(set) + for fqn, tensor in mod.state_dict(keep_vars=True).items(): + id_to_fqns[id(tensor)].add(fqn) + for fqn, tensor in mod.named_buffers(): + id_to_fqns[id(tensor)].add(fqn) + + # After moving the params to their corresponding hierarchies, we also + # need to move the `get_attr` nodes from the root of the graph to those + # hierarchies. + # [aliasing] use id -> fqn mapping to list out all valid FQNs + inputs_to_state: dict[str, list[str]] = {} + for attr in attr_nodes: + _, tensor = _recursive_getattr_with_parent(mod, attr.target) + fqns = list(id_to_fqns[id(tensor)]) + if fqns: + inputs_to_state[attr.name] = fqns + elif attr.target in exported_program.constants: # lifted constants + inputs_to_state[attr.name] = [attr.target] + + # [aliasing] for each submodule split, assign attributes on FQNs that may be used. + # We determine this based on whether or not the FQN attribute parent exists. + # i.e. if the last submodule exists, assign the attribute. + added_attributes: dict[str, list[str]] = defaultdict(list) + for fqn, tensor in mod.state_dict(keep_vars=True).items(): + for name, submod in split.named_children(): + if isinstance(submod, fx.GraphModule): + parent, child = _recursive_getattr_with_parent(submod, fqn) + if ( + parent and child is None + ): # parent exists, attribute doesn't -> assign + added_attributes[name].append(fqn) + setattr(parent, fqn.split(".")[-1], tensor) + + # Deferral deletion: Remove the original attributes (to params) from the + # root GraphModule + for mod_itr, last_atom in to_delete: + try: + delattr(mod_itr, last_atom) + except AttributeError: + # This is expected if the parameter is used in multiple stages + pass + + # This is done by (1) `_sink_params` at each submodule; + for submod in split.children(): + if isinstance(submod, fx.GraphModule): + _sink_params(submod, inputs_to_state, []) + submod.graph.lint() + submod.recompile() + + # [aliasing] This step is not super necessary, but helps reduce parameter usage/memory. + # After _sink_params() routine has run, clean up unused attributes that we previously added. + # Determine this based on the get_attr nodes - if not used, remove it. + for name, attributes in added_attributes.items(): + submod = getattr(split, name) + unused_attributes = set(attributes) + # track used attributes in the submodule, running DFS on subgraph hierarchy + stack = [("", submod)] # (scope, submodule) + while stack: + scope, _mod = stack.pop() + if isinstance(_mod, (fx.GraphModule, InterpreterModule)): + for node in _mod.graph.nodes: + if node.op == "get_attr": + # get_attr might get access deeper level attribute + fqn = scope + "." + node.target if scope else node.target + unused_attributes.discard(fqn) + for _name, _submod in _mod.named_children(): + stack.append((scope + "." + _name if scope else _name, _submod)) + # delete unused attributes + for attr in unused_attributes: + mod_itr, atoms = submod, attr.split(".") + for atom in atoms[:-1]: + mod_itr = getattr(mod_itr, atom) + delattr(mod_itr, atoms[-1]) + + for node in attr_nodes: + # And (2): remove `get_attr` node from submod's arg list + for user in copy.copy(node.users): + assert user.op == "call_module" + delete_user_reference(node, user) + # And (3): remove the `get_attr` node from the root graph. + split.graph.erase_node(node) + + split.delete_all_unused_submodules() + split.graph.lint() + split.recompile() + + num_stages = Pipe._number_and_count_forward_stages(split) + + has_loss_and_backward = False + generated_loss_spec = output_loss_value_spec + + if output_loss_value_spec is not None: + loss_node, output_node, generated_loss_spec = _find_loss_output( + mod, split.graph, output_loss_value_spec + ) + if loss_node is not None: + _insert_stage_symbolic_backward( + split.graph, + loss_node, + output_node, + ) + split.recompile() + has_loss_and_backward = True + logger.debug("Pipeline is in training mode, backward pass generated") + else: + raise RuntimeError( + f"Did not find any loss value according to {output_loss_value_spec=}" + ) + else: + logger.debug("Pipeline is in inference mode, backward pass not generated") + + logger.debug(f"Full pipe model:\n{split}") # noqa: G004 + + return Pipe( + split, + num_stages, + has_loss_and_backward, + generated_loss_spec, + ) + + def print_readable(self): + """ + Print the pipe in a human-readable format. + This will print both the root pipe and each stage module. + """ + self.split_gm.print_readable() + + @staticmethod + def _trace_with_export( + mod: torch.nn.Module, + example_args: tuple[Any, ...], + example_kwargs: dict[str, Any] | None = None, + ) -> ExportedProgram: + logger.info("Tracing model ...") + try: + ep = torch.export.export(mod, example_args, example_kwargs) + except Exception as e: + raise RuntimeError( + "It seems that we cannot capture your model as a full graph. " + "Typical reasons include graph breaks, data/shape-dependent " + "control flow, or missing meta kernels for custom operators. " + "You can use our manual pipeline interfaces, or try to fix the " + "graph breaks, see https://pytorch.org/docs/stable/export.html" + ) from e + + return ep + + @staticmethod + def from_tracing( + mod: torch.nn.Module, + example_args: tuple[Any, ...], + example_kwargs: dict[str, Any] | None = None, + split_policy: Callable[[fx.GraphModule], fx.GraphModule] | None = None, + ): + # If a param will be used in multiple pipeline stages, we default the strategy to REPLICATE'ing the param across + # stages instead of TRANSMIT'ting it + multi_use_param_spec = MultiUseParameterConfig.REPLICATE + + # Figure out which output is loss from output_chunk_spec + output_loss_value_spec: Any = None + # Deprecated + """ + if output_chunk_spec is not None: + output_loss_value_spec = map_aggregate( + output_chunk_spec, lambda v: isinstance(v, _LossReducer) + ) + """ + + # Trace with export + exported_program = Pipe._trace_with_export( + mod, + example_args, + example_kwargs, + ) + + pipe = Pipe._from_traced( + mod, + exported_program, + multi_use_param_spec, + output_loss_value_spec=output_loss_value_spec, + split_policy=split_policy, + ) + + # Users want the first pipeline stage to accept kwargs if the original + # program does. This is controlled by the `_codegen` field of the graph, + # so we make a copy here. Note: we only want the input spec and not the + # output spec, because the output spec is for the last stage. Maybe a + # TODO? Not sure yet. + split = pipe.split_gm + traced = exported_program.module() + submod0 = next(iter(split.children())) + submod0_sign = signature(submod0.forward) + model_sign = signature(traced.forward) + if len(model_sign.parameters) != len(submod0_sign.parameters): + # We don't change the signature of the first stage if it takes + # different number of args than original model + logger.info( + f"Original model takes {len(model_sign.parameters)} args but the " # noqa: G004 + f"first pipeline stage takes {len(submod0_sign.parameters)}. " + "Please provide args to respective pipeline stages." + ) + else: + # Support kwargs for the first stage + submod0.graph._codegen = copy.deepcopy(traced.graph._codegen) # type: ignore[union-attr] + # `_replace` is actually not "private" or internal. based on this doc: + # To prevent conflicts with field names, the method and attribute names + # start with an underscore + submod0.graph._codegen.pytree_info = ( # type: ignore[union-attr] + submod0.graph._codegen.pytree_info._replace(out_spec=None) # type: ignore[operator, union-attr] + ) + submod0.recompile() + + return pipe + + def __str__(self): + return self.split_gm.__str__() + + def __repr__(self): + return self.split_gm.__repr__() + + def info(self) -> PipeInfo: + """ + Get information about the pipe. + + Returns + ------- + PipeInfo + A dataclass containing information about the pipe. + """ + return PipeInfo( + graph=self.split_gm.graph, + num_stages=self.num_stages, + has_loss_and_backward=self.has_loss_and_backward, + ) + + def build_stage( + self, + stage_index: int, + device: torch.device, + group: ProcessGroup | None = None, + ) -> _PipelineStage: + """ + Create a `PipelineStage` given a stage index and distributed group. + The `PipelineStage` can run with `PipelineSchedule`s. + """ + # Find stage module + stage_module = self.get_stage_module(stage_index) + + # Move ops argument to device + # Today PT2 tracer does not treat `x.device` as a symbolic device; + # instead, the device of tracing time got burned into the generated + # code. Here we provide a workaround for users to manually modify the + # "device" kwarg of operations. Such operation may include: + # `torch.ones`, `torch.zeros`, `torch.rand`, etc. + if isinstance(stage_module, torch.fx.GraphModule): + _modify_graph_op_device(stage_module, device) + else: + logger.warning( + f"Expected a `torch.fx.GraphModule` but got {type(stage_module)}" # noqa: G004 + ) + + # Detach pipe info + # Note: be careful what's included in `pipe_info`. We don't want to keep + # a reference to `Pipe` or `Pipe.split_gm` which stops python from + # recycling them. When python recycles them, other stage modules (which + # are irrelevant to current rank) can be automatically freed. + pipe_info = self.info() + return _PipelineStage(stage_module, stage_index, pipe_info, device, group) + + +class SplitPoint(Enum): + """ + Enum representing the points at which a split can occur in the execution of a submodule. + Attributes: + BEGINNING: Represents adding a split point *before* the execution of a certain submodule in the `forward` function. + END: Represents adding a split point *after* the execution of a certain submodule in the `forward` function. + """ + + BEGINNING = 1 + END = 2 + + +# For backward compatibility, we kept the PipeSplitWrapper class because `class +# SplitPoint` used to be defined in this class. +class PipeSplitWrapper: + # Create a class alias for BC + SplitPoint = SplitPoint + + +def _split_before_forward(self, *args, **kwargs): + pipe_split() + return self._orig_forward(*args, **kwargs) + + +def _split_after_forward(self, *args, **kwargs): + try: + return self._orig_forward(*args, **kwargs) + finally: + pipe_split() + + +def annotate_split_points(mod: torch.nn.Module, spec: dict[str, SplitPoint]): + # TODO: make this implementation out-of-place? + for qualname, split_type in spec.items(): + atoms = qualname.split(".") + predecessor_module = mod + for i, atom in enumerate(atoms[:-1]): + try: + predecessor_module = getattr(predecessor_module, atom) + except AttributeError as e: + raise AttributeError( + f"Specified target {qualname} referenced " + f"nonexistent module {'.'.join(atoms[: i + 1])}" + ) from e + + mod_to_wrap = getattr(predecessor_module, atoms[-1]) + mod_to_wrap._orig_forward = mod_to_wrap.forward + if split_type == SplitPoint.BEGINNING: + mod_to_wrap.forward = MethodType(_split_before_forward, mod_to_wrap) + elif split_type == SplitPoint.END: + mod_to_wrap.forward = MethodType(_split_after_forward, mod_to_wrap) + else: + raise ValueError("Unknown split point type.") + + +def pipeline( + module: torch.nn.Module, + mb_args: tuple[Any, ...], + mb_kwargs: dict[str, Any] | None = None, + split_spec: dict[str, SplitPoint] | None = None, + split_policy: Callable[[fx.GraphModule], fx.GraphModule] | None = None, +) -> Pipe: + """ + Split a module based on a specification. + + See `Pipe` for more details. + + Arguments + --------- + module: + The module to be split. + mb_args: + Example positional inputs, in micro-batch form. + mb_kwargs: + Example keyword inputs, in micro-batch form. (default: `None`) + split_spec: + A dictionary using submodule names as split marker. (default: `None`) + split_policy: + The policy to use for splitting the module. (default: `None`) + + Returns + ------- + A pipeline representation of class `Pipe`. + """ + if split_spec is not None and split_policy is not None: + raise ValueError( + "Cannot specify both `split_spec` and `split_policy`. Please use only one of them." + ) + + if split_spec is not None: + # Annotate split points in the module based on user spec + annotate_split_points(module, split_spec) + return Pipe.from_tracing( + mod=module, + example_args=mb_args, + example_kwargs=mb_kwargs, + ) + else: + # Use split policy + return Pipe.from_tracing( + mod=module, + example_args=mb_args, + example_kwargs=mb_kwargs, + split_policy=split_policy, + ) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/pipelining/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/pipelining/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..aacaf0b7f5e4ae7f5d221906ebb5b1b6ff93dea9 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/pipelining/__init__.py @@ -0,0 +1,30 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates +from ._IR import Pipe, pipe_split, pipeline, SplitPoint +from .schedules import ( + _ScheduleForwardOnly, + Schedule1F1B, + ScheduleDualPipeV, + ScheduleGPipe, + ScheduleInterleaved1F1B, + ScheduleInterleavedZeroBubble, + ScheduleLoopedBFS, + ScheduleZBVZeroBubble, +) +from .stage import build_stage, PipelineStage + + +__all__ = [ + "Pipe", + "pipe_split", + "SplitPoint", + "pipeline", + "PipelineStage", + "build_stage", + "Schedule1F1B", + "ScheduleGPipe", + "ScheduleInterleaved1F1B", + "ScheduleLoopedBFS", + "ScheduleInterleavedZeroBubble", + "ScheduleZBVZeroBubble", + "ScheduleDualPipeV", +] diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/pipelining/__pycache__/_IR.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/pipelining/__pycache__/_IR.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..102d350f7070def4f60183049db1ff37b59061f9 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/pipelining/__pycache__/_IR.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/pipelining/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/pipelining/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0d17ce85bbb3acd4ffa2725c98190b27a45af86e Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/pipelining/__pycache__/__init__.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/pipelining/__pycache__/_backward.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/pipelining/__pycache__/_backward.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9589f55c0705ee41499e6598af04cebae6083173 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/pipelining/__pycache__/_backward.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/pipelining/__pycache__/_debug.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/pipelining/__pycache__/_debug.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..128d3f5ed607dfa4b0f19d534b71b81f04f163f8 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/pipelining/__pycache__/_debug.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/pipelining/__pycache__/_schedule_visualizer.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/pipelining/__pycache__/_schedule_visualizer.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..113cc5adedce41345c0487e4d99d8ccec0ba3d6d Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/pipelining/__pycache__/_schedule_visualizer.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/pipelining/__pycache__/_unflatten.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/pipelining/__pycache__/_unflatten.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d53fe75ea5e6e64523624b5c8c9a0e41c1c0af67 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/pipelining/__pycache__/_unflatten.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/pipelining/__pycache__/_utils.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/pipelining/__pycache__/_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4306633b54c3a0ee6dc3c58b4a7a6718da931a46 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/pipelining/__pycache__/_utils.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/pipelining/__pycache__/microbatch.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/pipelining/__pycache__/microbatch.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ccb76dd204886393aaf19b257dc95721df3dc155 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/pipelining/__pycache__/microbatch.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/pipelining/__pycache__/stage.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/pipelining/__pycache__/stage.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d9c24e972caea9ef880ee277087c76395b8ce436 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/pipelining/__pycache__/stage.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/pipelining/_backward.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/pipelining/_backward.py new file mode 100644 index 0000000000000000000000000000000000000000..bfcf294c2946c5107c7c506b9846cd320155b27c --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/pipelining/_backward.py @@ -0,0 +1,418 @@ +# mypy: allow-untyped-defs +# Copyright (c) Meta Platforms, Inc. and affiliates +import collections +import logging +from collections.abc import Iterator +from typing import Any + +import torch +from torch.autograd.graph import GradientEdge, Node +from torch.nn import Parameter + +from ._debug import map_debug_info + + +logger = logging.getLogger(__name__) + + +def _get_grad_fn_or_grad_acc(t: torch.Tensor) -> Node | None: + """ + Get the grad function or grad accumulator for a tensor. + + Accumulate grad nodes are lazily created, so we need to a + dummy view in order to trigger its creation. + """ + if t.requires_grad and t.grad_fn is None: + # if no grad function (leaf tensors) we use view + viewed_t = t.view_as(t) + grad_fn = viewed_t.grad_fn + if grad_fn is not None: + return grad_fn.next_functions[0][0] + else: + raise RuntimeError( + "Attempted to get grad_fn, but got None." + "Is this being created in a no-grad context?" + ) + else: + return t.grad_fn + + +def reverse_closure( + roots: list[Node], target_nodes: set[Node], reverse_edges_dict +) -> tuple[set[Node], set[Node]]: + """ + This function returns the reverse closure of the given roots, + i.e. the set of nodes that can be reached from the roots by following the + reverse edges of the graph. The target_nodes are the nodes that we want to + include in the closure. + """ + # Recurse until we reach a target node + closure: set[Node] = set() + visited_target_nodes = set() + q: collections.deque[Node] = collections.deque() + for node in roots: + if node is not None and node not in closure: + closure.add(node) + q.append(node) + while q: + node = q.popleft() + reverse_edges = reverse_edges_dict[node] + for fn in reverse_edges: + if fn in closure or fn is None: + continue + if fn in target_nodes: + visited_target_nodes.add(fn) + continue + closure.add(fn) + q.append(fn) + return closure, visited_target_nodes + + +def construct_reverse_graph(roots: list[Node]) -> dict[Node, list[Node]]: + q: collections.deque[Node] = collections.deque() + root_seen: set[Node] = set() + reverse_edges_dict: dict[Node, list[Node]] = collections.defaultdict(list) + for node in roots: + if node is not None and node not in root_seen: + q.append(node) + root_seen.add(node) + while q: + node = q.popleft() + for fn, _ in node.next_functions: + if fn is not None: + if len(reverse_edges_dict[fn]) == 0: + q.append(fn) + reverse_edges_dict[fn].append(node) + return reverse_edges_dict + + +def get_param_groups( + inputs: list[Node], params: list[Node], reverse_edges_dict +) -> list[dict[str, Any]]: + """ + Given a list of inputs and a list of parameters, return a list of parameter + groups, where each group contains the parameters and the intermediates that + are connected to the parameters. + + The returned list of parameter groups is a list of dictionaries, where each + dictionary contains the following keys: + - "params": a set of parameters + - "intermediates": a set of intermediates + + The returned list of parameter groups is a list of dictionaries, + """ + # reverse graph that starts with inputs, and goes up to the dOutput or the loss, + # but omits weights and any subgraphs connecting weights to this closure + inputs_closure, _ = reverse_closure(inputs, set(), reverse_edges_dict) + param_groups: dict[Node, dict[str, set]] = dict() # keyed on intermediates + for param in params: + closure, intersected = reverse_closure( + [param], inputs_closure, reverse_edges_dict + ) + param_group: dict[str, set] = { + "params": {param}, + "intermediates": intersected, + } + for input_node in intersected: + existing = param_groups.get(input_node) + if existing is not None: + existing["params"] = existing["params"].union(param_group["params"]) + existing["intermediates"] = existing["intermediates"].union( + param_group["intermediates"] + ) + param_group = existing + else: + param_groups[input_node] = param_group + + # Sanity check: union of all param_groups params should be equal to all params + union_params: set[Node] = set() + seen_ids: set[int] = set() + unique_param_groups = [] + for param_group in param_groups.values(): + if id(param_group) not in seen_ids: + seen_ids.add(id(param_group)) + unique_param_groups.append(param_group) + union_params = union_params.union(param_group["params"]) + + # The assert will only be true if the input tensor requires gradients, + # otherwise the autograd graph will miss the first layer of inputs + # assert union_params == set(params) + return unique_param_groups + + +def stage_backward_input( + stage_outputs_or_loss: list[torch.Tensor], + output_grads: list[torch.Tensor] | None, + input_values: list[torch.Tensor], + weights: Iterator[Parameter], +) -> tuple[tuple[torch.Tensor | None, ...], list[dict[str, Any]]]: + """ + Compute the gradients for only the stage inputs with + respect to the stage outputs (if non-last stage) or loss (if last stage) + + After computing input gradients, we save the intermediate nodes in `param_groups` + for later use in stage_backward_weight. We don't need to save any other intermediate nodes + that aren't needed for dW because when we do dW calculation, we start from saved intermediates. + Detaching the stage_outputs_or_loss at the end of this function is important as + it frees up the memory that the autograd graph is anticipating to be used later (but doesn't actually need). + """ + stage_output_grad_fns: list[Node] = list( + filter(None, map(_get_grad_fn_or_grad_acc, stage_outputs_or_loss)) + ) + stage_input_grad_fns: list[Node] = list( + filter(None, map(_get_grad_fn_or_grad_acc, input_values)) + ) + weight_grad_fns: list[Node] = list( + filter(None, map(_get_grad_fn_or_grad_acc, weights)) + ) + + reverse_edges_dict = construct_reverse_graph(stage_output_grad_fns) + param_groups = get_param_groups( + stage_input_grad_fns, weight_grad_fns, reverse_edges_dict + ) + + handles = [] + for param_group in param_groups: + for i, intermediate in enumerate(param_group["intermediates"]): + + def get_hook(param_group, i): + def hook(grad_inputs): + if param_group.get("grads", None) is None: + param_group["grads"] = [None] * len( + param_group["intermediates"] + ) + param_group["grads"][i] = grad_inputs + + return hook + + # These are always "split" nodes that we need to recompute, so + # save their inputs. + handle = intermediate.register_prehook(get_hook(param_group, i)) + handles.append(handle) + + if output_grads is None: + # In case this is the loss and there are no output_grads, then we just use 1s + output_grads = [ + torch.ones_like(stage_output) for stage_output in stage_outputs_or_loss + ] + + # Some inputs may not be used or may not require gradients, so we filter them out + input_values = [inp for inp in input_values if inp.requires_grad] + dinputs = torch.autograd.grad( + stage_outputs_or_loss, + inputs=input_values, + grad_outputs=output_grads, + retain_graph=True, + ) + # Update the gradients for inputs + for inp, dinput in zip(input_values, dinputs): + if inp.grad is None: + inp.grad = dinput + else: + inp.grad += dinput + + # stage_outputs_or_loss are not used in backwards after this point, so we can safely remove it from the autograd graph + # this allows autograd to clear up the graph dedicated for this tensor and free up significant memory + for t in stage_outputs_or_loss: + t.detach_() + + # hooks are no longer necessary, clean up for consistency + for handle in handles: + handle.remove() + + return dinputs, param_groups + + +def stage_backward_weight( + weights: Iterator[Parameter], param_groups: list[dict[str, Any]], retain_graph=False +) -> tuple[torch.Tensor | None, ...]: + # map weights to param_group_weights + grad_acc_to_weight = {} + weight_grads: list[torch.Tensor | None] = [] + for index, weight in enumerate(weights): + grad_acc = _get_grad_fn_or_grad_acc(weight) + grad_acc_to_weight[grad_acc] = weight, index + weight_grads.append(weight.grad) + + for param_group in param_groups: + valid_edges = [] + valid_grad_outputs: list[torch.Tensor] = [] + + for grads_tuple, intermediate in zip( + param_group["grads"], param_group["intermediates"] + ): + non_none_grads = [g for g in grads_tuple if g is not None] + if non_none_grads: + summed_grad = sum(non_none_grads) + valid_edges.append(GradientEdge(intermediate, 0)) + # pyrefly: ignore [bad-argument-type] + valid_grad_outputs.append(summed_grad) + + # Break a reference cycle caused inside stage_backward_input->get_hook->hook + # The summarized cycle is: + # `hook` -> cell -> param_group -> intermediates -> `hook` + # because we install the hook function onto each of the intermediate autograd nodes. + # We need to keep intermediates alive up until backward_weight, but we can free it now. + del param_group["intermediates"] + + if valid_edges: # Only call autograd.grad if we have valid gradients + # [NEW!] Able to pass a GradientEdge to autograd.grad as output + weights_edges = tuple(GradientEdge(w, 0) for w in param_group["params"]) + dweights = torch.autograd.grad( + valid_edges, + weights_edges, + grad_outputs=valid_grad_outputs, + retain_graph=retain_graph, + ) + + # release grad memory early after use + del param_group["grads"] + + for grad_acc, dw in zip(param_group["params"], dweights): + weight, index = grad_acc_to_weight[grad_acc] + if weight.grad is None: + weight.grad = dw + else: + weight.grad += dw + # return grads in the original order weights were provided in + return tuple(weight_grads) + + +def stage_backward( + stage_output, + output_grads, + input_values, + outputs_with_grads_idxs: list[int] | None = None, # deprecated, not used +) -> tuple[torch.Tensor | None, ...]: + """ + This is a helper function to: + 1. compute the gradients for the stage inputs, and + 2. accumulate gradients for the stage module's parameters. + + Given the input value(s) and the corresponding gradient for the output + value(s), compute and accumulate gradients for all parameter values (leaves + in the autograd trace) as well as return a list of the gradients for the + input values + """ + if outputs_with_grads_idxs is not None: + # Deprecated, not used in runtime calls, only exists in compiler + stage_output = [stage_output[i] for i in outputs_with_grads_idxs] + output_grads = [output_grads[i] for i in outputs_with_grads_idxs] + + try: + # stage_output may be a composite datatype like dict. Extract all individual + # tensor values here + stage_output_tensors: list[torch.Tensor] = [] + output_grad_tensors: list[torch.Tensor | None] = [] + + def extract_tensors_with_grads( + output_val, + grad_val, + # Don't delete me- see [Note: ref cycle] + extract_tensors_with_grads, + ): + if isinstance(output_val, torch.Tensor): + if not output_val.requires_grad and output_val.grad_fn is None: + return + assert isinstance(grad_val, (torch.Tensor, type(None))), ( + f"Expected Tensor or None gradient but got {type(grad_val)}" + ) + stage_output_tensors.append(output_val) + output_grad_tensors.append(grad_val) + elif isinstance(output_val, (tuple, list)): + if grad_val is None: + return + assert isinstance(grad_val, (tuple, list)), ( + f"grad_value expected to have type {type(output_val)} but got {type(grad_val)}" + ) + assert len(output_val) == len(grad_val) + for ov, gv in zip(output_val, grad_val): + extract_tensors_with_grads( + ov, + gv, + extract_tensors_with_grads, + ) + elif isinstance(output_val, dict): + if grad_val is None: + return + assert isinstance(grad_val, dict) + assert set(output_val.keys()) == set(grad_val.keys()) + for k in output_val: + extract_tensors_with_grads( + output_val[k], grad_val[k], extract_tensors_with_grads + ) + else: + # Output is a non-tensor type; just ignore it + pass + + # Note: ref cycle + # break a ref cycle that would keep tensors alive until GC runs + # 1. extract_tensors_with_grads refers to a cell that holds refs to any vars defined in stage_backward + # and used in extract_tensors_with_grads + # 2. extract_tensors_with_grads referred to both stage_output_tensors, output_grad_tensors, + # and to itself (extract_tensors_with_grads) since it makes a recursive call + # 3. stage_output_tensors was kept alive by the above refcycle, and it holds activation tensors, which is bad + # fix -> explicitly pass in the ref to the fn, so there is no gc cycle anymore + extract_tensors_with_grads( + stage_output, output_grads, extract_tensors_with_grads + ) + + torch.autograd.backward( + stage_output_tensors, + grad_tensors=output_grad_tensors, # type: ignore[arg-type] + ) + + # Extract gradients wrt the input values + grad_inputs: list[torch.Tensor | None] = [] + for val in input_values: + if isinstance(val, torch.Tensor): + grad_inputs.append(val.grad) + # Since gradients that will pass back to previous stages do not require gradient accumulation, + # by decrementing the gradients' reference count at this point, the memory of gradients will be + # returned to the allocator as soon as the next micro batch's get_bwd_send_ops comes and current + # asynchronous send completes. + # This prevents the gradients from persisting in GPU memory for the entire duration of step_microbatches + # until clear_runtime_states() is called. + val.grad = None + else: + grad_inputs.append(None) + + # Alternative impl: `torch.autograd.grad`. + # Note that `torch.autograd.grad` will not accumulate gradients into the + # model's parameters. + """ + inputs_with_grad = [] + for val in input_values: + if isinstance(val, torch.Tensor) and val.requires_grad: + inputs_with_grad.append(val) + + grad_inputs = torch.autograd.grad( + stage_output_tensors, inputs_with_grad, output_grad_tensors, # type: ignore[arg-type] + ) + """ + + except Exception as e: + exc_msg = f""" + Failed to run stage backward: + Stage output: {map_debug_info(stage_output)} + Output gradient: {map_debug_info(output_grads)} + Input: {map_debug_info(input_values)} + """ + raise RuntimeError(exc_msg) from e + + return tuple(grad_inputs) + + +# TODO: handling requires_grad=False dynamically. Can we analyze this during initial +# IR emission? +def _null_coalesce_accumulate(lhs, rhs): + """ + Coalesce two values, even if one of them is null, returning the non-null + value. + """ + if lhs is None: + return rhs + elif rhs is None: + return lhs + else: + return torch.add(lhs, rhs) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/pipelining/_debug.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/pipelining/_debug.py new file mode 100644 index 0000000000000000000000000000000000000000..a3201d2d3adf1d05921e070d14b4e544844df88f --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/pipelining/_debug.py @@ -0,0 +1,22 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates + +import torch +from torch.fx.node import Argument + + +def friendly_debug_info(v: object) -> Argument: + """ + Helper function to print out debug info in a friendly way. + """ + if isinstance(v, torch.Tensor): + return f"Tensor({v.shape}, grad={v.requires_grad}, dtype={v.dtype})" + else: + return str(v) + + +def map_debug_info(a: Argument) -> Argument: + """ + Helper function to apply `friendly_debug_info` to items in `a`. + `a` may be a list, tuple, or dict. + """ + return torch.fx.node.map_aggregate(a, friendly_debug_info) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/pipelining/_schedule_visualizer.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/pipelining/_schedule_visualizer.py new file mode 100644 index 0000000000000000000000000000000000000000..5ecc5bf19ab17d83e0f3128290c0d5cc4b862b4d --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/pipelining/_schedule_visualizer.py @@ -0,0 +1,437 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates + +""" +This visualizer requires matplotlib to be installed. + +Example usage: + +ops = get_schedule_ops("InterleavedZeroBubble", 4, 8) +visualize_schedule(ops, "test.png") +""" + +import collections +from typing import NamedTuple +from unittest import mock + +from torch.distributed.pipelining.schedules import ( + _Action, + _ComputationType, + _PipelineSchedule, + _PipelineScheduleRuntime, + get_schedule_class, + PipelineScheduleMulti, + PipelineScheduleSingle, +) +from torch.distributed.pipelining.stage import PipelineStage + + +class OpKey(NamedTuple): + stage_index: int + computation_type: _ComputationType + microbatch_index: int + + +def get_schedule_ops( + schedule: str | type[_PipelineSchedule], + pp_degree: int, + num_microbatches: int, + num_stages_per_rank: int | None = None, + add_spacing: bool = False, + with_comms: bool = False, +) -> list[list[_Action | None]]: + """ + Get all actions for a given schedule, pp_degree, and num_microbatches. The actions are returned in a list of lists + where each inner list represents a rank and each element in the inner list represents an action. + + The schedule can be specified as a string which is passed into get_schedule_class() or a _PipelineSchedule instance. + """ + if add_spacing and with_comms: + raise ValueError("Cannot add spacing and view comms at the same time") + + if isinstance(schedule, str): + schedule_class = get_schedule_class(schedule) + elif issubclass(schedule, _PipelineSchedule): + schedule_class = schedule + else: + raise ValueError(f"Invalid schedule: {schedule}") + + # Create a mock of the PipelineStage class + mock_pipeline_stage = mock.create_autospec(PipelineStage, instance=True) + # Set the return values for group_rank and group_size methods + mock_pipeline_stage.group_rank = 0 + mock_pipeline_stage.group_size = pp_degree + mock_pipeline_stage.submod = None + + # Check num_stages_per_rank is valid + if issubclass(schedule_class, PipelineScheduleSingle): + if num_stages_per_rank is None: + num_stages_per_rank = 1 + assert num_stages_per_rank == 1 + stages = mock_pipeline_stage + stages.num_stages = num_stages_per_rank * pp_degree + elif issubclass(schedule_class, PipelineScheduleMulti): + if num_stages_per_rank is None: + num_stages_per_rank = 2 + assert num_stages_per_rank >= 2 + stages = [mock_pipeline_stage for _ in range(num_stages_per_rank)] + for stage in stages: + stage.num_stages = num_stages_per_rank * pp_degree + + else: + raise ValueError(f"Invalid schedule: {schedule_class}") + + # Instantiate the schedule class + # pyrefly: ignore [bad-instantiation, bad-argument-type] + schedule_instance = schedule_class(stages, num_microbatches) + assert schedule_instance.pipeline_order is not None + + # Convert to List[List[_Action]] + all_actions: list[list[_Action | None]] = [] + if with_comms: + runtime = _PipelineScheduleRuntime(stages, num_microbatches) + runtime._prepare_schedule_with_comms(schedule_instance.pipeline_order) + for rank in range(pp_degree): + all_actions.append(list(runtime.pipeline_order_with_comms[rank])) + else: + for rank in range(pp_degree): + all_actions.append(schedule_instance.pipeline_order[rank]) + + # Add spacing + if add_spacing: + # remove all Nones, then respace + # TODO: later we can change this at the schedule creation level to not use Nones + all_actions = [ + [action for action in rank if action is not None] for rank in all_actions + ] + all_actions = add_schedule_op_spacing(all_actions) + + # Return the pipeline order + return all_actions + + +class _ComputationTypeVisual: + def __init__( + self, + color: str, + text: str = "", + width: int = 1, + ): + self.color = color + self.width = width + self.text = text + + +# Update the mapping to use _ComputationTypeVisual instances +action_type_to_color_mapping = { + _ComputationType.FORWARD: _ComputationTypeVisual("blue", "Forward"), + _ComputationType.BACKWARD_INPUT: _ComputationTypeVisual("teal", "Backward Input"), + _ComputationType.BACKWARD_WEIGHT: _ComputationTypeVisual( + "green", "Backward Weight" + ), + _ComputationType.FULL_BACKWARD: _ComputationTypeVisual( + "orange", "Full Backward", 2 + ), + _ComputationType.OVERLAP_F_B: _ComputationTypeVisual("purple", "Overlap F+B", 3), +} + + +def add_schedule_op_spacing( + schedule: list[list[_Action | None]], +) -> list[list[_Action | None]]: + """ + Add spacing to the schedule based on dependencies between ranks. + + Before adding an operation to the list, this function checks if there are + dependencies from other ranks. If there are dependencies (other ranks have + not finished processing the required microbatch), it adds None instead. + + For example, Forward microbatch 0 on rank 1 depends on rank 0 processing + Forward microbatch 0 first. + + Args: + schedule: The original schedule as a list of lists where each inner list + represents a rank and each element represents an action. + + Returns: + A new schedule with proper spacing based on dependencies. + """ + if not schedule: + return schedule + + num_stages = ( + max( + action.stage_index + for rank_actions in schedule + for action in rank_actions + if action is not None + ) + + 1 + ) + + num_ranks = len(schedule) + spaced_schedule: list[list[_Action | None]] = [[] for _ in range(num_ranks)] + rank_ops = [collections.deque(ops) for ops in schedule] + + # Track completion times: (stage_index, action_type, microbatch_index) -> completion_time + scheduled_ops: dict[OpKey, int] = {} + + def is_dependency_ready(dependency_key: OpKey, timestep: int) -> bool: + """Check if a dependency operation has completed by the given timestep.""" + return ( + dependency_key in scheduled_ops + and timestep >= scheduled_ops[dependency_key] + ) + + def get_dependencies(action: _Action) -> list[OpKey]: + """Get the list of dependencies for an action.""" + stage_idx = action.stage_index + comp_type = action.computation_type + mb_idx = action.microbatch_index + + # Ensure mb_idx is not None for dependency tracking + assert mb_idx is not None, f"Action {action} has None microbatch_index" + + # First stage forward has no dependencies + if stage_idx == 0 and comp_type == _ComputationType.FORWARD: + return [] + + # Last stage backward depends on forward from previous stage + if stage_idx == num_stages - 1 and comp_type in ( + _ComputationType.FULL_BACKWARD, + _ComputationType.BACKWARD_INPUT, + ): + return [OpKey(stage_idx - 1, _ComputationType.FORWARD, mb_idx)] + + # Forward depends on previous stage forward + if comp_type == _ComputationType.FORWARD: + return [OpKey(stage_idx - 1, _ComputationType.FORWARD, mb_idx)] + + # Backward depends on next stage backward + if comp_type in ( + _ComputationType.FULL_BACKWARD, + _ComputationType.BACKWARD_INPUT, + ): + return [ + OpKey(stage_idx + 1, _ComputationType.FULL_BACKWARD, mb_idx), + OpKey(stage_idx + 1, _ComputationType.BACKWARD_INPUT, mb_idx), + ] + + # Weight backward depends on input backward + if comp_type == _ComputationType.BACKWARD_WEIGHT: + return [OpKey(stage_idx, _ComputationType.BACKWARD_INPUT, mb_idx)] + + raise RuntimeError(f"Unknown computation type: {comp_type}") + + def is_action_ready(action: _Action, timestep: int) -> bool: + """Check if an action is ready to be scheduled at the given timestep.""" + # For OR dependencies (like backward), check if any dependency is satisfied + if action.computation_type in ( + _ComputationType.FULL_BACKWARD, + _ComputationType.BACKWARD_INPUT, + _ComputationType.BACKWARD_WEIGHT, + ): + dependencies = get_dependencies(action) + return any(is_dependency_ready(dep, timestep) for dep in dependencies) + # For AND dependencies, all must be satisfied + elif action.computation_type == _ComputationType.FORWARD: + dependencies = get_dependencies(action) + return all(is_dependency_ready(dep, timestep) for dep in dependencies) + elif action.computation_type == _ComputationType.OVERLAP_F_B: + assert action.sub_actions is not None, ( + f"OVERLAP_F_B action {action} has None sub_actions" + ) + dep_list: list[bool] = [] + for sub_action in action.sub_actions: + dep_list.append(is_action_ready(sub_action, timestep)) + return all(dep_list) + else: + raise RuntimeError(f"Unknown computation type: {action.computation_type}") + + def schedule_action(action: _Action, rank: int, timestep: int) -> int: + """Schedule an action and return completion time.""" + spaced_schedule[rank].append(action) + comp_type = action.computation_type + comp_time = action_type_to_color_mapping[comp_type].width + completion_time = timestep + comp_time + + if comp_type == _ComputationType.OVERLAP_F_B: + # For overlap actions, schedule each sub-action with cumulative timing + assert action.sub_actions is not None, ( + f"OVERLAP_F_B action {action} has None sub_actions" + ) + cumulative_time = 0 + for sub_action in action.sub_actions: + assert sub_action.microbatch_index is not None, ( + f"Sub-action {sub_action} has None microbatch_index" + ) + sub_comp_time = action_type_to_color_mapping[ + sub_action.computation_type + ].width + cumulative_time += sub_comp_time + scheduled_ops[ + OpKey( + sub_action.stage_index, + sub_action.computation_type, + sub_action.microbatch_index, + ) + ] = timestep + cumulative_time + else: + assert action.microbatch_index is not None, ( + f"Action {action} has None microbatch_index" + ) + scheduled_ops[ + OpKey(action.stage_index, comp_type, action.microbatch_index) + ] = completion_time + + return completion_time + + # Main scheduling loop + current_timestep = 0 + timesteps_without_progress = 0 + rank_completion_times = dict.fromkeys(range(num_ranks), 0) + while rank_ops: + print(f"Current timestep: {current_timestep}") + # Process all operations during timestep until we run out of ready operations + for rank, op_queue in enumerate(rank_ops): + if not op_queue: + continue + + op_queue = rank_ops[rank] + action = op_queue[0] + print(f"Rank: {rank}, {action=}") + if action is None: + spaced_schedule[rank].append(None) + op_queue.popleft() + timesteps_without_progress = 0 + elif current_timestep >= rank_completion_times[rank] and is_action_ready( + action, current_timestep + ): + rank_completion_times[rank] = schedule_action( + action, rank, current_timestep + ) + op_queue.popleft() + timesteps_without_progress = 0 + + # Add None for ranks that are waiting + for rank in range(num_ranks): + if current_timestep >= rank_completion_times[rank]: + spaced_schedule[rank].append(None) + + # Remove empty queues and advance timestep + rank_ops = [op_queue for op_queue in rank_ops if op_queue] + current_timestep += 1 + timesteps_without_progress += 1 + + if timesteps_without_progress > max( + visual.width for visual in action_type_to_color_mapping.values() + ): + raise RuntimeError("No progress made in scheduling - possible deadlock") + + return spaced_schedule + + +def visualize_schedule( + schedule: list[list[_Action | None]], + filename: str | None = None, +) -> None: + """ + Visualize the schedule using matplotlib. + The schedule is a list of lists where each inner list represents a rank and each element in the inner list represents an action. + The actions are represented as rectangles with different colors based on their computation type. + The filename is optional and if provided, the plot will be saved to that file. + + Args: + schedule: The schedule to visualize. + filename: The filename to save the plot to. If not provided, the plot will be displayed. + add_schedule_spacing: If True, add spacing to the schedule based on dependencies between ranks. + + """ + + import matplotlib.pyplot as plt + from matplotlib.patches import Rectangle + + plt.rcParams["font.family"] = ( + "DejaVu Sans" # or any other font available on your system + ) + num_ranks = len(schedule) + max_actions = max(len(rank) for rank in schedule) + + # Increase the figure size to provide more space for the legend + fig, ax = plt.subplots(figsize=(max_actions + 2, num_ranks + 2)) + max_draw_position = -1 + # Calculate dynamic font size based on figure size + font_size = min(max_actions, num_ranks) + 4 + used_computation = set() + for rank_idx, actions in enumerate(schedule): + draw_position = 0 # Initialize drawing position for each rank + for action in actions: + if action is not None: + comp_type_color = action_type_to_color_mapping.get( + action.computation_type, _ComputationTypeVisual("black") + ) + used_computation.add(action.computation_type) + color = comp_type_color.color + width = comp_type_color.width + + # Check if action has sub_actions to determine styling + if action.sub_actions is not None: + linewidth = 2 # Thicker border for compound actions + text_weight = "normal" # Bold text for compound actions + else: + linewidth = 1 # Default linewidth for regular actions + text_weight = "normal" # Default text weight + + # Draw the rectangle to represent the action duration + rect = Rectangle( + (draw_position, num_ranks - rank_idx - 1), + width, + 1, + facecolor=color, + edgecolor="black", + linewidth=linewidth, + ) + ax.add_patch(rect) + + # Draw the text centered within the rectangle + ax.text( + draw_position + width / 2, + num_ranks - rank_idx - 1 + 0.5, + str(action), + ha="center", + va="center", + fontsize=font_size, + color="white", + weight=text_weight, + ) + + draw_position += width + else: + draw_position += 1 # Move to the next + max_draw_position = max(max_draw_position, draw_position) + ax.set_xlim(-0.5, max_draw_position + 1) + ax.set_ylim(-0.5, num_ranks + 0.5) # Add extra space at the top + # Set y-ticks to be in the middle of each rank's row + ax.set_yticks([num_ranks - rank_idx - 0.5 for rank_idx in range(num_ranks)]) + ax.set_yticklabels([f"Rank {i}" for i in range(num_ranks)], fontsize=font_size) + ax.set_xticklabels([]) + + # Remove grid lines and ticks + ax.grid(False) + # Add legend with larger font size + legend_elements = [ + Rectangle( + (0, 0), + 1, + 1, + facecolor=action_type_to_color_mapping[comp_type].color, + edgecolor="black", + label=action_type_to_color_mapping[comp_type].text, + ) + for comp_type in used_computation + ] + ax.legend(handles=legend_elements, loc="upper right", fontsize=font_size) + # Save to file if filename is provided, otherwise display the plot + if filename: + plt.savefig(filename, bbox_inches="tight") + else: + plt.show() diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/pipelining/_unflatten.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/pipelining/_unflatten.py new file mode 100644 index 0000000000000000000000000000000000000000..0ed592f2f8d832de0703fbfa296225f17698afbf --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/pipelining/_unflatten.py @@ -0,0 +1,30 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates +from collections import defaultdict + +import torch +from torch.export.unflatten import _ModuleFrame, _SubmoduleEntry + + +def _outline_submodules(orig_graph: torch.fx.Graph) -> torch.fx.GraphModule: + # Create an empty GraphModule to hold the outlined modules + new_module = torch.fx.GraphModule(torch.nn.Module(), torch.fx.Graph()) + seen_nodes: dict[str, torch.fx.Node] = {} + seen_modules: dict[int, list[_SubmoduleEntry]] = defaultdict(list) + seen_attrs: dict[str, set[str]] = defaultdict(set) + created_modules: dict[str, torch.nn.Module] = {} + _ModuleFrame( + orig_graph, + tuple(orig_graph.nodes), + seen_nodes, + seen_modules, + seen_attrs, + created_modules, + None, + [("", None, 0)], + "", + {}, + module=new_module, + ).run_outer() + new_module.graph.lint() + new_module.recompile() + return new_module diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/pipelining/_utils.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/pipelining/_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..79b74be40681425ab4a5c97198bf0a2020d1d10e --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/pipelining/_utils.py @@ -0,0 +1,159 @@ +# mypy: allow-untyped-defs +# Copyright (c) Meta Platforms, Inc. and affiliates + +import logging +from dataclasses import dataclass + +import torch +from torch import fx + + +logger = logging.getLogger(__name__) + + +def flatten_args_detach(args): + """ + Flatten the args into a list form and detach the tensors from computational graph. + """ + flat_detached_args = [] + + def extract_tensor_args(a): + nonlocal flat_detached_args + if isinstance(a, torch.Tensor): + val = a.detach().requires_grad_(a.requires_grad) + flat_detached_args.append(val) + return val + else: + flat_detached_args.append(a) + return a + + new_args = fx.node.map_aggregate( + args, + extract_tensor_args, + ) + + return new_args, flat_detached_args + + +def flatten_args(args): + """ + Flatten the args into a list form. + """ + flat_args = [] + + def extract_tensor_args(a): + nonlocal flat_args + flat_args.append(a) + return a + + fx.node.map_aggregate( + args, + extract_tensor_args, + ) + + return flat_args + + +class PipeliningShapeError(RuntimeError): + """Shape mismatch between configured and runtime values.""" + + +def validate_tensor_metadata(desc, expected, given): + if not expected.shape == given.shape: + raise PipeliningShapeError( + f"{desc} has a shape mismatch: expected {expected.shape} actual {given.shape}" + ) + if not expected.dtype == given.dtype: + raise PipeliningShapeError( + f"{desc} has a dtype mismatch: expected {expected.dtype} actual {given.dtype}" + ) + if not expected.stride() == given.stride(): + raise PipeliningShapeError( + f"{desc} has a stride mismatch: expected {expected.stride()} actual {given.stride()}" + ) + + +def validate_tensors_metadata( + desc, + expected_tensors: list[torch.Tensor] | tuple[torch.Tensor, ...], + actual_tensors: list[torch.Tensor] | tuple[torch.Tensor, ...], +): + if len(expected_tensors) != len(actual_tensors): + raise PipeliningShapeError( + f"{desc}: Number of values ({len(actual_tensors)}) does not match expected number ({len(expected_tensors)})" + ) + for i in range(len(expected_tensors)): + validate_tensor_metadata( + f"{desc}: value {i}", expected_tensors[i], actual_tensors[i] + ) + + +def generate_stage_to_rank_mapping( + pp_size: int, num_stages: int, style: str = "loop" +) -> dict[int, int]: + """ + Compute the stage id to rank mapping for either a looped or V-style schedule. + + Most commonly num_stages == pp_size * 2, but this function can be used to + compute the mapping for any number of stages per rank. + """ + mapping = {} + if style == "loop": + for stage_index in range(num_stages): + mapping[stage_index] = stage_index % pp_size + elif style == "v": + if num_stages % pp_size != 0: + raise ValueError( + f"num_stages {num_stages} must be evenly divisible by pp_size {pp_size} for V schedules" + ) + + rank_index = 0 + for stage_index in range(num_stages): + mapping[stage_index] = rank_index + # dont change rank if we are on the border (to keep v shape) + if (stage_index + 1) % pp_size == 0: + continue + if (stage_index // pp_size) % 2 == 0: + rank_index += 1 + else: + rank_index -= 1 + else: + raise ValueError(f"Style {style} is not supported.") + return mapping + + +def generate_rank_to_stage_mapping( + pp_size: int, num_stages: int, style: str = "loop" +) -> dict[int, list[int]]: + """ + Compute the rank to stage id mapping for either a looped or V-style schedule. + + This function inverts the stage_to_rank_mapping to get which stages are assigned to each rank. + + Returns a dictionary mapping rank -> list of stage indices assigned to that rank. + """ + stage_to_rank = generate_stage_to_rank_mapping(pp_size, num_stages, style) + + # Invert the mapping: rank -> list of stages + rank_to_stages: dict[int, list[int]] = {} + for stage_id, rank in stage_to_rank.items(): + if rank not in rank_to_stages: + rank_to_stages[rank] = [] + rank_to_stages[rank].append(stage_id) + + # Sort the stage lists for each rank to ensure consistent ordering + for stages in rank_to_stages.values(): + stages.sort() + + return rank_to_stages + + +@dataclass +class PipeInfo: + """ + Captures information for a pipeline (`Pipe` object). + """ + + graph: fx.Graph + num_stages: int + has_loss_and_backward: bool diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/pipelining/microbatch.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/pipelining/microbatch.py new file mode 100644 index 0000000000000000000000000000000000000000..a82f83072fa1897c1738e8bf879911921720cfe5 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/pipelining/microbatch.py @@ -0,0 +1,544 @@ +# mypy: allow-untyped-defs +# Copyright (c) Meta Platforms, Inc. and affiliates +import logging +import operator +from collections.abc import Sequence +from typing import Any + +import torch +from torch.fx.node import map_aggregate +from torch.nn.attention.flex_attention import BlockMask +from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten + + +__all__ = [ + "TensorChunkSpec", + "split_args_kwargs_into_chunks", + "merge_chunks", +] + +logger = logging.getLogger(__name__) + +""" +_debug_mask_minibatches specifies to send masked versions of the mini-batch +through instead of micro-batch slices--this can be used for more stable +numerical testing (see [A Note About Correctness Testing]) +""" +_debug_mask_minibatches = False + + +class _CustomReducer: + """ + Custom reducer class that can be used to specify a custom operation that + reduces losses of multiple microbatches into one value. + + Example: + >>> # xdoctest: +SKIP + >>> sum_reducer = _CustomReducer( + >>> torch.tensor(0.0), + >>> lambda a, b: a + b + >>> ) + """ + + def __init__(self, init_value, reduce_fn): + self.init_value = init_value + self.reduce_fn = reduce_fn + + +class _LossReducer(_CustomReducer): + pass + + +sum_reducer = _LossReducer(torch.tensor(0.0), operator.add) + +# Default chunking dimension is 0. This is used for the case where the user did +# not specify a chunking dimension. +DEFAULT_CHUNK_DIM = 0 + + +class TensorChunkSpec: + """ + Class used to specify chunking of inputs + """ + + def __init__(self, split_dim): + self.split_dim = split_dim + + split_dim: int + + def __repr__(self): + return ( + f"{self.__class__.__module__}.{self.__class__.__name__}({self.split_dim})" + ) + + def __str__(self): + return f"TensorChunkSpec({self.split_dim})" + + @staticmethod + def from_tuple( + chunk_dims: tuple[int, ...], + ): + """ + A helper for creating a tuple of `TensorChunkSpec` from a tuple of chunk + dimensions (int's). + Example: + >>> # xdoctest: +SKIP + >>> # There are three positional arguments to the model, and + >>> # we are chunking them along dimension 0, 0 and 1, respectively + >>> args_chunk_spec = TensorChunkSpec.from_tuple((0, 0, 1)) + """ + args_chunk_spec = map_aggregate( + chunk_dims, + lambda dim: TensorChunkSpec(dim), # type: ignore[arg-type,return-value] + ) + return args_chunk_spec + + @staticmethod + def from_dict( + chunk_dims: dict[str, int], + ): + """ + A helper for creating a dictionary of `TensorChunkSpec` from a + dictionary of chunk dimensions (int's). + Example: + >>> # xdoctest: +SKIP + >>> # Chunk dimension 0 for the "id" argument, 1 for the "mask" argument + >>> kwargs_chunk_spec = TensorChunkSpec.from_dict({"id": 0, "mask": 1}) + """ + kwargs_chunk_spec = map_aggregate( + chunk_dims, + lambda dim: TensorChunkSpec(dim), # type: ignore[arg-type,return-value] + ) + return kwargs_chunk_spec + + +# Class used to specify replication of inputs +class _Replicate: + pass + + +def _split_block_mask( + block_mask: BlockMask, + num_chunks: int, +) -> list[BlockMask]: + """Given a block mask, split the block mask along the batch dimension (dim0). + + Args: + block_mask: Block mask to split + num_chunks: Number of chunks to split the block mask into + + Returns: + chunk_block_masks: List of chunked block masks + """ + + # BlockMask will broadcast if B is 1. + if block_mask.kv_num_blocks.size(0) == 1: + return [block_mask] * num_chunks + + assert block_mask.kv_num_blocks.size(0) >= num_chunks, ( + "Block mask has fewer batch size than the number of chunks. " + ) + + batch_dim = 0 + kv_num_blocks_chunks = torch.tensor_split( + block_mask.kv_num_blocks, num_chunks, batch_dim + ) + kv_indices_chunks = torch.tensor_split(block_mask.kv_indices, num_chunks, batch_dim) + full_kv_num_blocks_chunks = ( + torch.tensor_split(block_mask.full_kv_num_blocks, num_chunks, batch_dim) + if block_mask.full_kv_num_blocks is not None + else [None] * num_chunks + ) + full_kv_indices_chunks = ( + torch.tensor_split(block_mask.full_kv_indices, num_chunks, batch_dim) + if block_mask.full_kv_indices is not None + else [None] * num_chunks + ) + + chunk_block_masks = [] + batch_offset = 0 + for chunk_idx in range(num_chunks): + + def create_mask_mod(idx): + def batch_offset_mask_mod(b, h, q_idx, kv_idx): + b_offset = torch.full_like(b, idx) + return block_mask.mask_mod(b + b_offset, h, q_idx, kv_idx) + + return batch_offset_mask_mod + + chunk_block_masks.append( + BlockMask.from_kv_blocks( + kv_num_blocks=kv_num_blocks_chunks[chunk_idx], + kv_indices=kv_indices_chunks[chunk_idx], + full_kv_num_blocks=full_kv_num_blocks_chunks[chunk_idx], + full_kv_indices=full_kv_indices_chunks[chunk_idx], + BLOCK_SIZE=block_mask.BLOCK_SIZE, + mask_mod=create_mask_mod(batch_offset), + seq_lengths=block_mask.seq_lengths, + ) + ) + batch_offset += kv_num_blocks_chunks[chunk_idx].size(0) + return chunk_block_masks + + +def _split_tensor( + tensor: torch.Tensor, + spec: TensorChunkSpec, + num_chunks: int, +) -> Sequence[torch.Tensor]: + """Given a tensor, and a chunking spec, split the tensor. + Args: + + tensor: Tensor to split + spec: Chunking spec + num_chunks: Number of chunks to split the tensor into + + Returns: + chunk_tensors: List of chunked tensors + """ + + assert tensor.size(spec.split_dim) >= num_chunks, ( + f"Tensor size {tensor.size(spec.split_dim)} is smaller than num_chunks" + ) + chunk_tensors = torch.tensor_split(tensor, num_chunks, spec.split_dim) + + if not _debug_mask_minibatches: + return chunk_tensors + + expanded_chunks = [] + split_dim_idx = 0 + for chunk_tensor in chunk_tensors: + new_val = torch.zeros_like(tensor) + upper_idx = split_dim_idx + chunk_tensor.size(spec.split_dim) + + slice_indices = [slice(None, None, None)] * new_val.ndim + slice_indices[spec.split_dim] = slice(split_dim_idx, upper_idx) + new_val[slice_indices] = chunk_tensor + + expanded_chunks.append(new_val) + + split_dim_idx += chunk_tensor.size(spec.split_dim) + + return expanded_chunks + + +def _shard_dict_of_args( + args_dict, + args_chunk_spec, + num_chunks, +): + """ + Given a dictionary of args, and a dictionary of chunking specs, shard the + args according to the chunking specs. + + Args: + args_dict: Dictionary of args + args_chunk_spec: Dictionary of chunking specs + num_chunks: Number of chunks to shard the args into + + Returns: + args_split: List of sharded args + """ + + if not args_dict: + return [{} for _ in range(num_chunks)] + + assert len(args_dict) == len(args_chunk_spec), ( + f"args_dict.keys() = {list(args_dict.keys())} " + f"args_chunk_spec.keys() = {list(args_chunk_spec.keys())}" + ) + assert args_chunk_spec is not None # Should have been set by caller + + values, tree_spec = tree_flatten( + args_dict, is_leaf=lambda x: isinstance(x, BlockMask) + ) + chunk_specs, _ = tree_flatten( + args_chunk_spec, is_leaf=lambda x: isinstance(x, BlockMask) + ) + + # First check and find the actual number of chunks + split_sizes = [] + for v, spec in zip(values, chunk_specs, strict=True): + # The original logic is "spec is _Replicate". This doesn't seem to be + # correct. But we keep it for backward compatibility. + if spec is _Replicate or isinstance(spec, _Replicate): + split_sizes.append(num_chunks) + elif isinstance(v, torch.Tensor): + assert isinstance(spec, TensorChunkSpec) + split_sizes.append(v.size(spec.split_dim)) + elif isinstance(v, BlockMask): + assert isinstance(spec, TensorChunkSpec) + assert spec.split_dim == 0, "BlockMask only supports split_dim=0" + # BlockMask will broadcast if B is 1. + if v.kv_num_blocks.size(0) == 1: + split_sizes.append(num_chunks) + else: + split_sizes.append(v.kv_num_blocks.size(0)) + else: + raise ValueError( + f"Unsupported chunk spec: {spec} and value: {v} combination." + ) + result_num_chunks = min(*split_sizes, num_chunks) + + flat_split_results: list[Any] = [[] for _ in range(result_num_chunks)] + for v, spec in zip(values, chunk_specs, strict=True): + v_splits: Sequence[Any] = [] + if spec is _Replicate or isinstance(spec, _Replicate): + v_splits = [v] * result_num_chunks + elif isinstance(v, torch.Tensor): + v_splits = _split_tensor(v, spec, result_num_chunks) + elif isinstance(v, BlockMask): + v_splits = _split_block_mask(v, result_num_chunks) + else: + raise ValueError( + f"Unsupported chunk spec: {spec} and value: {v} combination." + ) + + for _flat_split_result, _v_split in zip( + flat_split_results, v_splits, strict=True + ): + _flat_split_result.append(_v_split) + + return [ + tree_unflatten(_flat_split_result, tree_spec) + for _flat_split_result in flat_split_results + ] + + +def split_args_kwargs_into_chunks( + args: tuple[Any, ...], + kwargs: dict[str, Any] | None, + chunks: int, + args_chunk_spec: tuple[TensorChunkSpec, ...] | None = None, + kwargs_chunk_spec: dict[str, TensorChunkSpec] | None = None, +) -> tuple[list[tuple], list[dict]]: + """ + Given a sequence of args and kwargs, split them into a number of chunks + according to their respective chunking specs. + + Args: + args: Tuple of args + kwargs: Dict of kwargs + chunks: Number of chunks to split the args and kwargs into + args_chunk_spec: chunking specs for args, in same shape as args + kwargs_chunk_spec: chunking specs for kwargs, in same shape as kwargs + + Returns: + args_split: List of sharded args + kwargs_split: List of sharded kwargs + """ + # Given `args` and `kwargs`, we want to yield a set of `chunks` args and kwargs such that + # the constituent Tensor values have been sharded/replicated according to the `args_chunk_spec` + # and `kwargs_chunk_spec` specifications. The steps are as follows: + # + # 1. Use pytree.tree_flatten to flatten each arg and its spec into nto a 1d array of values. + # To use a running example: suppose our inputs look like + # + # args = ([A, [B, C]], D) args_spec = ([None, [None, TensorChunkSpec]], None) + # (kwargs not shown but it's a similar process) + # + # Then for this step we would end up with + # + # args = ([A, B, C], D) args_spec = ([None, None, TensorChunkSpec], None) + # + # 2. Shard or replicate the arguments subject to the policy in the spec. Suppose chunks = 2 + # + # args = ([[A, A], [B, B], [C_1, C_2]], [D, D]) + # + # 3. Rotate the nesting order such that chunks are the outer dimension + # + # args_chunks = [ + # ([A, B, C_1], D), + # ([A, B, C_2], D), + # ] + # + # 4. Unflatten each chunk according to the spec + # + # args_chunks = [ + # ([A, [B, C_1]], D), + # ([A, [B, C_2]], D), + # ] + + # TODO: _debug_mask_minibatches + # Handle the case where kwargs is None + if kwargs is None: + kwargs = {} + + # If user did not provide args_chunk_spec or kwargs_chunk_spec, we extend + # their format and use default chunking along dim 0 + def default_spec(v): + if isinstance(v, torch.Tensor | BlockMask): + return TensorChunkSpec(DEFAULT_CHUNK_DIM) + else: + return _Replicate() + + if args_chunk_spec is None: + args_chunk_spec = tree_map( + default_spec, args, is_leaf=lambda v: isinstance(v, BlockMask) + ) + + if kwargs_chunk_spec is None: + kwargs_chunk_spec = tree_map( + default_spec, kwargs, is_leaf=lambda v: isinstance(v, BlockMask) + ) + + args_split_dict = _shard_dict_of_args( + dict(enumerate(args)), + dict(enumerate(args_chunk_spec)), + chunks, + ) + real_num_chunks = len(args_split_dict) + + kwargs_split = _shard_dict_of_args( + kwargs, + kwargs_chunk_spec, + real_num_chunks, + ) + + if len(kwargs_split) < real_num_chunks: + # In case kwargs are sharded into less chunks + # e.g. when `args` has no tensor, just values + real_num_chunks = len(kwargs_split) + # Re-shard args + args_split_dict = _shard_dict_of_args( + dict(enumerate(args)), + dict(enumerate(args_chunk_spec)), + real_num_chunks, + ) + + if len(args_split_dict) != len(kwargs_split): + raise RuntimeError( + "args and kwargs are split into different number of chunks: " + f"{len(args_split_dict)}, {len(kwargs_split)}" + ) + + args_split = [ + tuple(chunk_args[i] for i in range(len(chunk_args))) + for chunk_args in args_split_dict + ] + + return args_split, kwargs_split + + +def merge_chunks( + chunks: list[Any], + chunk_spec, +): + """ + Given a list of chunks, merge them into a single value according to + the chunk spec. + + Args: + chunks: list of chunks + chunk_spec: Chunking spec for the chunks + + Returns: + value: Merged value + """ + # This is essentially the inverse of `split_args_kwargs_into_chunks`, so the + # steps are similar to the steps in that function but in reverse. Given the + # input values: + # + # chunks = [ + # ([A, [B, C_1]], D), + # ([A, [B, C_2]], D), + # ] + # args_spec = ([None, [None, TensorChunkSpec]], None) + # + # 1. Flatten the chunks according to the chunk_spec + # + # chunks_flat = [ + # ([A, B, C_1], D), + # ([A, B, C_2], D), + # ] + # + # 2. Rotate the nesting order such that chunks are the inner dimension + # + # value_inner = ([A, B, [C_1, C_2]], D) + # + # 3. Concatenate sharded arguments + # + # value_combined = ([A, B, C], D) + # + # 4. Unflatten the combined args given the spec + # + # value = ([A, [B, C]], D) + + # Preliminary: flatten the chunk spec + if chunk_spec is not None: + spec_flattened, flatten_spec = tree_flatten(chunk_spec) + else: + # If chunk_spec is not provided, we will merge chunks along the default dimension (0), for all output fields + # We obtain the output structure by flattening chunk 0 and generate the chunk_spec + chunk0_flat, flatten_spec = tree_flatten(chunks[0]) + spec_flattened = [TensorChunkSpec(DEFAULT_CHUNK_DIM)] * len(chunk0_flat) + + # Stage 1: flatten chunks + # chunks_flattened : [num chunks, num args] + chunks_flattened = [] + + for chunk in chunks: + chunk_flattened, _ = tree_flatten(chunk) + if len(chunk_flattened) != len(spec_flattened): + raise ValueError(f"Chunk {chunk} did not match chunk spec {chunk_spec}") + + chunks_flattened.append(chunk_flattened) + + # Stage 2 and 3: Rotate nesting order s.t. chunks are inner dimension and + # concatenate sharded operands + # args_flattened : [num args] + args_flattened = [] + for arg_idx, arg in enumerate(spec_flattened): + if isinstance(arg, TensorChunkSpec): + partial_values = [ + chunks_flattened[chunk_idx][arg_idx] + for chunk_idx in range(len(chunks_flattened)) + ] + + if _debug_mask_minibatches: + # Infer size of individual chunks by running `tensor_split` again + overall_shape = partial_values[0].shape + for val in partial_values[1:]: + assert val.shape == overall_shape + meta_chunks = torch.tensor_split( + torch.empty(*overall_shape, device="meta"), + sections=len(partial_values), + dim=arg.split_dim, + ) + + values_to_cat = [] + chunk_start_idx = 0 + assert len(partial_values) == len(meta_chunks) + for partial_value, meta_chunk in zip( + partial_values, meta_chunks, strict=True + ): + chunk_end_idx = chunk_start_idx + meta_chunk.size(arg.split_dim) + + slice_indices = [slice(None, None, None)] * partial_value.ndim + slice_indices[arg.split_dim] = slice(chunk_start_idx, chunk_end_idx) + sliced = partial_value[slice_indices] + values_to_cat.append(sliced) + + chunk_start_idx = chunk_end_idx + + else: + values_to_cat = partial_values + + args_flattened.append(torch.cat(values_to_cat, dim=arg.split_dim)) + elif isinstance(arg, _CustomReducer): + reduced_val = arg.init_value + + for chunk_idx in range(len(chunks_flattened)): + reduced_val = arg.reduce_fn( + reduced_val, chunks_flattened[chunk_idx][arg_idx] + ) + + args_flattened.append(reduced_val) + else: + value = chunks_flattened[0][arg_idx] + for chunk_idx in range(1, len(chunks_flattened)): + assert chunks_flattened[chunk_idx][arg_idx] == value + args_flattened.append(value) + + # Stage 4: Unflatten combined args + return tree_unflatten(args_flattened, flatten_spec) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/pipelining/schedules.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/pipelining/schedules.py new file mode 100644 index 0000000000000000000000000000000000000000..5657068f0bcd7008a0dc9b4a2a56e364bcc92428 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/pipelining/schedules.py @@ -0,0 +1,3438 @@ +# mypy: allow-untyped-defs +# Copyright (c) Meta Platforms, Inc. and affiliates + +import copy +import csv +import itertools +import logging +import re +from abc import ABC, abstractmethod +from collections import Counter, defaultdict +from collections.abc import Callable +from enum import Enum +from functools import lru_cache +from typing import Any, cast, NamedTuple, Protocol + +import torch +import torch.distributed as dist +from torch._dynamo import OptimizedModule +from torch.distributed.fsdp import FSDPModule, UnshardHandle +from torch.nn.modules.loss import _Loss +from torch.profiler import record_function + +from ._utils import generate_rank_to_stage_mapping, generate_stage_to_rank_mapping +from .microbatch import merge_chunks, split_args_kwargs_into_chunks, TensorChunkSpec +from .stage import _PipelineStageBase + + +__all__ = [ + "get_schedule_class", + "PipelineScheduleSingle", + "PipelineScheduleMulti", + "Schedule1F1B", + "ScheduleGPipe", + "ScheduleInterleaved1F1B", + "ScheduleLoopedBFS", + "ScheduleInterleavedZeroBubble", + "ScheduleZBVZeroBubble", + "ScheduleDualPipeV", +] + +logger = logging.getLogger(__name__) + + +class _ComputationType(Enum): + # TODO(whc) rename to _ActType? + FORWARD = 1 + BACKWARD_INPUT = 2 + BACKWARD_WEIGHT = 3 + UNSHARD = 4 + RESHARD = 5 + SEND_F = 6 + RECV_F = 7 + SEND_B = 8 + RECV_B = 9 + FULL_BACKWARD = 10 + OVERLAP_F_B = 11 + REDUCE_GRAD = 12 + + def __str__(self): + str_map = { + _ComputationType.FORWARD: "F", + _ComputationType.BACKWARD_INPUT: "I", + _ComputationType.BACKWARD_WEIGHT: "W", + _ComputationType.UNSHARD: "UNSHARD", + _ComputationType.RESHARD: "RESHARD", + _ComputationType.SEND_F: "SEND_F", + _ComputationType.RECV_F: "RECV_F", + _ComputationType.SEND_B: "SEND_B", + _ComputationType.RECV_B: "RECV_B", + _ComputationType.FULL_BACKWARD: "B", + _ComputationType.OVERLAP_F_B: "OVERLAP_F_B", + _ComputationType.REDUCE_GRAD: "REDUCE_GRAD", + } + return str_map[self] + + @staticmethod + def from_str(action): + if action == "F": + return _ComputationType.FORWARD + elif action == "I": + return _ComputationType.BACKWARD_INPUT + elif action == "W": + return _ComputationType.BACKWARD_WEIGHT + elif action == "UNSHARD": + return _ComputationType.UNSHARD + elif action == "RESHARD": + return _ComputationType.RESHARD + elif action == "SEND_F": + return _ComputationType.SEND_F + elif action == "RECV_F": + return _ComputationType.RECV_F + elif action == "SEND_B": + return _ComputationType.SEND_B + elif action == "RECV_B": + return _ComputationType.RECV_B + elif action == "B": + return _ComputationType.FULL_BACKWARD + elif action == "OVERLAP_F_B": + return _ComputationType.OVERLAP_F_B + elif action == "REDUCE_GRAD": + return _ComputationType.REDUCE_GRAD + else: + raise RuntimeError(f"Invalid computation type {action}") + + +FORWARD = _ComputationType.FORWARD +BACKWARD_INPUT = _ComputationType.BACKWARD_INPUT +BACKWARD_WEIGHT = _ComputationType.BACKWARD_WEIGHT +UNSHARD = _ComputationType.UNSHARD +RESHARD = _ComputationType.RESHARD +SEND_F = _ComputationType.SEND_F +RECV_F = _ComputationType.RECV_F +SEND_B = _ComputationType.SEND_B +RECV_B = _ComputationType.RECV_B +FULL_BACKWARD = _ComputationType.FULL_BACKWARD +OVERLAP_F_B = _ComputationType.OVERLAP_F_B +REDUCE_GRAD = _ComputationType.REDUCE_GRAD + +# Convenience shorthand for compute actions only since they are used in 'simple schedule format' +F = FORWARD +I = BACKWARD_INPUT +W = BACKWARD_WEIGHT +B = FULL_BACKWARD + +# Helper to parse an action string like 1F0 into a tuple of (stage_index, computation_type, microbatch_index) +_action_regex = re.compile( + r"(\d+)(F|I|B|W|UNSHARD|RESHARD|REDUCE_GRAD|SEND_F|RECV_F|SEND_B|RECV_B)(\d*)" +) + + +class _Action(NamedTuple): + stage_index: int + computation_type: _ComputationType + microbatch_index: int | None = None + sub_actions: tuple["_Action", ...] | None = None + + def __str__(self): + return self.__repr__() + + def __repr__(self): + if self.sub_actions is not None: + # Use recursive repr for sub_actions + sub_action_reprs = [repr(sub_action) for sub_action in self.sub_actions] + return f"({';'.join(sub_action_reprs)}){self.computation_type}" + else: + repr_str = str(self.stage_index) + repr_str += str(self.computation_type) + if self.microbatch_index is not None: + repr_str += str(self.microbatch_index) + return repr_str + + @property + def is_compute_op(self) -> bool: + return self.computation_type in ( + FORWARD, + FULL_BACKWARD, + BACKWARD_INPUT, + BACKWARD_WEIGHT, + OVERLAP_F_B, + ) + + @staticmethod + def from_str(action_string: str): + """ + Reverse of __repr__ + + String should be formatted as [stage][action type][(microbatch)] + e.g. `2F0`, `1UNSHARD`, `3SEND_F1` + """ + action_string = action_string.strip() + if action_string == "": + return None + + # Check for sub_actions format: [sub_action1;sub_action2;...]ComputationType + if action_string.startswith("(") and ")" in action_string: + # Find the closing bracket to separate sub_actions from computation type + bracket_end = action_string.find(")") + sub_part = action_string[ + 1:bracket_end + ] # Remove '[' and get content before ']' + computation_type_part = action_string[ + bracket_end + 1 : + ] # Get part after ']' + + # Parse sub_actions + sub_actions = [] + if sub_part.strip(): + for sub_str in sub_part.split(";"): + sub_action = _Action.from_str(sub_str.strip()) + if sub_action is not None: + sub_actions.append(sub_action) + + # For sub_actions format, we create an action with just the computation type + # The stage_index and microbatch_index are not meaningful for the container action + return _Action( + stage_index=-1, # Placeholder, not meaningful for sub_actions container + computation_type=_ComputationType.from_str(computation_type_part), + microbatch_index=None, + sub_actions=tuple(sub_actions) if sub_actions else None, + ) + + # Handle regular single action format + if match := _action_regex.match(action_string): + stage_index, computation_type, microbatch_index = match.groups() + return _Action( + int(stage_index), + _ComputationType.from_str(computation_type), + int(microbatch_index) if len(microbatch_index) else None, + ) + elif action_string == "": + return None + raise RuntimeError( + f"Invalid action string: {action_string}, should be formatted as [stage][action type][(microbatch)] e.g. 2F0" + ) + + +@lru_cache +def _get_profiler_function_name(action: _Action) -> str: + return f"PP:{str(action)}" + + +def _format_pipeline_order( + pipeline_order: dict[int, list[_Action | None]], + error_step_number: int | None = None, +) -> str: + """ + Formats the pipeline order in a timestep (row) x rank (column) grid of actions + and returns the formatted string. + + If `error_step_number` is passed in, an additional label will be added to signify which step + that it is erroring on. + """ + + # don't mutate the original + pipeline_order = copy.deepcopy(pipeline_order) + + # Replace None with "" + for rank in pipeline_order: + for i in range(len(pipeline_order[rank])): + if pipeline_order[rank][i] is None: + # TODO make a real 'None action' that prints as empty string and make mypy happy + pipeline_order[rank][i] = "" # type: ignore[call-overload] + + # Calculate the maximum number of steps across all ranks + num_steps = max(len(actions) for actions in pipeline_order.values()) + step_labels = [ + "Step " + str(i).zfill(len(str(num_steps - 1))) for i in range(num_steps) + ] + # Sorting the dictionary by keys and retrieving values in that order + rank_actions = [ + pipeline_order.get(key, [""] * num_steps) for key in sorted(pipeline_order) + ] + # Transpose the list of lists (rows to columns) + # pyrefly: ignore [no-matching-overload] + transposed_actions = list(itertools.zip_longest(*rank_actions, fillvalue="")) + # Generate column labels for ranks + num_ranks = len(pipeline_order) + rank_labels = ["Rank " + str(i) for i in range(num_ranks)] + # Calculate the maximum length of each column, considering labels + max_lengths = [ + max(len(str(item)) if item is not None else 0 for item in col) + for col in zip(step_labels, *transposed_actions) + ] + # Format the header row with rank labels + header_row = " " * (len(step_labels[0]) + 2) + " ".join( + f"{label:<{max_lengths[i]}}" for i, label in enumerate(rank_labels) + ) + # Format each row with its corresponding label + formatted_rows = [ + f"{label}: " + + " ".join(f"{str(item):<{max_lengths[i]}}" for i, item in enumerate(row)) + + ( + " <-- ERROR HERE" + if error_step_number is not None + and int(label.split()[1]) == error_step_number + else "" + ) + for label, row in zip(step_labels, transposed_actions) + ] + # Join the rows into a single string + formatted_table = header_row + "\n" + "\n".join(formatted_rows) + "\n" + return formatted_table + + +class _PipelineSchedule(ABC): + def __init__( + self, + n_microbatches: int, + loss_fn: Callable[..., torch.Tensor] | None = None, + args_chunk_spec: tuple[TensorChunkSpec, ...] | None = None, + kwargs_chunk_spec: dict[str, TensorChunkSpec] | None = None, + output_merge_spec: dict[str, Any] | tuple[Any] | None = None, + scale_grads: bool = True, + ): + # From arguments + self._n_microbatches = n_microbatches + self._loss_fn = loss_fn + + # See documentation in `PipelineScheduleSingle` / `PipelineScheduleMulti` + self.scale_grads = scale_grads + + # Chunking specification for positional inputs. (default: `None`) + self._args_chunk_spec = args_chunk_spec + # Chunking specification for keyword inputs. (default: `None`) + self._kwargs_chunk_spec = kwargs_chunk_spec + self._output_merge_spec = output_merge_spec + """ + # args_chunk_spec and kwargs_chunk_spec specify how to chunk inputs. + # They are used to convert batch to microbatches in `step(x)`. See + # `TensorChunkSpec` for helper methods for creating them. + """ + + # Derived + self._has_backward = self._loss_fn is not None + + # Holds the losses for each microbatch. + self._internal_losses: list[torch.Tensor] = [] + logger.info("Using %s", self.__class__.__name__) + + def _maybe_compute_loss(self, stage, output, target_mbs, mb_index): + if stage.is_last and self._loss_fn is not None: + loss = self._compute_loss(output, target_mbs[mb_index]) # type: ignore[index] + self._internal_losses.append(loss) + + def _maybe_get_loss(self, stage, mb_index): + valid_index = 0 <= mb_index < len(self._internal_losses) + if stage.is_last and self._loss_fn is not None and valid_index: + return self._internal_losses[mb_index] + elif len(self._internal_losses) != 0 and not valid_index: + raise RuntimeError( + f"Loss for microbatch {mb_index} is not available. " + f"Available losses for microbatches: {self._internal_losses}" + ) + else: + return None + + def _update_losses(self, stages, losses): + """ + Update the losses to those in the internal state + """ + # if stages not a list turn into a list + if not isinstance(stages, list): + stages = [stages] + contains_last_stage = any(stage.is_last for stage in stages) + + # Return losses if there is a container passed in + if contains_last_stage and losses is not None: + if len(self._internal_losses) != self._n_microbatches: + raise RuntimeError( + f"Expecting {self._n_microbatches} losses but got {len(self._internal_losses)}" + ) + + # Clean external container first + losses.clear() + # Copy internal losses to external container + losses.extend(self._internal_losses) + + self._internal_losses.clear() + + @abstractmethod + def _step_microbatches( + self, + arg_mbs: list | None = None, + kwarg_mbs: list | None = None, + target_mbs: list | None = None, + losses: list | None = None, + return_outputs: bool = True, + ): + """ + Run one iteration of the pipeline schedule with list of microbatches. + Will go through all the microbatches according to the schedule + implementation. + + Args: + microbatches: list of microbatch args. + return_outputs: whether to return the outputs from the last stage. + """ + raise NotImplementedError + + @abstractmethod + def step( + self, + *args, + target=None, + losses: list | None = None, + return_outputs=True, + **kwargs, + ): + """ + Run one iteration of the pipeline schedule with *whole-batch* input. + Will chunk the input into microbatches automatically, and go through the + microbatches according to the schedule implementation. + + args: positional arguments to the model (as in non-pipeline case). + kwargs: keyword arguments to the model (as in non-pipeline case). + target: target for the loss function. + losses: a list to store the losses for each microbatch. + return_outputs: whether to return the outputs from the last stage. + """ + raise NotImplementedError + + def eval(self, *args, target=None, losses: list | None = None, **kwargs): + """ + Run one iteration of the pipeline schedule with *whole-batch* input. + Will chunk the input into microbatches automatically, and go through the + microbatches, calling forward only. + + args: positional arguments to the model (as in non-pipeline case). + kwargs: keyword arguments to the model (as in non-pipeline case). + target: target values for the loss function. + losses: a list to store the losses for each microbatch. + """ + # Save the original has_backward state + original_has_backward = self._has_backward + try: + self._has_backward = False + return self.step(*args, target=target, losses=losses, **kwargs) + finally: + # Restore the original state + self._has_backward = original_has_backward + + def _check_inputs( + self, + arg_mbs: list | None = None, + kwarg_mbs: list | None = None, + target_mbs: list | None = None, + losses: list | None = None, + ) -> tuple[list, list]: + """ + Pre-process/check inputs + """ + + def check_type_and_len(mbs, name: str): + if not isinstance(mbs, list): + raise TypeError(f"{name} must be a list but got a {type(mbs)}") + if len(mbs) != self._n_microbatches: + raise ValueError( + f"Expecting {self._n_microbatches} {name} but got {len(mbs)}" + ) + + if arg_mbs is not None: + check_type_and_len(arg_mbs, "arg_mbs") + else: + arg_mbs = [()] * self._n_microbatches + + if kwarg_mbs is not None: + check_type_and_len(kwarg_mbs, "kwarg_mbs") + else: + kwarg_mbs = [{}] * self._n_microbatches + + if target_mbs is not None: + check_type_and_len(target_mbs, "target_mbs") + + if losses is not None: + if not isinstance(losses, list): + raise TypeError(f"losses must be a list but got a {type(losses)}") + + return arg_mbs, kwarg_mbs + + def _compute_loss(self, output, target): + return self._loss_fn(output, target) # type: ignore[misc] + + def _split_inputs( + self, + args: tuple[Any, ...], + kwargs: dict[str, Any] | None = None, + ): + """ + Splits a full-batch input into chunks (i.e. microbatches) and returns + the chunks + """ + if args or kwargs: + args_split, kwargs_split = split_args_kwargs_into_chunks( + args, + kwargs, + self._n_microbatches, + self._args_chunk_spec, + self._kwargs_chunk_spec, + ) + return args_split, kwargs_split + else: + # Empty inputs (e.g. when called on middle stages) + # Return a list of empty tuples/dicts with matching length as chunks + return [()] * self._n_microbatches, [{}] * self._n_microbatches + + def _merge_outputs(self, output_chunks: list[Any]) -> Any: + """ + Merge output chunks back to a batch state. + If output_merge_spec is None, the utility will merge output chunks by dimension 0 (batch dim). + """ + return merge_chunks( + output_chunks, + self._output_merge_spec, + ) + + +def _batch_p2p(p2p_ops: list[dist.P2POp], desc: str | None = None) -> list[dist.Work]: + """ + Simple wrapper over batch_isend_irecv from torch.distributed, which just adds a descriptive logger on top. + """ + if len(p2p_ops) == 0: + return [] + desc_str = f"{desc}, " if desc else "" + logger.debug("batch_p2p %s%s", desc_str, p2p_ops) + return dist.batch_isend_irecv(p2p_ops) + + +def _sorted_batch_p2p( + p2p_ops: list[dist.P2POp], desc: str | None = None +) -> dict[int, list[dist.Work]]: + """ + Sorts the list of P2P ops by the peer rank, and then calls + batch_isend_irecv. Return a dictionary of works by peer rank. This function + helps us avoid hangs in case of skip connections. + """ + # Arrange p2p_ops by peer rank: + # int is the peer rank; + # List is the list of ops towards the peer + ops_by_peer: dict[int, list[dist.P2POp]] = defaultdict(list) + work_by_peer: dict[int, list[dist.Work]] = {} + if len(p2p_ops) == 0: + return work_by_peer + + # Classify the ops by peer rank + for op in p2p_ops: + ops_by_peer[op.peer].append(op) + + # Call batch_isend_irecv per peer, in sorted order of the peers (to avoid hangs) + for peer, ops in sorted(ops_by_peer.items()): + work_by_peer[peer] = _batch_p2p(ops, desc=desc) + + return work_by_peer + + +def _wait_batch_p2p(work: list[dist.Work]): + """ + Waits for a list of dist.Work (typically from _batch_p2p / _sorted_batch_p2p). + """ + for w in work: + w.wait() + + +class PipelineScheduleSingle(_PipelineSchedule): + """ + Base class for single-stage schedules. + Implements the `step` method. + Derived classes should implement `_step_microbatches`. + + Gradients are scaled by num_microbatches depending on the `scale_grads` argument, defaulting to True. This setting + should match the configuration of your loss_fn, which may either average losses (scale_grads=True) + or sum losses (scale_grads=False). + """ + + def __init__( + self, + stage: _PipelineStageBase, + n_microbatches: int, + loss_fn: Callable | None = None, + args_chunk_spec: tuple[TensorChunkSpec, ...] | None = None, + kwargs_chunk_spec: dict[str, TensorChunkSpec] | None = None, + output_merge_spec: dict[str, Any] | tuple[Any] | None = None, + scale_grads: bool = True, + ): + # Init parent + super().__init__( + n_microbatches=n_microbatches, + loss_fn=loss_fn, + args_chunk_spec=args_chunk_spec, + kwargs_chunk_spec=kwargs_chunk_spec, + output_merge_spec=output_merge_spec, + scale_grads=scale_grads, + ) + # Self attributes + self._stage = stage + self._num_stages = stage.num_stages + self._stage_forward_initialized = False + self._stage_backward_initialized = False + + if n_microbatches < self._num_stages: + raise ValueError( + f"Number of microbatches ({n_microbatches}) must be greater than \ +or equal to the number of stages ({self._num_stages})." + ) + + self.pipeline_order: dict[int, list[_Action | None]] | None = ( + self._get_pipeline_order() + ) + + def _initialize_stage(self, args, kwargs): + if not self._stage_forward_initialized: + # Prepare the communication needed for the pipeline schedule execution + # This is needed because during execution we always perform a series of batch P2P ops + # The first call of the batched P2P needs to involve the global group + all_ops: list[dist.P2POp] = [] + all_ops.extend(self._stage._get_init_p2p_neighbors_ops()) + _wait_batch_p2p(_batch_p2p(all_ops)) + + self._stage._prepare_forward_infra(self._n_microbatches, args, kwargs) + self._stage_forward_initialized = True + + if self._has_backward and not self._stage_backward_initialized: + self._stage._prepare_backward_infra(self._n_microbatches) + self._stage_backward_initialized = True + + def step( + self, + *args, + target=None, + losses: list | None = None, + return_outputs: bool = True, + **kwargs, + ): + """ + Run one iteration of the pipeline schedule with *whole-batch* input. + Will chunk the input into microbatches automatically, and go through the + microbatches according to the schedule implementation. + + args: positional arguments to the model (as in non-pipeline case). + kwargs: keyword arguments to the model (as in non-pipeline case). + target: target for the loss function. + losses: a list to store the losses for each microbatch. + return_outputs: whether to return the outputs from the last stage. + """ + if self._has_backward and not torch.is_grad_enabled(): + raise RuntimeError( + "step() requires gradients to be enabled for backward computation; " + "it should not be used under torch.no_grad() context. " + "Please call eval() instead." + ) + + # Set the same has_backward flag for stage object + self._stage.has_backward = self._has_backward + + # Clean per iteration + self._stage.clear_runtime_states() + + # Split inputs into microbatches + args_split, kwargs_split = self._split_inputs(args, kwargs) + + # Split target into microbatches + if target is not None: + targets_split = list(torch.tensor_split(target, self._n_microbatches)) + else: + targets_split = None + + # Run microbatches + self._step_microbatches( + args_split, kwargs_split, targets_split, losses, return_outputs + ) + + # Return merged results per original format + if self._stage.is_last and return_outputs: + return self._merge_outputs(self._stage.output_chunks) + else: + return None + + def _get_pipeline_order(self) -> dict[int, list[_Action | None]] | None: + """ + Returns the pipeline execution order as a schedule IR. + + The returned IR is a dictionary mapping rank IDs to lists of actions. + Each action is either an _Action object representing computation to perform, + or None representing a deliberate idle step. + + The None values are used to represent pipeline bubbles where a rank + must wait for dependencies from other ranks before proceeding. However + during execution, with the _PipelineScheduleRuntime, these Nones are + skipped since the relevant communication (send/recv) will be scheduled and waited on. + + Returns: + A dictionary mapping rank -> list of actions + """ + return None + + +class _ScheduleForwardOnly(PipelineScheduleSingle): + """ + The forward-only schedule. + Will go through all the microbatches and perform only the forward pass + """ + + def _step_microbatches( + self, + arg_mbs: list | None = None, + kwarg_mbs: list | None = None, + target_mbs: list | None = None, + losses: list | None = None, + return_outputs: bool = True, + ): + """ + Run one iteration of the pipeline schedule + """ + if target_mbs is not None or losses is not None: + raise RuntimeError( + "Forward-only schedule does not support loss computation" + ) + + arg_mbs, kwarg_mbs = self._check_inputs(arg_mbs, kwarg_mbs, target_mbs, losses) + self._initialize_stage(arg_mbs[0], kwarg_mbs[0]) + + # Delay send waits + fwd_sends_to_wait: list[list[dist.Work]] = [] + + # Run microbatches + for i in range(self._n_microbatches): + with record_function(f"Forward {i}"): + ops = self._stage.get_fwd_recv_ops(i) + works = _sorted_batch_p2p(ops, desc="fwd_recv") + for work in works.values(): + _wait_batch_p2p(work) + + self._stage.forward_one_chunk(i, arg_mbs[i], kwarg_mbs[i]) # type: ignore[index] + + ops = self._stage.get_fwd_send_ops(i) + works = _sorted_batch_p2p(ops, desc="fwd_send") + fwd_sends_to_wait.extend(works.values()) + + logger.debug("[%s] Forwarded microbatch %s", self._stage.stage_index, i) + + # Wait for all forward sends to finish + # This should not have performance impact because by the time the first + # backward arrives all the forward sends should have been finished. + for work in fwd_sends_to_wait: + _wait_batch_p2p(work) + + +class ScheduleGPipe(PipelineScheduleSingle): + """ + The GPipe schedule. + Will go through all the microbatches in a fill-drain manner. + """ + + def _step_microbatches( + self, + arg_mbs: list | None = None, + kwarg_mbs: list | None = None, + target_mbs: list | None = None, + losses: list | None = None, + return_outputs: bool = True, + ): + """ + Run one iteration of the pipeline schedule with list of microbatches. + Will go through all the microbatches according to the GPipe schedule. + + Args: + microbatches: list of microbatch args. + return_outputs: whether to return the outputs from the last stage. + """ + arg_mbs, kwarg_mbs = self._check_inputs(arg_mbs, kwarg_mbs, target_mbs, losses) + self._initialize_stage(arg_mbs[0], kwarg_mbs[0]) + + # Delay send waits + fwd_sends_to_wait: list[list[dist.Work]] = [] + + # Run microbatches + for i in range(self._n_microbatches): + with record_function(f"Forward {i}"): + ops = self._stage.get_fwd_recv_ops(i) + works = _sorted_batch_p2p(ops, desc="fwd_recv") + for work in works.values(): + _wait_batch_p2p(work) + + output = self._stage.forward_one_chunk( + i, arg_mbs[i], kwarg_mbs[i], save_forward_output=return_outputs + ) # type: ignore[index] + + ops = self._stage.get_fwd_send_ops(i) + works = _sorted_batch_p2p(ops, desc="fwd_send") + fwd_sends_to_wait.extend(works.values()) + + logger.debug("[%s] Forwarded microbatch %s", self._stage.stage_index, i) + + self._maybe_compute_loss(self._stage, output, target_mbs, i) + + # Wait for all forward sends to finish + # This should not have performance impact because by the time the first + # backward arrives all the forward sends should have been finished. + for work in fwd_sends_to_wait: + _wait_batch_p2p(work) + + # Run backward + # Delay send waits + bwd_sends_to_wait: list[list[dist.Work]] = [] + for i in range(self._n_microbatches): + with record_function(f"Backward {i}"): + ops = self._stage.get_bwd_recv_ops(i) + works = _sorted_batch_p2p(ops, desc="bwd_recv") + for work in works.values(): + _wait_batch_p2p(work) + + loss = self._maybe_get_loss(self._stage, i) + self._stage.backward_one_chunk( + i, + loss=loss, + last_backward=i == self._n_microbatches - 1, + ) + + ops = self._stage.get_bwd_send_ops(i) + works = _sorted_batch_p2p(ops, desc="bwd_send") + bwd_sends_to_wait.extend(works.values()) + + logger.debug("[%s] Backwarded microbatch %s", self._stage.stage_index, i) + + # Wait for all backward sends to finish + for work in bwd_sends_to_wait: + _wait_batch_p2p(work) + + # Update losses if there is a container passed in + self._update_losses(self._stage, losses) + + self._stage.perform_reduce_grad(self._n_microbatches if self.scale_grads else 1) + + def _get_pipeline_order(self) -> dict[int, list[_Action | None]] | None: + """ + Returns the pipeline order for GPipe schedule. + + See base method in PipelineScheduleSingle for details on the schedule IR format. + """ + pipeline_order = {} + pp_group_size = self._num_stages + + for rank in range(pp_group_size): + actions: list[_Action | None] = [] + + # 1. Initial delay based on rank position + warmup_delay = rank + actions.extend([None] * warmup_delay) + + # 2. Forward passes for all microbatches + for mb_idx in range(self._n_microbatches): + actions.append(_Action(rank, _ComputationType.FORWARD, mb_idx)) + + # 3. Wait period before backward passes can begin + backward_delay = 3 * (pp_group_size - 1 - rank) + actions.extend([None] * backward_delay) + + # 4. Backward passes for all microbatches + for mb_idx in range(self._n_microbatches): + actions.append(_Action(rank, _ComputationType.FULL_BACKWARD, mb_idx)) + + pipeline_order[rank] = _add_reduce_grad(actions, self._n_microbatches) + + return pipeline_order # type: ignore[return-value] + + +class Schedule1F1B(PipelineScheduleSingle): + """ + The 1F1B schedule. + Will perform one forward and one backward on the microbatches in steady state. + """ + + def _step_microbatches( + self, + arg_mbs: list | None = None, + kwarg_mbs: list | None = None, + target_mbs: list | None = None, + losses: list | None = None, + return_outputs: bool = True, + ): + """ + Run one iteration of the pipeline schedule with list of microbatches. + Will go through all the microbatches according to the 1F1B schedule. + + Args: + microbatches: list of microbatch args. + return_outputs: whether to return the outputs from the last stage. + """ + arg_mbs, kwarg_mbs = self._check_inputs(arg_mbs, kwarg_mbs, target_mbs, losses) + self._initialize_stage(arg_mbs[0], kwarg_mbs[0]) + + # Last stage has 1 warmup, second-to-last 2 warmups, ... + # first stage `num_stages` warmups + warmup_chunks = min( + self._n_microbatches, + self._num_stages - self._stage.stage_index, + ) + + # Chunk counters + fwd_mb_index = 0 + bwd_mb_index = 0 + + # Warmup phase + send_work: list[dist.Work] = [] + fwd_sends = [] + for _ in range(warmup_chunks): + # Receive activations + fwd_recvs = self._stage.get_fwd_recv_ops(fwd_mb_index) + _wait_batch_p2p(_batch_p2p(fwd_recvs, desc="fwd_recv")) + + # Compute + output = self._stage.forward_one_chunk( + fwd_mb_index, + arg_mbs[fwd_mb_index], + kwarg_mbs[fwd_mb_index], + save_forward_output=return_outputs, + ) # type: ignore[index] + + # Clear previous chunk's forward sends (hopefully they have well + # finished, otherwise, we are heavily communication bound, in which + # case it doesn't create a lot of benefit to compute next chunk + # eagerly either) + _wait_batch_p2p(send_work) + + # Send activations + fwd_sends = self._stage.get_fwd_send_ops(fwd_mb_index) + if fwd_mb_index != warmup_chunks - 1: + # Safe to fire + send_work = _batch_p2p(fwd_sends, desc="fwd_send") + # otherwise: + # The last forward send is left for fuse with first 1B in 1B1F below + + # Compute loss + self._maybe_compute_loss(self._stage, output, target_mbs, fwd_mb_index) + fwd_mb_index += 1 + + # Now we should have send ops left over, to be fused with first 1B of 1B1F phase below. + + # 1B1F phase + while True: # Don't worry, we have a break inside + # We actually do 1B first as the `1B1F` name indicates, so prepare its recv ops + bwd_recvs = self._stage.get_bwd_recv_ops(bwd_mb_index) + + # Now, we need to fire the fwd_sends and bwd_recvs together + _wait_batch_p2p(_batch_p2p(fwd_sends + bwd_recvs, desc="fwd_send_bwd_recv")) + + # Backward one chunk + loss = self._maybe_get_loss(self._stage, bwd_mb_index) + self._stage.backward_one_chunk( + bwd_mb_index, + loss=loss, + last_backward=bwd_mb_index == self._n_microbatches - 1, + ) + + # Get the bwd send ops, but don't fire, to be fused with the 1F below + bwd_sends = self._stage.get_bwd_send_ops(bwd_mb_index) + bwd_mb_index += 1 + + if fwd_mb_index == self._n_microbatches: + # We are done with 1B1F, so break with some left-over bwd_sends + break + + # We prepare 1F of the `1B1F` + fwd_recvs = self._stage.get_fwd_recv_ops(fwd_mb_index) + + # Fuse it with bwd_sends above + _wait_batch_p2p(_batch_p2p(bwd_sends + fwd_recvs, desc="bwd_send_fwd_recv")) + + # Now do the fwd + output = self._stage.forward_one_chunk( + fwd_mb_index, + arg_mbs[fwd_mb_index], + kwarg_mbs[fwd_mb_index], + save_forward_output=return_outputs, + ) # type: ignore[index] + + # Compute loss + self._maybe_compute_loss(self._stage, output, target_mbs, fwd_mb_index) + + # Get the fwd send ops, but don't fire, leave it for the next iter (wrap-around) + fwd_sends = self._stage.get_fwd_send_ops(fwd_mb_index) + fwd_mb_index += 1 + + # Remember we still have some bwd_sends left over after the break? Now it is time to fire it + send_work = _batch_p2p(bwd_sends, desc="bwd_send") + + # Cooldown + while bwd_mb_index < self._n_microbatches: + # prepare bwd recv ops + bwd_recvs = self._stage.get_bwd_recv_ops(bwd_mb_index) + _wait_batch_p2p(_batch_p2p(bwd_recvs, desc="bwd_recv")) + + # Backward one chunk + loss = self._maybe_get_loss(self._stage, bwd_mb_index) + self._stage.backward_one_chunk( + bwd_mb_index, + loss=loss, + last_backward=bwd_mb_index == self._n_microbatches - 1, + ) + + # Clear previous chunk's backward sends (hopefully they have well finished) + _wait_batch_p2p(send_work) + + # Get the bwd send ops, fire it + bwd_sends = self._stage.get_bwd_send_ops(bwd_mb_index) + send_work = _batch_p2p(bwd_sends, desc="bwd_send") + bwd_mb_index += 1 + + # Wait for the last backward send to finish + _wait_batch_p2p(send_work) + + # Return losses if there is a container passed in + self._update_losses(self._stage, losses) + + self._stage.perform_reduce_grad(self._n_microbatches if self.scale_grads else 1) + + def _get_pipeline_order(self) -> dict[int, list[_Action | None]] | None: + """ + Returns the pipeline order for 1F1B schedule. + + See base method in PipelineScheduleSingle for details on the schedule IR format. + """ + pipeline_order = {} + pp_group_size = self._num_stages + + for rank in range(pp_group_size): + actions: list[_Action | None] = [] + + # 1. Warmup phase: initial delay based on rank + actions.extend([None] * rank) + + # 2. Initial forward passes before 1F1B phase + num_forward = (pp_group_size - 1) - rank + forward_mb = 0 + for i in range(num_forward): + actions.append(_Action(rank, _ComputationType.FORWARD, i)) + forward_mb = i + + # 3. Wait for backward to be ready + wait_for_1f1b = max(0, 2 * (pp_group_size - 1 - rank)) + actions.extend([None] * wait_for_1f1b) + + # 4. 1F1B steady state phase + backward_mb = 0 + remaining_forward = self._n_microbatches - num_forward + + while remaining_forward > 0: + # One forward + forward_mb += 1 + actions.append(_Action(rank, _ComputationType.FORWARD, forward_mb)) + remaining_forward -= 1 + + # One backward + actions.append( + _Action(rank, _ComputationType.FULL_BACKWARD, backward_mb) + ) + backward_mb += 1 + + # 5. Cooldown phase: remaining backward passes + remaining_backward = self._n_microbatches - backward_mb + + while remaining_backward > 0: + # Add None and backward actions in alternating pattern + # based on distance from the last stage + if (pp_group_size - rank) > 0: + actions.append(None) + # Decrement the wait counter only if we still have backward passes to do + if remaining_backward > 0: + actions.append( + _Action(rank, _ComputationType.FULL_BACKWARD, backward_mb) + ) + backward_mb += 1 + remaining_backward -= 1 + else: + # If we're at the last stage, just add backward actions without None + actions.append( + _Action(rank, _ComputationType.FULL_BACKWARD, backward_mb) + ) + backward_mb += 1 + remaining_backward -= 1 + + pipeline_order[rank] = _add_reduce_grad(actions, self._n_microbatches) + return pipeline_order + + +def _requires_reduce_grad(action_type: _ComputationType) -> bool: + return action_type in (W, B) + + +def _add_reduce_grad( + actions: list[_Action | None], n_microbatches: int +) -> list[_Action | None]: + """ + REDUCE_GRAD refers to joint across minibatches grad reduction. + reduce_grad frees memory and we want to schedule it just after the last "backward"-like stage. + """ + actions_with_reduce_grad: list[_Action | None] = [] + cnt: dict[int, int] = defaultdict(int) + + def _leaf_action(a, to_schedule): + if _requires_reduce_grad(a.computation_type): + stage_index = a.stage_index + cnt[stage_index] += 1 + if cnt[stage_index] == n_microbatches: + to_schedule.append(stage_index) + + for a in actions: + if a is None: + continue + actions_with_reduce_grad.append(a) + schedule_reduce_grad_stage_idxs: list[int] = [] + if a.computation_type == OVERLAP_F_B and a.sub_actions is not None: + for sub_action in a.sub_actions: + _leaf_action(sub_action, schedule_reduce_grad_stage_idxs) + else: + _leaf_action(a, schedule_reduce_grad_stage_idxs) + + for stage_idx in schedule_reduce_grad_stage_idxs: + actions_with_reduce_grad.append(_Action(stage_idx, REDUCE_GRAD, None)) + return actions_with_reduce_grad + + +def _add_unshard_reshard( + compute_actions: list[_Action | None], + max_active_stages: int = 3, +) -> list[_Action]: + """Given a basic schedule involving only compute actions (F,B,W,OVERLAP_F_B), add UNSHARD/RESHARD actions for FSDP. + + UNSHARD refers to fetching the full contents of an FSDP-sharded layer, requiring an all-gather operation. + RESHARD does the opposite, releasing memory (but doing no communication) + + We abandon the "timestep lock" during lowering + + max_active_stages controls how many prefetches we allow. It should be measured in mb and tuneable but in practice + 3 stages is probably the thing we want? + (to account for having one f and one b active, and something else prefetching?) + """ + + def next_stage_indices(count: int, next_actions: list[_Action | None]) -> list[int]: + """Remove duplicates (same stage, different microbatch), find next 'count' stages that will do compute.""" + seen: set[int] = set() + ret: list[int] = [] + + for a in next_actions: + if a is not None: + # Handle OVERLAP_F_B actions by checking their sub_actions + if a.computation_type == OVERLAP_F_B and a.sub_actions is not None: + for sub_action in a.sub_actions: + if sub_action.stage_index not in seen: + seen.add(sub_action.stage_index) + ret.append(sub_action.stage_index) + if len(ret) >= count: + break + else: + # Regular action + if a.stage_index not in seen: + seen.add(a.stage_index) + ret.append(a.stage_index) + if len(ret) == count: + break + return ret + + active_stages: set[int] = set() + fsdp_aware_actions: list[_Action] = [] + + def _unshard(stage_index: int): + active_stages.add(stage_index) + fsdp_aware_actions.append(_Action(stage_index, UNSHARD, None)) + + def _reshard(stage_index: int): + active_stages.remove(stage_index) + fsdp_aware_actions.append(_Action(stage_index, RESHARD, None)) + + for i, action in enumerate(compute_actions): + if action is None: + continue + + # We prefetch the next N stages we'll see, dropping existing stages to make room + next_n = next_stage_indices(max_active_stages, compute_actions[i:]) + # Fetch needs to be ordered correctly, so don't use a set + fetch = list(filter(lambda s: s not in active_stages, next_n)) + # Unclear what the best policy is for eviction, but we can maintain order so we do + evict = list(filter(lambda s: s not in next_n, active_stages)) + + # logger.debug( + # "_add_unshard_reshard Step %d active: %s fetch %s, evict %s", + # i, + # active_stages, + # fetch, + # evict, + # ) + + for stage in evict: + _reshard(stage) + for stage in fetch: + _unshard(stage) + fsdp_aware_actions.append(action) + + # Reshard all remaining active stages after processing all operations + for stage in list(active_stages): + _reshard(stage) + + return fsdp_aware_actions + + +def _merge_bw( + compute_actions: list[_Action | None], +) -> list[_Action]: + """Given a basic schedule involving only compute actions (F,I,W), merge adjacent I and W ops into B ops. + (note: I = BACKWARD_INPUT, W = BACKWARD_WEIGHT, B = FULL_BACKWARD) + + B refers to running the whole backward (not separating grad_input and grad_weight), which can be more efficient + in some cases. + """ + merged_actions = [] + while compute_actions: + action = compute_actions.pop(0) + if action is None: + continue + + # Remove any None actions and find the next non-None action + while len(compute_actions) and compute_actions[0] is None: + compute_actions.pop(0) + + # Get the next action if it exists + next_action = compute_actions[0] if len(compute_actions) > 0 else None + + if ( + action.computation_type == BACKWARD_INPUT + and next_action is not None + and next_action.computation_type == BACKWARD_WEIGHT + and action.stage_index == next_action.stage_index + and action.microbatch_index == next_action.microbatch_index + ): + merged_actions.append( + _Action(action.stage_index, FULL_BACKWARD, action.microbatch_index) + ) + compute_actions.pop(0) + else: + merged_actions.append(action) + return merged_actions + + +def _add_send_recv( + compute_actions: dict[int, list[_Action]], + stage_to_rank: Callable[[int], int], + num_stages: int, +) -> dict[int, list[_Action]]: + """ + Transforms a compute-only schedule into a complete schedule with communication actions. + + For actions with sub-actions (OVERLAP_F_B) we ensure that all the subactions have been + computed and the communication is ready + """ + comm_actions: dict[int, list[_Action]] = {rank: [] for rank in compute_actions} + prev_actions: dict[int, set[_Action]] = {rank: set() for rank in compute_actions} + + def _has_comms(action: _Action) -> bool: + if action.computation_type == F: + return action.stage_index != num_stages - 1 and stage_to_rank( + action.stage_index + 1 + ) != stage_to_rank(action.stage_index) + elif action.computation_type in (BACKWARD_INPUT, FULL_BACKWARD): + return action.stage_index != 0 and stage_to_rank( + action.stage_index - 1 + ) != stage_to_rank(action.stage_index) + return False + + def _get_comms(action: _Action) -> tuple[_Action, _Action]: + assert _has_comms(action), f"{action} is not a valid comm action" + stage_idx = action.stage_index + ctype = action.computation_type + mb_idx = action.microbatch_index + send = _Action(stage_idx, SEND_F if ctype == F else SEND_B, mb_idx) + recv_stage_idx = stage_idx + 1 if ctype == F else stage_idx - 1 + recv = _Action(recv_stage_idx, RECV_F if ctype == F else RECV_B, mb_idx) + return send, recv + + def _ready_to_schedule(action: _Action | None, prev_actions: set[_Action]) -> bool: + """We don't put our own recv ops in the schedule, we let a sender on another rank put our recv ops in place. + This helps ensure a sane (non-hanging) ordering of sends and recvs. + But it also means we might not be able to schedule our next compute action yet. + """ + if action is None: + return True + elif action.computation_type == F and action.stage_index != 0: + if ( + _Action(action.stage_index, RECV_F, action.microbatch_index) + in prev_actions + ): + return True + elif ( + _Action(action.stage_index - 1, F, action.microbatch_index) + in prev_actions + ): + return True + return False + elif ( + action.computation_type in (BACKWARD_INPUT, FULL_BACKWARD) + and action.stage_index != num_stages - 1 + ): + if ( + _Action(action.stage_index, RECV_B, action.microbatch_index) + in prev_actions + ): + return True + elif ( + _Action(action.stage_index + 1, BACKWARD_INPUT, action.microbatch_index) + in prev_actions + ): + return True + elif ( + _Action(action.stage_index + 1, FULL_BACKWARD, action.microbatch_index) + in prev_actions + ): + return True + return False + else: + return True + + while compute_actions: + progress = False + # go in order of ranks even if dict keys aren't ordered + for rank in sorted(compute_actions): + assert len(compute_actions[rank]) > 0, ( + f"{rank=}, {len(compute_actions[rank])=}" + ) + action = compute_actions[rank][0] + # handle case where parent action (e.g. OVERLAP_F_B) can be comprised of subactions + if action is not None and action.sub_actions is not None: + all_actions = action.sub_actions + else: + all_actions = (action,) + + if not all(_ready_to_schedule(a, prev_actions[rank]) for a in all_actions): + continue + + # The action's dependencies are satisfied, so add to schedule + if action is not None: + comm_actions[rank].append(action) + for a in all_actions: + prev_actions[rank].add(a) + if _has_comms(a): + send, recv = _get_comms(a) + # TODO we can avoid send/recv if the 2 stages are on the same rank. + # should we avoid that in the runtime or here? + comm_actions[rank].append(send) + prev_actions[rank].add(send) + comm_actions[stage_to_rank(recv.stage_index)].append(recv) + prev_actions[stage_to_rank(recv.stage_index)].add(recv) + + compute_actions[rank].pop(0) + if len(compute_actions[rank]) == 0: + del compute_actions[rank] + progress = True + assert progress, "Malformed compute schedule, can't schedule sends/recvs" + return comm_actions + + +def _validate_schedule( + actions: dict[int, list[_Action | None]], + pp_group_size: int, + num_stages: int, + num_microbatches: int, +) -> dict[int, int]: + assert len(actions) == pp_group_size, ( + f"Schedule has incorrect number of ranks - expected {pp_group_size}, actual {len(actions)}" + ) + for rank in range(pp_group_size): + assert rank in actions, f"Schedule is missing actions for rank {rank}" + + # We will count all the actions per stage and ensure they happen in a valid order + # (e.g. F before (B, I) before W for a given microbatch) + stage_actions: dict[int, dict[_ComputationType, set]] = { + stage_id: { + F: set(), + B: set(), + I: set(), + W: set(), + } + for stage_id in range(num_stages) + } + stage_index_to_rank_mapping = {} + + def _process_action(action: _Action, rank: int, step: int): + """Process a single action and update stage_actions and stage_index_to_rank_mapping""" + s_id = action.stage_index + ctype = action.computation_type + mb_id = action.microbatch_index + + if ctype == F: + stage_actions[s_id][F].add(mb_id) + elif ctype == B: + if mb_id not in stage_actions[s_id][F]: + error_msg = ( + f"Rank {rank}, step {step}: Running Full Backward for stage {s_id}, " + f"microbatch {mb_id} without first running Forward" + ) + formatted_schedule = _format_pipeline_order( + actions, error_step_number=step + ) + full_error_msg = ( + f"{error_msg}\n\nFull pipeline schedule:\n{formatted_schedule}" + ) + raise AssertionError(full_error_msg) + stage_actions[s_id][B].add(mb_id) + elif ctype == I: + if mb_id not in stage_actions[s_id][F]: + error_msg = ( + f"Rank {rank}, step {step}: Running Backward Input for stage {s_id}, " + f"microbatch {mb_id} without first running Forward" + ) + formatted_schedule = _format_pipeline_order( + actions, error_step_number=step + ) + full_error_msg = ( + f"{error_msg}\n\nFull pipeline schedule:\n{formatted_schedule}" + ) + raise AssertionError(full_error_msg) + stage_actions[s_id][I].add(mb_id) + elif ctype == W: + if mb_id not in stage_actions[s_id][I]: + error_msg = ( + f"Rank {rank}, step {step}: Running Backward Weight for stage {s_id}, " + f"microbatch {mb_id} without first running Backward Input" + ) + formatted_schedule = _format_pipeline_order( + actions, error_step_number=step + ) + full_error_msg = ( + f"{error_msg}\n\nFull pipeline schedule:\n{formatted_schedule}" + ) + raise AssertionError(full_error_msg) + stage_actions[s_id][W].add(mb_id) + + if s_id not in stage_index_to_rank_mapping: + stage_index_to_rank_mapping[s_id] = rank + else: + existing_rank = stage_index_to_rank_mapping[s_id] + assert rank == existing_rank, ( + f"Rank {rank}, step {step}: Stage {s_id} is assigned to both rank {rank} and rank {existing_rank}" + ) + + for rank in actions: + for step, action in enumerate(actions[rank]): + if action is None: + continue + assert isinstance(action, _Action), ( + f"Rank {rank}, step {step}: Got an invalid action: {action}, expected instance of _Action" + ) + + # Check if action has sub_actions + if action.sub_actions is not None: + # Process each sub_action instead of the main action + for sub_action in action.sub_actions: + _process_action(sub_action, rank, step) + else: + # Process the main action normally + _process_action(action, rank, step) + + for s_id in stage_actions: + f_mb = len(stage_actions[s_id][F]) + b_mb = len(stage_actions[s_id][B]) + i_mb = len(stage_actions[s_id][I]) + w_mb = len(stage_actions[s_id][W]) + + assert f_mb == num_microbatches, ( + f"Got {f_mb} {F} microbatches for stage {s_id}, expected {num_microbatches}" + ) + + assert i_mb == w_mb, ( + f"Invalid backward microbatches for stage {s_id}: I and W must have equal counts, \ + but got I={i_mb}, W={w_mb}" + ) + + assert b_mb + (i_mb + w_mb) // 2 == num_microbatches, ( + f"Invalid backward microbatches for stage {s_id}: expected {num_microbatches} total backwards, \ + but got B={b_mb}, I={i_mb}, W={w_mb}" + ) + return stage_index_to_rank_mapping + + +class PipelineScheduleMulti(_PipelineSchedule): + """ + Base class for multi-stage schedules. + Implements the `step` method. + + Gradients are scaled by num_microbatches depending on the `scale_grads` argument, defaulting to True. This setting + should match the configuration of your loss_fn, which may either average losses (scale_grads=True) + or sum losses (scale_grads=False). + """ + + def __init__( + self, + stages: list[_PipelineStageBase], + n_microbatches: int, + loss_fn: Callable | None = None, + args_chunk_spec: tuple[TensorChunkSpec, ...] | None = None, + kwargs_chunk_spec: dict[str, TensorChunkSpec] | None = None, + output_merge_spec: dict[str, Any] | tuple[Any] | None = None, + use_full_backward: bool | None = None, + scale_grads: bool = True, + backward_requires_autograd: bool = True, + ): + # Init parent + super().__init__( + n_microbatches=n_microbatches, + loss_fn=loss_fn, + args_chunk_spec=args_chunk_spec, + kwargs_chunk_spec=kwargs_chunk_spec, + output_merge_spec=output_merge_spec, + scale_grads=scale_grads, + ) + # Self attributes + self._stages = stages + self._num_stages = stages[0].num_stages + self.pp_group_size = stages[0].group_size + self.rank = stages[0].group_rank + # Set the pipeline stage states + self.stage_index_to_group_rank = generate_stage_to_rank_mapping( + self.pp_group_size, self._num_stages + ) + for stage in self._stages: + stage.stage_index_to_group_rank = self.stage_index_to_group_rank + + self._stages_forward_initialized = False + self._stages_backward_initialized = False + + # avoid putting a reference to 'self' inside the lambda, it creates a ref cycle + has_loss: bool = self._loss_fn is not None + self._should_compute_loss = lambda stage: stage.is_last and has_loss + + # This will be set during init of derived schedules + self.pipeline_order: dict[int, list[_Action | None]] = {} + + # When using a custom backward function, we may or may not need autograd to be used + # for the backward pass. This flag is used to determine whether or torch.is_grad_enabled() + # check should be performed before the step function. + self._backward_requires_autograd = backward_requires_autograd + + if use_full_backward is not None: + logger.warning( + "Deprecation warning: 'use_full_backward' is no longer supported. " + "Simply stop passing it, and everything should still work fine." + ) + + def _initialize_stages(self, args: tuple[Any, ...], kwargs): + if not self._stages_forward_initialized: + # Prepare the communication needed for the pipeline schedule execution + # This is needed because during execution we always perform a series of batch P2P ops + # The first call of the batched P2P needs to involve the global group + all_ops: list[dist.P2POp] = [] + for stage in self._stages: + all_ops.extend(stage._get_init_p2p_neighbors_ops()) + _wait_batch_p2p(_batch_p2p(all_ops)) + + # may be 'none' value (if this stage sends its output shapes to the next stage via P2P) + # or real value (if this stage and next stage are on the same device) + next_stage_args: tuple[Any, ...] = tuple() + for stage in self._stages: + if stage.is_first: + next_stage_args = stage._prepare_forward_infra( + self._n_microbatches, args, kwargs + ) + else: + next_stage_args = stage._prepare_forward_infra( + self._n_microbatches, next_stage_args, kwargs + ) + self._stages_forward_initialized = True + + if self._has_backward and not self._stages_backward_initialized: + for stage in self._stages: + stage._prepare_backward_infra(self._n_microbatches) + self._stages_backward_initialized = True + + def _validate_and_set_stage_mapping( + self, actions: dict[int, list[_Action | None]] + ) -> None: + """ + Allocates the stage index to rank mapping which is needed for communication + """ + self.stage_index_to_group_rank = _validate_schedule( + actions, + self.pp_group_size, + self._num_stages, + self._n_microbatches, + ) + for stage in self._stages: + stage.stage_index_to_group_rank = self.stage_index_to_group_rank + + def _dump_csv(self, filename): + """Dump a CSV representation of the schedule into a file with the provided filename.""" + with open(filename, "w", newline="") as csvfile: + writer = csv.writer(csvfile) + for rank in self.pipeline_order: + writer.writerow(self.pipeline_order[rank]) + + def _load_csv(self, filename, format="compute_only"): + """Load a CSV representation of the schedule from a file with the provided filename. + This API will most likely get renamed/refactored so is marked as internal for now. + + format must be "compute_only" for PipelineScheduleMulti. + """ + assert format == "compute_only" + with open(filename, newline="") as csvfile: + reader = csv.reader(csvfile) + for rank, row in enumerate(reader): + self.pipeline_order[rank] = [_Action.from_str(s) for s in row] + + # Validates the order of the pipeline actions and infers the stage_to_rank_mapping. + # This will overwrite the default stage_to_rank_mapping created in the constructor + self._validate_and_set_stage_mapping(self.pipeline_order) + + def step( + self, + *args, + target=None, + losses: list | None = None, + return_outputs: bool = True, + **kwargs, + ): + """ + Run one iteration of the pipeline schedule with *whole-batch* input. + Will chunk the input into microbatches automatically, and go through the + microbatches according to the schedule implementation. + + args: positional arguments to the model (as in non-pipeline case). + kwargs: keyword arguments to the model (as in non-pipeline case). + target: target for the loss function. + losses: a list to store the losses for each microbatch. + return_outputs: whether to return the outputs from the last stage. + """ + if ( + self._has_backward + and self._backward_requires_autograd + and not torch.is_grad_enabled() + ): + raise RuntimeError( + "step() requires gradients to be enabled for backward computation; " + "it should not be used under torch.no_grad() context. " + "Please call eval() instead." + ) + + # Set the same has_backward flag for stage object + for stage in self._stages: + stage.has_backward = self._has_backward + + # Clean per iteration + for stage in self._stages: + stage.clear_runtime_states() + + # Split inputs into microbatches + args_split, kwargs_split = self._split_inputs(args, kwargs) + + # Split target into microbatches + if target is not None: + targets_split = list(torch.tensor_split(target, self._n_microbatches)) + else: + targets_split = None + + # Run microbatches + self._step_microbatches( + args_split, kwargs_split, targets_split, losses, return_outputs + ) + + # Return merged results per original format + for stage in self._stages: + if stage.is_last and return_outputs: + return self._merge_outputs(stage.output_chunks) + # Does not contain the last stage or we do not return output chunks + return None + + def _step_microbatches( + self, + arg_mbs: list | None = None, + kwarg_mbs: list | None = None, + target_mbs: list | None = None, + losses: list | None = None, + return_outputs: bool = True, + ): + """ + Operate on the microbatches for looped schedules (multiple stages on each rank). + + TODO: Does not use sorted_batch_isend_irecv(). As a result, this schedule does + not support models with skip connections. + """ + arg_mbs, kwarg_mbs = self._check_inputs(arg_mbs, kwarg_mbs, target_mbs, losses) + + self._initialize_stages(arg_mbs[0], kwarg_mbs[0]) + + # Based on the plan in Step 1 created in __init__: + # 2. Perform communication based on the pipeline_order + stage_index_to_stage: dict[int, _PipelineStageBase] = { + stage.stage_index: stage for stage in self._stages + } + + # determine prev_rank and next_rank based on which ranks are next to + # the stages in the pipeline_order + all_prev_ranks: set[int] = set() + all_next_ranks: set[int] = set() + for stage_index in stage_index_to_stage: + # TODO: assumption that stages only communicate from distances of +1/-1 (no skip connections) + if stage_index > 0: + all_prev_ranks.add(self.stage_index_to_group_rank[stage_index - 1]) + if stage_index < self._num_stages - 1: + all_next_ranks.add(self.stage_index_to_group_rank[stage_index + 1]) + # count either full_backward or backward_weight together, to determine when to sync DP grads + backward_counter: Counter[int] = Counter() + for time_step, action in enumerate(self.pipeline_order[self.rank]): + try: + ops: list[dist.P2POp] = [] + if action is not None: + computation_type = action.computation_type + mb_index = action.microbatch_index + stage_index = action.stage_index + assert mb_index is not None, ( + "All currently supported action types require valid microbatch_index" + ) + if computation_type == _ComputationType.FORWARD: + # perform forward computation + stage = stage_index_to_stage[stage_index] + output = stage.forward_one_chunk( + mb_index, + arg_mbs[mb_index], + kwarg_mbs[mb_index], + save_forward_output=return_outputs, + ) + self._maybe_compute_loss(stage, output, target_mbs, mb_index) + ops.extend(stage.get_fwd_send_ops(mb_index)) + elif computation_type == _ComputationType.FULL_BACKWARD: + # perform backward computation + stage = stage_index_to_stage[stage_index] + loss = self._maybe_get_loss(stage, mb_index) + backward_counter[stage_index] += 1 + last_backward = ( + backward_counter[stage_index] == self._n_microbatches + ) + grad_scale_factor = ( + self._n_microbatches if self.scale_grads else 1 + ) + stage.backward_one_chunk( + mb_index, + loss=loss, + full_backward=True, + last_backward=last_backward, + ) + if last_backward: + stage.scale_grads(grad_scale_factor) + + ops.extend(stage.get_bwd_send_ops(mb_index)) + elif computation_type == _ComputationType.BACKWARD_INPUT: + # perform backward computation + stage = stage_index_to_stage[stage_index] + loss = self._maybe_get_loss(stage, mb_index) + stage.backward_one_chunk( + mb_index, + loss=loss, + full_backward=False, + last_backward=False, + ) + ops.extend(stage.get_bwd_send_ops(mb_index)) + elif computation_type == _ComputationType.BACKWARD_WEIGHT: + # perform weight update + stage = stage_index_to_stage[stage_index] + backward_counter[stage_index] += 1 + last_backward = ( + backward_counter[stage_index] == self._n_microbatches + ) + grad_scale_factor = ( + self._n_microbatches if self.scale_grads else 1 + ) + stage.backward_weight_one_chunk( + mb_index, + last_backward=last_backward, + ) + if last_backward: + stage.scale_grads(grad_scale_factor) + else: + raise ValueError(f"Unknown computation type {computation_type}") + + # Look at the neighboring ranks for this current timestep and determine whether + # this current rank needs to do any recv communication + for prev_rank in all_prev_ranks: + prev_rank_ops = self.pipeline_order[prev_rank] + prev_rank_action = None + if time_step < len(prev_rank_ops): + prev_rank_action = prev_rank_ops[time_step] + if prev_rank_action is not None: + computation_type = prev_rank_action.computation_type + mb_index = prev_rank_action.microbatch_index + stage_index = prev_rank_action.stage_index + assert mb_index is not None, ( + "All currently supported action types require valid microbatch_index" + ) + # Only handle sends for the forward from a previous rank + if computation_type == _ComputationType.FORWARD: + # If not the last stage, then receive fwd activations + if stage_index + 1 in stage_index_to_stage: + # TODO: We are assuming that stage will always receive from stage-1 + # however that is not necessarily true of get_fwd_recv_ops + stage = stage_index_to_stage[stage_index + 1] + ops.extend(stage.get_fwd_recv_ops(mb_index)) + elif computation_type in ( + FULL_BACKWARD, + BACKWARD_INPUT, + BACKWARD_WEIGHT, + ): + # Previous rank doing backward has no influence for the current rank forward recv + pass + else: + raise ValueError( + f"Unknown computation type {computation_type}" + ) + for next_rank in all_next_ranks: + next_rank_ops = self.pipeline_order[next_rank] + next_rank_action = None + if time_step < len(next_rank_ops): + next_rank_action = next_rank_ops[time_step] + if next_rank_action is not None: + computation_type = next_rank_action.computation_type + mb_index = next_rank_action.microbatch_index + stage_index = next_rank_action.stage_index + assert mb_index is not None, ( + "All currently supported action types require valid microbatch_index" + ) + # Only handle receives for the backwards from a next rank + if computation_type in (FORWARD, BACKWARD_WEIGHT): + # Next rank doing forward or weight update has no influence for the current rank backward recv + pass + elif computation_type in (BACKWARD_INPUT, FULL_BACKWARD): + # If not the first stage, then receive bwd gradients + if stage_index - 1 in stage_index_to_stage: + # TODO: We are assuming that stage will always receive from stage+1 + # however that is not necessarily true of get_bwd_recv_ops + stage = stage_index_to_stage[stage_index - 1] + ops.extend(stage.get_bwd_recv_ops(mb_index)) + else: + raise ValueError( + f"Unknown computation type {computation_type}" + ) + + # do the communication + _wait_batch_p2p(_batch_p2p(ops)) + except Exception as e: + logger.error( # noqa: G200 + "[Rank %s] pipeline schedule %s caught the following exception '%s' \ +at time_step %s when running action %s", + self.rank, + self.__class__.__name__, + str(e), + time_step, + action, + ) + logger.error( + "%s", + _format_pipeline_order( + self.pipeline_order, error_step_number=time_step + ), + ) + raise e + # Return losses if there is a container passed in + self._update_losses(self._stages, losses) + + +class _PipelineContext: + def __init__( + self, + schedule_ref: _PipelineSchedule, + arg_mbs: list[tuple] | None = None, + kwarg_mbs: list[dict] | None = None, + target_mbs: list | None = None, + losses: list | None = None, + ): + self.schedule_ref = schedule_ref + self.arg_mbs = arg_mbs + self.kwarg_mbs = kwarg_mbs + self.target_mbs = target_mbs + self.losses = losses + + +class _CustomFunctionProtocol(Protocol): + def __call__(self, action: _Action, ctx: _PipelineContext) -> None: ... + + +class _PipelineScheduleRuntime(PipelineScheduleMulti): + """ + Provides a simple runtime that requires a 'schedule IR' including specified communication operations. + + Can be instantiated directly by creating _PipelineScheduleRuntime and calling load_csv, or can be + subclassed and the subclass can be responsible for creating a schedule IR. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # Action to custom function mapping + self._comp_type_to_function_map: dict[_ComputationType, Callable] = {} + # count either full_backward or backward_weight together, to determine when to sync DP grads + self.backward_counter: Counter[int] = Counter() + + # recv ops indexed by (stage_idx, mb_idx) need to be waited on before use + self.bwd_recv_ops: dict[tuple[int, int], list[dist.Work]] = {} + self.fwd_recv_ops: dict[tuple[int, int], list[dist.Work]] = {} + + # we track which stages are 'active' when used with FSDP, and wait on unshard ops before computing on stages + self.unshard_ops: dict[int, list[UnshardHandle]] = defaultdict(list) + self.unsharded_stages = set() + + def register_custom_function( + self, + computation_type: _ComputationType, + custom_function: _CustomFunctionProtocol, + ) -> None: + """ + Register a custom function to be executed for a specific computation type. + + Args: + computation_type: The computation type for which to register the custom function + custom_function: The function to execute when this computation type is encountered. + Must have signature: (action: _Action, ctx: _PipelineContext) -> None + """ + # Ensure that the computation type is valid + if computation_type not in ( + FORWARD, + FULL_BACKWARD, + BACKWARD_INPUT, + BACKWARD_WEIGHT, + OVERLAP_F_B, + UNSHARD, + RESHARD, + REDUCE_GRAD, + ): + raise ValueError( + f"Invalid computation type {computation_type}. Only FORWARD, FULL_BACKWARD, \ + BACKWARD_INPUT, BACKWARD_WEIGHT, OVERLAP_F_B, UNSHARD, RESHARD and REDUCE_GRAD are supported." + ) + + # Check if computation_type is already registered + if computation_type in self._comp_type_to_function_map: + logger.warning( + "Computation type %s is already registered. " + "Overwriting the existing custom function.", + computation_type, + ) + + self._comp_type_to_function_map[computation_type] = custom_function + + def _prepare_schedule_with_comms( + self, + actions: dict[int, list[_Action | None]], + format: str = "compute_only", + ): + """ + Given an in-memory representation for a simple compute-only schedule, lower it to a complex schedule including + communication actions. Stores the schedule in self, and must be called before running step_mo() + """ + # validate the provided actions are valid and overrides the default stage_index_to_group_rank + super()._validate_and_set_stage_mapping(actions) + + self.pipeline_order_with_comms: dict[int, list[_Action]] = {} + if format == "compute_comms": + for rank in actions: + self.pipeline_order_with_comms[rank] = [] + for action in actions[rank]: + assert action is not None + self.pipeline_order_with_comms[rank].append(action) + # TODO what level of validation should we offer for compute+comms schedule? + elif format == "compute_only": + # Validate that the schedule does not have comms already added to it + for rank, action_list in actions.items(): + for i, action in enumerate(action_list): + if action is not None and not action.is_compute_op: + raise ValueError( + f"Expected compute-only schedule but found communication action " + f"'{action}' at rank {rank}, position {i}. " + f"Communication actions (e.g. SEND_F, RECV_F, etc.) " + f"should not be present when format='compute_only'." + ) + + # Perform schedule lowering + for rank in actions: + self.pipeline_order_with_comms[rank] = _add_unshard_reshard( + actions[rank] + ) + self.pipeline_order_with_comms[rank] = _add_reduce_grad( # type: ignore[assignment] + self.pipeline_order_with_comms[rank], # type: ignore[arg-type] + self._n_microbatches, + ) + + self.pipeline_order_with_comms = _add_send_recv( + self.pipeline_order_with_comms, + stage_to_rank=lambda s: self.stage_index_to_group_rank[s], + num_stages=self._num_stages, + ) + else: + raise NotImplementedError(f"{format=} is not implemented") + + def _load_csv(self, filename: str, format: str = "compute_only"): + """Loads a csv in simple format and then lowers it to include communication actions + + format must be either "compute_only" or "compute_comms". If compute_only, the lowering passes + will automatically be run to generate a compute_comms schedule. + """ + if format == "compute_only": + # this will populate self.pipeline_order + super()._load_csv(filename) + # this will populate self.pipeline_order_with_comms + self._prepare_schedule_with_comms(self.pipeline_order) + elif format == "compute_comms": + actions = {} + with open(filename, newline="") as csvfile: + reader = csv.reader(csvfile) + for rank, row in enumerate(reader): + actions[rank] = [_Action.from_str(s) for s in row] + self._prepare_schedule_with_comms(actions, format=format) + else: + raise NotImplementedError(f"{format=} is not implemented") + + def _dump_csv(self, filename: str, format: str = "compute_comms"): + """Dump a CSV representation of the schedule into a file with the provided filename.""" + if format == "compute_only": + assert self.pipeline_order is not None, ( + "Compute only schedule must be available" + ) + with open(filename, "w", newline="") as csvfile: + writer = csv.writer(csvfile) + for rank in self.pipeline_order: + writer.writerow(self.pipeline_order[rank]) + elif format == "compute_comms": + assert self.pipeline_order_with_comms is not None, ( + "Must initialize compute_comms schedule before dump_csv" + ) + with open(filename, "w", newline="") as csvfile: + writer = csv.writer(csvfile) + for rank in self.pipeline_order_with_comms: + writer.writerow(self.pipeline_order_with_comms[rank]) + + def _simulate(self): + return _simulate_comms_compute( + self.pipeline_order_with_comms, + lambda s: self.stage_index_to_group_rank[s], + self._num_stages, + ) + + def _assert_unsharded(self, stage: _PipelineStageBase): + """If an unshard is active for `stage_idx`, wait() it and mark `stage_idx` unshared.""" + stage_uses_fsdp = isinstance(stage.submod, FSDPModule) + if stage_uses_fsdp: + stage_idx = stage.stage_index + if stage_idx in self.unshard_ops: + for op in self.unshard_ops[stage_idx]: + op.wait() + del self.unshard_ops[stage_idx] + self.unsharded_stages.add(stage_idx) + assert stage_idx in self.unsharded_stages, ( + f"Attempted to compute on sharded {stage_idx=}" + ) + + def _step_microbatches( + self, + arg_mbs: list | None = None, + kwarg_mbs: list | None = None, + target_mbs: list | None = None, + losses: list | None = None, + return_outputs: bool = True, + ): + """ + Operate on the microbatches for looped schedules (multiple stages on each rank). + + TODO: Does not use sorted_batch_isend_irecv(). As a result, this schedule does + not support models with skip connections. + """ + arg_mbs, kwarg_mbs = self._check_inputs(arg_mbs, kwarg_mbs, target_mbs, losses) + self._initialize_stages(arg_mbs[0], kwarg_mbs[0]) + + # Based on the plan in Step 1 created in __init__: + # 2. Perform communication based on the pipeline_order + stage_index_to_stage: dict[int, _PipelineStageBase] = { + stage.stage_index: stage for stage in self._stages + } + + assert self.pipeline_order_with_comms is not None, ( + "Must call _prepare_schedule_with_comms() before calling _step_microbatches()" + ) + + # send ops should be waited on before step() exists, mainly for hygiene + send_ops: list[list[dist.Work]] = [] + + def _perform_action(action: _Action) -> None: + comp_type = action.computation_type + mb_index: int = ( + action.microbatch_index if action.microbatch_index is not None else -1 + ) + assert mb_index >= 0 or comp_type in ( + UNSHARD, + RESHARD, + REDUCE_GRAD, + ), f"{action=} missing mb_index" + stage_idx = action.stage_index + stage = stage_index_to_stage[stage_idx] + stage_uses_fsdp = isinstance(stage.submod, FSDPModule) + # see [Note: V-schedule special case] + is_next_stage_on_this_rank = stage_idx + 1 in stage_index_to_stage + is_prev_stage_on_this_rank = stage_idx - 1 in stage_index_to_stage + + # TODO(whc) it's not actually safe to use _batch_p2p here in the uncommon case the model has skip-connections, + # since we do not want to batch up ops between more than a pair of ranks. _sorted_batch_p2p would be + # safe to use instead. + # However, I was wondering if I should avoid calling batched operators at all in the case that there is + # only one operator per batch. I could iterate through the 'fwd_send_ops' one by one and run them. + if comp_type == SEND_F: + send_ops.append(_batch_p2p(stage.get_fwd_send_ops(mb_index))) + elif comp_type == SEND_B: + send_ops.append(_batch_p2p(stage.get_bwd_send_ops(mb_index))) + elif comp_type == RECV_F: + assert ( + stage_idx, + mb_index, + ) not in self.fwd_recv_ops, ( + f"Recv twice for {stage_idx=} {mb_index=} without executing forward" + ) + self.fwd_recv_ops[(stage_idx, mb_index)] = _batch_p2p( + stage.get_fwd_recv_ops(mb_index) + ) + elif comp_type == RECV_B: + assert ( + stage_idx, + mb_index, + ) not in self.bwd_recv_ops, ( + f"Recv twice for {stage_idx=} {mb_index=} without executing backward" + ) + self.bwd_recv_ops[(stage_idx, mb_index)] = _batch_p2p( + stage.get_bwd_recv_ops(mb_index) + ) + elif comp_type == UNSHARD: + if stage_uses_fsdp: + assert ( + stage_idx not in self.unsharded_stages + and stage_idx not in self.unshard_ops + ), f"Unsharding the same {stage_idx=} twice" + for submodule in stage.submod.modules(): + if not isinstance(submodule, FSDPModule): + continue + handle = cast(UnshardHandle, submodule.unshard(async_op=True)) + self.unshard_ops[stage_idx].append(handle) + elif comp_type == RESHARD: + if stage_uses_fsdp: + assert stage_idx in self.unsharded_stages, ( + f"Resharding {stage_idx=} without unsharding" + ) + assert stage_idx not in self.unshard_ops, ( + f"Resharding {stage_idx=} before finishing unshard" + ) + for submodule in stage.submod.modules(): + if not isinstance(submodule, FSDPModule): + continue + submodule.reshard() + self.unsharded_stages.remove(stage_idx) + elif comp_type == FORWARD: + self._assert_unsharded(stage) + + if ( + not stage.is_first + # no recv op expected for V-schedule special case (see [Note: V-schedule special case]) + and not is_prev_stage_on_this_rank + ): + assert ( + stage_idx, + mb_index, + ) in self.fwd_recv_ops, ( + f"Computing {action=} before receiving input" + ) + _wait_batch_p2p(self.fwd_recv_ops.pop((stage_idx, mb_index))) + + output = stage.forward_one_chunk( + mb_index, + arg_mbs[mb_index], # type: ignore[index] + kwarg_mbs[mb_index], # type: ignore[index] + save_forward_output=return_outputs, + ) + self._maybe_compute_loss(stage, output, target_mbs, mb_index) + + # SEND/RECV op are avoided for special case with 2 adjacent stages on same rank + # see [Note: V-schedule special case] + if is_next_stage_on_this_rank: + stage_index_to_stage[stage_idx + 1].set_local_fwd_input( + output, mb_index + ) + + elif comp_type == FULL_BACKWARD: + self._assert_unsharded(stage) + + if ( + not stage.is_last + # no recv op expected for V-schedule special case (see [Note: V-schedule special case]) + and not is_next_stage_on_this_rank + ): + assert ( + stage_idx, + mb_index, + ) in self.bwd_recv_ops, ( + f"Attempted to run compute {action=} before receiving input" + ) + _wait_batch_p2p(self.bwd_recv_ops.pop((stage_idx, mb_index))) + loss = self._maybe_get_loss(stage, mb_index) + self.backward_counter[stage_idx] += 1 + last_backward = self.backward_counter[stage_idx] == self._n_microbatches + stage.backward_one_chunk( + mb_index, + loss=loss, + full_backward=True, + last_backward=last_backward, + ) + # SEND/RECV op are avoided for special case with 2 adjacent stages on same rank + # see [Note: V-schedule special case] + if is_prev_stage_on_this_rank: + stage_index_to_stage[stage_idx - 1].set_local_bwd_input( + stage.get_local_bwd_output(mb_index), mb_index + ) + elif comp_type == BACKWARD_INPUT: + self._assert_unsharded(stage) + + if not stage.is_last and not is_next_stage_on_this_rank: + assert ( + stage_idx, + mb_index, + ) in self.bwd_recv_ops, ( + f"Attempted to run compute {action=} before receiving input" + ) + _wait_batch_p2p(self.bwd_recv_ops.pop((stage_idx, mb_index))) + loss = self._maybe_get_loss(stage, mb_index) + stage.backward_one_chunk( + mb_index, + loss=loss, + full_backward=False, + last_backward=False, + ) + # SEND/RECV op are avoided for special case with 2 adjacent stages on same rank + # see [Note: V-schedule special case] + if is_prev_stage_on_this_rank: + stage_index_to_stage[stage_idx - 1].set_local_bwd_input( + stage.get_local_bwd_output(mb_index), mb_index + ) + elif comp_type == BACKWARD_WEIGHT: + self._assert_unsharded(stage) + self.backward_counter[stage_idx] += 1 + last_backward = self.backward_counter[stage_idx] == self._n_microbatches + stage.backward_weight_one_chunk( + mb_index, + last_backward=last_backward, + ) + elif comp_type == REDUCE_GRAD: + grad_scale_factor = self._n_microbatches if self.scale_grads else 1 + stage.perform_reduce_grad(grad_scale_factor) + else: + raise ValueError(f"{action=} is unknown or unsupported") + + # count either full_backward or backward_weight together, to determine when to sync DP grads + self.backward_counter.clear() + for time_step, action in enumerate(self.pipeline_order_with_comms[self.rank]): + logger.debug( + "_PipelineScheduleRuntime running time_step %d, action %s", + time_step, + action, + ) + try: + with record_function(_get_profiler_function_name(action)): + if action.computation_type in self._comp_type_to_function_map: + ctx = _PipelineContext( + self, + arg_mbs, + kwarg_mbs, + target_mbs, + losses, + ) + self._comp_type_to_function_map[action.computation_type]( + action, ctx + ) + elif action.computation_type == OVERLAP_F_B: + assert action.sub_actions is not None, "sub_actions must be set" + for sub_a in action.sub_actions: + _perform_action(sub_a) + else: + _perform_action(action) + except Exception as e: + logger.error( + "_PipelineScheduleRuntime caught exception at step %s when running action %s. Full Schedule:", + time_step, + action, + ) + logger.error( + _format_pipeline_order( + self.pipeline_order_with_comms, # type: ignore[arg-type] + error_step_number=time_step, + ) + ) + raise e + + # Mostly these operations should have finished long ago, but there isn't an obvious time when to wait for them + while send_ops: + _wait_batch_p2p(send_ops.pop()) + + assert len(self.unshard_ops) == 0, "Unused unshard operations" + + # Return losses if there is a container passed in + self._update_losses(self._stages, losses) + + +class ScheduleLoopedBFS(_PipelineScheduleRuntime): + """ + Breadth-First Pipeline Parallelism. + See https://arxiv.org/abs/2211.05953 for details. + Similar to Interleaved 1F1B, Looped BFS supports multiple stages per rank. + What is different is that when microbatches are ready for multiple local + stages, Loops BFS will prioritizes the earlier stage, running all available + microbatches at once. + """ + + def __init__( + self, + stages: list[_PipelineStageBase], + n_microbatches: int, + loss_fn: Callable | _Loss | None = None, + output_merge_spec: dict[str, Any] | tuple[Any] | None = None, + scale_grads: bool = True, + backward_requires_autograd: bool = True, + ): + super().__init__( + stages=stages, + n_microbatches=n_microbatches, + loss_fn=loss_fn, + output_merge_spec=output_merge_spec, + scale_grads=scale_grads, + backward_requires_autograd=backward_requires_autograd, + ) + + # 1. Create the pipeline_order (all ranks do this calculation) + # This will be used to keep track of the current state of the entire pipeline + # pipeline_order[rank] = [Action(computation_type, microbatch_index, stage_index), ...] + self.pipeline_order: dict[int, list[_Action | None]] = {} + # ======================================================================== + for rank in range(self.pp_group_size): + rank_ops = self._calculate_single_rank_operations(rank) + self.pipeline_order[rank] = rank_ops + + # Initialize the pipeline order with communication necessary to run with _PipelineScheduleRuntime + self._prepare_schedule_with_comms(self.pipeline_order) + + def _calculate_single_rank_operations(self, rank): + n_local_stages = len(self._stages) + stage_indices = range( + rank, self.pp_group_size * n_local_stages, self.pp_group_size + ) + + # Store the list of operations used for that rank + # Pre-padding, rank starts with no-ops based on the warmup. + rank_ops: list[_Action | None] = [None for _ in range(rank)] + + for stage_index in stage_indices: + rank_ops.extend( + _Action(stage_index, _ComputationType.FORWARD, mb_index) + for mb_index in range(self._n_microbatches) + ) + + # wait for the first backward to trickle up + # which is 2 for every hop away + post_warmup_ops = 2 * (self.pp_group_size - 1 - rank) + rank_ops.extend([None] * post_warmup_ops) + + for stage_index in reversed(stage_indices): + rank_ops.extend( + _Action(stage_index, _ComputationType.FULL_BACKWARD, mb_index) + for mb_index in reversed(range(self._n_microbatches)) + ) + return rank_ops + + +def _get_1f1b_rank_ops( + n_local_stages, + pp_group_size, + warmup_ops, + fwd_bwd_ops, + cooldown_ops, + rank, + forward_stage_index, + backward_stage_index, + num_1f1b_microbatches=0, + enable_zero_bubble=False, +): + # All stages start with handling microbatch 0 + fwd_stage_mb_index: dict[int, int] = defaultdict(int) + bwd_stage_mb_index: dict[int, int] = defaultdict(int) + weight_stage_mb_index: dict[int, int] = defaultdict(int) + + # Store the list of operations used for that rank + # Pre-padding, rank starts with no-ops based on the warmup. + rank_ops: list[_Action | None] = [None for _ in range(rank)] + # These are used to calculate the number of slots to fill with no-ops, to account for the delay in warmup + # when we want to wait for the backward to trickle back up and start 1f1b to align all ranks. + # Formula: + # pre-padding + warmup_ops + post_warmup_ops = earliest time step of first backward + # post_warmup_ops = [earliest time step of first backward] - (warmup_ops + pre-padding) + # earliest time step of first backward = [local_stages * group_size + 2 * (group_size - 1 - rank)] + # warmup_ops = calculated above + post_warmup_ops = ( + n_local_stages * pp_group_size + 2 * (pp_group_size - 1 - rank) + ) - (warmup_ops + rank) + + if enable_zero_bubble: + post_warmup_ops = pp_group_size - rank - 1 + + total_ops = warmup_ops + fwd_bwd_ops + cooldown_ops + + backward_op_ids = [] + weight_op_count = 0 + + FULL_BACKWARD_OR_BACKWARD_INPUT = ( + BACKWARD_INPUT if enable_zero_bubble else FULL_BACKWARD + ) + + for op in range(total_ops): + # Warmup phase + if op < warmup_ops: + fwd_stage_index = forward_stage_index(op) + # This will assign the current microbatch index and update it as well + fwd_stage_mb_index[fwd_stage_index] = ( + mb_index := fwd_stage_mb_index[fwd_stage_index] + ) + 1 + rank_ops.append( + _Action(fwd_stage_index, _ComputationType.FORWARD, mb_index) + ) + if op == warmup_ops - 1: + # This is the last step in the warmup phase, so we need to wait for the backward to trickle back up + rank_ops.extend([None] * post_warmup_ops) + # 1F1B Phase (forward and backward) + elif warmup_ops <= op < warmup_ops + fwd_bwd_ops: + fwd_stage_index = forward_stage_index(op) + fwd_stage_mb_index[fwd_stage_index] = ( + fwd_mb_index := fwd_stage_mb_index[fwd_stage_index] + ) + 1 + rank_ops.append( + _Action(fwd_stage_index, _ComputationType.FORWARD, fwd_mb_index) + ) + bwd_stage_index = backward_stage_index(op) + bwd_stage_mb_index[bwd_stage_index] = ( + bwd_mb_index := bwd_stage_mb_index[bwd_stage_index] + ) + 1 + rank_ops.append( + _Action(bwd_stage_index, FULL_BACKWARD_OR_BACKWARD_INPUT, bwd_mb_index) + ) + backward_op_ids.append(op) + + if enable_zero_bubble and op - warmup_ops >= num_1f1b_microbatches: + weight_stage_index = backward_stage_index( + backward_op_ids[weight_op_count] + ) + weight_stage_mb_index[weight_stage_index] = ( + weight_mb_index := weight_stage_mb_index[weight_stage_index] + ) + 1 + rank_ops.append( + _Action( + weight_stage_index, + _ComputationType.BACKWARD_WEIGHT, + weight_mb_index, + ) + ) + weight_op_count += 1 + # Cooldown phase + else: + # During cooldown phase, we need steps to align with 1f1b happening in other ranks + # TODO: we don't need to always append, after all 1f1b are finished we can stop appending None + if not enable_zero_bubble: + rank_ops.append(None) + + bwd_stage_index = backward_stage_index(op) + bwd_stage_mb_index[bwd_stage_index] = ( + bwd_mb_index := bwd_stage_mb_index[bwd_stage_index] + ) + 1 + rank_ops.append( + _Action(bwd_stage_index, FULL_BACKWARD_OR_BACKWARD_INPUT, bwd_mb_index) + ) + backward_op_ids.append(op) + + if enable_zero_bubble and op - warmup_ops >= num_1f1b_microbatches: + weight_stage_index = backward_stage_index( + backward_op_ids[weight_op_count] + ) + weight_stage_mb_index[weight_stage_index] = ( + weight_mb_index := weight_stage_mb_index[weight_stage_index] + ) + 1 + rank_ops.append( + _Action( + weight_stage_index, + _ComputationType.BACKWARD_WEIGHT, + weight_mb_index, + ) + ) + weight_op_count += 1 + + while enable_zero_bubble and weight_op_count < len(backward_op_ids): + weight_stage_index = backward_stage_index(backward_op_ids[weight_op_count]) + weight_stage_mb_index[weight_stage_index] = ( + weight_mb_index := weight_stage_mb_index[weight_stage_index] + ) + 1 + rank_ops.append( + _Action( + weight_stage_index, _ComputationType.BACKWARD_WEIGHT, weight_mb_index + ) + ) + weight_op_count += 1 + + return rank_ops + + +class ScheduleInterleaved1F1B(_PipelineScheduleRuntime): + """ + The Interleaved 1F1B schedule. + See https://arxiv.org/pdf/2104.04473 for details. + Will perform one forward and one backward on the microbatches in steady + state and supports multiple stages per rank. When microbatches are ready for + multiple local stages, Interleaved 1F1B prioritizes the earlier microbatch + (also called "depth first"). + + This schedule is mostly similar to the original paper. + It differs by being relaxing the requirement of num_microbatch % pp_size == 0. + Using the flex_pp schedule, we will have num_rounds = max(1, n_microbatches // pp_group_size) and + it works as long as n_microbatches % num_rounds is 0. As a few examples, support + + 1. pp_group_size = 4, n_microbatches = 10. We will have num_rounds = 2 and n_microbatches % 2 is 0. + 2. pp_group_size = 4, n_microbatches = 3. We will have num_rounds = 1 and n_microbatches % 1 is 0. + """ + + def __init__( + self, + stages: list[_PipelineStageBase], + n_microbatches: int, + loss_fn: Callable | None = None, + args_chunk_spec: tuple[TensorChunkSpec, ...] | None = None, + kwargs_chunk_spec: dict[str, TensorChunkSpec] | None = None, + output_merge_spec: dict[str, Any] | tuple[Any] | None = None, + scale_grads: bool = True, + backward_requires_autograd: bool = True, + ): + self.pp_group_size = stages[0].group_size + super().__init__( + stages=stages, + n_microbatches=n_microbatches, + loss_fn=loss_fn, + args_chunk_spec=args_chunk_spec, + kwargs_chunk_spec=kwargs_chunk_spec, + output_merge_spec=output_merge_spec, + scale_grads=scale_grads, + backward_requires_autograd=backward_requires_autograd, + ) + self.n_local_stages = len(stages) + self.rank = stages[0].group_rank + self.number_of_rounds = max(1, n_microbatches // self.pp_group_size) + self.microbatches_per_round = n_microbatches // self.number_of_rounds + if n_microbatches % self.number_of_rounds != 0: + raise ValueError( + "Interleaved 1F1B requires the number of microbatches to be a " + f"multiple of the number of rounds ({self.number_of_rounds}), " + f"but got {n_microbatches}." + ) + # 1. Create the pipeline_order (all ranks do this calculation) + # This will be used to keep track of the current state of the entire pipeline + # pipeline_order[rank] = [Action(computation_type, microbatch_index, stage_index), ...] + self.pipeline_order: dict[int, list[_Action | None]] = {} + for rank in range(self.pp_group_size): + rank_ops = self._calculate_single_rank_operations(rank) + self.pipeline_order[rank] = rank_ops + + # Initialize the pipeline order with communication necessary to run with _PipelineScheduleRuntime + self._prepare_schedule_with_comms(self.pipeline_order) + + def _calculate_single_rank_operations(self, rank) -> list[_Action | None]: + def get_rank_warmup_ops(rank): + # Warms up operations for last stage + warmups_ops_last_stage = ( + self.n_local_stages - 1 + ) * self.microbatches_per_round + # Increment warmup operations by 2 for each hop away from the last stage + multiply_factor = 2 + warmup_ops = warmups_ops_last_stage + multiply_factor * ( + (self.pp_group_size - 1) - rank + ) + + # We cannot have more warmup operations than there are number of microbatches, so cap it there + return min(warmup_ops, self._n_microbatches * self.n_local_stages) + + warmup_ops = get_rank_warmup_ops(rank) + microbatch_ops = self.n_local_stages * self._n_microbatches + # fwd_bwd_ops should encompass the remaining forwards + fwd_bwd_ops = microbatch_ops - warmup_ops + # cooldown_ops should encompass the remaining backwards + cooldown_ops = microbatch_ops - fwd_bwd_ops + # total ops encompass both forward and backward ops + total_ops = warmup_ops + fwd_bwd_ops + cooldown_ops + # warmup_ops + fwd_bwd_ops * 2 + cooldown_ops == microbatch_ops * 2 + logger.debug( + "rank %s, warmup_ops %s, 1f1b %s, cooldown_ops %s total_ops %s", + rank, + warmup_ops, + fwd_bwd_ops, + cooldown_ops, + total_ops, + ) + + # Calculates the stage index based on step and pp_group_size + def forward_stage_index(step): + # Get the local index from 0 to n_local_stages-1 + local_index = (step // self.microbatches_per_round) % self.n_local_stages + return (local_index * self.pp_group_size) + rank + + def backward_stage_index(step): + local_index = ( + self.n_local_stages + - 1 + - ((step - warmup_ops) // self.microbatches_per_round) + % self.n_local_stages + ) + return (local_index * self.pp_group_size) + rank + + return _get_1f1b_rank_ops( + self.n_local_stages, + self.pp_group_size, + warmup_ops, + fwd_bwd_ops, + cooldown_ops, + rank, + forward_stage_index, + backward_stage_index, + ) + + +class ScheduleInterleavedZeroBubble(_PipelineScheduleRuntime): + """ + The Interleaved Zero Bubble schedule. + See https://arxiv.org/pdf/2401.10241 for details. + Will perform one forward and one backward on inputs for the microbatches in steady + state and supports multiple stages per rank. Uses the backward for weights to fill in + the pipeline bubble. + + In particular this is implementing the ZB1P schedule in the paper. + """ + + def __init__( + self, + stages: list[_PipelineStageBase], + n_microbatches: int, + loss_fn: Callable | None = None, + args_chunk_spec: tuple[TensorChunkSpec, ...] | None = None, + kwargs_chunk_spec: dict[str, TensorChunkSpec] | None = None, + output_merge_spec: dict[str, Any] | tuple[Any] | None = None, + scale_grads: bool = True, + backward_requires_autograd: bool = True, + ): + # TODO: we dont support input/weight backward split with torch.compile + _check_torch_compile_compatibility(stages, self.__class__.__name__) + self.pp_group_size = stages[0].group_size + super().__init__( + stages=stages, + n_microbatches=n_microbatches, + loss_fn=loss_fn, + args_chunk_spec=args_chunk_spec, + kwargs_chunk_spec=kwargs_chunk_spec, + output_merge_spec=output_merge_spec, + scale_grads=scale_grads, + backward_requires_autograd=backward_requires_autograd, + ) + self.n_local_stages = len(stages) + self.rank = stages[0].group_rank + self.number_of_rounds = max(1, n_microbatches // self.pp_group_size) + self.microbatches_per_round = n_microbatches // self.number_of_rounds + if n_microbatches % self.number_of_rounds != 0: + raise ValueError( + "Zero bubble requires the number of microbatches to be a " + f"multiple of the number of rounds ({self.number_of_rounds}), " + f"but got {n_microbatches}." + ) + # 1. Create the pipeline_order (all ranks do this calculation) + # This will be used to keep track of the current state of the entire pipeline + # pipeline_order[rank] = [Action(computation_type, microbatch_index, stage_index), ...] + self.pipeline_order: dict[int, list[_Action | None]] = {} + for rank in range(self.pp_group_size): + rank_ops = self._calculate_single_rank_operations(rank) + self.pipeline_order[rank] = rank_ops + + # This function add bubbles to the generated schedule based on dependencies of actions + # Note that the ZB1P schedule will not require bubbles to be manually added and it is + # only useful when n_microbatches <= microbatches_per_round + self.pipeline_order = self._add_bubbles_to_actions( + self.n_local_stages * self.pp_group_size, + ) + + # Initialize the pipeline order with communication necessary to run with _PipelineScheduleRuntime + self._prepare_schedule_with_comms(self.pipeline_order) + + def _calculate_single_rank_operations(self, rank) -> list[_Action | None]: + def get_rank_warmup_ops(rank): + # Warms up operations for last stage + warmups_ops_last_stage = ( + self.n_local_stages - 1 + ) * self.microbatches_per_round + # Increment warmup operations by 2 for each hop away from the last stage + multiply_factor = 1 + warmup_ops = warmups_ops_last_stage + multiply_factor * ( + (self.pp_group_size - 1) - rank + ) + + # We cannot have more warmup operations than there are number of microbatches, so cap it there + return min(warmup_ops, self._n_microbatches * self.n_local_stages) + + warmup_ops = get_rank_warmup_ops(rank) + microbatch_ops = self.n_local_stages * self._n_microbatches + # fwd_bwd_ops should encompass the remaining forwards + fwd_bwd_ops = microbatch_ops - warmup_ops + # cooldown_ops should encompass the remaining backwards + cooldown_ops = microbatch_ops - fwd_bwd_ops + # total ops encompass both forward and backward ops + total_ops = warmup_ops + fwd_bwd_ops + cooldown_ops + # warmup_ops + fwd_bwd_ops * 2 + cooldown_ops == microbatch_ops * 2 + logger.debug( + "rank %s, warmup_ops %s, 1f1b %s, cooldown_ops %s total_ops %s", + rank, + warmup_ops, + fwd_bwd_ops, + cooldown_ops, + total_ops, + ) + + # Calculates the stage index based on step and pp_group_size + + def forward_stage_index(step): + # Get the local index from 0 to n_local_stages-1 + local_index = (step // self.microbatches_per_round) % self.n_local_stages + return (local_index * self.pp_group_size) + rank + + def backward_stage_index(step): + local_index = ( + self.n_local_stages + - 1 + - ((step - warmup_ops) // self.microbatches_per_round) + % self.n_local_stages + ) + return (local_index * self.pp_group_size) + rank + + num_1f1b_microbatches = rank + + return _get_1f1b_rank_ops( + self.n_local_stages, + self.pp_group_size, + warmup_ops, + fwd_bwd_ops, + cooldown_ops, + rank, + forward_stage_index, + backward_stage_index, + num_1f1b_microbatches, + enable_zero_bubble=True, + ) + + def _add_bubbles_to_actions(self, num_stages_global): + actions = self.pipeline_order + + def need_bubble(stage, op, microbatch, num_stages_global, seen_ops): + if op == _ComputationType.FORWARD: + if stage != 0 and (stage - 1, op, microbatch) not in seen_ops: + return True + elif op == _ComputationType.FULL_BACKWARD: + if stage == num_stages_global - 1: + return (stage, _ComputationType.FORWARD, microbatch) not in seen_ops + return (stage + 1, op, microbatch) not in seen_ops + return False + + seen_ops: set[tuple[int, _ComputationType, int]] = set() + result: dict[int, list[_Action | None]] = {} + next_pointer: dict[int, int] = {} + bubbles_added: dict[int, int] = {} + total_bubbles_added = 0 + + for rank in range(self.pp_group_size): + result[rank] = [] + next_pointer[rank] = 0 + bubbles_added[rank] = 0 + + while True: + should_stop = True + + temp_seen_ops: set[tuple[int, _ComputationType, int]] = set() + + for rank in range(self.pp_group_size): + timestamp = next_pointer[rank] + if timestamp >= len(actions[rank]): + continue + + should_stop = False + + if actions[rank][timestamp] is not None: + temp_action = actions[rank][timestamp] + assert temp_action is not None + stage_index, op, microbatch, _ = temp_action + if not need_bubble( + stage_index, op, microbatch, num_stages_global, seen_ops + ): + result[rank].append(actions[rank][timestamp]) + if microbatch is not None: + temp_seen_ops.add((stage_index, op, microbatch)) + next_pointer[rank] += 1 + else: + result[rank].append(None) + bubbles_added[rank] += 1 + else: + next_pointer[rank] += 1 + result[rank].append(None) + + seen_ops.update(temp_seen_ops) + if should_stop: + break + + if total_bubbles_added > 0: + logger.warning( + "Non zero bubbles added: total_bubbles_added=%s bubbles_added=%s", + total_bubbles_added, + bubbles_added, + ) + return result + + +class ScheduleZBVZeroBubble(_PipelineScheduleRuntime): + """ + The Zero Bubble schedule (ZBV variant). + See https://arxiv.org/pdf/2401.10241 Section 6 for details. + + This schedules requires exactly two stages per rank. + + This schedule will perform one forward and one backward on inputs for the microbatches in steady + state and supports multiple stages per rank. Uses backward with respect to weights to fill in + the pipeline bubble. + + This ZB-V schedule would have the "zero bubble" property only if time forward == time backward input == time backward weights. + In practice, this is not likely true for real models so alternatively + a greedy scheduler could be implemented for unequal/unbalanced time. + """ + + def __init__( + self, + stages: list[_PipelineStageBase], + n_microbatches: int, + loss_fn: Callable | None = None, + args_chunk_spec: tuple[TensorChunkSpec, ...] | None = None, + kwargs_chunk_spec: dict[str, TensorChunkSpec] | None = None, + output_merge_spec: dict[str, Any] | tuple[Any] | None = None, + scale_grads: bool = True, + backward_requires_autograd: bool = True, + ): + # TODO: we dont support input/weight backward split with torch.compile + _check_torch_compile_compatibility(stages, self.__class__.__name__) + self.pp_group_size = stages[0].group_size + super().__init__( + stages=stages, + n_microbatches=n_microbatches, + loss_fn=loss_fn, + args_chunk_spec=args_chunk_spec, + kwargs_chunk_spec=kwargs_chunk_spec, + output_merge_spec=output_merge_spec, + scale_grads=scale_grads, + backward_requires_autograd=backward_requires_autograd, + ) + self.stage_index_to_group_rank = generate_stage_to_rank_mapping( + self.pp_group_size, self._num_stages, style="v" + ) + for stage in self._stages: + stage.stage_index_to_group_rank = self.stage_index_to_group_rank + + self.n_local_stages = len(stages) + if self.n_local_stages != 2: + raise ValueError( + "ZBV requires exactly 2 stages per rank, but got " + f"{self.n_local_stages}." + ) + + self.rank = stages[0].group_rank + self.num_stages = stages[0].num_stages + + # 1. Create the pipeline_order (all ranks do this calculation) + # This will be used to keep track of the current state of the entire pipeline + # pipeline_order[rank] = [Action(computation_type, microbatch_index, stage_index), ...] + self.pipeline_order: dict[int, list[_Action | None]] = {} + for rank in range(self.pp_group_size): + rank_ops = self._calculate_single_rank_operations(rank) + self.pipeline_order[rank] = rank_ops + + # Initialize the pipeline order with communication necessary to run with _PipelineScheduleRuntime + self._prepare_schedule_with_comms(self.pipeline_order) + + def _calculate_single_rank_operations(self, rank) -> list[_Action | None]: + # max(2 * self.pp_group_size - 1, ...) ensure the number of microbatches is at least + # as large of the number of microbatches needed to fully utilize the pipeline + n_micro = max(2 * self.pp_group_size - 1, self._n_microbatches) + rank_ops: list[_Action | None] = [None for _ in range(rank)] + + # Forward and backward action counts for stage chunk 0 and chunk 1 + f0_cnt, f1_cnt, b0_cnt, b1_cnt = 0, 0, 0, 0 + # warm-up phase + warmup_n1 = 2 * (self.pp_group_size - rank) - 1 + stage_id_chunk0 = rank + stage_id_chunk1 = self.num_stages - 1 - rank + + for _ in range(warmup_n1): + rank_ops.append( + _Action(stage_id_chunk0, computation_type=F, microbatch_index=f0_cnt) + ) + f0_cnt += 1 + warmup_n2 = rank + for _ in range(warmup_n2): + rank_ops.append( + _Action(stage_id_chunk1, computation_type=F, microbatch_index=f1_cnt) + ) + f1_cnt += 1 + rank_ops.append( + _Action(stage_id_chunk0, computation_type=F, microbatch_index=f0_cnt) + ) + f0_cnt += 1 + warmup_n3 = self.pp_group_size - rank + for _ in range(warmup_n3): + rank_ops.append( + _Action(stage_id_chunk1, computation_type=F, microbatch_index=f1_cnt) + ) + f1_cnt += 1 + rank_ops.append( + _Action(stage_id_chunk1, computation_type=I, microbatch_index=b1_cnt) + ) + rank_ops.append( + _Action(stage_id_chunk1, computation_type=W, microbatch_index=b1_cnt) + ) + b1_cnt += 1 + # stable phase + while f1_cnt < f0_cnt or f0_cnt < n_micro: + if f0_cnt < n_micro: + rank_ops.append( + _Action( + stage_id_chunk0, computation_type=F, microbatch_index=f0_cnt + ) + ) + f0_cnt += 1 + rank_ops.append( + _Action(stage_id_chunk0, computation_type=I, microbatch_index=b0_cnt) + ) + rank_ops.append( + _Action(stage_id_chunk0, computation_type=W, microbatch_index=b0_cnt) + ) + b0_cnt += 1 + + rank_ops.append( + _Action(stage_id_chunk1, computation_type=F, microbatch_index=f1_cnt) + ) + f1_cnt += 1 + rank_ops.append( + _Action(stage_id_chunk1, computation_type=I, microbatch_index=b1_cnt) + ) + rank_ops.append( + _Action(stage_id_chunk1, computation_type=W, microbatch_index=b1_cnt) + ) + b1_cnt += 1 + # cool-down phase + w0_cnt, w1_cnt = b0_cnt, b1_cnt + cooldown_n1 = rank + for _ in range(cooldown_n1): + rank_ops.append( + _Action(stage_id_chunk0, computation_type=I, microbatch_index=b0_cnt) + ) + b0_cnt += 1 + rank_ops.append( + _Action(stage_id_chunk1, computation_type=I, microbatch_index=b1_cnt) + ) + b1_cnt += 1 + cooldown_n2 = self.pp_group_size - rank + for _ in range(cooldown_n2): + rank_ops.append( + _Action(stage_id_chunk0, computation_type=I, microbatch_index=b0_cnt) + ) + b0_cnt += 1 + rank_ops.append( + _Action(stage_id_chunk0, computation_type=W, microbatch_index=w0_cnt) + ) + w0_cnt += 1 + while w1_cnt < b1_cnt: + rank_ops.append( + _Action(stage_id_chunk1, computation_type=W, microbatch_index=w1_cnt) + ) + w1_cnt += 1 + while w0_cnt < b0_cnt: + rank_ops.append( + _Action(stage_id_chunk0, computation_type=W, microbatch_index=w0_cnt) + ) + w0_cnt += 1 + + assert w0_cnt == b0_cnt and b0_cnt == f0_cnt + assert w1_cnt == b1_cnt and b1_cnt == f1_cnt + # We use max() in the n_micro computation above, so we may need to + # remove redundant microbatches + rank_ops = [ + ( + action + if action is not None + and action.microbatch_index is not None + and action.microbatch_index < self._n_microbatches + else None + ) + for action in rank_ops + ] + return rank_ops + + +class ScheduleDualPipeV(_PipelineScheduleRuntime): + """ + The DualPipeV schedule. A more efficient schedule variant based on the + DualPipe schedule introduced by DeepSeek in https://arxiv.org/pdf/2412.19437 + + Based on the open sourced code from https://github.com/deepseek-ai/DualPipe + """ + + def __init__( + self, + stages: list[_PipelineStageBase], + n_microbatches: int, + loss_fn: Callable | None = None, + args_chunk_spec: tuple[TensorChunkSpec, ...] | None = None, + kwargs_chunk_spec: dict[str, TensorChunkSpec] | None = None, + output_merge_spec: dict[str, Any] | tuple[Any] | None = None, + scale_grads: bool = True, + backward_requires_autograd: bool = True, + ): + # TODO: we dont support input/weight backward split with torch.compile + _check_torch_compile_compatibility(stages, self.__class__.__name__) + self.pp_group_size = stages[0].group_size + super().__init__( + stages=stages, + n_microbatches=n_microbatches, + loss_fn=loss_fn, + args_chunk_spec=args_chunk_spec, + kwargs_chunk_spec=kwargs_chunk_spec, + output_merge_spec=output_merge_spec, + scale_grads=scale_grads, + backward_requires_autograd=backward_requires_autograd, + ) + self.stage_index_to_group_rank = generate_stage_to_rank_mapping( + self.pp_group_size, self._num_stages, style="v" + ) + for stage in self._stages: + stage.stage_index_to_group_rank = self.stage_index_to_group_rank + + self.n_local_stages = len(stages) + if self.n_local_stages != 2: + raise ValueError( + "ZBV requires exactly 2 stages per rank, but got " + f"{self.n_local_stages}." + ) + if n_microbatches < self._num_stages: + raise ValueError( + "DualPipeV requires at least as many microbatches as stages, but got " + f"{n_microbatches} microbatches and {self._num_stages} stages." + ) + + self.rank = stages[0].group_rank + self.num_stages = stages[0].num_stages + + # 1. Create the pipeline_order (all ranks do this calculation) + # This will be used to keep track of the current state of the entire pipeline + # pipeline_order[rank] = [Action(computation_type, microbatch_index, stage_index), ...] + self.pipeline_order: dict[int, list[_Action | None]] = {} + for rank in range(self.pp_group_size): + rank_ops = self._calculate_single_rank_operations(rank) + self.pipeline_order[rank] = rank_ops + + # Initialize the pipeline order with communication necessary to run with _PipelineScheduleRuntime + self._prepare_schedule_with_comms(self.pipeline_order) + + def _calculate_single_rank_operations(self, rank) -> list[_Action | None]: + actions: list[_Action | None] = [] + counters: dict[ + tuple[int, _ComputationType], int + ] = {} # (stage_index, computation_type) -> mb_index + weight_queue = [] # Queue of (stage_index, mb_index) for pending weight actions + + num_ranks = self.pp_group_size + num_chunks = self._n_microbatches + + rank_to_stages = generate_rank_to_stage_mapping( + num_ranks, num_ranks * 2, style="v" + ) + stage0_index, stage1_index = rank_to_stages[rank] + + def increment_backward_counts(stage_index: int): + """Helper method to increment BACKWARD_INPUT and BACKWARD_WEIGHT counters when FULL_BACKWARD is used.""" + input_key = (stage_index, BACKWARD_INPUT) + weight_key = (stage_index, BACKWARD_WEIGHT) + counters[input_key] = counters.get(input_key, 0) + 1 + counters[weight_key] = counters.get(weight_key, 0) + 1 + + def add_overlap_f_b( + actions: list, + forward_stage: int, + backward_stage: int, + ): + """Helper method to add an overlapped forward+backward action which tracks microbatch index.""" + # Create new overlapped forward+backward action with sub_actions + forward_key = (forward_stage, FORWARD) + backward_key = (backward_stage, BACKWARD_INPUT) + + forward_mb = counters.get(forward_key, 0) + backward_mb = counters.get(backward_key, 0) + + sub_actions = ( + _Action(forward_stage, FORWARD, forward_mb), + _Action(backward_stage, FULL_BACKWARD, backward_mb), + ) + actions.append(_Action(-1, OVERLAP_F_B, None, sub_actions)) + + # Update counters for sub_actions + counters[forward_key] = forward_mb + 1 + increment_backward_counts(backward_stage) + + def add_action( + actions: list, + stage_index: int, + computation_type: _ComputationType, + ): + # Regular single action, for FULL_BACKWARD we only use the BACKWARD_INPUT counter + key = ( + (stage_index, computation_type) + if computation_type != FULL_BACKWARD + else (stage_index, BACKWARD_INPUT) + ) + mb_index = counters.get(key, 0) + actions.append(_Action(stage_index, computation_type, mb_index)) + + # If FULL_BACKWARD is used, just increment the separate BACKWARD_INPUT and BACKWARD_WEIGHT counters + if computation_type == FULL_BACKWARD: + increment_backward_counts(stage_index) + else: + # If BACKWARD_INPUT is updated, add corresponding weight action to queue + if computation_type == BACKWARD_INPUT: + # Add weight action to queue for later processing + weight_queue.append((stage_index, mb_index)) + counters[key] = mb_index + 1 + + def add_weight_action_if_pending(actions: list): + """Helper method to add a weight action from the queue.""" + if not weight_queue: + return # No pending weight actions, skip + # Pop the oldest weight action from the queue + actual_stage_index, weight_mb_index = weight_queue.pop(0) + actions.append( + _Action( + actual_stage_index, + BACKWARD_WEIGHT, + weight_mb_index, + ) + ) + # Update the counter for the actual stage that was processed + weight_key = (actual_stage_index, BACKWARD_WEIGHT) + counters[weight_key] = counters.get(weight_key, 0) + 1 + + # Step 1: F0 + step_1 = (num_ranks - rank - 1) * 2 + for _ in range(step_1): + add_action(actions, stage0_index, FORWARD) + + # Step 2: F0F1 + step_2 = rank + 1 + for _ in range(step_2): + add_action(actions, stage0_index, FORWARD) + add_action(actions, stage1_index, FORWARD) + + # Step 3: I1W1F1 (Use zero bubble) + step_3 = num_ranks - rank - 1 + for _ in range(step_3): + add_action(actions, stage1_index, BACKWARD_INPUT) + add_weight_action_if_pending(actions) + add_action(actions, stage1_index, FORWARD) + + # Step 4 (Main step): F0B1-F1B0 (combined, overlapped forward+backward) + step_4 = num_chunks - num_ranks * 2 + rank + 1 + for i in range(step_4): + if i == 0 and rank == num_ranks - 1: + # NOTE: We don't overlap these two chunks to further reduce bubble size. + add_action(actions, stage0_index, FORWARD) + add_action(actions, stage1_index, FULL_BACKWARD) + else: + add_overlap_f_b( + actions, + forward_stage=stage0_index, + backward_stage=stage1_index, + ) + add_overlap_f_b( + actions, + forward_stage=stage1_index, + backward_stage=stage0_index, + ) + + # Step 5: B1-F1B0 + step_5 = num_ranks - rank - 1 + for _ in range(step_5): + add_action(actions, stage1_index, FULL_BACKWARD) + add_overlap_f_b( + actions, + forward_stage=stage1_index, + backward_stage=stage0_index, + ) + + # Step 6: B1B0 (The second half of the chunks use zero bubble) + step_6 = rank + 1 + enable_zb = False + for i in range(step_6): + if i == step_6 // 2 and rank % 2 == 1: + enable_zb = True + comp_type = BACKWARD_INPUT if enable_zb else FULL_BACKWARD + add_action(actions, stage1_index, comp_type) + if i == step_6 // 2 and rank % 2 == 0: + enable_zb = True + comp_type = BACKWARD_INPUT if enable_zb else FULL_BACKWARD + add_action(actions, stage0_index, comp_type) + + # Step 7: W0B0 + step_7 = num_ranks - rank - 1 + for _ in range(step_7): + add_weight_action_if_pending(actions) + comp_type = BACKWARD_INPUT if enable_zb else FULL_BACKWARD + add_action(actions, stage0_index, comp_type) + + # Step 8: W0 + step_8 = rank + 1 + for _ in range(step_8): + add_weight_action_if_pending(actions) + + return actions + + +def get_schedule_class(schedule_name: str): + """ + Maps a schedule name (case insensitive) to its corresponding class object. + + Args: + schedule_name (str): The name of the schedule. + """ + schedule_map = { + "1F1B": Schedule1F1B, + "Interleaved1F1B": ScheduleInterleaved1F1B, + "GPipe": ScheduleGPipe, + "LoopedBFS": ScheduleLoopedBFS, + "InterleavedZeroBubble": ScheduleInterleavedZeroBubble, + "PipelineScheduleSingle": PipelineScheduleSingle, + "PipelineScheduleMulti": PipelineScheduleMulti, + "ZBVZeroBubble": ScheduleZBVZeroBubble, + "DualPipeV": ScheduleDualPipeV, + } + lowercase_keys = {k.lower(): k for k in schedule_map} + lowercase_schedule_name = schedule_name.lower() + if lowercase_schedule_name not in lowercase_keys: + raise ValueError( + f"Unknown schedule name '{schedule_name}'. The valid options are {list(schedule_map.keys())}" + ) + return schedule_map[lowercase_keys[lowercase_schedule_name]] + + +def _simulate_comms_compute( + pipeline_order, stage_to_rank: Callable[[int], int], num_stages: int +): + """This function dry-run simulates the actions in the schedule from the perspective of all ranks, and flags + any deadlocks caused by missing or misordered communications. It also simulates any bubbles in time where a rank + can not execute any action due to waiting for unmet dependencies. The total number of simulator steps can be used + as a metric for unit tests involving IR optimization passes as reordering and merging of IR can reduce the number + of simulated steps. + + The simulation is not high-fidelity and does not model overlapping of compute and communication, or cuda streams. + Future work may be to enhance this and model the compute time, comms overlap, and even memory. + """ + pipeline_order = { + rank: [a for a in pipeline_order[rank] if a is not None] + for rank in sorted(pipeline_order) + } + _schedule: dict[int, list[_Action | None]] = { + rank: [] for rank in sorted(pipeline_order) + } + + _prev_ops_rank: dict[int, set[_Action]] = {rank: set() for rank in _schedule} + + def add_to_schedule(rank: int, action: _Action | None): + _schedule[rank].append(action) + if action is not None: + _prev_ops_rank[rank].add(action) + + def _ready_to_schedule(action: _Action | None) -> bool: + if action is None: + return True + + stage_idx = action.stage_index + prev_ops = _prev_ops_rank[stage_to_rank(stage_idx)] + if action.computation_type == F: + if action.stage_index == 0: + return True + elif ( + _Action(action.stage_index, RECV_F, action.microbatch_index) in prev_ops + ): + return True + elif ( + _Action(action.stage_index - 1, F, action.microbatch_index) in prev_ops + ): + return True + return False + elif action.computation_type in (BACKWARD_INPUT, FULL_BACKWARD): + if action.stage_index == num_stages - 1: + return True + if _Action(action.stage_index, RECV_B, action.microbatch_index) in prev_ops: + return True + if ( + _Action(action.stage_index + 1, BACKWARD_INPUT, action.microbatch_index) + in prev_ops + ): + return True + if ( + _Action(action.stage_index + 1, FULL_BACKWARD, action.microbatch_index) + in prev_ops + ): + return True + return False + elif action.computation_type == BACKWARD_WEIGHT: + return True + elif action.computation_type == SEND_F: + expected_f = _Action(action.stage_index, F, action.microbatch_index) + return expected_f in prev_ops + elif action.computation_type == RECV_F: + peer_stage_idx = stage_idx - 1 + expected_send = _Action(peer_stage_idx, SEND_F, action.microbatch_index) + return expected_send in _prev_ops_rank[stage_to_rank(peer_stage_idx)] + elif action.computation_type == SEND_B: + expected_b = _Action( + action.stage_index, BACKWARD_INPUT, action.microbatch_index + ) + expected_bw = _Action( + action.stage_index, FULL_BACKWARD, action.microbatch_index + ) + return expected_b in prev_ops or expected_bw in prev_ops + elif action.computation_type == RECV_B: + peer_stage_idx = stage_idx + 1 + expected_send = _Action(peer_stage_idx, SEND_B, action.microbatch_index) + return expected_send in _prev_ops_rank[stage_to_rank(peer_stage_idx)] + else: + raise ValueError(f"Unsupported action type {action}") + + while pipeline_order: + progress = False + for rank in sorted(pipeline_order): + if len(pipeline_order[rank]) == 0: + continue + + action = pipeline_order[rank][0] + if _ready_to_schedule(action): + if action is not None: + add_to_schedule(rank, action) + pipeline_order[rank].pop(0) + progress = True + else: + add_to_schedule(rank, None) + + for i in sorted(pipeline_order, reverse=True): + if len(pipeline_order[i]) == 0: + del pipeline_order[i] + + # hacky, but do a second pass to replace any 'none' at this timestep with a real action, if it got unblocked + # by one of the later ranks + for rank in sorted(pipeline_order): + if len(pipeline_order[rank]) == 0: + continue + + if _schedule[rank][-1] is not None: + continue + + action = pipeline_order[rank][0] + if _ready_to_schedule(action): + if action is not None: + _schedule[rank][-1] = action + _prev_ops_rank[rank].add(action) + pipeline_order[rank].pop(0) + + for i in sorted(pipeline_order, reverse=True): + if len(pipeline_order[i]) == 0: + del pipeline_order[i] + + if not progress: + print("WIP comms schedule:\n", _format_pipeline_order(_schedule)) + for rank in pipeline_order: + print(f"{rank=} next action= {pipeline_order[rank][0]}") + raise ValueError("Schedule is not progressing") + + return _schedule + + +def _dump_chrometrace(schedule, filename): + """ + This function dumps a schedule IR into a chrometrace format so it can be visualized. + + It is currently very basic and only serves as a graphical alternative to dumping the schedule IR as text. + + As future work we may extend this to include more accurate heuristics for durations, or let users input durations, + add 'flow events' to let the UI show the connection between sends and recvs, and model cuda streams for comm/compute + as separate streams on the chrometrace view. + """ + events = [] + for rank in sorted(schedule): + for timestep, action in enumerate(schedule[rank]): + if action is None: + continue + events.append( + { + "name": str(action), + "cat": ( + "computation" + if action.computation_type in (F, B, W) + else "communication" + ), + "ph": "X", + "pid": rank, + "tid": rank, + "ts": timestep, + "dur": 1, + } + ) + import json + + with open(filename, "w") as f: + json.dump({"traceEvents": events}, f) + + +def _check_torch_compile_compatibility( + stages: list[_PipelineStageBase], schedule_name: str +): + """ + Check if the schedule is compatible with torch.compile. + + Args: + stages: List of pipeline stages to check + schedule_name: Name of the schedule for error message + + Raises: + RuntimeError: If any stage uses torch.compile + """ + for stage in stages: + if not isinstance(stage.submod, torch.nn.Module): + continue + + for module in stage.submod.modules(): + if isinstance(module, OptimizedModule): + raise RuntimeError( + f"The {schedule_name} schedule is not supported with " + "stage modules that have used torch.compile. " + f"Found OptimizedModule in {type(module).__name__}" + ) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/pipelining/stage.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/pipelining/stage.py new file mode 100644 index 0000000000000000000000000000000000000000..cc0d51020458bcfd45cdc34c45868dc374bc2564 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/pipelining/stage.py @@ -0,0 +1,1588 @@ +# mypy: allow-untyped-defs +# Copyright (c) Meta Platforms, Inc. and affiliates +import logging +import operator +from abc import ABC, abstractmethod +from collections.abc import Callable +from typing import Any, cast, Union + +import torch +import torch.distributed as dist +import torch.fx as fx +import torch.nn as nn +from torch._subclasses.fake_tensor import FakeTensor +from torch.distributed._composable.replicate_with_fsdp import replicate, ReplicateModule +from torch.distributed.fsdp import FSDPModule, fully_shard +from torch.fx.node import Argument, map_aggregate +from torch.nn.parallel import DistributedDataParallel +from torch.utils._pytree import tree_map_only + +from ._backward import stage_backward, stage_backward_input, stage_backward_weight +from ._debug import map_debug_info +from ._utils import flatten_args, PipeInfo, validate_tensors_metadata + + +__all__ = [ + "PipelineStage", + "build_stage", +] + +logger = logging.getLogger(__name__) + + +def _normalize_model_output_as_tuple(output: Any) -> tuple[Any]: + """[Note: pipeline model output type] + + The output of the model passed to pipelining can be any type, controlled by the user. + + However, there are 2 API surfaces that complicate this. + (1) the outputs of intermediate stages are passed via Send/Recv ops to subsequent stages. The implicit assumption + is that each element of the outputs is a tensor. Otherwise, Send/Recv would not be supported. The exception + is the last layer of the model, which can output anything any which won't be communicated via Send/Recv. + (2) the outputs of the last layer of the model are returned to the user, or, passed to the loss function. + The loss function can be written in any way, such that its inputs match the outputs of the model. + + It would be convenient if we could strictly type the output signature of the pipeline stage wrapping the model, + but we do not want to impose an unnecessary constraint on user provided models. + + Currently, we let user provided models return either a Tensor or a tuple of Tensors from each stage. Due to + torch.export tracing, compiled models may also return a list instead of a Tuple, which we will normalize back to a + tuple for consistency. + + TODO: should we be stricter about asserting that stage modules (intermediate and output) all return only Tensor + values? + """ + if type(output) is list: + # HACK: this is a hacky workaround for the fact that export creates + # output in list format + output = tuple(output) + + # Unify output form to tuple for easy correspondence with + # `act_send_info` + output_tuple = output if type(output) is tuple else (output,) + return output_tuple + + +class _RootArgPlaceholder: + """ + Placeholder for model-level inputs. + """ + + def __init__(self, tensor): + self.meta = tensor.to("meta") + + +class _RecvInfo: + """ + Represents a stage input. + """ + + def __init__( + self, + input_name: str, + source: int, + buffer: torch.Tensor, + ): + # Name of this input + self.input_name = input_name + # Stage index of the source of this input + self.source = source + # Buffer to receive the input into. + self.buffer = buffer + + def __repr__(self): + return f"_RecvInfo(input={self.input_name}, source={self.source}, shape={self.buffer.size()})" + + +# An input can be either a received activation or a model input +InputInfo = Union[_RecvInfo, _RootArgPlaceholder] + + +def _make_tensor_from_meta( + example: torch.Tensor | FakeTensor, + device: torch.device, +) -> torch.Tensor: + """ + Create a real tensor from a tensor. + """ + return torch.empty( + example.size(), + dtype=example.dtype, + layout=example.layout, + device=device, + ) + + +class _PipelineStageBase(ABC): + """ + Base class for pipeline stages. + Defines or implements common methods used by the `_PipelineStage` used by + the tracing frontend and `PipelineStage` used by manual frontend. + """ + + def __init__( + self, + submodule: torch.nn.Module, + stage_index: int, + num_stages: int, + device: torch.device, + group: dist.ProcessGroup | None = None, + dw_builder: Callable[[], Callable[..., None]] | None = None, + ): + """ + Args: + submodule (torch.nn.Module): The module to be executed in this stage. + stage_index (int): The index of this stage. + num_stages (int): The total number of stages in this pipeline. + device (torch.device): The device to run this stage on. + group (Optional[dist.ProcessGroup]): The process group to use for communication. + If `None`, the default process group will be used. + Default: `None`. + dw_builder (Optional[Callable[[], Callable[..., None]]): If provided, dw_builder is a builder function + that will build a new dw_runner function that will run parts of module backward that were intentionally + skipped during the module's actual backward pass. The builder must be invoked by stage after stage runs + model backwards, and stage should save the latest dw_runner to run during weight pas (W). + If not provided, a dw_runner will be generated automatically by traversing the autograd graph. + When used with schedules that only have F and B steps, the fresh dw_runner function will be called as + part of I (input backwards). When used with F,I,W schedules, the dw_runner function implements 'W'. + """ + super().__init__() + if stage_index >= num_stages: + raise ValueError( + f"Stage index {stage_index} is out of range of {num_stages}" + ) + + self.submod = submodule + self.stage_index = stage_index + self.num_stages = num_stages + # pyrefly: ignore [read-only] + self.device = device + self.group = group + + self.dw_builder = dw_builder + + # backward state + self.backward_state: dict[int, tuple[Any, ...]] = {} + + # store dw_runner per microbatch_id + self.dw_runner: dict[int, Callable[..., None]] = {} + + # `group_rank` is rank in process group `group`. + self.group_rank = dist.get_rank(self.group) + self.group_size = dist.get_world_size(self.group) + if self.group_size > self.num_stages: + raise RuntimeError( + f"Pipeline group size {self.group_size} cannot be larger than number of stages {self.num_stages}" + ) + + # Run time states + self._outputs_meta: tuple[torch.Tensor, ...] | None = None + # map microbatch ID to list of forward tensor args + self.fwd_cache: dict[int, tuple[Any, list[torch.Tensor]]] = {} + # map microbatch ID to list of backward grad tensor args + self.bwd_cache: dict[int, tuple[torch.Tensor | None, ...]] = {} + # Caching chunk outputs for final output merge or reduction + self.output_chunks: list[Any] = [] + + # Initialize has_backward to false; this will be set to true if loss + # function is passed to pipeline schedule + self.has_backward = False + # Log prefix + self.log_prefix = f"[Stage {self.stage_index}]" + + # Forward infra + self.args_recv_info: dict[int, tuple[InputInfo, ...]] = {} + self.act_send_info: dict[int, list] = {} + + # Backward infra will created lazily + self.grad_recv_info: dict = {} + self.grad_send_info: list | None = None + + # To be populated later by the Schedule + self.chunks: int | None = None + self.stage_index_to_group_rank: dict[int, int] = { + i: i % self.group_size for i in range(self.num_stages) + } + + @property + def has_backward(self) -> bool: + """ + Returns true if this stage has a backward pass. + """ + return self._has_backward + + @has_backward.setter + def has_backward(self, has_backward: bool): + self._has_backward = has_backward + + @property + def is_first(self): + """ + Returns true if this stage is the first stage in the pipeline. + """ + return self.stage_index == 0 + + @property + def is_last(self): + """ + Returns true if this stage is the last stage in the pipeline. + """ + return self.stage_index == self.num_stages - 1 + + def _check_chunk_id(self, chunk_id: int): + if self.chunks is None: + raise RuntimeError( + "Attempted to access chunk_id before chunks have been configured." + ) + if chunk_id >= self.chunks: + raise RuntimeError( + f"Chunk id {chunk_id} is out of range [0, {self.chunks})" + ) + + def _configure_outputs_meta(self, outputs_meta: tuple[torch.Tensor, ...]): + """ + Track the output shapes/dtype of this stage since they determine the send operation(s) which must match + recv operations of the next stage. The next stage _will_ be freezing its recv buffers based on its initial + configuration, so it's important to also freeze/validate the output side to avoid any send/recv mismatches + which could show up as hangs, silent corruption, or other errors. + """ + assert self._outputs_meta is None, ( + "Attempting to reconfigure output_meta, which is not supported" + ) + self._outputs_meta = tuple(outputs_meta) # type: ignore[assignment] + + def get_outputs_meta(self) -> tuple[torch.Tensor, ...]: + """Get the output metadata (meta tensors) representing the outputs of this stage""" + assert self._outputs_meta is not None, ( + "Attempted to get_outputs_meta() without configuring output meta" + ) + return self._outputs_meta + + def _create_grad_send_info( + self, + args_recv_info: tuple, + ) -> list[int | None]: + """ + Create a list of stage indices to send gradients to. + """ + grad_send_info: list[int | None] = [] + + def map_recv_to_send(a): + # Note: we send gradients back to previous stage as long as in + # forward it is a received input, regardless of whether it requires + # grad. It is up to the previous stage to discard this gradient. + if isinstance(a, _RecvInfo): + grad_send_info.append(a.source) + return a.source + else: + grad_send_info.append(None) + return None + + map_aggregate(args_recv_info, map_recv_to_send) + + logger.debug("%s Grad send info: %s", self.log_prefix, grad_send_info) + return grad_send_info + + @abstractmethod + def _prepare_forward_infra( + self, + num_microbatches: int, + args: tuple[Any, ...], + kwargs: dict[str, Any] | None = None, + ) -> tuple[Any, ...]: + raise NotImplementedError + + def _prepare_backward_infra(self, num_microbatches: int): + # TODO: this is needed for backward_maybe_with_nosync + self.chunks = num_microbatches + + for mb_index in range(num_microbatches): + # `grad_recv_info` is a mirror of `act_send_info` + self.grad_recv_info[mb_index] = self._create_grad_recv_info( + self.act_send_info + ) + + @abstractmethod + def _create_grad_recv_info( + self, + act_send_info: dict, + ) -> tuple[_RecvInfo, ...]: + raise NotImplementedError + + def _get_recv_ops( + self, + recv_infos: tuple[InputInfo, ...], + ) -> list[dist.P2POp]: + """ + Helper function shared by `get_fwd_recv_ops` and `get_bwd_recv_ops`. + Returns a list of ops that correspond to the recv infos. + """ + ops: list[dist.P2POp] = [] + for info in recv_infos: + if not isinstance(info, _RecvInfo): + continue + + peer_rank = self.stage_index_to_group_rank[info.source] + peer_global_rank = ( + peer_rank + if self.group is None + else dist.get_global_rank(self.group, peer_rank) + ) + ops.append( + dist.P2POp(dist.irecv, info.buffer, peer_global_rank, self.group) + ) + + return ops + + """[Note: V-schedule special case] + + V-Schedules have a special case where 2 stages with adjacent stage_id are on the same rank. + + ex: 2 ranks, 4 stages forms a simple V: + rank0: stage 0 stage 3 + rank1: stage 1 stage 2 + + stage 0,1 and 2,3 communicate activations using send/recv as usual, but stage 1,2 do not need to + use communication ops. Instead, they should pass tensor data directly via function call. + + set_local_fwd_input and (get_local_bwd_output + set_local_bwd_input) facilitate this optimization, and + should be called at the appropriate time during the pipeline schedule (after forward or backward execution). + """ + + def set_local_fwd_input(self, prev_stage_outputs: Any, mb_index: int) -> None: + """ + Moves 'prev_stage_outputs' from another stage on the same rank into place as inputs for this stage. Avoids + copying tensor data or using send/recv op. Detaches original tensor and sets requires_grad so the + tensor can serve as a leaf for autograd and gradients can be collected from it during backward. + """ + recv_infos: tuple[InputInfo, ...] = self.args_recv_info[mb_index] + + # See [Note: pipeline model output type] + prev_stage_outputs = _normalize_model_output_as_tuple(prev_stage_outputs) + + for info, tensor in zip(recv_infos, prev_stage_outputs): + assert isinstance(tensor, torch.Tensor), ( + f"expected tensor values as outputs from prev stage, got {type(tensor)}" + ) + assert isinstance(info, _RecvInfo), ( + "set_local_Fwd_input should only be called on non-first stage, which should always have RecvInfo" + ) + + # We don't need to do a data copy here, since we can directly pass the activation tensor reference from + # one stage to the next. However, we do need to mark the activation as a leaf tensor since it will serve + # as the input tensor for a fresh autograd graph, not part of the previous stage's autograd graph. + # TODO: confirm, do we use this activation as the root of the backward call for the previous stage? does + # detach have any affect on that? + info.buffer = tensor.detach().requires_grad_(True) + + def get_local_bwd_output(self, mb_index): + """ + Returns the input grad tensors for this stage, which correspond to the stage inputs during forward. + """ + assert self.has_backward, ( + "can't steal_bwd_input if this stage doesn't have backward" + ) + assert not self.is_first, "can't get bwd output if this stage is first" + + self._check_chunk_id(mb_index) + return self.bwd_cache.pop(mb_index) + + def set_local_bwd_input( + self, next_stage_bwd_outputs: tuple[torch.Tensor | None, ...], mb_index: int + ) -> None: + """ + Moves 'grad input' tensors from the next stage to 'grad_output' on this stage, avoiding a copy or send/recv. + Does not detach or set '_requires_grad'. + """ + assert isinstance(next_stage_bwd_outputs, tuple), ( + f"Expected tuple, got {type(next_stage_bwd_outputs)}" + ) + + assert self.has_backward, ( + "can't set bwd input if this stage doesn't have backward" + ) + assert not self.is_last, "can't set bwd input if this stage is last" + recv_infos = self.grad_recv_info[mb_index] + for info, tensor in zip(recv_infos, next_stage_bwd_outputs): + assert isinstance(tensor, torch.Tensor), ( + f"expected tensor values as outputs from prev stage, got {type(tensor)}" + ) + assert isinstance(info, _RecvInfo), ( + f"Expected a recv info, got {type(info)}" + ) + info.buffer = tensor + + def get_fwd_recv_ops(self, fwd_chunk_id: int) -> list[dist.P2POp]: + """ + Returns a list of ops that are needed to receive the input arguments + for this stage. + """ + recv_infos: tuple[InputInfo, ...] = self.args_recv_info[fwd_chunk_id] + + return self._get_recv_ops(recv_infos) + + def get_bwd_recv_ops(self, bwd_chunk_id: int) -> list[dist.P2POp]: + """ + Returns a list of ops that are needed to receive the gradients + for this stage. + """ + if not self.has_backward or self.is_last: + return [] + + recv_infos = self.grad_recv_info[bwd_chunk_id] + return self._get_recv_ops(recv_infos) + + def get_fwd_send_ops(self, fwd_chunk_id: int) -> list[dist.P2POp]: + """ + Get the activation send ops for current stage's forward. + """ + output_tuple, _ = self.fwd_cache[fwd_chunk_id] + + ops: list[dist.P2POp] = [] + + for idx, out in enumerate(output_tuple): + dst_stages = self.act_send_info[idx] + for dst in dst_stages: + if dst is None: + continue + logger.debug( + "%s Sending tensor to Stage %s: %s", + self.log_prefix, + dst, + out.size(), + ) + peer_rank = self.stage_index_to_group_rank[dst] + peer_global_rank = ( + peer_rank + if self.group is None + else dist.get_global_rank(self.group, peer_rank) + ) + ops.append(dist.P2POp(dist.isend, out, peer_global_rank, self.group)) + + return ops + + def get_bwd_send_ops(self, bwd_chunk_id: int) -> list[dist.P2POp]: + """ + Get the gradient send ops for current stage's backward. + """ + if not self.has_backward or self.is_first: + return [] + + self._check_chunk_id(bwd_chunk_id) + # Create bwd send infra lazily + if self.grad_send_info is None: + # Send info for input grads during backward: + # List of destinations corresponding to input grads + # Can be None if an input has no grad + # `grad_send_info` is a mirror of `args_recv_info` + self.grad_send_info = self._create_grad_send_info(self.args_recv_info[0]) + + ops: list[dist.P2POp] = [] + grads_input = self.bwd_cache.pop(bwd_chunk_id) + for grad, grad_recv_stage in zip(grads_input, self.grad_send_info): + if isinstance(grad, torch.Tensor) and grad_recv_stage is not None: + logger.debug( + "%s Sending gradient to Stage %s: %s", + self.log_prefix, + grad_recv_stage, + grad.size(), + ) + peer_rank = self.stage_index_to_group_rank[grad_recv_stage] + peer_global_rank = ( + peer_rank + if self.group is None + else dist.get_global_rank(self.group, peer_rank) + ) + ops.append(dist.P2POp(dist.isend, grad, peer_global_rank, self.group)) + else: + if not (grad is None and grad_recv_stage is None): + raise RuntimeError( + f"[{self.stage_index}] for chunk {bwd_chunk_id} has gradients {grad} " + f"and is expecting to send gradients to stage {grad_recv_stage}" + ) + return ops + + def clear_runtime_states(self) -> None: + """ + Clear runtime states of the stage. + """ + # map microbatch ID to list of forward tensor args + self.fwd_cache.clear() + # Caching chunk outputs for final output merge or reduction + self.output_chunks.clear() + + # Clear grad of input buffers in between schedule steps. This is because + # `torch.autograd.backward()` will accumulate gradients into leaf + # tensors by default. For gradients to pass back to previous stages, we + # don't want such accumulation. + for recv_tuple in self.args_recv_info.values(): # iterate over all chunks + for a in recv_tuple: # iterate over all input args + if isinstance(a, _RecvInfo): + # Set to None is the newer and recommended way to clear grads, compared to `zero_()`. + # See https://github.com/pytorch/pytorch/pull/92731 + a.buffer.grad = None + + def _map_tensor_from_recv_info( + self, + recv_infos: tuple[InputInfo, ...], + ): + """ + Map tensors from recv infos to a list. + """ + + def get_recv_tensor(info): + if isinstance(info, _RecvInfo): + return info.buffer + else: + raise AssertionError(f"Expected _RecvInfo but got {type(info)}") + + return map_aggregate(cast(Argument, recv_infos), get_recv_tensor) + + def _retrieve_recv_activations(self, fwd_chunk_id: int): + """ + Retrieve the activations received for the current stage during forward. + """ + recv_infos = self.args_recv_info[fwd_chunk_id] + activations = self._map_tensor_from_recv_info(recv_infos) + return activations + + def _retrieve_recv_grads( + self, + bwd_chunk_id: int, + ): + """ + Retrieve the gradients received for the current stage during backward. + """ + recv_infos = self.grad_recv_info[bwd_chunk_id] + grads = self._map_tensor_from_recv_info(recv_infos) + return grads + + def forward_maybe_with_nosync(self, *args, **kwargs): + # If submod is wrapped with DDP, we use the `no_sync` context manager to + # avoid gradient all-reduce per microbatch + if isinstance(self.submod, DistributedDataParallel): + with self.submod.no_sync(): # type: ignore[operator] + out_val = self.submod(*args, **kwargs) + else: + out_val = self.submod(*args, **kwargs) + return out_val + + def scale_grads(self, grad_scale_factor: int) -> None: + """Scale gradients model gradients by `grad_scale_factor`, which should be specified in coordination with the + loss function used with pipelining. For loss functions which perform 'mean' loss reduction, `grad_scale_factor` + should be set to num_microbatches. For loss functions that use `sum` reduction, `grad_scale_factor` should + be set to 1. + + Should only be called once per pipeline schedule step, after all backwards passes have completed. + """ + + # PP scales only for its own contribution (microbatches), but relies on DP to scale further + # for DP degree. + if grad_scale_factor != 1: + for p in self.submod.parameters(): + if p.grad is not None: + p.grad.div_(grad_scale_factor) + + def backward_maybe_with_nosync( + self, + backward_type, + bwd_kwargs: dict, + last_backward: bool = False, + ) -> tuple[tuple[torch.Tensor | None, ...], list[dict[str, Any]] | None]: + """ + Whether using PP with FSDP, DDP, or replicate there are some runtime differences between the last backward step and the + other steps. Namely, we need to accumulate gradients on previous steps and reduce them on the last step, but + there are additional state-variables and performance considerations depending on the data parallelism used. + This helper should adapt any pipeline parallel schedule to work with common/supported data parallel libraries. + """ + + def perform_backward( + backward_type, + ) -> Callable[ + [], + tuple[tuple[torch.Tensor | None, ...], list[dict[str, Any]] | None], + ]: + if backward_type == "full": + return lambda: ( + stage_backward( + bwd_kwargs["stage_output"], + bwd_kwargs["output_grads"], + bwd_kwargs["input_values"], + ), + None, + ) + elif backward_type == "input": + return lambda: stage_backward_input( + bwd_kwargs["stage_output"], + bwd_kwargs["output_grads"], + bwd_kwargs["input_values"], + self.submod.parameters(), + ) + elif backward_type == "weight": + return lambda: ( + stage_backward_weight( + self.submod.parameters(), bwd_kwargs["param_groups"] + ), + None, + ) + else: + raise RuntimeError(f"Unknown backward type: {backward_type}") + + # If submod is wrapped by DDP + if isinstance(self.submod, DistributedDataParallel): + if last_backward: + # Last chunk, prepare for gradient reduction + # HACK: reaching into DDP implementation details here. Is there a better way? + self.submod.reducer.prepare_for_backward( # type: ignore[union-attr, operator] + list( + torch.nn.parallel.distributed._find_tensors( # type: ignore[attr-defined] + bwd_kwargs["stage_output"] + ) + ) + ) + result = perform_backward(backward_type)() + else: + with self.submod.no_sync(): # type: ignore[operator] + result = perform_backward(backward_type)() + + # If submod is a FSDP or replicate module + elif isinstance(self.submod, FSDPModule): + self.submod.set_is_last_backward(False) + self.submod.set_reshard_after_backward(False) + self.submod.set_requires_gradient_sync(False) + result = perform_backward(backward_type)() + + else: + # Non-DP submodule, regular backward + result = perform_backward(backward_type)() + + grads, param_groups = result + return grads, param_groups + + def forward_one_chunk( + self, + fwd_chunk_id: int, + args: tuple[Any, ...], + kwargs: dict[str, Any] | None = None, + save_forward_output: bool = True, + ): + """ + Perform forward pass on the stage with one microbatch. + `args` and `kwargs` are the inputs from *external* to this stage. + As of Sept 2024: + - `args` applies to the first stage only, other stages receives args + through activation transmission. + - `kwargs` can be passed to all stages via respective `step` calls. + """ + + if self.is_first: + # First stage doesn't need to receive anything + composite_args = args + else: + # Receive activations for this chunk + # Activations only come in args form + composite_args = self._retrieve_recv_activations(fwd_chunk_id) + + composite_kwargs = kwargs or {} + + self._validate_fwd_input(args, kwargs) + + # Compute forward + try: + output = self.forward_maybe_with_nosync(*composite_args, **composite_kwargs) + + except Exception as e: + exc_msg = f""" + {self.log_prefix} failed to run forward: + args: {map_debug_info(composite_args)} + kwargs: {map_debug_info(composite_kwargs)} + """ + raise RuntimeError(exc_msg) from e + + # See [Note: pipeline model output type] + output_tuple = _normalize_model_output_as_tuple(output) + + # Prepare for final output merge or reduction + # Output chunks is only used for the last stage since we only merge the output of the last stage + if self.is_last and save_forward_output: + self.output_chunks.append(output) + # Save activations and inputs for backward + flat_args = flatten_args(composite_args) + flat_kwargs = flatten_args(composite_kwargs) + flatten_input_tensors = flat_args + flat_kwargs + self.fwd_cache[fwd_chunk_id] = ( + output_tuple, # stage_output + flatten_input_tensors, # input_values + ) + + logger.debug( + "%s Forwarded chunk %s, outputs: %s", + self.log_prefix, + fwd_chunk_id, + map_debug_info(output), + ) + self._validate_fwd_outputs(output_tuple) + + # We return the original user-provided output, not normalized to tuple. + # See [Note: pipeline model output type] + return output + + def backward_one_chunk( + self, + bwd_chunk_id: int, + loss=None, + full_backward: bool = True, + last_backward=False, + ): + """ + Perform backward pass on the module. + This should only be called once per microbatch. + + If full_backward is True (the default), the full backward pass including weight and input gradients will be run, + and it is an error to call `backward_weight_one_chunk` for this bwd_chunk_id. + + If full_backward is False, it is optional that `dw_runner` was provided to the PipelineStage at __init__ time, + and a subsequent call to `backward_weight_one_chunk` is required to invoke dw_runner and complete the backward. + + last_backward is controlled by the schedule and signals synchronization of gradients across DP groups + after the last backward. + """ + # skip backward computation if backward is not enabled + if not self.has_backward: + return + + self._check_chunk_id(bwd_chunk_id) + + ( + stage_output, + input_values, + ) = self.fwd_cache.pop(bwd_chunk_id) + + # Compute backward + if self.is_last: + # Last stage computes gradients from loss and has no gradients from + # next stage + bwd_kwargs = { + "stage_output": loss, + "output_grads": None, + "input_values": input_values, + } + else: + # Otherwise, receive gradients from next stage + grads_output = self._retrieve_recv_grads(bwd_chunk_id) + # If an input to the pipeline requires gradient, + # `torch.autograd.backward` will accumulate the gradient into the + # `.grad` field of such input + bwd_kwargs = { + "stage_output": stage_output, + "output_grads": grads_output, + "input_values": input_values, + } + + grads_input: tuple[torch.Tensor | None, ...] = () + + # Custom backward function + if self.dw_builder: + # TODO: We may want to change our semantics so we are allowed to ignore + # the 'dw_builder' and call full_backward directly when it is a full_backward op. + grads_input, _ = self.backward_maybe_with_nosync( + "full", + bwd_kwargs, + last_backward=last_backward, + ) + if full_backward: + self.dw_builder()() + else: + self.dw_runner[bwd_chunk_id] = self.dw_builder() + else: + if full_backward: + grads_input, _ = self.backward_maybe_with_nosync( + "full", bwd_kwargs, last_backward=last_backward + ) + else: + param_groups: list[dict[str, Any]] | None = None + # Skip the backward for the first stage since we will perform the weight update with + # autograd.backward in backward_weight_one_chunk + if not self.is_first: + if isinstance(bwd_kwargs["stage_output"], torch.Tensor): + bwd_kwargs["stage_output"] = (bwd_kwargs["stage_output"],) + + # perform the partial backwards for the inputs with a custom backward function + # when the "stage_ouput" is a loss, then it is a tensor, otherwise it is a tuple of tensors + grads_input, param_groups = self.backward_maybe_with_nosync( + "input", bwd_kwargs, last_backward=last_backward + ) + + # TODO: we dont need to save this, add to dw_runner? + self.backward_state[bwd_chunk_id] = ( + bwd_kwargs["input_values"], + param_groups, + bwd_kwargs["stage_output"], + bwd_kwargs["output_grads"], + ) + # Save a placeholder for the dw_runner + self.dw_runner[bwd_chunk_id] = lambda: None + + self.bwd_cache[bwd_chunk_id] = grads_input + + if self.is_last and not self.is_first: + # Autograd dependencies: + # rest_of_autograd_graph -> stage_output -> loss + # stage_output is no longer used in the last stage for backward and only needed + # to return to the user in merge_output_chunks, therefore + # this should be detached to release autograd graph context and free memory earlier + for t in stage_output: + if not t._is_view(): # views are not detachable in-place + t.detach_() + + logger.debug("%s Backwarded chunk %s", self.log_prefix, bwd_chunk_id) + + def backward_weight_one_chunk(self, bwd_chunk_id: int, last_backward=False): + # skip backward computation if backward is not enabled + if not self.has_backward: + return + + assert bwd_chunk_id in self.dw_runner, ( + f"{self.log_prefix} Attempted to run backward_weight_one_chunk for chunk {bwd_chunk_id}" + " without first calling `backward_one_chunk(full_backward=False)`" + ) + + if self.dw_builder is not None: + self.dw_runner.pop(bwd_chunk_id)() + else: + ( + input_values, + param_groups, + stage_output, + output_grads, + ) = self.backward_state.pop(bwd_chunk_id) + + if self.stage_index != 0: + bwd_kwargs = { + "stage_output": stage_output, + "param_groups": param_groups, + } + self.backward_maybe_with_nosync( + "weight", bwd_kwargs, last_backward=last_backward + ) + else: + # TODO: figure out a better way to do this: + # if inputs does not require gradient, + # then the parameter group will not be fully captured during stage_backward_input + # in this case, we need call grad directly on the parameters + # To solve: make input fn do the intersect compute and then finish it off during W + bwd_kwargs = { + "stage_output": stage_output, + "output_grads": output_grads, + "input_values": input_values, + } + self.backward_maybe_with_nosync( + "full", bwd_kwargs, last_backward=last_backward + ) + + def _validate_fwd_input(self, args, kwargs): + """Raises a RuntimeError if shapes of input args/kwargs do not match the shapes configured for this stage.""" + + if self.is_first: + # TODO why is there a separate recv_info for each pipeline chunk? + # kwen2501: to avoid passing a `fwd_chunk_id` to this function, we + # check all chunks against args_recv_info[0] + expected_args = self.args_recv_info[0] + else: + # We don't check inputs for non-0 stages assuming they don't accept + # user inputs in canonical pipeline scenarios + return + + if len(kwargs): + # TODO- need a mapping of kwarg to position in self.args_recv_info + # Without it, we are not 100% sure how to match the args and + # expected_args. + return + + # TODO- need a mapping of kwarg to position in self.args_recv_info + # maybe it's impossible to tell whether the len mismatches because + # (a) the user passed an extra arg or missed an arg + # (b) the user did not pass a kwarg, which has a default value baked into expected_args + expected_tensors_meta = [ + e.meta if isinstance(e, _RootArgPlaceholder) else e.buffer + for e in expected_args + ] + validate_tensors_metadata( + f"Stage {self.stage_index} forward inputs", expected_tensors_meta, args + ) + + def _validate_fwd_outputs(self, outputs: tuple[torch.Tensor, ...]): + """Raises a RuntimeError if this stage produces an output of unexpected shape/dtype. + Most likely, this could be cause either by incorrect user specification of output shapes, or because + shape inference was done on the original model but then at runtime the model is wrapped with something like + mixed precision which changes output dtype. + """ + expected_tensors_meta = self.get_outputs_meta() + validate_tensors_metadata( + f"Stage {self.stage_index} forward outputs", expected_tensors_meta, outputs + ) + + def _get_init_p2p_neighbors_ops(self) -> list[dist.P2POp]: + """ + Get the operations to initialize the p2p communicators between previous and next stages. + This is done so by creating a dummy tensor and sending it to the next stage and receiving + from the previous stage. + """ + ops: list[dist.P2POp] = [] + next_stage_peer_rank = self.stage_index_to_group_rank.get(self.stage_index + 1) + prev_stage_peer_rank = self.stage_index_to_group_rank.get(self.stage_index - 1) + + recv_tensor = torch.zeros(1, device=self.device, dtype=torch.float32) + send_tensor = torch.tensor( + self.stage_index, device=self.device, dtype=torch.float32 + ) + # forward + if not self.is_first: + ops.append( + dist.P2POp( + dist.irecv, + recv_tensor, + group_peer=prev_stage_peer_rank, + group=self.group, + ) + ) + if not self.is_last: + ops.append( + dist.P2POp( + dist.isend, + send_tensor, + group_peer=next_stage_peer_rank, + group=self.group, + ) + ) + + # backward + if not self.is_first: + ops.append( + dist.P2POp( + dist.isend, + send_tensor, + group_peer=prev_stage_peer_rank, + group=self.group, + ) + ) + if not self.is_last: + ops.append( + dist.P2POp( + dist.irecv, + recv_tensor, + group_peer=next_stage_peer_rank, + group=self.group, + ) + ) + + return ops + + def perform_reduce_grad(self, grad_scale_factor: int): + """ + Called as a part of schedule IR. + REDUCE_GRAD action is scheduled after all microbatches W, B actions. + + Currently contains "post_backward" functionality for FSDP. + We can try to extract post_backward in a separate IR action in future. + """ + # Manually call post backward for FSDP + if isinstance(self.submod, FSDPModule): + fsdp_module = self.submod + fsdp_module.set_is_last_backward(True) + fsdp_module.set_reshard_after_backward(True) + fsdp_module.set_requires_gradient_sync(True) + + if isinstance(fsdp_module, ReplicateModule): + distributed_state = replicate.state(fsdp_module) # type: ignore[arg-type] + else: + distributed_state = fully_shard.state(fsdp_module) # type: ignore[attr-defined] + + for state in distributed_state._state_ctx.all_states: + if state._fsdp_param_group: + state._fsdp_param_group.post_backward() + + # it would be much better if pipelining backward invoked .backward so autograd hooks + # worked and modules like DDP/FSDP behaved as expected. Working around this for the time being, + # we need to call this too to ensure FSDP syncs its grad reduction ops back to the default stream. + distributed_state._root_post_backward_final_callback() + # Call gradient scaling at the end of the backward pass + # NOTE: this must happen after FSDP post_backward is FSDP is enabled + if grad_scale_factor != 1: + self.scale_grads(grad_scale_factor) + + +class _PipelineStage(_PipelineStageBase): + def __init__( + self, + stage_module: torch.nn.Module, + stage_index: int, + pipe_info: PipeInfo, + device: torch.device, + group: dist.ProcessGroup | None = None, + ): + """ + Create a pipeline stage given a stage_module to be wrapped by this stage + and a `pipe_info` describing the stage relationship of the pipeline. + + Args: + stage_module (torch.nn.Module): the module to be wrapped by this stage + stage_index (int): the index of this stage in the pipeline + pipe_info (PipeInfo): information about the pipeline, can be retrieved by `pipe.info()` + device (torch.device): the device to be used by this stage + group (Optional[dist.ProcessGroup]): the process group to be used by this stage + """ + _PipelineStageBase.__init__( + self, + stage_module, + stage_index, + pipe_info.num_stages, + device, + group, + ) + self.pipe_info = pipe_info + + # Find stage nodes in graph + submod_nodes = [ + node for node in pipe_info.graph.nodes if node.op == "call_module" + ] + if len(submod_nodes) != self.num_stages: + raise AssertionError( + f"Number of submodules in pipe graph {len(submod_nodes)} does not match number of stages {self.num_stages}" + ) + + # Find my stage node in graph + self.node = submod_nodes[self.stage_index] + self.name = self.node.name + logger.info( + "[%s] Creating PipelineStage %s for %s", + self.group_rank, + stage_index, + self.name, + ) + + # Create mapping from stage name to stage index + self.submod_to_stage_index: dict[str, int] = {} + for i, node in enumerate(submod_nodes): + self.submod_to_stage_index.setdefault(node.name, i) + + # Cast submodule to device + self._move_submod_to_device() + + def _move_submod_to_device(self): + # Move submodule to indicated device if possible + # Note: we cannot move meta module to real devices because meta tensors + # do not support to() method. One needs to do an in-place tensor swap in + # that case. + has_meta_param = any( + isinstance(p, FakeTensor) or p.is_meta for p in self.submod.parameters() + ) + if has_meta_param: + logger.debug("%s Found meta parameters!", self.log_prefix) + else: + self.submod.to(self.device) + + def _prepare_forward_infra( + self, + num_microbatches: int, + args: tuple[Any, ...], + kwargs: dict[str, Any] | None = None, + ) -> tuple[Any, ...]: + """ + Create send/recv infrastructures for activations (during forward) + """ + # TODO(whc) + # this method should be deleted once lazy buffer allocation is implemented + # for now, it ignores args/kwargs because it should not need to do shape inference + for chunk in range(num_microbatches): + self.args_recv_info[chunk] = self._create_act_recv_info() + + # Send info during forward for each activation + self.act_send_info = self._create_act_send_info() + return tuple() + + def get_stage_index_of_submod( + self, + submod_name: str, + ): + """ + Given a submodule name, return the stage index of the submodule. + """ + if submod_name not in self.submod_to_stage_index: + raise AssertionError(f"Stage id of {submod_name} not found") + + return self.submod_to_stage_index[submod_name] + + def _create_act_recv_info( + self, + ): + """ + Create a tuple of `_RecvInfo` for inputs to the stage. + """ + + def create_recv_tensor(placeholder, arg_node): + """ + Create a receive buffer for a placeholder. + """ + example_value = placeholder.meta["val"] + if arg_node.op == "placeholder": + # This is a root level placeholder, thus an input argument to the entire model. + # We are likely at stage 0, hence no need to create a receive buffer. + return _RootArgPlaceholder(example_value) + + # Figure out the source stage of this input + while arg_node.target is operator.getitem: + # If the input is a getitem, we need to go deeper + arg_node = arg_node.args[0] + + assert arg_node.op == "call_module", ( + f"Expecting call_module, got {arg_node.op}" + ) + src_stage = self.get_stage_index_of_submod(arg_node.name) + + # Create a receive buffer for this placeholder + logger.debug( + "%s Creating recv buffer for input '%s' : %s, %s", + self.log_prefix, + placeholder.name, + example_value.shape, + example_value.dtype, + ) + buffer = _make_tensor_from_meta(example_value, self.device) + # In case there is backward pass, set requires_grad for receive buffers + # before first forward + if self.has_backward: + buffer.requires_grad_(True) + + return _RecvInfo( + arg_node.name, + src_stage, + buffer, + ) + + args_recv_info: list[InputInfo] = [] + # Filter out placeholder nodes from `self.submod` (a GraphModule) + placeholders = filter( # type: ignore[var-annotated] + lambda node: node.op == "placeholder", # type: ignore[arg-type] + self.submod.graph.nodes, # type: ignore[arg-type,union-attr] + ) + # `placeholders` are nodes internal to submod. + # `self.node.args` are dependency nodes in the outer graph. + # The two are 1:1. + for placeholder, arg_node in zip(placeholders, self.node.args): + # Create a receive buffer for this placeholder + recv_info = create_recv_tensor(placeholder, arg_node) + args_recv_info.append(recv_info) + + logger.debug( + "%s Activation recv / args info: %s", self.log_prefix, args_recv_info + ) + # `args` is a Tuple, hence we will return a Tuple[InputInfo] + return tuple(args_recv_info) + + def find_dst_rank( + self, + user: fx.Node, + ) -> int | None: + """ + Find the destination rank of a `user` node. + If the `user` is not a submod, `None` may be returned. + """ + if user.op == "call_module": + # User is a stage (`call_module`) + return self.get_stage_index_of_submod(user.name) + else: + # - If user.op == "output": + # No need to send back to rank 0 + # - If user.target is stage_backward: + # No need to send assuming submod output is stored locally or + # should be re-calculated in case of activation checkpointing + return None + + def _create_act_send_info(self): + """ + Create a dict of send info for activations. + The dict is of the form: + { + output_index: [dst_rank_0, dst_rank_1, ...], + ... + } + where the list of `dst_rank`s covers the case where an output value may + be consumed by multiple stages. + """ + # Output index: List of receiver ranks + act_send_info: dict[int, list] = {} + out_idx = 0 + + for user in self.node.users: + if user.target is operator.getitem: + # Recursively find the real destination + gi_dsts = act_send_info.setdefault(out_idx, []) + for gi_user in user.users: + dst_rank = self.find_dst_rank(gi_user) + if dst_rank is not None: + gi_dsts.append(dst_rank) + # Next `getitem` will point to the next output index + out_idx += 1 + else: + # In case of single output value, `out_idx` will not increase + dsts = act_send_info.setdefault(out_idx, []) + dst_rank = self.find_dst_rank(user) + if dst_rank is not None: + dsts.append(dst_rank) + + output_node = self._get_output_node() + output_vals: tuple[torch.Tensor] = tuple( + v.meta["val"] for v in flatten_args(output_node.args) + ) + self._configure_outputs_meta(output_vals) + + logger.debug("%s Send info: %s", self.log_prefix, act_send_info) + return act_send_info + + def _get_output_node(self): + output_nodes = [node for node in self.submod.graph.nodes if node.op == "output"] # type: ignore[union-attr] + assert len(output_nodes) == 1 + output_node = output_nodes[0] + return output_node + + def _create_grad_recv_info( + self, + act_send_info: dict, + ) -> tuple[_RecvInfo, ...]: + """ + Create a tuple of `_RecvInfo` for gradients. + """ + # Dict[output_index, _RecvInfo] + grad_recv_info: dict[int, _RecvInfo] = {} + output_node = self._get_output_node() + + # The output node may take multiple args, meaning the submod having multiple output values. + output_vals = flatten_args(output_node.args) + + for out_idx, dst_list in act_send_info.items(): + if not dst_list: + # No actual receiver for activation so no grad coming back + continue + + output = output_vals[out_idx] + example_value = output.meta["val"] + logger.debug( + f"{self.log_prefix} Creating grad recv buffer for output {output.name} " # noqa: G004 + f": {example_value.shape}, {example_value.dtype}" + ) + + # TODO: otherwise needs grad accumulation + assert len(dst_list) == 1, "Backward of skip connections not supported yet" + grad_src = dst_list[0] + grad_recv_info[out_idx] = _RecvInfo( + f"{grad_src}", # noqa: G004 + grad_src, + _make_tensor_from_meta(example_value, self.device), + ) + + # Convert to tuple for convenience in get_ops and retrieve tensor + grad_recv_info_tuple = tuple(grad_recv_info.values()) + logger.debug("%s Grad recv info: %s", self.log_prefix, grad_recv_info_tuple) + return grad_recv_info_tuple + + +# A helper function to create a pipeline stage based on traced pipeline information +def build_stage( + stage_module: torch.nn.Module, + stage_index: int, + pipe_info: PipeInfo, + device: torch.device, + group: dist.ProcessGroup | None = None, +) -> _PipelineStage: + """ + Create a pipeline stage given a stage_module to be wrapped by this stage + and pipeline information. + + Args: + stage_module (torch.nn.Module): the module to be wrapped by this stage + stage_index (int): the index of this stage in the pipeline + pipe_info (PipeInfo): information about the pipeline, can be retrieved by `pipe.info()` + device (torch.device): the device to be used by this stage + group (Optional[dist.ProcessGroup]): the process group to be used by this stage + + Returns: + _PipelineStage: a pipeline stage that can run with `PipelineSchedules`. + """ + return _PipelineStage( + stage_module, + stage_index, + pipe_info, + device, + group, + ) + + +class PipelineStage(_PipelineStageBase): + """ + A class representing a pipeline stage in a pipeline parallelism setup. + + PipelineStage assumes sequential partitioning of the model, i.e. the model is split into chunks where outputs from + one chunk feed into inputs of the next chunk, with no skip connections. + + PipelineStage performs runtime shape/dtype inference automatically by propagating the outputs from stage0 to + stage1 and so forth, in linear order. To bypass shape inference, pass the `input_args` and `output_args` to each + PipelineStage instance. + + Args: + submodule (nn.Module): The PyTorch module wrapped by this stage. + stage_index (int): The ID of this stage. + num_stages (int): The total number of stages. + device (torch.device): The device where this stage is located. + input_args (Union[torch.Tensor, Tuple[torch.tensor]], optional): The input arguments for the submodule. + output_args (Union[torch.Tensor, Tuple[torch.tensor]], optional): The output arguments for the submodule. + group (dist.ProcessGroup, optional): The process group for distributed training. If None, default group. + dw_builder (Optional[Callable[[], Callable[..., None]]): If provided, dw_builder will build a new dw_runner function + that will the W action (input weights) for F, I, W (Fwd, Input, Weight) zero bubble schedules. + """ + + def __init__( + self, + submodule: nn.Module, + stage_index: int, + num_stages: int, + device: torch.device, + input_args: torch.Tensor | tuple[torch.Tensor, ...] | None = None, + output_args: torch.Tensor | tuple[torch.Tensor, ...] | None = None, + group: dist.ProcessGroup | None = None, + dw_builder: Callable[[], Callable[..., None]] | None = None, + ): + super().__init__(submodule, stage_index, num_stages, device, group, dw_builder) + self.inputs: list[torch.Tensor] | None = None + self.inputs_meta: tuple[torch.Tensor, ...] | None = None + # Note: inputs and submod should ideally be on meta device. We decided not to assert this (yet) because it + # might be breaking for existing users. + if input_args is None: + assert output_args is None, ( + "If specifying output_args, input_args must also be specified. " + "Otherwise, shape inference will be performed at runtime" + ) + else: + self.inputs_meta = ( + (input_args,) if isinstance(input_args, torch.Tensor) else input_args + ) + if output_args is None: + logger.warning( + "Deprecation warning: passing input_args and performing init-time shape inference is deprecated. " + "PipelineStage now supports runtime shape inference using the real inputs provided to schedule step(). " + "Either delete `input_args` arg to `PipelineStage` to opt-into runtime shape inference, " + "or additionally pass `output_args` to `PipelineStage` to fully override shape inference. " + ) + try: + with torch.no_grad(): + output_args = submodule(*self.inputs_meta) + output_args = tree_map_only( + torch.Tensor, lambda x: x.to("meta"), output_args + ) + except Exception as e: + raise RuntimeError( + "Failed to perform pipeline shape inference- are your inputs on the same device as your module?" + ) from e + assert output_args is not None, ( + "If passing input_args, also pass output_args to override shape inference" + ) + self._configure_outputs_meta( + (output_args,) if isinstance(output_args, torch.Tensor) else output_args + ) + + # these are the buffers used in backwards send/recv, they are allocated later + self.outputs_grad: list[torch.Tensor] = [] + + dbg_str = ( + f"Finished pipeline stage init, {self.stage_index=}, {self.is_first=}, " # noqa: G004 + f"{self.is_last=}, {self.num_stages=}, " + ) + if self.inputs_meta is not None: + dbg_str += ( + f"inputs: {[inp.shape for inp in self.inputs_meta]}, " + f"output: {[output.shape for output in self.get_outputs_meta()]}" + ) + else: + dbg_str += " running shape-inference at runtime" + + logger.debug(dbg_str) + + def _shape_inference( + self, + args: tuple[Any, ...], + kwargs: dict[str, Any] | None = None, + ): + if kwargs is None: + kwargs = {} + assert args is not None, "Args may be an empty tuple but not None" + + # We skip recv communication if we're the first stage, but also if the previous stage is on the same rank + # and can pass its output shapes in as args instead of using send/recv. + if ( + self.is_first + # if not first stage, then check if prev stage is on the same rank + or self.stage_index_to_group_rank[self.stage_index - 1] == self.group_rank + ): + logger.debug( + "Shape inference: stage %s skipping recv, because shape info passed in via `args`", + self.stage_index, + ) + args = tree_map_only(torch.Tensor, lambda x: x.to("meta"), args) + else: + assert len(args) == 0, ( + "Can't supply input args for shape inference on non-first stage" + ) + objects = [None] + logger.debug( + "Shape inference: stage %s receiving from stage %s", + self.stage_index, + self.stage_index - 1, + ) + dist.recv_object_list( + objects, + src=dist.get_global_rank( + self.group or dist.distributed_c10d._get_default_group(), + self.stage_index_to_group_rank[self.stage_index - 1], + ), + group=self.group, + device=self.device, + use_batch=True, + ) + recv_args = objects[0] + assert isinstance(recv_args, tuple), type(recv_args) + args = recv_args + + # cache input shapes for use during recv buffer allocation + self.inputs_meta = args + args = tree_map_only( + torch.Tensor, lambda x: torch.zeros_like(x, device=self.device), args + ) + + # set attributes needed for forward + with torch.no_grad(): + outputs = self.submod(*args, **kwargs) + + # if single tensor, convert so it is always a list + if isinstance(outputs, torch.Tensor): + outputs = [outputs] + + # communicate meta outputs not real outputs for two reasons + # 1 - its faster (esp. since obj coll pickles tensor data!) + # 2 - avoid activating a cuda context for the src rank when unpickling on the recv end! + outputs_meta = tuple( + tree_map_only(torch.Tensor, lambda x: x.to("meta"), outputs) + ) + logger.debug( + "Shape inference: stage %s inputs %s, outputs %s", + self.stage_index, + self.inputs_meta, + outputs_meta, + ) + self._configure_outputs_meta(outputs_meta) + + # Passing outputs to the next stage: + # two cases- + # 1. Usually: use send/recv communication to pass the output + # 2. Special case: for V-schedules, 2 'adjacent' stages (e.g. stage 3, 4 in an 8-stage 4-rank V) + # pass their shape info via return value and function args rather than send/recv. + if ( + self.is_last + # if not last stage, then check if next stage is on the same rank + or self.stage_index_to_group_rank[self.stage_index + 1] == self.group_rank + ): + # Case (2) above: pass shape info via return value and caller passes it as args to next stage's + # _shape_inference call + logger.debug( + "Shape inference: stage %s skipping send to next stage", + self.stage_index, + ) + + else: + # Case (1): send shapes via send operation, and ensure not to return it to the caller + logger.debug( + "Shape inference: stage %s sending to stage %s", + self.stage_index, + self.stage_index + 1, + ) + dist.send_object_list( + [outputs_meta], + dst=dist.get_global_rank( + self.group or dist.distributed_c10d._get_default_group(), + self.stage_index_to_group_rank[self.stage_index + 1], + ), + group=self.group, + device=self.device, + use_batch=True, + ) + outputs_meta = tuple() + + return outputs_meta + + def _prepare_forward_infra( + self, + num_microbatches: int, + args: tuple[Any, ...], + kwargs: dict[str, Any] | None = None, + ) -> tuple[Any, ...]: + # TODO move self.device to an argument from step API (from its input tensors)? + assert num_microbatches is not None, "TODO fix num_microbatches" + + outputs: tuple[Any, ...] = tuple() + if self.inputs_meta is None: + outputs = self._shape_inference(args, kwargs) + + assert self.inputs_meta is not None + # Receive info during forward + # TODO: create args_recv_info lazily? (same needed for PipelineStage) + for chunk_id in range(num_microbatches): + if not self.is_first: + # We assume that we always receive from stage - 1 + recv_infos = tuple( + _RecvInfo( + f"recv_for_{self.stage_index}_from_{self.stage_index - 1}", + self.stage_index - 1, + _make_tensor_from_meta(inp, self.device), + ) + for inp in self.inputs_meta + ) + # In case there is backward pass, set requires_grad for receive buffers + if self.has_backward: + for r in recv_infos: + r.buffer.requires_grad_(True) + + self.args_recv_info[chunk_id] = recv_infos + else: + self.args_recv_info[chunk_id] = tuple( + _RootArgPlaceholder(i) for i in self.inputs_meta + ) + + # Send info during forward for each activation + # only need the rank that is being sent to + self.act_send_info: dict[int, list] = {} + + for idx in range(len(self.get_outputs_meta())): + # We assume we always send to stage + 1 + if not self.is_last: + self.act_send_info[idx] = [self.stage_index + 1] + else: + self.act_send_info[idx] = [] + + return outputs + + def _create_grad_recv_info( + self, + act_send_info: dict, + ) -> tuple[_RecvInfo, ...]: + grad_recv_info: tuple[_RecvInfo, ...] = () + if not self.is_last: + # Receiving gradients from multiple sources is not supported + # hence we only take the first destination + grad_recv_info = tuple( + _RecvInfo( + f"recv_grad_for_{self.stage_index}_from_{dst_list[0]}", + dst_list[0], + _make_tensor_from_meta(self.get_outputs_meta()[idx], self.device), + ) + for idx, dst_list in act_send_info.items() + ) + return grad_recv_info diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/rpc/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/rpc/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..adf901d6b6e3e693f69464e5c64d58a857ae6014 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/rpc/__init__.py @@ -0,0 +1,257 @@ +# mypy: allow-untyped-defs +import logging +import os +import threading +import warnings +from collections.abc import Generator +from datetime import timedelta +from urllib.parse import urlparse + +import torch +import torch.distributed as dist + + +__all__ = ["is_available"] + + +logger = logging.getLogger(__name__) + + +_init_counter = 0 +_init_counter_lock = threading.Lock() + + +def is_available() -> bool: + return hasattr(torch._C, "_rpc_init") + + +if is_available() and not torch._C._rpc_init(): + raise RuntimeError("Failed to initialize torch.distributed.rpc") + + +if is_available(): + _is_tensorpipe_available = hasattr( + torch._C._distributed_rpc, "_TensorPipeRpcBackendOptionsBase" + ) + + import numbers + + import torch.distributed.autograd as dist_autograd + from torch._C._distributed_c10d import Store + from torch._C._distributed_rpc import ( # noqa: F401 + _cleanup_python_rpc_handler, + _DEFAULT_INIT_METHOD, + _DEFAULT_RPC_TIMEOUT_SEC, + _delete_all_user_and_unforked_owner_rrefs, + _destroy_rref_context, + _disable_jit_rref_pickle, + _disable_server_process_global_profiler, + _enable_jit_rref_pickle, + _enable_server_process_global_profiler, + _get_current_rpc_agent, + _invoke_remote_builtin, + _invoke_remote_python_udf, + _invoke_remote_torchscript, + _invoke_rpc_builtin, + _invoke_rpc_python_udf, + _invoke_rpc_torchscript, + _is_current_rpc_agent_set, + _reset_current_rpc_agent, + _rref_context_get_debug_info, + _set_and_start_rpc_agent, + _set_profiler_node_id, + _set_rpc_timeout, + _UNSET_RPC_TIMEOUT, + enable_gil_profiling, + get_rpc_timeout, + PyRRef, + RemoteProfilerManager, + RpcAgent, + RpcBackendOptions, + WorkerInfo, + ) + + if _is_tensorpipe_available: + from torch._C._distributed_rpc import ( # noqa: F401 + _DEFAULT_NUM_WORKER_THREADS, + _TensorPipeRpcBackendOptionsBase, + TensorPipeAgent, + ) + + from . import api, backend_registry, functions + from .api import * # noqa: F401,F403 + from .backend_registry import BackendType + from .options import TensorPipeRpcBackendOptions # noqa: F401 + from .server_process_global_profiler import _server_process_global_profile + + rendezvous_iterator: Generator[tuple[Store, int, int], None, None] + + __all__ += ["init_rpc", "BackendType", "TensorPipeRpcBackendOptions"] + __all__ = __all__ + api.__all__ + backend_registry.__all__ # noqa: PLE0605 + + def init_rpc( + name, + backend=None, + rank=-1, + world_size=None, + rpc_backend_options=None, + ): + r""" + Initializes RPC primitives such as the local RPC agent + and distributed autograd, which immediately makes the current + process ready to send and receive RPCs. + + Args: + name (str): a globally unique name of this node. (e.g., + ``Trainer3``, ``ParameterServer2``, ``Master``, ``Worker1``) + Name can only contain number, alphabet, underscore, colon, + and/or dash, and must be shorter than 128 characters. + backend (BackendType, optional): The type of RPC backend + implementation. Supported values is + ``BackendType.TENSORPIPE`` (the default). + See :ref:`rpc-backends` for more information. + rank (int): a globally unique id/rank of this node. + world_size (int): The number of workers in the group. + rpc_backend_options (RpcBackendOptions, optional): The options + passed to the RpcAgent constructor. It must be an agent-specific + subclass of :class:`~torch.distributed.rpc.RpcBackendOptions` + and contains agent-specific initialization configurations. By + default, for all agents, it sets the default timeout to 60 + seconds and performs the rendezvous with an underlying process + group initialized using ``init_method = "env://"``, + meaning that environment variables ``MASTER_ADDR`` and + ``MASTER_PORT`` need to be set properly. See + :ref:`rpc-backends` for more information and find which options + are available. + """ + torch._C._log_api_usage_once("torch.distributed.init_rpc") + if backend is not None and not isinstance( + backend, backend_registry.BackendType + ): + raise TypeError("Argument backend must be a member of BackendType") + + if rpc_backend_options is not None and not isinstance( + rpc_backend_options, RpcBackendOptions + ): + raise TypeError( + "Argument rpc_backend_options must be an instance of RpcBackendOptions" + ) + + # Try to detect the backend from the options + if backend is None and rpc_backend_options is not None: + for candidate_backend in BackendType: + if isinstance( + rpc_backend_options, + type( + backend_registry.construct_rpc_backend_options( + candidate_backend + ) + ), + ): + backend = candidate_backend + break + else: + raise TypeError( + f"Could not infer backend for options {rpc_backend_options}" + ) + # Ignore type error because mypy doesn't handle dynamically generated type objects (#4865) + if backend != BackendType.TENSORPIPE: # type: ignore[attr-defined] + logger.warning( + "RPC was initialized with no explicit backend but with options " # type: ignore[attr-defined] + "corresponding to %(backend)s, hence that backend will be used " + "instead of the default BackendType.TENSORPIPE. To silence this " + "warning pass `backend=%(backend)s` explicitly.", + {"backend": backend}, + ) + + if backend is None: + backend = BackendType.TENSORPIPE # type: ignore[attr-defined] + + if rpc_backend_options is None: + # default construct a set of RPC backend options. + rpc_backend_options = backend_registry.construct_rpc_backend_options( + backend + ) + + # Create store, performs rendezvous for static RPC group. + if not world_size: + # If world_size is not set in construction and also not set in environment variables + # The store will be created for the dynamic group setting + store = dist._create_store_from_options(rpc_backend_options, rank) + else: + # This rendezvous state sometimes is destroyed before all processes + # finishing handshaking. To avoid that issue, we make it global to + # keep it alive. + global rendezvous_iterator + rendezvous_iterator = dist.rendezvous( + rpc_backend_options.init_method, rank=rank, world_size=world_size + ) + store, _, _ = next(rendezvous_iterator) + # Use same timeout as RPC. + store.set_timeout(timedelta(seconds=rpc_backend_options.rpc_timeout)) + + # Use a PrefixStore to distinguish multiple invocations. + with _init_counter_lock: + global _init_counter + store = dist.PrefixStore(str(f"rpc_prefix_{_init_counter}"), store) + _init_counter += 1 + + # Initialize autograd before RPC since _init_rpc_backend guarantees all + # processes sync via the store. If we initialize autograd after RPC, + # there could be a race where some nodes might have initialized autograd + # and others might not have. As a result, a node calling + # torch.distributed.autograd.backward() would run into errors since + # other nodes might not have been initialized. + dist_autograd._init(rank) + + _set_profiler_node_id(rank) + # Initialize RPC. + _init_rpc_backend(backend, store, name, rank, world_size, rpc_backend_options) + + def _validate_rpc_args(backend, store, name, rank, world_size, rpc_backend_options): + type_mapping = { + backend: backend_registry.BackendType, + store: dist.Store, + name: str, + rank: numbers.Integral, + # world_size can be None for a dynamic group + world_size: (numbers.Integral, type(None)), + rpc_backend_options: RpcBackendOptions, + } + for arg, arg_type in type_mapping.items(): + if not isinstance(arg, arg_type): # type: ignore[arg-type] + raise RuntimeError( + f"Argument {arg} must be of type {arg_type} but got type {type(arg)}" + ) + + def _init_rpc_backend( + backend=BackendType.TENSORPIPE, # type: ignore[attr-defined] + store=None, + name=None, + rank=-1, + world_size=None, + rpc_backend_options=None, + ): + _validate_rpc_args(backend, store, name, rank, world_size, rpc_backend_options) + + if _is_current_rpc_agent_set(): + raise RuntimeError("RPC is already initialized") + + # Initialize RPC. + rpc_agent = backend_registry.init_backend( + backend, + store=store, + name=name, + rank=rank, + world_size=world_size, + rpc_backend_options=rpc_backend_options, + ) + + api._init_rpc_states(rpc_agent) + + @api._require_initialized + def _get_debug_info(): + info = _rref_context_get_debug_info() + info.update(api._get_current_rpc_agent().get_debug_info()) + info.update(dist_autograd._get_debug_info()) + return info diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/rpc/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/rpc/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4ae6a4429bdb3b1b802d472973e771324a4f82dc Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/rpc/__pycache__/__init__.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/rpc/__pycache__/_utils.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/rpc/__pycache__/_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..740a78c3d6948e28ee60350a6ff444cdacc3cd60 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/rpc/__pycache__/_utils.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/rpc/__pycache__/api.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/rpc/__pycache__/api.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6d289ef13dea7067573856848460cfb4cde38e93 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/rpc/__pycache__/api.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/rpc/__pycache__/backend_registry.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/rpc/__pycache__/backend_registry.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..84a366772df5fd08e94ce7d876c1b020a6dc228e Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/rpc/__pycache__/backend_registry.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/rpc/__pycache__/constants.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/rpc/__pycache__/constants.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fe79413ec86f1105de9e1787dd762a9e95ec392a Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/rpc/__pycache__/constants.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/rpc/__pycache__/functions.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/rpc/__pycache__/functions.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..223799efb7319314149a103541ba3f4b10bfa803 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/rpc/__pycache__/functions.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/rpc/__pycache__/internal.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/rpc/__pycache__/internal.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..444cfac4bf6bbd7d40bebb2ec58dea1df945c374 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/rpc/__pycache__/internal.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/rpc/__pycache__/options.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/rpc/__pycache__/options.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f2643daf27e64465cb76d92d53007c855f4852cf Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/rpc/__pycache__/options.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/rpc/__pycache__/rref_proxy.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/rpc/__pycache__/rref_proxy.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0a31af343559a6b60528814e06e3e626ec614177 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/rpc/__pycache__/rref_proxy.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/rpc/__pycache__/server_process_global_profiler.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/rpc/__pycache__/server_process_global_profiler.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..19d484f4fe39f8baaf6f8e1f1c43ecb46632d02a Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/rpc/__pycache__/server_process_global_profiler.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/rpc/_testing/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/rpc/_testing/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0abd737becafbae33b0b63799c1eb43c913e1998 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/rpc/_testing/__init__.py @@ -0,0 +1,18 @@ +import torch + + +def is_available() -> bool: + return hasattr(torch._C, "_faulty_agent_init") + + +if is_available() and not torch._C._faulty_agent_init(): + raise RuntimeError("Failed to initialize torch.distributed.rpc._testing") + +if is_available(): + # Registers FAULTY_TENSORPIPE RPC backend. + from torch._C._distributed_rpc_testing import ( + FaultyTensorPipeAgent, + FaultyTensorPipeRpcBackendOptions, + ) + + from . import faulty_agent_backend_registry diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/rpc/_testing/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/rpc/_testing/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c52002a71bf9bc4e38822a903ed43093b8f34c9f Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/rpc/_testing/__pycache__/__init__.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/rpc/_testing/__pycache__/faulty_agent_backend_registry.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/rpc/_testing/__pycache__/faulty_agent_backend_registry.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..556a10046240b8740832c8181cb4fb0232af6dc2 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/rpc/_testing/__pycache__/faulty_agent_backend_registry.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/rpc/_testing/faulty_agent_backend_registry.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/rpc/_testing/faulty_agent_backend_registry.py new file mode 100644 index 0000000000000000000000000000000000000000..d04882e16e79a94f74ddc1350e94f547ef625611 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/rpc/_testing/faulty_agent_backend_registry.py @@ -0,0 +1,62 @@ +#!/usr/bin/env python3 +# mypy: allow-untyped-defs + +import torch.distributed as dist +import torch.distributed.rpc as rpc + + +def _faulty_tensorpipe_construct_rpc_backend_options_handler( + rpc_timeout, + init_method, + num_worker_threads, + messages_to_fail, + messages_to_delay, + num_fail_sends, + **kwargs, +): + from . import FaultyTensorPipeRpcBackendOptions + + return FaultyTensorPipeRpcBackendOptions( + num_worker_threads=num_worker_threads, + rpc_timeout=rpc_timeout, + init_method=init_method, + messages_to_fail=messages_to_fail, + messages_to_delay=messages_to_delay, + num_fail_sends=num_fail_sends, + ) + + +def _faulty_tensorpipe_init_backend_handler( + store, name, rank, world_size, rpc_backend_options +): + from torch.distributed.rpc import api + + from . import FaultyTensorPipeAgent, FaultyTensorPipeRpcBackendOptions + + if not isinstance(store, dist.Store): + raise TypeError(f"`store` must be a c10d::Store. {store}") + + if not isinstance(rpc_backend_options, FaultyTensorPipeRpcBackendOptions): + raise TypeError( + f"`rpc_backend_options` must be a `FaultyTensorPipeRpcBackendOptions`. {rpc_backend_options}" + ) + + agent = FaultyTensorPipeAgent( + store, + name, + rank, + world_size, + rpc_backend_options, + {}, # reverse_device_map + [], # devices + ) + api._init_rpc_states(agent) + + return agent + + +rpc.backend_registry.register_backend( + "FAULTY_TENSORPIPE", + _faulty_tensorpipe_construct_rpc_backend_options_handler, + _faulty_tensorpipe_init_backend_handler, +) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/rpc/_utils.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/rpc/_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..a0021ff1e43d8653df457cb99e7ea3637a508851 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/rpc/_utils.py @@ -0,0 +1,47 @@ +# mypy: allow-untyped-defs +import logging +from contextlib import contextmanager +from typing import cast + + +logger = logging.getLogger(__name__) + + +@contextmanager +def _group_membership_management(store, name, is_join): + token_key = "RpcGroupManagementToken" + join_or_leave = "join" if is_join else "leave" + my_token = f"Token_for_{name}_{join_or_leave}" + while True: + # Retrieve token from store to signal start of rank join/leave critical section + returned = store.compare_set(token_key, "", my_token).decode() + if returned == my_token: + # Yield to the function this context manager wraps + yield + # Finished, now exit and release token + # Update from store to signal end of rank join/leave critical section + store.set(token_key, "") + # Other will wait for this token to be set before they execute + store.set(my_token, "Done") + break + else: + # Store will wait for the token to be released + try: + store.wait([returned]) + except RuntimeError: + logger.error( + "Group membership token %s timed out waiting for %s to be released.", + my_token, + returned, + ) + raise + + +def _update_group_membership(worker_info, my_devices, reverse_device_map, is_join): + from . import api, TensorPipeAgent + + agent = cast(TensorPipeAgent, api._get_current_rpc_agent()) + ret = agent._update_group_membership( + worker_info, my_devices, reverse_device_map, is_join + ) + return ret diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/rpc/api.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/rpc/api.py new file mode 100644 index 0000000000000000000000000000000000000000..845ce0b7faf6c4cb1390c4d7089f745a1861f335 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/rpc/api.py @@ -0,0 +1,965 @@ +# mypy: allow-untyped-decorators +# mypy: allow-untyped-defs + +import collections +import contextlib +import functools +import inspect +import logging +import threading +from typing import Any, Generic, TYPE_CHECKING, TypeVar + +import torch +from torch._C._distributed_rpc import ( + _cleanup_python_rpc_handler, + _delete_all_user_and_unforked_owner_rrefs, + _destroy_rref_context, + _get_current_rpc_agent, + _invoke_remote_builtin, + _invoke_remote_python_udf, + _invoke_remote_torchscript, + _invoke_rpc_builtin, + _invoke_rpc_python_udf, + _invoke_rpc_torchscript, + _is_current_rpc_agent_set, + _reset_current_rpc_agent, + _set_and_start_rpc_agent, + get_rpc_timeout, + PyRRef, + RemoteProfilerManager, + WorkerInfo, +) +from torch.futures import Future + +from ._utils import _group_membership_management, _update_group_membership +from .constants import DEFAULT_SHUTDOWN_TIMEOUT, UNSET_RPC_TIMEOUT +from .internal import ( + _build_rpc_profiling_key, + _internal_rpc_pickler, + PythonUDF, + RPCExecMode, +) + + +__all__ = [ + "shutdown", + "get_worker_info", + "remote", + "rpc_sync", + "rpc_async", + "RRef", + "AllGatherStates", + "method_factory", + "new_method", +] + + +logger = logging.getLogger(__name__) + +# NB: Ignoring RRef leaks during shutdown. Without this, applications have to +# make sure there is no references to any RRef in the application code and +# Python GC has done its job to delete those RRefs. This is could result in bad +# debugging experiences especially when for large applications. Therefore, by +# default, we are going to ignore RRef leaks during shutdown. This is usually +# fine as shutdown means applications have done training and no longer care +# about states. +# +# To enable RRef leak checking, set this _ignore_rref_leak to False +_ignore_rref_leak = True +_default_pickler = _internal_rpc_pickler + + +@contextlib.contextmanager +def _use_rpc_pickler(rpc_pickler): + r""" + rpc_pickler: (.internal._InternalRPCPickler) Overrides the default RPC pickler + """ + global _default_pickler + _default_pickler = rpc_pickler + try: + yield + finally: + _default_pickler = _internal_rpc_pickler + + +def _require_initialized(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + if not _is_current_rpc_agent_set(): + raise RuntimeError( + "RPC has not been initialized. Call " + "torch.distributed.rpc.init_rpc first." + ) + return func(*args, **kwargs) + + return wrapper + + +class AllGatherStates: + def __init__(self): + # Each `gathered_objects` is an empty dict at beginning. + # The leader worker is elected as the first worker in a sorted worker + # name list. Whenever there is a worker entering `_all_gather()`, it + # runs `_gather_to_leader()` on the leader to add its own name and + # data obj to this dict. The leader also adds itself's name to the dict + # on calling `_all_gather()`. + # Once `set(gathered_objects.keys()) == _ALL_WORKER_NAMES`, the leader + # will broadcast the gathered dict to all follower workers and set their + # `gathered_objects` field and the `proceed_signal` field. + self.gathered_objects = {} + # All workers wait on this signal until it receives all gathered + # objects. + self.proceed_signal = threading.Event() + + +# States used by `def _all_gather()`. +# `_ALL_WORKER_NAMES` is initialized on initializing RPC layer. +_ALL_WORKER_NAMES: set[Any] = set() +_all_gather_dict_lock = threading.RLock() +_all_gather_sequence_id: dict[str, int] = {} +_all_gather_sequence_id_to_states: collections.defaultdict = collections.defaultdict( + AllGatherStates +) + + +def _init_rpc_states(agent): + worker_infos = agent.get_worker_infos() + global _ALL_WORKER_NAMES + _ALL_WORKER_NAMES = {worker_info.name for worker_info in worker_infos} + + # NB: backend implementation might have already set the rpc_agent. + if not _is_current_rpc_agent_set(): + _set_and_start_rpc_agent(agent) + + +def _gather_to_leader(sequence_id, worker_name, obj, worker_names=None): + with _all_gather_dict_lock: + if not worker_names: + worker_names = _ALL_WORKER_NAMES + assert worker_name in worker_names, ( + f"{worker_name} is not expected by leader." + ) + states = _all_gather_sequence_id_to_states[sequence_id] + assert worker_name not in states.gathered_objects, ( + f"{worker_name} reported intent sequence id {sequence_id} twice. " + ) + states.gathered_objects[worker_name] = obj + if worker_names == set(states.gathered_objects.keys()): + states.proceed_signal.set() + + +def _broadcast_to_followers(sequence_id, objects_map): + with _all_gather_dict_lock: + states = _all_gather_sequence_id_to_states[sequence_id] + + assert not states.proceed_signal.is_set(), ( + f"Termination signal sequence id {sequence_id} got set twice." + ) + states.gathered_objects = objects_map + states.proceed_signal.set() + + +_thread_local_var = threading.local() + + +@contextlib.contextmanager +def _wait_all(): + r""" + A context manager that collects all futures returned by ``rpc_async`` and + waits them on the context manager's exit; relieving the user of needing + to explicitly call wait. + + + Example:: + >>> # xdoctest: +SKIP("distributed") + >>> # On worker 0: + >>> import torch + >>> import torch.distributed.rpc as rpc + >>> rpc.init_rpc("worker0", rank=0, world_size=2) + >>> with rpc._wait_all(): + >>> fut_1 = rpc.rpc_async(dst, torch.add, (torch.ones(2, 2), 1)) + >>> fut_2 = rpc.rpc_async(dst, torch.add, (torch.ones(2, 2), 1)) + >>> #fut_1 and fut_2 are waited on + """ + _thread_local_var.future_list = [] + try: + yield + finally: + try: + torch.futures.wait_all(_thread_local_var.future_list) + finally: + del _thread_local_var.future_list + + +@_require_initialized +def _all_gather(obj, worker_names=None, timeout: float = UNSET_RPC_TIMEOUT): + r""" + This is similar to torch.distributed.all_gather(), but is using RPC. It + picks the worker with the smallest name (alphabetic order) as the leader. + Then all followers send their data ``obj`` to the leader. After the leader + has received all, it will broadcast the results back to all followers. This + function blocks until all workers have received the gathered results. + """ + if not worker_names: + assert _ALL_WORKER_NAMES is not None, ( + "`_ALL_WORKER_NAMES` is not initialized for `def _all_gather`." + ) + worker_names = _ALL_WORKER_NAMES + leader_name = min(worker_names) + + self_name = _get_current_rpc_agent().get_worker_info().name + + with _all_gather_dict_lock: + concat_names = "".join(sorted(worker_names)) + sequence_num = _all_gather_sequence_id.get(concat_names, 0) + _all_gather_sequence_id[concat_names] = sequence_num + 1 + sequence_id = concat_names + str(sequence_num) + + is_leader = leader_name == self_name + + if timeout == UNSET_RPC_TIMEOUT: + # Timeout is specified by agent for RPC calls + rpc_timeout = get_rpc_timeout() + # No timeout for signal + signal_timeout = None + elif timeout == DEFAULT_SHUTDOWN_TIMEOUT: + # No timeout for RPC + rpc_timeout = timeout + # No timeout for signal + signal_timeout = None + else: + # Signal and RPC timeout use the same timeout + signal_timeout = rpc_timeout = timeout + + # Phase 1: Followers send it's object to the leader + if is_leader: + _gather_to_leader(sequence_id, self_name, obj, worker_names) + else: + rpc_sync( + leader_name, + _gather_to_leader, + args=(sequence_id, self_name, obj, worker_names), + timeout=rpc_timeout, + ) + + with _all_gather_dict_lock: + states = _all_gather_sequence_id_to_states[sequence_id] + + # Timeout is either set by function parameter or None (which is indefinite) + states.proceed_signal.wait(timeout=signal_timeout) + + # Phase 2: Leader broadcast gathered results to all followers + # Leader's signal is the first to be unblocked, after receiving all + # followers' data objects. + if is_leader: + worker_name_to_response_future_dict = {} + for follower_name in worker_names - {leader_name}: + fut = rpc_async( + follower_name, + _broadcast_to_followers, + args=(sequence_id, states.gathered_objects), + timeout=rpc_timeout, + ) + worker_name_to_response_future_dict[follower_name] = fut + + errors = [] + for follower_name, fut in worker_name_to_response_future_dict.items(): + try: + fut.wait() + except RuntimeError as ex: + errors.append((follower_name, ex)) + + if errors: + raise RuntimeError( + f"Followers {[e[0] for e in errors]} timed out in _all_gather " + f"after {rpc_timeout:.2f} seconds. The first exception is {errors[0][1]}" + ) + + # Clean up for the states using the sequence_id + with _all_gather_dict_lock: + states = _all_gather_sequence_id_to_states.pop(sequence_id) + return states.gathered_objects + + +@_require_initialized +def _barrier(worker_names): + r""" + Synchronizes local and remote RPC processes. + + This will block until all local and remote RPC processes specified under worker_names + reach this method to wait for all outstanding work to complete. + + Args: + worker_names (List[str]): The set of workers to synchronize. + + """ + try: + _all_gather(None, set(worker_names)) + except RuntimeError: + logger.exception("Failed to complete barrier") + + +@_require_initialized +def _wait_all_workers(timeout=DEFAULT_SHUTDOWN_TIMEOUT): + r""" + Block until all local and remote RPC processes reach this method and wait + for all outstanding work to complete. Every RPC process must call this + method before exit to perform a graceful shutdown. This should be used to + terminate the RPC framework, and there is no guarantee that the RPC + framework will work after this method returns. + """ + try: + _all_gather(None, timeout=timeout) + except RuntimeError as ex: + logger.exception("Failed to respond to 'Shutdown Proceed' in time") + raise ex + + +@_require_initialized +def shutdown(graceful=True, timeout=DEFAULT_SHUTDOWN_TIMEOUT): + r""" + Perform a shutdown of the RPC agent, and then destroy the RPC agent. This + stops the local agent from accepting outstanding requests, and shuts + down the RPC framework by terminating all RPC threads. If ``graceful=True``, + this will block until all local and remote RPC processes reach this method + and wait for all outstanding work to complete. Otherwise, if + ``graceful=False``, this is a local shutdown, and it does not wait for other + RPC processes to reach this method. + + .. warning:: + For :class:`~torch.futures.Future` objects returned by + :meth:`~torch.distributed.rpc.rpc_async`, ``future.wait()`` should not + be called after ``shutdown()``. + + Args: + graceful (bool): Whether to do a graceful shutdown or not. If True, + this will 1) wait until there is no pending system + messages for ``UserRRefs`` and delete them; 2) block + until all local and remote RPC processes have reached + this method and wait for all outstanding work to + complete. + + Example:: + Make sure that ``MASTER_ADDR`` and ``MASTER_PORT`` are set properly + on both workers. Refer to :meth:`~torch.distributed.init_process_group` + API for more details. For example, + + export MASTER_ADDR=localhost + export MASTER_PORT=5678 + + Then run the following code in two different processes: + + >>> # xdoctest: +SKIP + >>> # On worker 0: + >>> import torch + >>> import torch.distributed.rpc as rpc + >>> rpc.init_rpc("worker0", rank=0, world_size=2) + >>> # do some work + >>> result = rpc.rpc_sync("worker1", torch.add, args=(torch.ones(1), 1)) + >>> # ready to shutdown + >>> rpc.shutdown() + + >>> # On worker 1: + >>> import torch.distributed.rpc as rpc + >>> rpc.init_rpc("worker1", rank=1, world_size=2) + >>> # wait for worker 0 to finish work, and then shutdown. + >>> rpc.shutdown() + """ + if graceful: + try: + agent = _get_current_rpc_agent() + from torch._C._distributed_rpc import TensorPipeAgent + + if not isinstance(agent, TensorPipeAgent) or agent.is_static_group: + _wait_all_workers(timeout) + _delete_all_user_and_unforked_owner_rrefs() + agent.join(shutdown=True, timeout=timeout) + else: + # This is a dynamic group so we need to grab the token for the operation + my_worker_info = agent.get_worker_info() + my_name = my_worker_info.name + with _group_membership_management(agent.store, my_name, False): + all_worker_infos = agent.get_worker_infos() + for worker in all_worker_infos: + if worker.name != my_name: + rpc_sync( + worker.name, + _update_group_membership, + args=(my_worker_info, [], {}, False), + ) + agent.join(shutdown=True, timeout=timeout) + finally: + # In case of errors, continue to complete the local shutdown. + _finalize_shutdown() + else: + _finalize_shutdown() + + +def _finalize_shutdown(): + try: + # This raises a `TORCH_CHECK()` exception on RRef leak detected. + _destroy_rref_context(_ignore_rref_leak) + finally: + _get_current_rpc_agent().shutdown() + # clean up python rpc handler in shutdown(), see comments in + # PythonRpcHandler::cleanup(), call it in python API because the + # cleanup() function has python dependency, it assumes python + # interpreter exists. + # No matter if RRef leak exception is raised, this clean-up code + # must run to avoid destruction segfault in Python 3.5. + # + # future.wait() should not be called after shutdown(). + # pythonRpcHandler is cleaned up in shutdown(), after + # shutdown(), python objects returned from rpc python call can not be + # resolved. + _cleanup_python_rpc_handler() + _reset_current_rpc_agent() + + +@_require_initialized +def get_worker_info(worker_name=None): + r""" + Get :class:`~torch.distributed.rpc.WorkerInfo` of a given worker name. + Use this :class:`~torch.distributed.rpc.WorkerInfo` to avoid passing an + expensive string on every invocation. + + Args: + worker_name (str): the string name of a worker. If ``None``, return the + the id of the current worker. (default ``None``) + + Returns: + :class:`~torch.distributed.rpc.WorkerInfo` instance for the given + ``worker_name`` or :class:`~torch.distributed.rpc.WorkerInfo` of the + current worker if ``worker_name`` is ``None``. + """ + if worker_name is not None: + return _get_current_rpc_agent().get_worker_info(worker_name) + else: + return _get_current_rpc_agent().get_worker_info() + + +def _to_worker_info(to): + if isinstance(to, WorkerInfo): + return to + elif isinstance(to, (str, int)): + return get_worker_info(to) + else: + raise ValueError(f"Cannot get WorkerInfo from name {to}") + + +def _rref_typeof_on_owner(rref, blocking: bool = True): + rref_type = type(rref.local_value()) + if blocking: + return rref_type + else: + # Wrap result into a completed Future. This is so that if blocking=`False` + # is specified, we return a future regardless of if this call is on user + # or owner. + future = Future[type]() + future.set_result(rref_type) + return future + + +def _rref_typeof_on_user( + rref, timeout: float = UNSET_RPC_TIMEOUT, blocking: bool = True +): + fut = rpc_async(rref.owner(), _rref_typeof_on_owner, args=(rref,), timeout=timeout) + if blocking: + return fut.wait() + else: + return fut + + +T = TypeVar("T") +# pyrefly: ignore [invalid-annotation] +GenericWithOneTypeVar = Generic[T] + + +if TYPE_CHECKING: + + class RRef(PyRRef[T], Generic[T]): + pass + +else: + try: + # Combine the implementation class and the type class. + class RRef(PyRRef, Generic[T]): + pass + + except TypeError: + # TypeError: metaclass conflict: the metaclass of a derived class + # must be a (non-strict) subclass of the metaclasses of all its bases + # Mypy doesn't understand __class__ (mypy bug #4177) + class RRefMeta(PyRRef.__class__, GenericWithOneTypeVar.__class__): # type: ignore[name-defined, misc, valid-type] + pass + + # Combine the implementation class and the type class. + # Types for classes expecting a certain generic parameter (mypy bug #7791) + class RRef(PyRRef, GenericWithOneTypeVar, metaclass=RRefMeta): # type: ignore[misc, no-redef, valid-type] + pass + + +# Install docstrings from `PyRRef` to `RRef`. +# +# This is for the fact that pybind11 generates the parameter +# `self` as type `rpc.PyRRef`, so a `:inherited-members:` +# under `.. autoclass:: RRef` does not work. +# we have to do the following process to replace `rpc.PyRRef` with `rpc.RRef`. +# +def method_factory(method_name, docstring): + def method(self, *args, **kwargs): + return getattr(super(RRef, self), method_name)(*args, **kwargs) + + if method.__doc__: + method.__doc__ = docstring + return method + + +for method_name, method in inspect.getmembers(PyRRef): + # Ignore magic methods, except "__str__". + if method_name.startswith("_") and method_name != "__str__": + continue + + # Get pybind11 generated docstring. + # It's like, + """ + to_here(self: torch.distributed.rpc.PyRRef, timeout: float=-1.0) -> object + + Blocking call that copies the value of the RRef from the owner + to the local node and returns it. If the current node is the + owner, returns a reference to the local value. + """ + docstring = getattr(method, "__doc__", None) + assert docstring is not None, "RRef user-facing methods should all have docstrings." + + # Do surgery on pybind11 generated docstrings. + docstring = docstring.replace( + "torch.distributed.rpc.PyRRef", "torch.distributed.rpc.RRef" + ) + + # Attach user-facing RRef method with modified docstring. + new_method = method_factory(method_name, docstring) + setattr(RRef, method_name, new_method) + + +@_require_initialized +def remote(to, func, args=None, kwargs=None, timeout=UNSET_RPC_TIMEOUT): + r""" + Make a remote call to run ``func`` on worker ``to`` and return an + :class:`~torch.distributed.rpc.RRef` to the result value immediately. + Worker ``to`` will be the owner of the returned + :class:`~torch.distributed.rpc.RRef`, and the worker calling ``remote`` is + a user. The owner manages the global reference count of its + :class:`~torch.distributed.rpc.RRef`, and the owner + :class:`~torch.distributed.rpc.RRef` is only destructed when globally there + are no living references to it. + + Args: + to (str or WorkerInfo or int): name/rank/``WorkerInfo`` of the destination worker. + func (Callable): a callable function, such as Python callables, builtin + operators (e.g. :meth:`~torch.add`) and annotated + TorchScript functions. + args (tuple): the argument tuple for the ``func`` invocation. + kwargs (dict): is a dictionary of keyword arguments for the ``func`` + invocation. + + timeout (float, optional): timeout in seconds for this remote call. If the + creation of this + :class:`~torch.distributed.rpc.RRef` on worker + ``to`` is not successfully processed on this + worker within this timeout, then the next time + there is an attempt to use the RRef (such as + ``to_here()``), a timeout will be raised + indicating this failure. A value of 0 indicates + an infinite timeout, i.e. a timeout error will + never be raised. If not provided, the default + value set during initialization or with + ``_set_rpc_timeout`` is used. + + Returns: + A user :class:`~torch.distributed.rpc.RRef` instance to the result + value. Use the blocking API :meth:`torch.distributed.rpc.RRef.to_here` + to retrieve the result value locally. + + .. warning :: + The ``remote`` API does not copy storages of argument tensors until + sending them over the wire, which could be done by a different thread + depending on the RPC backend type. The caller should make sure that the + contents of those tensors stay intact until the returned RRef is + confirmed by the owner, which can be checked using the + :meth:`torch.distributed.rpc.RRef.confirmed_by_owner` API. + + .. warning :: + Errors such as timeouts for the ``remote`` API are handled on a + best-effort basis. This means that when remote calls initiated by + ``remote`` fail, such as with a timeout error, we take a best-effort + approach to error handling. This means that errors are handled and set + on the resulting RRef on an asynchronous basis. If the RRef has not been + used by the application before this handling (such as ``to_here`` or + fork call), then future uses of the ``RRef`` will appropriately raise + errors. However, it is possible that the user application will use the + ``RRef`` before the errors are handled. In this case, errors may not be + raised as they have not yet been handled. + + Example:: + + Make sure that ``MASTER_ADDR`` and ``MASTER_PORT`` are set properly + on both workers. Refer to :meth:`~torch.distributed.init_process_group` + API for more details. For example, + + export MASTER_ADDR=localhost + export MASTER_PORT=5678 + + Then run the following code in two different processes: + + >>> # xdoctest: +SKIP + >>> # On worker 0: + >>> import torch + >>> import torch.distributed.rpc as rpc + >>> rpc.init_rpc("worker0", rank=0, world_size=2) + >>> rref1 = rpc.remote("worker1", torch.add, args=(torch.ones(2), 3)) + >>> rref2 = rpc.remote("worker1", torch.add, args=(torch.ones(2), 1)) + >>> x = rref1.to_here() + rref2.to_here() + >>> rpc.shutdown() + + >>> # On worker 1: + >>> import torch.distributed.rpc as rpc + >>> rpc.init_rpc("worker1", rank=1, world_size=2) + >>> rpc.shutdown() + + Below is an example of running a TorchScript function using RPC. + + >>> # On both workers: + >>> @torch.jit.script + >>> def my_script_add(tensor: torch.Tensor, scalar: int): + >>> return torch.add(tensor, scalar) + + >>> # On worker 0: + >>> import torch.distributed.rpc as rpc + >>> rpc.init_rpc("worker0", rank=0, world_size=2) + >>> rref = rpc.remote("worker1", my_script_add, args=(torch.ones(2), 3)) + >>> rref.to_here() + >>> rpc.shutdown() + + >>> # On worker 1: + >>> import torch.distributed.rpc as rpc + >>> rpc.init_rpc("worker1", rank=1, world_size=2) + >>> rpc.shutdown() + """ + torch._C._log_api_usage_once("torch.distributed.rpc_remote") + qualified_name = torch.jit._builtins._find_builtin(func) + dst_worker_info = _to_worker_info(to) + should_profile = _get_should_profile() + + ctx_manager = _enable_rpc_profiler( + should_profile, qualified_name, func, RPCExecMode.REMOTE, dst_worker_info + ) + + with ctx_manager as rf: + args = args if args else () + kwargs = kwargs if kwargs else {} + + is_async_exec = hasattr(func, "_wrapped_async_rpc_function") + + if is_async_exec: + wrapped = func._wrapped_async_rpc_function + if isinstance(wrapped, torch.jit.ScriptFunction): + func = wrapped + + if qualified_name is not None: + rref = _invoke_remote_builtin( + dst_worker_info, qualified_name, timeout, *args, **kwargs + ) + elif isinstance(func, torch.jit.ScriptFunction): + rref = _invoke_remote_torchscript( + dst_worker_info.name, + torch._jit_internal._qualified_name(func), + timeout, + is_async_exec, + *args, + **kwargs, + ) + else: + (pickled_python_udf, tensors) = _default_pickler.serialize( + PythonUDF(func, args, kwargs) + ) + rref = _invoke_remote_python_udf( + dst_worker_info, pickled_python_udf, tensors, timeout, is_async_exec + ) + # attach profiling information + if should_profile: + assert torch.autograd._profiler_enabled() + assert rf is not None + fut = rf._call_end_callbacks_on_future(rref._get_future()) + rref._set_profiling_future(fut) + + return rref + + +def _invoke_rpc( + to, func, rpc_type, args=None, kwargs=None, rpc_timeout: float = UNSET_RPC_TIMEOUT +): + if not callable(func): + raise TypeError("function should be callable.") + + qualified_name = torch.jit._builtins._find_builtin(func) + dst_worker_info = _to_worker_info(to) + + should_profile = _get_should_profile() + + ctx_manager = _enable_rpc_profiler( + should_profile, qualified_name, func, rpc_type, dst_worker_info + ) + + with ctx_manager as rf: + args = args if args else () + kwargs = kwargs if kwargs else {} + + is_async_exec = hasattr(func, "_wrapped_async_rpc_function") + + if is_async_exec: + # pyrefly: ignore [missing-attribute] + wrapped = func._wrapped_async_rpc_function + if isinstance(wrapped, torch.jit.ScriptFunction): + func = wrapped + + if qualified_name is not None: + fut = _invoke_rpc_builtin( + dst_worker_info, qualified_name, rpc_timeout, *args, **kwargs + ) + elif isinstance(func, torch.jit.ScriptFunction): + fut = _invoke_rpc_torchscript( + dst_worker_info.name, + torch._jit_internal._qualified_name(func), + args, + kwargs, + rpc_timeout, + is_async_exec, + ) + else: + (pickled_python_udf, tensors) = _default_pickler.serialize( + PythonUDF(func, args, kwargs) + ) + fut = _invoke_rpc_python_udf( + dst_worker_info, pickled_python_udf, tensors, rpc_timeout, is_async_exec + ) + if should_profile: + assert torch.autograd._profiler_enabled() + assert rf is not None + # Schedule profiling callbacks to run when the future completes. + # This returns a future that is completed when the original future + # completes and the profiling callbacks have been completed as well, + # to guarantee that fut.wait() completes the profiling. This new + # future will contain the same value as the original future. + fut = rf._call_end_callbacks_on_future(fut) + return fut + + +@_require_initialized +def rpc_sync(to, func, args=None, kwargs=None, timeout: float = UNSET_RPC_TIMEOUT): + r""" + Make a blocking RPC call to run function ``func`` on worker ``to``. RPC + messages are sent and received in parallel to execution of Python code. This + method is thread-safe. + + Args: + to (str or WorkerInfo or int): name/rank/``WorkerInfo`` of the destination worker. + func (Callable): a callable function, such as Python callables, builtin + operators (e.g. :meth:`~torch.add`) and annotated + TorchScript functions. + args (tuple): the argument tuple for the ``func`` invocation. + kwargs (dict): is a dictionary of keyword arguments for the ``func`` + invocation. + timeout (float, optional): timeout in seconds to use for this RPC. If + the RPC does not complete in this amount of + time, an exception indicating it has + timed out will be raised. A value of 0 + indicates an infinite timeout, i.e. a timeout + error will never be raised. If not provided, + the default value set during initialization + or with ``_set_rpc_timeout`` is used. + + Returns: + Returns the result of running ``func`` with ``args`` and ``kwargs``. + + Example:: + Make sure that ``MASTER_ADDR`` and ``MASTER_PORT`` are set properly + on both workers. Refer to :meth:`~torch.distributed.init_process_group` + API for more details. For example, + + export MASTER_ADDR=localhost + export MASTER_PORT=5678 + + Then run the following code in two different processes: + + >>> # xdoctest: +SKIP + >>> # On worker 0: + >>> import torch + >>> import torch.distributed.rpc as rpc + >>> rpc.init_rpc("worker0", rank=0, world_size=2) + >>> ret = rpc.rpc_sync("worker1", torch.add, args=(torch.ones(2), 3)) + >>> rpc.shutdown() + + >>> # On worker 1: + >>> import torch.distributed.rpc as rpc + >>> rpc.init_rpc("worker1", rank=1, world_size=2) + >>> rpc.shutdown() + + Below is an example of running a TorchScript function using RPC. + + >>> # On both workers: + >>> @torch.jit.script + >>> def my_script_add(tensor: torch.Tensor, scalar: int): + >>> return torch.add(tensor, scalar) + + >>> # On worker 0: + >>> import torch.distributed.rpc as rpc + >>> rpc.init_rpc("worker0", rank=0, world_size=2) + >>> ret = rpc.rpc_sync("worker1", my_script_add, args=(torch.ones(2), 3)) + >>> rpc.shutdown() + + >>> # On worker 1: + >>> import torch.distributed.rpc as rpc + >>> rpc.init_rpc("worker1", rank=1, world_size=2) + >>> rpc.shutdown() + + """ + torch._C._log_api_usage_once("torch.distributed.rpc_sync") + fut = _invoke_rpc(to, func, RPCExecMode.SYNC, args, kwargs, timeout) + return fut.wait() + + +@_require_initialized +def rpc_async(to, func, args=None, kwargs=None, timeout=UNSET_RPC_TIMEOUT): + r""" + Make a non-blocking RPC call to run function ``func`` on worker ``to``. RPC + messages are sent and received in parallel to execution of Python code. This + method is thread-safe. This method will immediately return a + :class:`~torch.futures.Future` that can be awaited on. + + Args: + to (str or WorkerInfo or int): name/rank/``WorkerInfo`` of the destination worker. + func (Callable): a callable function, such as Python callables, builtin + operators (e.g. :meth:`~torch.add`) and annotated + TorchScript functions. + args (tuple): the argument tuple for the ``func`` invocation. + kwargs (dict): is a dictionary of keyword arguments for the ``func`` + invocation. + timeout (float, optional): timeout in seconds to use for this RPC. If + the RPC does not complete in this amount of + time, an exception indicating it has + timed out will be raised. A value of 0 + indicates an infinite timeout, i.e. a timeout + error will never be raised. If not provided, + the default value set during initialization + or with ``_set_rpc_timeout`` is used. + + + Returns: + Returns a :class:`~torch.futures.Future` object that can be waited + on. When completed, the return value of ``func`` on ``args`` and + ``kwargs`` can be retrieved from the :class:`~torch.futures.Future` + object. + + .. warning :: + Using GPU tensors as arguments or return values of ``func`` is not + supported since we don't support sending GPU tensors over the wire. You + need to explicitly copy GPU tensors to CPU before using them as + arguments or return values of ``func``. + + .. warning :: + The ``rpc_async`` API does not copy storages of argument tensors until + sending them over the wire, which could be done by a different thread + depending on the RPC backend type. The caller should make sure that the + contents of those tensors stay intact until the returned + :class:`~torch.futures.Future` completes. + + Example:: + Make sure that ``MASTER_ADDR`` and ``MASTER_PORT`` are set properly + on both workers. Refer to :meth:`~torch.distributed.init_process_group` + API for more details. For example, + + export MASTER_ADDR=localhost + export MASTER_PORT=5678 + + Then run the following code in two different processes: + + >>> # xdoctest: +SKIP + >>> # On worker 0: + >>> import torch + >>> import torch.distributed.rpc as rpc + >>> rpc.init_rpc("worker0", rank=0, world_size=2) + >>> fut1 = rpc.rpc_async("worker1", torch.add, args=(torch.ones(2), 3)) + >>> fut2 = rpc.rpc_async("worker1", min, args=(1, 2)) + >>> result = fut1.wait() + fut2.wait() + >>> rpc.shutdown() + + >>> # On worker 1: + >>> import torch.distributed.rpc as rpc + >>> rpc.init_rpc("worker1", rank=1, world_size=2) + >>> rpc.shutdown() + + Below is an example of running a TorchScript function using RPC. + + >>> # On both workers: + >>> @torch.jit.script + >>> def my_script_add(tensor: torch.Tensor, scalar: int): + >>> return torch.add(tensor, scalar) + + >>> # On worker 0: + >>> import torch.distributed.rpc as rpc + >>> rpc.init_rpc("worker0", rank=0, world_size=2) + >>> fut = rpc.rpc_async("worker1", my_script_add, args=(torch.ones(2), 3)) + >>> ret = fut.wait() + >>> rpc.shutdown() + + >>> # On worker 1: + >>> import torch.distributed.rpc as rpc + >>> rpc.init_rpc("worker1", rank=1, world_size=2) + >>> rpc.shutdown() + """ + torch._C._log_api_usage_once("torch.distributed.rpc_async") + fut = _invoke_rpc(to, func, RPCExecMode.ASYNC, args, kwargs, timeout) + if hasattr(_thread_local_var, "future_list"): + _thread_local_var.future_list.append(fut) + return fut + + +def _get_should_profile(): + # Legacy profiler should be enabled. RPC profiling is not supported with + # Kineto profiler. + ActiveProfilerType = torch._C._profiler.ActiveProfilerType + return ( + torch.autograd._profiler_enabled() + and torch._C._autograd._profiler_type() == ActiveProfilerType.LEGACY # type: ignore[attr-defined] + ) + + +def _enable_rpc_profiler( + should_profile, qualified_name, func, rpc_type, dst_worker_info +): + ctx_manager = contextlib.nullcontext() + + if should_profile: + # Create appropriate string representation based on type of func + # (builtin, script, python) + if qualified_name is None: + func_name = ( + torch._jit_internal._qualified_name(func) + if isinstance(func, torch.jit.ScriptFunction) + else func.__qualname__ + ) + else: + func_name = qualified_name + # Build RPC profiling key. + rpc_profiling_key = _build_rpc_profiling_key( + rpc_type, + func_name, + get_worker_info().name, + dst_worker_info.name, + ) + RemoteProfilerManager.set_current_profiling_key(rpc_profiling_key) + # Mypy doesn't support re-def of a variable not in the same block (#1174) + ctx_manager = torch.autograd.profiler.record_function(rpc_profiling_key) # type: ignore[assignment] + + return ctx_manager diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/rpc/backend_registry.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/rpc/backend_registry.py new file mode 100644 index 0000000000000000000000000000000000000000..3f30252bd825665280a9b4cf96613bd6a676d190 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/rpc/backend_registry.py @@ -0,0 +1,431 @@ +# mypy: allow-untyped-defs + + +import collections +import enum +from typing import cast + +import torch +import torch.distributed as dist + +from . import api, constants as rpc_constants +from ._utils import _group_membership_management, _update_group_membership + + +__all__ = [ + "backend_registered", + "register_backend", + "construct_rpc_backend_options", + "init_backend", + "BackendValue", + "BackendType", +] + +BackendValue = collections.namedtuple( + "BackendValue", ["construct_rpc_backend_options_handler", "init_backend_handler"] +) + + +def _backend_type_repr(self): + return "BackendType." + self.name + + +_backend_type_doc = """ + An enum class of available backends. + + PyTorch ships with a builtin ``BackendType.TENSORPIPE`` backend. + Additional ones can be registered using the + :func:`~torch.distributed.rpc.backend_registry.register_backend` function. +""" + +# Create an enum type, `BackendType`, with empty members. +# Can't handle Function Enum API (mypy bug #9079) +BackendType = enum.Enum(value="BackendType", names={}) # type: ignore[misc] +# Unable to assign a function a method (mypy bug #2427) +BackendType.__repr__ = _backend_type_repr # type: ignore[assignment] + +if BackendType.__doc__: + BackendType.__doc__ = _backend_type_doc + + +def backend_registered(backend_name): + """ + Checks if backend_name is registered as an RPC backend. + + Args: + backend_name (str): string to identify the RPC backend. + Returns: + True if the backend has been registered with ``register_backend``, else + False. + """ + return backend_name in BackendType.__members__ + + +def register_backend( + backend_name, construct_rpc_backend_options_handler, init_backend_handler +): + """Registers a new RPC backend. + + Args: + backend_name (str): backend string to identify the handler. + construct_rpc_backend_options_handler (function): + Handler that is invoked when + rpc_backend.construct_rpc_backend_options(**dict) is called. + init_backend_handler (function): Handler that is invoked when the + `_init_rpc_backend()` function is called with a backend. + This returns the agent. + """ + global BackendType + if backend_registered(backend_name): + raise RuntimeError(f"RPC backend {backend_name}: already registered") + # Create a new enum type, `BackendType`, with extended members. + existing_enum_dict = {member.name: member.value for member in BackendType} + extended_enum_dict = dict( + { + backend_name: BackendValue( + construct_rpc_backend_options_handler=construct_rpc_backend_options_handler, + init_backend_handler=init_backend_handler, + ) + }, + **existing_enum_dict, + ) + # Can't handle Function Enum API (mypy bug #9079) + BackendType = enum.Enum(value="BackendType", names=extended_enum_dict) # type: ignore[misc] + # Unable to assign a function a method (mypy bug #2427) + BackendType.__repr__ = _backend_type_repr # type: ignore[assignment] + if BackendType.__doc__: + BackendType.__doc__ = _backend_type_doc + # pyrefly: ignore [unsupported-operation] + return BackendType[backend_name] + + +def construct_rpc_backend_options( + backend, + rpc_timeout=rpc_constants.DEFAULT_RPC_TIMEOUT_SEC, + init_method=rpc_constants.DEFAULT_INIT_METHOD, + **kwargs, +): + return backend.value.construct_rpc_backend_options_handler( + rpc_timeout, init_method, **kwargs + ) + + +def init_backend(backend, *args, **kwargs): + return backend.value.init_backend_handler(*args, **kwargs) + + +def _init_process_group(store, rank, world_size): + # Initialize ProcessGroup. + process_group_timeout = rpc_constants.DEFAULT_PROCESS_GROUP_TIMEOUT + + # We're using a bunch of private APIs here since `new_group` requires the + # default group to be initialized. + group = dist.ProcessGroupGloo(store, rank, world_size, process_group_timeout) + + assert group is not None, "Failed to initialize default ProcessGroup." + + if (rank != -1) and (rank != group.rank()): + raise RuntimeError(f"rank argument {rank} doesn't match pg rank {group.rank()}") + if (world_size != -1) and (world_size != group.size()): + raise RuntimeError( + f"world_size argument {world_size} doesn't match pg size {group.size()}" + ) + return group + + +def _tensorpipe_construct_rpc_backend_options_handler( + rpc_timeout, + init_method, + num_worker_threads=rpc_constants.DEFAULT_NUM_WORKER_THREADS, + _transports=None, + _channels=None, + **kwargs, +): + from . import TensorPipeRpcBackendOptions + + return TensorPipeRpcBackendOptions( + rpc_timeout=rpc_timeout, + init_method=init_method, + num_worker_threads=num_worker_threads, + _transports=_transports, + _channels=_channels, + ) + + +def _tensorpipe_validate_devices(devices, device_count): + return all( + d.type == "cpu" or (d.type == "cuda" and 0 <= d.index < device_count) + for d in devices + ) + + +# detect if any worker has invalid device_map configurations, and return +# reverse device maps +def _tensorpipe_exchange_and_check_all_device_maps( + my_name, my_device_count, my_device_maps, my_devices, group +): + gathered: list[ + tuple[str, int, dict[str, dict[torch.device, torch.device]], list[torch.device]] + ] = [("", 0, {}, []) for _ in range(group.size())] + dist.all_gather_object( + gathered, (my_name, my_device_count, my_device_maps, my_devices), group + ) + all_names = [name for name, _, _, _ in gathered] + all_device_counts = {name: count for name, count, _, _ in gathered} + all_device_maps = {name: map_ for name, _, map_, _ in gathered} + all_devices = {name: devices for name, _, _, devices in gathered} + + _validate_device_maps(all_names, all_device_counts, all_device_maps, all_devices) + + # passed all checked, construct reverse mapping and get list of devices handled by this agent + reverse_device_maps = _create_reverse_mapping(my_name, all_names, all_device_maps) + my_devices = _create_device_list(my_devices, my_device_maps, reverse_device_maps) + return reverse_device_maps, my_devices + + +def _validate_device_maps( + all_names, all_device_counts, all_device_maps, all_devices, is_static_group=True +): + for node in all_names: + devices = all_devices[node] + if len(set(devices)) != len(devices): + raise ValueError(f"Node {node} has duplicated devices\ndevices = {devices}") + if not _tensorpipe_validate_devices(devices, all_device_counts[node]): + raise ValueError( + f"Node {node} has devices with invalid indices\n" + f"devices = {devices}\n" + f"device count = {all_device_counts[node]}" + ) + + for source_node in all_names: + # For dynamic group (non-static) do not check the target node name since it may not have joined yet + if is_static_group and not set(all_device_maps[source_node].keys()).issubset( + all_names + ): + raise ValueError( + f"Node {source_node} has invalid target node names in its device maps\n" + f"device maps = {all_device_maps[source_node].keys()}\n" + f"node names = {all_names}" + ) + for target_node, map_ in all_device_maps[source_node].items(): + if len(set(map_.values())) != len(map_): + raise ValueError( + f"Node {source_node} has duplicated target devices " + f"in its device map for {target_node}\n" + f"device map = {map_}" + ) + if all_devices[source_node]: + if not set(map_.keys()).issubset(all_devices[source_node]): + raise ValueError( + f"Node {source_node} has unexpected source devices " + f"in its device map for {target_node}\n" + f"device map = {map_}\n" + f"devices = {all_devices[source_node]}" + ) + elif not _tensorpipe_validate_devices( + map_.keys(), all_device_counts[source_node] + ): + raise ValueError( + f"Node {source_node} has source devices with invalid indices " + f"in its device map for {target_node}\n" + f"device map = {map_}\n" + f"device count = {all_device_counts[source_node]}" + ) + if all_devices.get(target_node, []): + if not set(map_.values()).issubset(all_devices[target_node]): + raise ValueError( + f"Node {source_node} has unexpected target devices " + f"in its device map for {target_node}\n" + f"device map = {map_}\n" + f"devices = {all_devices[target_node]}" + ) + elif target_node in all_device_counts and not _tensorpipe_validate_devices( + map_.values(), all_device_counts[target_node] + ): + raise ValueError( + f"Node {source_node} has target devices with invalid indices " + f"in its device map for {target_node}\n" + f"device map = {map_}\n" + f"device count = {all_device_counts[target_node]}" + ) + + +def _create_device_list(my_devices, my_device_maps, reverse_device_maps): + if not my_devices: + devices_set: set[torch.device] = set() + for map_ in my_device_maps.values(): + devices_set.update(map_.keys()) + for map_ in reverse_device_maps.values(): + devices_set.update(map_.keys()) + devices_set.discard(torch.device("cpu")) + my_devices = list(devices_set) + my_devices = sorted(my_devices, key=lambda d: d.index) + return my_devices + + +def _create_reverse_mapping(my_name, all_names, all_device_maps): + reverse_device_maps: dict[str, dict[torch.device, torch.device]] = {} + for node in all_names: + if my_name in all_device_maps[node]: + reverse_device_maps[node] = { + v: k for k, v in all_device_maps[node][my_name].items() + } + return reverse_device_maps + + +def _get_device_infos(): + from . import TensorPipeAgent + + agent = cast(TensorPipeAgent, api._get_current_rpc_agent()) + opts = agent._get_backend_options() + device_count = torch.cuda.device_count() + if torch.cuda.is_available() and opts.devices: + torch.cuda.init() + return device_count, opts.device_maps, opts.devices + + +def _set_devices_and_reverse_device_map(agent): + from . import TensorPipeAgent + + agent = cast(TensorPipeAgent, agent) + # Group state is retrieved from local agent + # On initialization, tensorpipe agent retrieves information from all existing workers, so group state is valid + my_worker_info = agent.get_worker_info() + my_name = my_worker_info.name + all_worker_infos = agent.get_worker_infos() + # One round to get device_maps of all workers and construct reverse device maps + all_device_counts, all_device_maps, all_devices, all_names = {}, {}, {}, [] + for worker_info in all_worker_infos: + worker_name = worker_info.name + if worker_name != my_name: + # TODO: make async? + device_count, device_map, devices = api.rpc_sync( + worker_name, _get_device_infos + ) + else: + opts = agent._get_backend_options() + device_count, device_map, devices = ( + torch.cuda.device_count(), + opts.device_maps, + opts.devices, + ) + all_device_counts[worker_name] = device_count + all_device_maps[worker_name] = device_map + all_devices[worker_name] = devices + all_names.append(worker_name) + + _validate_device_maps( + all_names, + all_device_counts, + all_device_maps, + all_devices, + is_static_group=False, + ) + reverse_device_maps = _create_reverse_mapping(my_name, all_names, all_device_maps) + + # Perform RPC call to all workers, including itself, to include newly joined worker information and device maps + for worker_name in all_names: + # Set device list for each worker + all_devices[worker_name] = _create_device_list( + all_devices[worker_name], all_device_maps[worker_name], reverse_device_maps + ) + api.rpc_sync( + worker_name, + _update_group_membership, + args=(my_worker_info, all_devices[worker_name], reverse_device_maps, True), + ) + + +def _tensorpipe_init_backend_handler( + store, name, rank, world_size, rpc_backend_options +): + from . import TensorPipeAgent, TensorPipeRpcBackendOptions + + if not isinstance(store, dist.Store): + raise TypeError(f"`store` must be a c10d::Store. {store}") + + if not isinstance(rpc_backend_options, TensorPipeRpcBackendOptions): + raise TypeError( + f"`rpc_backend_options` must be a `TensorPipeRpcBackendOptions`. {rpc_backend_options}" + ) + + device_count = torch.cuda.device_count() + + is_static_group = bool(world_size) + # world_size is specified so this is a static group (ranks cannot join and leave) + if is_static_group: + # The agent's join method is required to behave like a barrier and perform + # collective operations, for which it relies on a process group, instead of + # re-implementing this on top of RPCs. + group = _init_process_group(store, rank, world_size) + + reverse_device_maps, devices = _tensorpipe_exchange_and_check_all_device_maps( + name, + device_count, + rpc_backend_options.device_maps, + rpc_backend_options.devices, + group, + ) + + if torch.cuda.is_available() and devices: + # It's necessary to initialize PyTorch CUDA states here (e.g., + # CUDACachingAllocator). If this is missing, we could hit errors like + # "allocator not initialized", because other processes might send + # CUDA-related RPC request to this process before user code in this + # process initializes its PyTorch CUDA states. + torch.cuda.init() + + # TODO: add try-except and destroy _agent in all processes if any fails. + agent = TensorPipeAgent( + store, + name, + rank, + world_size, + rpc_backend_options, + reverse_device_maps, + devices, + ) + + api._init_rpc_states(agent) + + # Run one dummy round of RPC to initialize channels/transports. Without + # this, it's easy to hit timeout in rpc.shutdown() if there is no other RPC + # on that process before rpc.shutdown(), as the agent initialization can + # take longer than 5s. + api._all_gather(None, timeout=rpc_backend_options.rpc_timeout) + # Need a barrier here to make sure no peers leave before the rank0 finishes + # _all_gather + group.barrier().wait() + + return agent + # initialization for dynamic rpc (ranks can join and leave) + else: + with _group_membership_management(store, name, True): + # Construct TPAgent with empty reverse_device_map and devices + # these properties will be updated after initialization + agent = TensorPipeAgent( + store, + name, + rank, + world_size, + rpc_backend_options, + {}, + [], + ) + api._init_rpc_states(agent) + + try: + # Notify all workers in group this rank has joined and set devices and reverse_device_map + # This is a synchronous operation that completes once all existing ranks are updated + _set_devices_and_reverse_device_map(agent) + except Exception: + api.shutdown() + raise + return agent + + +register_backend( + "TENSORPIPE", + _tensorpipe_construct_rpc_backend_options_handler, + _tensorpipe_init_backend_handler, +) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/rpc/constants.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/rpc/constants.py new file mode 100644 index 0000000000000000000000000000000000000000..f0eaf92b8aef56dc96700c1ddb42bfb988542650 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/rpc/constants.py @@ -0,0 +1,24 @@ +from datetime import timedelta + +from torch._C._distributed_rpc import ( + _DEFAULT_INIT_METHOD, + _DEFAULT_NUM_WORKER_THREADS, + _DEFAULT_RPC_TIMEOUT_SEC, + _UNSET_RPC_TIMEOUT, +) + + +# For any RpcAgent. +DEFAULT_RPC_TIMEOUT_SEC: float = _DEFAULT_RPC_TIMEOUT_SEC +DEFAULT_INIT_METHOD: str = _DEFAULT_INIT_METHOD +DEFAULT_SHUTDOWN_TIMEOUT: float = 0 + +# For TensorPipeAgent. +DEFAULT_NUM_WORKER_THREADS: int = _DEFAULT_NUM_WORKER_THREADS +# Ensure that we don't time out when there are long periods of time without +# any operations against the underlying ProcessGroup. +DEFAULT_PROCESS_GROUP_TIMEOUT: timedelta = timedelta(milliseconds=2**31 - 1) +# Value indicating that timeout is not set for RPC call, and the default should be used. +UNSET_RPC_TIMEOUT: float = _UNSET_RPC_TIMEOUT + +__all__: list[str] = [] diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/rpc/functions.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/rpc/functions.py new file mode 100644 index 0000000000000000000000000000000000000000..e48ea8cc534ab87838965c947bbd0ed76d4d64c7 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/rpc/functions.py @@ -0,0 +1,169 @@ +# mypy: allow-untyped-defs +import functools + + +def async_execution(fn): + r""" + A decorator for a function indicating that the return value of the function + is guaranteed to be a :class:`~torch.futures.Future` object and this + function can run asynchronously on the RPC callee. More specifically, the + callee extracts the :class:`~torch.futures.Future` returned by the wrapped + function and installs subsequent processing steps as a callback to that + :class:`~torch.futures.Future`. The installed callback will read the value + from the :class:`~torch.futures.Future` when completed and send the + value back as the RPC response. That also means the returned + :class:`~torch.futures.Future` only exists on the callee side and is never + sent through RPC. This decorator is useful when the wrapped function's + (``fn``) execution needs to pause and resume due to, e.g., containing + :meth:`~torch.distributed.rpc.rpc_async` or waiting for other signals. + + .. note:: To enable asynchronous execution, applications must pass the + function object returned by this decorator to RPC APIs. If RPC detected + attributes installed by this decorator, it knows that this function + returns a ``Future`` object and will handle that accordingly. + However, this does not mean this decorator has to be outmost one when + defining a function. For example, when combined with ``@staticmethod`` + or ``@classmethod``, ``@rpc.functions.async_execution`` needs to be the + inner decorator to allow the target function be recognized as a static + or class function. This target function can still execute asynchronously + because, when accessed, the static or class method preserves attributes + installed by ``@rpc.functions.async_execution``. + + + Example:: + The returned :class:`~torch.futures.Future` object can come from + :meth:`~torch.distributed.rpc.rpc_async`, + :meth:`~torch.futures.Future.then`, or :class:`~torch.futures.Future` + constructor. The example below shows directly using the + :class:`~torch.futures.Future` returned by + :meth:`~torch.futures.Future.then`. + + >>> from torch.distributed import rpc + >>> + >>> # omitting setup and shutdown RPC + >>> + >>> # On all workers + >>> @rpc.functions.async_execution + >>> def async_add_chained(to, x, y, z): + >>> # This function runs on "worker1" and returns immediately when + >>> # the callback is installed through the `then(cb)` API. In the + >>> # mean time, the `rpc_async` to "worker2" can run concurrently. + >>> # When the return value of that `rpc_async` arrives at + >>> # "worker1", "worker1" will run the lambda function accordingly + >>> # and set the value for the previously returned `Future`, which + >>> # will then trigger RPC to send the result back to "worker0". + >>> return rpc.rpc_async(to, torch.add, args=(x, y)).then( + >>> lambda fut: fut.wait() + z + >>> ) + >>> + >>> # On worker0 + >>> # xdoctest: +SKIP + >>> ret = rpc.rpc_sync( + >>> "worker1", + >>> async_add_chained, + >>> args=("worker2", torch.ones(2), 1, 1) + >>> ) + >>> print(ret) # prints tensor([3., 3.]) + + When combined with TorchScript decorators, this decorator must be the + outmost one. + + >>> from torch import Tensor + >>> from torch.futures import Future + >>> from torch.distributed import rpc + >>> + >>> # omitting setup and shutdown RPC + >>> + >>> # On all workers + >>> @torch.jit.script + >>> def script_add(x: Tensor, y: Tensor) -> Tensor: + >>> return x + y + >>> + >>> @rpc.functions.async_execution + >>> @torch.jit.script + >>> def async_add(to: str, x: Tensor, y: Tensor) -> Future[Tensor]: + >>> return rpc.rpc_async(to, script_add, (x, y)) + >>> + >>> # On worker0 + >>> ret = rpc.rpc_sync( + >>> "worker1", + >>> async_add, + >>> args=("worker2", torch.ones(2), 1) + >>> ) + >>> print(ret) # prints tensor([2., 2.]) + + When combined with static or class method, this decorator must be the + inner one. + + >>> from torch.distributed import rpc + >>> + >>> # omitting setup and shutdown RPC + >>> + >>> # On all workers + >>> class AsyncExecutionClass: + >>> + >>> @staticmethod + >>> @rpc.functions.async_execution + >>> def static_async_add(to, x, y, z): + >>> return rpc.rpc_async(to, torch.add, args=(x, y)).then( + >>> lambda fut: fut.wait() + z + >>> ) + >>> + >>> @classmethod + >>> @rpc.functions.async_execution + >>> def class_async_add(cls, to, x, y, z): + >>> ret_fut = torch.futures.Future() + >>> rpc.rpc_async(to, torch.add, args=(x, y)).then( + >>> lambda fut: ret_fut.set_result(fut.wait() + z) + >>> ) + >>> return ret_fut + >>> + >>> @rpc.functions.async_execution + >>> def bound_async_add(self, to, x, y, z): + >>> return rpc.rpc_async(to, torch.add, args=(x, y)).then( + >>> lambda fut: fut.wait() + z + >>> ) + >>> + >>> # On worker0 + >>> ret = rpc.rpc_sync( + >>> "worker1", + >>> AsyncExecutionClass.static_async_add, + >>> args=("worker2", torch.ones(2), 1, 2) + >>> ) + >>> print(ret) # prints tensor([4., 4.]) + >>> + >>> ret = rpc.rpc_sync( + >>> "worker1", + >>> AsyncExecutionClass.class_async_add, + >>> args=("worker2", torch.ones(2), 1, 2) + >>> ) + >>> print(ret) # prints tensor([4., 4.]) + + This decorator also works with RRef helpers, i.e., . + :meth:`torch.distributed.rpc.RRef.rpc_sync`, + :meth:`torch.distributed.rpc.RRef.rpc_async`, and + :meth:`torch.distributed.rpc.RRef.remote`. + + >>> from torch.distributed import rpc + >>> + >>> # reuse the AsyncExecutionClass class above + >>> rref = rpc.remote("worker1", AsyncExecutionClass) + >>> ret = rref.rpc_sync().static_async_add("worker2", torch.ones(2), 1, 2) + >>> print(ret) # prints tensor([4., 4.]) + >>> + >>> rref = rpc.remote("worker1", AsyncExecutionClass) + >>> ret = rref.rpc_async().static_async_add("worker2", torch.ones(2), 1, 2).wait() + >>> print(ret) # prints tensor([4., 4.]) + >>> + >>> rref = rpc.remote("worker1", AsyncExecutionClass) + >>> ret = rref.remote().static_async_add("worker2", torch.ones(2), 1, 2).to_here() + >>> print(ret) # prints tensor([4., 4.]) + """ + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + return fn(*args, **kwargs) + + # Can't declare and use attributes of function objects (mypy#2087) + wrapper._wrapped_async_rpc_function = fn # type: ignore[attr-defined] + return wrapper diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/rpc/internal.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/rpc/internal.py new file mode 100644 index 0000000000000000000000000000000000000000..faef8afddfc2caac25c8360c216509aed5acf8c1 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/rpc/internal.py @@ -0,0 +1,285 @@ +# mypy: allow-untyped-defs +import collections +import copyreg +import io +import pickle +import sys +import threading +import traceback +from enum import Enum + +import torch +import torch.distributed as dist +from torch._C._distributed_rpc import _get_current_rpc_agent + + +__all__ = ["RPCExecMode", "serialize", "deserialize", "PythonUDF", "RemoteException"] + +# Thread local tensor tables to store tensors while pickling torch.Tensor +# objects +_thread_local_tensor_tables = threading.local() +_pickler = pickle.Pickler +_unpickler = pickle.Unpickler + + +class RPCExecMode(Enum): + SYNC = "sync" + ASYNC = "async" + ASYNC_JIT = "async_jit" + REMOTE = "remote" + + +class _InternalRPCPickler: + r""" + This class provides serialize() and deserialize() interfaces to serialize + data to be "binary string + tensor table" format + So for RPC python UDF function and args, non tensor data will be serialized + into regular binary string, tensor data will be put into thread local tensor + tables, this serialization format is consistent with builtin operator and args + using JIT pickler. This format will make tensor handling in C++ much easier, + e.g. attach tensor to distributed autograd graph in C++ + """ + + def __init__(self): + # Ignore type error because dispatch_table is defined in third-party package + self._dispatch_table = copyreg.dispatch_table.copy() # type: ignore[attr-defined] + self._dispatch_table[torch.Tensor] = self._tensor_reducer + # Used for registering customized picklers. + self._class_reducer_dict = {} + + def _register_reducer(self, obj_class, reducer): + # For the same class, only register the reducer once. + if obj_class not in self._class_reducer_dict: + self._class_reducer_dict[obj_class] = reducer + + @classmethod + def _tensor_receiver(cls, tensor_index): + global _thread_local_tensor_tables + return _thread_local_tensor_tables.recv_tables[tensor_index] + + def _tensor_reducer(self, tensor): + global _thread_local_tensor_tables + _thread_local_tensor_tables.send_tables.append(tensor) + tensor_index = len(_thread_local_tensor_tables.send_tables) - 1 + return (_InternalRPCPickler._tensor_receiver, (tensor_index,)) + + @classmethod + def _py_rref_receiver(cls, rref_fork_data): + return dist.rpc.PyRRef._deserialize(rref_fork_data) + + def _py_rref_reducer(self, py_rref): + rref_fork_data = py_rref._serialize() + return (_InternalRPCPickler._py_rref_receiver, (rref_fork_data,)) + + def _rref_reducer(self, rref): + return self._py_rref_reducer(rref) + + @classmethod + def _script_module_receiver(cls, script_module_serialized): + """ + Given a serialized representation of a ScriptModule created with torch.jit.save, + loads and returns the ScriptModule. + """ + f = io.BytesIO(script_module_serialized) + m = torch.jit.load(f) + return m + + def _script_module_reducer(self, script_module): + """ + Serializes a ScriptModule. + """ + f = io.BytesIO() + torch.jit.save(script_module, f) + return (_InternalRPCPickler._script_module_receiver, (f.getvalue(),)) + + def serialize(self, obj): + r""" + Serialize non tensor data into binary string, tensor data into + tensor table + """ + f = io.BytesIO() + p = _pickler(f) + p.dispatch_table = self._dispatch_table + + # rpc api could accept user picklers inheriting from _InternalRPCPickler to serialize rref, + # user picklers could have different initialization function from _InternalRPCPickler, + # but all the user picklers should call serialize() and use _rref_reducer to pickle rref + # in python. also, when _internal_rpc_pickler is imported to rpc/api.py, rpc.RRef is not + # compiled yet, it is not good place to access rpc.RRef inside _InternalRPCPickler constructor, + # so putting rref's dispatch table here + # + # The return value of a `rpc.remote(..)` call is type of `rpc.PyRRef`. + # The deserialized RRef object on an RPC receiver side is type of `rpc.PyRRef`. + # Ignore type error because dispatch_table is defined in third-party package + p.dispatch_table[dist.rpc.PyRRef] = self._py_rref_reducer # type: ignore[index] + # An RRef created locally by RRef Python constructor is type of `rpc.RRef`. + # Ignore type error because dispatch_table is defined in third-party package + p.dispatch_table[dist.rpc.RRef] = self._rref_reducer # type: ignore[index] + + # Add dispatch pickling for ScriptModule or its subclass. + if isinstance(obj, torch.jit.ScriptModule): + # Ignore type error because dispatch_table is defined in third-party package + p.dispatch_table[obj.__class__] = self._script_module_reducer # type: ignore[index] + + # Install customized picklers. + for class_name in self._class_reducer_dict: + p.dispatch_table[class_name] = self._class_reducer_dict[class_name] # type: ignore[index] + + # save _thread_local_tensor_tables.send_tables if it is in nested call + global _thread_local_tensor_tables + if hasattr(_thread_local_tensor_tables, "send_tables"): + old_send_tables = _thread_local_tensor_tables.send_tables + else: + old_send_tables = None + _thread_local_tensor_tables.send_tables = [] + + p.dump(obj) + + # restore _thread_local_tensor_tables.send_tables if return + # from nested call, otherwise clean up the table + tensors = _thread_local_tensor_tables.send_tables + if old_send_tables is not None: + _thread_local_tensor_tables.send_tables = old_send_tables + else: + del _thread_local_tensor_tables.send_tables + + return (f.getvalue(), tensors) + + def deserialize(self, binary_data, tensor_table): + r""" + Deserialize binary string + tensor table to original obj + """ + # save _thread_local_tensor_tables.recv_tables if it is in nested call + global _thread_local_tensor_tables + if hasattr(_thread_local_tensor_tables, "recv_tables"): + old_recv_tables = _thread_local_tensor_tables.recv_tables + else: + old_recv_tables = None + _thread_local_tensor_tables.recv_tables = tensor_table + + try: + unpickler = _unpickler(io.BytesIO(binary_data)) + ret = unpickler.load() + except AttributeError as e: + # Occurs when function is not found on module/class during + # unpickling. + except_str = ( + str(e) + + """ Default RPC pickler does not serialize + function code. Ensure that UDFs are defined on both caller and + callee modules.""" + ) + ret = AttributeError(except_str) + # Ensure the stack trace gets preserved + ret.__cause__ = e + + # restore _thread_local_tensor_tables.recv_tables if return + # from nested call, otherwise clean up the table + if old_recv_tables is not None: + _thread_local_tensor_tables.recv_tables = old_recv_tables + else: + del _thread_local_tensor_tables.recv_tables + + return ret + + +# Create _internal_rpc_pickler only once to initialize _dispatch_table only once +_internal_rpc_pickler = _InternalRPCPickler() + + +def serialize(obj): + return _internal_rpc_pickler.serialize(obj) + + +def deserialize(binary_data, tensor_table): + return _internal_rpc_pickler.deserialize(binary_data, tensor_table) + + +def _run_function(python_udf): + r""" + This function is exclusively called from C++. + See ``torch/csrc/distributed/rpc/python_rpc_handler.cpp``. + + Runs a Python UDF and returns its return value. + Wraps any exception in ``RemoteException`` if the function raises. + """ + try: + if isinstance(python_udf, AttributeError): + raise python_udf + result = python_udf.func(*python_udf.args, **python_udf.kwargs) + except Exception as e: + # except str = exception info + traceback string + except_str = ( + f"On {_get_current_rpc_agent().get_worker_info()}:\n" + f"{repr(e)}\n{traceback.format_exc()}" + ) + print(except_str, file=sys.stderr) + result = RemoteException(except_str, type(e)) + return result + + +def _handle_exception(result): + if isinstance(result, RemoteException): + exception_msg = result.msg.encode("utf-8").decode("unicode_escape") + # We wrap exception re-creation here in case some exception classes + # cannot be constructed directly from a string. + exc = None + try: + exc = result.exception_type(exception_msg) + except BaseException as e: # noqa: B036 + raise RuntimeError( # noqa: B904 + f"Failed to create original exception type. Error msg was {str(e)}" + f" Original exception on remote side was {exception_msg}" + ) from e + + if exc is not None: + raise exc + + +def _build_rpc_profiling_key( + exec_type, func_name, current_worker_name, dst_worker_name +): + """ + Builds the key that RPC calls are profiled with using the autograd profiler. + This will be the name of the corresponding Event recorded in the profiler. + + Args: + exec_type (RPCExecMode): Type of RPC/RRef call + func_name (str): Name of function being profiled. + current_worker_name (str): Name of current worker. + dst_worker_name (str): Name of the destination worker. + + Returns: + String representing profiling key + """ + profile_key = ( + f"rpc_{exec_type.value}#{func_name}({current_worker_name} -> {dst_worker_name})" + ) + return profile_key + + +def _start_record_function(exec_type, func_name, current_worker_name, dest_worker_name): + """ + This function should be called from RPC/RRef functions to create a + RecordFunction object for profiling. This function also runs the before + callbacks that start the profiling, though the user is responsible for + running the appropriate callbacks when the function to be profiled finishes. + + Args: + exec_type (RPCExecMode): Type of RPC/RRef call + func_name (str): Name of function being profiled. + current_worker_name (str): Name of current worker. + dest_worker_name (str): Name of the destination worker. + + Returns: + An instance of `torch.autograd._RecordFunction`. + """ + assert torch.autograd._profiler_enabled(), "Autograd profiler should be enabled." + profile_key = f"rpc_{exec_type.value}#{str(func_name)}({current_worker_name} -> {dest_worker_name})" + rf = torch.autograd._RecordFunction() # type: ignore[attr-defined] + torch.autograd._run_before_callbacks(rf, profile_key) # type: ignore[attr-defined] + return rf + + +PythonUDF = collections.namedtuple("PythonUDF", ["func", "args", "kwargs"]) +RemoteException = collections.namedtuple("RemoteException", ["msg", "exception_type"]) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/rpc/options.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/rpc/options.py new file mode 100644 index 0000000000000000000000000000000000000000..c58a2bf923910039502ed98f1fd742b827800f20 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/rpc/options.py @@ -0,0 +1,181 @@ +# mypy: allow-untyped-defs +from typing import Union + +import torch + +from . import _is_tensorpipe_available, constants as rpc_contants + + +DeviceType = Union[int, str, torch.device] + +__all__ = ["TensorPipeRpcBackendOptions"] + + +def _to_device(device: DeviceType) -> torch.device: + device = torch.device(device) + if device.type != "cuda": + raise ValueError( + "`set_devices` expect a list of CUDA devices, but got " + f"device type {device.type}." + ) + return device + + +def _to_device_map( + device_map: dict[DeviceType, DeviceType], +) -> dict[torch.device, torch.device]: + full_device_map: dict[torch.device, torch.device] = {} + reverse_map: dict[torch.device, torch.device] = {} + for k, v in device_map.items(): + k, v = torch.device(k), torch.device(v) + if v in reverse_map: + raise ValueError( + "`device_map` only supports 1-to-1 mapping, " + f"trying to map {k} and {reverse_map[v]} to {v}" + ) + full_device_map[k] = v + reverse_map[v] = k + return full_device_map + + +def _to_device_list(devices: list[DeviceType]) -> list[torch.device]: + return list(map(_to_device, devices)) + + +if _is_tensorpipe_available: # type: ignore[has-type] + from torch._C._distributed_rpc import _TensorPipeRpcBackendOptionsBase +else: + _TensorPipeRpcBackendOptionsBase = object # type: ignore[assignment, misc] + + +# pyrefly: ignore [invalid-inheritance] +class TensorPipeRpcBackendOptions(_TensorPipeRpcBackendOptionsBase): + r""" + The backend options for + :class:`~torch.distributed.rpc.TensorPipeAgent`, derived from + :class:`~torch.distributed.rpc.RpcBackendOptions`. + + Args: + num_worker_threads (int, optional): The number of threads in the + thread-pool used by + :class:`~torch.distributed.rpc.TensorPipeAgent` to execute + requests (default: 16). + rpc_timeout (float, optional): The default timeout, in seconds, + for RPC requests (default: 60 seconds). If the RPC has not + completed in this timeframe, an exception indicating so will + be raised. Callers can override this timeout for individual + RPCs in :meth:`~torch.distributed.rpc.rpc_sync` and + :meth:`~torch.distributed.rpc.rpc_async` if necessary. + init_method (str, optional): The URL to initialize the distributed + store used for rendezvous. It takes any value accepted for the + same argument of :meth:`~torch.distributed.init_process_group` + (default: ``env://``). + device_maps (Dict[str, Dict], optional): Device placement mappings from + this worker to the callee. Key is the callee worker name and value + the dictionary (``Dict`` of ``int``, ``str``, or ``torch.device``) + that maps this worker's devices to the callee worker's devices. + (default: ``None``) + devices (List[int, str, or ``torch.device``], optional): all local + CUDA devices used by RPC agent. By Default, it will be initialized + to all local devices from its own ``device_maps`` and corresponding + devices from its peers' ``device_maps``. When processing CUDA RPC + requests, the agent will properly synchronize CUDA streams for + all devices in this ``List``. + """ + + def __init__( + self, + *, + num_worker_threads: int = rpc_contants.DEFAULT_NUM_WORKER_THREADS, + rpc_timeout: float = rpc_contants.DEFAULT_RPC_TIMEOUT_SEC, + init_method: str = rpc_contants.DEFAULT_INIT_METHOD, + device_maps: dict[str, dict[DeviceType, DeviceType]] | None = None, + devices: list[DeviceType] | None = None, + _transports: list | None = None, + _channels: list | None = None, + ): + full_device_maps = ( + {} + if device_maps is None + else {k: _to_device_map(v) for k, v in device_maps.items()} + ) + full_device_list = [] if devices is None else _to_device_list(devices) + super().__init__( + num_worker_threads, + _transports, + _channels, + rpc_timeout, + init_method, + full_device_maps, + full_device_list, + ) + + def set_device_map(self, to: str, device_map: dict[DeviceType, DeviceType]): + r""" + Set device mapping between each RPC caller and callee pair. This + function can be called multiple times to incrementally add + device placement configurations. + + Args: + to (str): Callee name. + device_map (Dict of int, str, or torch.device): Device placement + mappings from this worker to the callee. This map must be + invertible. + + Example: + >>> # xdoctest: +SKIP("distributed") + >>> # both workers + >>> def add(x, y): + >>> print(x) # tensor([1., 1.], device='cuda:1') + >>> return x + y, (x + y).to(2) + >>> + >>> # on worker 0 + >>> options = TensorPipeRpcBackendOptions( + >>> num_worker_threads=8, + >>> device_maps={"worker1": {0: 1}} + >>> # maps worker0's cuda:0 to worker1's cuda:1 + >>> ) + >>> options.set_device_map("worker1", {1: 2}) + >>> # maps worker0's cuda:1 to worker1's cuda:2 + >>> + >>> rpc.init_rpc( + >>> "worker0", + >>> rank=0, + >>> world_size=2, + >>> backend=rpc.BackendType.TENSORPIPE, + >>> rpc_backend_options=options + >>> ) + >>> + >>> x = torch.ones(2) + >>> rets = rpc.rpc_sync("worker1", add, args=(x.to(0), 1)) + >>> # The first argument will be moved to cuda:1 on worker1. When + >>> # sending the return value back, it will follow the invert of + >>> # the device map, and hence will be moved back to cuda:0 and + >>> # cuda:1 on worker0 + >>> print(rets[0]) # tensor([2., 2.], device='cuda:0') + >>> print(rets[1]) # tensor([2., 2.], device='cuda:1') + """ + full_device_map = _to_device_map(device_map) + curr_device_maps = super().device_maps + + if to in curr_device_maps: + for k, v in full_device_map.items(): + if k in curr_device_maps[to] and v != curr_device_maps[to][k]: + raise ValueError( + "`set_device_map` only supports 1-to-1 mapping, trying" + f" to map {k} to {v} and {curr_device_maps[to][k]}" + ) + + super()._set_device_map(to, full_device_map) + + def set_devices(self, devices: list[DeviceType]): + r""" + Set local devices used by the TensorPipe RPC agent. When processing + CUDA RPC requests, the TensorPipe RPC agent will properly synchronize + CUDA streams for all devices in this ``List``. + + Args: + devices (List of int, str, or torch.device): local devices used by + the TensorPipe RPC agent. + """ + self.devices = _to_device_list(devices) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/rpc/rref_proxy.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/rpc/rref_proxy.py new file mode 100644 index 0000000000000000000000000000000000000000..46eecf19e22c9bcb11a475963f9be0461261b0a4 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/rpc/rref_proxy.py @@ -0,0 +1,80 @@ +# mypy: allow-untyped-defs +from functools import partial + +import torch +from torch.futures import Future + +from . import functions, rpc_async +from .constants import UNSET_RPC_TIMEOUT + + +def _local_invoke(rref, func_name, args, kwargs): + return getattr(rref.local_value(), func_name)(*args, **kwargs) + + +@functions.async_execution +def _local_invoke_async_execution(rref, func_name, args, kwargs): + return getattr(rref.local_value(), func_name)(*args, **kwargs) + + +def _invoke_rpc(rref, rpc_api, func_name, timeout, *args, **kwargs): + def _rref_type_cont(rref_fut): + rref_type = rref_fut.value() + + _invoke_func = _local_invoke + # Bypass ScriptModules when checking for async function attribute. + bypass_type = issubclass(rref_type, torch.jit.ScriptModule) or issubclass( + rref_type, torch._C.ScriptModule + ) + if not bypass_type: + func = getattr(rref_type, func_name) + if hasattr(func, "_wrapped_async_rpc_function"): + _invoke_func = _local_invoke_async_execution + + return rpc_api( + rref.owner(), + _invoke_func, + args=(rref, func_name, args, kwargs), + timeout=timeout, + ) + + rref_fut = rref._get_type(timeout=timeout, blocking=False) + + if rpc_api is not rpc_async: + rref_fut.wait() + return _rref_type_cont(rref_fut) + else: + # A little explanation on this. + # rpc_async returns a Future pointing to the return value of `func_name`, it returns a `Future[T]` + # Calling _rref_type_cont from the `then` lambda causes Future wrapping. IOW, `then` returns a `Future[Future[T]]` + # To address that, we return a Future that is completed with the result of the async call. + result: Future = Future() + + def _wrap_rref_type_cont(fut): + try: + _rref_type_cont(fut).then(_complete_op) + except BaseException as ex: # noqa: B036 + result.set_exception(ex) + + def _complete_op(fut): + try: + result.set_result(fut.value()) + except BaseException as ex: # noqa: B036 + result.set_exception(ex) + + rref_fut.then(_wrap_rref_type_cont) + return result + + +# This class manages proxied RPC API calls for RRefs. It is entirely used from +# C++ (see python_rpc_handler.cpp). +class RRefProxy: + def __init__(self, rref, rpc_api, timeout=UNSET_RPC_TIMEOUT): + self.rref = rref + self.rpc_api = rpc_api + self.rpc_timeout = timeout + + def __getattr__(self, func_name): + return partial( + _invoke_rpc, self.rref, self.rpc_api, func_name, self.rpc_timeout + ) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/rpc/server_process_global_profiler.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/rpc/server_process_global_profiler.py new file mode 100644 index 0000000000000000000000000000000000000000..29a916772d330b555673645a3e38308788b31535 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/rpc/server_process_global_profiler.py @@ -0,0 +1,190 @@ +#!/usr/bin/python3 +# mypy: allow-untyped-defs + +import itertools + +import torch + +# pyrefly: ignore [deprecated] +from torch.autograd.profiler_legacy import profile + +from . import ( + _disable_server_process_global_profiler, + _enable_server_process_global_profiler, +) + + +__all__: list[str] = [] + + +class _server_process_global_profile(profile): + """ + It has the same API as ``torch.autograd.profiler.profile`` class, + except that it enables profiling on all threads running RPC server request callbacks. + + Context manager that manages autograd profiler state and holds a summary of results. + Under the hood it just records events of functions being executed in C++ and + exposes those events to Python. You can wrap any code into it and it will + only report runtime of PyTorch functions. + Note: profiler is thread local and is automatically propagated into the async tasks + + Args: + enabled (bool, optional): Setting this to False makes this context manager a no-op. + Default: ``True``. + + use_cuda (bool, optional): Enables timing of CUDA events as well using the cudaEvent API. + Adds approximately 4us of overhead to each tensor operation. + Default: ``False`` + + record_shapes (bool, optional): If shapes recording is set, information + about input dimensions will be collected. This allows one to see which + dimensions have been used under the hood and further group by them + using prof.key_averages(group_by_input_shape=True). Please note that + shape recording might skew your profiling data. It is recommended to + use separate runs with and without shape recording to validate the timing. + Most likely the skew will be negligible for bottom most events (in a case + of nested function calls). But for higher level functions the total + self cpu time might be artificially increased because of the shape + collection. + + profile_memory (bool, optional): Whether to report memory usage, default: ``False`` + + .. warning:: + Enabling memory profiling incurs additional profiler overhead + + .. warning:: + Due to some CUDA multiprocessing limitations (see :ref:`multiprocessing-cuda-note`), + one cannot use the profiler with ``use_cuda = True`` to benchmark + DataLoaders with ``num_workers > 0``. If you wish to benchmark data loading, + please use ``use_cuda = False`` or ``num_workers = 0``. + + Example: + >>> # xdoctest: +SKIP + >>> # On worker 0: + >>> import torch + >>> import torch.distributed.rpc as rpc + >>> rpc.init_rpc("worker0", rank=0, world_size=2) + >>> x, y = torch.tensor(1), torch.tensor(2) + >>> outer_profile_rref = rpc.remote( + ... dst_worker_name, rpc._server_process_global_profile + ... ) + >>> outer_profile_rref.rpc_sync().__enter__() + >>> rpc.rpc_sync(dst_worker_name, torch.add, (x, y)) + >>> inner_profile_rref = rpc.remote( + ... dst_worker_name, rpc._server_process_global_profile + ... ) + >>> inner_profile_rref.rpc_sync().__enter__() + >>> rpc.rpc_sync(dst_worker_name, torch.sub, (x, y)) + >>> inner_profile_rref.rpc_sync().__exit__(None, None, None) + >>> outer_profile_rref.rpc_sync().__exit__(None, None, None) + >>> print(inner_profile_rref.rpc_sync().key_averages()) + --------- --------------- --------------- --------------- --------------- --------------- --------------- + Name Self CPU total % Self CPU total CPU total % CPU total CPU time avg Number of Calls + --------- --------------- --------------- --------------- --------------- --------------- --------------- + sub 85.06% 76.275us 100.00% 89.667us 89.667us 1 + empty 14.94% 13.392us 14.94% 13.392us 13.392us 1 + --------- --------------- --------------- --------------- --------------- --------------- --------------- + Self CPU time total: 89.667us + >>> print(outer_profile_rref.rpc_sync().key_averages()) + --------- --------------- --------------- --------------- --------------- --------------- --------------- + Name Self CPU total % Self CPU total CPU total % CPU total CPU time avg Number of Calls + --------- --------------- --------------- --------------- --------------- --------------- --------------- + sub 35.65% 76.275us 41.91% 89.667us 89.667us 1 + empty 12.67% 27.101us 12.67% 27.101us 13.551us 2 + add 51.68% 110.550us 58.09% 124.259us 124.259us 1 + --------- --------------- --------------- --------------- --------------- --------------- --------------- + Self CPU time total: 213.926us + >>> rpc.shutdown() + + >>> # On worker 1: + >>> import torch.distributed.rpc as rpc + >>> rpc.init_rpc("worker1", rank=1, world_size=2) + >>> # wait for worker 0 to finish work, and then shutdown. + >>> rpc.shutdown() + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def __enter__(self): + """ + Turn on server-side process-global profiling. + This enables thread-local profiler on all RPC threads running server-side request callbacks. + """ + if not self.enabled: + return + + if self.entered: # type: ignore[has-type] + raise RuntimeError("autograd profiler traces are not reentrant") + self.entered = True + + profiler_kind = ( + torch.autograd.ProfilerState.CUDA + if self.use_cuda + else torch.autograd.ProfilerState.CPU + ) + profiler_config = torch.autograd.ProfilerConfig( + profiler_kind, + self.record_shapes, + self.profile_memory, + False, + False, + False, + torch.profiler._ExperimentalConfig(), + ) + _enable_server_process_global_profiler(profiler_config) + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """ + Turn off server-side process-global profiling. + Aggregate all profiling events recorded by RPC threads. + + These attributes are assigned on exiting context. + + Attributes: + function_events (torch.autograd.profiler.EventList). It's a list that has helper + methods, like 1) show record items in a pretty-print table. + 2) do averaging by grouping on keys. 3) and more. + + process_global_function_events (List[torch.autograd.profiler.FunctionEvent]). + It's a list of ``FunctionEvent`` elements. Every element is a profiling result + of an RPC request handling within the profiling range. + """ + if not self.enabled: + return + + process_global_events = _disable_server_process_global_profiler() + + # Every element in this list is a thread profiling result from an RPC request handling. + process_global_function_events = [] + for thread_local_events in process_global_events: + # Parse from ``Event``s to ``FunctionEvent``s. + thread_local_function_events = ( + torch.autograd.profiler_legacy._parse_legacy_records( + thread_local_events + ) + ) + thread_local_function_events.sort( + key=lambda function_event: [ + function_event.time_range.start, + -(function_event.time_range.end), + ] + ) + process_global_function_events.append(thread_local_function_events) + + flattened_function_events = list( + itertools.chain.from_iterable(process_global_function_events) + ) + # pyrefly: ignore [bad-assignment] + self.function_events = torch.autograd.profiler_util.EventList( + flattened_function_events, + use_device="cuda" if self.use_cuda else None, + profile_memory=self.profile_memory, + ) + # pyrefly: ignore [missing-attribute] + self.function_events._build_tree() + + self.process_global_function_events = process_global_function_events + + return False diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..067d4c0917e9de33b516c7ed47c678be2ac6c692 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/__init__.py @@ -0,0 +1,88 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates + +import torch +import torch.distributed.tensor._ops # force import all built-in dtensor ops +from torch.distributed.device_mesh import DeviceMesh, init_device_mesh # noqa: F401 +from torch.distributed.tensor._api import ( + distribute_module, + distribute_tensor, + DTensor, + empty, + full, + ones, + rand, + randn, + zeros, +) +from torch.distributed.tensor.placement_types import ( + Partial, + Placement, + Replicate, + Shard, +) +from torch.optim.optimizer import ( + _foreach_supported_types as _optim_foreach_supported_types, +) +from torch.utils._foreach_utils import ( + _foreach_supported_types as _util_foreach_supported_types, +) + + +# All public APIs from dtensor package +__all__ = [ + "DTensor", + "distribute_tensor", + "distribute_module", + "Shard", + "Replicate", + "Partial", + "Placement", + "ones", + "empty", + "full", + "rand", + "randn", + "zeros", +] + +# For weights_only torch.load +from ._dtensor_spec import ( + DTensorSpec as _DTensorSpec, + ShardOrderEntry as _ShardOrderEntry, + TensorMeta as _TensorMeta, +) + + +torch.serialization.add_safe_globals( + [ + DeviceMesh, + _DTensorSpec, + _TensorMeta, + _ShardOrderEntry, + DTensor, + Partial, + Replicate, + Shard, + ] +) + + +# Append DTensor to the list of supported types for foreach implementation for optimizer +# and clip_grad_norm_ so that we will try to use foreach over the for-loop implementation on CUDA. +if DTensor not in _optim_foreach_supported_types: + _optim_foreach_supported_types.append(DTensor) + +if DTensor not in _util_foreach_supported_types: + _util_foreach_supported_types.append(DTensor) # type: ignore[arg-type] + + +# Set namespace for exposed private names +DTensor.__module__ = "torch.distributed.tensor" +distribute_tensor.__module__ = "torch.distributed.tensor" +distribute_module.__module__ = "torch.distributed.tensor" +ones.__module__ = "torch.distributed.tensor" +empty.__module__ = "torch.distributed.tensor" +full.__module__ = "torch.distributed.tensor" +rand.__module__ = "torch.distributed.tensor" +randn.__module__ = "torch.distributed.tensor" +zeros.__module__ = "torch.distributed.tensor" diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9f32433b645130868a535f58c32e62e118e67236 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/__pycache__/__init__.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/__pycache__/_api.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/__pycache__/_api.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7b597e41d64870542a48e03ca77b86222299314d Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/__pycache__/_api.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/__pycache__/_argmin_argmax.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/__pycache__/_argmin_argmax.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..44e9b9b851a4b625107da55ba254efb4a80f8378 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/__pycache__/_argmin_argmax.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/__pycache__/_collective_utils.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/__pycache__/_collective_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d48d2bcf2ad5c3d22a33e8c5bbb854b661ac1873 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/__pycache__/_collective_utils.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/__pycache__/_dispatch.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/__pycache__/_dispatch.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e86358101d8100506c192d046311b7603f4bacba Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/__pycache__/_dispatch.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/__pycache__/_dtensor_spec.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/__pycache__/_dtensor_spec.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a47b80287799b7cded4a1e1fb05823b6490c6b12 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/__pycache__/_dtensor_spec.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/__pycache__/_op_schema.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/__pycache__/_op_schema.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a843b6697f3e26dc1033d90f318ac5ee5e4e5273 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/__pycache__/_op_schema.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/__pycache__/_random.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/__pycache__/_random.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c66f5a531e463c11603e9145e3f34a36c48bcb74 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/__pycache__/_random.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/__pycache__/_redistribute.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/__pycache__/_redistribute.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7fede90660e7458a254544c81f88c70235433201 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/__pycache__/_redistribute.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/__pycache__/_sharding_prop.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/__pycache__/_sharding_prop.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2c6d8762d8a8c1663cdfa6b806a4b28cd70de407 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/__pycache__/_sharding_prop.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/__pycache__/_shards_wrapper.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/__pycache__/_shards_wrapper.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5073606399c33415e0c6edb5f0049bf7c4977b74 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/__pycache__/_shards_wrapper.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/__pycache__/_tp_conv.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/__pycache__/_tp_conv.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..776ececd9d1833fbc4f885e4ab4bb25be323a936 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/__pycache__/_tp_conv.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/__pycache__/_utils.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/__pycache__/_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8efaed71237bcb289bf2e2f20fca4e61b5a61603 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/__pycache__/_utils.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/__pycache__/device_mesh.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/__pycache__/device_mesh.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d89a1df7b6b0d91292b623bc2291e11322ec5c5b Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/__pycache__/device_mesh.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/__pycache__/placement_types.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/__pycache__/placement_types.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..27d0b94236e328577228ef8804888d7bf0b501e8 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/__pycache__/placement_types.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/_api.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/_api.py new file mode 100644 index 0000000000000000000000000000000000000000..78e00d5137ea075fdcda11d3e97f2ed7ed7f3f0a --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/_api.py @@ -0,0 +1,1385 @@ +# mypy: allow-untyped-decorators +# mypy: allow-untyped-defs +# Copyright (c) Meta Platforms, Inc. and affiliates +import copy +import inspect +import warnings +from collections.abc import Callable, Sequence +from typing import Any +from typing_extensions import deprecated + +import torch +import torch.distributed.tensor._dispatch as op_dispatch +import torch.distributed.tensor._random as random +import torch.nn as nn +from torch._export.wrappers import mark_subclass_constructor_exportable_experimental +from torch.distributed.device_mesh import _mesh_resources, DeviceMesh +from torch.distributed.tensor._collective_utils import check_tensor_meta, mesh_broadcast +from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta +from torch.distributed.tensor._redistribute import ( + Redistribute, + redistribute_local_tensor, +) +from torch.distributed.tensor._utils import ( + compute_global_tensor_info, + compute_local_shape_and_global_offset, + normalize_to_torch_size, +) +from torch.distributed.tensor.placement_types import ( + _StridedShard, + Partial, + Placement, + Replicate, + Shard, +) + + +__all__ = [ + "DTensor", + "distribute_tensor", + "distribute_module", + "ones", + "empty", + "full", + "rand", + "randn", + "zeros", +] + +aten = torch.ops.aten + + +# NOTE [Autograd interaction between torch.Tensor] +# +# The autograd functions defined below are being used by the public +# facing APIs (i.e. from_local, to_local) to ensure DTensor to work +# together with torch.Tensor within the autograd engine. This +# allows DTensor to only exist on part of the module hierarchy. +# +# As an example, we have the a module that consists of submodules +# A, B, and C, the execution flow would be like: +# input(torch.Tensor) -> Module A -> Module B -> Module C -> output (torch.Tensor) +# +# Suppose I only want to make Module B be a sharded module with +# DTensor params, the following forward/backward should work: +# +# input(torch.Tensor) -> Module A +# -> DTensor input (from_local) -> Sharded Module B -> DTensor output +# -> torch.Tensor output (to_local) -> Module C +# +# So from_local/to_local must be Autograd functions. +# +class _ToTorchTensor(torch.autograd.Function): + @staticmethod + def forward( # type: ignore[override] + ctx, + input: "DTensor", + grad_placements: Sequence[Placement] | None, + ): + ctx.dtensor_spec = input._spec + ctx.grad_placements = grad_placements + local_tensor = input._local_tensor + + # We need to return a fresh Tensor object there as autograd metadata + # will be inplaced into it. So we don't want to pollute the Tensor + # object stored in the _local_tensor of this DTensor. + return local_tensor.view_as(local_tensor) + + @staticmethod + def backward(ctx, grad_output: torch.Tensor): # type: ignore[override] + dtensor_spec = ctx.dtensor_spec + mesh = dtensor_spec.mesh + grad_placements = ctx.grad_placements + dtensor_meta = dtensor_spec.tensor_meta + + _, tensor_stride = compute_global_tensor_info( + grad_output, mesh, dtensor_spec.placements + ) + tensor_stride = tuple(tensor_stride) + grad_placements = grad_placements or dtensor_spec.placements + if ( + tensor_stride == dtensor_meta.stride + and grad_placements == dtensor_spec.placements + ): + # Avoid actual sharing of specs in case they're modified during (e.g.) + # sharding propagation. + grad_spec = copy.copy(dtensor_spec) + else: + grad_spec = DTensorSpec( + mesh, + grad_placements, + tensor_meta=TensorMeta( + shape=dtensor_meta.shape, + stride=tensor_stride, + dtype=dtensor_meta.dtype, + ), + ) + return ( + # pyrefly: ignore [bad-argument-type] + DTensor( + # pyrefly: ignore [bad-argument-count] + grad_output, + grad_spec, + # pyrefly: ignore [unexpected-keyword] + requires_grad=grad_output.requires_grad, + ), + None, + ) + + +class _FromTorchTensor(torch.autograd.Function): + @staticmethod + def forward( # type: ignore[override] + ctx, # pyre-ignore[2]: Parameter must be annotated. + input: torch.Tensor, + device_mesh: DeviceMesh, + placements: tuple[Placement, ...], + run_check: bool, + shape: torch.Size | None = None, + stride: tuple[int, ...] | None = None, + ) -> "DTensor": + ctx.previous_placement = placements + ctx.previous_device_mesh = device_mesh + + if shape and stride: + tensor_shape, tensor_stride = shape, stride + elif not shape and not stride: + # if it's not by default run_check, we assume user is certain that each + # rank has the same tensor shape, and we just use that to calculate the + # global shape + global_shape, global_stride = compute_global_tensor_info( + input, device_mesh, placements + ) + tensor_shape, tensor_stride = torch.Size(global_shape), tuple(global_stride) + else: + raise RuntimeError( + f"Found shape:{shape}, stride:{stride}.", + "Please pass both shape and stride at the same time.", + ) + + if device_mesh.get_coordinate() is None: + # if the global rank is not participating in the device mesh, we + # simply set the local tensor to an empty tensor + input = input.new_empty(0, requires_grad=input.requires_grad) + elif run_check: + # TODO: support uneven sharding when global shape/stride not passed, by + # building the global TensorMeta during check_tensor_meta + check_shape_stride = not shape and not stride + check_tensor_meta(input, check_shape_stride=check_shape_stride) + # TODO: See if we need to make this run_check logic + # have a corresponding backward. + for idx, placement in enumerate(placements): + if placement.is_replicate(): + # broadcast rank 0 tensor to all ranks + # only broadcast if run_check is True + input = input.contiguous() + mesh_broadcast(input, device_mesh, mesh_dim=idx) + + dist_spec = DTensorSpec( + device_mesh, + placements, + tensor_meta=TensorMeta( + tensor_shape, + tensor_stride, + input.dtype, + ), + ) + + # We want a fresh Tensor object that shares memory with the input tensor + # pyrefly: ignore [bad-argument-type] + dist_tensor = DTensor( + # pyrefly: ignore [bad-argument-count] + input.view_as(input), + dist_spec, + # requires_grad of the dist tensor depends on if input + # requires_grad or not + # pyrefly: ignore [unexpected-keyword] + requires_grad=input.requires_grad, + ) + return dist_tensor + + @staticmethod + def backward(ctx, grad_output: "DTensor"): # type: ignore[override] + previous_placement = ctx.previous_placement + previous_device_mesh = ctx.previous_device_mesh + + # reshard to the placement when creating DistributedTensor + # so that the gradient layout matches, and we could return + # local gradients directly + if grad_output.placements != previous_placement: + current_spec = grad_output._spec + target_spec = DTensorSpec( + previous_device_mesh, + previous_placement, + tensor_meta=grad_output._spec.tensor_meta, + ) + local_tensor = grad_output._local_tensor + output = redistribute_local_tensor( + local_tensor, current_spec, target_spec, is_backward=True + ) + # TODO: return the redistributed local tensor directly without + # differentiable backward. see if this make sense for all cases. + return output, None, None, None, None, None + + # TODO: backward is also differentiable now, add a test + # to test higher level gradients. + return grad_output.to_local(), None, None, None, None, None + + +class DTensor(torch.Tensor): + """ + ``DTensor`` (Distributed Tensor) is a subclass of ``torch.Tensor`` that provides single-device like + abstraction to program with multi-device ``torch.Tensor``. It describes the distributed tensor sharding + layout (DTensor Layout) through the :class:`DeviceMesh` and following types of :class:`Placement`: + + * :class:`Shard`: Tensor sharded on the tensor dimension ``dim`` on the devices of the ``DeviceMesh`` dimension + * :class:`Replicate`: Tensor replicated on the devices of the ``DeviceMesh`` dimension + * :class:`Partial`: Tensor is pending reduction on the devices of the ``DeviceMesh`` dimension + + When calling PyTorch operators, ``DTensor`` overrides the PyTorch operators to perform sharded computation and issue + communications whenever necessary. Along with the operator computation, ``DTensor`` will transform or propagate the + placements (DTensor Layout) properly (based on the operator semantic itself) and generate new ``DTensor`` outputs. + + To ensure numerical correctness of the ``DTensor`` sharded computation when calling PyTorch operators, ``DTensor`` + requires every Tensor argument of the operator be DTensor. + + .. note:: Directly using the Tensor subclass constructor here is not the recommended way to create a ``DTensor`` + (i.e. it does not handle autograd correctly hence is not the public API). Please refer to the `create_dtensor`_ + section to see how to create a ``DTensor``. + """ + + _local_tensor: torch.Tensor + _spec: DTensorSpec + __slots__ = ["_local_tensor", "_spec"] + + # _op_dispatcher instance as a class attribute to handle runtime dispatching logic + _op_dispatcher: op_dispatch.OpDispatcher = op_dispatch.OpDispatcher() + + # This implementation is just to convince mypy _spec and _local_tensor are + # initialized; it is immediately overridden below. + def __new__( + cls, + local_tensor: torch.Tensor, + spec: DTensorSpec, + *, + requires_grad: bool, + ) -> "DTensor": + r = torch.Tensor._dtensor__new__( + cls, local_tensor, spec, requires_grad=requires_grad + ) + r._spec = spec + r._local_tensor = local_tensor + return r + + __new__ = torch.Tensor._dtensor__new__ # type: ignore[assignment] # noqa: F811 + + @torch._disable_dynamo + @mark_subclass_constructor_exportable_experimental + def __init__(self, *args, **kwargs): + """ + Construct a DTensor from a local tensor, device mesh, and placement and + other tensor properties (i.e. shape, requires_grad, strides, etc). + .. note:: This is not a public API and it's only supposed to be used by the + operator implementations and internals. If you want to construct a + DTensor from a local tensor, consider using ``DTensor.from_local``, if + you want to construct a DTensor from a "global" tensor (where you + already have tensor initialized and want to shard this tensor), + consider using ``distribute_tensor``. + """ + super().__init__() + + # pyre-fixme[14]: `__repr__` overrides method defined in `DTensor` inconsistently. + # pyre-fixme[3]: Return type must be annotated. + def __repr__(self): # type: ignore[override] + # TODO: consider all_gather the local tensors for better debugging + return f"DTensor(local_tensor={self._local_tensor}, device_mesh={self._spec.mesh}, placements={self._spec.placements})" + + def __tensor_flatten__(self): + """ + protocol to inform how to flatten a DTensor to local tensor + for PT2 tracing + """ + return ["_local_tensor"], (self._spec, self.requires_grad) + + @staticmethod + def __tensor_unflatten__(inner_tensors, flatten_spec, outer_size, outer_stride): + assert flatten_spec is not None, ( + "Expecting spec to be not None from `__tensor_flatten__` return value!" + ) + local_tensor = inner_tensors["_local_tensor"] + spec, requires_grad = flatten_spec + unflatten_tensor_meta = TensorMeta( + shape=outer_size, + stride=outer_stride, + dtype=spec.tensor_meta.dtype, + ) + unflatten_spec = DTensorSpec( + spec.mesh, + spec.placements, + tensor_meta=unflatten_tensor_meta, + ) + # pyrefly: ignore [bad-argument-type] + return DTensor( + # pyrefly: ignore [bad-argument-count] + local_tensor, + unflatten_spec, + # pyrefly: ignore [unexpected-keyword] + requires_grad=requires_grad, + ) + + def __coerce_tangent_metadata__(self): + if not any(isinstance(p, Partial) for p in self.placements): + return self + placements = [ + Replicate() if isinstance(p, Partial) else p for p in self.placements + ] + return self.redistribute(device_mesh=self.device_mesh, placements=placements) + + def __coerce_same_metadata_as_tangent__(self, flatten_spec, expected_type=None): + if expected_type is not None: + return None + + (spec, _) = flatten_spec # Result of tensor_flatten() + return self.redistribute( + device_mesh=self.device_mesh, + placements=spec.placements, + ) + + @classmethod + def __torch_dispatch__(cls, func, types, args=(), kwargs=None): # type: ignore[override] + # We just need to have an implementation here; the __torch_dispatch__ machinery + # calls into a specific C++ fast path that doesn't call here. + # See #167051 for details + # python_arg_parser.cpp: dispatch_on_subclass() + # -> python_variable.cpp: dispatchDTensorOp() + raise NotImplementedError( + "DTensor.__torch_dispatch__ should not actually get called" + ) + + @staticmethod + def from_local( + local_tensor: torch.Tensor, + device_mesh: DeviceMesh | None = None, + placements: Sequence[Placement] | None = None, + *, + run_check: bool = False, + shape: torch.Size | None = None, + stride: tuple[int, ...] | None = None, + ) -> "DTensor": + """ + Create a :class:`DTensor` from a local torch.Tensor on each rank + according to the ``device_mesh`` and ``placements`` specified. + + Args: + local_tensor (torch.Tensor): local torch.Tensor on each rank. + device_mesh (:class:`DeviceMesh`, optional): DeviceMesh to place the + tensor, if not specified, must be called under a DeviceMesh + context manager, default: None + placements (List[:class:`Placement`], optional): the placements that + describes how to place the local torch.Tensor on DeviceMesh, must + have the same number of elements as ``device_mesh.ndim``. + + Keyword args: + run_check (bool, optional): at a cost of extra communications, perform + sanity check across ranks to check each local tensor's meta information + to ensure correctness. If have :class:`Replicate` in ``placements``, the + data on first rank of the device mesh dimension will be broadcasted + to other ranks. default: False + shape (torch.Size, optional): A List of int which specifies the size of + DTensor which build on top of `local_tensor`. Note this needs to be + provided if the shape of ``local_tensor`` are different across the ranks. + If not provided, ``shape`` will be computed assuming the given distributed + tensor is evenly sharded across ranks. default: None + stride (tuple, optional): A List of int which specifies the stride of DTensor. + If not provided, ``stride`` will be computed assuming the given distributed + tensor is evenly sharded across ranks. default: None + + Returns: + A :class:`DTensor` object + + .. note:: When ``run_check=False``, it is the user's responsibility to ensure the + local tensor passed in is correct across ranks (i.e. the tensor is sharded for + the ``Shard(dim)`` placement or replicated for the ``Replicate()`` placement). + If not, the behavior of the created DTensor is undefined. + + .. note:: ``from_local`` is differentiable, the `requires_grad` of the created + `DTensor` object will depend on if `local_tensor` requires_grad or not. + """ + # `local_tensor` argument cannot be DTensor + if isinstance(local_tensor, DTensor): + raise RuntimeError( + f"the local_tensor argument only accepts torch.Tensor but got {type(local_tensor)} value." + ) + + # if same shape/dtype, no need to run_check, if not, must allgather + # the metadatas to check the size/dtype across ranks + # There should be no data communication unless there's replication + # strategy, where we broadcast the replication from the first rank + # in the mesh dimension + device_mesh = device_mesh or _mesh_resources.get_current_mesh() + device_type = device_mesh.device_type + + # convert the local tensor to desired device base on device mesh's device_type + if device_type != local_tensor.device.type and not local_tensor.is_meta: + local_tensor = local_tensor.to(device_type) + + # set default placements to replicated if not specified + if placements is None: + placements = [Replicate() for _ in range(device_mesh.ndim)] + else: + placements = list(placements) + for idx, placement in enumerate(placements): + # normalize shard dim to be positive + if isinstance(placement, Shard | _StridedShard): + if placement.dim < 0: + normalized_dim = placement.dim + local_tensor.ndim + if type(placement) is _StridedShard: + placements[idx] = _StridedShard( + normalized_dim, split_factor=placement.split_factor + ) + elif type(placement) is Shard: + placements[idx] = Shard(normalized_dim) + + # `from_local` is differentiable, and the gradient of the dist tensor this function + # created should flow back the gradients to the local_tensor, so we call an autograd + # function to construct the dist tensor instead. + return _FromTorchTensor.apply( # pyre-ignore[16]: autograd func + local_tensor, + device_mesh, + tuple(placements), + run_check, + shape, + stride, + ) + + def to_local( + self, *, grad_placements: Sequence[Placement] | None = None + ) -> torch.Tensor: + """ + Get the local tensor of this DTensor on its current rank. For sharding it returns + a local shard of the logical tensor view, for replication it returns the replica on + its current rank. + + Keyword args: + grad_placements (List[:class:`Placement`], optional): the placements describes + the future layout of any gradient layout of the Tensor returned from this + function. + `to_local` converts DTensor to local tensor and the returned local tensor + might not be used as the original DTensor layout later in the code. This + argument is the hint that user can give to autograd in case the gradient + layout of the returned tensor does not match the original DTensor layout. + If not specified, we will assume the gradient layout remains the same + as the original DTensor and use that for gradient computation. + + Returns: + A :class:`torch.Tensor` or ``AsyncCollectiveTensor`` object. it represents the + local tensor on its current rank. When an ``AsyncCollectiveTensor`` object is returned, + it means the local tensor is not ready yet (i.e. communication is not finished). In this + case, user needs to call ``wait`` to wait the local tensor to be ready. + + .. note:: ``to_local`` is differentiable, the ``requires_grad`` of the local tensor returned + will depend on if the `DTensor` requires_grad or not. + """ + if not torch.is_grad_enabled(): + return self._local_tensor + + if grad_placements is not None and not isinstance(grad_placements, tuple): + grad_placements = tuple(grad_placements) + return _ToTorchTensor.apply( + self, grad_placements + ) # pyre-ignore[16]: autograd func + + def redistribute( + self, + device_mesh: DeviceMesh | None = None, + placements: Sequence[Placement] | None = None, + *, + async_op: bool = False, + forward_dtype: torch.dtype | None = None, + backward_dtype: torch.dtype | None = None, + ) -> "DTensor": + """ + ``redistribute`` performs necessary collective operations that redistribute the current + DTensor from its current placements to a new placements, or from its current DeviceMesh + to a new DeviceMesh. i.e. we can turn a Sharded DTensor to a Replicated DTensor by + specifying a Replicate placement for each dimension of the DeviceMesh. + + When redistributing from current to the new placements on one device mesh dimension, we + will perform the following operations including communication collective or local operation: + + 1. ``Shard(dim)`` -> ``Replicate()``: ``all_gather`` + 2. ``Shard(src_dim)`` -> ``Shard(dst_dim)``: ``all_to_all`` + 3. ``Replicate()`` -> ``Shard(dim)``: local chunking (i.e. ``torch.chunk``) + 4. ``Partial()`` -> ``Replicate()``: ``all_reduce`` + 5. ``Partial()`` -> ``Shard(dim)``: ``reduce_scatter`` + + + ``redistribute`` would correctly figure out the necessary redistribute steps for DTensors + that are created either on 1-D or N-D DeviceMesh. + + Args: + device_mesh (:class:`DeviceMesh`, optional): DeviceMesh to place the + DTensor. If not specified, it would use the current DTensor's DeviceMesh. + default: None + placements (List[:class:`Placement`], optional): the new placements that + describes how to place the DTensor into the DeviceMesh, must + have the same number of elements as ``device_mesh.ndim``. + default: replicate on all mesh dimensions + + Keyword args: + async_op (bool, optional): whether to perform the DTensor redistribute operation + asynchronously or not. Default: False + forward_dtype (torch.dtype, optional): the local tensor datatype can be converted to + ``forward_dtype`` before redistributing the local tensor in its forward. + The result DTensor will be in ``forward_dtype`` Default: None. + backward_dtype (torch.dtype, optional): the local tensor datatype can be converted to + ``backward_dtype`` before redistributing the local tensor in its backward. + The result DTensor gradient would be converted back to the current DTensor dtype. Default: None + + Returns: + A :class:`DTensor` object + + .. note:: ``redistribute`` is differentiable, which means user do not need to worry about + the backward formula of the redistribute operation. + + .. note:: ``redistribute`` currently only supports redistributing DTensor on the same DeviceMesh, + Please file an issue if you need to redistribute DTensor to different DeviceMesh. + """ + # NOTE: This redistribute API currently only supports out + # of place redistribution, i.e. it always create a new + # DTensor object and leave the original one unchanged. + + # if device_mesh is not specified, use the current device_mesh + device_mesh = device_mesh or self.device_mesh + # raise error if new placements not specified + if placements is None: + raise RuntimeError("placements is needed for redistribute!") + + placements = list(placements) + for i, placement in enumerate(placements): + if placement.is_partial() and self.placements[i] != placement: + raise RuntimeError( + f"Can not redistribute from {self.placements[i]} to {placement}, " + "redistributing to Partial is for internal use only!" + ) + elif isinstance(placement, Shard) and placement.dim < 0: + # normalize shard dim to be positive + placements[i] = Shard(placement.dim + self.ndim) + elif isinstance(placement, _StridedShard) and placement.dim < 0: + placements[i] = _StridedShard( + placement.dim + self.ndim, split_factor=placement.split_factor + ) + placements = tuple(placements) + + # pyre-fixme[16]: `Redistribute` has no attribute `apply`. + return Redistribute.apply( + self, device_mesh, placements, async_op, forward_dtype, backward_dtype + ) + + def full_tensor( + self, *, grad_placements: Sequence[Placement] | None = None + ) -> torch.Tensor: + """ + Return the full tensor of this DTensor. It will perform necessary collectives + to gather the local tensors from other ranks in its DeviceMesh and concatenate + them together. It's a syntactic sugar of the following code: + + ``dtensor.redistribute(placements=[Replicate()] * mesh.ndim).to_local()`` + + Keyword args: + grad_placements (List[:class:`Placement`], optional): the placements describes + the future layout of any gradient layout of the full Tensor returned from this + function. + `full_tensor` converts DTensor to a full torch.Tensor and the returned torch.tensor + might not be used as the original replicated DTensor layout later in the code. This + argument is the hint that user can give to autograd in case the gradient + layout of the returned tensor does not match the original replicated DTensor layout. + If not specified, we will assume the gradient layout of the full tensor be replicated. + + Returns: + A :class:`torch.Tensor` object that represents the full tensor of this DTensor. + + .. note:: ``full_tensor`` is differentiable. + """ + + redist_res = self.redistribute( + placements=[Replicate()] * self.device_mesh.ndim, async_op=False + ) + return _ToTorchTensor.apply(redist_res, grad_placements) + + @property + def device_mesh(self) -> DeviceMesh: + """ + The :class:`DeviceMesh` attribute that associates with this DTensor object. + + .. note:: ``device_mesh`` is a read-only property, it can not be set. + """ + return self._spec.mesh + + @property + def placements(self) -> tuple[Placement, ...]: + """ + The placements attribute of this DTensor that describes the layout of this + DTensor on the its DeviceMesh. + + .. note:: ``placements`` is a read-only property, it can not be set. + """ + return self._spec.placements + + def _raise_if_contains_partial_placements(self) -> None: + """ + Raise an error if the DTensor contains partial placements. + """ + for placement in self._spec.placements: + if not isinstance(placement, Partial): + continue + + raise ValueError( + "Any checkpointing related operations are not supported for " + "DTensor with partial placements!" + ) + + def __create_write_items__(self, fqn: str, object: Any): + self._raise_if_contains_partial_placements() + from torch.distributed.checkpoint.planner_helpers import ( + _create_write_items_for_dtensor, + ) + + if hasattr(self._local_tensor, "__create_write_items__"): + return self._local_tensor.__create_write_items__(fqn, object) # type: ignore[attr-defined] + elif isinstance(self._local_tensor, torch.Tensor): + return [_create_write_items_for_dtensor(fqn, object)] + else: + raise RuntimeError("Unsupported tensor type!") + + def __create_chunk_list__(self): + """ + Return a list of ChunkStorageMetadata, which is a dataclass that describes the size/offset of the local shard/replica + on current rank. For DTensor, each rank will have a single local shard/replica, so the returned list usually only + has one element. + + This dunder method is primariy used for distributed checkpoint purpose. + + Returns: + A List[:class:`ChunkStorageMetadata`] object that represents the shard size/offset on the current rank. + """ + self._raise_if_contains_partial_placements() + from torch.distributed.checkpoint.planner_helpers import ( + _create_chunk_from_dtensor, + ) + + if hasattr(self._local_tensor, "__create_chunk_list__"): + return self._local_tensor.__create_chunk_list__() # type: ignore[attr-defined] + elif isinstance(self._local_tensor, torch.Tensor): + return [_create_chunk_from_dtensor(self)] + else: + raise RuntimeError("Unsupported tensor type!") + + def __get_tensor_shard__(self, index): + self._raise_if_contains_partial_placements() + if hasattr(self._local_tensor, "__get_tensor_shard__"): + return self._local_tensor.__get_tensor_shard__(index) # type: ignore[attr-defined] + elif isinstance(self._local_tensor, torch.Tensor): + return self.to_local() + else: + raise RuntimeError("Unsupported tensor type!") + + @classmethod + def __metadata_guard__( + cls, orig: tuple[DTensorSpec, bool], other: tuple[DTensorSpec, bool] + ) -> bool: + # TODO - delete this - This is now unused after the PR - + # https://github.com/pytorch/pytorch/pull/165824 + orig_spec, orig_requires_grad = orig + other_spec, other_requires_grad = other + return ( + orig_spec._check_equals(other_spec, skip_shapes=True) + and orig_requires_grad == other_requires_grad + ) + + +def distribute_tensor( + tensor: torch.Tensor, + device_mesh: DeviceMesh | None = None, + placements: Sequence[Placement] | None = None, + *, + src_data_rank: int | None = 0, +) -> DTensor: + """ + Distribute a leaf ``torch.Tensor`` (i.e. nn.Parameter/buffers) to the ``device_mesh`` according + to the ``placements`` specified. The rank of ``device_mesh`` and ``placements`` must be the + same. The ``tensor`` to distribute is the logical or "global" tensor, and the API would use + the ``tensor`` from first rank of the DeviceMesh dimension as the source of truth to preserve + the single-device semantic. If you want to construct a DTensor in the middle of the Autograd + computation, please use :meth:`DTensor.from_local` instead. + + Args: + tensor (torch.Tensor): torch.Tensor to be distributed. Note that if you + want to shard a tensor on a dimension that is not evenly divisible by + the number of devices in that mesh dimension, we use ``torch.chunk`` + semantic to shard the tensor and scatter the shards. The uneven sharding + behavior is experimental and subject to change. + device_mesh (:class:`DeviceMesh`, optional): DeviceMesh to distribute the + tensor, if not specified, must be called under a DeviceMesh context + manager, default: None + placements (List[:class:`Placement`], optional): the placements that + describes how to place the tensor on DeviceMesh, must have the same + number of elements as ``device_mesh.ndim``. If not specified, we will + by default replicate the tensor across the ``device_mesh`` from the + first rank of each dimension of the `device_mesh`. + + Keyword args: + src_data_rank (int, optional): the rank of the source data for the logical/global tensor, it is + used by :meth:`distribute_tensor` to scatter/broadcast the shards/replicas to other ranks. + By default, we use ``group_rank=0`` on each DeviceMesh dimension as the source data to preserve + the single-device semantic. If passing ``None`` explicitly, :meth:`distribute_tensor` simply uses + its local data instead of trying to preserve the single-device semantic via scatter/broadcast. + Default: 0 + + Returns: + A :class:`DTensor` or ``XLAShardedTensor`` object. + + .. note:: + When initialize the DeviceMesh with the ``xla`` device_type, ``distribute_tensor`` + return `XLAShardedTensor` instead. see `this issue `__ + for more details. The XLA integration is experimental and subject to change. + """ + + torch._C._log_api_usage_once("torch.dtensor.distribute_tensor") + + # get default device mesh if there's nothing specified + device_mesh = device_mesh or _mesh_resources.get_current_mesh() + device_type = device_mesh.device_type + if device_type == "xla": + try: + # call PyTorch/XLA SPMD for `xla` backend type device mesh. + # This returns XLAShardedTensor + from torch_xla.distributed.spmd import ( # type:ignore[import] + xla_distribute_tensor, + ) + + return xla_distribute_tensor(tensor, device_mesh, placements) # type:ignore[return-value] + except ImportError as e: + msg = "To use DTensor API with xla, you must install the torch_xla package!" + raise ImportError(msg) from e + + if not tensor.is_leaf: + raise RuntimeError( + "`distribute_tensor` should be used to distribute leaf tensors! but found non-leaf tensor!" + ) + + # convert tensor to the corresponding device type if it's not in that device type + if device_type != tensor.device.type and not tensor.is_meta: + tensor = tensor.to(device_type) + + # set default placements to replicated if not specified + if placements is None: + placements = [Replicate() for _ in range(device_mesh.ndim)] + + if len(placements) != device_mesh.ndim: + raise ValueError( + f"`placements` must have the same length as `device_mesh.ndim`! " + f"Found placements length: {len(placements)}, and device_mesh.ndim: {device_mesh.ndim}." + ) + if isinstance(tensor, DTensor): + # if the tensor is already a DTensor, we need to check: + # 1. if the we can further shard this DTensor if the two device mesh belong to + # the same parenet mesh and further sharding is possible. + # 2. check if device mesh and placements are the same + if tensor.device_mesh != device_mesh: + raise ValueError( + f"Cannot distribute a DTensor with device mesh {tensor.device_mesh} " + f"to a different device mesh {device_mesh}." + ) + if tensor.placements != tuple(placements): + raise ValueError( + f"Cannot distribute a DTensor with placements {tensor.placements} " + f"to a different placements {placements}. do you want to call " + f"`redistribute` instead?" + ) + return tensor + + local_tensor = tensor.detach() + + # TODO(xilun): address sharding order + # distribute the tensor according to the placements. + placements = list(placements) + for idx, placement in enumerate(placements): + if isinstance(placement, Shard | _StridedShard): + placement_dim = ( + placement.dim + tensor.ndim if placement.dim < 0 else placement.dim + ) + if isinstance(placement, Shard): + local_tensor = Shard._make_shard_tensor( + placement_dim, local_tensor, device_mesh, idx, src_data_rank + ) + placements[idx] = Shard(placement_dim) + else: + local_tensor = _StridedShard._make_shard_tensor( + placement_dim, + local_tensor, + device_mesh, + idx, + src_data_rank, + split_factor=placement.split_factor, + ) + placements[idx] = _StridedShard( + placement_dim, split_factor=placement.split_factor + ) + elif isinstance(placement, Replicate): + local_tensor = Replicate._make_replicate_tensor( + local_tensor, device_mesh, idx, src_data_rank + ) + elif isinstance(placement, Partial): + local_tensor = Replicate._make_replicate_tensor( + local_tensor, device_mesh, idx, src_data_rank + ) + local_tensor = placement._partition_value(local_tensor, device_mesh, idx) + else: + raise RuntimeError( + f"Trying to distribute tensor with unsupported placements {placement} on device mesh dimension {idx}!" + ) + placements = tuple(placements) + + assert local_tensor is not None, "distributing a tensor should not be None" + # detach the local tensor passed to DTensor since after the construction + # of DTensor, autograd would work on top of DTensor instead of local tensor + spec = DTensorSpec( + mesh=device_mesh, + placements=placements, + tensor_meta=TensorMeta( + shape=tensor.size(), + stride=tensor.stride(), + dtype=tensor.dtype, + ), + ) + # pyrefly: ignore [bad-argument-type] + return DTensor( + # pyrefly: ignore [bad-argument-count] + local_tensor.requires_grad_(tensor.requires_grad), + spec, + # pyrefly: ignore [unexpected-keyword] + requires_grad=tensor.requires_grad, + ) + + +@deprecated("Please use `distribute_tensor` with `src_data_rank=None` instead.") +def _shard_tensor( + full_tensor: torch.Tensor, + placements: Sequence[Shard], + device_mesh: DeviceMesh | None = None, +) -> "DTensor": + """ + Locally shards a full tensor based on indicated sharding arrangement, and + returns a DTensor containing the local shard. + + .. warning:: This is a private API that is subject to change. It skips the + communication otherwise required by `distribute_tensor`. It is only + applicable to cases where all ranks have the same `full_tensor`. For + example, in distributed inference all ranks load from the same + checkpoint. This API will not check for data equality between ranks, it + is thus user's responsibility to ensure the `full_tensor` is the same + across ranks. + + Args: + full_tensor (torch.Tensor): the full tensor to be sharded. + placements (Sequence[:class:`Shard`]): the placements that + describes how to place the local tensor on DeviceMesh. + device_mesh (:class:`DeviceMesh`, optional): DeviceMesh to place the + DTensor. Must have same dimension as the number of placements. + If not specified, would be retrieve from current context. + + Returns: + A :class:`DTensor` object with the shard as its local tensor. + + Examples: + >>> # xdoctest: +SKIP("need world_size and rank") + >>> device_mesh = dist.init_device_mesh("cuda", (world_size,)) + >>> full_tensor = torch.arange(world_size, device=f"cuda:{rank}") + >>> dtensor = _shard_tensor(full_tensor, [Shard(1)], device_mesh) + """ + return distribute_tensor(full_tensor, device_mesh, placements, src_data_rank=None) + + +def distribute_module( + module: nn.Module, + device_mesh: DeviceMesh | None = None, + partition_fn: Callable[[str, nn.Module, DeviceMesh], None] | None = None, + input_fn: Callable[[nn.Module, Any, DeviceMesh], None] | None = None, + output_fn: Callable[[nn.Module, Any, DeviceMesh], None] | None = None, +) -> nn.Module: + """ + This function expose three functions to control the parameters/inputs/outputs of the module: + + 1. To perform sharding on the module before runtime execution by specifying the + ``partition_fn`` (i.e. allow user to convert Module parameters to :class:`DTensor` + parameters according to the `partition_fn` specified). + 2. To control the inputs or outputs of the module during runtime execution by + specifying the ``input_fn`` and ``output_fn``. (i.e. convert the input to + :class:`DTensor`, convert the output back to ``torch.Tensor``) + + Args: + module (:class:`nn.Module`): user module to be partitioned. + device_mesh (:class:`DeviceMesh`): the device mesh to place the module. + partition_fn (Callable): the function to partition parameters (i.e. shard certain + parameters across the ``device_mesh``). If ``partition_fn`` is not specified, + by default we replicate all module parameters of ``module`` across the mesh. + input_fn (Callable): specify the input distribution, i.e. could control how the + input of the module is sharded. ``input_fn`` will be installed as a module + ``forward_pre_hook`` (pre forward hook). + output_fn (Callable): specify the output distribution, i.e. could control how the + output is sharded, or convert it back to torch.Tensor. ``output_fn`` will be + installed as a module ``forward_hook`` (post forward hook). + + Returns: + A module that contains parameters/buffers that are all ``DTensor`` s. + + .. note:: + When initialize the DeviceMesh with the ``xla`` device_type, ``distribute_module`` + return nn.Module with PyTorch/XLA SPMD annotated parameters. See + `this issue `__ + for more details. The XLA integration is experimental and subject to change. + + """ + + torch._C._log_api_usage_once("torch.dtensor.distribute_module") + + already_distributed = getattr(module, "_distribute_module_applied", False) + if already_distributed: + raise RuntimeError( + "distribute_module should only be called once on a module, " + "but it has already been called on this module!" + ) + + device_mesh = device_mesh or _mesh_resources.get_current_mesh() + device_type = device_mesh.device_type + if device_type == "xla": + try: + # This function annotates all module parameters for auto-partitioning with + # PyTorch/XLA SPMD or explicitly partition to :class:`XLAShardedTensor` parameters + # according to the `partition_fn` specified. + from torch_xla.distributed.spmd import ( # type:ignore[import] + xla_distribute_module, + ) + + return xla_distribute_module( + module, device_mesh, partition_fn, input_fn, output_fn + ) # type:ignore[return-value] + except ImportError as e: + msg = "To use DTensor API with xla, you must install the torch_xla package!" + raise ImportError(msg) from e + + def replicate_module_params_buffers(m: nn.Module, mesh: DeviceMesh) -> None: + # This function loop over the immediate module parameters and + # buffers, replicate all non DTensor params/buffers to DTensor + # parameters/buffers, if they have not been partitioned in the + # partition_fn, we can't easily use `module._apply` here + # because we don't know what happened inside partition_fn as + # user could do anything, i.e. install hooks, and we want to + # preserve those. + full_replicate = [Replicate()] * mesh.ndim + for key, param in m._parameters.items(): + if param is not None and not isinstance(param, DTensor): + m.register_parameter( + key, + nn.Parameter(distribute_tensor(param.data, mesh, full_replicate)), + ) + for key, buffer in m._buffers.items(): + if buffer is not None and not isinstance(buffer, DTensor): + m._buffers[key] = distribute_tensor(buffer, mesh, full_replicate) + + if partition_fn is None: + # if partition_fn not specified, we by default replicate + # all module params/buffers + for submod in module.modules(): + replicate_module_params_buffers(submod, device_mesh) + else: + # apply partition_fun to submodules + for name, submod in module.named_modules(): + partition_fn(name, submod, device_mesh) + replicate_module_params_buffers(submod, device_mesh) + + # register input_fn as module forward pre hook + if input_fn is not None: + # check the input_fn signature + num_args = len(inspect.signature(input_fn).parameters) + if num_args == 2: + # input_fn only takes in inputs and device mesh + warnings.warn( + "Deprecating input_fn that takes two arguments (inputs, device_mesh), " + "please use input_fn that takes in (module, inputs, device_mesh) instead!", + FutureWarning, + stacklevel=2, + ) + module.register_forward_pre_hook( + lambda _, inputs: input_fn(inputs, device_mesh) # type: ignore[call-arg] + ) + elif num_args == 3: + # input_fn takes in module, inputs, device mesh + module.register_forward_pre_hook( + lambda mod, inputs: input_fn(mod, inputs, device_mesh) + ) + else: + raise ValueError( + f"input_fn should take in 3 arguments, but got {num_args} arguments!" + ) + # register output_fn as module forward hook + if output_fn is not None: + num_args = len(inspect.signature(output_fn).parameters) + if num_args == 2: + # output_fn only takes in outputs and device mesh + warnings.warn( + "Deprecating output_fn that takes two arguments (inputs, device_mesh), " + "please use output_fn that takes in (module, inputs, device_mesh) instead!", + FutureWarning, + stacklevel=2, + ) + module.register_forward_hook( + lambda mod, inputs, outputs: output_fn(outputs, device_mesh) # type: ignore[call-arg] + ) + elif num_args == 3: + module.register_forward_hook( + lambda mod, inputs, outputs: output_fn(mod, outputs, device_mesh) + ) + else: + raise ValueError( + f"output_fn should take in 3 arguments, but got {num_args} arguments!" + ) + + module._distribute_module_applied = True # type: ignore[assignment] + return module + + +# Below are tensor factory function APIs, which are used to create a DTensor directly. We need +# to make separate factory function APIs because tensor subclass could not override the tensor +# factory methods, and we need user to call the factory functions with user intended device_mesh +# and placements to create a proper DTensor. + + +def _dtensor_init_helper( # type: ignore[no-untyped-def] + init_op, + size: torch.Size, + device_mesh: DeviceMesh | None = None, + placements: Sequence[Placement] | None = None, + **kwargs, +) -> DTensor: + # if device_mesh is None, use the one from mesh resources + device_mesh = device_mesh or _mesh_resources.get_current_mesh() + kwargs["device"] = device_mesh.device_type + + # set default placements to replicated if not specified + placements = placements or tuple(Replicate() for _ in range(device_mesh.ndim)) + + # check device_mesh against placements + assert device_mesh.ndim == len(placements), ( + "mesh dimension does not match the length of placements" + ) + + assert kwargs["layout"] == torch.strided, "layout value not supported!" + torch_stride = torch._prims_common.make_contiguous_strides_for(size) + + # get local tensor shape + local_shape, _ = compute_local_shape_and_global_offset( + size, device_mesh, placements, skip_offset=True + ) + + # initialize the local tensor + if init_op is torch.full: + fill_value = kwargs.pop("fill_value", 0) + local_tensor = init_op(local_shape, fill_value, **kwargs) + elif init_op is torch.rand or init_op is torch.randn: + # this tensor meta is not used except `shape` + dtype = kwargs.get("dtype", torch.get_default_dtype()) + + tensor_meta = TensorMeta(size, (0,), dtype) + spec = DTensorSpec(device_mesh, tuple(placements), tensor_meta=tensor_meta) + + if random.is_rng_supported_mesh(device_mesh) and not random._rng_tracker: + random._rng_tracker = random.OffsetBasedRNGTracker(device_mesh) + + assert random._rng_tracker is not None + with random._rng_tracker._distribute_region(spec): + local_tensor = init_op(local_shape, **kwargs) + else: + local_tensor = init_op(local_shape, **kwargs) + + spec = DTensorSpec( + device_mesh, + tuple(placements), + tensor_meta=TensorMeta( + size, + torch_stride, + local_tensor.dtype, + ), + ) + + # pyrefly: ignore [bad-argument-type] + return DTensor( + # pyrefly: ignore [bad-argument-count] + local_tensor, + spec, + # pyrefly: ignore [unexpected-keyword] + requires_grad=kwargs["requires_grad"], + ) + + +def ones( # type: ignore[no-untyped-def] + *size, + dtype: torch.dtype | None = None, + layout: torch.layout = torch.strided, + requires_grad: bool = False, + device_mesh: DeviceMesh | None = None, + placements: Sequence[Placement] | None = None, +) -> DTensor: + """ + Returns a :class:`DTensor` filled with the scalar value 1, with the shape defined + by the variable argument ``size``. + + Args: + size (int...): a sequence of integers defining the shape of the output :class:`DTensor`. + Can be a variable number of arguments or a collection like a list or tuple. + E.g.: ones(1,2,3..) or ones([1,2,3..]) or ones((1,2,3..)) + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned :class:`DTensor`. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned DTensor. + Default: ``torch.strided``. + requires_grad (bool, optional): If autograd should record operations on the + returned :class:`DTensor`. Default: ``False``. + device_mesh: :class:`DeviceMesh` type, contains the mesh info of ranks + placements: a sequence of :class:`Placement` type: ``Shard``, ``Replicate`` + + Returns: + A :class:`DTensor` object on each rank + """ + torch_size = normalize_to_torch_size(size) + + return _dtensor_init_helper( + torch.ones, + torch_size, + dtype=dtype, + layout=layout, + requires_grad=requires_grad, + device_mesh=device_mesh, + placements=placements, + ) + + +def empty( # type: ignore[no-untyped-def] + *size, + dtype: torch.dtype | None = None, + layout: torch.layout = torch.strided, + requires_grad: bool = False, + device_mesh: DeviceMesh | None = None, + placements: Sequence[Placement] | None = None, +) -> DTensor: + """ + Returns a :class:`DTensor` filled with uninitialized data. The shape of the :class:`DTensor` + is defined by the variable argument ``size``. + + Args: + size (int...): a sequence of integers defining the shape of the output :class:`DTensor`. + Can be a variable number of arguments or a collection like a list or tuple. + E.g.: empty(1,2,3..) or empty([1,2,3..]) or empty((1,2,3..)) + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned :class:`DTensor`. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`).\ + layout (:class:`torch.layout`, optional): the desired layout of returned :class:`DTensor`. + Default: ``torch.strided``. + requires_grad (bool, optional): If autograd should record operations on the + returned :class:`DTensor`. Default: ``False``. + device_mesh: :class:`DeviceMesh` type, contains the mesh info of ranks + placements: a sequence of :class:`Placement` type: ``Shard``, ``Replicate`` + + Returns: + A :class:`DTensor` object on each rank + """ + torch_size = normalize_to_torch_size(size) + + return _dtensor_init_helper( + torch.empty, + torch_size, + dtype=dtype, + layout=layout, + requires_grad=requires_grad, + device_mesh=device_mesh, + placements=placements, + ) + + +def full( # type: ignore[no-untyped-def] + size, + fill_value, + *, + dtype: torch.dtype | None = None, + layout: torch.layout = torch.strided, + requires_grad: bool = False, + device_mesh: DeviceMesh | None = None, + placements: Sequence[Placement] | None = None, +) -> DTensor: + """ + Returns a :class:`DTensor` filled with ``fill_value`` according to ``device_mesh`` and + ``placements``, with the shape defined by the argument ``size``. + + Args: + size (int...): a sequence of integers defining the shape of the output :class:`DTensor`. + Can be a variable number of arguments or a collection like a list or tuple. + E.g.: ones(1,2,3..) or ones([1,2,3..]) or ones((1,2,3..)) + fill_value(Scalar): the value to fill the output tensor with. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned :class:`DTensor`. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned DTensor. + Default: ``torch.strided``. + requires_grad (bool, optional): If autograd should record operations on the + returned :class:`DTensor`. Default: ``False``. + device_mesh: :class:`DeviceMesh` type, contains the mesh info of ranks. + placements: a sequence of :class:`Placement` type: ``Shard``, ``Replicate`` + + Returns: + A :class:`DTensor` object on each rank + """ + torch_size = normalize_to_torch_size(size) + + return _dtensor_init_helper( + torch.full, + torch_size, + fill_value=fill_value, + dtype=dtype, + layout=layout, + requires_grad=requires_grad, + device_mesh=device_mesh, + placements=placements, + ) + + +def rand( # type: ignore[no-untyped-def] + *size, + requires_grad: bool = False, + dtype: torch.dtype | None = None, + layout: torch.layout = torch.strided, + device_mesh: DeviceMesh | None = None, + placements: Sequence[Placement] | None = None, +) -> DTensor: + """ + Returns a :class:`DTensor` filled with random numbers from a uniform distribution + on the interval ``[0, 1)``. The shape of the tensor is defined by the variable + argument ``size``. + + Args: + size (int...): a sequence of integers defining the shape of the output :class:`DTensor`. + Can be a variable number of arguments or a collection like a list or tuple. + E.g.: ones(1,2,3..) or ones([1,2,3..]) or ones((1,2,3..)) + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned :class:`DTensor`. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned DTensor. + Default: ``torch.strided``. + requires_grad (bool, optional): If autograd should record operations on the + returned :class:`DTensor`. Default: ``False``. + device_mesh: :class:`DeviceMesh` type, contains the mesh info of ranks. + placements: a sequence of :class:`Placement` type: ``Shard``, ``Replicate`` + + Returns: + A :class:`DTensor` object on each rank + """ + torch_size = normalize_to_torch_size(size) + + return _dtensor_init_helper( + torch.rand, + torch_size, + dtype=dtype, + layout=layout, + requires_grad=requires_grad, + device_mesh=device_mesh, + placements=placements, + ) + + +def randn( # type: ignore[no-untyped-def] + *size, + requires_grad: bool = False, + dtype: torch.dtype | None = None, + layout: torch.layout = torch.strided, + device_mesh: DeviceMesh | None = None, + placements: Sequence[Placement] | None = None, +) -> DTensor: + """ + Returns a :class:`DTensor` filled with random numbers from a normal distribution + with mean 0 and variance 1. The shape of the tensor is defined by the variable + argument ``size``. + + Args: + size (int...): a sequence of integers defining the shape of the output :class:`DTensor`. + Can be a variable number of arguments or a collection like a list or tuple. + E.g.: ones(1,2,3..) or ones([1,2,3..]) or ones((1,2,3..)) + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned :class:`DTensor`. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned DTensor. + Default: ``torch.strided``. + requires_grad (bool, optional): If autograd should record operations on the + returned :class:`DTensor`. Default: ``False``. + device_mesh: :class:`DeviceMesh` type, contains the mesh info of ranks. + placements: a sequence of :class:`Placement` type: ``Shard``, ``Replicate`` + + Returns: + A :class:`DTensor` object on each rank + """ + torch_size = normalize_to_torch_size(size) + + return _dtensor_init_helper( + torch.randn, + torch_size, + dtype=dtype, + layout=layout, + requires_grad=requires_grad, + device_mesh=device_mesh, + placements=placements, + ) + + +def zeros( # type: ignore[no-untyped-def] + *size, + requires_grad: bool = False, + dtype: torch.dtype | None = None, + layout: torch.layout = torch.strided, + device_mesh: DeviceMesh | None = None, + placements: Sequence[Placement] | None = None, +) -> DTensor: + """ + Returns a :class:`DTensor` filled with the scalar value 0. + + Args: + size (int...): a sequence of integers defining the shape of the output :class:`DTensor`. + Can be a variable number of arguments or a collection like a list or tuple. + E.g.: zeros(1,2,3..) or zeros([1,2,3..]) or zeros((1,2,3..)) + Keyword args: + requires_grad (bool, optional): If autograd should record operations on the + returned :class:`DTensor`. Default: ``False``. + dtype (:class:`torch.dtype`, optional): the desired data type of returned :class:`DTensor`. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned :class:`DTensor`. + Default: ``torch.strided``. + device_mesh: :class:`DeviceMesh` type, contains the mesh info of ranks + placements: a sequence of :class:`Placement` type: ``Shard``, ``Replicate`` + + Returns: + A :class:`DTensor` object on each rank + """ + torch_size = normalize_to_torch_size(size) + + return _dtensor_init_helper( + torch.zeros, + torch_size, + dtype=dtype, + layout=layout, + requires_grad=requires_grad, + device_mesh=device_mesh, + placements=placements, + ) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/_argmin_argmax.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/_argmin_argmax.py new file mode 100644 index 0000000000000000000000000000000000000000..730291a7926b3130e23a0d1b98b6d6170fca03c3 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/_argmin_argmax.py @@ -0,0 +1,120 @@ +import operator +from functools import reduce + +import torch +import torch.distributed._functional_collectives as funcol +import torch.distributed.tensor._api as dtensor +from torch.distributed.tensor._utils import compute_local_shape_and_global_offset +from torch.distributed.tensor.placement_types import Partial, Replicate, Shard + + +_REDUCTION_OPS = { + torch.ops.aten.argmax.default: torch.max, + torch.ops.aten.argmin.default: torch.min, +} + + +def argmin_argmax_handler( + op_call: torch._ops.OpOverload, + args: tuple["dtensor.DTensor", int] | tuple["dtensor.DTensor", int, bool], + kwargs: dict[str, object], +): + """ + Handles reduces on sharded dimensions locally to limit calls to replicate. + """ + op_info = dtensor.DTensor._op_dispatcher.unwrap_to_op_info(op_call, args, kwargs) + dtensor.DTensor._op_dispatcher.sharding_propagator.propagate(op_info) + output_sharding = op_info.output_sharding + assert output_sharding is not None, "output sharding should not be None" + if op_call not in _REDUCTION_OPS: + raise NotImplementedError(f"Unsupported reduction op: {op_call}") + val_op = _REDUCTION_OPS[op_call] + + input_dtensor = args[0] + if not isinstance(input_dtensor, dtensor.DTensor): + raise NotImplementedError + + dim: int | None = args[1] if len(args) > 1 else None # type: ignore[assignment] + keepdim = args[2] if len(args) > 2 else False + + placements = input_dtensor.placements + + # check for partial placements and handle it as replicate. + if any(isinstance(p, Partial) for p in placements): + target_placements = [ + Replicate() if isinstance(p, Partial) else p for p in placements + ] + input_dtensor = input_dtensor.redistribute( + device_mesh=input_dtensor.device_mesh, placements=target_placements + ) + placements = input_dtensor.placements + local_tensor = input_dtensor.to_local() + + input_shape = list(local_tensor.shape) + if dim is None: + expected_shape = ( + torch.Size([1] * len(input_shape)) if keepdim else torch.Size([]) + ) + elif keepdim: + if input_shape: + input_shape[dim] = 1 + expected_shape = torch.Size(input_shape) + else: + if input_shape: + input_shape.pop(dim) + expected_shape = torch.Size(input_shape) + + shard_mesh_dims = [] + for mesh_dim, p in enumerate(placements): + if isinstance(p, Shard): + if dim is None or p.dim == (dim if dim >= 0 else local_tensor.ndim + dim): + shard_mesh_dims.append(mesh_dim) + + device_mesh = input_dtensor.device_mesh + + if dim is None: + local_idx = op_call(local_tensor) + local_max = local_tensor.flatten()[local_idx] + else: + local_max, local_idx = val_op(local_tensor, dim=dim, keepdim=True) + + if not shard_mesh_dims: + return dtensor.DTensor._op_dispatcher.wrap( + local_idx.reshape(expected_shape), output_sharding.output_spec + ) + + # find the correct offset for sharded dim + global_shape = input_dtensor.shape + _, global_offset = compute_local_shape_and_global_offset( + global_shape, device_mesh, placements + ) + gathered_maxes = local_max + if dim is None: + local_coord = torch.unravel_index(local_idx, local_tensor.shape) + global_coord = torch.stack(local_coord) + gather_dim = 0 + for i, offset in enumerate(global_offset): + global_coord[i] += offset + # compute with proper striding + gathered_idxs = torch.tensor(0, device=local_tensor.device, dtype=torch.long) + for i, coord in enumerate(global_coord): + gathered_idxs += coord * reduce(operator.mul, global_shape[i + 1 :], 1) + else: + gather_dim = dim + gathered_idxs = local_idx + global_offset[dim] + + for mesh_dim in shard_mesh_dims: + gathered_maxes = funcol.all_gather_tensor( + gathered_maxes, gather_dim=gather_dim, group=(device_mesh, mesh_dim) + ) + gathered_idxs = funcol.all_gather_tensor( + gathered_idxs, gather_dim=gather_dim, group=(device_mesh, mesh_dim) + ) + + rank_winner = op_call(gathered_maxes, dim, True) + + final_idx = torch.gather(gathered_idxs, dim=gather_dim, index=rank_winner) + + return dtensor.DTensor._op_dispatcher.wrap( + final_idx.reshape(expected_shape), output_sharding.output_spec + ) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/_collective_utils.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/_collective_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..766b030ad9524d7c3e8dc185ac3ff3056e4bcc77 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/_collective_utils.py @@ -0,0 +1,396 @@ +# mypy: allow-untyped-defs +import logging +import math +from dataclasses import dataclass +from functools import lru_cache +from typing import Optional + +import torch +import torch.distributed._functional_collectives as funcol +import torch.distributed.tensor._dtensor_spec as dtensor_spec +from torch._C._distributed_c10d import _resolve_process_group +from torch._logging import warning_once +from torch.distributed._local_tensor import ( + local_tensor_mode, + maybe_run_for_local_tensor, +) +from torch.distributed.device_mesh import _mesh_resources, DeviceMesh +from torch.distributed.distributed_c10d import ( + _get_group_size_by_name, + broadcast, + get_group_rank, + get_rank, + ProcessGroup, + scatter, + Work, +) + + +logger = logging.getLogger(__name__) + + +@torch.library.register_fake("_dtensor::shard_dim_alltoall") +def _shard_dim_alltoall_meta(input, gather_dim, shard_dim, group_name): + group_size = _get_group_size_by_name(group_name) + stacked_list = [torch.empty_like(input) for _ in range(group_size)] + group = _resolve_process_group(group_name) + group_rank = get_group_rank(group, get_rank()) + + return ( + torch.cat(stacked_list, dim=gather_dim) + .chunk(group_size, dim=shard_dim)[group_rank] + .contiguous() + ) + + +def shard_dim_alltoall(input, gather_dim, shard_dim, mesh, mesh_dim): + if mesh.device_type == "cpu" and local_tensor_mode() is None: + # Gloo does not support alltoall, so falling back to allgather + chunk + warning_once( + logger, + "CPU process group does not support alltoall yet, falling back with allgather + chunk!", + ) + out = funcol.all_gather_tensor(input, gather_dim, (mesh, mesh_dim)) + if isinstance(out, funcol.AsyncCollectiveTensor): + # stick to the same behavior for the alltoall case, remove this once we enable alltoall async + out = out.wait() + out = torch.chunk(out, mesh.size(mesh_dim), dim=shard_dim)[ + mesh.get_local_rank(mesh_dim) + ] + return out.contiguous() + + group_name = funcol._resolve_group_name((mesh, mesh_dim)) + # TODO: enable async op for shard_dim_alltoall + return torch.ops._dtensor.shard_dim_alltoall( + input, gather_dim, shard_dim, group_name + ) + + +def mesh_scatter( + output: torch.Tensor, + scatter_list: list[torch.Tensor], + mesh: DeviceMesh, + mesh_dim: int = 0, + async_op: bool = False, + *, + group_src: int = 0, +) -> Work | None: + """ + scatter a list of tensors to a device mesh dimension. We by default + use the first rank of the mesh dimension as the source of truth, i.e + for a 2d mesh [[0, 1], [2, 3]], if we scatter on mesh_dim = 1, we will + scatter the tensor list on rank 0 to rank 0/1, and tensor list on rank + 2 to rank 2/3. + + Args: + output (torch.Tensor): the tensor to receive the scattered list. + scatter_list (List[torch.Tensor]): the tensor list to be scattered. + mesh_dim (int, optional): indicate which mesh dimension we want + to scatter on, we by default choose the first rank on the + mesh dimension as source of truth. + + Keyword args: + group_src (int, optional): the group rank of the source data for the + logical/global tensor, on the specific mesh dimension. By default, we + use ``group_rank=0`` on each DeviceMesh dimension as the source data + to preserve the single-device semantic. If passing ``None`` explicitly, + this method simply uses its local data with no communication. + + Returns: + A :class:`Work` object + """ + # TODO: Ideally we should use the meta tensor way + # (to register a meta kernel for the collective op) + # so that it would avoid the communication. Need to + # remove the check below once that is done. + if output.is_meta: + return None + dim_group = mesh.get_group(mesh_dim) + assert isinstance(dim_group, ProcessGroup) + + if group_src == get_rank(dim_group): + fut = scatter( + output, + scatter_list=scatter_list, + group=dim_group, + async_op=async_op, + group_src=group_src, + ) + else: + fut = scatter( + output, + scatter_list=None, + group=dim_group, + async_op=async_op, + group_src=group_src, + ) + + return fut + + +def mesh_broadcast( + tensor: torch.Tensor, + mesh: DeviceMesh, + mesh_dim: int = 0, + async_op: bool = False, + *, + group_src: int = 0, +) -> Work | None: + """ + broadcast the tensor to a device mesh dimension. We by default + use the first rank of the mesh dimension as the source of truth, i.e + for a 2d mesh [[0, 1], [2, 3]], if we broadcast on mesh_dim = 1, we will + broadcast the tensor on rank 0 to rank 0/1, and tensor on rank 2 + to rank 2/3. + + Args: + tensor (torch.Tensor): tensor to broadcast. + mesh_dim (int, optional): indicate which mesh dimension we want + to scatter on, we by default choose the first rank on the + mesh dimension as source of truth. + + Keyword args: + group_src (int, optional): the group rank of the source data for the + logical/global tensor, on the specific mesh dimension. By default, we + use ``group_rank=0`` on each DeviceMesh dimension as the source data + to preserve the single-device semantic. If passing ``None`` explicitly, + this method simply uses its local data with no communication. + + Returns: + A :class:`Work` object + """ + # TODO: Ideally we should use the meta tensor way + # (to register a meta kernel for the collective op) + # so that it would avoid the communication. Need to + # remove the check below once that is done. + if tensor.is_meta: + return None + dim_group = mesh.get_group(mesh_dim) + assert isinstance(dim_group, ProcessGroup) + + return broadcast(tensor, group=dim_group, async_op=async_op, group_src=group_src) + + +@maybe_run_for_local_tensor +def pad_tensor(tensor: torch.Tensor, pad_dim: int, pad_size: int) -> torch.Tensor: + if pad_size == 0: + return tensor + pad = [0, 0] * (tensor.ndim - pad_dim) + pad[-1] = pad_size + return torch.nn.functional.pad(tensor, pad) + + +@maybe_run_for_local_tensor +def unpad_tensor(tensor: torch.Tensor, pad_dim: int, pad_size: int) -> torch.Tensor: + if pad_size == 0: + return tensor + return tensor.narrow( + pad_dim, + start=0, + length=tensor.size(pad_dim) - pad_size, + ) + + +def fill_empty_tensor_to_shards( + shards: list[torch.Tensor], shard_dim: int, num_empty_tensors: int +) -> list[torch.Tensor]: + if num_empty_tensors == 0: + return shards + tensor_size = list(shards[0].size()) + tensor_size[shard_dim] = 0 + tensor = shards[0].new_zeros(tensor_size) + shards.extend(tensor for _ in range(num_empty_tensors)) + return shards + + +def check_tensor_meta( + local_tensor, check_shape_stride=False +) -> Optional["dtensor_spec.TensorMeta"]: + local_metadata = { + "dtype": local_tensor.dtype, + "requires_grad": local_tensor.requires_grad, + } + + if check_shape_stride: + local_metadata.update( + {"shape": local_tensor.shape, "stride": local_tensor.stride()} + ) + + gathered_metadata = [None for _ in range(torch.distributed.get_world_size())] + torch.distributed.all_gather_object(gathered_metadata, local_metadata) + + # Check if metadata is consistent across ranks + if not all(meta == local_metadata for meta in gathered_metadata): + raise ValueError( + "Inconsistent tensor metadata (including shape and stride) across ranks." + ) + return None + + +def spec_to_bytes(spec: "dtensor_spec.DTensorSpec") -> int: + assert spec.tensor_meta is not None, "spec should have tensor meta defined!" + return spec.tensor_meta.dtype.itemsize * math.prod(spec.shape) + + +@dataclass +class MeshTopoInfo: + """ + Mesh information for collective cost estimation + """ + + mesh: DeviceMesh + mesh_dim_devices: list[int] + mesh_dim_bandwidth: list[float] + mesh_dim_latency: list[float] + + @staticmethod + @lru_cache(None) + def build_from_mesh(mesh: DeviceMesh) -> "MeshTopoInfo": + # Generate mesh topology info for intra-host/inter-host communication pattern + # Note that we made bunch of assumptions for simplicity: + # 1. we assume the mesh is homogeneous, and it's gpu/nccl model + # 2. we assume gpu arch is Ampere or Hopper + # 3. we assume collectives are all ring base algo for now + num_devices_per_host = _mesh_resources.num_devices_per_host(mesh.device_type) + # the base bw number (intra-node), GB/s + base_bw = 87.7 + mesh_dim_bandwidth = [base_bw] * mesh.ndim + # the latency in terms of us (intra-node, nv-link) + mesh_dim_latency = [0.6] * mesh.ndim + mesh_dim_devices = [1] * mesh.ndim + + total_num_devices = 1 + for mesh_dim in reversed(range(mesh.ndim)): + num_devices = mesh.size(mesh_dim) + mesh_dim_devices[mesh_dim] = num_devices + total_num_devices *= num_devices + if total_num_devices > num_devices_per_host: + # magic number for inter-host communication bandwidth/latency factor + # This number assumes latest GPU arch, i.e. Ampere or Hopper + # TODO: see if we need to tweak this or offer a way for user + # to specify the bandwidths/latency + mesh_dim_bandwidth[mesh_dim] *= 0.22 + # set to ethernet latency for inter-host + mesh_dim_latency[mesh_dim] = 2.7 + + return MeshTopoInfo( + mesh, mesh_dim_devices, mesh_dim_bandwidth, mesh_dim_latency + ) + + +def allgather_cost(bytes_gb: float, mesh_topo: MeshTopoInfo, mesh_dim: int) -> float: + num_devices_on_mesh_dim = mesh_topo.mesh_dim_devices[mesh_dim] + mesh_dim_bandwidth = mesh_topo.mesh_dim_bandwidth[mesh_dim] + num_hops = num_devices_on_mesh_dim - 1 + # base latency + comm latency + latency = 6.6 + num_hops * mesh_topo.mesh_dim_latency[mesh_dim] # us + bw = (bytes_gb * num_hops / num_devices_on_mesh_dim) / mesh_dim_bandwidth # s + return latency + bw * 1e6 # rescale to us + + +def allreduce_cost(bytes_gb: float, mesh_topo: MeshTopoInfo, mesh_dim: int) -> float: + num_devices_on_mesh_dim = mesh_topo.mesh_dim_devices[mesh_dim] + mesh_dim_bandwidth = mesh_topo.mesh_dim_bandwidth[mesh_dim] + # allreduce have almost 2x comm bytes compare to allgather/reduce_scatter + num_hops = 2 * (num_devices_on_mesh_dim - 1) + + latency = 6.6 + num_hops * mesh_topo.mesh_dim_latency[mesh_dim] + bw = (bytes_gb * num_hops / num_devices_on_mesh_dim) / mesh_dim_bandwidth + return latency + bw * 1e6 + + +def reduce_scatter_cost( + bytes_gb: float, + mesh_topo: MeshTopoInfo, + mesh_dim: int, +) -> float: + num_devices_on_mesh_dim = mesh_topo.mesh_dim_devices[mesh_dim] + mesh_dim_bandwidth = mesh_topo.mesh_dim_bandwidth[mesh_dim] + num_hops = num_devices_on_mesh_dim - 1 + # base latency + comm latency + latency = 6.6 + num_hops * mesh_topo.mesh_dim_latency[mesh_dim] + bw = (bytes_gb * num_hops / num_devices_on_mesh_dim) / mesh_dim_bandwidth + return latency + bw * 1e6 + + +def redistribute_cost( + current_spec: "dtensor_spec.DTensorSpec", + target_spec: "dtensor_spec.DTensorSpec", +) -> float: + """ + This function returns the cost of redistribute from current to target DTensorSpec. + + NOTE: + 1. Only consider communication cost here, since computation costs for redistribute + are quite trivial (i.e. we only need to narrow or simple division) + 2. Only consider redistribute cost on same mesh, cross mesh communication cost is + not quite needed for operator strategy estimation/selection. + """ + if current_spec.mesh != target_spec.mesh: + # make infinite cost if meshes are not same + # TODO: see if we want to support this once there's cross mesh communication + return float("inf") + + if current_spec.is_replicated(): + # short-cut: + # comm cost is 0 if current spec is already full replication + return 0.0 + + mesh_topo = MeshTopoInfo.build_from_mesh(current_spec.mesh) + cost = 0.0 + comm_bytes_gb = ( + spec_to_bytes(current_spec) / current_spec.num_shards / 1024 / 1024 / 1024 + ) + # Transformation that considered for redistribute cost: + # 1. allgather 2. alltoall + # 3. allreduce 4. reduce_scatter + from torch.distributed._functional_collectives import _are_we_tracing + from torch.distributed.tensor._redistribute import ( + _gen_transform_infos, + _gen_transform_infos_non_cached, + ) + + # No redistribution needed when placements are already identical. + # This also prevents potential failures in _gen_transform_infos for certain configurations + # (e.g., sub-meshes) where finding a transform path between identical states may error out. + # TODO(zpcore): test placements with _StridedShard. + if current_spec.placements == target_spec.placements: + return cost + if _are_we_tracing(): + transform_infos = _gen_transform_infos_non_cached(current_spec, target_spec) + else: + transform_infos = _gen_transform_infos(current_spec, target_spec) + for transform_info in transform_infos: + assert current_spec.tensor_meta is not None, ( + "spec should have tensor meta defined!" + ) + current = transform_info.src_dst_placements[0] + target = transform_info.src_dst_placements[1] + if current == target: + continue + mesh_dim = transform_info.mesh_dim + num_devices_on_mesh_dim = mesh_topo.mesh_dim_devices[mesh_dim] + if current.is_shard() and target.is_replicate(): + # allgather gives larger comm bytes + comm_bytes_gb *= num_devices_on_mesh_dim + # add up allgather comm cost + cost += allgather_cost(comm_bytes_gb, mesh_topo, mesh_dim) + elif current.is_shard() and target.is_shard(): + # should be alltoall comm, since we haven't implement it yet, add 1.0 as penalty + # to favor allgather instead + # TODO: add alltoall_cost + cost += allgather_cost(comm_bytes_gb, mesh_topo, mesh_dim) + 1.0 + elif current.is_partial() and target.is_replicate(): + # add up allreduce comm cost + cost += allreduce_cost(comm_bytes_gb, mesh_topo, mesh_dim) + elif current.is_partial() and target.is_shard(): + # add up reduce_scatter comm cost + cost += reduce_scatter_cost(comm_bytes_gb, mesh_topo, mesh_dim) + # after reduce_scatter the comm bytes for further collectives halved. + comm_bytes_gb /= num_devices_on_mesh_dim + elif current.is_shard() and target.is_partial(): + # ban shard -> partial as it does not make sense to perform + # this redistribute + return float("inf") + + return cost diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/_dispatch.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/_dispatch.py new file mode 100644 index 0000000000000000000000000000000000000000..54c0cf63440b947587eca96781371c98ffa58407 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/_dispatch.py @@ -0,0 +1,653 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates +import contextlib +import logging +import warnings +from collections.abc import Sequence +from typing import cast + +import torch +import torch.distributed as dist +import torch.distributed.tensor._api as dtensor +import torch.distributed.tensor._random as random +from torch._library.utils import fill_defaults +from torch.distributed._functional_collectives import _are_we_tracing +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor._argmin_argmax import argmin_argmax_handler +from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta +from torch.distributed.tensor._op_schema import ( + OpInfo, + OpSchema, + OutputSharding, + OutputSpecType, +) +from torch.distributed.tensor._random import is_rng_supported_mesh +from torch.distributed.tensor._redistribute import redistribute_local_tensor +from torch.distributed.tensor._sharding_prop import ShardingPropagator +from torch.distributed.tensor._tp_conv import ( + convolution_backward_handler, + convolution_handler, +) +from torch.distributed.tensor._utils import ( + ExplicitRedistributionContext, + try_find_mesh_from_args, +) +from torch.distributed.tensor.placement_types import Partial, Placement, Replicate +from torch.utils._debug_mode import get_active_debug_mode +from torch.utils._python_dispatch import return_and_correct_aliasing + + +try: + from torch.utils import _cxx_pytree as pytree +except ImportError: + from torch.utils import _pytree as pytree # type: ignore[no-redef] + +aten = torch.ops.aten +logger = logging.getLogger(__name__) + + +def as_strided_handler( + op_call: torch._ops.OpOverload, + args: tuple[object, ...], + kwargs: dict[str, object], +): + args, kwargs = fill_defaults(op_call._schema, args, kwargs) + assert not kwargs + tensor, size, stride, storage_offset = args + if ( + tensor.size() == tuple(size) + and tensor.stride() == tuple(stride) + and (storage_offset is None or tensor.storage_offset() == storage_offset) + ): + return torch.ops.aten.alias.default(tensor) + raise RuntimeError("as_strided not supported with DTensor") + + +def is_same_size_handler( + op_call: torch._ops.OpOverload, + args: tuple[object, ...], + kwargs: dict[str, object], +) -> bool: + lhs = cast(torch.Tensor, args[0]) + rhs = cast(torch.Tensor, args[1]) + return lhs.shape == rhs.shape + + +def found_inf_reduce_handler( + op_call: torch._ops.OpOverload, + args: tuple[object, ...], + kwargs: dict[str, object], +) -> None: + op_info = dtensor.DTensor._op_dispatcher.unwrap_to_op_info(op_call, args, kwargs) + local_tensor_args = pytree.tree_unflatten( + cast(list[object], op_info.local_args), + op_info.args_tree_spec, # type: ignore[arg-type] + ) + local_tensor_args = cast(tuple[object, ...], local_tensor_args) + op_call(*local_tensor_args, **op_info.local_kwargs) + + grad_dtensor = cast(list[dtensor.DTensor], args[0])[0] + grad_placements = grad_dtensor.placements + mesh = grad_dtensor.device_mesh + + found_inf_placements: list[Placement] = [] + for placement in grad_placements: + if isinstance(placement, Replicate): + found_inf_placements.append(placement) + else: + found_inf_placements.append(Partial("max")) + + target_tensor = cast(torch.Tensor, args[1]) + spec = DTensorSpec( + mesh=mesh, + placements=tuple(found_inf_placements), + tensor_meta=TensorMeta( + shape=target_tensor.size(), + stride=target_tensor.stride(), + dtype=target_tensor.dtype, + ), + ) + # pyrefly: ignore [bad-argument-type] + found_inf_dtensor = dtensor.DTensor( + local_tensor=target_tensor, # pyrefly: ignore [unexpected-keyword] + spec=spec, # pyrefly: ignore [unexpected-keyword] + requires_grad=False, # pyrefly: ignore [unexpected-keyword] + ) + found_inf = found_inf_dtensor.full_tensor() + target_tensor.copy_(found_inf) + + +class OpDispatcher: + """ + Op dispatching class instance to handle args/kwargs pre-processing (un-wrapping), sharding + propagation, redistribute local args, local compute, and post-processing (re-wrapping). It + also handles any op specific logic if necessary. + + NOTE: Given the runtime overhead of Tensor subclass (__torch_dispatch__), the OpDispatcher + is designed to minimize the CPU overhead by using the tricks of proper unflattening, faster + pytree if needed, and leveraging various caching mechanisms implemented in the sharding + propagation and redistribute modules. The CPU overhead is critical to eager mode performance, + one need to carefully measure the CPU overhead when making significant changes to the + OpDispatcher and ShardingPropagator. + """ + + def __init__(self) -> None: + self.sharding_propagator = ShardingPropagator() + # NOTE: must stay in sync with is_random_op in + # torch/csrc/autograd/python_variable.cpp + self._random_ops = { + aten.native_dropout.default, + aten.normal_.default, + aten.rand.default, + aten.rand_like.default, + aten.randn.default, + aten.randn_like.default, + aten.randint_like.default, + aten.randint_like.low_dtype, + aten.randint_like.low_dtype_out, + aten.uniform_.default, + aten.bernoulli.default, + aten.bernoulli_.float, + } + self._custom_op_handlers = { + aten.is_same_size.default: is_same_size_handler, + aten.convolution.default: convolution_handler, + aten.convolution_backward.default: convolution_backward_handler, + aten._amp_foreach_non_finite_check_and_unscale_.default: found_inf_reduce_handler, + aten.as_strided.default: as_strided_handler, + aten.argmin.default: argmin_argmax_handler, + aten.argmax.default: argmin_argmax_handler, + } + + # ******************************************************************************************** + # def dispatch(...) + # + # NOTE: this class no longer contains the top-level dispatch entrypoint! + # See #167051 for details + # + # The entrypoint has been moved to C++, and it handles common cases and then calls back into + # OpDispatcher python to handle corner cases. + # See dispatchDTensorOp() defined in python_variable.cpp and called from python_arg_parser.cpp + # ******************************************************************************************** + + # This flag is used internally to control whether we treat the torch.Tensor(non-DTensor) + # as implicitly replicated or we throw error to user. + # NOTE: It is EXTREMELY UNSAFE to turn this flag on by default so we intentionally leave + # it as False by default. + @property + def _allow_implicit_replication(self) -> bool: + return torch._C._get_dtensor_allow_implicit_replication() + + @_allow_implicit_replication.setter + def _allow_implicit_replication(self, value: bool) -> None: + return torch._C._set_dtensor_allow_implicit_replication(value) + + def _propagate_op_sharding_dispatch_slow_path( + self, + op_call: torch._ops.OpOverload, + args: tuple[object, ...], + kwargs: dict[str, object], + op_info: OpInfo, + # The logic here is a bit messy. There are several reasons why the + # C++ fastpath may have bailed out. If we just cache missed, we will + # come here because we need to actually calculate the real thing. + # There's no need to have a SECOND Python cache lookup; the C++ native + # cache completely subsumes it. But sometimes, we will have failed + # to compute the cache key in C++ entirely. In this case, we DO need + # to do a cache lookup in Python, as the missing cache key in C++ + # means we don't have access to it all. Furthermore, without duping + # this function, we need to do the try_cache test inside of the + # try-except block so that either case hits the inference mode / + # exception rewrapping case. + # + # This should be cleaned up. First, ensuring the C++ codepath can + # always compute a key will be a big help. Second, we should properly + # fastpath inference mode composite implicit autograd so that you + # don't have to throw an exception even in "fastpath". + try_cache: bool, + ) -> object: + try: + # We have basically inlined propagate() here, but WITHOUT the + # output_sharding assignment + if try_cache and not _are_we_tracing(): + return self.sharding_propagator.propagate_op_sharding(op_info.schema) + else: + return self.sharding_propagator.propagate_op_sharding_non_cached( + op_info.schema + ) + except NotImplementedError: + if torch._C._dispatch_has_kernel_for_dispatch_key( + op_call.name(), torch._C.DispatchKey.CompositeImplicitAutograd + ): + # When running under inference mode, CompositeImplicitAutograd ops show up in __torch_dispatch__, + # so we manually decompose them, here + out = op_call.decompose(*args, **kwargs) + assert out is not NotImplemented + return out + else: + raise + except Exception as e: + raise RuntimeError( + f"{e}\n\nSharding propagation failed for {op_info.schema}" + ) from e + + def _dispatch_get_local_results_slow_path( + self, + op_call: torch._ops.OpOverload, + args: tuple[object, ...], + op_info: OpInfo, + ) -> object: + output_sharding = op_info.output_sharding + assert output_sharding is not None, "output sharding should not be None" + + mesh = op_info.compute_mesh + participating = mesh.get_coordinate() is not None + local_results = None + if participating: + # computation that happens in the current rank of the mesh, normal case + if output_sharding.needs_redistribute: + # If sharding propagation decision needs redistribute, perform redistribute + # on args first, which could potentially modify args (i.e. allgather certain arg) + assert output_sharding.redistribute_schema is not None + self.redistribute_local_args( + op_info, + output_sharding.redistribute_schema, + output_sharding.use_val_from_redistribute_schema, + ) + + local_tensor_args = ( + pytree.tree_unflatten( + cast(list[object], op_info.local_args), + # pyrefly: ignore [bad-argument-type] + op_info.args_tree_spec, + ) + if op_info.args_tree_spec + else op_info.local_args + ) + + # run local op computation with potentially modified args/kwargs + local_tensor_args = cast(tuple[object, ...], local_tensor_args) + if op_call in self._random_ops: + if not random._rng_tracker and is_rng_supported_mesh(mesh): + # Default to `OffsetBasedRNGTracker` if the parallelism API + # did not already construct one + random._rng_tracker = random.OffsetBasedRNGTracker(mesh) + + first_arg, first_local_arg = ( + cast(dtensor.DTensor, args[0]), + cast(torch.Tensor, local_tensor_args[0]), + ) + + # If the user provided a generator, we hook it up to our RNG manager, but we also pop it from kwargs + # so the op_call does not directly use it (we want op_call to fall back to the 'default' which is + # our RNG manager) + maybe_user_generator = op_info.local_kwargs.pop("generator", None) + assert maybe_user_generator is None or isinstance( + maybe_user_generator, torch.Generator + ) + # maybe_user_generator = None + rng_context = ( + random._rng_tracker._distribute_region( + first_arg._spec, generator=maybe_user_generator + ) + if random._rng_tracker and not first_local_arg.is_meta + else contextlib.nullcontext() + ) + # For DTensor random operator, run it within a RNGTracker context to + # ensure the random number generator is properly distributed. + with rng_context: + local_results = op_call(*local_tensor_args, **op_info.local_kwargs) + else: + # normal case, run local sharded op computation + local_results = op_call(*local_tensor_args, **op_info.local_kwargs) + + else: + # For a non-participating device (happens on rank that does not belong to + # the device mesh), we do: + # 1. if the return type is scalar, set the local result to None. + # 2. if the return type is Tensor or List[Tensor], return empty + # tensor(s) with correct dtype. + spec = output_sharding.output_spec + ret_list = op_call._schema.returns + + if spec is None: + # For a scalar return type, the non-participating device has None + # as its local result + local_results = None + else: + + def default_tensor(spec: DTensorSpec) -> torch.Tensor: + if spec.tensor_meta is not None: + shape = spec.tensor_meta.shape + dtype = spec.tensor_meta.dtype + if len(shape) == 0: + # scalar tensor + return torch.zeros((), dtype=dtype) + else: + # non-scalar tensor + return torch.tensor([], dtype=dtype) + else: + raise RuntimeError(f"{spec} has no tensor metadata.") + + if isinstance(spec, DTensorSpec): + # return a Tensor value + local_results = default_tensor(spec) + elif isinstance(spec, Sequence): + # return a List[Tensor] value + local_results = [ + default_tensor(s) if s is not None else None for s in spec + ] + assert isinstance(local_results, list) + if None in local_results: + ret_type = str(ret_list[0].type) + raise NotImplementedError( + f"return type {ret_type} in DTensor op is not supported" + ) + return local_results + + def _dispatch_fast_path_python_tail( + self, + op_call: torch._ops.OpOverload, + args: tuple[object, ...], + kwargs: dict[str, object], + compute_mesh: DeviceMesh, + output_sharding: OutputSharding, + local_results: object, + participating: bool, + is_inplace_op: bool, + is_out_variant_op: bool, + ) -> object: + """ + Tail of main dispatching logic, called from C++ fast path. + """ + + if output_sharding.output_spec is None: + if op_call == aten.equal.default: + # The output of the equal op is a bool, by converting it into a + # a single value tensor, we can use all-reduce with min reduce op + # to simulate logical and. + assert local_results is None or isinstance(local_results, bool) + r = torch.tensor( + int(local_results) if local_results is not None else 1, + device=compute_mesh.device_type, + ) + dist.all_reduce(r, op=dist.ReduceOp.MIN) + local_results = bool(r.item()) + + if is_inplace_op: + # inplace op should return self instead of re-wrapping + if output_sharding.output_spec is not None: + output_spec = output_sharding.output_spec + assert isinstance(output_spec, DTensorSpec) + assert isinstance(args[0], dtensor.DTensor) + + # NOTE: aten.squeeze_.dim is an inplace op but it also may change + # the inplace argument's tensor meta. Here we choose to special case + # this op because as far as I know this is the only inplace op that + # has such as behavior. We can extend this special case if necessary. + if op_call == aten.squeeze_.dim: + # update the spec to handle tensor meta changes + args[0]._spec = output_spec + # use return_and_correct_aliasing to match the outer and the inner + # aliasing. See https://github.com/pytorch/pytorch/pull/158954 + return return_and_correct_aliasing(op_call, args, kwargs, args[0]) + else: + # For all other inplace ops, check if placement changes are required + # Inplace operations that change placement are not supported because + # they would require redistribution, which breaks aliasing semantics. + # If there are views into the tensor, the views would not be updated. + if args[0]._spec.placements != output_spec.placements: + raise RuntimeError( + f"{op_call}: in-place operations that require placement changes " + f"are not supported. The operation would change placement from " + f"{args[0]._spec.placements} to {output_spec.placements}, " + f"which requires redistribution and breaks aliasing semantics. " + f"Please use the out-of-place version of this operation instead." + ) + # Most inplace ops don't change tensor meta, so no spec update needed + return args[0] + else: + return None + elif is_out_variant_op: + # out variant could possibly have multiple out args (i.e. lu_unpack.out) + output_specs = ( + (output_sharding.output_spec,) + if not isinstance(output_sharding.output_spec, tuple) + else output_sharding.output_spec + ) + out_dts = [] + spec_idx = 0 + for argument in op_call._schema.arguments: + if argument.is_out: + out_dt = cast(dtensor.DTensor, kwargs[argument.name]) + out_dt._spec = cast(DTensorSpec, output_specs[spec_idx]) + out_dts.append(out_dt) + spec_idx += 1 + + assert len(out_dts) >= 1, "out variant should have at least one out arg" + return tuple(out_dts) if len(out_dts) > 1 else out_dts[0] + else: + assert op_call == aten.equal.default, op_call + ret = self.wrap(local_results, output_sharding.output_spec) # type: ignore[possibly-undefined] + if participating and op_call._schema._is_view_op(): + return return_and_correct_aliasing(op_call, args, kwargs, ret) + else: + return ret + + @staticmethod + def redistribute_local_args( + op_info: OpInfo, + suggested_input_schema: OpSchema, + use_val_from_redistribute_schema: bool, + ) -> None: + debug_mode = get_active_debug_mode() + + # NOTE: it's very rare that we need to reshard kwargs so we intentionally skip it + if op_info.args_tree_spec is not None: + flatten_args_schema_to_reshard = tuple( + pytree.tree_leaves(suggested_input_schema.args_schema) + ) + else: + flatten_args_schema_to_reshard = suggested_input_schema.args_schema + + new_local_args: list[object] = [] + for i, arg_spec in enumerate(op_info.flat_args_schema): + reshard_arg_spec = flatten_args_schema_to_reshard[i] + if isinstance(arg_spec, DTensorSpec): + local_tensor = cast(torch.Tensor, op_info.local_args[i]) + if arg_spec != reshard_arg_spec: + redistribute_context = ( + debug_mode.record_redistribute_calls( # type: ignore[union-attr] + i, arg_spec, reshard_arg_spec + ) + if debug_mode is not None + else contextlib.nullcontext() + ) + ExplicitRedistributionContext.observe_redistribution( + arg_spec, + # pyrefly: ignore [bad-argument-type] + reshard_arg_spec, + message=f"Implicit redistribution occurred for {op_info.schema} " + "while ExplicitRedistributionContext was active", + ) + with redistribute_context: + resharded_local_tensor = redistribute_local_tensor( + local_tensor, + arg_spec, + # pyrefly: ignore [bad-argument-type] + reshard_arg_spec, + ) + new_local_args.append(resharded_local_tensor) + else: + new_local_args.append(local_tensor) + else: + if use_val_from_redistribute_schema: + # args can be updated for view related ops, we refer to the + # update in redistribute_schema. + new_local_args.append(reshard_arg_spec) + else: + new_local_args.append(arg_spec) + + op_info.local_args = tuple(new_local_args) + + def unwrap_to_op_info( + self, + op_call: torch._ops.OpOverload, + args: tuple[object, ...], + kwargs: dict[str, object], + ) -> OpInfo: + return self._unwrap_to_op_info_impl(op_call, args, kwargs, True) + + def _unwrap_to_op_info_impl( + self, + op_call: torch._ops.OpOverload, + args: tuple[object, ...], + kwargs: dict[str, object], + create_schema: bool, + ) -> OpInfo: + # get runtime schema info to determine whether to use pytree to flatten inputs + runtime_schema_info = self.sharding_propagator.op_to_schema_info.get( + op_call, None + ) + + if runtime_schema_info is not None and runtime_schema_info.needs_pytree: + # flatten args/kwargs when op says necessary + tree_args, args_spec = pytree.tree_flatten(args) + args_list: Sequence[object] = tree_args + else: + args_list, args_spec = args, None + + args_schema: list[object] = [] + kwargs_schema: dict[str, object] = {} + local_args: list[object] = [] + local_kwargs: dict[str, object] = {} + compute_mesh: DeviceMesh | None = None + + for arg in args_list: + if isinstance(arg, dtensor.DTensor): + local_args.append(arg._local_tensor) + args_schema.append(arg._spec) + if compute_mesh is None: + # record the first compute device mesh from args + compute_mesh = arg.device_mesh + elif isinstance(arg, torch.Tensor): + compute_mesh = compute_mesh or try_find_mesh_from_args( + op_call, args_list + ) + args_schema.append( + self._try_replicate_spec_for_scalar_tensor( + op_call, arg, compute_mesh + ) + ) + local_args.append(arg) + else: + # non DTensor/Tensor args (i.e. int/float/bool), just add to args_schema/local_args + args_schema.append(arg) + local_args.append(arg) + + for k, v in kwargs.items(): + if isinstance(v, dtensor.DTensor): + local_kwargs[k] = v._local_tensor + kwargs_schema[k] = v._spec + elif isinstance(v, torch.Tensor): + compute_mesh = compute_mesh or try_find_mesh_from_args( + op_call, args_list + ) + kwargs_schema[k] = self._try_replicate_spec_for_scalar_tensor( + op_call, + v, + # pyrefly: ignore [bad-argument-type] + compute_mesh, + ) + local_kwargs[k] = v + else: + # non DTensor/Tensor args (i.e. int/float/bool), just add to args_schema/local_args + kwargs_schema[k] = v + local_kwargs[k] = v + + assert compute_mesh is not None, ( + f"found no DeviceMesh from dtensor args for {op_call}!" + ) + op_info = OpInfo( + compute_mesh, + OpSchema( + op_call, + ( + # pyrefly: ignore [bad-argument-type] + pytree.tree_unflatten(args_schema, args_spec) + if args_spec + else tuple(args_schema) + ), + kwargs_schema, + schema_info=runtime_schema_info, + ) + if create_schema + else None, # type: ignore[arg-type] + args_schema, + tuple(local_args), + local_kwargs, + args_spec, + ) + return op_info + + @staticmethod + def wrap(res: object, spec: OutputSpecType) -> object: + if isinstance(res, torch.Tensor): + if spec is not None: + assert isinstance(spec, DTensorSpec), ( + f"output spec does not match with output! Expected DTensorSpec, got {spec}." + ) + # pyrefly: ignore [bad-argument-type, bad-argument-count, unexpected-keyword] + return dtensor.DTensor(res, spec, requires_grad=res.requires_grad) + else: + # if output does not have a DTensorSpec due to specific ops, it must be a scalar tensor + assert res.ndim == 0, "output tensor should be scalar!" + return res + elif isinstance(res, (list, tuple)): + assert spec is not None and isinstance(spec, (list, tuple)), ( + f"output spec does not match with output! Expected list/tuple, got {spec}." + ) + res_list = [] + for e, s in zip(res, spec): + res_list.append(OpDispatcher.wrap(e, s)) + + return tuple(res_list) if isinstance(res, tuple) else res_list + else: + # if the res contains only non tensor values (i.e. int/float/none), we simply return it + # without rewrapping to DTensor. + return res + + def _try_replicate_spec_for_scalar_tensor( + self, + op_call: torch._ops.OpOverload, + tensor_arg: torch.Tensor, + compute_mesh: DeviceMesh, + ) -> DTensorSpec: + # util function to produce a replicate spec for a scalar tensor arg/kwarg + if tensor_arg.numel() == 1 and tensor_arg.ndim == 1: + warnings.warn( + "Found a non-scalar tensor with numel=1 and ndim!=0, " + "we are implicitly creating a replicated DTensor for it. " + "However, please consider changing it to a scalar tensor " + "or explicitly create a DTensor under distributed environment.", + stacklevel=2, + ) + + if tensor_arg.numel() == 1 or self._allow_implicit_replication: + # scalar tensor can be safely treated as replicated + replication_spec = DTensorSpec( + compute_mesh, + (Replicate(),) * compute_mesh.ndim, + tensor_meta=TensorMeta( + shape=tensor_arg.shape, + stride=tensor_arg.stride(), + dtype=tensor_arg.dtype, + ), + ) + else: + raise RuntimeError( + f"{op_call}: got mixed torch.Tensor and DTensor, need to convert all" + " torch.Tensor to DTensor before calling distributed operators!" + " Please see https://docs.pytorch.org/docs/main/distributed.tensor.html#mixed-tensor-and-dtensor-operations" + " for more details." + ) + return replication_spec diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/_dtensor_spec.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/_dtensor_spec.py new file mode 100644 index 0000000000000000000000000000000000000000..629bf104e11632beee286256d6f0f77609a289eb --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/_dtensor_spec.py @@ -0,0 +1,710 @@ +import itertools +import math +from collections import defaultdict +from dataclasses import dataclass +from typing import Any, cast, NamedTuple, Optional + +import torch +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor.placement_types import ( + _StridedShard, + MaskPartial, + Partial, + Placement, + Replicate, + Shard, +) +from torch.utils._debug_mode import _stringify_shape +from torch.utils._dtype_abbrs import dtype_abbrs + + +class ShardOrderEntry(NamedTuple): + """ + Represents how a single tensor dimension is sharded across mesh dimensions. + + Attributes: + tensor_dim: The tensor dimension being sharded (e.g., 0, 1, 2 for a 3D tensor). + mesh_dims: Tuple of mesh dimensions across which this tensor dimension is sharded, + in execution order. The first mesh dim is applied first, second is applied + second, etc. This tuple is guaranteed to be non-empty. + + Examples: + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_DISTRIBUTED) + >>> # Tensor dim 1 sharded across mesh dim 2, then mesh dim 0 + >>> ShardOrderEntry(tensor_dim=1, mesh_dims=(2, 0)) + + >>> # Tensor dim 0 sharded only on mesh dim 1 + >>> ShardOrderEntry(tensor_dim=0, mesh_dims=(1,)) + """ + + tensor_dim: int + mesh_dims: tuple[int, ...] # guaranteed to be non-empty + + +# Type alias for the complete shard order specification +# A tuple of ShardOrderEntry, one per sharded tensor dimension +# +# Example: +# shard_order = ( +# ShardOrderEntry(tensor_dim=0, mesh_dims=(1,)), +# ShardOrderEntry(tensor_dim=2, mesh_dims=(0, 3)), +# ) +# This means: +# - Tensor dimension 0 is sharded on mesh dimension 1 +# - Tensor dimension 2 is sharded on mesh dimension 0 first, then mesh dimension 3 +ShardOrder = tuple[ShardOrderEntry, ...] + + +class TensorMeta(NamedTuple): + # simple named tuple to represent tensor metadata + # intentionally to stay simple only for sharding + # propagation purposes. + shape: torch.Size + stride: tuple[int, ...] + dtype: torch.dtype + + +# used internally to propagate the placements +@dataclass +class DTensorSpec: + mesh: DeviceMesh + placements: tuple[Placement, ...] + + # tensor meta will only be set during sharding propagation + tensor_meta: TensorMeta | None = None + + # When a tensor dimension is sharded across multiple mesh axes, + # `shard_order` specifies the sequence in which these shardings are applied. + # This order determines how tensor shards are mapped and distributed across + # devices. + # + # Example: + # For a tensor of shape [8, 16] and a 3D device mesh, if dim 0 is sharded over + # mesh dim 1, and dim 1 is sharded over mesh dim 0 and then mesh dim 2, + # the shard_order would be: + # shard_order = ( + # ShardOrderEntry(tensor_dim=0, mesh_dims=(1,)), + # ShardOrderEntry(tensor_dim=1, mesh_dims=(0, 2)), + # ) + shard_order: ShardOrder = None # type: ignore[assignment] + + def __post_init__(self) -> None: + if not isinstance(self.placements, tuple): + self.placements = tuple(self.placements) + if self.shard_order is None: + # pyrefly: ignore [bad-assignment] + + _, self.shard_order = self._normalize_placements_into_shard_order( + self.placements, self.mesh + ) + self._hash: int | None = None + + @staticmethod + def _normalize_placements_into_shard_order( + placements: tuple[Placement, ...], mesh: DeviceMesh + ) -> tuple[tuple[Placement, ...], Optional[ShardOrder]]: + # If the returned shard_order is None, it means the StridedShard/Shard + # combinations can't be interpreted as shard order. + # If no _StridedShard in placements, we create default order. + if not any(isinstance(p, _StridedShard) for p in placements): + return placements, DTensorSpec.compute_default_shard_order(placements) + # _StridedShard in placements, try check if it can be decoded as shard order + shard_order = DTensorSpec._maybe_convert_StridedShard_to_shard_order( + placements, mesh + ) + if shard_order is not None: + normalized_placements = tuple( + [ + p if not isinstance(p, _StridedShard) else Shard(p.dim) + for p in placements + ] + ) + return normalized_placements, shard_order + # unable to decode placements to shard order(e.g., the _StridedShard is + # also used by `view` op shard propagation). + return placements, None + + @staticmethod + def compute_default_shard_order( + placements: tuple[Placement, ...], + ) -> ShardOrder: + """ + Compute the default shard order from placements. + + Returns a ShardOrder where each ShardOrderEntry maps a tensor dimension + to the mesh dimensions it's sharded on, in left-to-right order. + """ + # follow default left-to-right device order if shard_order is not specified + tensor_dim_to_mesh_dims: defaultdict[int, list[int]] = defaultdict(list) + mesh_ndim = len(placements) + for mesh_dim in range(mesh_ndim): + # shard_order doesn't work with _StridedShard + if isinstance(placements[mesh_dim], _StridedShard): + return () + if isinstance(placements[mesh_dim], Shard): + placement = cast(Shard, placements[mesh_dim]) + shard_dim = placement.dim + assert shard_dim >= 0, ( + f"Shard dim {shard_dim} in placements {placements} must be normalized" + ) + tensor_dim_to_mesh_dims[shard_dim].append(mesh_dim) + + # Convert dict into ShardOrderEntry tuples + default_shard_order = tuple( + ShardOrderEntry(tensor_dim=key, mesh_dims=tuple(value)) + for key, value in sorted(tensor_dim_to_mesh_dims.items()) + if value + ) + return default_shard_order + + @staticmethod + def _convert_shard_order_to_StridedShard( + shard_order: ShardOrder, placements: tuple[Placement, ...], mesh: DeviceMesh + ) -> tuple[Placement, ...]: + """ + Convert ShardOrder to placements with _StridedShard. + + This function converts a ShardOrder specification into a tuple of Placement objects, + using _StridedShard when a tensor dimension is sharded across multiple mesh dimensions + in a non-default order. The split_factor of each _StridedShard is determined by the + product of mesh dimension sizes that appear earlier in the shard order but later in + the placement tuple. + + Args: + shard_order: ShardOrder specification indicating which tensor dimensions are + sharded on which mesh dimensions and in what execution order. + placements: Tuple of Placement objects that does not contain _StridedShard. + mesh: DeviceMesh containing the size information for each mesh dimension. + + Returns: + Updated tuple of Placement objects with Shard or _StridedShard placements. + + Algorithm: + For each ShardOrderEntry in shard_order: + - For each mesh dimension in the entry's mesh_dims (in order): + - Calculate split_factor as the product of mesh sizes for all mesh dimensions + that appear: + 1. Earlier in the shard order (lower index in mesh_dims), and + 2. Later in the placement tuple (higher mesh dimension index) + - If split_factor == 1: use normal Shard + - Otherwise: use _StridedShard with the calculated split_factor + + Example: + >>> # xdoctest: +SKIP("Requires DeviceMesh") + >>> # Tensor dimension 0 sharded on mesh dims [2, 0, 1] in that order + >>> # mesh = DeviceMesh([4, 3, 2]) # sizes: mesh[0]=4, mesh[1]=3, mesh[2]=2 + >>> shard_order = (ShardOrderEntry(tensor_dim=0, mesh_dims=(2, 0, 1)),) + >>> placements = (Shard(0), Shard(0), Shard(0)) + >>> # For mesh_dim=2 (index 0 in mesh_dims): no earlier dims, split_factor=1 + >>> # -> placements[2] = Shard(0) + >>> # For mesh_dim=0 (index 1 in mesh_dims): mesh_dim=2 is earlier and has index 2>0 + >>> # -> split_factor = mesh.size(2) = 2 + >>> # -> placements[0] = _StridedShard(0, split_factor=2) + >>> # For mesh_dim=1 (index 2 in mesh_dims): mesh_dim=2 is earlier and has index 2>1 + >>> # -> split_factor = mesh.size(2) = 2 + >>> # -> placements[1] = _StridedShard(0, split_factor=2) + >>> # Result: (_StridedShard(0, sf=2), _StridedShard(0, sf=2), Shard(0)) + """ + placements_list = list(placements) + for entry in shard_order: + tensor_dim = entry.tensor_dim + mesh_dims = entry.mesh_dims + for idx in range(len(mesh_dims)): + # TODO(zpcore): split_factor from `view` and `shard order` + # should be able to be multiplied into one. Need to loosen the + # condition here. + mesh_dim = mesh_dims[idx] + if type(placements[mesh_dim]) is not Shard: + raise ValueError( + f"Only Shard placement can be converted to _StridedShard, " + f"found {placements[mesh_dim]} in {placements=}." + ) + split_factor = math.prod( + mesh.size(i) for i in mesh_dims[:idx] if i > mesh_dim + ) + if split_factor == 1: + # use normal Shard + placements_list[mesh_dim] = Shard(tensor_dim) + else: + placements_list[mesh_dim] = _StridedShard( + tensor_dim, split_factor=split_factor + ) + return tuple(placements_list) + + @staticmethod + def _maybe_convert_StridedShard_to_shard_order( + placements: tuple[Placement, ...], mesh: DeviceMesh + ) -> ShardOrder | None: + """ + Try to convert _StridedShard placements to ShardOrder. + + This is the inverse of `_convert_shard_order_to_StridedShard`. It reconstructs the shard + order by examining the split_factor of each _StridedShard and determining its position + in the execution order. If the _StridedShard configuration cannot be represented as a + valid ShardOrder (i.e., there's no shard order that produces the observed split_factors), + this function returns None. + + Args: + placements: Tuple of Placement objects that may contain _StridedShard. + mesh: DeviceMesh containing the size information for each mesh dimension. + + Returns: + ShardOrder if conversion is possible, None otherwise. For placements without + _StridedShard, returns the default shard order. + + Algorithm: + 1. If no _StridedShard in placements, return default shard order + 2. Create an empty list for each tensor dimension to represent mesh dim ordering + 3. Iterate through placements in reverse order (right to left): + - For each Shard/_StridedShard on a tensor dimension: + - Extract its split_factor (1 for Shard, split_factor for _StridedShard) + - Find the position in mesh_dims_order where accumulated_sf equals split_factor + - accumulated_sf is the product of mesh sizes of mesh dimensions that appear + earlier in mesh_dims_order (lower indices) + - Insert mesh_dim at the found position + 4. If no valid position found for any split_factor, return None (unable to convert) + 5. Construct ShardOrderEntry for each tensor dimension from mesh_dims_order + + Example: + >>> # xdoctest: +SKIP("Requires DeviceMesh") + >>> # mesh = DeviceMesh([4, 3, 2]) # sizes: mesh[0]=4, mesh[1]=3, mesh[2]=2 + >>> # placements = (_StridedShard(0, sf=2), _StridedShard(0, sf=2), Shard(0)) + >>> # Process tensor_dim=0 from right to left: + >>> # - mesh_dim=2: Shard(0) with sf=1 + >>> # Try position 0: accumulated_sf=1, matches! Insert at position 0 + >>> # Current mesh_dims_order order: [2] + >>> # - mesh_dim=1: _StridedShard(0, sf=2) with sf=2 + >>> # Try position 0: accumulated_sf=1, no match + >>> # Try position 1: accumulated_sf=1*mesh.size(2)=2, matches! Insert at position 1 + >>> # Current mesh_dims_order order: [2, 1] + >>> # - mesh_dim=0: _StridedShard(0, sf=2) with sf=2 + >>> # Try position 0: accumulated_sf=1, no match + >>> # Try position 1: accumulated_sf=1*mesh.size(2)=2, matches! Insert at position 1 + >>> # Final mesh_dims_order order: [2, 0, 1] + >>> # Result: ShardOrder((ShardOrderEntry(tensor_dim=0, mesh_dims=(2, 0, 1)),)) + >>> # This means: first shard on mesh_dim=2, then mesh_dim=0, then mesh_dim=1 + + Note: + This function validates that _StridedShard can be represented as a ShardOrder. + Not all _StridedShard configurations are valid - the split_factor must match + the product of mesh sizes in some execution order. + """ + if not any(isinstance(p, _StridedShard) for p in placements): + return DTensorSpec.compute_default_shard_order(placements) + max_tensor_dim = ( + max([i.dim for i in placements if isinstance(i, Shard | _StridedShard)]) + 1 + ) + shard_order = [] + + tensor_dim_to_mesh_dims_order: list[list[int]] = [ + [] for i in range(max_tensor_dim) + ] + for mesh_dim in reversed(range(len(placements))): + cur_placement = placements[mesh_dim] + # _StridedShard may not be a subclass of Shard in the future, so write in this way: + if isinstance(cur_placement, Shard | _StridedShard): + tensor_dim = cur_placement.dim + mesh_dims_order = tensor_dim_to_mesh_dims_order[tensor_dim] + cur_sf = 1 + if isinstance(cur_placement, _StridedShard): + cur_sf = cur_placement.split_factor + accumulated_sf = 1 + find_order = False + for i in range(len(mesh_dims_order) + 1): + if accumulated_sf == cur_sf: + mesh_dims_order.insert(i, mesh_dim) + find_order = True + break + if i < len(mesh_dims_order): + accumulated_sf *= mesh.size(mesh_dims_order[i]) + if not find_order: + # _StridedShard is not convertible to ShardOrder + return None + else: + if not isinstance(cur_placement, Replicate | Partial | MaskPartial): + raise ValueError( + f"Unsupported placement type {type(cur_placement)} encountered in " + f"{placements}; expected Replicate, Partial, or MaskPartial." + ) + for tensor_dim in range(max_tensor_dim): + if len(tensor_dim_to_mesh_dims_order[tensor_dim]) > 0: + shard_order.append( + ShardOrderEntry( + tensor_dim=tensor_dim, + mesh_dims=tuple(tensor_dim_to_mesh_dims_order[tensor_dim]), + ) + ) + return tuple(shard_order) + + def _verify_shard_order(self, shard_order: ShardOrder) -> None: + """Verify that the shard_order is valid and matches the placements.""" + total_shard = 0 + if any(isinstance(p, _StridedShard) for p in self.placements): + return + prev_tensor_dim = -1 + for entry in shard_order: + tensor_dim = entry.tensor_dim + mesh_dims = entry.mesh_dims + assert len(mesh_dims) > 0, f"shard_order {shard_order} has empty mesh dim" + assert tensor_dim >= 0, ( + f"shard_order {shard_order} has invalid tensor dim {tensor_dim}" + ) + assert tensor_dim > prev_tensor_dim, ( + "tensor dim should be sorted in shard_order" + ) + prev_tensor_dim = tensor_dim + total_shard += len(mesh_dims) + for mesh_dim in mesh_dims: + assert 0 <= mesh_dim < len(self.placements), ( + f"shard_order {shard_order} has invalid mesh dim {mesh_dims}" + ) + assert self.placements[mesh_dim] == Shard(tensor_dim), ( + f"placement[{mesh_dim}] doesn't have a matching shard in shard_order" + ) + assert total_shard == sum(1 for p in self.placements if isinstance(p, Shard)) + + def __setattr__(self, attr: str, value: Any) -> None: + if attr == "shard_order" and value is not None: + self._verify_shard_order(value) + super().__setattr__(attr, value) + # Make sure to recompute the hash in case any of the hashed attributes + # change (though we do not expect `mesh`, `placements` or `shard_order` + # to change) + if hasattr(self, "_hash") and attr in ( + "mesh", + "placements", + "tensor_meta", + "shard_order", + ): + self._hash = None + # This assert was triggered by buggy handling for dict outputs in some + # FX passes, where you accidentally iterate over a dict and try to put + # keys into TensorMeta. See https://github.com/pytorch/pytorch/issues/157919 + if attr == "tensor_meta" and value is not None: + from torch.fx.passes.shape_prop import TensorMetadata + + # TODO: the TensorMetadata arises from + # test/distributed/tensor/experimental/test_tp_transform.py::TensorParallelTest::test_tp_transform_e2e + # but I actually can't reproduce it, maybe it is also a bug! + assert isinstance(value, TensorMeta | TensorMetadata), value + + def _hash_impl(self) -> int: + # hashing and equality check for DTensorSpec are used to cache the sharding + # propagation results. We only need to consider the mesh, placements, shape + # dtype and stride. + # Caveat: we need to keep this in mind and sync hash and eq if we add more + # fields to them. + if self.tensor_meta is not None: + return hash( + ( + self.mesh, + self.placements, + self.shard_order, + self.tensor_meta.shape, + self.tensor_meta.stride, + self.tensor_meta.dtype, + ) + ) + return hash((self.mesh, self.placements, self.shard_order)) + + def __hash__(self) -> int: + # We lazily cache the spec to avoid recomputing the hash upon each + # use, where we make sure to update the hash when the `tensor_meta` + # changes by overriding `__setattr__`. This must be lazy so that Dynamo + # does not try to hash non-singleton `SymInt`s for the stride. + if self._hash is None: + self._hash = self._hash_impl() + return self._hash + + def _check_equals(self, other: object, skip_shapes: bool = False) -> bool: + if not ( + isinstance(other, DTensorSpec) + and self.mesh == other.mesh + and self.placements == other.placements + and self.shard_order == other.shard_order + ): + return False + if self.tensor_meta is None or other.tensor_meta is None: + return self.tensor_meta == other.tensor_meta + + if skip_shapes: + return self.tensor_meta.dtype == other.tensor_meta.dtype + return ( + self.tensor_meta.shape == other.tensor_meta.shape # type: ignore[union-attr] + and self.tensor_meta.stride == other.tensor_meta.stride # type: ignore[union-attr] + and self.tensor_meta.dtype == other.tensor_meta.dtype # type: ignore[union-attr] + ) + + def __eq__(self, other: object, /) -> bool: + return self._check_equals(other) + + def __str__(self) -> str: + """ + human readable representation of the DTensorSpec + """ + placement_str = self.format_shard_order_str(self.placements, self.shard_order) + if self.tensor_meta is not None: + tensor_shape = _stringify_shape(self.tensor_meta.shape) + tensor_dtype = dtype_abbrs[self.tensor_meta.dtype] + else: + tensor_shape = "unknown shape" + tensor_dtype = "unknown dtype" + + return f"Spec({tensor_dtype}{tensor_shape}({placement_str}))" + + @staticmethod + def is_default_device_order(shard_order: ShardOrder) -> bool: + """ + Check if the device order is the default left-to-right order. + """ + for entry in shard_order: + mesh_dims = entry.mesh_dims + is_increasing = all( + prev < nxt for prev, nxt in itertools.pairwise(mesh_dims) + ) + if not is_increasing: + return False + return True + + @staticmethod + def format_shard_order_str( + placements: tuple[Placement, ...], + shard_order: ShardOrder | None = None, + ) -> str: + """ + Format DTensor sharding information as a human-readable string. + + This method formats the sharding pattern in mesh-centric order, showing the placement + for each mesh dimension sequentially. When a tensor dimension is sharded across multiple + mesh dimensions, the order index indicates the execution sequence of the sharding operations. + + Args: + placements: Tuple of placement objects for each mesh dimension. + shard_order: Optional ShardOrder specifying the sharding order. + + Returns: + String representation of the sharding pattern in mesh-centric format. + + Example: + For a 3D tensor on a 2x2x2x2 mesh (16 devices) with:: + + placements = [Partial(), Shard(1), Shard(1), Replicate()] + shard_order = (ShardOrderEntry(tensor_dim=1, mesh_dims=(2, 1)),) + + Mesh configuration: + - mesh_dim_0: Partial reduction (sum) + - mesh_dim_1: Shard tensor dimension 1 (executed second, order index 1) + - mesh_dim_2: Shard tensor dimension 1 (executed first, order index 0) + - mesh_dim_3: Replicate + + Output: ``"PS(1)[1]S(1)[0]R"`` + + Explanation: + - ``P``: mesh dimension 0 has partial reduction + - ``S(1)[1]``: mesh dimension 1 shards tensor dimension 1 (order index 1 means second) + - ``S(1)[0]``: mesh dimension 2 shards tensor dimension 1 (order index 0 means first) + - ``R``: mesh dimension 3 replicates + + The format follows mesh dimension order (0, 1, 2, 3), and when a tensor dimension + is sharded across multiple mesh dimensions, the bracketed index shows the execution + order: ``[0]`` is executed first, ``[1]`` is executed second, etc. + """ + out_str = "" + # native dtensor-style sharding representation: map from mesh + # dim to tensor dim + for mesh_dim, placement in enumerate(placements): + if isinstance(placement, Shard): + if shard_order is not None: + for entry in shard_order: + tensor_dim = entry.tensor_dim + mesh_dims = entry.mesh_dims + + if placement.dim == tensor_dim: + assert mesh_dim in mesh_dims + if len(mesh_dims) > 1: + out_str += f"{placement}[{mesh_dims.index(mesh_dim)}]" + else: + # no need to show device order if the tensor dim is + # only sharded in one mesh dim + out_str += str(placement) + break + else: + out_str += str(placement) + else: + out_str += str(placement) + return out_str + + @property + def shape(self) -> torch.Size: + if self.tensor_meta is None: + raise ValueError("tensor_meta is not set") + return self.tensor_meta.shape + + @property + def stride(self) -> tuple[int, ...]: + if self.tensor_meta is None: + raise ValueError("tensor_meta is not set") + return self.tensor_meta.stride + + @property + def ndim(self) -> int: + if self.tensor_meta is None: + raise ValueError("tensor_meta is not set") + return len(self.tensor_meta.shape) + + @property + def num_shards(self) -> int: + num_shards = 1 + for i, placement in enumerate(self.placements): + if placement.is_shard(): + num_shards *= self.mesh.size(i) + return num_shards + + @property + def device_mesh(self) -> DeviceMesh: + # simple aliasing for the mesh field, make some + # checks that mixes DTensor/DTensorSpec easier + return self.mesh + + @property + def dim_map(self) -> list[int]: + """ + dim_map is a property we derive from `placements` of + the distributed tensor. It simply return a list of ints + where dim_map[i] denotes the sharding mapping to the mesh + dimension, and len(dim_map) == dist_tensor.ndim + dim_map[i] = -1: means tensor dim i replicate on mesh + dim_map[i] = j: means tensor dim i shard on mesh dim j + + For example, we have a dist tensor that have the shape of + [18, 20, 30], and device_mesh([0, 1, 2, 3]), placements: + [Shard(1)], the dim_map of this placement would be: + [-1, 0, -1]. This representation is pretty helpful during + sharding propagation where we could know exactly each + tensor dimension is sharded or not. + + Note that if placements contains `_Partial`, we have to + explicitly deal with it, so that when we create a DTensorSpec + with dim_map, we could properly record the pending sums. + """ + # dims mapping of dist tensor sharding + # return size of tensor ndim, -1 represent replicate + # and int >=0 represent shard on that device mesh dim + r = [-1] * self.ndim + for i, placement in enumerate(self.placements): + if placement.is_shard(): + shard_dim = cast(Shard, placement).dim + if r[shard_dim] > -1: + raise ValueError( + f"Tensor dim {shard_dim} is already sharded on mesh dim {r[shard_dim]}," + " DTensor operator implementation does not support things like hybrid" + " sharding strategies yet (i.e. [Shard(0), Shard(0)])" + ) + r[shard_dim] = i + return r + + @property + def num_shards_map(self) -> list[int]: + """ + dim_map is a property we derive from `placements` of + the distributed tensor. Unlike `dim_map`, `num_shards_map` + denotes how many shards each tensor dim has. Like `dim_map`: + len(num_shards_map) == dist_tensor.ndim + num_shards_map[i] = 1: means tensor dim i is not sharded + num_shards_map[i] = j: means tensor dim i has j shards in total + + For example, we have a dist tensor of shape [18, 20, 30], + a device_mesh ([[0, 1, 2, 3], [4, 5, 6, 7]]), and placements + ([Shard(1), Shard(0)]), the num_shards_map of this distributed tensor + would be: [4, 2, 1]. + """ + r = [1] * self.ndim + for i, placement in enumerate(self.placements): + if placement.is_shard(): + shard_dim = cast(Shard, placement).dim + r[shard_dim] *= self.mesh.size(i) + + return r + + @property + def sums(self) -> list[int]: + """ + sums is a property we derive from `placements` of the + distributed tensor. It simply return a list of ints where + sums[i] denotes the pending sum (partial) on mesh dim i + """ + return [ + idx + for idx, placement in enumerate(self.placements) + if placement.is_partial() + ] + + @classmethod + def from_dim_map( + cls, + mesh: DeviceMesh, + dim_map: list[int], + sums: list[int], + tensor_meta: TensorMeta | None = None, + ) -> "DTensorSpec": + """ + Construct a DTensorSpec from dim_map list and pending sum. + + Args: + mesh (class:`DeviceMesh`): device mesh to be used in the DTensorSpec + dim_map (List[int]): a list of integer that represents sharding on each + tensor dimension, see `dim_map` property doc for details + sums (List[int]): a list of integer that represents the dist tensor have + pending sum on which device mesh dimension. + tensor meta (TensorMeta): DTensor metadata + + Return: + a class:`DTensorSpec` object + """ + # by default replicate on device mesh dims + placements: list[Placement] = [Replicate() for _ in range(mesh.ndim)] + + # find all mesh dims that need pending reductions + for s in sums: + placements[s] = Partial() + + for i, m in enumerate(dim_map): + if m >= 0: + placement = placements[m] + if placement.is_shard(): + placement = cast(Shard, placement) + raise RuntimeError( + f"DeviceMesh dimension can't be mapped to two dimension of the same tensor: {i} and {placement.dim}" + ) + elif placement.is_partial(): + raise RuntimeError( + f"DeviceMesh dimension {m} cannot be both shard and partial!" + ) + placements[m] = Shard(i) + + return cls(mesh, tuple(placements), tensor_meta=tensor_meta) + + def is_replicated(self) -> bool: + """ + return True if the current DTensorSpec replicates on all mesh dims (devices) + """ + return all(placement.is_replicate() for placement in self.placements) + + def is_sharded(self) -> bool: + """ + return True if the current DTensorSpec uses Shard() placement on any mesh dims (devices) + """ + return any(placement.is_shard() for placement in self.placements) + + def shallow_copy_with_tensor_meta( + self, tensor_meta: TensorMeta | None + ) -> "DTensorSpec": + """ + Shallow copy the DTensorSpec with a new tensor_meta. + """ + assert tensor_meta is not None, "shallow copy with no tensor_meta!" + return DTensorSpec( + self.mesh, + self.placements, + tensor_meta=tensor_meta, + ) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/_op_schema.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/_op_schema.py new file mode 100644 index 0000000000000000000000000000000000000000..4fec0293554ac1c0bb4031b91953386d3dc6d541 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/_op_schema.py @@ -0,0 +1,612 @@ +# mypy: allow-untyped-defs +""" +DTensor operator schema definitions and utilities. + +This module defines the core data structures and utilities for describing and managing +distributed tensor operations in PyTorch's DTensor system. It provides the foundational +schema types used for sharding propagation, operator strategy selection, and distributed +execution planning. + +Key components: +- OpSpec: Describes acceptable sharding placements for operations +- OpStrategy: Represents the possible sharding strategies for an operator +- TupleStrategy: Container for multiple strategies when ops have tuple/list of tensors input +- OpSchema: Describes operator input/output schemas with DTensorSpecs +- OutputSharding: Manages output sharding specifications and redistribution +- RuntimeSchemaInfo: Runtime execution metadata for operators +- OpInfo: Complete runtime operator execution information + +These schema definitions enable the DTensor system to: +1. Propagate tensor sharding information to the operator outputs +2. Greedily select sharding strategies for distributed operations +3. Plan and execute tensor redistributions when needed +4. Cache sharding decisions for performance optimization +""" + +from collections.abc import Sequence +from dataclasses import dataclass +from functools import cached_property +from typing import Any +from typing_extensions import deprecated + +import torch +from torch._C import ( + _DTensor_OpSchema_post_init, + _DTensor_OpSchema_recompute_comparison_key, +) +from torch._ops import OpOverload +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor._dtensor_spec import DTensorSpec +from torch.distributed.tensor.placement_types import Placement + + +try: + from torch.utils._cxx_pytree import ( + register_pytree_node, + tree_leaves, + tree_map_only, + TreeSpec, + ) +except ImportError: + from torch.utils._pytree import ( # type: ignore[no-redef, assignment] + register_pytree_node, + tree_leaves, + tree_map_only, + TreeSpec, + ) + + +# Common type aliases +ArgsType = tuple[object, ...] +KwargsType = dict[str, object] + +PlacementList = list[Placement | None] + +# ATen op schemas could have Tensor, Tuple[Tensor] and List[Tensor], so output type should +# be the same set of possibilities. +OutputSpecType = DTensorSpec | Sequence[DTensorSpec | None] | None + + +def _rebuild_tensor_from_dtensor_meta(arg) -> object: + """ + This is used to propagate tensor metadata, must be under fake mode + """ + assert arg.tensor_meta is not None, "DTensorSpec does not contain tensor_meta." + return torch.empty_strided( + arg.tensor_meta.shape, + arg.tensor_meta.stride, + dtype=arg.tensor_meta.dtype, + ) + + +def _pretty_print_spec(spec: object) -> str: + if spec is None: + return "None" + elif isinstance(spec, DTensorSpec): + return "".join([str(p) for p in spec.placements]) + elif isinstance(spec, Sequence): + return "(" + ", ".join([_pretty_print_spec(s) for s in spec]) + ")" + else: + raise RuntimeError(f"Unknown spec type to print: spec={spec}") + + +@dataclass +class OpSpec: + """ + An OpSpec describes an acceptable sharding placements of an operation, with the + specified DTensorSpecs for both the output and the inputs. + + note: when the op return value is a single DTensor object, output_specs is + DTensorSpec; when the return value is a tuple of Optional[DTensor], + output_specs is a tuple of Optional[DTensorSpec]. + + note: we MUST produce an DTensorSpec for every output that is a Tensor. None + entries only occur for non-Tensor outputs (e.g., operators that return Optional[Tensor], + or non-Tensor outputs.) + + invariant: the DeviceMesh on all DTensorSpec must be the same + """ + + # output_specs and input_specs are related: for this op, given these input_specs, + # this is the way the output would look + output_specs: DTensorSpec | tuple[DTensorSpec | None, ...] + input_specs: Sequence[DTensorSpec] | None = None + + """ + redistribute_cost tells how expensive it is to redistribute a given input into the + placement specified in this OpSpec. + + outer list: one entry (list) per (tensor) input in the op's arg schema + inner list: one entry (cost value) per possible sharding spec for that input + + Example: + ------- + another_op() -> tensor_a # another_op produces the output that becomes our first input + my_op(tensor_a) + + Let's assume this OpSpec's input_specs are [Replicate()], + but another_op() supports 2 strategies (OpSpecs) which produce outputs of + Replicate() + Shard(0) + + In this example, redistribute_costs would look like this + [ + # one row representing "my_op's first input" (tensor_a) + [ + # two entries, one for each strategies supported by another_op + 0.0, # cost of redistributing tensor_a from 'Replicate()' + K, # cost of redistributing tensor_a from 'Shard(0)' + ], + """ + redistribute_cost: list[list[float]] | None = None + + @cached_property + def output_spec(self) -> DTensorSpec: + """ + This function requires that the strategy have exactly one DTensorSpec as the + output spec. If the output_specs is a tuple, we throw an exception. + """ + if isinstance(self.output_specs, DTensorSpec): + return self.output_specs + else: + raise ValueError( + f"function output_spec expects a single DTensorSpec but got: {self.output_specs}" + ) + + @cached_property + def mesh(self): + if isinstance(self.output_specs, DTensorSpec): + return self.output_specs.mesh + elif isinstance(self.output_specs, tuple): + out_spec = self.output_specs[0] + assert isinstance(out_spec, DTensorSpec) + return out_spec.mesh + else: + raise ValueError( + f"function output_spec expects a single DTensorSpec or a tuple of DTensorSpec but got: {self.output_specs}" + ) + + def input_spec(self, index: int = 0) -> DTensorSpec: + assert self.input_specs is not None, "input_specs of OpSpec is None!" + assert len(self.input_specs) > index, ( + f"Invalid index {index} for input_specs of length " + f"{len(self.input_specs)}: {self.input_specs}" + ) + return self.input_specs[index] + + def __str__(self) -> str: + if self.input_specs is not None: + input_specs_str = f"{_pretty_print_spec(self.input_specs)} -> " + else: + input_specs_str = "" + output_spec_str = _pretty_print_spec(self.output_specs) + return f"{input_specs_str}{output_spec_str}" + + +class StrategyType: + """ + Base class type for op strategy, We have two StrategyType: + OpStrategy and TupleStrategy + """ + + +class OpStrategy(StrategyType): + """ + OpStrategy that consists of a list of sharding strategies associated with the op, + where each strategy is an OpSpec that describes the acceptable input/output sharding. + + invariant: the DeviceMesh on all OpSpec must be the same + """ + + def __init__(self, strategies: list[OpSpec]) -> None: + super().__init__() + self.strategies: list[OpSpec] = strategies + + def __str__(self) -> str: + strategy_list_str = ", ".join([str(strategy) for strategy in self.strategies]) + mesh_shape = self.mesh_shape + return f"OpStrategy[{strategy_list_str}] @ mesh: {mesh_shape}" + + def max_num_shards(self) -> int: + """ + Returns the max number of shards across all OpSpecs + """ + return max(strategy.output_spec.num_shards for strategy in self.strategies) + + @property + def mesh(self): + return self.strategies[0].mesh + + @property + def mesh_shape(self): + return self.strategies[0].mesh.shape + + @property + def ndim(self): + return self.strategies[0].output_spec.ndim + + @property + def shape(self): + return self.strategies[0].output_spec.shape + + +class TupleStrategy(StrategyType): + """ + TupleStrategy is a special case for operators that are fundamentally compound or batched such that some subset + of the inputs and outputs are completely unrelated to some other subset. + + Generally, foreach_* ops are the most common use-case for TupleStrategy, because they accept lists of inputs, + but operate independently on each input or tuple of zipped inputs. + + For example, [out_a, out_b] = torch.foreach_add([a, b], scalar): input a's sharding only affects out_a's sharding, + independent of b and out_b. + + An example of an operator that should NOT use TupleStrategy is torch.split. It produces a List[Tensor] + as its output, but the sharding decision of one output is bound together with the decision + of each other output and the common input. + """ + + def __init__( + self, + children: Sequence[StrategyType], + ) -> None: + super().__init__() + self.children: Sequence[StrategyType] = children + + @property + @deprecated( + "TupleStrategy.childs is deprecated, use TupleStrategy.children instead.", # codespell:ignore childs + category=FutureWarning, + ) + def childs(self) -> Sequence[StrategyType]: # codespell:ignore childs + """ + Alias for children, to maintain backward compatibility. + """ + return self.children + + def child_mesh(self, index: int) -> DeviceMesh: + op_strategy = self.children[index] + assert isinstance(op_strategy, OpStrategy) + return op_strategy.mesh + + def __str__(self) -> str: + child_strategies_str = ", ".join( + [f"{str(strat)}" for idx, strat in enumerate(self.children)] + ) + return f"TupleStrategy({child_strategies_str})" + + +try: + register_pytree_node( + TupleStrategy, + lambda node: (node.children, None), + lambda children, _: TupleStrategy(tuple(children)), + ) +except ValueError: + # already registered TupleStrategy, skip + pass + + +@dataclass +class RuntimeSchemaInfo: + """ + RuntimeSchemaInfo stores the operator schema related information for runtime (eager) + execution. This is mainly used for two ways: 1. to generate hash for args to determine + whether to re-run sharding prop or not 2. to determine if we need pytree + """ + + # This static_argnum records static arg "starting index" for ops that have non-tensor + # args/kwargs which would affect sharding propagation results. All args starting from + # this index would be hashed to our sharding cache. + # Note that only a few ops need this information, e.g. view, transpose, var.dim, etc. + static_argnum: int = 100 + # This static_kwargkey records static kwarg names which would affect sharding prop + static_kwargkey: list[str] | None = None + # each op can decide if it wants to use pytree flatten/unflatten during operator + # eager execution, by default we don't need to do flatten/unflatten, only if the + # op indicate it needs to, this is to accelerate eager performance. + needs_pytree: bool = False + + +@dataclass +class OpSchema: + """ + OpSchema is a data class that describes an operator input schemas, it includes + DTensorSpecs/OpStrategies (instead of DTensor) and non-tensor args/kwargs (positional + order preserved). It is mainly used by the DTensor's dispatching logic to perform various + actions (i.e. sharding propagation, caching sharding decisions, redistribute, etc.) + + NOTE: this must be used as a read only data class + TODO: make this a frozen dataclass + + Args: + op: the operator overload we are intercepting + args_schema: contains args except that the DTensor args have been replaced + with its DTensorSpec or OpStrategy + kwargs_schema: contains kwargs except that the DTensor kwargs have been replaced + with its DTensorSpec or OpStrategy + """ + + op: OpOverload + args_schema: ArgsType + kwargs_schema: KwargsType + + schema_info: RuntimeSchemaInfo | None = None + + _comparison_key: tuple[object, ...] | None = None + + @property + def args_spec(self) -> tuple[DTensorSpec, ...]: + """ + args_spec: Tuple[DTensorSpec, ...]: contains a clean list of args spec list + with NO non-DTensor positional arguments (i.e. int/float/tuple, etc) + mainly used by sharding propagation to propagate the output spec + """ + args = ( + tree_leaves(self.args_schema) + if self.schema_info is not None and self.schema_info.needs_pytree + else self.args_schema + ) + return tuple(item for item in args if isinstance(item, DTensorSpec)) + + @property + def args_strategy(self) -> tuple[OpStrategy, ...]: + # filter out non-relevant values from args schema to get a clean OpStrategy list + # separate with args_spec for the ease of type annotation + # TODO: see if we should merge this with args_spec + args = ( + tree_leaves(self.args_schema) + if self.schema_info is not None and self.schema_info.needs_pytree + else self.args_schema + ) + return tuple(item for item in args if isinstance(item, OpStrategy)) + + @property + def kwargs_strategy(self) -> tuple[OpStrategy, ...]: + # returns OpStrategy items from kwargs_schema. + kwargs_vals = ( + tree_leaves(self.kwargs_schema) + if self.schema_info is not None and self.schema_info.needs_pytree + else self.kwargs_schema.values() + ) + return tuple(item for item in kwargs_vals if isinstance(item, OpStrategy)) + + def __repr__(self) -> str: + args_schema = ", ".join([str(arg_schema) for arg_schema in self.args_schema]) + return ( + f"OpSchema(op={self.op}," + f" args_schema=({args_schema})," + f" kwargs_schema={self.kwargs_schema})" + ) + + def __str__(self) -> str: + args_schema: list[str] = [] + device_mesh = None + + for arg in self.args_schema: + if isinstance(arg, DTensorSpec): + args_schema.append(str(arg)) + device_mesh = arg.mesh + elif isinstance(arg, OpStrategy): + assert len(arg.strategies) == 1 + args_schema.append(_pretty_print_spec(arg.strategies[0].output_specs)) + device_mesh = arg.mesh + elif isinstance(arg, TupleStrategy): + first_op_strategy = arg.children[0] + assert isinstance(first_op_strategy, OpStrategy) + device_mesh = first_op_strategy.mesh + args_schema.append(str(arg)) + else: + args_schema.append(str(arg)) + + return f"{self.op}({', '.join(args_schema)}) on {device_mesh})" + + def __post_init__(self) -> None: + _DTensor_OpSchema_post_init(self) + + def arg_type_tensor_or_tensor_list_like(self, arg: object) -> bool: + is_tensor = isinstance(arg, DTensorSpec) + if is_tensor: + return True + + if not isinstance(arg, list): + return False + + return all(isinstance(e, DTensorSpec) or e is None for e in arg) + + def return_type_tuple_tensor_like(self) -> bool: + # all dispatch ops could only return Tuple[Tensor] or have None/ints/floats + # in the tuple, but the first element must be a Tensor, so this check is enough + return_types = self.op._schema.returns + return len(return_types) > 1 and isinstance( + return_types[0].type, torch.TensorType + ) + + def return_type_list_tensor_like(self) -> bool: + # returns True if the return type is a List + return_types = self.op._schema.returns + return len(return_types) == 1 and isinstance( + return_types[0].type, torch.ListType + ) + + def return_type_tensor(self) -> bool: + return_types = self.op._schema.returns + # all dispatch ops only return Tensor or Tuple[Tensor] for tensor like + # return types, so this check is enough for tensor like types + return isinstance(return_types[0].type, torch.TensorType) + + def get_mesh_from_args(self, validate: bool = True) -> DeviceMesh: + """ + This util can be used to get a mesh from the OpSchema that contains multiple + DTensors as arguments. When `validate` is True, it will try to validate that all the + arguments have the same mesh to avoid unexpected cross mesh errors. + + NOTE: this util currently does not handle TupleStrategy when `validate=True`, + this is because for TupleStrategy there could be different types of checks, i.e.: + - for stack and cat like op, we need to check within a TupleStrategy is every + input is on the same mesh + - for foreach like ops we need to check "zipped" inputs are on the same mesh + for each index. + """ + first_arg = self.args_schema[0] + if isinstance(first_arg, (DTensorSpec, OpStrategy)): + mesh = first_arg.mesh + elif isinstance(first_arg, (list, tuple, TupleStrategy)): + first_elem = ( + first_arg.children[0] + if isinstance(first_arg, TupleStrategy) + else first_arg[0] + ) + assert isinstance(first_elem, (DTensorSpec, OpStrategy)) + mesh = first_elem.mesh + else: + raise ValueError(f"Cannot find device mesh from args for op : {self.op}.") + + if validate: + for arg in self.args_schema[1:]: + if isinstance(arg, (DTensorSpec, OpStrategy)) and arg.mesh != mesh: + raise RuntimeError( + f"DTensor does not support cross-mesh operation on {self.op}! " + f"Got meshes: {mesh} {arg.mesh}. " + f"Please make sure all the arguments have the same DeviceMesh." + ) + + return mesh + + def is_inplace_op(self) -> bool: + # simple analysis of function schema to determine + # if this is an inplace variant, it might not + # be entirely correct, but it's good enough for now. + return self.op._schema.name[-1] == "_" + + def is_out_variant_op(self) -> bool: + # simple analysis of function schema to determine + # if this is an out variant, it might not + # be entirely correct, but it's good enough for now. + return "out" in self.op._schema.overload_name + + def is_view_op(self) -> bool: + return self.op._schema._is_view_op() + + def _recompute_comparison_key(self) -> None: + _DTensor_OpSchema_recompute_comparison_key(self) + + def __hash__(self) -> int: + return hash(self._comparison_key) + + def __eq__(self, other: object) -> bool: + # early return checks + if not isinstance(other, OpSchema): + return False + + if self.op != other.op: + return False + + if len(self.args_schema) != len(other.args_schema): + return False + + return self._comparison_key == other._comparison_key + + def gen_fake_args(self) -> ArgsType: + """ + gen_fake_args: generate fake args for the operator, this is mainly used + by sharding propagation rules to generate fake args for the operator + to run the local tensor operator and get the output spec. + """ + return tree_map_only( + DTensorSpec, + _rebuild_tensor_from_dtensor_meta, + self.args_schema, + is_leaf=lambda x: isinstance(x, DTensorSpec), + ) + + def gen_fake_kwargs(self) -> KwargsType: + """ + gen_fake_kwargs: generate fake kwargs for the operator, this is mainly used + by sharding propagation rules to generate fake kwargs for the operator + to run the local tensor operator and get the output spec. + """ + return tree_map_only( + DTensorSpec, + _rebuild_tensor_from_dtensor_meta, + self.kwargs_schema, + is_leaf=lambda x: isinstance(x, DTensorSpec), + ) + + def _inplace_rewrap_schema_suggestion(self, origin_schema: "OpSchema") -> None: + suggestion_args_spec = self.args_spec + new_arg_schema: list[object] = [] + idx_of_args_spec = 0 + if ( + origin_schema.schema_info is not None + and origin_schema.schema_info.needs_pytree + ): + args_schema: Sequence[Any] = tree_leaves(origin_schema.args_schema) + else: + args_schema = origin_schema.args_schema + for arg in args_schema: + if isinstance(arg, DTensorSpec): + new_arg_schema.append(suggestion_args_spec[idx_of_args_spec]) + idx_of_args_spec += 1 + else: + new_arg_schema.append(arg) + self.args_schema = tuple(new_arg_schema) + self.kwargs_schema = origin_schema.kwargs_schema + self._recompute_comparison_key() + + +@dataclass +class OutputSharding: + """ + OutputSharding is a data class that is used by the sharding propagation, + it could set the output_spec upon successful propagation. If needs_redistribute + is set to True, a redistribute_schema would be returned together to indicate + the input arguments needs to be redistributed before the op execution. + + NOTE: the redistribute_schema generated by sharding propagation should be + exactly the same as the operator OpSchema, except the DTensorSpecs + """ + + # specifies the output sharding pattern + output_spec: OutputSpecType + # schema for redistribution if needed + redistribute_schema: OpSchema | None = None + # flag indicating if inputs need redistribution + needs_redistribute: bool = False + # flag to use values from `redistribute_schema` + use_val_from_redistribute_schema: bool = False + + @cached_property + def mesh(self): + if isinstance(self.output_spec, DTensorSpec): + return self.output_spec.mesh + elif isinstance(self.output_spec, tuple): + out_spec = self.output_spec[0] + if isinstance(out_spec, DTensorSpec): + return out_spec.mesh + else: + raise ValueError(f"Unknown output spec type: {type(out_spec)}") + else: + raise ValueError(f"Unknown output spec type: {type(self.output_spec)}") + + +@dataclass +class OpInfo: + """ + All Runtime Op execution info are packed here + """ + + # The first compute device mesh recorded from args + # NOTE: one op could have multiple meshes from its args. We just record the first + # mesh here to check if current rank should participate in computation or not. + compute_mesh: DeviceMesh + + # compete runtime operator infos + schema: OpSchema + flat_args_schema: list[object] + local_args: Sequence[object] + local_kwargs: dict[str, object] + args_tree_spec: TreeSpec | None = None + + # the output sharding info + output_sharding: OutputSharding | None = None diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/_ops/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/_ops/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..64a490b3a5147c1d685adf88511ab0ddf068f567 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/_ops/__pycache__/__init__.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/_ops/__pycache__/_common_rules.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/_ops/__pycache__/_common_rules.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a02fc56013f9207d510a83ddf084b0d4ed0804b7 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/_ops/__pycache__/_common_rules.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/_ops/__pycache__/_conv_ops.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/_ops/__pycache__/_conv_ops.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b68ba780563de04930296e104df75f390085de40 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/_ops/__pycache__/_conv_ops.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/_ops/__pycache__/_einsum_strategy.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/_ops/__pycache__/_einsum_strategy.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0fc1041bf8f08af3952af2208867a37ff14c38bc Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/_ops/__pycache__/_einsum_strategy.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/_ops/__pycache__/_embedding_ops.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/_ops/__pycache__/_embedding_ops.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9fdb48628afcdcea8d80801d15e65dff41b86351 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/_ops/__pycache__/_embedding_ops.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/_ops/__pycache__/_mask_buffer.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/_ops/__pycache__/_mask_buffer.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5197c3e27cb8b173dc8356f569450386c1f62390 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/_ops/__pycache__/_mask_buffer.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/_ops/__pycache__/_math_ops.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/_ops/__pycache__/_math_ops.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ccaa893f3c990bd9dd16b52cd75f692f49c47940 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/_ops/__pycache__/_math_ops.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/_ops/__pycache__/_matrix_ops.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/_ops/__pycache__/_matrix_ops.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..49d8487af807c5fbebeed705f5c8b8e33249f5be Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/_ops/__pycache__/_matrix_ops.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/_ops/__pycache__/_pointwise_ops.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/_ops/__pycache__/_pointwise_ops.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7bb795a1c0a64f3f4d1d41cf8cf224ce9858f7be Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/_ops/__pycache__/_pointwise_ops.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/_ops/__pycache__/_random_ops.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/_ops/__pycache__/_random_ops.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..566e0824fef8fb79cbd20d8f5d3d59e4323d4029 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/_ops/__pycache__/_random_ops.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/_ops/__pycache__/_tensor_ops.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/_ops/__pycache__/_tensor_ops.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8d8dc5b9874bbf2a97bb7dee051a1ed3bff2253e Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/_ops/__pycache__/_tensor_ops.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/_ops/__pycache__/_view_ops.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/_ops/__pycache__/_view_ops.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0fc560b137172375e1c13bf7441bf6b205723879 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/_ops/__pycache__/_view_ops.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/_ops/__pycache__/registration.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/_ops/__pycache__/registration.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e054b6d9e2e593608fa7f7cbb2d4cc7efdefc22a Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/_ops/__pycache__/registration.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/_ops/__pycache__/utils.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/_ops/__pycache__/utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d32199f27dfe506d01fcf321fd0a995ee7d637a2 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/_ops/__pycache__/utils.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/_random.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/_random.py new file mode 100644 index 0000000000000000000000000000000000000000..995a057b0c7faa085ae94d67e11d7603d057e05a --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/_random.py @@ -0,0 +1,478 @@ +# mypy: allow-untyped-defs +# Copyright (c) Meta Platforms, Inc. and affiliates +import contextlib +import warnings +from logging import getLogger +from typing import Optional + +import torch +from torch.distributed._local_tensor import maybe_run_for_local_tensor +from torch.distributed.device_mesh import _get_device_handle, DeviceMesh +from torch.distributed.tensor._dtensor_spec import DTensorSpec +from torch.distributed.tensor.placement_types import _StridedShard, Shard + + +logger = getLogger(__name__) + +__all__ = [ + "is_rng_supported_mesh", + "manual_seed", + "OffsetBasedRNGTracker", +] + +_rng_tracker: Optional["_RNGStateTracker"] = None + + +def is_rng_supported_mesh(device_mesh: DeviceMesh) -> bool: + """Checks if the current device of ``device_mesh`` supports DTensor's random APIs. + Currently DTensor Random APIs only supports cuda/cuda-like devices. We suggest + users call this API to test the availability before using our random APIs. + + Args: + device_mesh (:class:`DeviceMesh`): The device mesh on which we check if the + random ops APIs are supported. + + Returns: + A bool value. True if ``device_mesh`` supports DTensor Random APIs; False otherwise. + + .. warning:: + Currently we only support correct RNG on cuda/cuda-like devices. + """ + device_handle = _get_device_handle(device_mesh.device_type) + if device_handle and hasattr(device_handle, "set_rng_state"): + return True + else: + # TODO: Logs way too much + warnings.warn( + f"DTensor random operators may not have complete support on {device_mesh.device_type} device mesh", + stacklevel=2, + ) + return False + + +def manual_seed(seed: int, device_mesh: DeviceMesh) -> None: + """Sets the seed for generating random numbers for the calling rank. + + Args: + seed (int): The desired seed. + device_mesh (:class:`DeviceMesh`): The device mesh to set the seed. It is + required that the ``device_mesh`` include the calling rank. This is + to ensure that the SPMD region maintains a synchronous RNG state, which + means no ranks should be initialized with values other than ``seed``. + + Returns: + None + + .. warning:: + :func:`manual_seed` does not check the ``seed`` value correctness. Users must + ensure on their own that the value passed in is the desired ``seed`` for ranks + within ``device_mesh``. + If ``device_mesh`` is a sub-mesh and the calling rank is not a part of it, + ``manual_seed`` will throw an error. + Current implementation only supports a GPU device mesh. + """ + if not is_rng_supported_mesh(device_mesh): + warnings.warn( + "DTensor manual_seed() may not have complete support " + f"on {device_mesh.device_type} device mesh", + stacklevel=2, + ) + return + + # TODO: deprecate this API, but also need to ensure we disable broadcast for PP case, and that's currently + # bundled together with this API. See torchtitan/distributed/utils.py:set_determinism + # warnings.warn( + # "DTensor manual_seed() is deprecated, since DTensor no longer maintains a separate copy of generator state. " + # "Use `torch.manual_seed` instead" + # ) + # Note: we still need to ensure setting `run_state_sync=False` to support the pp case + + # instantiate a RNG tracker if haven't. By default DTensor uses an + # OffsetBasedRNGTracker to perform random operators. + global _rng_tracker + if not _rng_tracker: + _rng_tracker = OffsetBasedRNGTracker(device_mesh, run_state_sync=False) + + if device_mesh.get_coordinate() is None: + raise RuntimeError( + "manual_seed requires the current rank to be a part of the device mesh " + "otherwise DTensor RNG state on the rank will not be initialized and " + "the behavior of DTensor random ops is undefined." + ) + + # DTensor no longer maintains a copy of rng state. manual seed on dtensor is the same thing + # as manual seed on torch. + # + # torch.manual_seed will handle LocalTensor mode correctly by + # iterating through all ranks if seed is a LocalIntNode. + torch.manual_seed(seed) + + +class _PhiloxState: + """ + Convenience accessor for interpreting the packed bits of (seed: uint64, offset: uint64) in the philox state, + which for some reason is actually exposed as a size-16 uint8 tensor. + + The state is always moved to .cpu since it is necessary for it to be on CPU before applying it back to a generator. + """ + + def __init__(self, state: torch.Tensor): + self._state = state.to("cpu") + + @property + def state(self): + return self._state + + @property + def offset(self) -> int: + return int(self._state[8:].view(dtype=torch.int64).item()) + + @offset.setter + def offset(self, offset: int) -> None: + offset_tensor = torch.tensor([offset], dtype=torch.uint64, device="cpu").view( + torch.uint8 + ) + self._state[8:] = offset_tensor + + @property + def seed(self) -> int: + return int(self._state[:8].view(dtype=torch.uint64).item()) + + @seed.setter + def seed(self, seed: int) -> None: + seed_tensor = torch.tensor([seed], dtype=torch.uint64, device="cpu").view( + torch.uint8 + ) + self._state[:8] = seed_tensor + + +class _RNGStateTracker: + """ + _RNGStateTracker stores Random Number Generator (RNG) state (a ByteTensor object) + in a dict, mapping from a corresponding tag to each state tensor. It also provides + a set of convenient utility methods to help access/modify the state tensors. The most + important interface is _distribute_region which will be used when DTensor executes + a random op (an operator that calls RNG). + """ + + def __init__(self, device: torch.device): + # pyrefly: ignore [read-only] + self._device = device + self._device_handle = _get_device_handle(self._device.type) + if not (self._device_handle and self._device_handle.is_available()): + raise RuntimeError( + f"{self.__class__.__name__} instantiation requires the presence of " + f"{device.type} device but couldn't find." + ) + self._use_distribute_region = True + + @property + def distribute_region_enabled(self) -> bool: + return self._use_distribute_region + + @distribute_region_enabled.setter + def distribute_region_enabled(self, value) -> None: + self._use_distribute_region = value + + def _distribute_region( + self, spec: DTensorSpec, generator: torch.Generator | None = None + ): + pass + + def _manual_seed(self, parallel_seed: int) -> None: + pass + + +class OffsetBasedRNGTracker(_RNGStateTracker): + """ + This subclass of ``_RNGStateTracker`` defines the default policy of how RNG states + should be shared and synchronized among all ranks to respect the semantics of DTensor + random operators. + + note: _RNGStateTracker only supports cuda/cuda-like device. + """ + + def __init__( + self, + device_mesh: DeviceMesh, + run_state_sync: bool = True, + ): + super().__init__(_resolve_device(device_mesh=device_mesh)) + assert self._device_handle is not None + # DTensor RNG tracker so far only supports CUDA/CUDA-like devices + if self._device.type == "cpu": + raise RuntimeError( + f"{self.__class__.__name__} instantiation requires the presence of " + f"CUDA/CUDA-like/XPU device. Got {self._device.type} instead." + ) + + rng_state = self._get_device_state() + if run_state_sync: + # synchronize RNG state using rank 0's current one + torch.distributed.broadcast(rng_state, 0) + my_rng_state = self._get_device_state() + if not all(my_rng_state == rng_state): + logger.warning( + "DTensor is synchronizing RNG states of every rank with the state from rank 0. " + "This behavior is deprecated. " + "Please call `torch.manual_seed()` on every rank that participates in SPMD DTensor Operations with " + "the same seed. If using Pipeline Parallelism, each pipeling state would use a different seed, " + "but all ranks belonging to one pipeline stage would use the same seed." + ) + self._set_device_state(rng_state) + + def _get_device_state(self) -> torch.Tensor: + if self._device.type == "hpu": + self._device_handle.set_rng_ctx("philox") + rng_state = self._device_handle.get_rng_state().to(self._device) + if self._device.type == "hpu": + self._device_handle.unset_rng_ctx("philox") + return rng_state + + def _set_device_state(self, state: torch.Tensor): + # It seems that the underlying generator wants a cpu tensor but the dtensor code expects `_get_device_state` + # to convert to a 'device' tensor, probably because we may use it with our backend comms for sync/debug + # for now, we just convert back to cpu here to make sure it always works. + if self._device.type == "hpu": + self._device_handle.set_rng_ctx("philox") + self._device_handle.set_rng_state(state.to("cpu")) + if self._device.type == "hpu": + self._device_handle.unset_rng_ctx("philox") + + @contextlib.contextmanager + def _distribute_region( + self, spec: DTensorSpec, generator: torch.Generator | None = None + ): + from torch.distributed._local_tensor import maybe_enable_local_tracker + + if local_tracker_context := maybe_enable_local_tracker( + self._device.type, self.distribute_region_enabled, spec, generator + ): + with local_tracker_context: + yield + return + + # regular (non-LocalTensor) mode + if generator is not None: + # This is a little hacky, but for any user-passed generator, we store its state under a unique key, + # not because we need to keep a copy of it but because its the easiest way to make it work with the + # existing set/get APIs. We also ensure we remove it from rng_states after each _distribute_region. + state = _PhiloxState(generator.get_state()) + else: + state = _PhiloxState(self._get_device_state()) + + if self.distribute_region_enabled: + if self._device.type == "hpu": + self._device_handle.set_rng_ctx("philox") + old_offset = state.offset + self._set_pre_op_offset(state, spec) + with torch.random.fork_rng( + devices=[self._device], device_type=self._device.type + ): + assert self._device_handle is not None + self._device_handle.set_rng_state(state.state) + try: + yield # execute the region code + finally: + # update offset to synchronize among ranks + self._set_post_op_offset(state, spec, old_offset) + if self._device.type == "hpu": + self._device_handle.unset_rng_ctx("philox") + else: + yield + + if generator is not None: + # ensure we (a) propagate the state advancement back to the user's RNG so its visible and impacts any future + # usage of that RNG (dtensor or non-dtensor), (b) drop it from our own cache so that if the user updates + # the seed value in their rng and uses it with DTensor again, we always use the latest value + generator.set_state(state.state) + else: + self._set_device_state(state.state) + + def _set_pre_op_offset(self, state: _PhiloxState, spec: DTensorSpec) -> None: + """Set the starting RNG offset for current device's local shard before actual + op execution. The pre_op_offset value should start from the current RNG offset + and increment by the size of local shard until it reaches the size of the whole + DTensor. For different ranks that hold the same DTensor shard, their pre_op_offset + will be the same. + + Args: + state (:class:`Tensor`): The generator state to modify + spec (:class:`DTensorSpec`): the spec of the DTensor object on which + we prepare the offset for running random ops. + + Returns: + None + + .. warning:: + Note that, current implementation does not consider DTensor's continguity. + + Example: + take a DTensor of shape [8, 16] as an example. Assume that the DTensor + is placed on a device mesh with placements ([Shard(1), Replicate(), Shard(0)]), + and the mesh is: + [[[0, 1], [2, 3]], [[4, 5], [6, 7]]] + ``spec.mesh.get_coordinate()`` provides the coordinate of the current rank + in the mesh. For example, the coordinate of rank 5 is (1, 0, 1). + + Another concept to introduce besides rank coordinate is shard coordinate. + Each rank holds a local shard of the DTensor. In the example, the DTensor + is partitioned into 4 [4, 8] shards. The first shard has 2 replicas and + rank 0 (coord (0, 0, 0)) and rank 2 (coord (0, 1, 0)) have 1 replica each. + That being said, the local shard on rank 0 and rank 2 correspond to the same + shard of the DTensor. To denote each DTensor shard, we use a shard coordinate + (in the example, it will be a tuple (i, j) where shard (i, j) has the slice + DTensor[4 * i : 4 * (i + 1), 8 * j : 8 * (j + 1)], 0 <= i < 2, 0 <= j < 2). + + Once we have rank coordinate and shard coordinate, we can calculate on each rank + what shard of the DTensor the rank holds, with the help of dim_map. The dim_map + of the above DTensor is [2, 0] so the shard coordinate of a rank with rank coord + (x, y, z) is simply (z, x) by taking(rank_coord[dim_map[0]],rank_coord[dim_map[1]]). + Following this calculation, + rank 0 and rank 2 holds the shard of coord (0, 0); + rank 1 and rank 3 holds the shard of coord (0, 1); + rank 4 and rank 6 holds the shard of coord (1, 0); + rank 5 and rank 7 holds the shard of coord (1, 1); + + The last value to calculate before obtaining the starting offset is the shard linear index. + The starting offset for each rank will be its shard_linear_index * local_tensor_numel. + """ + mesh = spec.mesh + mesh_coordinate = mesh.get_coordinate() + assert mesh_coordinate is not None + + # Compute shard index and total number of shards on each tensor dim + shard_idx_by_dim, total_num_shards_by_dim = _calc_shard_info( + mesh_coordinate, spec + ) + + # compute shard linear index + shard_linear_idx = self._calc_shard_linear_idx( + shard_idx_by_dim, total_num_shards_by_dim + ) + + # compute starting offset using the first shard's size + local_size_on_rank_0 = _calc_first_shard_size(spec) + + from torch.distributed.tensor._ops.utils import prod + + local_size = prod(local_size_on_rank_0) + + # get current RNG offset + current_offset = state.offset + + # pytorch: offset must be multiple of 4 + # source: aten/src/ATen/cuda/CUDAGeneratorImpl.cpp + offset_incr = (shard_linear_idx * local_size + 3) // 4 * 4 + state.offset = current_offset + offset_incr + + def _set_post_op_offset( + self, state: _PhiloxState, spec: DTensorSpec, old_offset: int + ) -> None: + """Sets the RNG to a synchronized state after running the local random op. Every + rank should set its RNG offset to `old_offset + DTensor.numel()` where old_offset is + the offset before calling `set_pre_op_offset` i.e. the offset before running DTensor + random ops. + + Args: + state (:class:`Tensor`): The generator state to modify. + spec (:class:`DTensorSpec`): the spec of the DTensor object on which + we post-process the offset for running random ops. + + Returns: + None + """ + dtensor_shape = spec.shape + + from torch.distributed.tensor._ops.utils import prod + + numel = prod(dtensor_shape) + # pytorch: offset must be multiple of 4 + # source: aten/src/ATen/cuda/CUDAGeneratorImpl.cpp + numel = (numel + 3) // 4 * 4 + state.offset = old_offset + numel + + def _calc_shard_linear_idx( + self, shard_coord: list[int], shard_size: list[int] + ) -> int: + return _calc_shard_linear_idx(shard_coord, shard_size) + + +def _calc_first_shard_size(spec: DTensorSpec) -> list[int]: + local_size_on_rank_0 = list(spec.shape) + for idx, placement in enumerate(spec.placements): + if isinstance(placement, Shard | _StridedShard): + mesh_dim_size = spec.mesh.size(idx) + shard_dim = placement.dim + local_size_on_rank_0[shard_dim], _ = placement._local_shard_size_and_offset( + spec.shape[shard_dim], + mesh_dim_size, + 0, + ) + return local_size_on_rank_0 + + +def _calc_shard_info( + mesh_coordinate: list[int], spec: DTensorSpec +) -> tuple[list[int], list[int]]: + mesh = spec.mesh + # note: dim_map does not allow double sharding which is the FSDP(fully_shard)+TP + # case. Replace the custom logic with dim_map once we support it. + dim_map: list[int | list[int]] = [-1] * spec.ndim + for i, placement in enumerate(spec.placements): + if isinstance(placement, Shard | _StridedShard): + shard_dim = placement.dim + if dim_map[shard_dim] == -1: + dim_map[shard_dim] = [i] + else: + mesh_dim_list = dim_map[shard_dim] + assert isinstance(mesh_dim_list, list) + mesh_dim_list.append(i) + + # Compute shard coordinate: + # The coordinate on each tensor dim is a tuple (idx, range) + # If a DTensor is partitioned on its dim i into n shards, and the current rank + # holds the j-th, then its shard coordinate will be (idx=j, range=n) on dim i + assert mesh_coordinate is not None + mesh_size = mesh.shape + shard_idx_by_dim = [] + total_num_shards_by_dim = [] # total number of shards on each tensor dim + for mesh_dim in dim_map: + shard_idx = 0 + total_num_shards = 1 + # the tensor dim is sharded on more than 1 mesh dim + if isinstance(mesh_dim, list): + rank_coord = [mesh_coordinate[d] for d in mesh_dim] + num_shards = [mesh_size[d] for d in mesh_dim] + # compute the shard idx and total number of shards + for idx, size in zip(rank_coord, num_shards): + shard_idx = shard_idx * size + idx + total_num_shards *= size + + shard_idx_by_dim.append(shard_idx) + total_num_shards_by_dim.append(total_num_shards) + return shard_idx_by_dim, total_num_shards_by_dim + + +def _calc_shard_linear_idx(shard_coord: list[int], shard_size: list[int]) -> int: + # compute shard linear index + shard_linear_idx = 0 + shard_coord_stride = 1 + for idx, size in zip(reversed(shard_coord), reversed(shard_size)): + shard_linear_idx += idx * shard_coord_stride + shard_coord_stride *= size + + return shard_linear_idx + + +def _resolve_device(device_mesh: DeviceMesh) -> torch.device: + device_type = device_mesh.device_type + device_handle = _get_device_handle(device_type) + assert device_handle is not None + device_idx = device_mesh.get_rank() % device_handle.device_count() + + @maybe_run_for_local_tensor + def get_device(device_idx): + return torch.device(f"{device_type}:{device_idx:d}") + + return get_device(device_idx) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/_redistribute.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/_redistribute.py new file mode 100644 index 0000000000000000000000000000000000000000..7119fd9ae6529c174f7a34f55145434a35070e2a --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/_redistribute.py @@ -0,0 +1,1067 @@ +# mypy: allow-untyped-defs +# Copyright (c) Meta Platforms, Inc. and affiliates +import contextlib +import dataclasses +import itertools +import logging +import weakref +from collections import defaultdict +from collections.abc import Sequence +from functools import cache +from typing import cast, NamedTuple, Optional + +import torch +import torch.distributed._functional_collectives as funcol +import torch.distributed.tensor._api as dtensor +from torch.distributed._functional_collectives import _are_we_tracing +from torch.distributed.tensor._dtensor_spec import ( + DTensorSpec, + ShardOrder, + ShardOrderEntry, + TensorMeta, +) +from torch.distributed.tensor.device_mesh import DeviceMesh +from torch.distributed.tensor.placement_types import ( + _StridedShard, + Partial, + Placement, + Replicate, + Shard, +) +from torch.utils._debug_mode import get_active_debug_mode + + +logger = logging.getLogger(__name__) + +# Global configuration flag to control the redistribution planning strategy. +# When True, forces the graph-based algorithm using Dijkstra's shortest path. +# When False, prefers the greedy algorithm for faster planning. Uses the graph-based algorithm +# only when necessary to support strided-shard redistribution +_FORCE_MIN_COST_REDISTRIBUTION_PLAN: Optional[bool] = None + + +@contextlib.contextmanager +def use_min_cost_redistribution_plan(enabled: bool = True): + """ + Context manager to control the redistribution planning strategy for DTensor operations. + + This context manager allows you to choose between two algorithms for computing the + sequence of collective operations needed to redistribute a DTensor from one placement + to another: + + - **Graph-based**: Uses Dijkstra's algorithm to find the minimum-cost path + through all possible placement transformations. This approach considers the global + cost of all collective operations and finds the optimal sequence. Best for complex + redistribution patterns where reducing communication cost and memory overhead is critical. + + - **Greedy**: Uses a heuristic approach that makes locally optimal choices + at each step. This is faster to compute but may not produce the globally optimal + transformation sequence. Best for simple redistribution patterns or when planning + speed is more important than optimal communication. + + **Default Behavior (without this context manager):** + + When this context manager is NOT used, the algorithm selection follows this priority: + + 1. **Non-default shard orders** + → Always use graph-based algorithm (required for correctness) + + 2. **Explicit `use_graph_based_transform` parameter** to `_gen_transform_infos_non_cached` + → Use the specified algorithm (True = graph-based, False = greedy) + + 3. **No explicit parameter** (default case) + → Use greedy algorithm for faster planning + + **Behavior with this context manager:** + + This context manager overrides the default selection by setting the global flag + `_FORCE_MIN_COST_REDISTRIBUTION_PLAN`, which takes precedence over the explicit + `use_graph_based_transform` parameter (but not over non-default shard order requirements). + + **Cache Considerations:** + + The redistribution planner caches transform info for performance via the `@cache` + decorator on `_gen_transform_infos`. If you need to change the algorithm selection + for the same input specs, clear the cache using `_gen_transform_infos.cache_clear()` + to ensure the new setting takes effect and doesn't reuse cached results from a + previous run. + + Args: + enabled (bool): If True, forces the use of the graph-based algorithm. + If False, forces the use of the greedy algorithm. + Default: True + """ + global _FORCE_MIN_COST_REDISTRIBUTION_PLAN + old_value = _FORCE_MIN_COST_REDISTRIBUTION_PLAN + _FORCE_MIN_COST_REDISTRIBUTION_PLAN = enabled + try: + yield + finally: + _FORCE_MIN_COST_REDISTRIBUTION_PLAN = old_value + + +class _TransformInfo(NamedTuple): + mesh_dim: int + src_dst_placements: tuple[Placement, Placement] + # logical_shape on this mesh dimension + logical_shape: list[int] + + +# Global cache for DTensorRedistributePlanner instances +_planner_cache: dict[ + tuple[weakref.ReferenceType, int], "DTensorRedistributePlanner" +] = {} + + +def get_redistribute_planner( + device_mesh: DeviceMesh, tensor_dimension: int +) -> "DTensorRedistributePlanner": + """ + Factory function to get or create a DTensorRedistributePlanner instance. + This function provides transparent caching of planner instances based on + device_mesh and tensor_dimension. Multiple calls with the same parameters + will return the same cached instance for better performance. + Args: + device_mesh: The device mesh for the planner + tensor_dimension: Number of tensor dimensions + Returns: + A DTensorRedistributePlanner instance (potentially cached) + """ + cache_key = (weakref.ref(device_mesh), tensor_dimension) + + if cache_key not in _planner_cache: + planner = DTensorRedistributePlanner(device_mesh, tensor_dimension) + _planner_cache[cache_key] = planner + + return _planner_cache[cache_key] + + +def clear_redistribute_planner_cache() -> None: + """Clear the cache of DTensorRedistributePlanner instances.""" + _planner_cache.clear() + + +class DTensorRedistributePlanner: + """ + This class is used to plan the collective calls to transform the local shard + of the DTensor from its current spec to the target spec. + Suppose there are N tensor dimensions and M mesh dimensions, the total + possible state size will be (N+2)*M*M!. + Note: Use get_redistribute_planner() factory function instead of direct + instantiation for automatic caching. + """ + + @dataclasses.dataclass(frozen=True, slots=True) + class DistState: + placements: tuple[Placement, ...] + tensor_dim_to_mesh_dim: ShardOrder + _hash: int | None = dataclasses.field( + default=None, init=False, repr=False, compare=False + ) + + def __str__(self): + return DTensorSpec.format_shard_order_str( + self.placements, + self.tensor_dim_to_mesh_dim, + ) + + def __repr__(self): + return self.__str__() + + def __post_init__(self): + # precompute hash after all attributes are set + object.__setattr__( + self, + "_hash", + self._compute_hash(), + ) + + def __hash__(self) -> int: + return self._hash if self._hash is not None else self._compute_hash() + + def _compute_hash(self) -> int: + return hash( + ( + self.placements, + self.tensor_dim_to_mesh_dim, + ) + ) + + def __eq__(self, other: object) -> bool: + if not isinstance(other, DTensorRedistributePlanner.DistState): + return False + if self._hash != other._hash: + return False + return ( + self.placements, + self.tensor_dim_to_mesh_dim, + ) == ( + other.placements, + other.tensor_dim_to_mesh_dim, + ) + + def _to_tuple(self, x): + """Convert a nested list structure to a nested tuple structure.""" + if isinstance(x, list | tuple): + return tuple(self._to_tuple(item) for item in x) + return x + + @staticmethod + def _dict_to_ShardOrder(x: dict[int, list[int]]) -> ShardOrder: + """Convert dict to ShardOrder""" + return tuple( + ShardOrderEntry(tensor_dim=key, mesh_dims=tuple(value)) + for key, value in sorted(x.items()) + if value + ) + + @staticmethod + def _ShardOrder_to_dict(x: ShardOrder) -> dict[int, list[int]]: + """Convert ShardOrder to dict with tensor dim as key""" + tensor_mesh_dim_dict = defaultdict(list) + for entry in x: + tensor_mesh_dim_dict[entry.tensor_dim] = list(entry.mesh_dims) + return tensor_mesh_dim_dict + + @staticmethod + def stringify_transform_infos( + mesh: DeviceMesh, + transform_infos: Sequence[_TransformInfo], + src_placement: tuple[Placement, ...], + src_shard_order: ShardOrder | None = None, + ) -> str: + """ + Generate a string representation of the sequence of state transitions + (placements and shard orders) as described by the given transform_info. + + Args: + mesh: The DeviceMesh used for the redistribution. + transform_infos: A sequence of _TransformInfo objects describing each + transformation step. + src_placement: The initial tuple of Placement objects. + src_shard_order: (Optional) The initial ShardOrder representing + the mapping of tensor dimensions to mesh dimensions. If None, + the default shard order is computed from src_placement and mesh. + + Returns: + A string showing the sequence of DistState transitions, separated by '->'. + """ + assert len(src_placement) == mesh.ndim + if src_shard_order is None: + src_shard_order = DTensorSpec.compute_default_shard_order(src_placement) + cur_placement = list(src_placement) + shard_order_dict = DTensorRedistributePlanner._ShardOrder_to_dict( + src_shard_order + ) + cur_state = DTensorRedistributePlanner.DistState( + tuple(cur_placement), src_shard_order + ) + state_list = [ + cur_state, + ] + for transform_info in transform_infos: + src_dim_placement, dst_dim_placement = transform_info.src_dst_placements + if src_dim_placement.is_shard(): + src_dim = src_dim_placement.dim # type: ignore[attr-defined] + assert ( + src_dim in shard_order_dict and len(shard_order_dict[src_dim]) > 0 + ) + shard_order_dict[src_dim].pop() + if dst_dim_placement.is_shard(): + dst_dim = dst_dim_placement.dim # type: ignore[attr-defined] + if dst_dim not in shard_order_dict: + shard_order_dict[dst_dim] = [] + shard_order_dict[dst_dim].append(transform_info.mesh_dim) + cur_placement[transform_info.mesh_dim] = dst_dim_placement + new_state = DTensorRedistributePlanner.DistState( + tuple(cur_placement), + DTensorRedistributePlanner._dict_to_ShardOrder(shard_order_dict), + ) + state_list.append(new_state) + return "->".join([str(s) for s in state_list]) + + def __init__( + self, + device_mesh: DeviceMesh, + tensor_dimension: int, + ) -> None: + """ + Initialize DTensorRedistributePlanner. + + Args: + device_mesh: The device mesh for this planner + tensor_dimension: Number of tensor dimensions + """ + self.device_mesh = device_mesh + self.coordinate = device_mesh.get_coordinate() + assert self.coordinate is not None + self.tensor_dimension = tensor_dimension + self.setup_collective_cost() + + def setup_collective_cost( + self, + all_reduce_cost: int = 4, + all_to_all_cost: int = 1, + all_gather_cost: int = 2, + reduce_scatter_cost: int = 2, + chunk_cost: int = 0, + ) -> None: + """ + Set up the cost weights for different collective operations. + """ + # those can be turned in a handler considering the tensor dim size + self.all_reduce_cost = all_reduce_cost + self.all_to_all_cost = all_to_all_cost + self.all_gather_cost = all_gather_cost + self.reduce_scatter = reduce_scatter_cost + self.chunk_cost = chunk_cost + + def get_next_state( + self, + placements: tuple[Placement, ...], + tensor_mesh_dim_tuple: ShardOrder, + ) -> dict["DTensorRedistributePlanner.DistState", int]: + # We map tensor dimensions to device mesh axes, similar to JAX-style + # sharding representation. Notation: + # S()[] means tensor dimension + # is sharded on the listed device mesh axes, where + # is sorted by device order. + # + # To generalize to arbitrary dimensionality, we use the following notation: + # S(a)[x, ...] : tensor dimension 'a' is sharded on device mesh axes x, ... (variadic, possibly empty) + # R[...] : replicated on the listed device mesh axes (possibly empty) + # P[...] : partial on the listed device mesh axes (possibly empty) + # The ellipsis '...' denotes a variadic wildcard, i.e., zero or more device mesh axes. + # + # Below are possible transitions from one sharding state to another. + # We use `S` for Shard, `R` for Replicate, and `P` for Partial. + # + # Case 1. Shard(a) -> Shard(b), use all-to-all (a2a), applies to: + # S(a)[..., x] -> S(b)[..., x] + # or + # S(a)[..., x, y]S(b)[..., z, k] -> S(a)[..., x]S(b)[..., z, k, y] + # where device order of 'y' > device order of 'z' and 'k' + # + # Case 2. Shard() -> Replicate(), use all-gather, applies to: + # S(a)[..., x, y, z] -> S(a)[..., x, y] + # + # Case 3. Partial() -> Replicate(), use all-reduce, applies to: + # P[..., x, y] -> P[..., y] or P[..., x] + # Note: this case can be disabled because all-reduce technically is not + # a primitive since it combines a reduce-scatter + all-gather. + # + # Case 4. Replicate() -> Shard(), use chunk, applies to: + # S(a)[..., z] -> S(a)[..., z, y] (`a` can be any tensor dim). Note that + # 'y' must be after 'z'. + # + # Case 5. Partial() -> Shard(), use reduce-scatter, applies to: + # P[..., x, y] -> P[..., x]S(a)[..., y] or P[..., x, y] -> P[..., y]S(a)[..., x] + # + # Case 6. Replicate() -> Partial(), local math op, applies to: + # R* -> P[..., x] + # + # NB: Device order in Partial placement doesn't take impact. We should be able + # to operate on any Partial mesh dim. + + # list of [DistState, cost] + all_next_state: dict[DTensorRedistributePlanner.DistState, int] = {} + + tensor_mesh_dim_dict = DTensorRedistributePlanner._ShardOrder_to_dict( + tensor_mesh_dim_tuple + ) + ###################################################################### + # handle case 1: Shard(a) -> Shard(b) + # For S(a), S(b), only the last device order of S(a) and S(b) can be a2a + # interchangeably. + + # convert sparse tuple + for entry in tensor_mesh_dim_tuple: + src_tensor_dim = entry.tensor_dim + for dst_tensor_dim in range(self.tensor_dimension): + if src_tensor_dim == dst_tensor_dim: + continue + # try move the last sharded device dim from + # Shard(src_tensor_dim) to Shard(dst_tensor_dim) + move_mesh_dim = tensor_mesh_dim_dict[src_tensor_dim].pop() + tensor_mesh_dim_dict[dst_tensor_dim].append(move_mesh_dim) + new_placements = list(placements) + new_placements[move_mesh_dim] = Shard(dst_tensor_dim) + dist_state = self.DistState( + self._to_tuple(new_placements), + DTensorRedistributePlanner._dict_to_ShardOrder( + tensor_mesh_dim_dict + ), + ) + all_next_state[dist_state] = self.all_to_all_cost + # reset content for next iteration + tensor_mesh_dim_dict[src_tensor_dim].append(move_mesh_dim) + tensor_mesh_dim_dict[dst_tensor_dim].pop() + # TODO(zpcore): support discovering submesh to prevent padding when + # tensor dim is not divisible by the mesh dim. + + ###################################################################### + # handle case 2: Shard() -> Replicate() + for entry in tensor_mesh_dim_tuple: + src_tensor_dim = entry.tensor_dim + move_mesh_dim = tensor_mesh_dim_dict[src_tensor_dim].pop() + new_placements = list(placements) + new_placements[move_mesh_dim] = Replicate() + dist_state = self.DistState( + self._to_tuple(new_placements), + DTensorRedistributePlanner._dict_to_ShardOrder(tensor_mesh_dim_dict), + ) + tensor_mesh_dim_dict[src_tensor_dim].append(move_mesh_dim) + all_next_state[dist_state] = self.all_gather_cost + + ###################################################################### + # handle case 3: Partial() -> Replicate() + for src_mesh_dim, placement in enumerate(placements): + if not isinstance(placement, Partial): + continue + new_placements = list(placements) + new_placements[src_mesh_dim] = Replicate() + dist_state = self.DistState( + self._to_tuple(new_placements), tensor_mesh_dim_tuple + ) + all_next_state[dist_state] = self.all_reduce_cost + + ###################################################################### + # handle case 4: Replicate() -> Shard() + for mesh_dim, placement in enumerate(placements): + if not isinstance(placement, Replicate): + continue + for dst_tensor_dim in range(self.tensor_dimension): + # try convert placement[mesh_dim] to Shard(dst_tensor_dim) + new_placements = list(placements) + new_placements[mesh_dim] = Shard(dst_tensor_dim) + tensor_mesh_dim_dict[dst_tensor_dim].append(mesh_dim) + dist_state = self.DistState( + self._to_tuple(new_placements), + DTensorRedistributePlanner._dict_to_ShardOrder( + tensor_mesh_dim_dict + ), + ) + all_next_state[dist_state] = self.chunk_cost + tensor_mesh_dim_dict[dst_tensor_dim].pop() + + ###################################################################### + # handle case 5: Partial() -> Shard() + for mesh_dim, placement in enumerate(placements): + if not isinstance(placement, Partial): + continue + for dst_tensor_dim in range(self.tensor_dimension): + # try convert placement[mesh_dim] to Shard(dst_tensor_dim) + new_placements = list(placements) + new_placements[mesh_dim] = Shard(dst_tensor_dim) + tensor_mesh_dim_dict[dst_tensor_dim].append(mesh_dim) + dist_state = self.DistState( + self._to_tuple(new_placements), + DTensorRedistributePlanner._dict_to_ShardOrder( + tensor_mesh_dim_dict + ), + ) + all_next_state[dist_state] = self.reduce_scatter + tensor_mesh_dim_dict[dst_tensor_dim].pop() + + ###################################################################### + # handle case 6: Replicate() -> Partial(), default to partial(sum) + for mesh_dim, placement in enumerate(placements): + if not isinstance(placement, Replicate): + continue + new_placements = list(placements) + new_placements[mesh_dim] = Partial() + dist_state = self.DistState( + self._to_tuple(new_placements), tensor_mesh_dim_tuple + ) + all_next_state[dist_state] = self.chunk_cost + + return all_next_state + + # TODO(zpcore): if the dst_state contains special placement like + # `_MaskPartial`, we will never reach that state. Need to support this case. + def find_min_cost_path( + self, src_state: DistState, dst_state: DistState + ) -> list["DTensorRedistributePlanner.DistState"]: + """ + Find the min cost path from src_state to dst_state using Dijkstra's + algorithm. + + Args: + src_state: The source state + dst_state: The destination state + + Returns: + A list of states representing the min cost path from src_state to + dst_state + """ + import heapq + + # priority queue (cost, counter, state, path) for Dijkstra's algorithm + # use counter to break ties and avoid comparing DistState objects + counter = 0 + pq: list[ + tuple[ + int, + int, + DTensorRedistributePlanner.DistState, + list[DTensorRedistributePlanner.DistState], + ] + ] = [(0, counter, src_state, [src_state])] + visited = set() + while pq: + cost, _, current_state, path = heapq.heappop(pq) + if current_state == dst_state: + return path + if current_state in visited: + continue + visited.add(current_state) + # get all possible next states and their costs + next_states = self.get_next_state( + current_state.placements, current_state.tensor_dim_to_mesh_dim + ) + for next_state, transition_cost in next_states.items(): + if next_state not in visited: + new_cost = cost + transition_cost + new_path = path + [next_state] + counter += 1 + heapq.heappush(pq, (new_cost, counter, next_state, new_path)) + raise AssertionError( + f"No path found from src_state {src_state} to dst_state {dst_state}" + ) + + def get_logical_shape( + self, + src_state: "DTensorRedistributePlanner.DistState", + mesh_dim: int, + full_tensor_shape: tuple[int, ...], + ) -> list[int]: + new_logical_shape = list(full_tensor_shape) + assert self.coordinate is not None + for entry in src_state.tensor_dim_to_mesh_dim: + tensor_dim = entry.tensor_dim + mesh_dims = entry.mesh_dims + assert len(mesh_dims) > 0 + for mdim in mesh_dims: + if mdim == mesh_dim: + continue + new_size = Shard.local_shard_size_and_offset( + new_logical_shape[tensor_dim], + self.device_mesh.size(mesh_dim=mdim), + self.coordinate[mdim], + )[0] + new_logical_shape[tensor_dim] = new_size + return new_logical_shape + + def generate_graph_based_transform_infos( + self, + src_spec: DTensorSpec, + dst_spec: DTensorSpec, + full_tensor_shape: tuple[int, ...], + ) -> list[_TransformInfo]: + # In case _StridedShard exists in placements, we let _StridedShard have + # higher priority to express shard_order. + if any( + isinstance(placement, _StridedShard) for placement in src_spec.placements + ): + src_placements, src_shard_order = ( + DTensorSpec._normalize_placements_into_shard_order( + src_spec.placements, src_spec.mesh + ) + ) + else: + src_placements = src_spec.placements + src_shard_order = src_spec.shard_order + if any( + isinstance(placement, _StridedShard) for placement in dst_spec.placements + ): + dst_placements, dst_shard_order = ( + DTensorSpec._normalize_placements_into_shard_order( + dst_spec.placements, dst_spec.mesh + ) + ) + else: + dst_placements = dst_spec.placements + dst_shard_order = dst_spec.shard_order + if src_shard_order is None or dst_shard_order is None: + raise NotImplementedError( + "Redistribution of _StridedShard placement is only supported for " + "_StridedShard that can be converted to ordered Shard placements. " + "Full _StridedShard redistribution support is not yet implemented." + ) + src_state = self.DistState(src_placements, src_shard_order) + dst_state = self.DistState(dst_placements, dst_shard_order) + transform_infos: list[_TransformInfo] = [] + state_path = self.find_min_cost_path(src_state, dst_state) + for cur_state, nxt_state in itertools.pairwise(state_path): + # find the mesh_dim that is different between cur_state and nxt_state + if cur_state.placements != nxt_state.placements: + update_mesh_dim = -1 + for mesh_dim, (cur_placement, nxt_placement) in enumerate( + zip(cur_state.placements, nxt_state.placements) + ): + if cur_placement != nxt_placement: + if update_mesh_dim != -1: + raise AssertionError( + "Multiple mesh_dims are different between cur_state and nxt_state" + ) + update_mesh_dim = mesh_dim + logical_shape = self.get_logical_shape( + cur_state, mesh_dim, full_tensor_shape + ) + transform_infos.append( + _TransformInfo( + mesh_dim=update_mesh_dim, + src_dst_placements=(cur_placement, nxt_placement), + logical_shape=logical_shape, + ) + ) + + return transform_infos + + def generate_greedy_transform_infos( + self, + src_spec: DTensorSpec, + dst_spec: DTensorSpec, + ) -> list[_TransformInfo]: + """ + Generate the transform infos from the source placements to the target placements. + + To transform from source to target placement it might have multiple steps, i.e. it + might decompose Si -> Sj into Si -> R -> Sj. + This would detect if there're mis-aligned/nested shardings between src/dst placements. + E.g. Suppose the redistribution to perform is (Shard(0), Shard(0)) -> (Replicate(), Shard(0)), + in this case Shard(0) -> Shard(0) for mesh dimension 1 actually needs resharding, because in + the former is a nested-sharding of a tensor already already sharded dimension 0, whereas + the latter is the first sharding on tensor dimension 0. + """ + # logical shape records the logic tensor shape on the mesh dimension + # this is useful to ensure uneven sharding gets correct output shape + assert self.coordinate is not None + initial_logical_shape = list(src_spec.shape) + mesh_dims_to_logical_shape = [initial_logical_shape] + transform_infos: list[_TransformInfo] = [] + if self.device_mesh.ndim == 1: + # if device_mesh is 1D, redistribute is a simple direct + # transformation + transform_infos.append( + _TransformInfo( + mesh_dim=0, + src_dst_placements=(src_spec.placements[0], dst_spec.placements[0]), + logical_shape=initial_logical_shape, + ) + ) + return transform_infos + + # Handle multi-dim device mesh placement redistribution First, we need + # to build the logical shape for each mesh dim for correct allgather + # uneven shards on each mesh dim (with dynamic padding) + for i, src in enumerate(src_spec.placements): + current_logical_shape = mesh_dims_to_logical_shape[i] + if isinstance(src, Shard): + if i < self.device_mesh.ndim - 1: + # calculate and save the logical shape for this sharding + mesh_dim_size = self.device_mesh.size(mesh_dim=i) + local_shard_size, _ = src._local_shard_size_and_offset( + current_logical_shape[src.dim], + mesh_dim_size, + self.coordinate[i], + ) + new_logical_shape = list(current_logical_shape) + new_logical_shape[src.dim] = local_shard_size + mesh_dims_to_logical_shape.append(new_logical_shape) + else: + mesh_dims_to_logical_shape.append(current_logical_shape) + + # Next, we need to derive the transform infos from src to dst + # placements, here we use a greedy search with step by step state + # transformations + current_placements = list(src_spec.placements) + target_placements = list(dst_spec.placements) + + if src_spec.num_shards > 1: + # If src_spec have sharding, it could potentially have sharding that + # is misaligned with dst_spec a common case of this is nested + # sharding (i.e. (S(0), S(0)) -> (R, S(0))). In those cases, we + # first traverse from inner placement to outer placement to detect + # misaligned shardings and properly replicate nested sharding first. + for mesh_dim in reversed(range(len(current_placements))): + current = current_placements[mesh_dim] + target = target_placements[mesh_dim] + # If target is not Shard, we can directly redistribute since we + # are traversing from inner to outer placements here + if isinstance(target, Shard): + # If target is Shard, check for nested sharding on the + # tensor dim BEFORE the current mesh_dim + shard_dim = target.dim + current_mesh_sharding, target_mesh_sharding = [], [] + for i, (s, p) in enumerate( + zip(current_placements, target_placements) + ): + if i >= mesh_dim: + break + if s.is_shard(shard_dim): + current_mesh_sharding.append(i) + if p.is_shard(shard_dim): + target_mesh_sharding.append(i) + + if current_mesh_sharding != target_mesh_sharding: + # if current/target_placements have misaligned sharding + # on the tensor dim BEFORE the current mesh_dim, we need + # to replicate the tensor on the mesh dim first to clear + # the nested sharding + target = Replicate() + + if current != target: + transform_infos.append( + _TransformInfo( + mesh_dim=mesh_dim, + src_dst_placements=(current, target), + logical_shape=mesh_dims_to_logical_shape[mesh_dim], + ) + ) + current_placements[mesh_dim] = target + + # We always traverse from outer placement to inner placement to collect + # the remaining needed transform infos (i.e. the replication from nested + # sharding might need to further perform resharding to Shard again) + for mesh_dim, (current, target) in enumerate( + zip(current_placements, target_placements) + ): + if current != target: + transform_infos.append( + _TransformInfo( + mesh_dim=mesh_dim, + src_dst_placements=(current, target), + logical_shape=mesh_dims_to_logical_shape[mesh_dim], + ) + ) + current_placements[mesh_dim] = target + return transform_infos + + +def _gen_transform_infos_non_cached( + src_spec: DTensorSpec, + dst_spec: DTensorSpec, + use_graph_based_transform: bool | None = None, +) -> list[_TransformInfo]: + device_mesh = src_spec.device_mesh + src_shard_order = src_spec.shard_order + dst_shard_order = dst_spec.shard_order + # DTensorSpec should automatically generate shard_order, and it can be () if + # no shard. + assert src_shard_order is not None and dst_shard_order is not None + # Determine which transform strategy to use: + # 1. Non-standard device order → always use graph-based + # 2. Global flag or explicit parameter True → use graph-based + # 3. Otherwise → use greedy + has_non_default_order = not all( + DTensorSpec.is_default_device_order(order) + for order in (src_shard_order, dst_shard_order) + ) + + if has_non_default_order is True: + use_graph_based_transform = True + elif _FORCE_MIN_COST_REDISTRIBUTION_PLAN is not None: + use_graph_based_transform = _FORCE_MIN_COST_REDISTRIBUTION_PLAN + elif use_graph_based_transform is None: + use_graph_based_transform = False + drp = get_redistribute_planner(device_mesh, len(src_spec.shape)) + if use_graph_based_transform: + transform_infos = drp.generate_graph_based_transform_infos( + src_spec, dst_spec, src_spec.shape + ) + else: + transform_infos = drp.generate_greedy_transform_infos(src_spec, dst_spec) + return transform_infos + + +@cache +def _gen_transform_infos( + src_spec: DTensorSpec, + dst_spec: DTensorSpec, + use_graph_based_transform: bool | None = None, +) -> list[_TransformInfo]: + return _gen_transform_infos_non_cached( + src_spec, dst_spec, use_graph_based_transform + ) + + +def redistribute_local_tensor( + local_tensor: torch.Tensor, + current_spec: DTensorSpec, + target_spec: DTensorSpec, + *, + async_op: bool = False, + is_backward: bool = False, + use_graph_based_transform: bool | None = None, +) -> torch.Tensor: + """ + This redistribute the local tensor (torch.Tensor) from the current DTensorSpec to + the target DTensorSpec, which involves the necessary collective calls to transform + the local shard of the DTensor from its current spec to the target spec. + """ + + if current_spec.mesh != target_spec.mesh: + # TODO: alltoall/permute reshuffling to change device_mesh if they are not the same + raise NotImplementedError("Cross device mesh comm not supported yet!") + + new_local_tensor = local_tensor + device_mesh = current_spec.mesh + + my_coordinate = device_mesh.get_coordinate() + + if my_coordinate is None: + # if rank is not part of mesh, we skip redistribute and simply return local_tensor, + # which should be an empty tensor + return local_tensor + + if _are_we_tracing(): + transform_infos = _gen_transform_infos_non_cached( + current_spec, target_spec, use_graph_based_transform + ) + else: + transform_infos = _gen_transform_infos( + current_spec, target_spec, use_graph_based_transform + ) + + debug_mode = get_active_debug_mode() + redistribute_context = ( + debug_mode.record_redistribute_calls( # type: ignore[union-attr] + local_tensor, + current_spec.placements, + target_spec.placements, + DTensorRedistributePlanner.stringify_transform_infos( + device_mesh, + transform_infos, + current_spec.placements, + current_spec.shard_order, + ), + ) + if debug_mode is not None + else contextlib.nullcontext() + ) + + with redistribute_context: + for transform_info in transform_infos: + i = transform_info.mesh_dim + current, target = transform_info.src_dst_placements + num_chunks = device_mesh.size(mesh_dim=i) + + if current == target: + # short cut, just use the original local tensor + new_local_tensor = local_tensor + continue + + if num_chunks == 1: + # short cut, if there's only one shard, we don't need to do any collective + # comm, just use the original local tensor + new_local_tensor = local_tensor + continue + + if target.is_replicate(): + # Case 1: target is Replicate + if current.is_partial(): + partial_spec = cast(Partial, current) + new_local_tensor = partial_spec._reduce_value( + local_tensor, device_mesh, i + ) + elif current.is_shard(): + current_placement = cast(Shard, current) + new_local_tensor = current_placement._to_replicate_tensor( + local_tensor, device_mesh, i, transform_info.logical_shape + ) + else: + raise RuntimeError( + f"redistribute from {current} to {target} not supported yet" + ) + + elif target.is_shard(): + # Case 2: target is Shard + target_placement = cast(Shard, target) + if current.is_partial(): + partial_spec = cast(Partial, current) + new_local_tensor = partial_spec._reduce_shard_value( + local_tensor, device_mesh, i, target_placement + ) + elif current.is_replicate(): + # split the tensor and return the corresponding cloned local shard + new_local_tensor = target_placement._replicate_to_shard( + local_tensor, device_mesh, i, my_coordinate[i] + ) + else: + assert current.is_shard(), ( + f"Current placement should be shard but found {current}" + ) + shard_spec = cast(Shard, current) + if shard_spec.dim != target_placement.dim: + new_local_tensor = shard_spec._to_new_shard_dim( + local_tensor, + device_mesh, + i, + transform_info.logical_shape, + target_placement.dim, + ) + elif target.is_partial(): + if current.is_replicate(): + partial_spec = cast(Partial, target) + # skip the replicate to partial transformation when we are in backward pass + # In this case we keep the grad as replicate, this is because we don't + # want to convert the replicated gradients back to partial, although + # that's logically conform with the same layout, converting the gradients + # back to partial is actually useless as you would have to do reduce later + # which would be more expensive than keeping it replicate! For this reason, + # we keep the replicate grad here. + new_local_tensor = ( + partial_spec._partition_value(local_tensor, device_mesh, i) + if not is_backward + else local_tensor + ) + elif current.is_shard(): + if not is_backward: + raise RuntimeError( + f"redistribute from {current} to {target} not supported yet" + ) + # for backward shard -> partial, we just need to convert the shard to replicate + current_placement = cast(Shard, current) + new_local_tensor = current_placement._to_replicate_tensor( + local_tensor, device_mesh, i, transform_info.logical_shape + ) + else: + # partial -> partial no op, should never hit + new_local_tensor = local_tensor + + if not async_op and isinstance( + new_local_tensor, funcol.AsyncCollectiveTensor + ): + new_local_tensor = new_local_tensor.wait() + local_tensor = new_local_tensor + return new_local_tensor + + +class Redistribute(torch.autograd.Function): + @staticmethod + def forward( # type: ignore[override] + # pyre-fixme[2]: Parameter must be annotated. + ctx, + input: "dtensor.DTensor", + device_mesh: DeviceMesh, + placements: tuple[Placement, ...], + async_op: bool = False, + forward_dtype: torch.dtype | None = None, + backward_dtype: torch.dtype | None = None, + ): + ctx.async_op = async_op + ctx.backward_dtype = backward_dtype + ctx.original_dtype = input._local_tensor.dtype + + if forward_dtype is not None and forward_dtype != input._local_tensor.dtype: + local_tensor = input._local_tensor.to(dtype=forward_dtype) + current_spec = DTensorSpec( + mesh=device_mesh, + placements=input._spec.placements, + tensor_meta=TensorMeta( + shape=input.shape, + stride=input.stride(), + dtype=forward_dtype, + ), + ) + else: + local_tensor = input._local_tensor + current_spec = input._spec + + ctx.current_spec = current_spec + + if current_spec.placements != placements: + target_spec = DTensorSpec( + device_mesh, placements, tensor_meta=current_spec.tensor_meta + ) + + output = redistribute_local_tensor( + local_tensor, current_spec, target_spec, async_op=async_op + ) + else: + # use the same local tensor if placements are the same. + output = local_tensor + target_spec = current_spec + + # pyrefly: ignore [bad-argument-type] + return dtensor.DTensor( + # pyrefly: ignore [bad-argument-count] + output, + target_spec, + # pyrefly: ignore [unexpected-keyword] + requires_grad=input.requires_grad, + ) + + @staticmethod + def backward(ctx, grad_output: "dtensor.DTensor"): # type: ignore[override] + previous_spec = ctx.current_spec + async_op = ctx.async_op + backward_dtype = ctx.backward_dtype or ctx.original_dtype + + if backward_dtype != grad_output._local_tensor.dtype: + local_tensor = grad_output._local_tensor.to(dtype=backward_dtype) + current_spec = DTensorSpec( + mesh=grad_output._spec.device_mesh, + placements=grad_output._spec.placements, + tensor_meta=TensorMeta( + shape=grad_output.shape, + stride=grad_output.stride(), + dtype=backward_dtype, + ), + ) + previous_spec = DTensorSpec( + mesh=previous_spec.device_mesh, + placements=previous_spec.placements, + tensor_meta=current_spec.tensor_meta, + ) + else: + local_tensor = grad_output._local_tensor + current_spec = grad_output._spec + + output = redistribute_local_tensor( + local_tensor, + current_spec, + previous_spec, + async_op=async_op, + is_backward=True, + ) + + if output.dtype != ctx.original_dtype: + output = output.to(ctx.original_dtype) + + # normalize the target placement to replicate if it is partial + normalized_placements: list[Placement] = [] + for previous_placement in previous_spec.placements: + if previous_placement.is_partial(): + # keep target placement to replicate instead of partial in this case + normalized_placements.append(Replicate()) + else: + normalized_placements.append(previous_placement) + + spec = DTensorSpec( + previous_spec.device_mesh, + tuple(normalized_placements), + tensor_meta=TensorMeta( + shape=grad_output.shape, + stride=grad_output.stride(), + dtype=output.dtype, + ), + ) + # pyrefly: ignore [bad-argument-type] + output_dtensor = dtensor.DTensor( + # pyrefly: ignore [bad-argument-count] + output, + spec, + # pyrefly: ignore [unexpected-keyword] + requires_grad=grad_output.requires_grad, + ) + + return ( + output_dtensor, + None, + None, + None, + None, + None, + ) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/_sharding_prop.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/_sharding_prop.py new file mode 100644 index 0000000000000000000000000000000000000000..c1fddd05c9d6e7f38e637ea10a3bf2ffe0e16fe0 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/_sharding_prop.py @@ -0,0 +1,680 @@ +# mypy: allow-untyped-defs +import logging +import threading +from collections.abc import Callable, Sequence +from functools import lru_cache +from itertools import chain +from typing import cast + +import torch +from torch._guards import detect_fake_mode +from torch._ops import OpOverload +from torch._subclasses import FakeTensorMode +from torch.distributed._functional_collectives import _are_we_tracing +from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta +from torch.distributed.tensor._op_schema import ( + OpInfo, + OpSchema, + OpSpec, + OpStrategy, + OutputSharding, + OutputSpecType, + RuntimeSchemaInfo, + StrategyType, + TupleStrategy, +) +from torch.distributed.tensor._utils import ( + compute_local_shape_and_global_offset, + compute_local_stride, +) +from torch.distributed.tensor.placement_types import _StridedShard, Shard + + +aten = torch.ops.aten + +log = logging.getLogger(__name__) + + +def _length(obj) -> int: + if obj is None: + return 0 + if not isinstance(obj, Sequence): + return 1 + return len(obj) + + +class LocalLRUCache(threading.local): + def __init__(self, user_function: Callable) -> None: + self.cache = lru_cache(None)(user_function) + + def __call__(self, *args, **kwargs) -> object: + return self.cache(*args, **kwargs) + + def cache_info(self): + return self.cache.cache_info() + + def cache_clear(self): + return self.cache.cache_clear() + + +class ShardingPropagator: + def __init__(self) -> None: + self.op_to_rules: dict[OpOverload, Callable[[OpSchema], OutputSharding]] = {} + self.op_strategy_funcs: dict[ + OpOverload, + Callable[[OpSchema], StrategyType], + ] = {} + # op map to save static argnum to decide to reuse sharding prop cache or + # re-run sharding prop + self.op_to_schema_info: dict[OpOverload, RuntimeSchemaInfo] = {} + self.propagate_op_sharding = LocalLRUCache( + self.propagate_op_sharding_non_cached + ) + # op map to save indices of shape (and stride) args which may need to be + # modified in sharding prop + self.op_to_shape_and_stride_idx: dict[OpOverload, int | tuple[int, int]] = { + # new factory ops + aten.new_empty.default: 1, + aten.new_full.default: 1, + aten.new_ones.default: 1, + aten.new_zeros.default: 1, + aten.new_empty_strided.default: (1, 2), + # view ops + aten.expand.default: 1, + aten.reshape.default: 1, + aten.view.default: 1, + aten._unsafe_view.default: 1, + aten.select_backward.default: 1, + aten.slice_backward.default: 1, + } + + def register_sharding_prop_rule( + self, + op_overload: OpOverload, + rule_func: Callable[[OpSchema], OutputSharding], + schema_info: RuntimeSchemaInfo | None = None, + ): + """ + Register a sharding propagation rule for an operator. + """ + self.op_to_rules[op_overload] = rule_func + if schema_info is not None: + self.op_to_schema_info[op_overload] = schema_info + + def register_op_strategy( + self, + op_overload: OpOverload, + strategy_func: Callable[[OpSchema], StrategyType], + schema_info: RuntimeSchemaInfo | None = None, + ): + """ + Register a :class:`OpStrategy` generator for an operator. + + During the sharding propagation, DTensor wants to enumerate all + acceptable sharding specs (:class:`OpSpec`) for an operator, + and by "acceptable" we mean that the operator can be executed on + the ``_local_tensor`` of DTensor args/kwargs (with ``OpSpec.input_specs``) + and the output(s) constitute valid DTensor(s) (with ``OpSpec.output_specs``). + + ``strategy_func`` is the function that enumerates such acceptable specs + for the operator ``op_overload``. One general approach to write ``strategy_func`` + is, if the operator has simple arguments structure (e.g. mm, bmm), first enumerating + all sharding specs for the operands, and then filtering out the ones that + are not valid. For example, for ``mm``, the operands are two 2D tensors, and + if both ``input`` and ``mat2`` have sharding placements ``[Shard(0)]``, then this + is not an acceptable ``input_specs``. + + Once we have a way to enumerate all acceptable sharding specs, we can use each + of them to construct a :class:`OpSpec`. The ``OpSpec.input_specs`` directly comes + from the sharding spec, and the ``OpSpec.output_specs`` is therefore determined + (e.g. ``[Shard(1)]`` @ ``[Shard(0)]`` yields ``[Partial()]``). In addition, + :class:`OpSpec` also contains ``redistribute_cost`` which records the redistribution + cost from each :class:`OpSpec` in the source :class:`OpStrategy.strategies` to + the target sharding spec, for each operand. + + The ``strategy_func`` should return a :class:`OpStrategy` which contains a list of + all the :class:`OpSpec`s generated in the above. + + The optional ``schema_info`` tells which non-DTensor args/kwargs could affect the + cache and whether ``pytree`` is needed to flatten the nested args. ``static_argnum`` + marks the starting index of the non-DTensor args that should be hashed into the + sharding propagation hash key, and ``static_kwargkey`` marks the keys of the + non-DTensor kwargs that should be hashed. ``needs_pytree`` should be used when + the input arg has :class:`list` or :class:`dict` structure. + + For example, ``aten.cat.default`` op has a ``List[Tensor]`` argument ``tensors`` + and an ``int`` argument ``dim``. Because ``dim`` affects the sharding propagation + result, we want to pass ``RuntimeSchemaInfo(static_argnum=1)`` because the argument + index of ``dim`` is 1. Besides, we also want to set ``needs_pytree=True`` because + ``tensors`` needs be flattened in sharding propagation. Another example is + ``aten.histc.default``. ``histc`` has 4 arguments (self, bins, min, max) and the + last two would affect sharding propagation along with the :class:`DTensor` argument + ``self``. Since the argument index of ``min`` is 2, the `schema_info` should be + `RuntimeSchemaInfo(static_argnum=2)`. + """ + self.op_strategy_funcs[op_overload] = strategy_func + if schema_info is not None: + self.op_to_schema_info[op_overload] = schema_info + + def _propagate_tensor_meta_non_cached( + self, op_schema: OpSchema + ) -> None | TensorMeta | Sequence[TensorMeta | None]: + """ + Propagate the tensor metadata, it could either return a TensorMeta + or a list/tuple of TensorMetas + """ + if op_schema.op == aten.equal.default: + # data dependent ops can't be used for fake propagation + return None + + # NOTE: We must call the tracing in fake tensor mode so that it avoids + # materializing memory. + fake_mode = detect_fake_mode() or FakeTensorMode() + with fake_mode: + fake_args = op_schema.gen_fake_args() + fake_kwargs = op_schema.gen_fake_kwargs() + fake_out = op_schema.op(*fake_args, **fake_kwargs) + + if isinstance(fake_out, torch.Tensor): + return TensorMeta( + shape=fake_out.shape, stride=fake_out.stride(), dtype=fake_out.dtype + ) + + elif isinstance(fake_out, (tuple, list)): + tensor_meta_list: list[TensorMeta | None] = [] + for fake_out_item in fake_out: + if isinstance(fake_out_item, torch.Tensor): + tensor_meta_list.append( + TensorMeta( + shape=fake_out_item.shape, + stride=fake_out_item.stride(), + dtype=fake_out_item.dtype, + ) + ) + else: + tensor_meta_list.append(None) + return ( + tuple(tensor_meta_list) + if isinstance(fake_out, tuple) + else tensor_meta_list + ) + else: + # if fake is not a tensor or tuple of tensor, return as none + return None + + @lru_cache # noqa: B019 + def _propagate_tensor_meta( + self, op_schema: OpSchema + ) -> None | TensorMeta | Sequence[TensorMeta | None]: + """ + Cached version of _propagate_tensor_meta_non_cached + This is a private API. Use propagate_tensor_meta instead. + """ + return self._propagate_tensor_meta_non_cached(op_schema) + + def propagate_tensor_meta( + self, op_schema: OpSchema + ) -> None | TensorMeta | Sequence[TensorMeta | None]: + """ + Propagate the tensor metadata, it could either return a TensorMeta + or a list/tuple of TensorMetas. This is a public API that should be + used if cache should be used. + """ + if _are_we_tracing(): + return self._propagate_tensor_meta_non_cached(op_schema) + else: + return self._propagate_tensor_meta(op_schema) + + def _create_output_spec_with_new_tensor_meta( + self, + op: OpOverload, + output_specs: OutputSpecType, + output_tensor_meta: None | TensorMeta | Sequence[TensorMeta | None], + ) -> OutputSpecType: + """ + Wrap the output_specs with the tensor metadata from the output. + """ + + if isinstance(output_specs, DTensorSpec): + if not isinstance(output_tensor_meta, TensorMeta): + # Either error due to ShardingPropagator or due to incorrect OutputSpec + if not isinstance(output_tensor_meta, (tuple, list)): + raise ValueError( + "ShardingPropagator error: output does not have an associated " + "TensorMeta" + ) + raise ValueError( + f"For the op {op.name()}, `output_specs` has 1 output which does " + "not equal the " + f"number of op outputs: {len(output_tensor_meta)}." + ) + return output_specs.shallow_copy_with_tensor_meta(output_tensor_meta) + elif isinstance(output_specs, (tuple, list)): + new_specs: list[DTensorSpec | None] = [] + if not isinstance(output_tensor_meta, (tuple, list)) or len( + output_specs + ) != len(output_tensor_meta): + raise ValueError( + f"For the op {op.name()}, `output_specs` has {len(output_specs)} " + "outputs which does not equal the " + f"number of op outputs {_length(output_tensor_meta)}." + ) + + for i, spec in enumerate(output_specs): + if isinstance(spec, DTensorSpec): + output_tensor_meta_i = output_tensor_meta[i] + if not isinstance(output_tensor_meta_i, TensorMeta): + # NOTE: aten.convolution_backward.default is an exception and it + # needs extra handling because any Tensor in the output tuple + # can be `None` depending on the output_mask parameter. This can + # occur during double backpropagation or when certain gradients + # are not needed (e.g., grad_input when input has requires_grad=False, + # grad_weight/grad_bias when weight/bias have requires_grad=False, + # or grad_bias when bias is None). We explicitly allow the + # corresponding TensorMeta to be `None`. + if ( + op == aten.convolution_backward.default + and i in (0, 1, 2) + and output_tensor_meta_i is None + ): + assert isinstance(output_specs, list) + new_specs.append(None) + continue + else: + raise ValueError( + f"ShardingPropagator error: output {i} of {op.name()} " + "does not have an associated TensorMeta" + ) + + new_specs.append( + spec.shallow_copy_with_tensor_meta(output_tensor_meta_i) + ) + else: + new_specs.append(spec) + + return tuple(new_specs) + else: + assert output_specs is None + return output_specs + + def _wrap_with_op_strategy(self, op_schema: OpSchema) -> OpSchema: + """ + wrap a op_schema that contains DTensorSpec to another op_schema that contains + OpStrategy/TupleStrategy, the returned op_schema is then used for sharding + strategy propagation on pytorch operators. + """ + + def spec_to_strategy(spec: object) -> object: + if isinstance(spec, DTensorSpec): + return OpStrategy([OpSpec(spec)]) + elif ( + isinstance(spec, (list, tuple)) + and len(spec) > 0 + and isinstance(spec[0], DTensorSpec) + ): + # tensor list create tuple strategy + tuple_strategy = [spec_to_strategy(s) for s in spec] + tuple_strategy = cast(Sequence[StrategyType], tuple_strategy) + return TupleStrategy( + tuple(tuple_strategy) if isinstance(spec, tuple) else tuple_strategy + ) + else: + return spec + + args_op_strategy = [spec_to_strategy(i) for i in op_schema.args_schema] + + kwargs_op_strategy = { + k: spec_to_strategy(v) for k, v in op_schema.kwargs_schema.items() + } + + return OpSchema( + op=op_schema.op, + args_schema=tuple(args_op_strategy), + kwargs_schema=kwargs_op_strategy, + schema_info=op_schema.schema_info, + ) + + def propagate(self, op_info: OpInfo) -> None: + # NB: The logic here is duplicated in _propagate_op_sharding_dispatch_slow_path. + # Ideally, this function would be deleted, but there are a handful of + # one off call sites here that aren't cleaned up. + + # We cannot use an lru cache if we know that inputs will have dynamic shapes, + # because SymInts are not hashable. + # This is generally ok because this only happens during tracing in torch.compile, + # and tracing does not need to be as fast as eagermode DTensor usages. + if _are_we_tracing(): + output_sharding = self.propagate_op_sharding_non_cached(op_info.schema) + else: + output_sharding = cast( + OutputSharding, self.propagate_op_sharding(op_info.schema) + ) + op_info.output_sharding = output_sharding + + def propagate_op_sharding_non_cached(self, op_schema: OpSchema) -> OutputSharding: + """ + Propagate the sharding for an operator given the op_schema. + """ + # no-op in OSS, logs API usage metrics in meta-internal runs + torch._C._log_api_usage_once( + "torch.distributed.tensor._sharding_prop.ShardingPropagator.propogate_op_sharding_non_cached" + ) + # special case op, we don't need to propagate for local + # scalar. TODO: figure out a better way to handle this + if op_schema.op is aten._local_scalar_dense.default: + return OutputSharding(None, op_schema) + + out_tensor_meta = self._propagate_tensor_meta_non_cached(op_schema) + if op_schema.op in self.op_strategy_funcs: + # wrap the op_schema with op strategy for sharding strategy propagation + strategy_schema = self._wrap_with_op_strategy(op_schema) + + # run sharding strategy propagation/generation + op_strategy = self.op_strategy_funcs[op_schema.op](strategy_schema) + + if isinstance(op_strategy, OpStrategy): + # single Op strategy + output_strategy = self._select_strategy(op_strategy, op_schema) + + # check if we need to redistribute the input + needs_redistribute = False + # check if we want to use args value from redistribute_schema + use_val_from_redistribute_schema = False + expected_input_specs: list[DTensorSpec] = [] + + # in case where the op does not specify input_specs and output_specs + # is a DTensorSpec, we use output_specs as the spec for each DTensor + # input arg. + if output_strategy.input_specs is None: + assert isinstance(output_strategy.output_specs, DTensorSpec) + + for idx, input_spec in enumerate(op_schema.args_spec): + desired_spec = ( + output_strategy.output_spec + if output_strategy.input_specs is None + else output_strategy.input_specs[idx] + ) + expected_input_specs.append( + desired_spec.shallow_copy_with_tensor_meta( + input_spec.tensor_meta + ) + ) + if input_spec.placements != desired_spec.placements: + needs_redistribute = True + + suggestion_schema = None + if needs_redistribute: + suggestion_schema = OpSchema( + op_schema.op, tuple(expected_input_specs), {} + ) + suggestion_schema._inplace_rewrap_schema_suggestion(op_schema) + + # shape and stride args need to be modified for + # view ops and new factory ops, potentially + if op_schema.op in self.op_to_shape_and_stride_idx: + assert isinstance(output_strategy.output_spec, DTensorSpec) + # It happens when the output has the same shape as the input + # and the input placements are not all Replicate(). + if any( + isinstance(p, Shard | _StridedShard) + for p in output_strategy.output_spec.placements + ): + schema = suggestion_schema or op_schema + assert isinstance(out_tensor_meta, TensorMeta) + suggestion_schema = self._adjust_shape_and_stride_args( + out_tensor_meta, schema, output_strategy.output_spec + ) + needs_redistribute = True + use_val_from_redistribute_schema = True + + # construct output spec for the op + if op_schema.return_type_tuple_tensor_like(): + # for ops that return multiple tensors and the output_specs is not + # a tuple, we use a tuple of that single output spec as the new + # output_specs + output_specs: OutputSpecType = output_strategy.output_specs + if isinstance(output_specs, DTensorSpec): + output_specs = tuple( + # create a new DTensorSpec with the same placement as the + # output_specs in output_strategy + DTensorSpec( + mesh=output_specs.mesh, + placements=output_specs.placements, + tensor_meta=output_specs.tensor_meta, + ) + for _ in range(len(op_schema.op._schema.returns)) + ) + elif ( + op_schema.return_type_tensor() + or op_schema.return_type_list_tensor_like() + ): + output_specs = output_strategy.output_specs + else: + output_specs = None + + output_sharding = OutputSharding( + output_specs, + suggestion_schema, + needs_redistribute=needs_redistribute, + use_val_from_redistribute_schema=use_val_from_redistribute_schema, + ) + elif isinstance(op_strategy, TupleStrategy): + # tuple strategy output sharding processing + # runtime select OpSpec for each TupleStrategy input arg + selected_strategies: list[OpSpec] = [] + out_spec_list: list[DTensorSpec] = [] + for strategy in op_strategy.children: + assert isinstance(strategy, OpStrategy) + selected_strategy = self._select_strategy(strategy) + selected_strategies.append(selected_strategy) + out_spec_list.append(selected_strategy.output_spec) + + needs_redistribute = False + suggestion_args: list[object] = [] + tensor_or_list_tensor_arg_idx = 0 + + for arg in op_schema.args_schema: + if ( + arg + and isinstance(arg, (list, tuple)) + and isinstance(arg[0], DTensorSpec) + ): + expected_input_spec_list: list[DTensorSpec] = [] + for idx, arg_spec in enumerate(arg): + expected_input_spec = selected_strategies[idx].input_spec( + tensor_or_list_tensor_arg_idx + ) + expected_input_spec = ( + expected_input_spec.shallow_copy_with_tensor_meta( + arg_spec.tensor_meta + ) + ) + if arg_spec.placements != expected_input_spec.placements: + needs_redistribute = True + expected_input_spec_list.append(expected_input_spec) + suggestion_args.append( + tuple(expected_input_spec_list) + if isinstance(arg, tuple) + else expected_input_spec_list + ) + tensor_or_list_tensor_arg_idx += 1 + + elif isinstance(arg, DTensorSpec): + expected_input_spec = selected_strategies[0].input_spec( + tensor_or_list_tensor_arg_idx + ) + expected_input_spec = ( + expected_input_spec.shallow_copy_with_tensor_meta( + arg.tensor_meta + ) + ) + if arg.placements != expected_input_spec.placements: + needs_redistribute = True + suggestion_args.append(expected_input_spec) + tensor_or_list_tensor_arg_idx += 1 + else: + suggestion_args.append(arg) + + suggestion_schema = None + if needs_redistribute: + suggestion_schema = OpSchema( + op_schema.op, tuple(suggestion_args), op_schema.kwargs_schema + ) + + output_sharding = OutputSharding( + tuple(out_spec_list) if out_tensor_meta is not None else None, + suggestion_schema, + needs_redistribute=needs_redistribute, + use_val_from_redistribute_schema=False, + ) + else: + raise ValueError("Unsupported op strategy type") + + # associate the output sharding with the output tensor metadata + new_output_spec = self._create_output_spec_with_new_tensor_meta( + op_schema.op, output_sharding.output_spec, out_tensor_meta + ) + output_sharding.output_spec = new_output_spec + return output_sharding + elif op_schema.op in self.op_to_rules: + # propagate the sharding with rule + sharding_prop_func = self.op_to_rules[op_schema.op] + + # step 1. there's sharding propagation rule, run + # sharding propagation to get the output sharding + try: + output_sharding = sharding_prop_func(op_schema) + except NotImplementedError as e: + raise e + except Exception as e: + raise RuntimeError( + f"Sharding propagation failed on op {op_schema}.\nError: {e}" + ) from e + + # step 2. if can't get output_spec from sharding + # propagation (i.e. no rules apply for input + # placements), we return the output sharding + # with schema suggestions, which can be used to + # decide how to do redistribute on inputs + if output_sharding.output_spec is None: + if output_sharding.redistribute_schema is None: + raise RuntimeError( + f"Sharding propagation failed on op {op_schema}!" + ) + else: + # we do auto redistribute on inputs if necessary + # run sharding propagation again with suggested schema + propagation_res = sharding_prop_func( + output_sharding.redistribute_schema + ) + # we set the output sharding with the new propagation result + # so that dispatching know both output_spec and redistribute_schema + # exist, which indicates a reshard is needed + output_sharding.output_spec = propagation_res.output_spec + output_sharding.needs_redistribute = True + + # associate the output sharding with the output tensor metadata + new_output_spec = self._create_output_spec_with_new_tensor_meta( + op_schema.op, output_sharding.output_spec, out_tensor_meta + ) + output_sharding.output_spec = new_output_spec + + return output_sharding + else: + raise NotImplementedError( + f"Operator {op_schema.op} does not have a sharding strategy registered." + ) + + def _select_strategy( + self, strategy: OpStrategy, op_schema: OpSchema | None = None + ) -> OpSpec: + from torch.fx.experimental.symbolic_shapes import guard_or_false + + if len(strategy.strategies) == 1: + # short cut with only one possible OpSpec + return strategy.strategies[0] + + op_spec_costs: list[torch.types.FloatLikeType] = [] + no_redistribute_strategy_index: int = -1 + negative_cost_index: int = -1 + zero_cost_index: int = -1 + for strategy_idx, op_spec in enumerate(strategy.strategies): + assert op_spec.redistribute_cost is not None, ( + "must set redistribute cost each OpSpec!" + ) + redistribute_cost = sum(chain.from_iterable(op_spec.redistribute_cost)) + op_spec_costs.append(redistribute_cost) + + # If there are strategies with negative/zero/no redistribute cost, + # we record those indices. + # TODO: Currently this only applies to OpStrategy selection. Requires extra + # logic to make it work for TupleStrategy, if needed. + if op_schema is not None: + if guard_or_false(redistribute_cost < 0): + if ( + negative_cost_index == -1 + or redistribute_cost < op_spec_costs[negative_cost_index] + ): + negative_cost_index = strategy_idx + elif guard_or_false(redistribute_cost == 0): + needs_redistribute = False + for spec_idx, input_spec in enumerate(op_schema.args_spec): + desired_spec = ( + op_spec.output_spec + if op_spec.input_specs is None + else op_spec.input_specs[spec_idx] + ) + if input_spec.placements != desired_spec.placements: + needs_redistribute = True + break + + if not needs_redistribute: + no_redistribute_strategy_index = strategy_idx + elif zero_cost_index == -1: + zero_cost_index = strategy_idx + + # prioritize negative/zero/no redistribute cost strategies + if negative_cost_index != -1: + # If there's negative cost, we select the one with the minimal cost, + # even if this means we need to redistribute, e.g. via local chunking. + # E.g. this can happen for ops in self.op_to_shape_and_stride_idx + # when the inputs / outputs are sharded. + selected_strategy_index = negative_cost_index + elif no_redistribute_strategy_index != -1: + selected_strategy_index = no_redistribute_strategy_index + elif zero_cost_index != -1: + selected_strategy_index = zero_cost_index + else: + # default to choosing minimal redistribute cost + min_cost = min(op_spec_costs) + selected_strategy_index = op_spec_costs.index(min_cost) + + return strategy.strategies[selected_strategy_index] + + def _adjust_shape_and_stride_args( + self, + out_tensor_meta: TensorMeta, + schema: OpSchema, + spec: DTensorSpec, + ) -> OpSchema: + shape_stride_idx = self.op_to_shape_and_stride_idx[schema.op] + if isinstance(shape_stride_idx, tuple): + shape_idx, stride_idx = shape_stride_idx + else: + shape_idx = shape_stride_idx + stride_idx = None + + expected_input_schema = list(schema.args_schema) + # adjust shape to be the same as that of the _local_tensor + # of the DTensor input arg at index 0, which is inferred + expected_input_schema[shape_idx], _ = compute_local_shape_and_global_offset( + out_tensor_meta.shape, spec.mesh, spec.placements, skip_offset=True + ) + + # adjust the stride arg for aten.new_empty_strided.default + if stride_idx: + expected_input_schema[stride_idx] = compute_local_stride( + out_tensor_meta.stride, spec.mesh, spec.placements + ) + + return OpSchema(schema.op, tuple(expected_input_schema), schema.kwargs_schema) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/_shards_wrapper.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/_shards_wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..1673dd7e34b994470386e1fb1a5079c302302393 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/_shards_wrapper.py @@ -0,0 +1,359 @@ +# mypy: allow-untyped-defs +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Any + +import torch +from torch.distributed.checkpoint.metadata import ( + ChunkStorageMetadata, + MetadataIndex, + TensorProperties, + TensorStorageMetadata, +) +from torch.distributed.checkpoint.planner import ( + TensorWriteData, + WriteItem, + WriteItemType, +) + + +aten = torch.ops.aten + + +class LocalShardsWrapper(torch.Tensor): + """ + A wrapper class to hold local shards of a DTensor. + This class is used largely for checkpointing purposes and implicitly subtypes + the _Checkpointable protocol. + """ + + __slots__ = ["_local_shards", "_storage_meta"] + _local_shards: list[torch.Tensor] + _storage_meta: TensorStorageMetadata + + @staticmethod + def __new__( + cls, local_shards: list[torch.Tensor], local_offsets: list[tuple[int, ...]] + ) -> "LocalShardsWrapper": + assert all( + tensor.device == local_shards[0].device for tensor in local_shards[1:] + ) + + # if empty shard, we create a empty tensor + if len(local_shards) == 0: + r = torch.Tensor._make_wrapper_subclass( + cls, + torch.Size([0, 0]), + ) + r._local_shards = [] + r._storage_meta = TensorStorageMetadata( + properties=TensorProperties(), + size=torch.Size([0, 0]), + chunks=[ + ChunkStorageMetadata( + offsets=torch.Size([0, 0]), sizes=torch.Size([0, 0]) + ) + ], + ) + return r + + # we calculate the total tensor size by "concat" on second tensor dimension + cat_tensor_shape = list(local_shards[0].size()) + if len(local_shards) > 1 and local_shards[0].ndim == 2: # column-wise sharding + for shard in local_shards[1:]: + cat_tensor_shape[1] += shard.size()[1] + + # in cases of sharding optimizer rowwise, we calculate total tensor size by "concat" on first tensor dimension + if len(local_shards) > 1 and local_shards[0].ndim == 1: # column-wise sharding + for shard in local_shards[1:]: + cat_tensor_shape[0] += shard.size()[0] + + wrapper_properties = TensorProperties.create_from_tensor(local_shards[0]) + wrapper_shape = torch.Size(cat_tensor_shape) + chunks_meta = [ + ChunkStorageMetadata( + offsets=torch.Size(offset), + sizes=shard.size(), + ) + for shard, offset in zip(local_shards, local_offsets) + ] + + r = torch.Tensor._make_wrapper_subclass( + cls, + torch.Size(cat_tensor_shape), + ) + r._local_shards = local_shards + r._storage_meta = TensorStorageMetadata( + properties=wrapper_properties, + size=wrapper_shape, + chunks=chunks_meta, + ) + + return r + + # necessary for ops dispatching from this subclass to its local shards + @classmethod + def __torch_dispatch__(cls, func, types, args=(), kwargs=None): # type: ignore[override] + kwargs = kwargs or {} + + dispatcher = { + torch.ops._c10d_functional.all_gather_into_tensor.default: cls.handle_all_gather_into_tensor, + torch.ops._c10d_functional.wait_tensor.default: cls.handle_wait_tensor, + aten._to_copy.default: cls.handle_to_copy, + aten.view.default: cls.handle_view, + aten.equal.default: cls.handle_equal, + aten.detach.default: cls.handle_detach, + aten.clone.default: cls.handle_clone, + aten.new_empty.default: cls.handle_new_empty, + } + + if func in dispatcher: + return dispatcher[func](args, kwargs) + else: + raise NotImplementedError( + f"{func} is not supported for LocalShardsWrapper!" + ) + + @staticmethod + def handle_all_gather_into_tensor(args, kwargs) -> torch.Tensor: + dim = args[0].local_sizes()[0][1] + cat_tensor = torch.cat( + [t.view(-1) for t in args[0].local_shards()], dim=0 + ).view(-1, dim) + return torch.ops._c10d_functional.all_gather_into_tensor.default( + cat_tensor, *args[1:], **kwargs + ) + + @staticmethod + def handle_wait_tensor(args, kwargs) -> torch.Tensor: + return torch.ops._c10d_functional.wait_tensor(args[0]) + + @staticmethod + def handle_to_copy(args, kwargs) -> torch.Tensor: + res_shards_list = [ + aten._to_copy.default(shard, *args[1:], **kwargs) + for shard in args[0].local_shards() + ] + return LocalShardsWrapper(res_shards_list, args[0].local_offsets()) + + @staticmethod + def handle_view(args, kwargs) -> "LocalShardsWrapper": + view_shape = args[1] + res_shards_list = [] + if len(args[0].local_shards()) > 1: + if args[0].local_shards()[0].ndim == 2: + assert ( + args[0].storage_metadata().size[0] == view_shape[0] + and args[0].storage_metadata().size[1] == view_shape[1] + ) + # This accounts for a DTensor quirk, when multiple shards are present on a rank, DTensor on + # init calls view_as() on the global tensor shape + # will fail because the view shape is not applicable to individual shards. + res_shards_list = [ + aten.view.default(shard, shard.shape, **kwargs) + for shard in args[0].local_shards() + ] + elif args[0].local_shards()[0].ndim == 1: + assert args[0].storage_metadata().size[0] == view_shape[0] + # This case is for optimizer sharding as regardless of sharding type, optimizer state is row wise sharded + res_shards_list = [ + aten.view.default(shard, shard.shape, **kwargs) + for shard in args[0].local_shards() + ] + else: + raise NotImplementedError("No support for view on tensors ndim > 2") + else: + # view is called per shard + res_shards_list = [ + aten.view.default(shard, args[1], **kwargs) + for shard in args[0].local_shards() + ] + return LocalShardsWrapper(res_shards_list, args[0].local_offsets()) + + @staticmethod + def handle_equal(args, kwargs) -> bool: + """ + LocalShardsWrapper equal impl also checks for equality of storage metadata + and the order of shards + """ + a, b = args[0], args[1] + if len(a.local_shards()) != len(b.local_shards()): + return False + if not all( + aten.equal.default(x, y) for x, y in zip(a.local_shards(), b.local_shards()) + ): + return False + if a.storage_metadata() != b.storage_metadata(): + return False + return True + + @staticmethod + def handle_detach(args, kwargs) -> "LocalShardsWrapper": + self_ls = args[0] + deatched_local_shards = [ + aten.detach.default(shard) for shard in self_ls.local_shards() + ] + self_ls._local_shards = deatched_local_shards + self_ls._storage_meta.properties.requires_grad = False + return self_ls + + @staticmethod + def handle_clone(args, kwargs) -> "LocalShardsWrapper": + self_ls = args[0] + desired_memory_format = kwargs.get("memory_format", None) + if desired_memory_format and desired_memory_format != torch.preserve_format: + raise NotImplementedError( + f"{desired_memory_format} is not supported for LocalShardsWrapper!" + ) + cloned_local_shards = [ + shard.clone(memory_format=desired_memory_format) + for shard in self_ls._local_shards + ] + return LocalShardsWrapper(cloned_local_shards, self_ls.local_offsets()) + + @staticmethod + def handle_new_empty(args, kwargs) -> "LocalShardsWrapper": + self_ls = args[0] + return LocalShardsWrapper( + [torch.empty_like(shard) for shard in self_ls._local_shards], + self_ls.local_offsets(), + ) + + @property + def device(self) -> torch._C.device: # type: ignore[override] + return ( + self._local_shards[0].device if self._local_shards else torch.device("meta") + ) + + @property + def is_meta(self) -> bool: # type: ignore[override] + return self._local_shards[0].is_meta if self._local_shards else True + + def is_pinned(self) -> bool: # type: ignore[override] + return self._storage_meta.properties.pin_memory + + def requires_grad_(self, requires_grad: bool = True) -> "LocalShardsWrapper": + self._storage_meta.properties.requires_grad = requires_grad + [shard.requires_grad_(requires_grad) for shard in self._local_shards] + return self + + def local_shards(self) -> list[torch.Tensor]: + """ + Returns a list of :class:`torch.Tensor' corresponding to the + local shards for this rank. Returns an empty list if the current rank + does not host any shards for this Tensor. + """ + return self._local_shards + + def local_sizes(self) -> list[torch.Size]: + """ + Returns a list of :class:`torch.Size' corresponding to the + local sizes for the shards on this rank. Returns an empty list if the current rank + does not host any shards for this Tensor. + """ + return [chunk.sizes for chunk in self._storage_meta.chunks] + + def local_offsets(self) -> list[torch.Size]: + """ + Returns a list of :class:`torch.Size' corresponding to the + local offsets for the shards on this rank. Returns an empty list if the current rank + does not host any shards for this Tensor. + """ + return [chunk.offsets for chunk in self._storage_meta.chunks] + + @property + def local_chunks(self) -> list[ChunkStorageMetadata]: + """ + Returns a :class:`list[ChunkStorageMetadata]` object corresponding to the + metadata for each tensor shard + """ + return self._storage_meta.chunks + + def storage_metadata(self) -> TensorStorageMetadata: + """ + Returns a :class:`TensorStorageMetadata` object corresponding to the + metadata for the local tensor on current rank + """ + return self._storage_meta + + def is_empty_shard(self) -> bool: + """ + Returns a :class:`bool` object indicating if the local tensor on current rank + is an empty tensor + """ + return self._storage_meta.size[0] == 0 and self._storage_meta.size[1] == 0 + + def __create_write_items__(self, fqn: str, object: Any) -> list[WriteItem]: + """ + For compatibility with DCP, we support creation of WriteItems + such that they can be saved properly. + """ + return [ + WriteItem( + index=MetadataIndex(fqn, chunks.offsets), + type=WriteItemType.SHARD, + tensor_data=TensorWriteData( + chunk=ChunkStorageMetadata( + offsets=chunks.offsets, + sizes=chunks.sizes, + ), + properties=self._storage_meta.properties, + size=object.size(), + ), + ) + for tensor, chunks in zip(self.local_shards(), self.local_chunks) + ] + + def __create_chunk_list__(self) -> list[ChunkStorageMetadata]: + """ + For compatibility with DCP, we support creation of chunk lists + such that they can be saved properly. + """ + return self._storage_meta.chunks + + def __get_tensor_shard__(self, index: MetadataIndex) -> torch.Tensor: + """ + For compatibility with DCP, we support finding shard based on index + Return a 'torch.Tensor' shard based on 'MetadataIndex'. + """ + # Fast lookup path + if index.index is not None: + if ( + len(self._local_shards) > index.index + and self._storage_meta.chunks[index.index].offsets == index.offset + ): + return self._local_shards[index.index] + + if index.offset is not None: + for shard, chunk in zip(self._local_shards, self._storage_meta.chunks): + if chunk.offsets == index.offset: + return shard + + # Empty shard case + if len(self._local_shards) == 0 and self._storage_meta.chunks[ + 0 + ].sizes == torch.Size([0, 0]): + return torch.empty(0) + + raise ValueError( + f"Could not find shard at '{index.offset}' for FQN: '{index.fqn}'" + ) + + def _get_tensor_size_bytes(self) -> int: + object_size = 0 + for shard in self.local_shards(): + object_size += shard.nelement() * shard.element_size() + return object_size + + def __hash__(self) -> int: + return id(self) + + def __repr__(self) -> str: # type: ignore[override] + return f"LocalShardsWrapper:{self._local_shards} {self._storage_meta}" + + def __str__(self) -> str: + return f"LocalShardsWrapper:{self._local_shards} {self._storage_meta}" diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/_tp_conv.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/_tp_conv.py new file mode 100644 index 0000000000000000000000000000000000000000..275cb07934b5030bc9cd5bc71dc66f82e98eb3b5 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/_tp_conv.py @@ -0,0 +1,293 @@ +# mypy: allow-untyped-defs +# Copyright (c) Meta Platforms, Inc. and affiliates +# implement matrix related ops for distributed tensor +from typing import cast + +import torch +import torch.distributed as dist +import torch.distributed.tensor._api as dtensor + + +aten = torch.ops.aten + + +def _requires_data_exchange(padding, dim_map) -> bool: + # Data exchange is not need if only sharded across batch dim + if all(x == -1 for x in dim_map[1:]): + return False + # TODO: whether there requires data exchange is currently determined by padding + return padding[-1] != 0 + + +def _is_supported(input_size, kernel_size, stride, padding, dilation): + if dilation[-1] != 1: + raise RuntimeError("Dilation must be 1 for tensor parallel convolution.") + if padding[-1] != 0: + if stride[-1] != 1: + raise RuntimeError( + "Stride must be 1 when there is padding for tensor parallel convolution." + ) + if kernel_size[-1] // 2 > input_size[-1]: + raise RuntimeError( + "kernel_size[-1] // 2 should be less than or equal to input_size[-1] for tensor parallel convolution." + ) + else: + if not (input_size[-1] % stride[-1] == 0 and stride[-1] == kernel_size[-1]): + raise RuntimeError( + "It requires that input_size[-1] is divisible by stride[-1] and stride[-1] equals kernel_size[-1] " + "when there is padding for tensor parallel convolution." + ) + return True + + +def _ring_send_recv_construct(in_tensor, d1, d2, left, right, rank, size): + # dist comms and reconstruct local input tensor + send_to_right = in_tensor[..., -d1:].contiguous() + send_to_left = in_tensor[..., :d2].contiguous() + recv_from_right = torch.zeros_like(send_to_left) + recv_from_left = torch.zeros_like(send_to_right) + + send_op_right = dist.P2POp(dist.isend, send_to_right, right) + send_op_left = dist.P2POp(dist.isend, send_to_left, left) + recv_op_right = dist.P2POp(dist.irecv, recv_from_right, right) + recv_op_left = dist.P2POp(dist.irecv, recv_from_left, left) + + reqs = dist.batch_isend_irecv( + [send_op_right, send_op_left, recv_op_left, recv_op_right] + ) + for req in reqs: + req.wait() + + if rank == 0: + in_tensor = torch.cat([in_tensor, recv_from_right], dim=-1) + elif rank == size - 1: + in_tensor = torch.cat([recv_from_left, in_tensor], dim=-1) + else: + in_tensor = torch.cat([recv_from_left, in_tensor, recv_from_right], dim=-1) + + return in_tensor + + +def _ring_send_recv_aggregate(grad_in_tensor, d1, d2, left, right, rank, size): + # dist comms and aggregate gradients for edge pixels + send_to_right = grad_in_tensor[:, :, :, -d2:].contiguous() + send_to_left = grad_in_tensor[:, :, :, :d1].contiguous() + recv_from_right = torch.zeros_like(send_to_left) + recv_from_left = torch.zeros_like(send_to_right) + + send_op_right = dist.P2POp(dist.isend, send_to_right, right) + send_op_left = dist.P2POp(dist.isend, send_to_left, left) + recv_op_right = dist.P2POp(dist.irecv, recv_from_right, right) + recv_op_left = dist.P2POp(dist.irecv, recv_from_left, left) + + reqs = dist.batch_isend_irecv( + [send_op_right, send_op_left, recv_op_left, recv_op_right] + ) + for req in reqs: + req.wait() + + if rank == 0: + grad_in_tensor = grad_in_tensor[:, :, :, :-d2] + grad_in_tensor[:, :, :, -d1:] = torch.add( + grad_in_tensor[:, :, :, -d1:], recv_from_right + ) + elif rank == size - 1: + grad_in_tensor = grad_in_tensor[:, :, :, d1:] + grad_in_tensor[:, :, :, :d2] = torch.add( + grad_in_tensor[:, :, :, :d2], recv_from_left + ) + else: + grad_in_tensor = grad_in_tensor[:, :, :, d1:-d2] + grad_in_tensor[:, :, :, -d1:] = torch.add( + grad_in_tensor[:, :, :, -d1:], recv_from_right + ) + grad_in_tensor[:, :, :, :d2] = torch.add( + grad_in_tensor[:, :, :, :d2], recv_from_left + ) + + +def tp_convolution( + op_call: torch._ops.OpOverload, + local_tensor_args: tuple[object, ...], + local_tensor_kwargs: dict[str, object], + dim_map: list[int], +) -> object: + assert op_call == aten.convolution.default + assert len(local_tensor_args) == 9 + + rank = dist.get_rank() + size = dist.get_world_size() + in_tensor = cast(torch.Tensor, local_tensor_args[0]) + weight = cast(torch.Tensor, local_tensor_args[1]) + stride, padding, dilation = local_tensor_args[3:6] + + assert _is_supported(in_tensor.shape, weight.shape, stride, padding, dilation) + assert isinstance(padding, list) + + if not _requires_data_exchange(padding, dim_map): + local_results = op_call(*local_tensor_args, **local_tensor_kwargs) + return local_results + else: + # step 0 compute the overlap pixels of the input tensor + d = weight.shape[-1] - 1 + d1 = d // 2 + d2 = d - d1 + assert d1 + d2 == d + right = (rank + 1) % size + left = (rank - 1 + size) % size + + # step1 reconstruct local input tensor + in_tensor = _ring_send_recv_construct( + in_tensor, d1, d2, left, right, rank, size + ) + + # step2 feed local input tensor to op_call + local_tensor_args_list = list(local_tensor_args) + local_tensor_args_list[0] = in_tensor + local_tensor_args = cast(tuple[object, ...], local_tensor_args_list) + local_results = op_call(*local_tensor_args, **local_tensor_kwargs) + + # step3 remove extra outputs from the results + padding_w = padding[-1] + w = local_results.size(-1) + if rank == 0: + local_results = local_results[..., : w - padding_w] + elif rank == size - 1: + local_results = local_results[..., padding_w:] + else: + local_results = local_results[..., padding_w : w - padding_w] + + return local_results + + +def tp_convolution_backward( + op_call: torch._ops.OpOverload, + local_tensor_args: tuple[object, ...], + local_tensor_kwargs: dict[str, object], + dim_map: list[int], +) -> object: + assert op_call == aten.convolution_backward.default + assert len(local_tensor_args) == 11 + + rank = dist.get_rank() + size = dist.get_world_size() + grad_out_tensor = cast(torch.Tensor, local_tensor_args[0]) + in_tensor = cast(torch.Tensor, local_tensor_args[1]) + weight = cast(torch.Tensor, local_tensor_args[2]) + stride, padding, dilation = local_tensor_args[4:7] + + assert _is_supported(in_tensor.shape, weight.shape, stride, padding, dilation) + assert isinstance(padding, list) + + if not _requires_data_exchange(padding, dim_map): + local_results = op_call(*local_tensor_args, **local_tensor_kwargs) + return local_results + else: + # step 0 compute the overlap pixels of the input tensor + d = weight.shape[3] - 1 + d1 = d // 2 + d2 = d - d1 + assert d1 + d2 == d + right = (rank + 1) % size + left = (rank - 1 + size) % size + + # step1 reconstruct local input tensor + in_tensor = _ring_send_recv_construct( + in_tensor, d1, d2, left, right, rank, size + ) + + # step2 reconstruct local gradient output tensor + padding_w = padding[1] + if rank == 0: + grad_out_tensor = torch.nn.functional.pad( + grad_out_tensor, (0, padding_w), "constant", 0 + ) + elif rank == size - 1: + grad_out_tensor = torch.nn.functional.pad( + grad_out_tensor, (padding_w, 0), "constant", 0 + ) + else: + grad_out_tensor = torch.nn.functional.pad( + grad_out_tensor, (padding_w, padding_w), "constant", 0 + ) + + # step3 feed local input tensor to op_call + local_tensor_args_list = list(local_tensor_args) + local_tensor_args_list[0] = grad_out_tensor + local_tensor_args_list[1] = in_tensor + local_tensor_args = cast(tuple[object, ...], local_tensor_args_list) + local_results = op_call(*local_tensor_args, **local_tensor_kwargs) + + # step4 aggregate gradients for edge pixels + grad_in_tensor = local_results[0] + if grad_in_tensor is not None: + grad_in_tensor = _ring_send_recv_aggregate( + grad_in_tensor, d1, d2, left, right, rank, size + ) + local_results = list(local_results) + local_results[0] = grad_in_tensor + + local_results = cast(tuple[object, ...], local_results) + + return local_results + + +def convolution_handler( + op_call: torch._ops.OpOverload, + args: tuple[object, ...], + kwargs: dict[str, object], +) -> object: + # extract local tensor and sharding infos to a OpInfo + op_info = dtensor.DTensor._op_dispatcher.unwrap_to_op_info(op_call, args, kwargs) + + # sharding propagation + dtensor.DTensor._op_dispatcher.sharding_propagator.propagate(op_info) + output_sharding = op_info.output_sharding + assert output_sharding is not None, "output sharding should not be None" + output_spec = output_sharding.output_spec + assert isinstance(output_spec, dtensor.DTensorSpec) + + # local propagation + local_results = tp_convolution( + op_call, + tuple(op_info.local_args), + op_info.local_kwargs, + output_spec.dim_map, + ) + + return dtensor.DTensor._op_dispatcher.wrap(local_results, output_spec) + + +def convolution_backward_handler( + op_call: torch._ops.OpOverload, + args: tuple[object, ...], + kwargs: dict[str, object], +) -> object: + # Redistribute grad_output tensor to the same placement as input tensor + # pyrefly: ignore [bad-assignment] + args = list(args) + assert isinstance(args[0], dtensor.DTensor) and isinstance(args[1], dtensor.DTensor) + # pyrefly: ignore [unsupported-operation] + args[0] = args[0].redistribute(args[1].device_mesh, args[1].placements) + args = tuple(args) + + # extract local tensor and sharding infos to a OpInfo + op_info = dtensor.DTensor._op_dispatcher.unwrap_to_op_info(op_call, args, kwargs) + + # sharding propagation + dtensor.DTensor._op_dispatcher.sharding_propagator.propagate(op_info) + output_sharding = op_info.output_sharding + assert output_sharding is not None, "output sharding should not be None" + assert isinstance(op_info.flat_args_schema[0], dtensor.DTensorSpec) + + # local propagation + local_results = tp_convolution_backward( + op_call, + tuple(op_info.local_args), + op_info.local_kwargs, + op_info.flat_args_schema[0].dim_map, + ) + + return dtensor.DTensor._op_dispatcher.wrap( + local_results, output_sharding.output_spec + ) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/_utils.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..f085b681f94911521683c7d566dc60124e1c9047 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/_utils.py @@ -0,0 +1,461 @@ +import logging +import threading +from collections.abc import Sequence +from typing import Any, cast, Optional + +import torch +import torch.distributed._functional_collectives as funcol +import torch.distributed.tensor._api as dtensor +from torch._prims_common import ShapeType +from torch.distributed._local_tensor import maybe_run_for_local_tensor +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor._collective_utils import redistribute_cost +from torch.distributed.tensor._dtensor_spec import DTensorSpec +from torch.distributed.tensor.placement_types import ( + _StridedShard, + Partial, + Placement, + Replicate, + Shard, +) + + +logger = logging.getLogger(__name__) + + +class ExplicitRedistributionContext: + """ + Within this context manager, DTensor will refuse to perform implicit redistribution, + instead raising an error. Manual calls to ``redistribute()`` are required wherever a redistribution + must occur to avoid erroring. This can be used to ensure that the user is aware of all redistribution. + + Note: it is easier to use this mode on just the forward pass of a typical DTensor program, as the backwards pass + may contain implicit redistribution calls that are not visible to the user and difficult to replace with manual + calls. Redistribution during backward can be made explicit by writing `autograd.Function`s that are no-op + during forward and perform a manual redistribution during backwards. + + enable (bool) if False, disables the context manager. Can be used nested inside an enabled region. + + strict (bool) if True, triggers on any redistribution. If False, only triggers on redistributions that perform + communication. + + mode (str) Determines what happens when ExplicitRedistributionContext triggers: + "raise": raises an exceptoin, "warn" issues a warning + """ + + _local = threading.local() + + def __init__(self, enable: bool = True, strict: bool = False, mode="raise"): + self._enable = enable + self._strict = strict + if mode not in ("raise", "warn"): + raise RuntimeError(f"Invalid mode {mode}") + self._raise_on_redistribution = mode == "raise" + + @classmethod + def observe_redistribution( + cls, src_spec: DTensorSpec, dst_spec: DTensorSpec, message: str + ): + if instance := getattr(cls._local, "_active", None): + allowed = True + if instance._enable: + if instance._strict: + allowed = False + else: + allowed = redistribute_cost(src_spec, dst_spec) <= 0 + if not allowed: + if instance._raise_on_redistribution: + raise RuntimeError(message) + else: + logger.warning(message) + + def __enter__(self): + self._prev = getattr(ExplicitRedistributionContext._local, "_active", None) + ExplicitRedistributionContext._local._active = self + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + ExplicitRedistributionContext._local._active = self._prev + + +def compute_local_shape_and_global_offset( + global_shape: ShapeType, + mesh: DeviceMesh, + placements: Sequence[Placement], + skip_offset: bool = False, +) -> tuple[tuple[int, ...], tuple[int, ...]]: + """ + Compute the local tensor shape and the global offsets into the original tensor + of a DTensor on its current global rank. This is useful for checkpointing purpose. + + Example: + global_tensor = [[0, 1, 2, 3, 4], sharded on mesh (DP=2, TP=2) with (Shard(1), Shard(1)) + [10, 11, 12, 13, 14]] + + This table shows the return value of local_shape and global_offset for each rank. + (`local_tensor` is for illustration only). + + Note how the first coordinate of global_offset is always 0, corresponding to tensor dim 0 being replicated. + + Rank local_tensor local_shape global_offset + ------------------------------------------------------------- + 0 [[0, 1], (2, 2) (0, 0) + [10, 11]] + + 1 [[2], (2, 1) (0, 2) + [12]] + + 2 [[3], (2, 1) (0, 3) + [13]] + + 3 [[4], (2, 1) (0, 4) + [14]] + + Args: + global_shape (ShapeType): The global shape of the DTensor. + mesh (:class:`DeviceMesh`): The device mesh this DTensor is distributed on. + placements (Sequence[:class:`Placement`]]): The placements of the DTensor. + skip_offset (bool): If True, skip computing the global offsets and return an empty + tuple for global_offset. This can improve performance when only the local shape + is needed. Defaults to False. + + Return: + local_shape: the shape of the DTensor's _local_tensor on the current rank. + global_offset: a tuple of offsets for each dimension of the global tensor shape, + identifying how this shard fits into the global tensor in each dimension. If + skip_offset is True, this will be an empty tuple. + + """ + return _compute_local_shape_and_global_offset( + global_shape, mesh.shape, mesh.get_coordinate(), placements, skip_offset + ) + + +@maybe_run_for_local_tensor +def _get_shard_size_and_offsets( + curr_local_size: int, + mesh_dim_size: int, + rank: int, + placement: Shard | _StridedShard, + previous_offsets, + zero_global_offset: int, + skip_offset: bool, +) -> tuple[int, Optional[torch.Tensor]]: + kwargs: dict[str, Any] = { + "curr_local_size": curr_local_size, + "num_chunks": mesh_dim_size, + "rank": rank, + } + if isinstance(placement, _StridedShard): + kwargs["return_first_offset"] = False + shard_size, shard_offsets = placement._local_shard_size_and_offset(**kwargs) + if skip_offset: + return shard_size, None + if shard_size == 0: + return shard_size, torch.arange(zero_global_offset, zero_global_offset + 1) + if isinstance(placement, Shard) and not isinstance(placement, _StridedShard): + assert isinstance(shard_offsets, int) + index = torch.arange(shard_offsets, shard_offsets + shard_size) + else: + assert isinstance(shard_offsets, list) + index = torch.tensor(shard_offsets) + if previous_offsets is None: + return shard_size, index + else: + return shard_size, previous_offsets[index] + + +@maybe_run_for_local_tensor +def _get_first_offset(offsets: torch.Tensor) -> int: + return int(offsets[0]) + + +# accept 'plain data types' to enable simpler unit testing without creating device mesh +def _compute_local_shape_and_global_offset( + global_shape: ShapeType, + mesh_shape: ShapeType, + my_coordinate: list[int] | None, + placements: Sequence[Placement], + skip_offset: bool = False, +) -> tuple[tuple[int, ...], tuple[int, ...]]: + """ + Suppose you have a full tensor with size global_shape, and you have sharded + it according to placements for mesh_shape. This function returns, for a + specific coordinate my_coordinate in the device mesh: + + - The size of your local shard WITHOUT padding (i.e., if you have + an uneven split, your size might be smaller than the other entries + in your dim), and + + - Where the data for your shard begins, in the full tensor. + + This function is fairly simple if your tensor is evenly sharded; the complication + is around uneven splits. There is also some complication for handling StridedShard, + which changes the order you should apply sharding. + + Args: + global_shape (ShapeType): The global shape of the tensor. + mesh_shape (ShapeType): The shape of the device mesh. + my_coordinate (Optional[list[int]]): The coordinate of the current rank in the device mesh. + placements (Sequence[Placement]): The placements of the DTensor. + skip_offset (bool): If True, skip computing the global offsets and return an empty + tuple for global_offset. This can improve performance when only the local shape + is needed. Defaults to False. + + Returns: + tuple: A tuple containing: + - local_shape (tuple[int, ...]): The shape of the local shard on the current rank. + - global_offset (tuple[int, ...]): The offsets for each dimension identifying where + this shard begins in the global tensor. If skip_offset is True, this will be an + empty tuple. + """ + + empty_offset = () + if my_coordinate is None: + # if rank not in the mesh, return empty offset + return ((0,), empty_offset) + + local_shape = list(global_shape) + # Perform shard from left to right. For example, + # global tensor: [0, 1, 2, 3, 4, 5, 6, 7] + # placements: S(0), SS(0, split_factor=2) + # mesh_shape: (2, 2) + # After S(0), shard_dim_to_global_offsets are + # {0: [0, 1, 2, 3]} on my_coordinate [0, 0] [0, 1] + # {0: [4, 5, 6, 7]} on my_coordinate [1, 0] [1, 1] + # After SS(0, split_factor=2), shard_dim_to_global_offsets are + # {0: [0, 2]} on my_coordinate [0, 0] + # {0: [1, 3]} on my_coordinate [0, 1] + # {0: [4, 6]} on my_coordinate [1, 0] + # {0: [5, 7]} on my_coordinate [1, 1] + shard_dim_to_global_offsets = {} + for mesh_dim, placement in enumerate(placements): + if not isinstance(placement, (Shard, _StridedShard)): + continue + shard_dim = placement.dim + zero_global_offset = global_shape[shard_dim] + assert shard_dim < len(local_shape), ( + f"Sharding dim {shard_dim} greater than tensor ndim {len(local_shape)}" + ) + previous_offsets = shard_dim_to_global_offsets.get(shard_dim) + shard_size, shard_offsets = _get_shard_size_and_offsets( + local_shape[shard_dim], + mesh_shape[mesh_dim], + my_coordinate[mesh_dim], + placement, + previous_offsets, + zero_global_offset, + skip_offset, + ) + local_shape[shard_dim] = shard_size + shard_dim_to_global_offsets[shard_dim] = shard_offsets + if skip_offset: + return tuple(local_shape), empty_offset + global_offset = [0] * len(global_shape) + for shard_dim, global_offsets in shard_dim_to_global_offsets.items(): + global_offset[shard_dim] = _get_first_offset(global_offsets) + return tuple(local_shape), tuple(global_offset) + + +compute_global_tensor_info = torch._C._DTensor_compute_global_tensor_info + + +def compute_local_tensor_info( + global_tensor: torch.Tensor, + mesh: DeviceMesh, + placements: Sequence[Placement], +) -> tuple[list[int], list[int]]: + """ + Compute the local size and stride of a DTensor from the given global tensor info. + + For example, if we have a global tensor with size (4, 8, 4) and stride (32, 1, 8). + If the DTensor placements are [Shard(2)] and world_size is 2; + then the local size is (4, 8, 2) and stride is (16, 1, 8). + + Args: + tensor (:class:`torch.Tensor`): + Global tensor which DTensor will distribute + mesh (:class:`DeviceMesh`): + Object which describes the mesh topology + of devices for the DTensor. + placements (Sequence[:class:`Placement`]): + The attribute of the DTensor that describes its layout + on the mesh topology. + + Returns: + local_shape: A List of int which specifies the size of the local tensor. + local_stride: A List of int which specifies the stride of the local tensor. + """ + local_shape = list(global_tensor.size()) + local_stride = list(global_tensor.stride()) + + for idx, placement in enumerate(placements): + mesh_dim_size = mesh.size(idx) + if placement.is_shard(): + shard_placement = cast(Shard, placement) + if shard_placement.dim < 0: + raise AssertionError( + "Shard placements should have negative dims normalized in " + f"the user-facing APIs: {shard_placement}" + ) + shard_dim = shard_placement.dim + assert shard_dim < len(local_shape), ( + f"Sharding dim {shard_dim} greater than tensor ndim {len(local_shape)} " + f"for placement number {idx}." + ) + + global_dim_size = local_shape[shard_dim] + assert global_dim_size % mesh_dim_size == 0, ( + f"Global dim {global_dim_size} not divisible by mesh size {mesh_dim_size}" + ) + local_shape[shard_dim] = global_dim_size // mesh_dim_size + + # shrink strides that were scaled up globally + for i in range(len(local_stride)): + if ( + i != shard_dim + and local_stride[i] >= local_stride[shard_dim] * mesh_dim_size + ): + local_stride[i] = local_stride[i] // mesh_dim_size + + elif not isinstance(placement, (Replicate, Partial)): + raise RuntimeError(f"placement type {type(placement)} not supported!") + + return local_shape, local_stride + + +def compute_global_tensor_shape( + shape: torch.Size, mesh: DeviceMesh, placements: Sequence[Placement] +) -> torch.Size: + """ + Compute the global size of a DTensor from the given local tensor shape, + the mesh and placements. Different from `compute_global_tensor_info`, + which assumes sharding is even, this util allgathers local shards' shapes + from all ranks and thus can support uneven sharding. + NOTE: Currently this function only supports 1D mesh. + + Args: + shape (:class:`torch.Size`): + Shape of the local tensor + mesh (:class:`DeviceMesh`): + Object which describes the mesh topology + of devices for the DTensor. + placements (Sequence[:class:`Placement`]]): + The attribute of the DTensor that describes its layout + on the mesh topology. + + Return: + tensor_shape: Shape of the global DTensor. + """ + if len(placements) != 1: + raise NotImplementedError( + "compute_global_tensor_shape only supports 1 placement for now." + ) + + if len(placements) != mesh.ndim: + raise RuntimeError( + "Expected one placement per mesh dim, " + f"but found {len(placements)} placements and {mesh.ndim} mesh dims." + ) + + if isinstance(placements[0], Replicate): + return shape + elif isinstance(placements[0], Shard): + + @maybe_run_for_local_tensor + def _create_local_shape_tensor(shape): + return torch.tensor(list(shape), device=mesh.device_type) + + local_shape = _create_local_shape_tensor(shape) + gathered_shaped_tensors = [ + torch.empty_like(local_shape, device=local_shape.device) + for _ in range(mesh.size()) + ] + funcol.all_gather_inplace(gathered_shaped_tensors, local_shape, mesh) + + @maybe_run_for_local_tensor + def _validate_and_compute_global_shape(local_shape, gathered_shaped_tensors): + sharded_dim_sum = 0 + shard_dim = placements[0].dim # type: ignore[union-attr] + other_dims = [d for d in range(len(shape)) if d != shard_dim] + for shape_tensor in gathered_shaped_tensors: + if not torch.equal(local_shape[other_dims], shape_tensor[other_dims]): + raise RuntimeError( + "Non-sharded dimensions should have identical size across ranks." + ) + shape_tensor_list = shape_tensor.tolist() + sharded_dim_sum += shape_tensor_list[shard_dim] + return sharded_dim_sum + + sharded_dim_sum = _validate_and_compute_global_shape( + local_shape, gathered_shaped_tensors + ) + global_shape = list(shape) + global_shape[placements[0].dim] = sharded_dim_sum + return torch.Size(global_shape) + else: + raise NotImplementedError( + f"Placement type {type(placements[0])} not supported." + ) + + +def try_find_mesh_from_args( + op_call: torch._ops.OpOverload, args: Sequence[object] +) -> DeviceMesh: + """ + Find the device mesh object from args. + It returns None if no mesh is found. + NOTE: we can optimize this search if needed + """ + for arg in args: + if isinstance(arg, (dtensor.DTensor, DTensorSpec)): + return arg.device_mesh + elif ( + isinstance(arg, (list, tuple)) + and len(arg) > 0 + and isinstance(arg[0], (dtensor.DTensor, DTensorSpec)) + ): + return arg[0].device_mesh + + raise ValueError(f"Cannot find device mesh from args for op : {op_call}.") + + +def compute_local_stride( + global_stride: ShapeType, mesh: DeviceMesh, placements: Sequence[Placement] +) -> tuple[int, ...]: + """ + Compute the stride of a local tensor shard, given the global stride of the DTensor. + NOTE: Currently this function is assuming the DTensor is evenly shardable. + """ + stride_divisors = [1] * len(global_stride) + for mesh_idx, p in enumerate(placements): + if p.is_shard(): + i = cast(Shard, p).dim + # tensor dimension i is sharded on mesh dimension mesh_idx, + # so we need to divide all the strides larger than stride[i] + # (by the submesh size) + for j in range(len(global_stride)): + if global_stride[j] > global_stride[i]: + stride_divisors[j] *= mesh.size(mesh_idx) + return tuple( + global_stride[i] // stride_divisors[i] for i in range(len(global_stride)) + ) + + +def normalize_to_torch_size(size) -> torch.Size: # type: ignore[no-untyped-def] + """ + Unify variable types of size argument to torch.Size + Acceptable types include: + int, Sequence[int], Tuple[int], Tuple[Sequence[int]], + or torch.Size + """ + if isinstance(size, torch.Size): + return size + + if isinstance(size, int): + torch_size = [size] + elif len(size) == 1 and isinstance(size[0], Sequence): + torch_size = list(size[0]) + else: + torch_size = list(size) + return torch.Size(torch_size) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/debug/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/debug/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e6aeca3b93a12deba923e8dd3094d36583cecd26 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/debug/__init__.py @@ -0,0 +1,52 @@ +# mypy: allow-untyped-defs +import torch._C +from torch.distributed.tensor.debug._comm_mode import CommDebugMode +from torch.distributed.tensor.debug._visualize_sharding import visualize_sharding + + +__all__ = ["CommDebugMode", "visualize_sharding"] + + +def _get_python_sharding_prop_cache_info(): + """ + Get the cache info for the Python sharding propagation cache, used for debugging purpose only. + This would return a named tuple showing hits, misses, maxsize and cursize of the sharding + propagator cache. Note that directly calling into the sharding propagator does not share cache + state with the DTensor dispatch fast path! + """ + from torch.distributed.tensor._api import DTensor + + return ( + DTensor._op_dispatcher.sharding_propagator.propagate_op_sharding.cache_info() # type:ignore[attr-defined] + ) + + +def _get_fast_path_sharding_prop_cache_stats(): + """ + Get a tuple (hits, misses) for the fast path sharding propagation cache, used for debugging + only. + """ + return torch._C._get_DTensor_sharding_propagator_cache_stats() + + +def _clear_python_sharding_prop_cache(): + """ + Clears the cache for the Python sharding propagation cache, used for debugging purpose only. + """ + from torch.distributed.tensor._api import DTensor + + return ( + DTensor._op_dispatcher.sharding_propagator.propagate_op_sharding.cache_clear() # type:ignore[attr-defined] + ) + + +def _clear_fast_path_sharding_prop_cache(): + """ + Clears the cache for the fast path sharding propagation cache, used for debugging purpose only. + """ + torch._C._clear_DTensor_sharding_propagator_cache() + + +# Set namespace for exposed private names +CommDebugMode.__module__ = "torch.distributed.tensor.debug" +visualize_sharding.__module__ = "torch.distributed.tensor.debug" diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/debug/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/debug/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..45ea9b4cd91cddc1f020c493d9e38b01381630cd Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/debug/__pycache__/__init__.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/debug/__pycache__/_comm_mode.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/debug/__pycache__/_comm_mode.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cfe823366d8d17a9cb50876ee511291229b87c3c Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/debug/__pycache__/_comm_mode.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/debug/__pycache__/_op_coverage.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/debug/__pycache__/_op_coverage.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1167e1f4002b051ff3daefd442e41690e8c36236 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/debug/__pycache__/_op_coverage.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/debug/__pycache__/_visualize_sharding.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/debug/__pycache__/_visualize_sharding.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7fe01a8fac4de13d56b84ff29332b88e2ba5a55d Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/debug/__pycache__/_visualize_sharding.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/debug/_comm_mode.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/debug/_comm_mode.py new file mode 100644 index 0000000000000000000000000000000000000000..66ec0bfff5f9c74dfc4760729eead2bbf5c09e70 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/debug/_comm_mode.py @@ -0,0 +1,740 @@ +# mypy: allow-untyped-defs +import copy +import json +import re +import weakref +from collections import defaultdict +from typing import Any + +import torch +import torch.nn +from torch._guards import detect_fake_mode +from torch.autograd.graph import register_multi_grad_hook +from torch.distributed._tools.mod_tracker import ModTracker +from torch.distributed.tensor._api import DTensor +from torch.nn.modules.module import ( + register_module_forward_hook, + register_module_forward_pre_hook, + register_module_full_backward_pre_hook, +) +from torch.utils._python_dispatch import TorchDispatchMode +from torch.utils._pytree import tree_flatten + + +__all__ = ["CommDebugMode"] + +funcol_native = torch.ops._c10d_functional +funcol_py = torch.ops.c10d_functional +funcol_autograd = torch.ops._c10d_functional_autograd +c10d_ops = torch.ops.c10d + +NATIVE_TO_PY_MAPPING = { + funcol_native.all_gather_into_tensor: funcol_py.all_gather_into_tensor, + funcol_native.all_gather_into_tensor_coalesced: funcol_py.all_gather_into_tensor_coalesced, + funcol_native.all_reduce: funcol_py.all_reduce, + funcol_native.all_reduce_coalesced: funcol_py.all_reduce_coalesced, + funcol_native.all_to_all_single: funcol_py.all_to_all_single, + funcol_native.broadcast: funcol_py.broadcast, + funcol_native.reduce_scatter_tensor: funcol_py.reduce_scatter_tensor, + funcol_native.reduce_scatter_tensor_coalesced: funcol_py.reduce_scatter_tensor_coalesced, + # functional ops + funcol_autograd.all_to_all_single: funcol_py.all_to_all_single, +} + +c10d_collective_ops = { + c10d_ops._allgather_base_, + c10d_ops._reduce_scatter_base_, + c10d_ops.allgather_, + c10d_ops.allgather_coalesced_, + c10d_ops.allgather_into_tensor_coalesced_, + c10d_ops.allreduce_, + c10d_ops.allreduce_coalesced_, + c10d_ops.alltoall_, + c10d_ops.alltoall_base_, + c10d_ops.broadcast_, + c10d_ops.gather_, + c10d_ops.scatter_, + c10d_ops.reduce_, + c10d_ops.reduce_scatter_, + c10d_ops.reduce_scatter_tensor_coalesced_, +} + +trivial_ops = { + "aten.detach.default", + "aten.t.default", + "aten.view.default", + "aten._to_copy.default", + "aten.as_strided.default", + "aten.transpose.int", +} + + +class _CommModeModuleTracker(ModTracker): + """ + Inherits ModuleTracker and expands on its functionality to track the + parameters and sharding information of a model at a module-level + """ + + def __init__(self): + super().__init__() + self.module_helper_dict = {} + self.module_parameters_dict = {} + self.module_parents_dict = {} + self.register_forward_hook_handles = {} + self.parent_dict = {} + self.parent_list = [] + self.sharding_dict = {} + self.activation_checkpointing = False + self.name = "" + + def _fw_set_module_hook(self, mod, input, output): + """ + Updates the current module after module finishes running and + all other hooks are resolved + """ + + if self.is_bw: + self.activation_checkpointing = True + else: + self.activation_checkpointing = False + + if not self.activation_checkpointing: + # module is no longer parent of next modules + self.parent_list.pop() + + # set current module to previous parent module + self.name = self.parent_list[-1] + + def _fw_pre_hook(self, mod, input): + """ + This function is called before the forward pass of a module. It + collects the parameters and sharding information of a module and + stores it in a dictionary. + """ + if self.is_bw: + self.activation_checkpointing = True + else: + self.activation_checkpointing = False + + self.name = super()._get_mod_name(mod) + w_mod = weakref.ref(mod) + + # adds current sub-module to module tracker parent class + super()._get_append_fn(w_mod, self.name, False)() + + args, _ = tree_flatten(input) + tensors = [a for a in args if isinstance(a, torch.Tensor) and a.requires_grad] + if not self.is_bw and tensors: + register_multi_grad_hook( + tensors, super()._get_pop_fn(w_mod, self.name, True) + ) + + if not self.activation_checkpointing: + # contains information about module ordering and depth in the module tree + if self.name not in self.module_helper_dict: + self.module_helper_dict[self.name] = {} + + self.module_helper_dict[self.name]["module_type"] = ( + str(type(mod)).replace("<", "").replace(">", "") + ) + self.module_helper_dict[self.name]["depth"] = len(self.parents) - 1 + + for param_name, param in mod.named_parameters(recurse=False): + if self.name not in self.module_parameters_dict: + self.module_parameters_dict[self.name] = {} + + self.module_parameters_dict[self.name][param_name] = param.data + + if isinstance(param.data, DTensor): + key_name = self.name + "." + param_name + self.sharding_dict[key_name] = param.data.placements + + if "parameters" not in self.module_helper_dict[self.name]: + self.module_helper_dict[self.name]["parameters"] = {} + + self.module_helper_dict[self.name]["parameters"][param_name] = str( + param.data.placements + ) + + # used to store module's parents to ensure correctness in backward pass/checkpointing + if self.name not in self.module_parents_dict: + self.module_parents_dict[self.name] = copy.deepcopy(self.parents) + + # used to create parent-child module associations for json dumps + parent = self.parent_list[-1] + if parent not in self.parent_dict: + self.parent_dict[parent] = [] + + self.parent_dict[parent].append(self.name) + self.parent_list.append(self.name) + + self.register_forward_hook_handles[self.name] = mod.register_forward_hook( + self._fw_set_module_hook + ) + + def _fw_post_hook(self, mod, input, output): # pylint: disable=useless-parent-delegation + """ + This function is called when the forward pass of a module is called. + It updates the module tracker and removes the module from parent data + """ + + super()._fw_post_hook(mod, input, output) + + def _bw_hook(self, mod, output): + """ + This function is called when the backward pass of a module is called. It + updates the current module for backward passes + """ + self.activation_checkpointing = False + self.name = super()._get_mod_name(mod) + + def __enter__(self): + self.activation_checkpointing = False + self.module_parameters_dict.clear() + self.sharding_dict.clear() + self.parent_dict.clear() + self.parent_list = ["Global"] + self.module_helper_dict.clear() + self.module_helper_dict["Global"] = {"depth": 0} + self.module_parents_dict.clear() + self.module_parents_dict["Global"] = set() + self._fw_pre_handle = register_module_forward_pre_hook(self._fw_pre_hook) + self._fw_post_handle = register_module_forward_hook(self._fw_post_hook) + self.register_forward_hook_handles.clear() + self._bw_handle = register_module_full_backward_pre_hook(self._bw_hook) + self.name = "Global" + + def __exit__(self, *args): + super().__exit__(*args) + self._bw_handle.remove() + + # removes all forward_hook handles added in the pre-hook + for handle in self.register_forward_hook_handles.values(): + handle.remove() + + def print_paramater_info(self): + print(self.module_parameters_dict) + + def print_sharding_info(self): + for key, value in self.sharding_dict.items(): + print(key + ": " + str(value)) + + +class CommDebugMode(TorchDispatchMode): + """ + :class:`CommDebugMode` is a context manager that counts the number of + functional collectives within its context. It does this using a + ``TorchDispatchMode``. + + .. note:: Not all collectives are supported yet. + + Example usage + + .. code-block:: python + + mod = ... + comm_mode = CommDebugMode() + with comm_mode: + mod.sum().backward() + print(comm_mode.get_comm_counts()) + """ + + def __init__(self): + super().__init__() + self.comm_counts: dict[Any, int] = defaultdict(int) + self.comm_module_counts = {} + self.comm_module_operation_counts = {} + self.comm_registry = set() + for native_op, py_op in NATIVE_TO_PY_MAPPING.items(): + self.comm_registry.add(native_op) + self.comm_registry.add(py_op) + + self.comm_registry.add(torch.ops._dtensor.shard_dim_alltoall) + self.advanced_module_tracker = _CommModeModuleTracker() + + def generate_json_dump(self, file_name="comm_mode_log.json", noise_level=3): + """ + Creates json file used to build browser visual + 0. prints module-level collective counts + 1. prints dTensor operations not included in trivial operations + 2. prints operations not included in trivial operations + 3. prints all operations + """ + + ( + include_DTensor_ops, + include_module_data, + include_ops, + include_trivial_ops, + ) = self._set_noise_parameters(noise_level) + + # recursively builds json data + def add_json_information(json_dict, fqn): + json_dict["fqn"] = fqn + json_dict["module_type"] = "" + json_dict["parameters"] = [] + json_dict["children"] = [] + json_dict["collectives_forward"] = [] + json_dict["collectives_backward"] = [] + json_dict["operations_forward"] = [] + json_dict["operations_backward"] = [] + + # adds module layer type and parameters, and their sharding + if ( + "module_type" in self.advanced_module_tracker.module_helper_dict[fqn] + and include_module_data + ): + json_dict["module_type"] = ( + self.advanced_module_tracker.module_helper_dict[fqn]["module_type"] + ) + + if "parameters" in self.advanced_module_tracker.module_helper_dict[fqn]: + for ( + param_name, + placement, + ) in self.advanced_module_tracker.module_helper_dict[fqn][ + "parameters" + ].items(): + json_dict["parameters"].append((param_name, placement)) + + # adds module collective information + if fqn in self.comm_module_counts: + for collective, count in self.comm_module_counts[fqn][ + "forward" + ].items(): + json_dict["collectives_forward"].append((str(collective), count)) + + for collective, count in self.comm_module_counts[fqn][ + "backward" + ].items(): + json_dict["collectives_backward"].append((str(collective), count)) + + # adds module operation information + forward_operations = [] + backward_operations = [] + checkpointing_operations = [] + + # only get operations if the minimum operation noise level is set to true + if include_DTensor_ops: + if fqn in self.comm_module_operation_counts: + ( + forward_operations, + backward_operations, + checkpointing_operations, + ) = self._get_operations_list( + self.comm_module_operation_counts[fqn] + ) + + # remove all operations who don't have DTensor inputs + if not include_ops: + forward_operations = [ + op for op in forward_operations if len(op["input_sharding"]) + ] + backward_operations = [ + op for op in backward_operations if len(op["input_sharding"]) + ] + checkpointing_operations = [ + op for op in checkpointing_operations if len(op["input_sharding"]) + ] + + # remove all operations in trivial operations set + if not include_trivial_ops: + forward_operations = [ + op + for op in forward_operations + if str(op["name"]) not in trivial_ops + ] + backward_operations = [ + op + for op in backward_operations + if str(op["name"]) not in trivial_ops + ] + checkpointing_operations = [ + op + for op in checkpointing_operations + if str(op["name"]) not in trivial_ops + ] + + # converts operation information into string format for json.dumps() + forward_operations = copy.deepcopy(forward_operations) + for op in forward_operations: + op["name"] = str(op["name"]) + + for i in range(len(op["input_sharding"])): + op["input_sharding"][i] = str(op["input_sharding"][i]) + op["input_shape"][i] = str(op["input_shape"][i]) + + backward_operations = copy.deepcopy(backward_operations) + for op in backward_operations: + op["name"] = str(op["name"]) + + for i in range(len(op["input_sharding"])): + op["input_sharding"][i] = str(op["input_sharding"][i]) + op["input_shape"][i] = str(op["input_shape"][i]) + + checkpointing_operations = copy.deepcopy(checkpointing_operations) + for op in checkpointing_operations: + op["name"] = str(op["name"]) + + for i in range(len(op["input_sharding"])): + op["input_sharding"][i] = str(op["input_sharding"][i]) + op["input_shape"][i] = str(op["input_shape"][i]) + + json_dict["operations_forward"] = forward_operations + json_dict["operations_backward"] = backward_operations + json_dict["operations_checkpointing"] = checkpointing_operations + + if fqn not in self.advanced_module_tracker.parent_dict: + return json_dict + + # recursively adds module's children + for ele in self.advanced_module_tracker.parent_dict[fqn]: + json_dict["children"].append(add_json_information({}, ele)) + + return json_dict + + json_dict: dict[str, Any] = {} + add_json_information(json_dict, "Global") + + # converts dictionary into json file + with open(file_name, "w") as json_file: + json.dump(json_dict, json_file, indent=4) + + def generate_comm_debug_tracing_table(self, noise_level=3): + """ + Generates detailed table displaying operations and collective tracing information + on a module level. Amount of information is dependent on noise_level + + 0. prints module-level collective counts + 1. prints dTensor operations not included in trivial operations, module information + 2. prints operations not included in trivial operations + 3. prints all operations + """ + + ( + include_DTensor_ops, + include_module_data, + include_ops, + include_trivial_ops, + ) = self._set_noise_parameters(noise_level) + + table = "" + for fqn in self.advanced_module_tracker.module_helper_dict: + # setting up indentations for table formatting + indent = " " * ( + 2 * self.advanced_module_tracker.module_helper_dict[fqn]["depth"] + ) + table += f"{indent}{fqn}\n" + + if include_module_data: + if ( + "module_type" + in self.advanced_module_tracker.module_helper_dict[fqn] + ): + module_type = self.advanced_module_tracker.module_helper_dict[fqn][ + "module_type" + ] + table += f"{indent}*module type: {module_type}\n" + + if "parameters" in self.advanced_module_tracker.module_helper_dict[fqn]: + table += f"{indent}*Parameter List\n" + for ( + param_name, + placement, + ) in self.advanced_module_tracker.module_helper_dict[fqn][ + "parameters" + ].items(): + table += f"{indent} *{param_name}: {placement}\n" + + indent += " " + collective_indent = " " * ( + 2 * self.advanced_module_tracker.module_helper_dict[fqn]["depth"] + 2 + ) + operation_indent = " " * ( + 2 * self.advanced_module_tracker.module_helper_dict[fqn]["depth"] + 3 + ) + + # separate the module's collective and operations by forward and backward + forward_collectives = {} + backward_collectives = {} + if fqn in self.comm_module_counts: + forward_collectives = self.comm_module_counts[fqn]["forward"] + backward_collectives = self.comm_module_counts[fqn]["backward"] + + forward_operations = [] + backward_operations = [] + checkpointing_operations = [] + + if include_DTensor_ops: + if fqn in self.comm_module_operation_counts: + ( + forward_operations, + backward_operations, + checkpointing_operations, + ) = self._get_operations_list( + self.comm_module_operation_counts[fqn] + ) + + def add_tracing_information(table, collectives_dict, operation_list): + """ + adds tracing information for module's forward or backward + """ + for collective, count in collectives_dict.items(): + table += ( + f"\033[1;33m{collective_indent}*{collective}: {count}\033[0m\n" + ) + + def add_operations( + table, operation, collective_indent, operation_indent + ): + """ + adds operation information to the table + """ + table += f"\033[1;33m{collective_indent}**{operation_name}\033[0m\n" + + if len(operation["input_shape"]): + operation_shape = operation["input_shape"] + operation_sharding = operation["input_sharding"] + operation_device_mesh = operation["device_mesh"] + + table += f"\033[1;31m{operation_indent}shape: {operation_shape}\033[0m\n" + table += f"\033[1;31m{operation_indent}sharding: {operation_sharding}\033[0m\n" + table += f"\033[1;31m{operation_indent}device mesh: {operation_device_mesh}\033[0m\n" + + return table + + for operation in operation_list: + operation_name = str(operation["name"]) + + # include all operations + if include_trivial_ops: + table = add_operations( + table, operation, collective_indent, operation_indent + ) + + # include all operations not in trivial operations + elif include_ops and operation_name not in trivial_ops: + table = add_operations( + table, operation, collective_indent, operation_indent + ) + + # only include dTensor operations not in trivial set + elif ( + include_DTensor_ops + and (operation_name not in trivial_ops) + and len(operation["input_shape"]) + ): + table = add_operations( + table, operation, collective_indent, operation_indent + ) + + return table + + if len(forward_collectives) or len(forward_operations): + table += f"{indent}FORWARD PASS\n" + table = add_tracing_information( + table, forward_collectives, forward_operations + ) + + if len(backward_collectives) or len(backward_operations): + table += f"{indent}BACKWARD PASS\n" + table = add_tracing_information( + table, backward_collectives, backward_operations + ) + + if len(checkpointing_operations): + table += f"{indent}ACTIVATION CHECKPOINTING\n" + table = add_tracing_information(table, {}, checkpointing_operations) + + return table + + def _get_operations_list(self, module_operation_counts): + forward_operations = [ + op for op in module_operation_counts["operations_list"] if not op["is_bw"] + ] + backward_operations = [ + op + for op in module_operation_counts["operations_list"] + if op["is_bw"] and not op["is_activation_checkpointing"] + ] + checkpointing_operations = [ + op + for op in module_operation_counts["operations_list"] + if op["is_activation_checkpointing"] + ] + + return forward_operations, backward_operations, checkpointing_operations + + def get_total_counts(self) -> int: + return sum(self.comm_counts.values()) + + def get_comm_counts(self) -> dict[Any, int]: + """Returns the communication counts as a dictionary. + + Returns: + Dict[Any, int]: The communication counts as a dictionary. + """ + return self.comm_counts + + def get_parameter_info(self) -> dict[str, dict[str, Any]]: + return self.advanced_module_tracker.module_parameters_dict + + def get_sharding_info(self) -> dict[str, dict[str, Any]]: + return self.advanced_module_tracker.sharding_dict + + def __enter__(self): + self.comm_counts.clear() + self.comm_module_counts.clear() + self.comm_module_counts["Global"] = {} + self.comm_module_counts["Global"]["forward"] = defaultdict(int) + self.comm_module_counts["Global"]["backward"] = defaultdict(int) + + self.comm_module_operation_counts.clear() + + super().__enter__() + self.advanced_module_tracker.__enter__() + return self + + # pyrefly: ignore [bad-override] + def __exit__(self, *args): + self.advanced_module_tracker.__exit__() + super().__exit__(*args) + + def log_comm_debug_tracing_table_to_file( + self, file_name="comm_mode_log.txt", noise_level=3 + ): + """ + Alternative to console CommDebugMode output, writes to file specified by the user + """ + ansi_escape = re.compile(r"\x1B\[[0-?]*[ -/]*[@-~]") + table = ansi_escape.sub("", self.generate_comm_debug_tracing_table(noise_level)) + + with open(file_name, "w") as log_file: + log_file.write(table) + + def _set_noise_parameters(self, noise_level): + """ + sets variables controlling what information displays based on noise level + """ + include_DTensor_ops = False + include_module_data = False + include_ops = False + include_trivial_ops = False + + if noise_level > 0: + include_DTensor_ops = True + include_module_data = True + + if noise_level > 1: + include_ops = True + + if noise_level > 2: + include_trivial_ops = True + + return ( + include_DTensor_ops, + include_module_data, + include_ops, + include_trivial_ops, + ) + + def __torch_dispatch__(self, func, types, args=(), kwargs=None): + # When running this mode with DTensor, ordinarily all modes will + # run **before** subclasses get a chance to run. + # Returning NotImplemented here gives us a chance to let DTensor + # run and desugar into comms ops, before CommDebugMode sees them. + + # sets up operation-level collective count + if self.advanced_module_tracker.name not in self.comm_module_operation_counts: + # dictionary should hold module input and output shape, operations list and collective counter + self.comm_module_operation_counts[self.advanced_module_tracker.name] = { + "operations_list": [] + } + operation_dict = {} + operation_dict["name"] = func + + operation_dict["input_shape"] = [] + operation_dict["input_sharding"] = [] + operation_dict["device_mesh"] = "" + + # tracks if the operation is part of the backward pass + operation_dict["is_bw"] = self.advanced_module_tracker.is_bw + + # tracks if the operation is part of activation checkpointing + operation_dict["is_activation_checkpointing"] = ( + self.advanced_module_tracker.activation_checkpointing + ) + + if any(t == DTensor for t in types): + for ele in args: + if isinstance(ele, DTensor): + # saves shapes and placements of all DTensor args + operation_dict["input_shape"].append(ele.shape) + operation_dict["input_sharding"].append(ele.placements) + operation_dict["device_mesh"] = str(ele.device_mesh) + + self.comm_module_operation_counts[self.advanced_module_tracker.name][ + "operations_list" + ].append(operation_dict) + + return NotImplemented + + kwargs = kwargs if kwargs else {} + out = func(*args, **kwargs) + func_packet = func._overloadpacket + + # We have many tests that use CommDebugMode to verify the occurrence of + # collectives. These tests do so by querying comm_counts with legacy + # funcol ops as key. For the purpose of native funcol migration, we + # need these tests to work for both legacy and native funcol. To avoid + # the need to modify all tests to accommodate the two implementations, + # we make CommDebugMode translate native funcol ops into legacy funcol + # ops until the migration finishes. + + if func_packet in self.comm_registry or func_packet in c10d_collective_ops: + if func_packet in NATIVE_TO_PY_MAPPING: + func_packet = NATIVE_TO_PY_MAPPING[func_packet] + self.comm_counts[func_packet] += 1 + + key = "forward" + if self.advanced_module_tracker.is_bw: + key = "backward" + + # adds collective count to current module + if self.advanced_module_tracker.name not in self.comm_module_counts: + self.comm_module_counts[self.advanced_module_tracker.name] = {} + self.comm_module_counts[self.advanced_module_tracker.name][ + "forward" + ] = defaultdict(int) + self.comm_module_counts[self.advanced_module_tracker.name][ + "backward" + ] = defaultdict(int) + self.comm_module_counts[self.advanced_module_tracker.name][key][ + func_packet + ] += 1 + + # adds collective count to parent modules + for par in self.advanced_module_tracker.module_parents_dict[ + self.advanced_module_tracker.name + ]: + # makes sure we aren't double counting when current sub-module hasn't been removed from parents + if par != self.advanced_module_tracker.name: + if par not in self.comm_module_counts: + self.comm_module_counts[par] = {} + self.comm_module_counts[par]["forward"] = defaultdict(int) + self.comm_module_counts[par]["backward"] = defaultdict(int) + self.comm_module_counts[par][key][func_packet] += 1 + + # if tensor op uses fake tensors, return + if detect_fake_mode(args): + return out + + # add tensor operation to module operation list + self.comm_module_operation_counts[self.advanced_module_tracker.name][ + "operations_list" + ].append(operation_dict) + + return out + + def __repr__(self): + return f"CommDebugMode(get_total_counts()={self.get_total_counts()})" diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/debug/_op_coverage.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/debug/_op_coverage.py new file mode 100644 index 0000000000000000000000000000000000000000..7315d64d697a88f012e0bd67aa2e3e6e1141d7e1 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/debug/_op_coverage.py @@ -0,0 +1,106 @@ +# mypy: allow-untyped-defs +from operator import itemgetter + +import torch +import torch.fx +import torch.nn as nn +from functorch.compile import make_boxed_func +from torch._functorch.compilers import aot_module +from torch._inductor.decomposition import select_decomp_table +from torch.distributed.tensor import DTensor + + +inductor_decomps = select_decomp_table() + +graphs: list[torch.fx.GraphModule] = [] + + +def fwd_bwd_compiler(fx_g, _): + graphs.append(fx_g) + return make_boxed_func(fx_g) + + +def get_inductor_decomp_graphs(model: nn.Module, args, kwargs): + """ + Obtain forward and backward graphs of a model with inductor decompositions using tracing and aot_module. + + Convenient util to get the fwd and bwd graphs of an arbitrary model + with inductor decompositions. Note that this would simply do tracing + with aot_module and don't ensure correctness. This is useful to track + the ops needed in DTensor. + """ + compiled_mod = aot_module( + model, fw_compiler=fwd_bwd_compiler, decompositions=inductor_decomps + ) + output = compiled_mod(*args, **kwargs) + + if output.ndim != 0: + # if output is not a scalar tensor, by default sum it in order to + # run backward + output = output.sum() + + output.backward() + + # one fwd, one bwd graph + assert len(graphs) == 2 + return graphs + + +def print_op_coverage_summary(model: nn.Module, args, kwargs, *, output_csv=False): + """ + Util to print the operator coverage summary of a certain model with tabulute. + + Must have tabulate module installed. + """ + # python module required for summary + import csv + + from tabulate import tabulate + + fwd_graph, bwd_graph = get_inductor_decomp_graphs(model, args, kwargs) + + op_counts = {} + + for node in fwd_graph.graph.nodes: + if node.op == "call_function" and isinstance( + node.target, torch._ops.OpOverload + ): + if node.target not in op_counts: + op_counts[node.target] = 0 + + op_counts[node.target] += 1 + + for node in bwd_graph.graph.nodes: + if node.op == "call_function" and isinstance( + node.target, torch._ops.OpOverload + ): + if node.target not in op_counts: + op_counts[node.target] = 0 + + op_counts[node.target] += 1 + + op_infos = [] + + for op, count in op_counts.items(): + supported = op in DTensor._op_dispatcher.sharding_propagator.op_to_rules + op_infos.append([op, str(op._schema), count, supported]) + + # sort the op info base on the total count index + count_idx = 2 + op_infos.sort(key=itemgetter(count_idx), reverse=True) + + headers = ["Operator", "Schema", "Total Count", "Supported"] + # pyrefly: ignore [bad-argument-type] + print(tabulate(op_infos, headers=headers)) + + if output_csv: + # Open a CSV file for writing + with open("op_summary.csv", "w", newline="") as csv_file: + # Create a CSV writer object + csv_writer = csv.writer(csv_file) + + csv_writer.writerow(headers) + # Write each table row to the CSV file + for row in op_infos: + # pyrefly: ignore [bad-argument-type] + csv_writer.writerow(row) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/debug/_visualize_sharding.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/debug/_visualize_sharding.py new file mode 100644 index 0000000000000000000000000000000000000000..20dd0c3e9f4b47f5e8427855221b9e0c10535377 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/debug/_visualize_sharding.py @@ -0,0 +1,227 @@ +# mypy: allow-untyped-defs +import importlib.util + +import numpy as np + +from torch._prims_common import ShapeType +from torch.distributed.tensor._utils import _compute_local_shape_and_global_offset + + +__all__ = ["visualize_sharding"] + +Color = tuple[float, float, float] + + +def _create_table( + shards: list[tuple[tuple[int, int], tuple[int, int], int]], device_kind: str = "" +): + """ + Creates a tabulate table given row and column ranges with device name + """ + from tabulate import tabulate + + # Extract unique row and column ranges + row_ranges = sorted({block[0] for block in shards}) + col_ranges = sorted({block[1] for block in shards}) + + # Create a matrix initialized with empty strings + matrix = [["" for _ in col_ranges] for _ in row_ranges] + + # Fill the matrix with values + for block in shards: + row_index = row_ranges.index(block[0]) + col_index = col_ranges.index(block[1]) + if matrix[row_index][col_index] == "": + matrix[row_index][col_index] = device_kind + ":" + str(block[2]) + else: + matrix[row_index][col_index] += "," + str(block[2]) + + # Prepare headers + row_headers = [f"Row {r[0]}-{r[1]}" for r in row_ranges] + col_headers = [f"Col {c[0]}-{c[1]}" for c in col_ranges] + + return tabulate(matrix, headers=col_headers, showindex=row_headers) + + +def make_color_iter(color_map, num_rows, num_cols): + num_colors = num_rows * num_cols + for idx in range(num_colors): + yield color_map(idx) + + +def _canonicalize_color(color: Color) -> str: + if isinstance(color, str): + return color + r, g, b = (int(a * 255) for a in color) + return f"#{r:02X}{g:02X}{b:02X}" + + +def _get_text_color(color: str) -> str: + r, g, b = map(lambda x: int(x, 16), (color[1:3], color[3:5], color[5:7])) # noqa: C417 + if (r * 0.299 + g * 0.587 + b * 0.114) > 186: + return "#000000" + return "#ffffff" + + +def _create_rich_table( + shape: ShapeType, + shards: list[tuple[tuple[int, int], tuple[int, int], int]], + device_kind: str = "", + scale: float = 1.0, + min_width: int = 9, + max_width: int = 80, +): + import matplotlib + import rich.align + import rich.box + import rich.console + import rich.padding + import rich.style + import rich.table + + dtensor_height = shape[0] + dtensor_width = shape[1] if len(shape) == 2 else 1 + + row_ranges = sorted({s[0] for s in shards}) + col_ranges = sorted({s[1] for s in shards}) + num_rows, num_cols = len(row_ranges), len(col_ranges) + + console = rich.console.Console(width=max_width) + use_color = console.color_system + color_iter = make_color_iter(matplotlib.colormaps["tab20b"], num_rows, num_cols) + + base_height = int(10 * scale) + aspect_ratio = (shape[1] if len(shape) == 2 else 1) / shape[0] + base_width = int(base_height * aspect_ratio) + height_to_width_ratio = 2.5 + + table = rich.table.Table( + show_header=False, + show_lines=not use_color, + padding=0, + highlight=not use_color, + pad_edge=False, + box=rich.box.SQUARE if not use_color else None, + ) + for row in range(num_rows): + table_row = [] + for col in range(num_cols): + entry = ( + device_kind + + ":" + + ",".join( + [ + str(device_id) + for row_range, col_range, device_id in shards + if row_range == row_ranges[row] and col_range == col_ranges[col] + ] + ) + ) + width = (col_ranges[col][1] - col_ranges[col][0]) / dtensor_width + width = int(width * base_width * height_to_width_ratio) + height = (row_ranges[row][1] - row_ranges[row][0]) / dtensor_height + height = int(height * base_height) + left_padding, remainder = divmod(width - len(entry) - 2, 2) + right_padding = left_padding + remainder + top_padding, remainder = divmod(height - 2, 2) + bottom_padding = top_padding + remainder + if use_color: + color = _canonicalize_color(next(color_iter)[:3]) + text_color = _get_text_color(color) + top_padding += 1 + bottom_padding += 1 + left_padding += 1 + right_padding += 1 + else: + color = None + text_color = None + padding = ( + max(top_padding, 0), + max(right_padding, 0), + max(bottom_padding, 0), + max(left_padding, 0), + ) + table_row.append( + rich.padding.Padding( + rich.align.Align(entry, "center", vertical="middle"), + padding, + style=rich.style.Style(bgcolor=color, color=text_color), + ) + ) + table.add_row(*table_row) + console.print(table, end="\n\n") + + +def visualize_sharding(dtensor, header="", use_rich: bool = False): + """ + Visualizes sharding in the terminal for :class:`DTensor` that are 1D or 2D. + + .. note:: This requires the ``tabulate`` package, or ``rich`` and ``matplotlib``. + No sharding info will be printed for empty tensors + """ + if dtensor.numel() == 0: # Do not print empty dtensors. + return + + if len(dtensor.shape) >= 3: + raise RuntimeError("visualize sharding supports only 1D or 2D DTensor") + + if dtensor.device_mesh.get_coordinate() is None: # current rank is not in the mesh + return + + # Only display the visualization once for each DTensor, on the rank whose + # coordinate is 0 on all dimensions. For example, if the mesh is a full mesh, + # we will only print on rank 0. + local_rank_zero_on_all_dim = all( + dtensor.device_mesh.get_local_rank(mesh_dim=dim) == 0 + for dim in range(dtensor.device_mesh.ndim) + ) + if not local_rank_zero_on_all_dim: + return + + device_coords = { + int(device_index.item()): list(coord) + for coord, device_index in np.ndenumerate( + np.array(dtensor.device_mesh.mesh.tolist()) + ) + } + + device_shard_shape_and_offsets = { + device_index: _compute_local_shape_and_global_offset( + dtensor.shape, + dtensor.device_mesh.shape, + device_coords[device_index], + dtensor.placements, + ) + for device_index in device_coords + } + + # Extend shards in a 1D tensor to 2D + device_shard_shape_and_offsets = { + device_index: ( + shape if len(shape) == 2 else (shape[0], 1), + offset if len(offset) == 2 else (offset[0], 0), + ) + for device_index, (shape, offset) in device_shard_shape_and_offsets.items() + } + + shards = [ + ( + (offset[0], offset[0] + shape[0] - 1), + (offset[1], offset[1] + shape[1] - 1), + device_index, + ) + for device_index, (shape, offset) in device_shard_shape_and_offsets.items() + ] + + if ( + importlib.util.find_spec("rich") + and importlib.util.find_spec("matplotlib") + and use_rich + ): + _create_rich_table( + dtensor.shape, shards, device_kind=dtensor.device_mesh.device_type + ) + elif importlib.util.find_spec("tabulate"): + print(_create_table(shards, device_kind=dtensor.device_mesh.device_type)) + else: + raise ValueError("`visualize_sharding` requires either `rich` or `tabulate`.") diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/device_mesh.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/device_mesh.py new file mode 100644 index 0000000000000000000000000000000000000000..ca59ded5eb52bc0a3878e76077ad2879df4bf499 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/device_mesh.py @@ -0,0 +1,9 @@ +from torch.distributed.device_mesh import ( # noqa: F401 + _get_device_handle, + _mesh_resources, + DeviceMesh, + init_device_mesh, +) + + +__all__ = ["init_device_mesh", "DeviceMesh"] diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/experimental/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/experimental/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e9e4cd25f47265dd7e3c0101b3b1b6ab3a3ecec3 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/experimental/__pycache__/__init__.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/experimental/__pycache__/_attention.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/experimental/__pycache__/_attention.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d98ecbe3592d6df90f4d87af6f55b7abc895023c Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/experimental/__pycache__/_attention.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/experimental/__pycache__/_func_map.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/experimental/__pycache__/_func_map.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c87abbeaeb6a11013a70f5d642a628ce44d6adbe Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/experimental/__pycache__/_func_map.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/experimental/__pycache__/_register_sharding.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/experimental/__pycache__/_register_sharding.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..275a7725b51b159f6727e287811994af054da6ff Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/experimental/__pycache__/_register_sharding.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/experimental/__pycache__/_tp_transform.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/experimental/__pycache__/_tp_transform.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b9e110e39322249ae71995c29c785fbc0bc165d8 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/experimental/__pycache__/_tp_transform.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/experimental/_context_parallel/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/experimental/_context_parallel/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..009255631796fc56b6607ae46b6ae4f91589e83b --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/experimental/_context_parallel/__init__.py @@ -0,0 +1,46 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates +# Context Parallel components + +from ._attention import ( + _CausalBehavior, + _context_parallel_shard, + _ContextParallel, + _cp_options, + _disable_context_parallel_dispatcher, + _enable_context_parallel_dispatcher, + _is_causal_behavior, + _RotateMethod, + context_parallel, + context_parallel_unshard, + set_rotate_method, +) +from ._cp_custom_ops import flex_cp_allgather +from ._load_balancer import ( + _HeadTailLoadBalancer, + _LoadBalancer, + _PerDocumentHeadTailLoadBalancer, + _PTRRLoadBalancer, +) + + +__all__ = [ + # From _attention + "_CausalBehavior", + "_context_parallel_shard", + "_ContextParallel", + "_cp_options", + "_disable_context_parallel_dispatcher", + "_enable_context_parallel_dispatcher", + "_is_causal_behavior", + "_RotateMethod", + "context_parallel", + "context_parallel_unshard", + "set_rotate_method", + # From _cp_custom_ops + "flex_cp_allgather", + # From _load_balancer + "_HeadTailLoadBalancer", + "_LoadBalancer", + "_PerDocumentHeadTailLoadBalancer", + "_PTRRLoadBalancer", +] diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/experimental/_context_parallel/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/experimental/_context_parallel/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..68cdb8dfbcd6d250efb5199dd3f11da853249868 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/experimental/_context_parallel/__pycache__/__init__.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/experimental/_context_parallel/__pycache__/_attention.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/experimental/_context_parallel/__pycache__/_attention.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..09b95c7e999f2614e710f9557ab25089106d1570 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/experimental/_context_parallel/__pycache__/_attention.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/experimental/_context_parallel/__pycache__/_cp_custom_ops.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/experimental/_context_parallel/__pycache__/_cp_custom_ops.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..142853253d02bfd04d8601f9de70e394327dbf99 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/experimental/_context_parallel/__pycache__/_cp_custom_ops.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/experimental/_context_parallel/__pycache__/_load_balancer.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/experimental/_context_parallel/__pycache__/_load_balancer.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..184ebef80e5e313d3f5a3b4806b44d947ff0630d Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/experimental/_context_parallel/__pycache__/_load_balancer.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/experimental/_context_parallel/__pycache__/_sharding_rules.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/experimental/_context_parallel/__pycache__/_sharding_rules.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9fa46855118dae19372f27219221c5e48b531e6b Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/experimental/_context_parallel/__pycache__/_sharding_rules.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/experimental/_context_parallel/_attention.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/experimental/_context_parallel/_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..9a1c6299dfca4912736feddd818f33e1e618e8d9 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/experimental/_context_parallel/_attention.py @@ -0,0 +1,1675 @@ +import contextlib +import itertools +import logging +import types +from abc import ABC, abstractmethod +from collections.abc import Callable, Generator, Mapping, Sequence +from dataclasses import dataclass +from enum import auto, Enum +from functools import partial +from typing import Any, cast, Protocol, TypeAlias + +import torch +import torch.distributed as dist +import torch.distributed._functional_collectives as ft_c +import torch.distributed.distributed_c10d as c10d +import torch.nn as nn +import torch.nn.functional as F +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor import distribute_tensor, DTensor, Shard +from torch.distributed.tensor.parallel import ParallelStyle +from torch.nn.attention.flex_attention import ( + _mask_mod_signature, + BlockMask, + create_block_mask, +) +from torch.utils._pytree import tree_flatten, tree_unflatten + +from ._cp_custom_ops import flex_cp_allgather +from ._load_balancer import _create_default_load_balancer, _LoadBalancer + + +__all__ = [ + "_CausalBehavior", + "_context_parallel_shard", + "_ContextParallel", + "_cp_options", + "_disable_context_parallel_dispatcher", + "_enable_context_parallel_dispatcher", + "_is_causal_behavior", + "_RotateMethod", + "context_parallel", + "context_parallel_unshard", + "set_rotate_method", +] + + +class _CausalBehavior(Enum): + SKIP = None + NOT_IS_CAUSAL = False + IS_CAUSAL = True + + +class _RotateMethod(Enum): + ALL_TO_ALL = auto() + ALL_GATHER = auto() + + +aten = torch.ops.aten +logger = logging.getLogger(__name__) + + +class _DispatchMode(Enum): + MONKEY_PATCH = auto() + MODULE_WRAPPER = auto() + + +_dispatch_mode: _DispatchMode = _DispatchMode.MONKEY_PATCH + + +@dataclass +class _ContextParallelOptions: + # Whether to upcast parameters and gradients to float32 to avoid accumulation + # errors. It is likely this is always True, but we currently keep this variable + # for experimental purposes. + convert_to_f32: bool = True + enable_load_balance: bool = True + rotate_method: _RotateMethod = _RotateMethod.ALL_GATHER + + +_cp_options = _ContextParallelOptions() + + +def _is_causal_behavior( + rank: int, world_size: int, i: int, is_causal: bool +) -> _CausalBehavior: + """ + Calculate is_causal behavior for each KV block. The attention can either be + calculated in full, not at all or with the causal mask applied. + """ + if not is_causal: + return _CausalBehavior.NOT_IS_CAUSAL + + if i == 0: + return _CausalBehavior.IS_CAUSAL + + source_rank = (rank - i) % world_size + if source_rank < rank or _cp_options.enable_load_balance: + return _CausalBehavior.NOT_IS_CAUSAL + else: + return _CausalBehavior.SKIP + + +def _maybe_wait(tensor: torch.Tensor) -> torch.Tensor: + """ + When tracing the code, the result tensor is not an AsyncCollectiveTensor, + so we cannot call ``wait()``. + """ + if isinstance(tensor, ft_c.AsyncCollectiveTensor): + return tensor.wait() + return tensor + + +def _partial_update( + original: torch.Tensor, + new: torch.Tensor, + dim: int, + n_chunks: int, + idx: int, + add: bool, +) -> torch.Tensor: + """ + This API partially updates a chunk of ``original`` tensor. The ``original`` + tensor will be first chunked along ``dim`` dimension, then the ``idx`` chunk + will be updated with ``new``. If ``add`` is True, the chunk will be added + with ``new``, otherwise the chunk will be replaced by ``new``. + + The result is a tensor that is the same size as ``original``. + """ + chunks = list(original.chunk(n_chunks, dim=dim)) + assert chunks[idx].shape == new.shape, (original.shape, new.shape, idx) + if add: + chunks[idx] += new + else: + chunks[idx] = new + return torch.cat(chunks, dim=dim) + + +class _SDPAMerger: + """A class to help merge the local SDPA result.""" + + def __init__(self, convert_to_f32: bool, seq_dim: int): + self._seq_dim = seq_dim + self._out: torch.Tensor | None = None + self._lse: torch.Tensor | None = None + self._should_lse_squeeze = False + self._convert_to_f32 = convert_to_f32 + self._out_dtype = torch.float32 + self._lse_dtype = torch.float32 + + def _merge_one( + self, block_out: torch.Tensor, block_lse: torch.Tensor, partial: bool + ) -> None: + # The cuDNN backend preserves the last dimension for LSE. + # Apply unsqueeze only if the input does not already have + # the required dimensionality. + if len(block_lse.shape) < len(block_out.shape): + block_lse = block_lse.unsqueeze(dim=-1) + self._should_lse_squeeze = True + assert len(block_lse.shape) == len(block_out.shape) + + if self._lse is None: + self._lse = block_lse + self._out = block_out + else: + ROUND_ROBIN_CYCLE = 2 + assert self._lse is not None + assert self._out is not None + lse = ( + self._lse.chunk(ROUND_ROBIN_CYCLE, dim=self._seq_dim)[1] + if partial + else self._lse + ) + out = ( + self._out.chunk(ROUND_ROBIN_CYCLE, dim=self._seq_dim)[1] + if partial + else self._out + ) + + # The algorithm from + # github.com/zhuzilin/ring-flash-attention/pull/34#issuecomment-2076126795 + # gives a relatively stable result. + out = out - F.sigmoid(block_lse - lse) * (out - block_out) + lse = lse - F.logsigmoid(lse - block_lse) + if partial: + self._lse = _partial_update( + self._lse, + lse, + dim=self._seq_dim, + n_chunks=ROUND_ROBIN_CYCLE, + idx=1, + add=False, + ) + self._out = _partial_update( + self._out, + out, + dim=self._seq_dim, + n_chunks=ROUND_ROBIN_CYCLE, + idx=1, + add=False, + ) + else: + self._lse = lse + self._out = out + + def step(self, out: torch.Tensor, lse: torch.Tensor, partial: bool) -> None: + self._out_dtype = out.dtype + self._lse_dtype = lse.dtype + + if self._convert_to_f32: + out = out.to(torch.float32) + lse = lse.to(torch.float32) + + self._merge_one(out, lse, partial) + + def results(self) -> tuple[torch.Tensor, torch.Tensor]: + assert self._out is not None + assert self._lse is not None + out = self._out.to(self._out_dtype) + if self._should_lse_squeeze: + lse = self._lse.squeeze(-1).to(self._lse_dtype) + else: + lse = self._lse.to(self._lse_dtype) + return out, lse + + +class _AttentionOp(Protocol): + def __call__( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + **kwargs: object, + ) -> tuple[torch.Tensor, ...]: ... + + +class _RingRotater(ABC): + @abstractmethod + def __init__(self, pg: dist.ProcessGroup, seq_dim: int) -> None: ... + + @abstractmethod + def exchange_buffers(self, curr_buffer: torch.Tensor) -> None: ... + + @abstractmethod + def next_buffer(self) -> torch.Tensor: ... + + +class _AllToAllRotater(_RingRotater): + """Use all_to_all to send the kv to the next rank.""" + + def __init__(self, pg: dist.ProcessGroup, seq_dim: int) -> None: + self._pg = pg + self._seq_dim = seq_dim + self._buffer: torch.Tensor | None = None + + def exchange_buffers(self, curr_buffer: torch.Tensor) -> None: + curr_buffer = curr_buffer.contiguous() + size = dist.get_world_size(self._pg) + dsts = list(range(1, size)) + [0] + self._buffer = ft_c.permute_tensor(curr_buffer, dsts, self._pg) + + def next_buffer(self) -> torch.Tensor: + assert self._buffer is not None + return _maybe_wait(self._buffer) + + +class _AllGatherRotater(_RingRotater): + """ + Allgather the kv and return only the required kv. + Only one communication will be done. + """ + + def __init__(self, pg: dist.ProcessGroup, seq_dim: int) -> None: + self._pg = pg + self._seq_dim = seq_dim + self._aggregated_buffer: torch.Tensor | None = None + self._idx = 0 + + def exchange_buffers(self, curr_buffer: torch.Tensor) -> None: + # We only need to perform allgather once. + self._idx += 1 + if self._aggregated_buffer is None: + self._aggregated_buffer = ft_c.all_gather_tensor( + curr_buffer.contiguous(), gather_dim=0, group=self._pg + ) + + def next_buffer(self) -> torch.Tensor: + rank = dist.get_rank(self._pg) + idx = rank - self._idx + + assert self._aggregated_buffer is not None + self._aggregated_buffer = _maybe_wait(self._aggregated_buffer) + return self._aggregated_buffer.chunk(dist.get_world_size(self._pg))[idx] + + +def _create_rotater( + pg: dist.ProcessGroup, seq_dim: int, method: _RotateMethod | None = None +) -> _RingRotater: + if method is None: + method = _cp_options.rotate_method + + if method == _RotateMethod.ALL_TO_ALL: + return _AllToAllRotater(pg, seq_dim) + elif method == _RotateMethod.ALL_GATHER: + return _AllGatherRotater(pg, seq_dim) + else: + raise NotImplementedError(f"Unknown method {method}") + + +def _templated_ring_attention( + group: dist.ProcessGroup, + seq_dim: int, + op: _AttentionOp, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + is_causal: bool = False, + **kwargs: object, +) -> tuple[torch.Tensor, ...]: + """ + A generalized ring attention implementation that can support multiple attention ops. + + Note [Context parallelism load balance algorithm for causal masking] + ===================== + This explanation uses an example to illustrate the CP algorithm with causal + masking. + + Consider a scenario where the sequence length of q, k, and v is 4 (e.g., + q = (q0, q1, q2, q3)), and there are two ranks. For simplicity, we will discuss + only q and k, as v follows the same pattern as k. + + The diagram below represents a complete QK^T operation without parallelism. + The `****` entries indicate that the result is not required due to causal + masking (e.g., q0k1 is marked as `****`). + + +----+------------------------+ + | | k0 k1 k2 k3 | + +----+------------------------+ + | q0 | q0k0, ****, ****, **** | + | q1 | q1k0, q1k1, ****, **** | + | q2 | q2k0, q2k1, q2k2, **** | + | q3 | q3k0, q3k1, q3k2, q3k3 | + +----+------------------------+ + + ### No Load Balance: + + In this scenario, each rank owns a local chunk of q, k, and v, with each chunk + containing two elements. Rank0 is responsible for managing (q0, q1) and (k0, k1), + while rank1 manages (q2, q3) and (k2, k3). + + First Iteration: Both rank0 and rank1 perform SDPA with their local qkv pairs. + Causal masking is enabled as some results are not required (e.g., q0k1). + + Second Iteration: Local queries remain the same, but local kv pairs are exchanged. + Rank0 now has (q0, q1) and (k2, k3); rank1 has (q2, q3) and (k0, k1). Rank0 performs + no computation, while rank1 computes locally without causal masking since all results + (q2k0, q2k1, q3k0, q3k1) are needed. + + ### Round-robin Load Balance: + + In this setup, each rank owns two local chunks of q, k, and v, with each chunk + containing one element. Rank0 manages (q0, q3) and (k0, k3); Rank1 manages (q1, q2) + and (k1, k2). Although the local chunks are not consecutive, they are concatenated to + enable SDPA to be performed in a single call for each step. Consequently, the chunk() + function may be required to prepare the correct q, k, and v configurations. + + First Iteration: Both ranks perform SDPA with their local qkv pairs, similar to the + no-load-balance case. This iteration corresponds to the `if` of the + (`if, `elif`, `else`) in the implementation. + + Second Iteration: Rank0 now has (q0, q3) and (k1, k2); rank1 has (q1, q2) and + (k0, k3). For rank0, no computation is needed for q0. However, computations for + q3k1 and q3k2 are required, so only q3 is used for SDPA. This corresponds to the + `else` of the (`if`, `elif`, `else`) in the implementation. + For rank1, k3 is not needed for q1 and q2, so only k0 is used for SDPA. This + corresponds to the `elif` of (`if`, `elif`, `else`) in the implementation. + + Parameters + ---------- + op: + The attention op to use + *args: + additional args are passed to the op + **kwargs: + additional kwargs are passed to the op + + Returns + ------- + out: + The merged attention output + softmax_lse: + The logsumexp of the merged attention output + """ + if is_causal and (query.size(2) != key.size(2)): + raise NotImplementedError( + "is_causal requires the same query and context sequence lengths" + ) + if not is_causal and _cp_options.enable_load_balance: + raise RuntimeError("Load balancing requires `is_causal=True`.") + + assert isinstance(group, dist.ProcessGroup), ( + "process group must be single dimension" + ) + rank = dist.get_rank(group) + size = dist.get_world_size(group) + + next_kv = None + + # Without making key and value contiguous(), the loss curve is bad. + # TODO(fegin): figure out why this is a requirement since SDPA does not have + # this requirement. + key = key.contiguous() + value = value.contiguous() + + sdpa_merger = _SDPAMerger(_cp_options.convert_to_f32, seq_dim=seq_dim) + + rest: list[Any] + out: torch.Tensor + logsumexp: torch.Tensor + + rotater = _create_rotater(group, 2) + + for i in range(size): + if i > 0: + # Wait for the kv from the (cp_rank - 1) rank. + next_kv = rotater.next_buffer() + key = next_kv[: key.numel()].reshape(key.shape) + value = next_kv[key.numel() :].reshape(value.shape) + + if i < (size - 1): + # Send the k, v to the next rank + next_kv = torch.cat([key.flatten(), value.flatten()]) + next_kv = rotater.exchange_buffers(next_kv) + + is_causal_behavior = _is_causal_behavior( + rank=rank, world_size=size, i=i, is_causal=is_causal + ) + + # For a detailed understanding of the load balancing algorithm, see + # Note [Context parallelism load balance algorithm for causal masking] + if is_causal_behavior == _CausalBehavior.SKIP: + # If i > rank and load balancing is not turned on. + continue + + if i == 0 or (not _cp_options.enable_load_balance or not is_causal): + # When local balance is enabled, we still need to do SDPA with + # the both local chunks of q, k, v for the first iteration. + q, k, v, partial = (query, key, value, False) + elif i <= rank: + # Round-robin load balancing case, and i <= rank. + # We need to do SDPA with only the first local chunk of k, v. + # Note that q, k, v each contains two local chunks. + ROUND_ROBIN_CYCLE = 2 + q, k, v, partial = ( + query, + key.chunk(ROUND_ROBIN_CYCLE, dim=2)[0], + value.chunk(ROUND_ROBIN_CYCLE, dim=2)[0], + False, + ) + else: + # Round-robin load balancing case, and i > rank. + # We need to do SDPA with only the second half of q, and update + # only the second part of logsumexp. So partial is True. + # Note that q, k, v each contains two chunks. + q, k, v, partial = query.chunk(2, dim=2)[1], key, value, True + + # See https://github.com/pytorch/pytorch/blob/release/2.4/aten/src/ATen/native/native_functions.yaml#L14695 + # for the SDPA kernel definitions. + out, logsumexp, *rest = op( + q, + k, + v, + is_causal=is_causal_behavior.value, + **kwargs, + ) + sdpa_merger.step(out, logsumexp, partial) + + # pyrefly: ignore [unbound-name] + return *sdpa_merger.results(), *rest + + +def _templated_ring_attention_backward( + group: dist.ProcessGroup, + seq_dim: int, + op: _AttentionOp, + grad_out: torch.Tensor, + grad_out_name: str, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + out: torch.Tensor, + logsumexp: torch.Tensor, + is_causal: bool, + **kwargs: Any, +) -> tuple[torch.Tensor, ...]: + """This API implements the backward pass of the ring attention.""" + if not is_causal and _cp_options.enable_load_balance: + raise RuntimeError("Load balancing requires `is_causal=True`.") + rank = dist.get_rank(group) + size = dist.get_world_size(group) + next_kv = None + next_grad_kv = None + rest: list[Any] + grad_query_, grad_key_, grad_value_ = None, None, None + + accum_dtype = torch.float32 if _cp_options.convert_to_f32 else query.dtype + grad_query = torch.zeros_like(query, dtype=accum_dtype) + grad_key = torch.zeros_like(key, dtype=accum_dtype) + grad_value = torch.zeros_like(value, dtype=accum_dtype) + + key = key.contiguous() + value = value.contiguous() + kv_rotater = _create_rotater(group, 2) + dkv_rotater = _create_rotater(group, 2, method=_RotateMethod.ALL_TO_ALL) + for i in range(size): + if i > 0: + # Wait for the kv from the (cp_rank - 1) rank. + buffer = kv_rotater.next_buffer() + pointer = 0 + key = buffer[pointer : pointer + key.numel()].reshape(key.shape) + pointer += key.numel() + value = buffer[pointer : pointer + value.numel()].reshape(value.shape) + pointer += value.numel() + + if i != size - 1: + # Send the kv to the next rank. + next_kv = torch.cat([key.flatten(), value.flatten()]) + kv_rotater.exchange_buffers(next_kv) + + is_causal_behavior = _is_causal_behavior( + rank=rank, world_size=size, i=i, is_causal=is_causal + ) + + if is_causal_behavior != _CausalBehavior.SKIP: + if i == 0 or (not _cp_options.enable_load_balance or not is_causal): + # We need to do SDPA with the full local q, k, v. + q, k, v, out_, dout, lse = (query, key, value, out, grad_out, logsumexp) + elif i <= rank: + # Round-robin load balancing case, and i <= rank. + # We need to do SDPA with only the first half of k, v. + # Note that q, k, v each contains two chunks. + q, k, v, out_, dout, lse = ( + query, + key.chunk(2, dim=seq_dim)[0], + value.chunk(2, dim=seq_dim)[0], + out, + grad_out, + logsumexp, + ) + else: + # Round-robin load balancing case, and i > rank. + # We need to do SDPA with only the second half of q. + # Note that q, k, v each contains two chunks. + q, k, v, out_, dout, lse = ( + query.chunk(2, dim=seq_dim)[1], + key, + value, + out.chunk(2, dim=seq_dim)[1], + grad_out.chunk(2, dim=seq_dim)[1], + # Need to make logsumexp contiguous, otherwise there will + # be numerical error. + logsumexp.chunk(2, dim=seq_dim)[1].contiguous(), + ) + + kwargs[grad_out_name] = dout + # See https://github.com/pytorch/pytorch/blob/release/2.4/aten/src/ATen/native/native_functions.yaml#L14695 + # for the SDPA kernel definitions. + grad_query_, grad_key_, grad_value_, *rest = op( + query=q, + key=k, + value=v, + out=out_, + logsumexp=lse, + is_causal=is_causal_behavior.value, + **kwargs, + ) + else: + grad_query_ = torch.zeros_like(query, dtype=accum_dtype) + grad_key_ = torch.zeros_like(key, dtype=accum_dtype) + grad_value_ = torch.zeros_like(value, dtype=accum_dtype) + + ROUND_ROBIN_CYCLE = 2 + if i == 0: + grad_key += grad_key_ + grad_value += grad_value_ + else: + pointer = 0 + # Wait for the kv gradient from (cp_rank - 1) rank. + next_grad_kv = dkv_rotater.next_buffer() + grad_key = next_grad_kv[pointer : pointer + grad_key.numel()].reshape( + grad_key.shape + ) + pointer += grad_key.numel() + grad_value = next_grad_kv[pointer : pointer + grad_value.numel()].reshape( + grad_value.shape + ) + + if i <= rank and _cp_options.enable_load_balance: + grad_key = _partial_update( + grad_key, + grad_key_, + dim=seq_dim, + n_chunks=ROUND_ROBIN_CYCLE, + idx=0, + add=True, + ) + grad_value = _partial_update( + grad_value, + grad_value_, + dim=seq_dim, + n_chunks=ROUND_ROBIN_CYCLE, + idx=0, + add=True, + ) + else: + grad_key += grad_key_ + grad_value += grad_value_ + + next_grad_kv = torch.cat([grad_key.flatten(), grad_value.flatten()]) + # Send the grad key and grad value to the next rank. + dkv_rotater.exchange_buffers(next_grad_kv) + + if i <= rank or not _cp_options.enable_load_balance: + grad_query += grad_query_ + else: + grad_query = _partial_update( + grad_query, + grad_query_, + dim=seq_dim, + n_chunks=ROUND_ROBIN_CYCLE, + idx=1, + add=True, + ) + + assert grad_key_ is not None + assert grad_value_ is not None + grad_query = grad_query.to(query.dtype) + next_grad_kv = dkv_rotater.next_buffer().to(key.dtype) + grad_key = next_grad_kv[: grad_key.numel()].reshape(grad_key.shape) + grad_value = next_grad_kv[grad_key.numel() :].reshape(grad_value.shape) + return ( + grad_query, + grad_key, + grad_value, + # pyrefly: ignore [unbound-name] + *rest, + ) + + +def _scaled_dot_product_ring_flash_attention( + mesh: DeviceMesh, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + dropout_p: float = 0.0, + is_causal: bool = False, + return_debug_mask: bool = False, + *, + scale: float | None = None, +) -> tuple[torch.Tensor, ...]: + if return_debug_mask: + raise NotImplementedError("return_debug_mask is not supported yet") + + # TODO: remove this hardcoding + seq_dim = 2 + group = mesh.get_group() + return _templated_ring_attention( + group, + seq_dim, + aten._scaled_dot_product_flash_attention, + query=query, + key=key, + value=value, + is_causal=is_causal, + dropout_p=dropout_p, + scale=scale, + ) + + +def _scaled_dot_product_ring_efficient_attention( + mesh: DeviceMesh, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_bias: torch.Tensor | None = None, + compute_log_sumexp: bool = True, + dropout_p: float = 0.0, + is_causal: bool = False, + *, + scale: float | None = None, +) -> tuple[torch.Tensor, ...]: + if attn_bias is not None: + raise NotImplementedError("attn_bias is not supported yet") + + if not compute_log_sumexp: + # CP requires compute_log_sumexp to be True because it always merges LSE + compute_log_sumexp = True + + # TODO: remove this hardcoding + seq_dim = 2 + group = mesh.get_group() + return _templated_ring_attention( + group, + seq_dim, + aten._scaled_dot_product_efficient_attention, + query=query, + key=key, + value=value, + is_causal=is_causal, + attn_bias=attn_bias, + dropout_p=dropout_p, + scale=scale, + compute_log_sumexp=compute_log_sumexp, + ) + + +def _scaled_dot_product_ring_cudnn_attention( + mesh: DeviceMesh, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_bias: torch.Tensor | None = None, + compute_log_sumexp: bool = True, + dropout_p: float = 0.0, + is_causal: bool = False, + return_debug_mask: bool = False, + *, + scale: float | None = None, +) -> tuple[torch.Tensor, ...]: + if attn_bias is not None: + raise NotImplementedError("attn_bias is not supported yet") + + if not compute_log_sumexp: + # CP requires compute_log_sumexp to be True because it always merges LSE + compute_log_sumexp = True + + # TODO: remove this hardcoding + seq_dim = 2 + group = mesh.get_group() + return _templated_ring_attention( + group, + seq_dim, + aten._scaled_dot_product_cudnn_attention, + query=query, + key=key, + value=value, + attn_bias=attn_bias, + compute_log_sumexp=compute_log_sumexp, + dropout_p=dropout_p, + is_causal=is_causal, + return_debug_mask=return_debug_mask, + scale=scale, + ) + + +def _scaled_dot_product_ring_flash_attention_backward( + mesh: DeviceMesh, + grad_out: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + out: torch.Tensor, + logsumexp: torch.Tensor, + cum_seq_q: torch.Tensor, + cum_seq_k: torch.Tensor, + max_q: int, + max_k: int, + dropout_p: float, + is_causal: bool, + philox_seed: torch.Tensor, + philox_offset: torch.Tensor, + *, + scale: float | None = None, +) -> tuple[torch.Tensor, ...]: + # TODO: remove this hardcoding + seq_dim = 2 + group = mesh.get_group() + return _templated_ring_attention_backward( + group, + seq_dim, + aten._scaled_dot_product_flash_attention_backward.default, + grad_out=grad_out, + grad_out_name="grad_out", + query=query, + key=key, + value=value, + out=out, + logsumexp=logsumexp, + is_causal=is_causal, + cum_seq_q=cum_seq_q, + cum_seq_k=cum_seq_k, + max_q=max_q, + max_k=max_k, + dropout_p=dropout_p, + philox_seed=philox_seed, + philox_offset=philox_offset, + scale=scale, + ) + + +def _scaled_dot_product_ring_efficient_attention_backward( + mesh: DeviceMesh, + grad_out: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + bias: torch.Tensor, + out: torch.Tensor, + logsumexp: torch.Tensor, + philox_seed: torch.Tensor, + philox_offset: torch.Tensor, + dropout_p: float, + grad_input_mask: tuple[bool, ...], + is_causal: bool = False, + *, + scale: float | None = None, +) -> tuple[torch.Tensor, ...]: + # TODO: remove this hardcoding + seq_dim = 2 + group = mesh.get_group() + return _templated_ring_attention_backward( + group, + seq_dim, + aten._scaled_dot_product_efficient_attention_backward.default, + grad_out=grad_out, + grad_out_name="grad_out_", + query=query, + key=key, + value=value, + attn_bias=bias, + out=out, + logsumexp=logsumexp, + philox_seed=philox_seed, + philox_offset=philox_offset, + dropout_p=dropout_p, + grad_input_mask=grad_input_mask, + is_causal=is_causal, + scale=scale, + ) + + +def _scaled_dot_product_ring_cudnn_attention_backward( + mesh: DeviceMesh, + grad_out: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + out: torch.Tensor, + logsumexp: torch.Tensor, + philox_seed: torch.Tensor, + philox_offset: torch.Tensor, + attn_bias: torch.Tensor, + cum_seq_q: torch.Tensor, + cum_seq_k: torch.Tensor, + max_q: int, + max_k: int, + dropout_p: float, + is_causal: bool, + *, + scale: float | None = None, +) -> tuple[torch.Tensor, ...]: + # TODO: remove this hardcoding + seq_dim = 2 + group = mesh.get_group() + return _templated_ring_attention_backward( + group, + seq_dim, + aten._scaled_dot_product_cudnn_attention_backward.default, + grad_out=grad_out, + grad_out_name="grad_out", + query=query, + key=key, + value=value, + out=out, + logsumexp=logsumexp, + philox_seed=philox_seed, + philox_offset=philox_offset, + attn_bias=attn_bias, + cum_seq_q=cum_seq_q, + cum_seq_k=cum_seq_k, + max_q=max_q, + max_k=max_k, + dropout_p=dropout_p, + is_causal=is_causal, + scale=scale, + ) + + +def _sdpa_handler( + op_call: torch._ops.OpOverload, + args: tuple[object, ...], + kwargs: dict[str, object], +) -> object: + # extract local tensor and sharding infos to a OpInfo + op_info = DTensor._op_dispatcher.unwrap_to_op_info(op_call, args, kwargs) + logger.debug("Dispatching op_call: %s", op_info.schema) + + # sharding propagation + # TODO: remove the context parallel strategy from the default propagation + # rule. Either figure out how to dynamically enable it or just don't call + # propagate. + DTensor._op_dispatcher.sharding_propagator.propagate(op_info) + output_sharding = op_info.output_sharding + assert output_sharding is not None, "output sharding should not be None" + assert not output_sharding.needs_redistribute, "inputs need to be redistributed" + + call_maps: dict[torch._ops.OpOverload, Callable] = { + aten._scaled_dot_product_flash_attention.default: _scaled_dot_product_ring_flash_attention, + aten._scaled_dot_product_efficient_attention.default: _scaled_dot_product_ring_efficient_attention, + aten._scaled_dot_product_cudnn_attention.default: _scaled_dot_product_ring_cudnn_attention, + aten._scaled_dot_product_flash_attention_backward.default: _scaled_dot_product_ring_flash_attention_backward, + aten._scaled_dot_product_efficient_attention_backward.default: _scaled_dot_product_ring_efficient_attention_backward, + aten._scaled_dot_product_cudnn_attention_backward.default: _scaled_dot_product_ring_cudnn_attention_backward, + } + if op_call in call_maps: + local_results = call_maps[op_call]( + op_info.compute_mesh, + *op_info.local_args, # type: ignore[arg-type] + **op_info.local_kwargs, # type: ignore[arg-type] + ) + else: + raise NotImplementedError( + "CP only supports flash attention and memory efficient attention now." + ) + + return DTensor._op_dispatcher.wrap(local_results, output_sharding.output_spec) + + +custom_ops = { + aten._scaled_dot_product_flash_attention.default: _sdpa_handler, + aten._scaled_dot_product_flash_attention_backward.default: _sdpa_handler, + aten._scaled_dot_product_efficient_attention.default: _sdpa_handler, + aten._scaled_dot_product_efficient_attention_backward.default: _sdpa_handler, + aten._scaled_dot_product_cudnn_attention.default: _sdpa_handler, + aten._scaled_dot_product_cudnn_attention_backward.default: _sdpa_handler, +} +exitsing_custom_ops = DTensor._op_dispatcher._custom_op_handlers + + +ArgsType = tuple[Any, ...] +KwargsType = dict[str, Any] +InputFnType = Callable[[nn.Module | None, ArgsType, KwargsType, DeviceMesh], Any] +OutputFnType = Callable[[nn.Module | None, Any, Any, DeviceMesh], Any] + +_replaced_functions: dict[Callable, tuple[str, Callable]] = {} + + +def _distribute_function( + fn: Callable, + fn_module: types.ModuleType, + device_mesh: DeviceMesh, + input_fn: InputFnType, + output_fn: OutputFnType, +) -> None: + """ + A helper function to replace a function with a distributed version by + using the monkey patching approach. + + This function is for the CP internal usage only. + """ + + def wrapper( + target_fn: Callable, input_fn: InputFnType, output_fn: OutputFnType + ) -> Callable: + def inner_fn(*args: ArgsType, **kwargs: KwargsType) -> Any: + args, kwargs = input_fn(None, args, kwargs, device_mesh) + outputs = target_fn(*args, **kwargs) + return output_fn(None, (args, kwargs), outputs, device_mesh) + + return inner_fn + + global _replaced_functions + + if fn in _replaced_functions: + return + + wrapper_fn = wrapper(fn, input_fn, output_fn) + setattr(fn_module, fn.__name__, wrapper_fn) + _replaced_functions[wrapper_fn] = (fn.__name__, fn) + + +def _restore_function(fn: Callable, fn_module: types.ModuleType) -> None: + """Restore the function that is replaced by _distribute_function.""" + if fn not in _replaced_functions: + return + + original_name, original_fn = _replaced_functions[fn] + setattr(fn_module, original_name, original_fn) + + +def _enable_cp_dtensor_dispatcher() -> None: + """Enables DTensor dispatcher to dispatch SDPA to CP.""" + # Enable custom op handlers for CP + DTensor._op_dispatcher._custom_op_handlers = { + **exitsing_custom_ops, + **custom_ops, + } + # Register CP-specific sharding rules + from ._sharding_rules import register_cp_sharding_rules + + register_cp_sharding_rules() + + +def _disable_cp_dtensor_dispatcher() -> None: + """Disables DTensor dispatcher to dispatch SDPA to CP.""" + # Restore original custom op handlers + DTensor._op_dispatcher._custom_op_handlers = exitsing_custom_ops + + # TODO: unregister_cp_sharding_rules(clear_the_cache=True) will cause + # all DTensor sharding propagation cache being invalidated. It is not + # easy to achieve selectively invalidating lru cache without rewriting + # the sharding propagation wrapper. + + from ._sharding_rules import unregister_cp_sharding_rules + + unregister_cp_sharding_rules(clear_the_cache=False) + + +def _enable_context_parallel_dispatcher_impl(seq_dim: int, mesh: DeviceMesh) -> None: + sdpa_cp = _ContextParallel( + seq_dim=seq_dim, + attention_type=_ContextParallel.AttentionType.SDPA, + ) + + if _dispatch_mode == _DispatchMode.MONKEY_PATCH: + _distribute_function( + F.scaled_dot_product_attention, + F, + mesh, + sdpa_cp.sdpa_input_fn, + sdpa_cp.sdpa_output_fn, + ) + _enable_cp_dtensor_dispatcher() + elif _dispatch_mode == _DispatchMode.MODULE_WRAPPER: + _enable_cp_dtensor_dispatcher() + else: + raise ValueError(f"Unknown dispatch mode: {_dispatch_mode}") + + +def _disable_context_parallel_dispatcher_impl() -> None: + if _dispatch_mode == _DispatchMode.MONKEY_PATCH: + _restore_function(F.scaled_dot_product_attention, F) + elif _dispatch_mode == _DispatchMode.MODULE_WRAPPER: + pass + else: + raise NotImplementedError(f"Unknown dispatch mode: {_dispatch_mode}") + + _disable_cp_dtensor_dispatcher() + + +_compiled_create_block_mask = None + + +def _context_parallel_buffers( + mesh: DeviceMesh, + buffers: list[torch.Tensor | BlockMask], + buffer_seq_dims: list[int], + load_balancer: _LoadBalancer | None = None, +) -> list[torch.Tensor | BlockMask]: + """ + Shard the buffers along the sequence dimensions according to CP rules. + Args: + mesh (:class:`DeviceMesh`): the device mesh for the context parallelism. + buffers (List[torch.Tensor]): the buffers to be sharded. + seq_dims (List[int]): the sequence dimensions of ``buffers``. This list + must have the same length as ``buffers``. + load_balancer (Optional[:class:`_LoadBalancer`]): an optional `_LoadBalancer` + object. If this argument is `None`, it means the `buffers` need no + rearrangement before being sharded. If this argument is a `_LoadBalancer` + object, call its `_generate_indices(restore=False)` to generate the + rearrangement indices such that each shard of `buffer[rearrange_idx]` is + well-balanced (i.e., having close sparsities). + + Returns: + List[torch.Tensor]: the sharded buffers. + + Note: + For `_context_parallel_shard` we require a non-None `load_balancer` object to be + explicitly passed if load-balancing is needed. + """ + # generate the index tensor for rearranging the buffer if a load-balance + # is available + load_balance_indices = load_balancer._generate_indices() if load_balancer else None + assert load_balance_indices is None or load_balance_indices.ndim == 2, ( + "load balance index expects shape (1, seq_len) or (B, seq_len) " + f"but got {load_balance_indices.shape}." + ) + + new_buffers = [] + sharded_buffer: torch.Tensor | BlockMask + for buffer, seq_dim in zip(buffers, buffer_seq_dims): + if isinstance(buffer, torch.Tensor): + # TODO: the load balance doesn't perform error handling. + + # NOTE: assuming batch dim is 0 + + if load_balance_indices is not None: + # TODO: we should expclitly ask users to unsqueeze the batch dim. + # But this is a BC breaking ask. + # However, what we have done today is also not very safe. + idx_batch_size = load_balance_indices.size(0) + data_batch_size = buffer.size(0) if seq_dim > 0 else 1 + + if idx_batch_size != 1 and idx_batch_size != data_batch_size: + raise ValueError( + "Cannot rearrange buffer: " + f"load_balance_indices has shape {load_balance_indices.shape}, " + f"but buffer has shape {buffer.shape}." + ) + + if seq_dim == 0: + buffer = torch.index_select( + buffer, dim=0, index=load_balance_indices[0] + ) + else: + indices = load_balance_indices + if idx_batch_size == 1: + size = [data_batch_size] + list(indices.size())[1:] + indices = indices.expand(*size) + + for i in range(data_batch_size): + buffer[i] = torch.index_select( + buffer[i], dim=seq_dim - 1, index=indices[i] + ) + + # use DTensor to shard the buffer on sequence dimension, retain the local tensor + sharded_buffer = distribute_tensor( + buffer, mesh, [Shard(seq_dim)], src_data_rank=None + ).to_local() + elif isinstance(buffer, BlockMask): + sharded_buffer = _create_cp_block_mask( + mask_mod=buffer.mask_mod, + B=buffer.kv_num_blocks.shape[0], + H=buffer.kv_num_blocks.shape[1], + Q_LEN=buffer.seq_lengths[0], + KV_LEN=buffer.seq_lengths[1], + device_mesh=mesh, + load_balancer=load_balancer, + ) + else: + raise ValueError(f"Unknown buffer type: {type(buffer)}") + + new_buffers.append(sharded_buffer) + + return new_buffers + + +def _create_cp_block_mask( + mask_mod: _mask_mod_signature, + B: int, + H: int, + Q_LEN: int, + KV_LEN: int, + device_mesh: DeviceMesh, + load_balancer: _LoadBalancer | None = None, +) -> BlockMask: + """ + Creates a specialized BlockMask for Context Parallel FlexAttention. + + This function creates a BlockMask that enables computation of attention results + for sharded Q attending to global KV. The mask appropriately handles the query + index offset required when each rank operates on a shard of the query sequence + while accessing the full key-value sequence. + + The function internally rewrites the provided mask_mod function to translate local + query indices to global query indices, ensuring that the masking logic is applied + correctly across the distributed computation. + + Args: + mask_mod (Callable): Mask function that operates on global attention indices. + B (int): Batch size. + H (int): Number of query heads. + Q_LEN (int): Global sequence length of the query. + KV_LEN (int): Global sequence length of the key/value. + device_mesh (DeviceMesh): Device mesh used for context parallelism. + load_balancer (Optional[:class:`_LoadBalancer`]): The load-balancer used to rearrange + QKV before sharding. This will be used to modify the block_mask generated. + + Returns: + BlockMask: A block mask configured for the local query shard that can be used + with flex_attention() for the given cp_mesh. + + Raises: + NotImplementedError: If Q_LEN is not divisible by (CP world size * BLOCK_SIZE). + + Warning: + Currently requires Q_LEN to be divisible by CP mesh world size * BLOCK_SIZE + (BLOCK_SIZE defaults to 128). This constraint exists because the BlockMask + must handle both padding and offsets correctly. For example, if Q_LEN is 384, + CP world size is 2, and BLOCK_SIZE is 128, the local Q_LEN would be 192. In + such cases, both rank0 and rank1 would have paddings in their local BlockMasks. + Support for padding in this scenario is planned for future work. + + """ + + from torch.nn.attention.flex_attention import _DEFAULT_SPARSE_BLOCK_SIZE + + if Q_LEN % (device_mesh.size() * _DEFAULT_SPARSE_BLOCK_SIZE) != 0: + raise NotImplementedError( + f"Q_LEN {Q_LEN} is not divisible by CP mesh world size {device_mesh.size()} * " + f"BLOCK_SIZE {_DEFAULT_SPARSE_BLOCK_SIZE}. This is not supported yet. " + ) + + global _compiled_create_block_mask + if _compiled_create_block_mask is None: + _compiled_create_block_mask = torch.compile( + create_block_mask, dynamic=False, fullgraph=True + ) + compiled_create_block_mask = _compiled_create_block_mask + + def _rewrite_mask_mod( + mask_mod: _mask_mod_signature, + rank: int, + block_size: int, + local_q_size: int, + qkv_rearrange_indices: torch.Tensor | None = None, + ) -> _mask_mod_signature: + assert qkv_rearrange_indices is None or qkv_rearrange_indices.ndim == 2, ( + "load balance index expects shape (1, seq_len) or (B, seq_len) " + f"but got {qkv_rearrange_indices.shape}." + ) + + def qkv_idx_restore( + b: torch.Tensor, idx_post_rearrange: torch.Tensor + ) -> torch.Tensor: + if qkv_rearrange_indices is not None: + if ( + qkv_rearrange_indices.size(0) == 1 + ): # identical load-balance in batch + idx_pre_rearrange = qkv_rearrange_indices[0][idx_post_rearrange] + else: + idx_pre_rearrange = qkv_rearrange_indices[b][idx_post_rearrange] + else: + idx_pre_rearrange = idx_post_rearrange + + return idx_pre_rearrange + + def local_q_idx_to_q_idx(local_q_idx: torch.Tensor) -> torch.Tensor: + # calculate local block_idx and block_offset + local_blk_idx, local_blk_offset = ( + local_q_idx // block_size, + local_q_idx % block_size, + ) + # NOTE: load balancing is not used + local_num_blocks = local_q_size // block_size + blk_idx = local_num_blocks * rank + local_blk_idx + return blk_idx * block_size + local_blk_offset + + return lambda b, h, q_idx, kv_idx: mask_mod( + b, + h, + qkv_idx_restore(b, local_q_idx_to_q_idx(q_idx)), + qkv_idx_restore(b, kv_idx), + ) + + cp_rank = device_mesh.get_local_rank() + cp_group_size = device_mesh.size() + load_balancer = load_balancer or _create_default_load_balancer( + Q_LEN, cp_group_size, device_mesh.device_type + ) + Q_SHARD_LEN = Q_LEN // cp_group_size + block_size = _DEFAULT_SPARSE_BLOCK_SIZE + + rearrange_indices = ( + load_balancer._generate_indices(restore=False) if load_balancer else None + ) + block_mask = compiled_create_block_mask( + _rewrite_mask_mod( + mask_mod, + cp_rank, + block_size, + Q_SHARD_LEN, + qkv_rearrange_indices=rearrange_indices, + ), + B, + H, + Q_SHARD_LEN, + KV_LEN, + device=device_mesh.device_type, + BLOCK_SIZE=(block_size, block_size), + ) + return block_mask + + +##################### +# Experimental APIs +##################### + + +class _ContextParallel(ParallelStyle): + class AttentionType(Enum): + FLEX = "flex_attention" + SDPA = "scaled_dot_product_attention" + + def __init__( + self, + seq_dim: int, + attention_type: AttentionType, + ) -> None: + super().__init__() + self.seq_dim = seq_dim + self.attention_type = attention_type + + def _apply(self, module: nn.Module, mesh: DeviceMesh) -> nn.Module: + if self.attention_type == self.AttentionType.FLEX: + module.register_forward_pre_hook( + partial(self.flex_input_fn, mesh=mesh), with_kwargs=True + ) + return module + elif self.attention_type == self.AttentionType.SDPA: + module.register_forward_pre_hook( + partial(self.sdpa_input_fn, mesh=mesh), with_kwargs=True + ) + module.register_forward_hook(partial(self.sdpa_output_fn, mesh=mesh)) + return module + else: + raise ValueError(f"Unknown attention type: {self.attention_type}") + + def flex_input_fn( + self, module: nn.Module | None, args: Any, kwargs: Any, mesh: DeviceMesh + ) -> Any: + args_list = list(args) + for idx, name in enumerate( + ("query", "key", "value", "score_mod", "block_mask") + ): + if idx >= len(args): + args_list.append(kwargs.pop(name, None)) + + query, key, value, score_mod, block_mask = args_list[:5] + assert isinstance(query, torch.Tensor) + assert isinstance(key, torch.Tensor) + assert isinstance(value, torch.Tensor) + assert isinstance(block_mask, BlockMask | tuple) + + key = key.contiguous() + value = value.contiguous() + + global_key, global_value = flex_cp_allgather( + key, value, self.seq_dim, c10d._get_process_group_name(mesh.get_group()) + ) + args_list[1] = global_key + args_list[2] = global_value + + return tuple(args_list), kwargs + + def sdpa_input_fn( + self, + module: nn.Module | None, + args: tuple[Any, ...], + kwargs: dict[str, Any], + mesh: DeviceMesh, + ) -> tuple[tuple[Any, ...], dict[str, Any]]: + placement = [Shard(self.seq_dim)] + all_args = [] + + for arg in itertools.chain(args, kwargs.values()): + if isinstance(arg, torch.Tensor): + if isinstance(arg, DTensor): + assert arg._spec.placements == placement + else: + arg = DTensor.from_local(arg, mesh, placement, run_check=False) + + all_args.append(arg) + + new_args = tuple(all_args[0 : len(args)]) + new_kwargs = dict(zip(kwargs.keys(), all_args[len(args) :])) + return new_args, new_kwargs + + def sdpa_output_fn( + self, module: nn.Module | None, inputs: Any, outputs: Any, mesh: DeviceMesh + ) -> Any: + new_outputs = [] + for output in [outputs] if isinstance(outputs, torch.Tensor) else outputs: + output = output.to_local() if isinstance(output, DTensor) else output + new_outputs.append(output) + + if isinstance(outputs, torch.Tensor): + return new_outputs[0] + + return tuple(new_outputs) + + +CPBuffer: TypeAlias = torch.Tensor | BlockMask +CPBufferContainer: TypeAlias = Sequence[CPBuffer] | Mapping[str, CPBuffer] +CPBufferSeqDims: TypeAlias = Sequence[int] | Mapping[str, int] + + +def _context_parallel_shard( + mesh: DeviceMesh, + buffers: CPBufferContainer, + seq_dims: CPBufferSeqDims, + load_balancer: _LoadBalancer | None = None, +) -> list[torch.Tensor | BlockMask]: + """ + Shard the buffers along the specified sequence dimensions (`seq_dims`), so that each + rank retains only its corresponding shard according to the provided `mesh`. If a + `load_balancer` is provided, the buffers will be rearranged by the load balancer + before sharding to improve load balance. Buffers can be either tensors or `BlockMask` + objects. If a buffer is a `BlockMask`, its sharding dimension is determined by the + `BlockMask` implementation, and the corresponding `seq_dim` is ignored. + + Note: + For `_context_parallel_shard`, a non-None `load_balancer` must be explicitly passed + if load balancing is required. + + Args: + mesh (DeviceMesh): The device mesh used for context parallelism. + buffers (List[torch.Tensor | BlockMask]): Buffers whose usage depends on the sequence + dimension. Examples include input batches, labels, and positional embedding buffers. + These buffers must be sharded along the sequence dimension to ensure correctness. + seq_dims (List[int]): The sequence dimensions for each buffer in `buffers`. Must have + the same length as `buffers`. + load_balancer (Optional[_LoadBalancer]): An optional load balancer object. If provided, + it rearranges the buffers before sharding to achieve better load balance. If not + provided, no rearrangement is performed. + + Returns: + List[torch.Tensor | BlockMask]: The sharded buffers, each corresponding to the local + shard for the current rank. + """ + # TODO: these global variables are going to bite us someday. + # We will have to remove them soon. + # For the new API, we only support the module wrapper mode. + global _dispatch_mode + _dispatch_mode = _DispatchMode.MODULE_WRAPPER + global _cp_options + if load_balancer is not None: + _cp_options.enable_load_balance = True + else: + _cp_options.enable_load_balance = False + + if len(buffers) != len(seq_dims): + raise ValueError( + "`seq_dims` must have the same number of elements as `buffers`." + ) + + flat_buffers, spec = tree_flatten(buffers) + flat_seq_dims, _ = tree_flatten(seq_dims) + if len(flat_buffers) != len(flat_seq_dims): + raise ValueError("`seq_dims` must have the pytree structure as `buffers`.") + + if isinstance(flat_buffers[0], torch.Tensor): + device = flat_buffers[0].device + else: + device = flat_buffers[0].kv_num_blocks.device + for buffer in flat_buffers: + if isinstance(buffer, torch.Tensor): + assert device == buffer.device, "All buffers must be on the same device" + else: + assert device == buffer.kv_num_blocks.device, ( + "All buffers must be on the same device" + ) + + flat_sharded_buffers = _context_parallel_buffers( + mesh, flat_buffers, flat_seq_dims, load_balancer + ) + + return tree_unflatten(flat_sharded_buffers, spec) + + +def _enable_context_parallel_dispatcher() -> None: + """ + Enable the context parallel dispatcher. This API is experimental and subject to change. + """ + _enable_cp_dtensor_dispatcher() + + +def _disable_context_parallel_dispatcher() -> None: + """ + Disable the context parallel dispatcher. This API is experimental and subject to change. + """ + _disable_cp_dtensor_dispatcher() + + +##################################################### +# Current public APIs, but are also subject to change +##################################################### +@contextlib.contextmanager +@torch.no_grad() +def context_parallel( + mesh: DeviceMesh, + *, + buffers: list[torch.Tensor] | None = None, + buffer_seq_dims: list[int] | None = None, + no_restore_buffers: set[torch.Tensor] | None = None, +) -> Generator[None, None, None]: + """ + + ``context_parallel`` is an experimental API to enable context + parallelism (CP). This API performs two actions: 1) patch the SDPA + (``torch.nn.functional.scaled_dot_product_attention``) with the CP-enabled + one, 2) shard ``buffers`` along the sequence dimension and each rank will + preserve the corresponding shard according ``mesh``. + + Args: + mesh (:class:`DeviceMesh`): the device mesh for the context parallelism. + buffers (Optional[List[torch.Tensor]]): buffers that the usage depend + on the sequence dimension. Examples are input batch, labels and + positional embedding buffers. These buffers must be sharded along + the sequence dimension to ensure the accuracy. The sharding will + happen in-place, the buffer's shape will change within the context. + The buffers will be restored after the context finishes. + ``no_restore_buffers`` can be used to specify which buffers don't + need to be restored. Note that ``buffers`` should not contain any + nn.Parameter. + buffer_seq_dims (Optional[List[int]]): the sequence dimensions of ``buffers``. + no_restore_buffers (Optional[Set[torch.Tensor]]): buffers in these set + won't be restored after the context exits. This set must be a subset + of ``buffers``. If the buffers won't be used after the context exits, + these buffers can be put in this list to avoid extra restore time. + + .. warning:: + `torch.distributed.tensor.experimental.context_parallel` is a + prototype feature in PyTorch. The API is subject to change. + """ + # For the legacy API, we only support the monkey-patch mode. + # We will deprecate this API once the new API is widely used. + global _dispatch_mode + _dispatch_mode = _DispatchMode.MONKEY_PATCH + + buffers = [] if buffers is None else buffers + buffer_seq_dims = [] if buffer_seq_dims is None else buffer_seq_dims + no_restore_buffers = set() if no_restore_buffers is None else no_restore_buffers + + if len(buffers) != len(buffer_seq_dims): + raise ValueError( + "`seq_dims` must have the same number of elements as `buffers`." + ) + + for buffer in no_restore_buffers: + # Cannot use `if not buffer in buffers` which will incur tensor comparison. + if not any(b is buffer for b in buffers): + raise ValueError("`no_restore_buffers` must be a subset of `buffers`.") + + original_buffers = [None if b in no_restore_buffers else b.clone() for b in buffers] + + device = buffers[0].device + seq_length = buffers[0].shape[buffer_seq_dims[0]] + cp_world_size = mesh.size() + + # If `enable_load_balance` is True, the default Head-tail load balancer + # (:class:`_HeadTailLoadBalancer`) is used to rearrange the buffers before + # sharding. Otherwise, we don't do any load-balance rearrange by passing + # `None` to `_context_parallel_shard()`. + load_balancer = _create_default_load_balancer(seq_length, cp_world_size, device) + shards = _context_parallel_buffers( + mesh, + cast(list[torch.Tensor | BlockMask], buffers), + buffer_seq_dims, + load_balancer, + ) + for buffer, shard in zip(buffers, shards): + assert isinstance(shard, torch.Tensor), "ContextParallel only supports Tensor" + shard = shard.clone() + buffer.resize_(shard.shape) + buffer.copy_(shard) + + _enable_context_parallel_dispatcher_impl(seq_dim=2, mesh=mesh) + yield + _disable_context_parallel_dispatcher_impl() + + for buffer, original_buffer in zip(buffers, original_buffers): + if original_buffer is not None: + buffer.resize_(original_buffer.shape) + buffer.copy_(original_buffer) + + +@torch.no_grad() +def context_parallel_unshard( + mesh: DeviceMesh, + buffers: list[torch.Tensor], + seq_dims: list[int], + load_balancer: _LoadBalancer | None = None, +) -> list[torch.Tensor]: + """ + Unshard the tensors (e.g., output) that are sharded due to context parallelism. + + Args: + mesh (:class:`DeviceMesh`): the device mesh for the context parallelism. + buffers (List[torch.Tensor]): the buffers to be unsharded. + seq_dims (List[int]): the sequence dimensions of ``buffers``. This list + must have the same length as ``buffers``. + load_balancer (Optional[:class:`_Loadbalancer`]): an optional `_LoadBalancer` + object. If this argument is `None`, it means the `buffers` were not + rearranged when being sharded and there's no need to put it back to order + after unsharding. If this argument is a `_LoadBalancer` object, call + its `_generate_indices(restore=True)` to generate the restore indices such + that `unsharded[restore_idx]` is the original buffer. + + Returns: + List[torch.Tensor]: the unsharded buffers. + + Note: + For `context_parallel_unshard` we require not-None `load_balancer` object be + explicitly passed if flex_attention() is to be used and load-balancing is needed. + This is different from the case of SDPA though we strongly suggest users follow + the same convention. + """ + device = buffers[0].device + cp_world_size = mesh.size() + seq_length = buffers[0].shape[seq_dims[0]] * cp_world_size + + # If users don't pass in a `load_balancer`: + # - if `enable_load_balance` is True, we use the default round-robin + # load balancer. + # - if `enable_load_balance` is False, we don't do any load balancing + # by passing in `None` as `restore_indices`. + load_balancer = load_balancer or _create_default_load_balancer( + seq_length, cp_world_size, device + ) + restore_indices = ( + load_balancer._generate_indices(restore=True) if load_balancer else None + ) + + assert restore_indices is None or restore_indices.ndim == 2, ( + "load balance restore index expects shape (1, seq_len) or (B, seq_len) " + f"but got {restore_indices.shape}." + ) + unsharded_buffers = [] + for b, dim in zip(buffers, seq_dims): + b = b.contiguous() + unsharded_b = _maybe_wait(ft_c.all_gather_tensor(b, dim, mesh)) + + if restore_indices is not None: + # NOTE: assuming batch dim is 0 + idx_batch_size = restore_indices.size(0) + data_batch_size = unsharded_b.size(0) + if idx_batch_size != 1 and idx_batch_size != data_batch_size: + raise ValueError( + "Cannot restore buffer: " + f"restore_indices has shape {restore_indices.shape}, " + f"but unsharded_b has shape {unsharded_b.shape}." + ) + + for i in range(data_batch_size): + index = ( + restore_indices[0] # identical load-balance in batch + if idx_batch_size == 1 + else restore_indices[i] + ) + unsharded_b_batch_i = torch.index_select( + unsharded_b[i], dim=dim - 1, index=index + ) + unsharded_b[i] = unsharded_b_batch_i + + unsharded_buffers.append(unsharded_b) + + return unsharded_buffers + + +def set_rotate_method(rotate_method: str) -> None: + """ + Context Parallel SDPA requires the rotation of kv shards. Users can call this + API to specify which rotation method to use. "alltoall" shuffles the kv shards + using all-to-all collective. While "allgather" gathers the kv shards using + all-gather collective after the first sub-SDPA computation. If this API has not + been called, the default rotate method is "allgather". + + Args: + rotate_method (str): the rotate method to use. Currently only supports + "allgather" and "alltoall". If a different string other than these two + is passed in, the function will raise an error. + + Returns: + None + """ + logger.info("Note that FlexAttention CP doesn't support alltoall yet.") + if rotate_method == "allgather": + _cp_options.rotate_method = _RotateMethod.ALL_GATHER + elif rotate_method == "alltoall": + _cp_options.rotate_method = _RotateMethod.ALL_TO_ALL + else: + raise NotImplementedError( + "Context Parallel does not support " + f"using {rotate_method} for kv shards rotation" + ) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/experimental/_context_parallel/_cp_custom_ops.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/experimental/_context_parallel/_cp_custom_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..760ca79cfcbe54971afd23d467437e790c4b5fe2 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/experimental/_context_parallel/_cp_custom_ops.py @@ -0,0 +1,88 @@ +from typing import Any + +import torch +import torch.distributed._functional_collectives as funcol +import torch.distributed.distributed_c10d as c10d + + +@torch.library.custom_op("cplib::flex_cp_allgather", mutates_args=()) +def flex_cp_allgather( + k: torch.Tensor, v: torch.Tensor, seq_dim: int, pg_name: c10d.GroupName +) -> tuple[torch.Tensor, torch.Tensor]: + k = k.contiguous() + v = v.contiguous() + k = funcol.all_gather_tensor(k, seq_dim, pg_name) + v = funcol.all_gather_tensor(v, seq_dim, pg_name) + if isinstance(k, funcol.AsyncCollectiveTensor): + k = k.wait() + if isinstance(v, funcol.AsyncCollectiveTensor): + v = v.wait() + return k, v + + +@flex_cp_allgather.register_fake +def _( + k: torch.Tensor, v: torch.Tensor, seq_dim: int, pg_name: c10d.GroupName +) -> tuple[torch.Tensor, torch.Tensor]: + shape_k = list(k.shape) + shape_v = list(v.shape) + shape_k[seq_dim] *= c10d._get_group_size_by_name(pg_name) + shape_v[seq_dim] *= c10d._get_group_size_by_name(pg_name) + new_k = torch.empty(shape_k, dtype=k.dtype, device=k.device) + new_v = torch.empty(shape_v, dtype=v.dtype, device=v.device) + return new_k, new_v + + +@torch.library.custom_op("cplib::flex_cp_allgather_backward", mutates_args=()) +def flex_cp_allgather_backward( + grad_full_k: torch.Tensor, + grad_full_v: torch.Tensor, + seq_dim: int, + pg_name: c10d.GroupName, +) -> tuple[torch.Tensor, torch.Tensor]: + grad_k = funcol.reduce_scatter_tensor(grad_full_k, "sum", seq_dim, pg_name) + if isinstance(grad_k, funcol.AsyncCollectiveTensor): + grad_k = grad_k.wait() + grad_v = funcol.reduce_scatter_tensor(grad_full_v, "sum", seq_dim, pg_name) + if isinstance(grad_v, funcol.AsyncCollectiveTensor): + grad_v = grad_v.wait() + + return grad_k, grad_v + + +@flex_cp_allgather_backward.register_fake +def _( + grad_full_k: torch.Tensor, + grad_full_v: torch.Tensor, + seq_dim: int, + pg_name: c10d.GroupName, +) -> tuple[torch.Tensor, torch.Tensor]: + shape_k = list(grad_full_k.shape) + shape_v = list(grad_full_v.shape) + shape_k[seq_dim] //= c10d._get_group_size_by_name(pg_name) + shape_v[seq_dim] //= c10d._get_group_size_by_name(pg_name) + new_grad_k = torch.empty( + shape_k, dtype=grad_full_k.dtype, device=grad_full_k.device + ) + new_grad_v = torch.empty( + shape_v, dtype=grad_full_v.dtype, device=grad_full_v.device + ) + return new_grad_k, new_grad_v + + +def _flex_cp_allgather_backward( + ctx: Any, grad_full_k: torch.Tensor, grad_full_v: torch.Tensor +) -> tuple[torch.Tensor, torch.Tensor, None, None]: + grad_k, grad_v = flex_cp_allgather_backward( + grad_full_k, grad_full_v, ctx.seq_dim, ctx.pg_name + ) + return grad_k, grad_v, None, None + + +def _flex_cp_setup_context(ctx: Any, inputs: Any, output: Any) -> None: + _, _, ctx.seq_dim, ctx.pg_name = inputs + + +flex_cp_allgather.register_autograd( + _flex_cp_allgather_backward, setup_context=_flex_cp_setup_context +) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/experimental/_context_parallel/_load_balancer.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/experimental/_context_parallel/_load_balancer.py new file mode 100644 index 0000000000000000000000000000000000000000..4b293b0e260efcc9e8fd308579ecaa0e3d48c0e6 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/experimental/_context_parallel/_load_balancer.py @@ -0,0 +1,486 @@ +# this file contains the `_LoadBalancer` class and its family of implementation +# for different load-balancing strategies in tensor sharding. +import functools +from abc import ABC, abstractmethod + +import torch +from torch import Tensor +from torch.nn.attention.flex_attention import BlockMask + + +# make it private since it's still a prototype +class _LoadBalancer(ABC): + @abstractmethod + def _generate_indices(self, restore: bool = False) -> Tensor | None: + """ + Generate indices for load balancing. + Args: + restore (bool): + + Returns: + The generated indices of shape `(1, seq_len)` if the load-balancing is + identical within the batch, or `(batch_size, seq_len)` if the load-balancing + should vary within the batch. + + Warning: + For Multi-Head Attention, we require the masks over the head dimension are identical + (i.e. the return value of `_generate_indices()` does not have `heads` dimension). + + Example: + Here is the causal mask for attention where q_len == kv_len == 8: + KV_index + [1, 0, 0, 0, 0, 0, 0, 0] + [1, 1, 0, 0, 0, 0, 0, 0] + [1, 1, 1, 0, 0, 0, 0, 0] + Q_index [1, 1, 1, 1, 0, 0, 0, 0] + [1, 1, 1, 1, 1, 0, 0, 0] + [1, 1, 1, 1, 1, 1, 0, 0] + [1, 1, 1, 1, 1, 1, 1, 0] + [1, 1, 1, 1, 1, 1, 1, 1] + + This mask matrix also represents the computation required to compute + the masked Q @ K^T by: + - mask[i, j] == 1: the computation of Q[i, :] dot K[j, :] is required + - mask[i, j] == 0: the computation should be skipped + + Therefore the number of 1s in matrix represents the amount of computation + required. + + Assume we want to distribute this Q @ K^T computation to 2 devices, then + the matrix is also distributed as: + KV_index + [1, 0, 0, 0, 0, 0, 0, 0] + [1, 1, 0, 0, 0, 0, 0, 0] + [1, 1, 1, 0, 0, 0, 0, 0] rank 0 + [1, 1, 1, 1, 0, 0, 0, 0] + Q_index ------------------------ + [1, 1, 1, 1, 1, 0, 0, 0] + [1, 1, 1, 1, 1, 1, 0, 0] rank 1 + [1, 1, 1, 1, 1, 1, 1, 0] + [1, 1, 1, 1, 1, 1, 1, 1] + + An imbalance of computation is observed on these 2 ranks and this could make + rank 1 the straggler when performing Context Parallel. In order to balance + the computation, we need to rearrange the QKV tensors before sharding in such a + way that the result mask matrix is evenly distributed over devices and each + rank has the number of 1s as close as possible. + + This method defines the strategy of how to rearrange the QKV tensor for better + load-balance: + - when `restore == False`, this method returns an indices tensor `rearrange_idx` + such that Q[rearrange_idx] is the desired Q tensor after rearranging. + - when `restore == True`, this method returns an indices tensor `restore_idx` + such that Q[rearrange_idx][restore_idx] == Q, i.e. restoring the rearranged tensor + back to the original status before rearranging. + """ + + +class _HeadTailLoadBalancer(_LoadBalancer): + def __init__(self, seq_length: int, world_size: int, device: str | torch.device): + self.seq_length = seq_length + self.world_size = world_size + self.device = device + + def _generate_indices(self, restore: bool = False) -> Tensor: + """ + Generate head-and-tail load balancing indices or restore indices. + Args: + restore: + If True, generate restore indices that map head-and-tail rearranged + positions back to original positions. If False, generate load + balance indices that rearrange original positions to head-and-tail pattern. + + Returns: + The generated indices of shape `(1, seq_len)` because the load-balancing is + identical within the batch. + + Warning: + For Multi-Head Attention, we require the masks over the head dimension are identical + (i.e. the return value of `_generate_indices()` does not have `heads` dimension). + + Example: + Here is the causal mask for attention where q_len == kv_len == 8: + KV_index + [1, 0, 0, 0, 0, 0, 0, 0] + [1, 1, 0, 0, 0, 0, 0, 0] + [1, 1, 1, 0, 0, 0, 0, 0] + Q_index [1, 1, 1, 1, 0, 0, 0, 0] + [1, 1, 1, 1, 1, 0, 0, 0] + [1, 1, 1, 1, 1, 1, 0, 0] + [1, 1, 1, 1, 1, 1, 1, 0] + [1, 1, 1, 1, 1, 1, 1, 1] + + Head-tail load-balance strategy rearranges the Q tensor by combining + Q[0:k] (on seq dim) and Q[-k:] for rank 0, Q[k:2k] and Q[-2k:-k] for + rank 1, and so on. In python code it looks like: + + k = Q.size(0) // (2 * cp_world_size) + for rank in range(cp_world_size): + reordered_Q[rank * 2 * k : (rank + 1) * 2 * k] = torch.cat( + (Q[rank * k : (rank + 1) * k], Q[-(rank + 1) * k : -rank * k]) + ) + + This can also be done by tensor slicing. For the above example, the indices + tensor for slicing is: + slice_indices = Tensor([0, 7, 1, 6, 2, 5, 3, 4]) + + After reordering QKV using the `slice_indices`, the corresponding mask matrix + distributing over 2 devices becomes well-balanced: + KV_index + [1, 0, 0, 0, 0, 0, 0, 0] + [1, 1, 1, 1, 1, 1, 1, 1] + [1, 1, 0, 0, 0, 0, 0, 0] rank 0 + [1, 1, 1, 1, 1, 1, 1, 0] + Q_index ------------------------ + [1, 1, 1, 0, 0, 0, 0, 0] + [1, 1, 1, 1, 1, 1, 0, 0] rank 1 + [1, 1, 1, 1, 0, 0, 0, 0] + [1, 1, 1, 1, 1, 0, 0, 0] + + To restore the reordering and putting the tensor back, slicing op can do the + trick with a `restore_indices` such that: + slice_indices[restore_indices] == Tensor([0, 1, 2, ...]) + + In this way, `reordered_Q[restore_indices]` will just be the original Q. + """ + seq_length = self.seq_length + world_size = self.world_size + assert seq_length % (world_size * 2) == 0 + chunk_size = seq_length // (world_size * 2) + all_indices = [] + + for rank in range(world_size): + # Generate indices for first chunk of the cp rank + first_chunk_start = rank * chunk_size + first_chunk_indices = list( + range(first_chunk_start, first_chunk_start + chunk_size) + ) + + # Second chunk: positions from the complementary chunk + second_chunk_idx = world_size * 2 - rank - 1 + second_chunk_start = second_chunk_idx * chunk_size + second_chunk_indices = list( + range(second_chunk_start, second_chunk_start + chunk_size) + ) + # combine the indices for this rank + all_indices.extend(first_chunk_indices + second_chunk_indices) + + all_indices_tensor = torch.tensor( + all_indices, dtype=torch.int, device=self.device + ) + if restore: + all_indices_tensor = torch.argsort(all_indices_tensor) + + return all_indices_tensor.unsqueeze(0) # add batch dim + + +class _PerDocumentHeadTailLoadBalancer(_LoadBalancer): + def __init__( + self, + seq_length_per_doc: list[list[int]], + world_size: int, + device: str | torch.device, + ): + """ + `seq_length_per_doc` has size (B, seq_len) if the load-balancing should vary + within the batch. Otherwise `seq_length_per_doc` should have size (1, seq_len). + """ + self.seq_length_per_doc = seq_length_per_doc + self.world_size = world_size + self.device = device + + def _generate_indices(self, restore: bool = False) -> Tensor: + """ + Generate the per-document head-and-tail rearrange indices so that after rearranging + the input is load-balanced in per-document head-and-tail style. + + Args: + restore: + If True, generate restore indices that map per-document head-and-tail + rearranged positions back to original positions. If False, generate load + balance indices that rearrange original positions to per-document + head-and-tail pattern. + + Returns: + The generated indices of shape `(batch_size, seq_len)` if the load-balancing + should vary within the batch. Otherwise, it should have shape `(1, seq_len)`. + + Warning: + For Multi-Head Attention, we require the masks over the head dimension are identical + (i.e. `seq_length_per_doc` must have size (B, seq_len) or (1, seq_len)). + + Example: + Here is the document causal mask for attention where q_len == kv_len == 16: + KV_index + [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] + [1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] + [1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] + [1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] + [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] + [0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] + [0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0] + Q_index [0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0] + [0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0] + [0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0] + [0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0] + [0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0] + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0] + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0] + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0] + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1] + + The per-document head-and-tail load-balancer will apply head-and-tail + reordering within each document. After load-balancing for context-parallel + on 2 devices, the above mask matrix will look like this: + KV_index + [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] + [1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] + [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] + [0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] + [0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0] + [0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0] + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0] + Q_index [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1] + ------------------------------------------------ + [1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] + [1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] + [0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0] + [0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0] + [0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0] + [0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0] + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0] + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0] + """ + return torch.stack( + [ + self._generate_indices_for_batch(seq_lengths, restore) + for seq_lengths in self.seq_length_per_doc + ] + ) + + def _generate_indices_for_batch(self, seq_length_per_doc, restore) -> Tensor: # type: ignore[no-untyped-def] + world_size = self.world_size + device = self.device + assert all( + seq_length % (2 * world_size) == 0 for seq_length in seq_length_per_doc + ) + chunk_length_per_doc = [ + seq_length // (2 * world_size) for seq_length in seq_length_per_doc + ] + + indices = [] + document_start_idx = 0 + for seq_length, chunk_length in zip(seq_length_per_doc, chunk_length_per_doc): + # Generate the indices for the current document + for rank in range(world_size): + head_chunk_start_idx = document_start_idx + chunk_length * rank + tail_chunk_end_idx = document_start_idx + chunk_length * ( + 2 * world_size - rank + ) + indices.append( + torch.arange( + head_chunk_start_idx, + head_chunk_start_idx + chunk_length, + device=device, + ) + ) + indices.append( + torch.arange( + tail_chunk_end_idx - chunk_length, + tail_chunk_end_idx, + device=device, + ) + ) + + document_start_idx += seq_length + + indices_tensor = torch.cat(indices) + if restore: + indices_tensor = torch.argsort(indices_tensor) + + return indices_tensor + + +class _PTRRLoadBalancer(_LoadBalancer): + """ + Processing-Time based Round-Robin (PTRR) load balancer. This load balancer should + only be used for flex_attention() since it leverages `BlockMask`. + """ + + def __init__( + self, + block_mask: BlockMask, + world_size: int, + ): + """ + `block_mask` must have shape (B, 1, seq_len, seq_len) or (1, 1, seq_len, seq_len). + """ + self.block_mask = block_mask + self.world_size = world_size + + @staticmethod + def ptrr_scheduling(process_time: Tensor, group_size: int) -> Tensor: + """ + Separate the tasks into `group_size` groups using PTRR scheduling. + process_time: + 1D tensor of size n, where n is the number of tasks. The value + is the process time of the task. Size `n` must be divisible by + `group_size`. + group_size: + the number of groups + + Returns: + tasks_in_group (list[list[int]]): + A collection of list[int] and each list should have size `n // group_size` + (`group_size` lists in total). Each element is an index in the input + `process_time` (i.e. [0, len(process_time) - 1]). + + Example: + process_time = [9, 14, 2, 20, 10, 15, 8, 14, 16, 19, 15, 3, 12, 1, 12, 10] + tasks_in_group = [ + [3, 12, 13, 14], # values = [1, 12, 12, 20], sum = 45 + [2, 4, 7, 9], # values = [2, 10, 14, 19], sum = 45 + [1, 8, 11, 15], # values = [14, 16, 3, 10], sum = 43 + [0, 5, 6, 10] # values = [9, 15, 8, 15], sum = 47 + ] + """ + assert process_time.ndim == 1 + + num_tasks = process_time.size(0) + + if num_tasks % group_size != 0: + raise NotImplementedError( + f"num_tasks {num_tasks} must be divisible by group_size {group_size}" + ) + + device = process_time.device + _, sorted_indices_descending = torch.sort( + process_time, descending=True, stable=True + ) # if process time is tied, the order is preserved + sorted_indices_descending_reversed = torch.flip( + sorted_indices_descending.view(-1, group_size), dims=[1] + ).view(-1) + tasks_in_group = torch.where( + torch.arange(num_tasks, device=device) // group_size % 2 == 0, + sorted_indices_descending, + sorted_indices_descending_reversed, + ) + tasks_in_group = tasks_in_group.view(-1, group_size).transpose( + 0, 1 + ) # (group_size, n // group_size) + + # sort each group. This step should not have impact on correctness + # nor execution run time, but it helps users visualize the mask + tasks_in_group, _ = torch.sort(tasks_in_group, dim=1) + return tasks_in_group + + def _generate_indices(self, restore: bool = False) -> Tensor: + """ + Generate the PTRR reorder indices of shape `(1, seq_len)` or `(batch_size, seq_len)`. + + Args: + restore: + If True, generate restore indices that map Processing-Time based Round-Robin + (PTRR) rearranged positions back to original positions. If False, generate + load balance indices that rearrange original positions to PTRR pattern. + + Returns: + The generated indices of shape `(1, seq_len)` if the load-balancing is + identical within the batch (i.e. `BlockMask.shape[0] == 1`), or + `(batch_size, seq_len)` if the load-balancing should vary within the batch. + + Warning: + For Multi-Head Attention, we require the masks over the head dimension are identical + (i.e. `self.block_mask` must have shape (B, 1, seq_len, seq_len) or (1, 1, seq_len, seq_len)). + + Example: + Here is the document causal mask for attention whereq_len == kv_len == 16 * BLOCK_SIZE + (each entry is a block): + KV_index + [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] -> row value = 1 + [1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] -> row value = 2 + [1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] -> row value = 3 + [1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] -> row value = 4 + [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] -> row value = 1 + [0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] -> row value = 2 + [0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0] -> row value = 3 + Q_index [0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0] -> row value = 4 + [0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0] -> row value = 5 + [0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0] -> row value = 6 + [0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0] -> row value = 7 + [0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0] -> row value = 8 + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0] -> row value = 1 + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0] -> row value = 2 + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0] -> row value = 3 + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1] -> row value = 4 + + The reorder indices will be: [2, 3, 5, 6, 8, 11, 12, 13, 0, 1, 4, 7, 9, 10, 14, 15] and + the mask matrix will look like: + KV_index + [1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] -> row value = 3 + [1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] -> row value = 4 + [0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] -> row value = 2 + [0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0] -> row value = 3 + [0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0] -> row value = 5 rank 0 (sum=28) + [0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0] -> row value = 8 + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0] -> row value = 1 + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0] -> row value = 2 + ------------------------------------------------ + [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] -> row value = 1 + [1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] -> row value = 2 + [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] -> row value = 1 + [0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0] -> row value = 4 + [0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0] -> row value = 6 rank 1 (sum=28) + [0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0] -> row value = 7 + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0] -> row value = 3 + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1] -> row value = 4 + """ + block_mask = self.block_mask + kv_num_blocks = block_mask.kv_num_blocks + full_kv_num_blocks = block_mask.full_kv_num_blocks + non_sparse_kv_num_blocks = ( + kv_num_blocks + full_kv_num_blocks + if full_kv_num_blocks is not None + else kv_num_blocks + ) + B, H, Q = non_sparse_kv_num_blocks.shape + # requirement: the masking is identical across heads (i.e. H == 1 in BlockMask) + non_sparse_kv_num_blocks = non_sparse_kv_num_blocks.view(-1, Q) # (B, Q_BLK) + + batch_ptrr = torch.vmap( + functools.partial( + _PTRRLoadBalancer.ptrr_scheduling, + group_size=self.world_size, + ) + ) + ptrr_indices = batch_ptrr( + non_sparse_kv_num_blocks + ) # (B, group_size, num_blks_in_group) + ptrr_indices = ptrr_indices.reshape(B, -1) # (B, num_blocks) + + # NOTE: only support the case where the qkv block size are equal + q_blk_size, kv_blk_size = block_mask.BLOCK_SIZE + assert q_blk_size == kv_blk_size, ( + "for now only support q_blk_size == kv_blk_size" + ) + + indices = torch.arange( + q_blk_size * ptrr_indices.size(1), device=ptrr_indices.device + ).view(-1, q_blk_size) # (NUM_BLOCKS, BLOCK_SIZE) + indices = indices[ptrr_indices].view(B, -1) # (B, qkv_size) + + if restore: + indices = torch.vmap(torch.argsort)(indices) + + return indices + + +def _create_default_load_balancer( + seq_length: int, world_size: int, device: str | torch.device +) -> _LoadBalancer | None: + from ._attention import _cp_options + + if _cp_options.enable_load_balance: + return _HeadTailLoadBalancer(seq_length, world_size, device) + else: + return None diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/experimental/_context_parallel/_sharding_rules.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/experimental/_context_parallel/_sharding_rules.py new file mode 100644 index 0000000000000000000000000000000000000000..ebb6eb0cface8449af3b8d7df72c2e41a51de309 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/experimental/_context_parallel/_sharding_rules.py @@ -0,0 +1,406 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates +""" +Context Parallelism sharding rules for scaled_dot_product attention operators. + +The sharding rules for CP cannot be embedded by default because Shard(2) is not +a valid sharding for SDPA without CP enabled. This module provides utilities to +dynamically install Shard(2) sharding rules when CP is activated. +""" + +from contextlib import contextmanager + +import torch +from torch.distributed.tensor._op_schema import ( + OpSchema, + OpStrategy, + PlacementList, + RuntimeSchemaInfo, +) +from torch.distributed.tensor._ops.registration import register_op_strategy +from torch.distributed.tensor._ops.utils import expand_to_full_mesh_op_strategy +from torch.distributed.tensor.debug import ( + _clear_fast_path_sharding_prop_cache, + _clear_python_sharding_prop_cache, +) +from torch.distributed.tensor.placement_types import Replicate, Shard + + +aten = torch.ops.aten + +SEQ_DIM = 2 + + +@contextmanager +def _op_strategy_context(op_overload, strategy_func, schema_info=None): + """ + Context manager for setting and clearing op strategies for Context Parallelism. + + Args: + op_overload: The operator overload to set or clear the strategy for. + strategy_func: The strategy function to set for the operator overload. + schema_info: Optional schema information for the operator overload. + + Yields: + None + """ + from torch.distributed.tensor import DTensor + + propagator = DTensor._op_dispatcher.sharding_propagator + _origin_op_strategy_funcs = None + _origin_op_strategy_schema = None + try: + # Save original strategy if exists + if op_overload in propagator.op_strategy_funcs: + _origin_op_strategy_funcs = propagator.op_strategy_funcs[op_overload] + if op_overload in propagator.op_to_schema_info: + _origin_op_strategy_schema = propagator.op_to_schema_info[op_overload] + + # Register the new op strategy + register_op_strategy(op_overload, schema_info=schema_info)(strategy_func) + yield (_origin_op_strategy_funcs, _origin_op_strategy_schema) + finally: + # Restore original strategy + if _origin_op_strategy_funcs is None: + if op_overload in propagator.op_strategy_funcs: + del propagator.op_strategy_funcs[op_overload] + else: + propagator.op_strategy_funcs[op_overload] = _origin_op_strategy_funcs + + if _origin_op_strategy_schema is None: + if op_overload in propagator.op_to_schema_info: + del propagator.op_to_schema_info[op_overload] + else: + propagator.op_to_schema_info[op_overload] = _origin_op_strategy_schema + + # Ideally, we should clear the cache, but it is too expensive. + # _clear_python_sharding_prop_cache() + # _clear_fast_path_sharding_prop_cache() + + +# ==================== Flash Attention Strategies ==================== + + +def _scaled_dot_product_flash_attention_cp_strategy(op_schema: OpSchema) -> OpStrategy: + """ + Strategy for flash attention forward with Context Parallelism support. + This includes the base strategies plus CP-specific sequence dimension sharding. + """ + # Import here to avoid circular dependency + from torch.distributed.tensor._ops._matrix_ops import ( + _scaled_dot_product_flash_attention_base_strategies, + ) + + # Get the base strategies (without CP modifications) + mesh = op_schema.get_mesh_from_args() + single_mesh_dim_strategies = _scaled_dot_product_flash_attention_base_strategies( + op_schema + ) + + # Add Context Parallelism strategy: shards on the sequence dim + return_debug_mask = len(op_schema.args_schema) >= 6 and op_schema.args_schema[5] + debug_attn_mask_sharding = Shard(SEQ_DIM) if return_debug_mask else Replicate() + + cp_strategy: PlacementList = [ + Shard(SEQ_DIM), # output + Shard(SEQ_DIM), # logsumexp + None, # cum_seq_q + None, # cum_seq_k + None, # max_q + None, # max_k + Replicate(), # rng_state + None, # unused + debug_attn_mask_sharding, # debugattn + Shard(SEQ_DIM), # q + Shard(SEQ_DIM), # k + Shard(SEQ_DIM), # v + ] + single_mesh_dim_strategies.append(cp_strategy) + + return expand_to_full_mesh_op_strategy( + mesh, op_schema, single_mesh_dim_strategies, input_index=9 + ) + + +def _scaled_dot_product_flash_attention_backward_cp_strategy( + op_schema: OpSchema, +) -> OpStrategy: + """ + Strategy for flash attention backward with Context Parallelism support. + """ + from torch.distributed.tensor._ops._matrix_ops import ( + _scaled_dot_product_flash_attention_backward_base_strategies, + ) + + mesh = op_schema.get_mesh_from_args(validate=False) + single_mesh_dim_strategies = ( + _scaled_dot_product_flash_attention_backward_base_strategies(op_schema) + ) + + tensor_input_indices = [ + i + for i, arg_spec in enumerate(op_schema.args_schema) + if isinstance(arg_spec, OpStrategy) + ] + num_tensor_inputs = len(tensor_input_indices) + + # Context Parallelism: shards on the sequence dim + cp_strategy: PlacementList = [ + Shard(SEQ_DIM), # grad_q + Shard(SEQ_DIM), # grad_k + Shard(SEQ_DIM), # grad_v + Shard(SEQ_DIM), # grad_output + Shard(SEQ_DIM), # q + Shard(SEQ_DIM), # k + Shard(SEQ_DIM), # v + Shard(SEQ_DIM), # output + Shard(SEQ_DIM), # logsumexp + ] + cp_strategy.extend([Replicate()] * (num_tensor_inputs - 6)) + single_mesh_dim_strategies.append(cp_strategy) + + return expand_to_full_mesh_op_strategy( + mesh, op_schema, single_mesh_dim_strategies, input_index=3 + ) + + +# ==================== Efficient Attention Strategies ==================== + + +def _scaled_dot_product_efficient_attention_cp_strategy( + op_schema: OpSchema, +) -> OpStrategy: + """ + Strategy for efficient attention forward with Context Parallelism support. + """ + from torch.distributed.tensor._ops._matrix_ops import ( + _scaled_dot_product_efficient_attention_base_strategies, + ) + + mesh = op_schema.get_mesh_from_args() + single_mesh_dim_strategies = ( + _scaled_dot_product_efficient_attention_base_strategies(op_schema) + ) + + # Add Context Parallelism strategy + has_attn_bias = op_schema.args_schema[3] is not None + + cp_strategy: PlacementList = [ + Shard(SEQ_DIM), # output + Shard(SEQ_DIM), # logsumexp + None, # philox_seed + None, # philox_offset + Shard(SEQ_DIM), # q + Shard(SEQ_DIM), # k + Shard(SEQ_DIM), # v + ] + if has_attn_bias: + cp_strategy.append(Replicate()) # attn bias - not sharded for CP + single_mesh_dim_strategies.append(cp_strategy) + + return expand_to_full_mesh_op_strategy( + mesh, op_schema, single_mesh_dim_strategies, input_index=4 + ) + + +def _scaled_dot_product_efficient_attention_backward_cp_strategy( + op_schema: OpSchema, +) -> OpStrategy: + """ + Strategy for efficient attention backward with Context Parallelism support. + """ + from torch.distributed.tensor._ops._matrix_ops import ( + _scaled_dot_product_efficient_attention_backward_base_strategies, + ) + + mesh = op_schema.get_mesh_from_args(validate=False) + single_mesh_dim_strategies = ( + _scaled_dot_product_efficient_attention_backward_base_strategies(op_schema) + ) + + has_attn_bias = op_schema.args_schema[4] is not None + + # Context Parallelism: shards on the sequence dim + cp_strategy: PlacementList = [ + Shard(SEQ_DIM), # grad_q + Shard(SEQ_DIM), # grad_k + Shard(SEQ_DIM), # grad_v + Shard(1) if has_attn_bias else None, # grad_bias + Shard(SEQ_DIM), # grad_output + Shard(SEQ_DIM), # q + Shard(SEQ_DIM), # k + Shard(SEQ_DIM), # v + Shard(SEQ_DIM), # output + Shard(SEQ_DIM), # logsumexp + ] + if has_attn_bias: + cp_strategy.insert(8, Shard(1)) # attn_bias input + cp_strategy.extend([Replicate(), Replicate()]) + single_mesh_dim_strategies.append(cp_strategy) + + return expand_to_full_mesh_op_strategy( + mesh, op_schema, single_mesh_dim_strategies, input_index=4 + ) + + +# ==================== cuDNN Attention Strategies ==================== + + +def _scaled_dot_product_cudnn_attention_cp_strategy(op_schema: OpSchema) -> OpStrategy: + """ + Strategy for cudnn attention forward with Context Parallelism support. + """ + from torch.distributed.tensor._ops._matrix_ops import ( + _scaled_dot_product_cudnn_attention_base_strategies, + ) + + mesh = op_schema.get_mesh_from_args() + single_mesh_dim_strategies = _scaled_dot_product_cudnn_attention_base_strategies( + op_schema + ) + + ( + query_strategy, + _, + _, + attn_bias_strategy, + compute_log_sumexp, + *rest_args, + ) = op_schema.args_schema + return_debug_mask = len(op_schema.args_schema) >= 8 and rest_args[2] + has_attn_bias = attn_bias_strategy is not None + + # Context Parallelism: shards on the sequence dim + logsumexp_sharding = Shard(SEQ_DIM) if compute_log_sumexp else Replicate() + debug_attn_mask_sharding = Shard(SEQ_DIM) if return_debug_mask else None + + cp_strategy: PlacementList = [ + Shard(SEQ_DIM), # output + logsumexp_sharding, # logsumexp + None, # cum_seq_q + None, # cum_seq_k + None, # max_q + None, # max_k + None, # philox_seed + None, # philox_offset + debug_attn_mask_sharding, # debug_attn_mask + Shard(SEQ_DIM), # q + Shard(SEQ_DIM), # k + Shard(SEQ_DIM), # v + ] + if has_attn_bias: + cp_strategy.append(Replicate()) # attn_bias - not sharded for CP + single_mesh_dim_strategies.append(cp_strategy) + + return expand_to_full_mesh_op_strategy( + mesh, op_schema, single_mesh_dim_strategies, input_index=9 + ) + + +def _scaled_dot_product_cudnn_attention_backward_cp_strategy( + op_schema: OpSchema, +) -> OpStrategy: + """ + Strategy for cudnn attention backward with Context Parallelism support. + """ + from torch.distributed.tensor._ops._matrix_ops import ( + _scaled_dot_product_cudnn_attention_backward_base_strategies, + ) + + mesh = op_schema.get_mesh_from_args(validate=False) + single_mesh_dim_strategies = ( + _scaled_dot_product_cudnn_attention_backward_base_strategies(op_schema) + ) + + has_attn_bias = op_schema.args_schema[8] is not None + has_scale = len(op_schema.args_schema) >= 16 and False + + # Context Parallelism: shards on the sequence dim + cp_sharding_gout: PlacementList = [Shard(SEQ_DIM)] * 3 # grad_q, grad_k, grad_v + cp_sharding_ginp: PlacementList = [ + Shard(SEQ_DIM) + ] * 6 # grad_output, q, k, v, output, logsumexp + cp_sharding_ginp += [Replicate()] * 2 # philox_seed, philox_offset + cp_sharding_ginp += [Shard(SEQ_DIM) if has_attn_bias else None] # attn_bias + cp_sharding_ginp += [ + None + ] * 6 # cum_seq_q, cum_seq_k, max_q, max_k, dropout_p, is_causal + if has_scale: + cp_sharding_ginp.append(None) + + cp_sharding = cp_sharding_gout + cp_sharding_ginp + single_mesh_dim_strategies.append(cp_sharding) + + return expand_to_full_mesh_op_strategy( + mesh, op_schema, single_mesh_dim_strategies, input_index=3 + ) + + +# Store context managers and original strategies +_cp_strategy_contexts = {} +_original_strategies = {} + + +def register_cp_sharding_rules(): + """Register Context Parallelism sharding rules for all scaled_dot_product ops.""" + global _cp_strategy_contexts, _original_strategies + + # If already registered, don't register again + if _cp_strategy_contexts: + return + + # Define ops and their corresponding CP strategy functions + cp_strategies = [ + ( + aten._scaled_dot_product_flash_attention.default, + _scaled_dot_product_flash_attention_cp_strategy, + RuntimeSchemaInfo(5), + ), + ( + aten._scaled_dot_product_flash_attention_backward.default, + _scaled_dot_product_flash_attention_backward_cp_strategy, + None, + ), + ( + aten._scaled_dot_product_efficient_attention.default, + _scaled_dot_product_efficient_attention_cp_strategy, + RuntimeSchemaInfo(4), + ), + ( + aten._scaled_dot_product_efficient_attention_backward.default, + _scaled_dot_product_efficient_attention_backward_cp_strategy, + None, + ), + ( + aten._scaled_dot_product_cudnn_attention.default, + _scaled_dot_product_cudnn_attention_cp_strategy, + RuntimeSchemaInfo(4), + ), + ( + aten._scaled_dot_product_cudnn_attention_backward.default, + _scaled_dot_product_cudnn_attention_backward_cp_strategy, + None, + ), + ] + + # Register each strategy + for op_overload, strategy_func, schema_info in cp_strategies: + ctx = _op_strategy_context(op_overload, strategy_func, schema_info) + orig_funcs, orig_schema = ctx.__enter__() + _cp_strategy_contexts[op_overload] = ctx + _original_strategies[op_overload] = (orig_funcs, orig_schema) + + +def unregister_cp_sharding_rules(clear_the_cache=False): + """Unregister Context Parallelism sharding rules and restore original strategies.""" + global _cp_strategy_contexts, _original_strategies + + # Exit all context managers + for ctx in _cp_strategy_contexts.values(): + ctx.__exit__(None, None, None) + + if clear_the_cache: + _clear_fast_path_sharding_prop_cache() + _clear_python_sharding_prop_cache() + + _cp_strategy_contexts = {} + _original_strategies = {} diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/parallel/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/parallel/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5e4881de43874ab238b1cfbe6003c9a8751f0c3b --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/parallel/__init__.py @@ -0,0 +1,25 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates +from torch.distributed.tensor.parallel.api import parallelize_module +from torch.distributed.tensor.parallel.loss import loss_parallel +from torch.distributed.tensor.parallel.style import ( + ColwiseParallel, + ParallelStyle, + PrepareModuleInput, + PrepareModuleInputOutput, + PrepareModuleOutput, + RowwiseParallel, + SequenceParallel, +) + + +__all__ = [ + "ColwiseParallel", + "ParallelStyle", + "PrepareModuleInput", + "PrepareModuleInputOutput", + "PrepareModuleOutput", + "RowwiseParallel", + "SequenceParallel", + "parallelize_module", + "loss_parallel", +] diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/parallel/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/parallel/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..62f417e473e11cf7c76aa2cb2ba4952a55ffe45c Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/parallel/__pycache__/__init__.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/parallel/__pycache__/_data_parallel_utils.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/parallel/__pycache__/_data_parallel_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d1d16987c2d1a79b350dbcdda96ddcec8f11e78e Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/parallel/__pycache__/_data_parallel_utils.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/parallel/__pycache__/api.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/parallel/__pycache__/api.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f21566022927336fd105bb2e8af1c0cbeb7fff0c Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/parallel/__pycache__/api.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/parallel/__pycache__/ddp.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/parallel/__pycache__/ddp.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..61a0f51a72666a8aba5a2902593eb6b8f1084041 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/parallel/__pycache__/ddp.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/parallel/__pycache__/fsdp.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/parallel/__pycache__/fsdp.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3f75c8c07bb79817c11a6126c57d01c3fc7f3401 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/parallel/__pycache__/fsdp.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/parallel/__pycache__/input_reshard.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/parallel/__pycache__/input_reshard.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..29bf4494aed6fcb190c36b21844fd534dc27009a Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/parallel/__pycache__/input_reshard.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/parallel/__pycache__/loss.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/parallel/__pycache__/loss.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..74b7ebefd675d112f22ffaec490e4f6315b075a2 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/parallel/__pycache__/loss.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/parallel/__pycache__/style.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/parallel/__pycache__/style.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0d3f971a5dcfaafce42599656b76ab31fea50da1 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/parallel/__pycache__/style.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/parallel/_data_parallel_utils.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/parallel/_data_parallel_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..735b74e099478ebc606d68b3099f721c59874297 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/parallel/_data_parallel_utils.py @@ -0,0 +1,51 @@ +from functools import partial +from typing import no_type_check + +import torch +from torch.distributed._functional_collectives import AsyncCollectiveTensor +from torch.distributed.tensor import DTensor +from torch.distributed.tensor._dtensor_spec import DTensorSpec + + +@no_type_check +def sync_grad_hook(grad, *, device_handle=None, compute_stream=None): + if isinstance(grad, AsyncCollectiveTensor): + if compute_stream is not None: + with device_handle.stream(compute_stream): + grad = grad.wait() + else: + grad = grad.wait() + + return grad + + +def _flatten_tensor( + tensor: torch.Tensor, +) -> tuple[torch.Tensor, DTensorSpec | None]: + if isinstance(tensor, DTensor): + tensor._local_tensor.requires_grad_() + return tensor._local_tensor, tensor._spec + return tensor, None + + +@no_type_check +def _unflatten_tensor(tensor, spec, *, device_handle=None, compute_stream=None): + # unflatten would mainly be called every time FSDP allgather parameters. + result = DTensor.from_local( + tensor, + spec.mesh, + spec.placements, + run_check=False, + shape=spec.shape, + stride=spec.stride, + ) + if tensor.requires_grad: + # only register the hook if the tensor requires grad + tensor.register_hook( + partial( + sync_grad_hook, + device_handle=device_handle, + compute_stream=compute_stream, + ) + ) + return result diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/parallel/api.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/parallel/api.py new file mode 100644 index 0000000000000000000000000000000000000000..954b62327808d13da7d56923efe35600085eee1e --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/parallel/api.py @@ -0,0 +1,142 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates +import warnings +from fnmatch import fnmatch + +import torch +import torch.nn as nn +from torch.distributed.device_mesh import _mesh_resources, DeviceMesh +from torch.distributed.tensor.parallel.style import ParallelStyle + + +__all__ = ["parallelize_module"] + + +def parallelize_module( # type: ignore[return] + module: nn.Module, + device_mesh: DeviceMesh | None = None, + parallelize_plan: ParallelStyle | dict[str, ParallelStyle] | None = None, + *, + src_data_rank: int | None = 0, +) -> nn.Module: + """ + Apply Tensor Parallelism in PyTorch by parallelizing modules or sub-modules based on a user-specified plan. + + We parallelize module or sub_modules based on a parallelize_plan. The parallelize_plan contains + :class:`ParallelStyle`, which indicates how user wants the module or sub_module + to be parallelized. + + User can also specify different parallel style per module fully qualified name (FQN). + + Note that ``parallelize_module`` only accepts a 1-D :class:`DeviceMesh`, if you have a 2-D or N-D :class:`DeviceMesh`, + slice the DeviceMesh to a 1-D sub DeviceMesh first then pass to this API(i.e. ``device_mesh[\"tp\"]``) + + Args: + module (:class:`nn.Module`): + Module to be parallelized. + device_mesh (:class:`DeviceMesh`, optional): + Object which describes the mesh topology of devices for the DTensor. + If not specified, the call must be under a DeviceMesh context. + parallelize_plan (Union[:class:`ParallelStyle`, Dict[str, :class:`ParallelStyle`]], optional): + The plan used to parallelize the module. It can be either a + :class:`ParallelStyle` object which contains how we prepare + input/output for Tensor Parallelism or it can be a dict of module + FQN and its corresponding :class:`ParallelStyle` object. If not + specified, the call will do nothing at the moment. + Keyword args: + src_data_rank (int, optional): the rank of the source data for the logical/global tensor, it is used by + :meth:`distribute_tensor` to scatter/broadcast the shards/replicas to other ranks. By default, + we use ``group_rank=0`` on each DeviceMesh dimension as the source data to preserve the single-device + semantic. If passing ``None`` explicitly, :meth:`parallelize_module` simply uses its local data instead + of trying to preserve the single-device semantic via scatter/broadcast. Default: 0 + Return: + A :class:`nn.Module` object parallelized. + + Example:: + >>> # xdoctest: +SKIP("distributed") + >>> from torch.distributed.tensor.parallel import parallelize_module, ColwiseParallel + >>> from torch.distributed.device_mesh import init_device_mesh + >>> + >>> # Define the module. + >>> m = Model(...) + >>> tp_mesh = init_device_mesh("cuda", (8,)) + >>> m = parallelize_module(m, tp_mesh, {"w1": ColwiseParallel(), "w2": RowwiseParallel()}) + >>> + + .. note:: For complex module architecture like Attention, MLP layers, we recommend composing + different ParallelStyles together (i.e. ``ColwiseParallel`` and ``RowwiseParallel``) and pass + as a parallelize_plan, to achieves the desired sharding computation. + """ + torch._C._log_api_usage_once("torch.distributed.tensor.parallel.parallelize_module") + + device_mesh = device_mesh or _mesh_resources.get_current_mesh() + + if parallelize_plan is None: + warnings.warn( + "No parallelize_plan is provided and auto-parallel is not supported " + "at the moment, so this parallelize_module call will do nothing.", + stacklevel=2, + ) + return module + + # note: The RNG tracker will be initialized in distribute_tensor() call if it hasn't + # been initialized. + + if isinstance(parallelize_plan, ParallelStyle): + parallelize_plan.src_data_rank = src_data_rank + return parallelize_plan._apply(module, device_mesh) + elif isinstance(parallelize_plan, dict): + for module_path, parallelize_style in parallelize_plan.items(): + if module_path == "": + # shortcut: empty string means to apply the plan to the current module + parallelize_module(module, device_mesh, parallelize_style) + continue + + path_splits = module_path.split(".") + # Instead of blindly popping tokens, first check the match, + # we only consume/pop the token if we found a match. + token = path_splits[0] + + matched_children = list( + filter( + # `t[0]` is child name + lambda t: fnmatch(t[0], token), + module.named_children(), + ) + ) + if not matched_children: + # No match at this level. Log a warning and process next plan entry. + warnings.warn( + f"Parallelize plan key '{module_path}' could not be resolved: " + f"no submodule matching token '{token}' in module {module}, " + f"skipping this plan entry.", + stacklevel=2, + ) + continue + + # Now that we have a match, we can consume the token. + path_splits.pop(0) + # apply the plan to all matched submodules + for _, submodule in matched_children: + if path_splits: + # we haven't reached the leaf, apply in dict style + leaf_path = ".".join(path_splits) # rest of the path after `token` + parallelize_module( + submodule, + device_mesh, + {leaf_path: parallelize_style}, + src_data_rank=src_data_rank, + ) + else: + # otherwise, directly apply style to this submodule + parallelize_module( + submodule, + device_mesh, + parallelize_style, + src_data_rank=src_data_rank, + ) + return module + else: + raise TypeError( # pyre-ignore[7] + "Expect Union[ParallelStyle, Dict[str, ParallelStyle]] for" + f" parallelize_plan, {type(parallelize_plan)} found!" + ) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/parallel/fsdp.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/parallel/fsdp.py new file mode 100644 index 0000000000000000000000000000000000000000..9e68ed6b1dba50f35981f3b633f089e852e57f7c --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/parallel/fsdp.py @@ -0,0 +1,391 @@ +# mypy: allow-untyped-defs +import copy +from typing import Any, cast + +import torch +import torch.distributed as dist +import torch.distributed._shard.sharding_spec as shard_spec +import torch.distributed.distributed_c10d as c10d +from torch.distributed._shard.sharded_tensor import ( + Shard, + ShardedTensor, + ShardedTensorMetadata, + TensorProperties, +) +from torch.distributed._shard.sharding_spec import ShardMetadata +from torch.distributed._shard.sharding_spec.chunk_sharding_spec import ChunkShardingSpec +from torch.distributed.fsdp._common_utils import _set_fsdp_flattened +from torch.distributed.fsdp._fsdp_extensions import FSDPExtensions +from torch.distributed.fsdp._shard_utils import _create_chunk_sharded_tensor +from torch.distributed.remote_device import _remote_device +from torch.distributed.tensor import DeviceMesh, DTensor, Replicate, Shard as DShard +from torch.distributed.tensor.parallel._data_parallel_utils import ( + _flatten_tensor, + _unflatten_tensor, +) + + +__all__ = ["DTensorExtensions"] + + +def _get_box(tensor: DTensor) -> tuple[torch.Size, torch.Size]: + device_mesh = tensor.device_mesh + assert device_mesh.ndim == 1, "Only 1D DeviceMeshes currently handled" + + placement = tensor.placements[0] + offsets = [0] * len(tensor.size()) + num_chunks = device_mesh.size(mesh_dim=0) + + if tensor.placements[0].is_shard(): + shard_dim = cast(DShard, placement).dim + chunk_size = tensor.size(shard_dim) // num_chunks + offsets[shard_dim] = chunk_size + + return (torch.Size(offsets), tensor._local_tensor.size()) + + +def _get_box_for(tensor: DTensor, idx: int) -> tuple[torch.Size, torch.Size]: + offsets, size = _get_box(tensor) + return (torch.Size([val * idx for val in offsets]), size) + + +def _get_local_box(tensor: DTensor) -> tuple[torch.Size, torch.Size]: + device_mesh = tensor.device_mesh + coord = device_mesh.get_coordinate() + assert coord is not None + return _get_box_for(tensor, coord[0]) + + +def _create_shard_md_from_dt(dt: DTensor, current_rank: int) -> ShardMetadata: + mesh = dt.device_mesh + assert mesh.ndim == 1, "Only 1D DeviceMeshes currently handled" + + offsets, sizes = _get_local_box(dt) + return ShardMetadata( + shard_offsets=list(offsets), + shard_sizes=list(sizes), + placement=f"rank:{current_rank}/{dt._local_tensor.device}", + ) + + +def _create_sharded_tensor_md_from_dt( + dt: DTensor, dt_pg: c10d.ProcessGroup +) -> ShardedTensorMetadata: + # This is where it gets tricky, we have to produce a ShardedTensor that has full coverage + # and yet has only one valid shard for the current rank. + + shards_md = [] + my_rank = dist.get_rank(dt_pg) + scapegoat_rank = 0 if my_rank > 0 else 1 + + if dt.placements[0].is_shard(): + shard_count = dt_pg.size() + else: + shard_count = 1 + + for i in range(shard_count): + offsets, sizes = _get_box_for(dt, i) + shards_md.append( + ShardMetadata( + shard_offsets=list(offsets), + shard_sizes=list(sizes), + placement=( + f"rank:{scapegoat_rank if i > 0 else my_rank}/{dt._local_tensor.device}" + ), + ) + ) + + return ShardedTensorMetadata( + shards_metadata=shards_md, + size=dt.size(), + tensor_properties=TensorProperties( + dtype=dt.dtype, + layout=dt.layout, + requires_grad=dt.requires_grad, + # ignore memory_format and pin_memory as those are not supported by DT + ), + ) + + +def _get_dt_pg(dt: DTensor) -> c10d.ProcessGroup: + mesh = dt.device_mesh + assert mesh.ndim == 1, "Only 1D DeviceMeshes currently handled" + return mesh.get_group() + + +def _rewrite_spec_if_needed( + spec: shard_spec.ShardingSpec, tensor: torch.Tensor, rank: int +) -> shard_spec.ShardingSpec: + """ + Rewrite ``spec`` to match the device of ``tensor``. + + FSDP.sharded_optim_state_dict sneakly ships optimizer state to CPU so if the original ShardingSpec + produces CUDA metadata, ST construction bombs. + """ + if not isinstance(spec, ChunkShardingSpec): + return spec + + # let's see if we need + rewrite = False + for p in spec.placements: + p = cast(_remote_device, p) + if p.rank() == rank and p.device() != tensor.device: + rewrite = True + break + if rewrite: + spec = copy.deepcopy(spec) + # pyrefly: ignore [missing-attribute] + for i, placement in enumerate(spec.placements): + placement = cast(_remote_device, placement) + if placement.rank() == rank and placement.device() != tensor.device: + # pyrefly: ignore [missing-attribute] + spec.placements[i] = _remote_device(f"rank:{rank}/{tensor.device}") + + return spec + + +def _chunk_tensor( + tensor: torch.Tensor, + rank: int, + world_size: int, + num_devices_per_node: int, + pg: dist.ProcessGroup, +) -> torch.Tensor: + if type(tensor) is ShardedTensor: + assert len(tensor.local_shards()) == 1 + + inner_param = tensor.local_tensor() + inner_st = _create_chunk_sharded_tensor( + inner_param, + rank, + world_size, + num_devices_per_node, + pg, + ) + + outer_local_shard = tensor.local_shards()[0] + shards: list[Shard] = [ + Shard(inner_st, copy.deepcopy(outer_local_shard.metadata)) + ] + st_meta = copy.deepcopy(tensor.metadata()) + st_meta.tensor_properties.requires_grad = False + + st_outer = ShardedTensor._init_from_local_shards_and_global_metadata( + shards, + sharded_tensor_metadata=st_meta, + process_group=tensor._process_group, + init_rrefs=False, + ) + return st_outer + elif type(tensor) is DTensor: + device_mesh = tensor.device_mesh + assert device_mesh.ndim == 1, "Only 1D DeviceMeshes currently handled" + + inner_param = tensor._local_tensor + + inner_st = _create_chunk_sharded_tensor( + inner_param, + rank, + world_size, + torch.accelerator.device_count(), + pg, + ) + + dt_pg = _get_dt_pg(tensor) + # We do this differently here, we create a ST with no local shards then patch it + shards = [ + Shard(inner_st, _create_shard_md_from_dt(tensor, dist.get_rank(dt_pg))) + ] + + st_meta = _create_sharded_tensor_md_from_dt(tensor, dt_pg) + st_meta.tensor_properties.requires_grad = False + + st_outer = ShardedTensor._init_from_local_shards_and_global_metadata( + shards, + sharded_tensor_metadata=st_meta, + process_group=dt_pg, + init_rrefs=False, + ) + + return st_outer + else: + return _create_chunk_sharded_tensor( + tensor, + rank, + world_size, + num_devices_per_node, + pg, + ) + + +def _chunk_dtensor( + tensor: torch.Tensor, + rank: int, + device_mesh: DeviceMesh, +) -> DTensor: + """ + Shard a tensor to chunks along the first dimension. + + The local rank will gets its corresponding chunk as the local tensor to create a DTensor. + """ + root_mesh = device_mesh._get_root_mesh() if device_mesh is not None else None + if root_mesh is None: + raise RuntimeError("No parent device_mesh is found for FSDP device_mesh.") + if root_mesh.ndim < 2: + raise RuntimeError( + f"Found parent device_mesh of ndim={root_mesh.ndim},", + "but meshes must be at least 2D.", + ) + + # We need to explicitly call .detach() to return a new tensor detached from the current graph. + tensor = tensor.detach().clone() + + # When a layer is not involved in TP, then the tensor will not be a DTensor. + # e.g. When a layer is not sppecified in the parallelize_plan, TP will have no effect on the layer. + # e.g. When you do PairwiseParallel on a 3 layer model, TP will have no effect on the third layer. + if isinstance(tensor, torch.Tensor) and not isinstance(tensor, DTensor): + # For tensors, it is replicated across tp dimension and sharded across FSDP dimension. + # TP is the inner dimension and FSDP is the outer dimension. + # Therefore, shard placements for tensor is (Shard(0), Replicate()). + replicate_placements = [Replicate() for _ in range(root_mesh.ndim)] + shard_placements = [Replicate() for _ in range(root_mesh.ndim)] + shard_placements[0] = DShard(0) # type: ignore[call-overload] + + return DTensor.from_local( + tensor, root_mesh, replicate_placements, run_check=False + ).redistribute( + device_mesh=root_mesh, + placements=shard_placements, + ) + + else: + tp_placements = tensor.placements + tp_placement = tp_placements[0] + + tensor = tensor.to_local() + + # For DTensors, it is sharded across tp dimension first and then sharded across FSDP dimension. + # TP is the inner dimension and FSDP is the outer dimension. + # Therefore, shard placements for tensor is (Shard(0), tp_placement). + # For higher dimensional meshes, it is replicated across other dimensions. For example, with + # HSDP the shard placements for tensor is (Replicate, Shard(0), tp_placement). + replicate_placements = [Replicate() for _ in range(root_mesh.ndim)] + replicate_placements[-1] = tp_placement # type: ignore[call-overload] + shard_placements = [Replicate() for i in range(root_mesh.ndim)] # type: ignore[misc] + shard_placements[-2] = DShard(0) # type: ignore[call-overload] + shard_placements[-1] = tp_placement # type: ignore[call-overload] + + return DTensor.from_local( + tensor, root_mesh, replicate_placements, run_check=False + ).redistribute( + device_mesh=root_mesh, + placements=shard_placements, + ) + + +def _pre_load_state_dict( + tensor: torch.Tensor, +) -> tuple[torch.Tensor, list[Shard]]: + shards = cast(ShardedTensor, tensor).local_shards() + if len(shards) == 1 and type(shards[0].tensor) is ShardedTensor: + inner_tensor = shards[0].tensor + shards = inner_tensor.local_shards() # pyre-ignore[16] + tensor = inner_tensor + + return (tensor, shards if len(shards) > 0 else []) + + +def _all_gather_dtensor( + tensor: DTensor, + parent_mesh: DeviceMesh | None, +) -> torch.Tensor: + """All gather a DTensor in its FSDP dimension and return the local tensor.""" + assert parent_mesh == tensor.device_mesh + + placements = list(copy.deepcopy(tensor.placements)) + # FSDP + TP: [Shard(0), tp_placement] -> [Replicate(), tp_placement] + # HSDP + TP: [Replicate(), Shard(0), tp_placement] -> [Replicate(), Replicate(), tp_placement] + for i in range(len(placements) - 1): + placements[i] = Replicate() + tensor = tensor.redistribute( + device_mesh=tensor.device_mesh, + placements=placements, + ) + + return tensor.to_local() + + +class DTensorExtensions(FSDPExtensions): + """ + DTensorExtension is the TensorFlattener extension needed for 2D FSDP + TP. + + This is the implementation for FSDPExtensions defined in + https://github.com/pytorch/pytorch/blob/main/torch/distributed/fsdp/_fsdp_extensions.py + """ + + def __init__(self, device_handle) -> None: + super().__init__() + self.compute_stream = None + self.device_handle = device_handle + # we have to use the dynamo disable this way to disable dynamo as the decorator way would + # trigger build failure with torch deploy... + self.post_unflatten_transform = torch._dynamo.disable( # type: ignore[method-assign] + self.post_unflatten_transform + ) + + def pre_flatten_transform( + self, + tensor: torch.Tensor, + ) -> tuple[torch.Tensor, Any | None]: + return _flatten_tensor(tensor) + + def post_unflatten_transform( + self, tensor: torch.Tensor, param_extension: Any + ) -> torch.Tensor: + stream = self.compute_stream or self.device_handle.current_stream() + with self.device_handle.stream(stream): + # runtime we put the unflattened tensor call on the compute stream since + # the unflattened tensor might contain computations in fwd/bwd where we + # need to sync properly. + # TODO: this is a short term fix and we should make the get_unflat_views + # directly happen in the compute stream. + result = _unflatten_tensor( + tensor, + param_extension, + device_handle=self.device_handle, + compute_stream=self.compute_stream, + ) + _set_fsdp_flattened(result) + return result + + def chunk_tensor( + self, + tensor: torch.Tensor, + rank: int, + world_size: int, + num_devices_per_node: int, + pg: dist.ProcessGroup, + device: torch.device | None = None, + ) -> torch.Tensor: + return _chunk_tensor(tensor, rank, world_size, num_devices_per_node, pg) + + def chunk_dtensor( + self, + tensor: torch.Tensor, + rank: int, + device_mesh: DeviceMesh, + ) -> torch.Tensor: + return _chunk_dtensor(tensor, rank, device_mesh) + + def pre_load_state_dict_transform( + self, + tensor: torch.Tensor, + ) -> tuple[torch.Tensor, list[Shard]]: + return _pre_load_state_dict(tensor) + + def all_gather_dtensor( + self, + tensor: DTensor, + parent_mesh: DeviceMesh | None, + ) -> torch.Tensor: + return _all_gather_dtensor(tensor, parent_mesh) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/parallel/input_reshard.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/parallel/input_reshard.py new file mode 100644 index 0000000000000000000000000000000000000000..81e25621e040abd767e033ee2efec56e466262b6 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/parallel/input_reshard.py @@ -0,0 +1,106 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates +from functools import partial +from typing import Any + +import torch +from torch.distributed.tensor import DeviceMesh, DTensor, Replicate, Shard + + +__all__ = [ + "input_reshard", +] + + +def input_reshard( + module: torch.nn.Module, + tp_device_mesh: DeviceMesh, + input_reshard_dim: int | None = None, +) -> torch.nn.Module: + """ + Register hooks to an nn.Module for input resharding, enabling sharding and restoration during backward computation. + + Register hooks to an nn.Module with input resharding so that we can shard + per the given `tp_device_mesh` and `input_reshard_dim` and restore the + input back when recomputing the activations in the backward. The reason + why we can do this is that for Tensor Parallel(TP), the input are same + across all TP ranks. + + Args: + module (:class:`nn.Module`): + Module to be registered with input resharding. + tp_device_mesh (:class:`DeviceMesh`): + Object which describes the mesh topology + of devices for Tensor Parallel. + input_reshard_dim (Optional[int]): + The dimension of where we perform the sharding + of input. If set None, there is no sharding of input. + Default: None + + Return: + A :class:`nn.Module` object registered with TP input resharding. + """ + if input_reshard_dim is None: + return module + + cx: torch.autograd.graph.saved_tensors_hooks | None = None + + def input_reshard_forward_pre_hook(_: torch.nn.Module, _i: tuple[Any, ...]) -> None: + saved_tensor_hooks = torch.autograd.graph.saved_tensors_hooks( + partial(_pack_hook_tp, tp_device_mesh, input_reshard_dim), + partial(_unpack_hook_tp, tp_device_mesh, input_reshard_dim), + ) + saved_tensor_hooks.__enter__() + nonlocal cx + cx = saved_tensor_hooks # type: ignore[name-defined] + + def input_reshard_backward_hook( + _: torch.nn.Module, _i: tuple[Any, ...], _o: Any + ) -> Any: + nonlocal cx + cx.__exit__() # type: ignore[name-defined, union-attr] + + module.register_forward_pre_hook(input_reshard_forward_pre_hook) + module.register_forward_hook(input_reshard_backward_hook) + return module + + +def _pack_hook_tp(mesh: DeviceMesh, input_reshard_dim: int, x: torch.Tensor) -> Any: # noqa: D401 + """Hook function called after FWD to shard input.""" + if isinstance(x, DTensor) and all(p.is_replicate() for p in x._spec.placements): + return x.redistribute(device_mesh=mesh, placements=[Shard(input_reshard_dim)]) + elif ( + not isinstance(x, DTensor) + and isinstance(x, torch.Tensor) + and x.numel() >= mesh.size() + ): + return ( + DTensor.from_local(x, device_mesh=mesh) + .redistribute(device_mesh=mesh, placements=[Shard(input_reshard_dim)]) + .to_local() + ) + else: + return x + + +def _unpack_hook_tp(mesh: DeviceMesh, input_reshard_dim: int, x: Any) -> torch.Tensor: # noqa: D401 + """Hook function called before activation recomputing in BWD to restore input.""" + if ( + isinstance(x, DTensor) + and len(x._spec.placements) == 1 + and x._spec.placements[0].is_shard() + ): + return x.redistribute(device_mesh=mesh, placements=[Replicate()]) + elif ( + not isinstance(x, DTensor) + and isinstance(x, torch.Tensor) + and x.numel() >= mesh.size() + ): + return ( + DTensor.from_local( + x, device_mesh=mesh, placements=[Shard(input_reshard_dim)] + ) + .redistribute(device_mesh=mesh, placements=[Replicate()]) + .to_local() + ) + else: + return x diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/parallel/loss.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/parallel/loss.py new file mode 100644 index 0000000000000000000000000000000000000000..9c1adbf2a672a5bbe2004f17b1652c29803c9b33 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/parallel/loss.py @@ -0,0 +1,505 @@ +# mypy: allow-untyped-defs +# Copyright (c) Meta Platforms, Inc. and affiliates +import contextlib +from typing import cast + +import torch +import torch._prims_common as utils +import torch.distributed._functional_collectives as funcol +import torch.distributed.distributed_c10d as c10d +from torch import Tensor +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor import DTensor, Replicate, Shard +from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta +from torch.distributed.tensor._ops._embedding_ops import MaskPartial +from torch.distributed.tensor._ops._math_ops import ( + _skip_dim, + Reduction, + replicate_reduction_dims, +) +from torch.distributed.tensor._ops.utils import normalize_dim +from torch.distributed.tensor.placement_types import Placement + + +aten = torch.ops.aten + + +__all__ = ["loss_parallel"] + + +@contextlib.contextmanager +def loss_parallel(): + """ + A context manager that enables loss parallelism, where efficient parallelized loss computation + can be performed when the input is sharded on the class dimension. Currently only the cross-entropy + loss is supported. + + Within this context manager, one can use :func:`~torch.nn.functional.cross_entropy` or + :class:`~torch.nn.CrossEntropyLoss` as usual, with the following assumptions on the input parameters. + The corresponding ``backward()`` call, if any, also needs to happen under this context manager. + + Args: + input (:class:`DTensor`): + Input logits. Assumed to be sharded on the class dimension. + target (Union[:class:`torch.Tensor`, :class:`DTensor`]): + Must be ground truth class indices (class probabilities currently not supported). + Assumed to be replicated across the ``DeviceMesh``. + weight (Union[:class:`torch.Tensor`, :class:`DTensor`], optional): + If given, assumed to be replicated across the ``DeviceMesh``. + label_smoothing: + Currently not supported. + + Returns: + A replicated :class:`DTensor`. + + Example: + A sharded DTensor is manually created here to showcase the usage. + In practice, it is usually the output of a TP module. + + >>> # xdoctest: +SKIP("distributed") + >>> from torch.distributed.tensor.parallel import loss_parallel + >>> from torch.distributed.device_mesh import init_device_mesh + >>> ... + >>> device_mesh = init_device_mesh("cuda", (8,)) + >>> input = torch.randn(4, 16, device="cuda", requires_grad=True) + >>> dist_input = distribute_tensor(input, device_mesh, placements=[Shard(1)]) + >>> target = torch.randint(16, (4,), device="cuda") + >>> with loss_parallel(): + >>> loss = F.cross_entropy(dist_input, target, reduction="mean") + >>> loss.backward() + >>> ... + """ + _enable_custom_loss_ops() + + yield + + _disable_custom_loss_ops() + + +# Currently only needs to support one dimensional DeviceMesh; in general return +# the mesh_dim with placements[mesh_dim].is_shard(dim) +def _find_all_reduce_mesh_dim(placements: tuple[Placement, ...], dim: int) -> int: + if not len(placements) == 1: + raise ValueError( + "Currently loss_parallel() only supports input on one-dimensional DeviceMesh." + ) + if not placements[0].is_shard(dim): + raise ValueError( + f"loss_parallel() should be enabled only when the input tensor is sharded on dimension {dim}." + ) + return 0 + + +def _cast_to_dtensor( + tensor, placements: tuple[Placement, ...], mesh: DeviceMesh +) -> DTensor: + if isinstance(tensor, DTensor): + if tensor.placements == placements: + return tensor + else: + raise RuntimeError(f"Expected {placements} but got {tensor.placements}.") + elif isinstance(tensor, torch.Tensor): + return DTensor.from_local( + tensor, device_mesh=mesh, placements=placements, run_check=False + ) + else: + raise TypeError(f"Unsupported type {type(tensor)}") + + +def _propagate_tensor_meta( + op_call: torch._ops.OpOverload, + args: tuple[object, ...], + kwargs: dict[str, object], +) -> TensorMeta: + op_info = DTensor._op_dispatcher.unwrap_to_op_info(op_call, args, kwargs) + tensor_meta = DTensor._op_dispatcher.sharding_propagator.propagate_tensor_meta( + op_info.schema + ) + if isinstance(tensor_meta, TensorMeta): + return tensor_meta + elif isinstance(tensor_meta, tuple): + return tensor_meta[0] + else: + raise RuntimeError(f"Unexpected tensor meta type: {type(tensor_meta)}.") + + +# NOTE: The implementation follows torch._decomp.decomposition._log_softmax, +# with all_reduce manually inserted to perform distributed computation. +def _log_softmax(x, dim, half_to_float, mesh, mesh_dim): + if half_to_float: + assert x.dtype == torch.half + computation_dtype, result_dtype = utils.elementwise_dtypes( + x, type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT + ) + x = x.to(dtype=computation_dtype, memory_format=torch.contiguous_format) + if x.numel() == 0: + shifted = x + else: + x_max = torch.amax(x, dim, keepdim=True) + x_max = funcol.all_reduce( + x_max, reduceOp=c10d.ReduceOp.MAX.name, group=(mesh, mesh_dim) + ) + shifted = x - x_max + shifted_sumexp = torch.sum(torch.exp(shifted), dim, keepdim=True) + shifted_sumexp = funcol.all_reduce( + shifted_sumexp, reduceOp=c10d.ReduceOp.SUM.name, group=(mesh, mesh_dim) + ) + shifted_logsumexp = torch.log(shifted_sumexp) + result = shifted - shifted_logsumexp + if not half_to_float: + result = result.to(result_dtype) + return result + + +def _log_softmax_handler( + op_call: torch._ops.OpOverload, + args: tuple[object, ...], + kwargs: dict[str, object], +) -> object: + x = cast(DTensor, args[0]) + dim = cast(int, args[1]) + half_to_float = cast(bool, args[2]) + + spec = x._spec + dim = normalize_dim(dim, x.dim()) + mesh_dim = _find_all_reduce_mesh_dim(spec.placements, dim) + + output_tensor_meta = _propagate_tensor_meta(op_call, args, kwargs) + + res = _log_softmax(x._local_tensor, dim, half_to_float, spec.mesh, mesh_dim) + + res_spec = DTensorSpec( + spec.mesh, + spec.placements, + tensor_meta=output_tensor_meta, + ) + + # pyrefly: ignore [bad-argument-type] + return DTensor( + # pyrefly: ignore [bad-argument-count] + res, + res_spec, + # pyrefly: ignore [unexpected-keyword] + requires_grad=res.requires_grad, + ) + + +# NOTE: As explained below at _nll_loss_and_log_softmax_backward, the +# _log_softmax_backward_handler does not actually do any computation. +def _log_softmax_backward_handler( + op_call: torch._ops.OpOverload, + args: tuple[object, ...], + kwargs: dict[str, object], +) -> object: + grad_output = cast(DTensor, args[0]) + input_dtype = cast(torch.dtype, args[3]) + return grad_output.to(input_dtype) + + +# NOTE: The implementation follows torch._decomp.decomposition._nll_loss_forward, +# with customized communication inserted to perform distributed computation. +def _nll_loss_forward( + x: Tensor, + target: Tensor, + weight: Tensor | None, + local_weight: Tensor | None, + reduction: int, + ignore_index: int, + input_shape: torch.Size, + channel_dim: int, + mesh: DeviceMesh, + mesh_dim: int, +) -> tuple[Tensor, Tensor]: + n_dims = x.dim() + channel_dim = 1 + if n_dims < 2: + channel_dim = 0 + + def _weight_view(weight: Tensor) -> Tensor: + if n_dims > 1: + shape = [ + 1, + ] * n_dims + shape[channel_dim] = weight.shape[0] + w = weight.view(shape) + else: + w = weight + return w + + if weight is not None: + w = _weight_view(weight) + assert local_weight is not None + local_w = _weight_view(local_weight) + x = x * local_w + safe_target = torch.where(target != ignore_index, target, 0) + safe_target_ = safe_target.unsqueeze(channel_dim) + + # The following code block is a distributed version of + # result = -torch.gather(self, channel_dim, safe_target_).squeeze(channel_dim) + partial_placement = MaskPartial(offset_shape=input_shape, offset_dim=channel_dim) + safe_target_partial_ = partial_placement._partition_value( + safe_target_, mesh, mesh_dim + ) + result_partial = torch.gather(x, channel_dim, safe_target_partial_) + # an all_reduce happens here + result_reduced = partial_placement._reduce_value(result_partial, mesh, mesh_dim) + result = -result_reduced.squeeze(channel_dim) + + result = torch.where(target != ignore_index, result, 0) + + if reduction == Reduction.NONE.value and n_dims > 1: + total_weight = x.new_full((), 0.0) + return result, total_weight + + if weight is not None: + new_shape = list(x.shape) + new_shape[channel_dim] = -1 + # pyrefly: ignore [unbound-name] + w = w.expand(new_shape) + wsum = torch.gather(w, channel_dim, safe_target_).squeeze(channel_dim) + wsum = torch.where(target != ignore_index, wsum, 0) + total_weight = wsum.sum() + else: + total_weight = (target != ignore_index).sum().to(x) + + # NOTE: this is correct only on 1D DeviceMesh; o/w additional + # all-reduce on result and total_weight is needed + if reduction == Reduction.SUM.value: + result = result.sum() + elif reduction == Reduction.MEAN.value: + result = result.sum() / total_weight + + return result, total_weight + + +def _nll_loss_forward_handler( + op_call: torch._ops.OpOverload, + args: tuple[object, ...], + kwargs: dict[str, object], +) -> object: + x = cast(DTensor, args[0]) + target = args[1] + weight = args[2] + reduction = cast(int, args[3]) + ignore_index = cast(int, args[4]) + + channel_dim = 1 if x.dim() >= 2 else 0 + spec = x._spec + mesh_dim = _find_all_reduce_mesh_dim(spec.placements, channel_dim) + + # Check user input: if target and weight are not DTensors, convert them to DTensors; + # if they are DTensors, check that they have the desired placements. + target_placements = _skip_dim( + replicate_reduction_dims(spec.placements, [channel_dim]), channel_dim + ) + all_replicate_placements = (Replicate(),) * spec.mesh.ndim + target = _cast_to_dtensor(target, target_placements, spec.mesh) + local_weight = None + if weight is not None: + weight = _cast_to_dtensor(weight, all_replicate_placements, spec.mesh) + # For local computation, both (replicated) weight and (sharded) local_weight + # are needed in _nll_loss_forward(). local_weight is generated here using + # DTensor API, without incurring any communication. + sharded_placements = [ + Shard(0) if i == mesh_dim else Replicate() for i in range(spec.mesh.ndim) + ] + local_weight = weight.redistribute(spec.mesh, sharded_placements)._local_tensor + assert local_weight.shape[0] == x._local_tensor.shape[channel_dim] + + if reduction == Reduction.NONE.value: + output_placements = target_placements + else: + output_placements = all_replicate_placements + + # tensor inputs to _propagate_tensor_meta need to be DTensors + # pyrefly: ignore [bad-assignment] + args = list(args) + # pyrefly: ignore [unsupported-operation] + args[1], args[2] = target, weight + output_tensor_meta = _propagate_tensor_meta(op_call, tuple(args), kwargs) + + result, total_weight = _nll_loss_forward( + x._local_tensor, + target._local_tensor, + weight._local_tensor if weight is not None else None, + local_weight, + reduction, + ignore_index, + x.shape, + channel_dim, + spec.mesh, + mesh_dim, + ) + out_spec = DTensorSpec(spec.mesh, output_placements, tensor_meta=output_tensor_meta) + + return ( + # pyrefly: ignore [bad-argument-type] + DTensor( + # pyrefly: ignore [bad-argument-count] + result, + out_spec, + # pyrefly: ignore [unexpected-keyword] + requires_grad=result.requires_grad, + ), + total_weight, + ) + + +# NOTE: The backward computation of cross_entropy goes through two steps: +# backward for nll_loss and then backward for log_softmax. In loss parallel, +# the two steps are fused into the following function (called by _nll_loss_backward_handler) +# to avoid communication when target contains class indices not class probabilities. +# Also note that the _log_softmax_backward_handler does not perform computation. +# The implementation resembles _nll_loss_backward and _log_softmax_backward_data +# from torch._decomp.decomposition. +def _nll_loss_and_log_softmax_backward( + grad_output: Tensor, + x: Tensor, + target: Tensor, + weight: Tensor | None, + reduction: int, + ignore_index: int, + total_weight: Tensor, + input_shape: torch.Size, + channel_dim: int, + mesh: DeviceMesh, + mesh_dim: int, +) -> Tensor: + channel_dim = 0 if x.dim() < 2 else 1 + if reduction == Reduction.MEAN.value: + grad_output = grad_output / total_weight + + target = target.unsqueeze(channel_dim) + safe_target = torch.where(target != ignore_index, target, 0) + grad_input = torch.zeros_like(x) + + # The following code block is a distributed version of + # grad_input = torch.scatter(grad_input, channel_dim, safe_target, -1.0) + partial_placement = MaskPartial(offset_shape=input_shape, offset_dim=channel_dim) + safe_target = safe_target.squeeze(channel_dim).flatten() + masked_safe_target = partial_placement._partition_value(safe_target, mesh, mesh_dim) + # only update grad_input to -1 if not masked + assert partial_placement.mask_buffer.data is not None + grad_update = partial_placement.mask_buffer.data.to(grad_input.dtype) - 1.0 + arange_1d = torch.arange( + masked_safe_target.shape[0], device=masked_safe_target.device + ) + # The first two cases with x.dim() <= 2 are for aten.nll_loss_backward.default; + # the last case is for aten.nll_loss2d_backward.default. + if x.dim() == 1: + grad_input[masked_safe_target] = grad_update + elif x.dim() == 2: + grad_input[arange_1d, masked_safe_target] = grad_update + else: + grad_input_t = grad_input.transpose(channel_dim, -1) + intermidate_shape = grad_input_t.shape + grad_input_2d = grad_input_t.reshape(-1, x.shape[channel_dim]) + grad_input_2d[arange_1d, masked_safe_target] = grad_update + grad_input = grad_input_2d.view(intermidate_shape).transpose(channel_dim, -1) + + if grad_input.dim() > grad_output.dim() > 0: + grad_output = grad_output.unsqueeze(channel_dim) + + if weight is not None: + new_shape = [1 for _ in range(x.dim())] + new_shape[channel_dim] = weight.shape[0] + weight = weight.reshape(new_shape) + # In order for fused computation to work, the following line is rewritten. + # grad_output = grad_output * weight + new_shape = list(x.shape) + new_shape[channel_dim] = -1 + w = weight.expand(new_shape) + w_target = torch.gather(w, channel_dim, target) + grad_output = grad_output * w_target + + grad_output = torch.where(target != ignore_index, grad_output, 0) + + # NOTE: Instead of directly returning the grad_input as grad_output for log_softmax, + # here we perform backward computation for log_softmax altogether to avoid the + # otherwise extra all_gather communication. + # return grad_input * grad_output + return (grad_input + torch.exp(x)) * grad_output + + +def _nll_loss_backward_handler( + op_call: torch._ops.OpOverload, + args: tuple[object, ...], + kwargs: dict[str, object], +) -> object: + grad_output = cast(DTensor, args[0]) + x = cast(DTensor, args[1]) + target = args[2] + weight = args[3] + reduction = cast(int, args[4]) + ignore_index = cast(int, args[5]) + total_weight = cast(Tensor, args[6]) + + channel_dim = 1 if x.dim() >= 2 else 0 + spec = x._spec + mesh_dim = _find_all_reduce_mesh_dim(spec.placements, channel_dim) + + # if target and weight are not DTensors, convert them to DTensors + target_placements = _skip_dim( + replicate_reduction_dims(spec.placements, [channel_dim]), channel_dim + ) + all_replicate_placements = (Replicate(),) * spec.mesh.ndim + target = _cast_to_dtensor(target, target_placements, spec.mesh) + if weight is not None: + weight = _cast_to_dtensor(weight, all_replicate_placements, spec.mesh) + + # tensor inputs to _propagate_tensor_meta need to be DTensors + # pyrefly: ignore [bad-assignment] + args = list(args) + # pyrefly: ignore [unsupported-operation] + args[2], args[3] = target, weight + # pyrefly: ignore [unsupported-operation] + args[6] = _cast_to_dtensor(total_weight, all_replicate_placements, spec.mesh) + output_tensor_meta = _propagate_tensor_meta(op_call, tuple(args), kwargs) + + result = _nll_loss_and_log_softmax_backward( + grad_output._local_tensor, + x._local_tensor, + target._local_tensor, + weight._local_tensor if weight is not None else None, + reduction, + ignore_index, + total_weight, + x.shape, + channel_dim, + spec.mesh, + mesh_dim, + ) + # the output sharding is the same as input sharding: Shard(channel_dim) on mesh_dim + out_spec = DTensorSpec( + spec.mesh, + spec.placements, + tensor_meta=output_tensor_meta, + ) + + # pyrefly: ignore [bad-argument-type] + return DTensor( + # pyrefly: ignore [bad-argument-count] + result, + out_spec, + # pyrefly: ignore [unexpected-keyword] + requires_grad=result.requires_grad, + ) + + +customized_loss_ops = { + aten._log_softmax.default: _log_softmax_handler, + aten._log_softmax_backward_data.default: _log_softmax_backward_handler, + aten.nll_loss_forward.default: _nll_loss_forward_handler, + aten.nll_loss2d_forward.default: _nll_loss_forward_handler, + aten.nll_loss_backward.default: _nll_loss_backward_handler, + aten.nll_loss2d_backward.default: _nll_loss_backward_handler, +} + + +def _enable_custom_loss_ops(): + DTensor._op_dispatcher._custom_op_handlers.update(customized_loss_ops) + + +def _disable_custom_loss_ops(): + for custom_op in customized_loss_ops: + DTensor._op_dispatcher._custom_op_handlers.pop(custom_op) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/parallel/style.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/parallel/style.py new file mode 100644 index 0000000000000000000000000000000000000000..9eed832eabe8653c9e02ee0bb72b2e1256f76275 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/parallel/style.py @@ -0,0 +1,810 @@ +# mypy: allow-untyped-defs +# Copyright (c) Meta Platforms, Inc. and affiliates +from abc import ABC, abstractmethod +from functools import partial +from typing import Any + +import torch +import torch.nn as nn +from torch.distributed.tensor import ( + DeviceMesh, + distribute_module, + distribute_tensor, + DTensor, + Replicate, + Shard, +) +from torch.distributed.tensor.placement_types import Placement + + +__all__ = [ + "ParallelStyle", + "RowwiseParallel", + "SequenceParallel", + "ColwiseParallel", + "PrepareModuleInput", + "PrepareModuleInputOutput", + "PrepareModuleOutput", +] + + +class ParallelStyle(ABC): + """ + The parallel style contract defines how the module or submodule should be parallelized. + + It only defines the ``apply`` method for ``parallelize_module`` to use, this allows maximum + flexibility for different kind of style implementations. + """ + + src_data_rank: int | None = 0 + + @abstractmethod + def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: ... + + +class ColwiseParallel(ParallelStyle): + """ + Partition a compatible nn.Module in a column-wise fashion. Currently supports nn.Linear and nn.Embedding. + Users can compose it together with RowwiseParallel to achieve the sharding of more complicated modules. + (i.e. MLP, Attention) + + Keyword Args: + input_layouts (Placement, optional): + The DTensor layout of input tensor for the nn.Module, this is used to annotate the input tensor to + become a DTensor. If not specified, we assume the input tensor to be replicated. + output_layouts (Placement, optional): + The DTensor layout of the output for the nn.Module, this is used to ensure the output of the nn.Module + with the user desired layout. If not specified, the output tensor is sharded on the last dimension. + use_local_output (bool, optional): + Whether to use local :class:`torch.Tensor` instead of :class:`DTensor` for the module output, default: True. + Returns: + A :class:`ParallelStyle` object that represents Colwise sharding of the nn.Module. + + Example:: + >>> # xdoctest: +SKIP(failing) + >>> from torch.distributed.tensor.parallel import parallelize_module, ColwiseParallel + >>> from torch.distributed.device_mesh import init_device_mesh + >>> ... + >>> m = Model(...) # m is a nn.Module that contains a "w1" nn.Linear submodule + >>> tp_mesh = init_device_mesh("cuda", (8,)) + >>> + >>> # By default, the input of the "w1" Linear will be converted to Replicated DTensor + >>> # and the output of "w1" will return :class:`torch.Tensor` that shards on the last dim. + >>> + >>> sharded_mod = parallelize_module(m, tp_mesh, {"w1": ColwiseParallel()}) + >>> ... + + .. note:: By default ``ColwiseParallel`` output is sharded on the last dimension if the ``output_layouts`` not + specified, if there're operators that require specific tensor shape (i.e. before the paired ``RowwiseParallel``), + keep in mind that if the output is sharded the operator might need to be adjusted to the sharded size. + """ + + def __init__( + self, + *, + input_layouts: Placement | None = None, + output_layouts: Placement | None = None, + use_local_output: bool = True, + ): + super().__init__() + self.input_layouts = (input_layouts or Replicate(),) + self.output_layouts = (output_layouts or Shard(-1),) + # colwise linear runtime sharding (desired sharding): + # 1. requires replicate input + # 2. shard output on last dim + self.desired_input_layouts = (Replicate(),) + self.use_local_output = use_local_output + + @staticmethod + def _prepare_input_fn( + input_layouts, desired_input_layouts, mod, inputs, device_mesh + ): + # TODO: figure out dynamo support for instance method and switch this to instance method + + # annotate module input placements/sharding with input_layouts + input_tensor = inputs[0] + if not isinstance(input_tensor, DTensor): + input_tensor = DTensor.from_local( + input_tensor, device_mesh, input_layouts, run_check=False + ) + + # transform the input layouts to the desired layouts of ColwiseParallel + if input_layouts != desired_input_layouts: + input_tensor = input_tensor.redistribute( + placements=desired_input_layouts, async_op=True + ) + return input_tensor + + def _partition_linear_fn(self, name, module, device_mesh): + # colwise shard weight/bias to Shard(0), weight be Shard(0) + # means Colwise as Linear is input * weight^T + bias, where + # weight would become Shard(1) + for name, param in module.named_parameters(): + dist_param = nn.Parameter( + distribute_tensor( + param, device_mesh, [Shard(0)], src_data_rank=self.src_data_rank + ) + ) + module.register_parameter(name, dist_param) + + def _partition_embedding_fn(self, name, module, device_mesh): + # colwise shard embedding.weight is straight forward as Shard(1) + for name, param in module.named_parameters(): + dist_param = nn.Parameter( + distribute_tensor( + param, device_mesh, [Shard(1)], src_data_rank=self.src_data_rank + ) + ) + module.register_parameter(name, dist_param) + + @staticmethod + def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh): + # outputs is a shard on last dimension DTensor, i.e. Shard(-1) + if outputs.placements != output_layouts: + outputs = outputs.redistribute(placements=output_layouts, async_op=True) + # back to local tensor + return outputs.to_local() if use_local_output else outputs + + def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: + if isinstance(module, nn.Linear): + partition_fn = self._partition_linear_fn + elif isinstance(module, nn.Embedding): + partition_fn = self._partition_embedding_fn + else: + raise NotImplementedError( + "ColwiseParallel currently only support nn.Linear and nn.Embedding!" + ) + + return distribute_module( + module, + device_mesh, + partition_fn, + partial( + self._prepare_input_fn, self.input_layouts, self.desired_input_layouts + ), + partial( + self._prepare_output_fn, self.output_layouts, self.use_local_output + ), + ) + + def __repr__(self) -> str: + tmpstr = self.__class__.__name__ + "(" + tmpstr += f"input_layouts={self.input_layouts}, " + tmpstr += f"output_layouts={self.output_layouts}, " + tmpstr += f"use_local_output={self.use_local_output}" + tmpstr += ")" + return tmpstr + + +class RowwiseParallel(ParallelStyle): + """ + Partition a compatible nn.Module in a row-wise fashion. Currently supports nn.Linear and nn.Embedding. + Users can compose it with ColwiseParallel to achieve the sharding of more complicated modules. + (i.e. MLP, Attention) + + Keyword Args: + input_layouts (Placement, optional): + The DTensor layout of input tensor for the nn.Module, this is used to annotate the input tensor to + become a DTensor. If not specified, we assume the input tensor to be sharded on the last dimension. + output_layouts (Placement, optional): + The DTensor layout of the output for the nn.Module, this is used to ensure the output of the nn.Module + with the user desired layout. If not specified, the output tensor is replicated. + use_local_output (bool, optional): + Whether to use local :class:`torch.Tensor` instead of :class:`DTensor` for the module output, default: True. + Returns: + A :class:`ParallelStyle` object that represents Rowwise sharding of the nn.Module. + + Example:: + >>> # xdoctest: +SKIP(failing) + >>> from torch.distributed.tensor.parallel import parallelize_module, RowwiseParallel + >>> from torch.distributed.device_mesh import init_device_mesh + >>> ... + >>> m = Model(...) # m is a nn.Module that contains a "w2" nn.Linear submodule + >>> tp_mesh = init_device_mesh("cuda", (8,)) + >>> + >>> # By default, the input of the "w2" Linear will be converted to DTensor that shards on the last dim + >>> # and the output of "w2" will return a replicated :class:`torch.Tensor`. + >>> + >>> sharded_mod = parallelize_module(m, tp_mesh, {"w2": RowwiseParallel()}), + >>> ... + """ + + def __init__( + self, + *, + input_layouts: Placement | None = None, + output_layouts: Placement | None = None, + use_local_output: bool = True, + ): + super().__init__() + self.input_layouts = (input_layouts or Shard(-1),) + self.output_layouts = (output_layouts or Replicate(),) + self.use_local_output = use_local_output + + @staticmethod + def _prepare_input_fn( + input_layouts, desired_input_layouts, mod, inputs, device_mesh + ): + input_tensor = inputs[0] + if not isinstance(input_tensor, DTensor): + input_tensor = DTensor.from_local( + input_tensor, device_mesh, input_layouts, run_check=False + ) + + if input_layouts != desired_input_layouts: + input_tensor = input_tensor.redistribute( + placements=desired_input_layouts, async_op=True + ) + return input_tensor + + def _partition_linear_fn(self, name, module, device_mesh): + # Rowwise shard weight to Shard(1), bias to Replicate(), weight be Shard(1) + # means Rowwise as nn.Linear is input * weight^T + bias, where + # weight would become Shard(0) + module.register_parameter( + "weight", + nn.Parameter( + distribute_tensor( + module.weight, + device_mesh, + [Shard(1)], + src_data_rank=self.src_data_rank, + ) + ), + ) + if getattr(module, "bias", None) is not None: + # The Linear module has bias + module.register_parameter( + "bias", + nn.Parameter( + distribute_tensor( + module.bias, + device_mesh, + [Replicate()], + src_data_rank=self.src_data_rank, + ) + ), + ) + + def _partition_embedding_fn(self, name, module, device_mesh): + # rowwise shard embedding.weight is Shard(0) + for name, param in module.named_parameters(): + dist_param = nn.Parameter( + distribute_tensor( + param, device_mesh, [Shard(0)], src_data_rank=self.src_data_rank + ) + ) + module.register_parameter(name, dist_param) + + @staticmethod + def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh): + # Rowwise sharding produces partial output, depending on output layouts: + # 1. to replicate -> allreduce + # 2. to shard -> reduce_scatter + if outputs.placements != output_layouts: + outputs = outputs.redistribute(placements=output_layouts, async_op=True) + # back to local tensor if use_local_output is True + return outputs.to_local() if use_local_output else outputs + + def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: + if isinstance(module, nn.Linear): + partition_fn = self._partition_linear_fn + # rowwise linear runtime sharding requires input tensor shard on last dim + self.desired_input_layouts: tuple[Placement, ...] = (Shard(-1),) + elif isinstance(module, nn.Embedding): + partition_fn = self._partition_embedding_fn + # rowwise embedding runtime sharding requires input tensor replicated + self.desired_input_layouts = (Replicate(),) + else: + raise NotImplementedError( + "RowwiseParallel currently only support nn.Linear and nn.Embedding!" + ) + + return distribute_module( + module, + device_mesh, + partition_fn, + partial( + self._prepare_input_fn, self.input_layouts, self.desired_input_layouts + ), + partial( + self._prepare_output_fn, self.output_layouts, self.use_local_output + ), + ) + + def __repr__(self) -> str: + tmpstr = self.__class__.__name__ + "(" + tmpstr += f"input_layouts={self.input_layouts}, " + tmpstr += f"output_layouts={self.output_layouts}, " + tmpstr += f"use_local_output={self.use_local_output}" + tmpstr += ")" + return tmpstr + + +class SequenceParallel(ParallelStyle): + """ + SequenceParallel replicates a compatible ``nn.Module`` parameters and runs the sharded computation with + input sharded on the sequence dimension. This currently supports ``nn.LayerNorm``, ``nn.Dropout``, and the + `RMSNorm python implementation `__ + + This style implements the operation that is described in the paper + `Reducing Activation Recomputation in Large Transformer Models `__ + + If the input passed in to this ``nn.Module`` is a :class:`torch.Tensor`, it assumes that the input is already sharded + on the sequence dimension and converts the input to a :class:`DTensor` sharded on the sequence dimension. If the input + passed in to this ``nn.Module`` is already a :class:`DTensor` but is not sharded on the sequence dimension, it would + redistribute the input to be sharded on the sequence dimension. + + The output of the ``nn.Module`` will be sharded on the sequence dimension. + + Keyword Args: + sequence_dim (int, optional): + The sequence dimension of the input tensor for the ``nn.Module``, this is used to annotate the input tensor to + become a DTensor that is sharded on the sequence dimension, default: 1. + use_local_output (bool, optional): + Whether to use local :class:`torch.Tensor` instead of :class:`DTensor` for the module output, default: False. + Returns: + A :class:`ParallelStyle` object that represents Sequence Parallel of the ``nn.Module``. + + Example:: + >>> # xdoctest: +SKIP(failing) + >>> from torch.distributed.tensor.parallel import parallelize_module, SequenceParallel + >>> from torch.distributed.device_mesh import init_device_mesh + >>> ... + >>> m = Model(...) # m is a nn.Module that contains a "norm" nn.LayerNorm submodule + >>> tp_mesh = init_device_mesh("cuda", (8,)) + >>> + >>> # By default, the input of the "norm" will be converted to DTensor that shards on the sequence dim + >>> # and the output of "norm" will return a sharded on sequence dimension :class:`DTensor`. + >>> + >>> sharded_mod = parallelize_module(m, tp_mesh, {"norm": SequenceParallel()}), + >>> ... + + .. note:: SequenceParallel style assumes ones initialization if there are weights in the nn.Module (i.e. + ``nn.LayerNorm`` or ``RMSNorm``, and they by default have ones initialization). If you have custom + inits for the weights on those modules, you need to broadcast the weights before/after parallelizing + to ensure that they are replicated. + """ + + def __init__(self, *, sequence_dim: int = 1, use_local_output: bool = False): + super().__init__() + self.sequence_sharding = (Shard(sequence_dim),) + self.use_local_output = use_local_output + + def _replicate_module_fn( + self, name: str, module: nn.Module, device_mesh: DeviceMesh + ): + for p_name, param in module.named_parameters(): + # simple replication with fixed ones_ init from LayerNorm/RMSNorm, which allow + # us to simply just use from_local + replicated_param = torch.nn.Parameter( + DTensor.from_local(param, device_mesh, [Replicate()], run_check=False) + ) + module.register_parameter(p_name, replicated_param) + + @staticmethod + def _prepare_input_fn(sequence_sharding, mod, inputs, device_mesh): + input_tensor = inputs[0] + if isinstance(input_tensor, DTensor): + # if the passed in input DTensor is not sharded on the sequence dim, we need to redistribute it + if input_tensor.placements != sequence_sharding: + input_tensor = input_tensor.redistribute( + placements=sequence_sharding, async_op=True + ) + return input_tensor + elif isinstance(input_tensor, torch.Tensor): + # assume the input passed in already sharded on the sequence dim and create the DTensor + return DTensor.from_local( + input_tensor, device_mesh, sequence_sharding, run_check=False + ) + else: + raise ValueError( + f"expecting input of {mod} to be a torch.Tensor or DTensor, but got {input_tensor}" + ) + + @staticmethod + def _prepare_output_fn(use_local_output, mod, outputs, device_mesh): + return outputs.to_local() if use_local_output else outputs + + def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: + return distribute_module( + module, + device_mesh, + self._replicate_module_fn, + partial(self._prepare_input_fn, self.sequence_sharding), + partial(self._prepare_output_fn, self.use_local_output), + ) + + def __repr__(self) -> str: + tmpstr = self.__class__.__name__ + "(" + if len(self.sequence_sharding) == 1: + tmpstr += f"sequence_dim={self.sequence_sharding[0].dim}, " + tmpstr += f"use_local_output={self.use_local_output}" + tmpstr += ")" + return tmpstr + + +class PrepareModuleInput(ParallelStyle): + """ + Configure the nn.Module's inputs to convert the input tensors of the nn.Module to DTensors at runtime according to + ``input_layouts``, and perform layout redistribution according to the ``desired_input_layouts``. + + Keyword Args: + input_layouts (Union[Placement, Tuple[Optional[Placement]]]): + The DTensor layouts of input tensors for the nn.Module, this is used to convert the input tensors to + DTensors. If some inputs are not torch.Tensor or no need to convert to DTensors, ``None`` need to be specified + as a placeholder. default: None. + desired_input_layouts (Union[Placement, Tuple[Optional[Placement]]]): + The desired DTensor layout of input tensors for the nn.Module, this is used to ensure the inputs of the nn.Module + have the desired DTensor layouts. This argument needs to have the same length with ``input_layouts``. default: None. + input_kwarg_layouts (Dict[str, Placement]): + The DTensor layouts of input kwargs for the nn.Module, this is used to convert the input kwarg tensors to DTensors. + default: None + desired_input_kwarg_layouts: (Dict[str, Placement]): + The desired DTensor layout of input kwargs for the nn.Module, this is used to ensure the inputs of the nn.Module + have the desired DTensor layouts. default: None. + use_local_output (bool, optional): + Whether to use local :class:`torch.Tensor` instead of :class:`DTensor` for the module inputs, default: False. + Returns: + A :class:`ParallelStyle` object that prepares the sharding layouts of the nn.Module's inputs. + + Example:: + >>> # xdoctest: +SKIP(failing) + >>> from torch.distributed.tensor.parallel import parallelize_module, PrepareModuleInput + >>> from torch.distributed.device_mesh import init_device_mesh + >>> ... + >>> block = TransformerBlock(...) # block is a nn.Module that contains an "attn" Attention submodule + >>> tp_mesh = init_device_mesh("cuda", (8,)) + >>> + >>> # According to the style specified below, the first input of attn will be annotated to Sharded DTensor + >>> # and then redistributed to Replicated DTensor. + >>> parallelize_module( + >>> block, # this can be a submodule or module + >>> tp_mesh, + >>> parallelize_plan={ + >>> "attn": PrepareModuleInput( + >>> input_layouts=(Shard(0), None, None, ...), + >>> desired_input_layouts=(Replicate(), None, None, ...) + >>> ), + >>> } + >>> ) + """ + + def __init__( + self, + *, + input_layouts: Placement | tuple[Placement | None, ...] | None = None, + desired_input_layouts: Placement | tuple[Placement | None, ...] | None = None, + input_kwarg_layouts: dict[str, Placement] | None = None, + desired_input_kwarg_layouts: dict[str, Placement] | None = None, + use_local_output: bool = False, + ): + self.input_layouts = ( + (input_layouts,) if isinstance(input_layouts, Placement) else input_layouts + ) + self.desired_input_layouts = ( + (desired_input_layouts,) + if isinstance(desired_input_layouts, Placement) + else desired_input_layouts + ) + self.use_local_output = use_local_output + if self.input_layouts is not None: + assert self.desired_input_layouts is not None, ( + "desired module inputs should not be None!" + ) + assert len(self.input_layouts) == len(self.desired_input_layouts), ( + "input_layouts and desired_input_layouts should have same length!" + ) + self.with_kwargs = input_kwarg_layouts is not None + self.input_kwarg_layouts = input_kwarg_layouts or {} + self.desired_input_kwarg_layouts = desired_input_kwarg_layouts or {} + if self.with_kwargs: + assert len(self.input_kwarg_layouts) == len( + self.desired_input_kwarg_layouts + ), ( + "input_kwarg_layouts and desired_input_kwarg_layouts should have same length!" + ) + + def _prepare_input_arg( + self, + input: Any, + mesh: DeviceMesh, + input_layout: Placement | None, + desired_layout: Placement | None, + ): + if input_layout is not None: + if isinstance(input, DTensor): + # TODO: re-enable the check once we fix the compile path + # assert inp.placements[0] == input_layout + dt_inp = input + else: + assert isinstance(input, torch.Tensor), ( + "expecting input to be a torch.Tensor!" + ) + dt_inp = DTensor.from_local( + input, mesh, (input_layout,), run_check=False + ) + + if desired_layout is not None and input_layout != desired_layout: + dt_inp = dt_inp.redistribute(placements=(desired_layout,)) + + return dt_inp.to_local() if self.use_local_output else dt_inp + else: + return input + + def _prepare_input_fn(self, inputs, device_mesh): + if self.input_layouts is None: + return inputs + prepared_inputs = [] + if not isinstance(inputs, tuple): + inputs = (inputs,) + if len(inputs) != len(self.input_layouts): + raise ValueError("module inputs and input_layouts should have same length!") + + assert self.desired_input_layouts is not None, ( + "desired module inputs should not be None!" + ) + + for inp, input_layout, desired_layout in zip( + inputs, self.input_layouts, self.desired_input_layouts + ): + prepared_inputs.append( + self._prepare_input_arg(inp, device_mesh, input_layout, desired_layout) + ) + return tuple(prepared_inputs) + + def _prepare_input_kwarg_fn(self, inputs, kwarg_inputs, device_mesh): + prepared_arg_inputs = self._prepare_input_fn(inputs, device_mesh) + prepared_kwarg_inputs = {} + for kwarg_key in kwarg_inputs: + kwarg_val = kwarg_inputs[kwarg_key] + input_layout = self.input_kwarg_layouts.get(kwarg_key) + desired_input_layout = self.desired_input_kwarg_layouts.get(kwarg_key) + + prepared_kwarg_inputs[kwarg_key] = self._prepare_input_arg( + kwarg_val, device_mesh, input_layout, desired_input_layout + ) + + return (prepared_arg_inputs, prepared_kwarg_inputs) + + def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: + if self.with_kwargs: + module.register_forward_pre_hook( + lambda _, inputs, kwargs: self._prepare_input_kwarg_fn( + inputs, kwargs, device_mesh + ), + with_kwargs=True, + ) # type: ignore[misc] + else: + module.register_forward_pre_hook( + lambda _, inputs: self._prepare_input_fn(inputs, device_mesh) + ) # type: ignore[misc, call-arg] + return module + + def __repr__(self) -> str: + tmpstr = self.__class__.__name__ + "(" + tmpstr += f"input_layouts={self.input_layouts}, " + tmpstr += f"desired_input_layouts={self.desired_input_layouts}, " + tmpstr += f"input_kwarg_layouts={self.input_kwarg_layouts}, " + tmpstr += f"desired_input_kwarg_layouts={self.desired_input_kwarg_layouts}, " + tmpstr += f"use_local_output={self.use_local_output}" + tmpstr += ")" + return tmpstr + + +class PrepareModuleOutput(ParallelStyle): + """ + Configure the nn.Module's outputs to convert the output tensors of the nn.Module to DTensors at runtime according to + ``output_layouts``, and perform layout redistribution according to the ``desired_output_layouts``. + + Keyword Args: + output_layouts (Union[Placement, Tuple[Placement]]): + The DTensor layouts of output tensors for the nn.Module, this is used to convert the output tensors to + DTensors if they are :class:`torch.Tensor`. If some outputs are not torch.Tensor or no need to convert to DTensors, + ``None`` need to be specified as a placeholder. + desired_output_layouts (Union[Placement, Tuple[Placement]]): + The desired DTensor layouts of output tensors for the nn.Module, this is used to ensure the outputs of the nn.Module + have the desired DTensor layouts. + use_local_output (bool, optional): + Whether to use local :class:`torch.Tensor` instead of :class:`DTensor` for the module outputs, default: True. + Returns: + A ParallelStyle object that prepares the sharding layouts of the nn.Module's outputs. + + Example:: + >>> # xdoctest: +SKIP(failing) + >>> from torch.distributed.tensor.parallel import parallelize_module, PrepareModuleOutput + >>> from torch.distributed.device_mesh import init_device_mesh + >>> ... + >>> block = TransformerBlock(...) # block is a nn.Module that contains an "attn" Attention submodule + >>> tp_mesh = init_device_mesh("cuda", (8,)) + >>> + >>> # According to the style specified below, the output of the TransformerBlock will be converted to Replicated DTensor + >>> # and then redistributed to Sharded DTensor. + >>> parallelize_module( + >>> block, # this can be a submodule or module + >>> tp_mesh, + >>> parallelize_plan = PrepareModuleOutput( + >>> output_layouts=Replicate(), + >>> desired_output_layouts=Shard(0) + >>> ) + >>> ) + """ + + def __init__( + self, + *, + output_layouts: Placement | tuple[Placement | None, ...], + desired_output_layouts: Placement | tuple[Placement, ...], + use_local_output: bool = True, + ): + self.output_layouts = ( + (output_layouts,) + if isinstance(output_layouts, Placement) + else output_layouts + ) + self.desired_output_layouts = ( + (desired_output_layouts,) + if isinstance(desired_output_layouts, Placement) + else desired_output_layouts + ) + self.use_local_output = use_local_output + assert len(self.output_layouts) == len(self.desired_output_layouts), ( + "output_layouts and desired_output_layouts should have same length!" + ) + + def _prepare_out_fn(self, outputs, device_mesh): + prepared_outputs = [] + if not isinstance(outputs, tuple): + outputs = (outputs,) + if len(outputs) != len(self.output_layouts): + raise ValueError( + "module outputs and output_layouts should have same length!" + ) + + for out, out_layout, desired_out_layout in zip( + outputs, self.output_layouts, self.desired_output_layouts + ): + if out_layout is not None: + if isinstance(out, DTensor): + # TODO: re-enable the check once we fix the compile path + # assert out.placements[0] == out_layout + dt_out = out + else: + dt_out = DTensor.from_local( + out, device_mesh, (out_layout,), run_check=False + ) + + if out_layout != desired_out_layout: + dt_out = dt_out.redistribute(placements=(desired_out_layout,)) + prepared_outputs.append( + dt_out.to_local() if self.use_local_output else dt_out + ) + else: + prepared_outputs.append(out) + if len(prepared_outputs) == 1: + return prepared_outputs[0] + else: + return tuple(prepared_outputs) + + def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: + module.register_forward_hook( + lambda _, inputs, outputs: self._prepare_out_fn(outputs, device_mesh) + ) # type: ignore[misc, call-arg] + return module + + def __repr__(self) -> str: + tmpstr = self.__class__.__name__ + "(" + tmpstr += f"output_layouts={self.output_layouts}, " + tmpstr += f"desired_output_layouts={self.desired_output_layouts}, " + tmpstr += f"use_local_output={self.use_local_output}" + tmpstr += ")" + return tmpstr + + +class PrepareModuleInputOutput(ParallelStyle): + """ + Configure the nn.Module's inputs (and outputs) to convert the input tensors (and output tensors, respectively) of the nn.Module + to DTensors at runtime according to ``input_layouts`` (and output_layouts, respectively), and perform layout redistribution + according to the ``desired_input_layouts`` (and ``desired_output_layouts``, respectively). This is a combination of + :class:`PrepareModuleInput` and :class:`PrepareModuleOutput`. + + Keyword Args: + input_layouts (Union[Placement, Tuple[Optional[Placement]]]): + The DTensor layouts of input tensors for the nn.Module, this is used to convert the input tensors to + DTensors. If some inputs are not torch.Tensor or no need to convert to DTensors, ``None`` need to be specified + as a placeholder. default: None. + desired_input_layouts (Union[Placement, Tuple[Optional[Placement]]]): + The desired DTensor layout of input tensors for the nn.Module, this is used to ensure the inputs of the nn.Module + have the desired DTensor layouts. This argument needs to have the same length with ``input_layouts``. default: None. + input_kwarg_layouts (Dict[str, Placement]): + The DTensor layouts of input kwargs for the nn.Module, this is used to convert the input kwarg tensors to DTensors. + default: None + desired_input_kwarg_layouts: (Dict[str, Placement]): + The desired DTensor layout of input kwargs for the nn.Module, this is used to ensure the inputs of the nn.Module + have the desired DTensor layouts. default: None. + use_local_input (bool, optional): + Whether to use local :class:`torch.Tensor` instead of :class:`DTensor` for the module inputs, default: False. + output_layouts (Union[Placement, Tuple[Placement]]): + The DTensor layouts of output tensors for the nn.Module, this is used to convert the output tensors to + DTensors if they are :class:`torch.Tensor`. If some outputs are not torch.Tensor or no need to convert to DTensors, + ``None`` need to be specified as a placeholder. + desired_output_layouts (Union[Placement, Tuple[Placement]]): + The desired DTensor layouts of output tensors for the nn.Module, this is used to ensure the outputs of the nn.Module + have the desired DTensor layouts. + use_local_output (bool, optional): + Whether to use local :class:`torch.Tensor` instead of :class:`DTensor` for the module outputs, default: True. + Returns: + A :class:`ParallelStyle` object that prepares the sharding layouts of the nn.Module's inputs and outputs. + + Example:: + >>> # xdoctest: +SKIP(failing) + >>> from torch.distributed.tensor.parallel import parallelize_module, PrepareModuleInputOutput + >>> from torch.distributed.device_mesh import init_device_mesh + >>> ... + >>> block = TransformerBlock(...) # block is a nn.Module that contains an "attn" Attention submodule + >>> tp_mesh = init_device_mesh("cuda", (8,)) + >>> + >>> # According to the style specified below, the first input of attn will be annotated as Sharded DTensor + >>> # and then redistributed to Replicated DTensor, and the output of the TransformerBlock will be annotated + >>> # as Replicated DTensor and then redistributed to Sharded DTensor. + >>> parallelize_module( + >>> block, # this can be a submodule or module + >>> tp_mesh, + >>> parallelize_plan={ + >>> "attn": PrepareModuleInputOutput( + >>> input_layouts=(Shard(0), None, None, ...), + >>> desired_input_layouts=(Replicate(), None, None, ...), + >>> output_layouts=Replicate(), + >>> desired_output_layouts=Shard(0), + >>> ), + >>> } + >>> ) + """ + + def __init__( + self, + *, + input_layouts: Placement | tuple[Placement | None, ...] | None = None, + desired_input_layouts: Placement | tuple[Placement | None, ...] | None = None, + input_kwarg_layouts: dict[str, Placement] | None = None, + desired_input_kwarg_layouts: dict[str, Placement] | None = None, + use_local_input: bool = False, + output_layouts: Placement | tuple[Placement | None, ...], + desired_output_layouts: Placement | tuple[Placement, ...], + use_local_output: bool = True, + ): + self.prepare_module_input = PrepareModuleInput( + input_layouts=input_layouts, + desired_input_layouts=desired_input_layouts, + input_kwarg_layouts=input_kwarg_layouts, + desired_input_kwarg_layouts=desired_input_kwarg_layouts, + use_local_output=use_local_input, + ) + self.prepare_module_output = PrepareModuleOutput( + output_layouts=output_layouts, + desired_output_layouts=desired_output_layouts, + use_local_output=use_local_output, + ) + + def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: + self.prepare_module_input._apply(module, device_mesh) + self.prepare_module_output._apply(module, device_mesh) + + return module + + def __repr__(self) -> str: + tmpstr = self.__class__.__name__ + "(" + tmpstr += f"input_layouts={self.prepare_module_input.input_layouts}, " + tmpstr += ( + f"desired_input_layouts={self.prepare_module_input.desired_input_layouts}, " + ) + tmpstr += ( + f"input_kwarg_layouts={self.prepare_module_input.input_kwarg_layouts}, " + ) + tmpstr += f"desired_input_kwarg_layouts={self.prepare_module_input.desired_input_kwarg_layouts}, " + tmpstr += f"use_local_input={self.prepare_module_input.use_local_output}, " + tmpstr += f"output_layouts={self.prepare_module_output.output_layouts}, " + tmpstr += f"desired_output_layouts={self.prepare_module_output.desired_output_layouts}, " + tmpstr += f"use_local_output={self.prepare_module_output.use_local_output}" + tmpstr += ")" + return tmpstr diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/placement_types.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/placement_types.py new file mode 100644 index 0000000000000000000000000000000000000000..cdeaf359bc2f9e2d273ae40a5a122eea376e07c9 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/tensor/placement_types.py @@ -0,0 +1,1114 @@ +# mypy: allow-untyped-defs +# Copyright (c) Meta Platforms, Inc. and affiliates + +from dataclasses import dataclass, field +from typing import cast, Optional + +import torch +import torch._C +import torch.distributed._functional_collectives as funcol +from torch._C._distributed import Placement +from torch.distributed._local_tensor import maybe_run_for_local_tensor +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor._collective_utils import ( + fill_empty_tensor_to_shards, + mesh_broadcast, + mesh_scatter, + pad_tensor, + shard_dim_alltoall, + unpad_tensor, +) +from torch.distributed.tensor._ops._mask_buffer import MaskBuffer + + +__all__ = ["Placement", "Shard", "Replicate", "Partial", "MaskPartial"] + + +# Appease TestPublicBindings.test_correct_module_names +Placement.__module__ = "torch.distributed.tensor.placement_types" + + +class Shard(torch._C._distributed.Shard): + """ + The ``Shard(dim)`` placement describes the DTensor sharding on tensor dimension + ``dim`` over a corresponding ``DeviceMesh`` dimension, where each rank on the + DeviceMesh dimension only holds a shard/piece of the global Tensor. The + ``Shard(dim)`` placement follows the ``torch.chunk(dim)`` semantic, where the + last few shards on the DeviceMesh dimension might be empty when the tensor dimension + is not evenly divisible on the DeviceMesh dimension. The ``Shard`` placement can be + used by all DTensor APIs (i.e. distribute_tensor, from_local, etc.) + + Args: + dim (int): The tensor dimension that describes the DTensor is sharded over its + corresponding DeviceMesh dimension. + + .. warning:: sharding on a tensor dimension where the tensor dimension size is not + evenly divisible on a DeviceMesh dimension is currently experimental and subject to change. + """ + + def _split_tensor( + self, + tensor: torch.Tensor, + num_chunks: int, + *, + with_padding: bool = True, + contiguous: bool = True, + ) -> tuple[list[torch.Tensor], list[int]]: + """ + This function uses torch.chunk to split a tensor into num_chunks shards along + the Shard placement dimension, and return a list of shards with their pad sizes. + + Keyword args: + with_padding (bool, optional): when True, we pad the tensor on the last + few ranks before calling the collectives (i.e. scatter/all_gather, etc.). + This is because collectives usually require equal size tensor inputs + """ + assert self.dim <= tensor.ndim, ( + f"Sharding dim {self.dim} greater than tensor ndim {tensor.ndim}" + ) + + # chunk tensor over dimension `dim` into n slices + tensor_list = list(torch.chunk(tensor, num_chunks, dim=self.dim)) + tensor_list = fill_empty_tensor_to_shards( + tensor_list, self.dim, num_chunks - len(tensor_list) + ) + + # compute the chunk size inline with ``torch.chunk`` to calculate padding + full_chunk_size = (tensor.size(self.dim) + num_chunks - 1) // num_chunks + + shard_list: list[torch.Tensor] = [] + pad_sizes: list[int] = [] + for shard in tensor_list: + if with_padding: + pad_size = Shard._get_shard_pad_size(full_chunk_size, shard, self.dim) + shard = pad_tensor(shard, self.dim, pad_size) + pad_sizes.append(pad_size) + if contiguous: + shard = shard.contiguous() + shard_list.append(shard) + return shard_list, pad_sizes + + @staticmethod + @maybe_run_for_local_tensor + def local_shard_size_and_offset( + curr_local_size: int, + num_chunks: int, + rank: int, + ) -> tuple[int, int]: + """ + Given the size of the current local tensor (which may already be sharded on some dimensions), + computes the new local shard size and offset given the desired number of chunks + (num_chunks is generally equal to the size of the current sharding dim). + + Note: new local shard offset is relative to the current sharded tensor, not the global tensor. + See `_utils.compute_local_shape_and_global_offset` for computing global offset. + + Returns (new local shard size, offset) + + """ + # Compute the chunk size inline with ``torch.chunk`` + if curr_local_size % num_chunks == 0: + full_chunk_size = curr_local_size // num_chunks + return full_chunk_size, full_chunk_size * rank + + # uneven sharding case + full_chunk_size = (curr_local_size + num_chunks - 1) // num_chunks + shard_starting_idx = full_chunk_size * rank + + if curr_local_size < shard_starting_idx: + return 0, curr_local_size + else: + local_shard_size = ( + min(curr_local_size, shard_starting_idx + full_chunk_size) + - shard_starting_idx + ) + return local_shard_size, shard_starting_idx + + def _local_shard_size_and_offset( + self, + curr_local_size: int, + num_chunks: int, + rank: int, + ) -> tuple[int, int | None]: + return Shard.local_shard_size_and_offset(curr_local_size, num_chunks, rank) + + @staticmethod + @maybe_run_for_local_tensor + def _maybe_unpad_tensor_with_sizes( + dim, local_tensor, pad_sizes, mesh_dim_local_rank, make_contiguous + ) -> torch.Tensor: + # Only unpad if the local_tensor was padded on the dimension. + if pad_sizes[mesh_dim_local_rank] > 0: + local_tensor = unpad_tensor( + local_tensor, dim, pad_sizes[mesh_dim_local_rank] + ) + if make_contiguous: + local_tensor = local_tensor.contiguous() + return local_tensor + + def _shard_tensor( + self, + tensor: torch.Tensor, + mesh: DeviceMesh, + mesh_dim: int, + src_data_rank: int | None = 0, + ) -> torch.Tensor: + """ + Shard and scatter a tensor on a mesh dimension (use coordinate 0 on the + mesh dimension as source of truth). + + Create the local tensor for this rank following the given Shard + placement. If src_data_rank is None, perform only local splitting. + Otherwise, additionally scatter data from src_data_rank. Unlike + ``_split_tensor``, which supports uneven sharding via padding, this + method requires the tensor dimension to be evenly divisible by the + number of chunks (mesh dimension size). + """ + my_coordinate = mesh.get_coordinate() + num_chunks = mesh.size(mesh_dim=mesh_dim) + + if my_coordinate is None: + # if rank is not part of mesh, we simply return an empty tensor + return tensor.new_empty(0, requires_grad=tensor.requires_grad) + + mesh_dim_local_rank = my_coordinate[mesh_dim] + + if src_data_rank is None: + # src_data_rank specified as None explicitly means to skip the + # communications, simply split + scatter_list, _ = self._split_tensor( + tensor, num_chunks, with_padding=False, contiguous=True + ) + + return self._select_shard(scatter_list, mesh_dim_local_rank) + + scatter_list, pad_sizes = self._split_tensor( + tensor, num_chunks, with_padding=True, contiguous=True + ) + + it = iter(scatter_list) + first = next(it) + # Tensors in the scatter list are expected to have the same shape because + # split is requested with padding. + assert all(first.shape == v.shape for v in it) + + output = torch.empty_like(first) + + # perform scatter from the src_data_rank as data source when it is not None + mesh_scatter( + output, scatter_list, mesh, mesh_dim=mesh_dim, group_src=src_data_rank + ) + + return Shard._maybe_unpad_tensor_with_sizes( + self.dim, output, pad_sizes, mesh_dim_local_rank, True + ) + + @classmethod + def _make_shard_tensor( + cls, + dim: int, + tensor: torch.Tensor, + mesh: DeviceMesh, + mesh_dim: int, + src_data_rank: int | None = 0, + ) -> torch.Tensor: + shard_placement = cls(dim) + return shard_placement._shard_tensor(tensor, mesh, mesh_dim, src_data_rank) + + def _reduce_shard_tensor( + self, + tensor: torch.Tensor, + mesh: DeviceMesh, + reduce_op: str, + mesh_dim: int, + ) -> torch.Tensor: + """ + reduce and scatter a tensor on a mesh dimension + """ + my_coordinate = mesh.get_coordinate() + num_chunks = mesh.size(mesh_dim=mesh_dim) + + if my_coordinate is None: + # if rank is not part of mesh, we simply return local_tensor, + # which should be an empty tensor + return tensor + + is_padded = tensor.size(self.dim) % num_chunks != 0 + pad_sizes = None + if is_padded: + scattered_list, pad_sizes = self._split_tensor( + tensor, num_chunks, with_padding=True, contiguous=True + ) + tensor = torch.cat(scattered_list, dim=self.dim) + elif not tensor.is_contiguous(): + tensor = tensor.contiguous() + + output = funcol.reduce_scatter_tensor( + tensor, reduce_op, scatter_dim=self.dim, group=(mesh, mesh_dim) + ) + + if is_padded: + assert pad_sizes is not None + output = Shard._maybe_unpad_tensor_with_sizes( + self.dim, output, pad_sizes, my_coordinate[mesh_dim], False + ) + return output + + @maybe_run_for_local_tensor + def _maybe_pad_tensor( + self, + local_tensor: torch.Tensor, + logical_dim_size: int, + num_chunks: int, + ) -> torch.Tensor: + is_padded = logical_dim_size % num_chunks != 0 + + if is_padded: + full_chunk_size = (logical_dim_size + num_chunks - 1) // num_chunks + pad_size = full_chunk_size - local_tensor.size(self.dim) + local_tensor = pad_tensor(local_tensor, self.dim, pad_size) + + if not local_tensor.is_contiguous(): + local_tensor = local_tensor.contiguous() + + return local_tensor + + @maybe_run_for_local_tensor + def _maybe_unpad_tensor( + self, + local_tensor: torch.Tensor, + logical_dim_size: int, + num_chunks: int, + ) -> torch.Tensor: + is_padded = logical_dim_size % num_chunks != 0 + + if is_padded: + full_chunk_size = (logical_dim_size + num_chunks - 1) // num_chunks + unpad_size = full_chunk_size * num_chunks - logical_dim_size # type: ignore[possibly-undefined] + local_tensor = unpad_tensor(local_tensor, self.dim, unpad_size) + + return local_tensor + + def _to_replicate_tensor( + self, + local_tensor: torch.Tensor, + mesh: DeviceMesh, + mesh_dim: int, + current_logical_shape: list[int], + ) -> torch.Tensor: + """ + This function all_gather all shards and return a tensor that + is replicated on the previously sharded mesh dimension + """ + num_chunks = mesh.size(mesh_dim=mesh_dim) + logical_dim_size = current_logical_shape[self.dim] + + local_tensor = self._maybe_pad_tensor( + local_tensor, logical_dim_size, num_chunks + ) + + result = funcol.all_gather_tensor( + local_tensor, + gather_dim=self.dim, + group=(mesh, mesh_dim), + ) + + result = self._maybe_unpad_tensor(result, logical_dim_size, num_chunks) + + return result + + @staticmethod + @maybe_run_for_local_tensor + def _select_shard(shards: list[torch.Tensor], shard_index) -> torch.Tensor: + return shards[shard_index].clone() + + def _replicate_to_shard( + self, + local_tensor: torch.Tensor, + mesh: DeviceMesh, + mesh_dim: int, + shard_index: int, + ) -> torch.Tensor: + """ + transform from replicated tensor to a sharded tensor on + the current rank, which would perform a local chunk + """ + num_chunks = mesh.size(mesh_dim=mesh_dim) + shards, _ = self._split_tensor( + local_tensor, + num_chunks, + with_padding=False, + contiguous=False, + ) + + return Shard._select_shard(shards, shard_index) + + @staticmethod + @maybe_run_for_local_tensor + def _get_shard_pad_size( + full_size: int, local_tensor: torch.Tensor, dim: int + ) -> int: + """ + Get the padding size of the local tensor on the shard dimension. + """ + return full_size - local_tensor.size(dim) + + @staticmethod + def _compute_padding_info( + current_logical_shape: list[int], + num_chunks: int, + old_shard_dim: int, + new_shard_dim: int, + ) -> tuple[bool, int, int, bool, int, int]: + results = [] + for shard_dim in [old_shard_dim, new_shard_dim]: + dim_logical_size = current_logical_shape[shard_dim] + dim_padding = dim_logical_size % num_chunks != 0 + dim_full_chunk_size = (dim_logical_size + num_chunks - 1) // num_chunks + results.append((dim_padding, dim_logical_size, dim_full_chunk_size)) + + return results[0] + results[1] + + @staticmethod + @maybe_run_for_local_tensor + def _pad_for_new_shard_dim( + current_logical_shape: list[int], + local_tensor: torch.Tensor, + num_chunks: int, + old_shard_dim: int, + new_shard_dim: int, + ) -> torch.Tensor: + ( + old_dim_padding, + _, + old_dim_full_chunk_size, + new_dim_padding, + _, + new_dim_full_chunk_size, + ) = Shard._compute_padding_info( + current_logical_shape, num_chunks, old_shard_dim, new_shard_dim + ) + + if old_dim_padding: + old_dim_pad_size = Shard._get_shard_pad_size( + old_dim_full_chunk_size, local_tensor, old_shard_dim + ) + local_tensor = pad_tensor(local_tensor, old_shard_dim, old_dim_pad_size) + if new_dim_padding: + new_dim_pad_size = Shard._get_shard_pad_size( + new_dim_full_chunk_size * num_chunks, local_tensor, new_shard_dim + ) + local_tensor = pad_tensor(local_tensor, new_shard_dim, new_dim_pad_size) + + if not local_tensor.is_contiguous(): + local_tensor = local_tensor.contiguous() + return local_tensor + + @staticmethod + @maybe_run_for_local_tensor + def _unpad_for_new_shard_dim( + current_logical_shape: list[int], + local_tensor: torch.Tensor, + num_chunks: int, + old_shard_dim: int, + new_shard_dim: int, + local_rank: int, + ) -> torch.Tensor: + ( + old_dim_padding, + _, + old_dim_full_chunk_size, + new_dim_padding, + new_dim_logical_size, + new_dim_full_chunk_size, + ) = Shard._compute_padding_info( + current_logical_shape, num_chunks, old_shard_dim, new_shard_dim + ) + + if old_dim_padding: + old_dim_unpad_size = ( + old_dim_full_chunk_size * num_chunks + - current_logical_shape[old_shard_dim] # type: ignore[possibly-undefined] + ) + local_tensor = unpad_tensor(local_tensor, old_shard_dim, old_dim_unpad_size) # type: ignore[possibly-undefined] + + if new_dim_padding: + local_shard_size_on_new_dim = Shard.local_shard_size_and_offset( + new_dim_logical_size, num_chunks, local_rank + )[0] + new_dim_unpad_size = new_dim_full_chunk_size - local_shard_size_on_new_dim # type: ignore[possibly-undefined] + local_tensor = unpad_tensor(local_tensor, new_shard_dim, new_dim_unpad_size) # type: ignore[possibly-undefined] + + return local_tensor + + def _to_new_shard_dim( + self, + local_tensor: torch.Tensor, + mesh: DeviceMesh, + mesh_dim: int, + current_logical_shape: list[int], + new_shard_dim: int, + ) -> torch.Tensor: + """ + transform from existing sharded tensor to a new sharded tensor on + that shard on a new dimension, which performs an alltoall + """ + my_coordinate = mesh.get_coordinate() + if my_coordinate is None: + # if rank is not part of mesh, we simply return local_tensor, + # which should be an empty tensor + return local_tensor + + num_chunks = mesh.size(mesh_dim=mesh_dim) + + local_tensor = Shard._pad_for_new_shard_dim( + current_logical_shape, local_tensor, num_chunks, self.dim, new_shard_dim + ) + + new_tensor = shard_dim_alltoall( + local_tensor, self.dim, new_shard_dim, mesh, mesh_dim + ) + + new_tensor = Shard._unpad_for_new_shard_dim( + current_logical_shape, + new_tensor, + num_chunks, + self.dim, + new_shard_dim, + my_coordinate[mesh_dim], + ) + + return new_tensor + + def __hash__(self) -> int: + return hash(self.dim) + + def __repr__(self) -> str: + """ + machine readable representation of the Shard placement + """ + return f"Shard(dim={self.dim})" + + def __str__(self) -> str: + """human readable representation of the Shard placement""" + return f"S({self.dim})" + + +class _StridedShard(torch._C._distributed.StridedShard): + """ + _StridedShard is only introduced to support 2D FSDP2 + TP sharding where the tensor + is sharded on the TP mesh dimension first, then sharded on the FSDP mesh dimension. + We call this right-to-left sharding which is the opposite of the default + left-to-right sharding. See the example below: + tensor shape: [8, 8] + mesh: [[0, 1], [2, 3]], names=("dp", "tp") + placements: [Shard(0), Shard(0)] + + The default sharding behavior shards the tensor on "dp" mesh dimension first then + "tp" dimension. The sharding result will be: + Rank | Mesh Coordinate | Shard Index + ------------------------------------------------ + 0 | (0, 0) | 0 (row 0-1) + 1 | (0, 1) | 1 (row 2-3) + 2 | (1, 0) | 2 (row 4-5) + 3 | (1, 1) | 3 (row 6-7) + + While the FSDP2 + TP sharding behavior does the opposite: it shards the tensor on + "tp" mesh dim first then "dp" dim. This right-to-left sharding will produce the + result: + Rank | Mesh Coordinate | Shard Index + ------------------------------------------------ + 0 | (0, 0) | 0 (row 0-1) + 1 | (0, 1) | 2 (row 4-5) + 2 | (1, 0) | 1 (row 2-3) + 3 | (1, 1) | 3 (row 6-7) + + The consequence is, any attempt to redistribute this DTensor to a full replica will + produce a wrong result because the shard-to-replicate redistribution always happens + right-to-left, regardless it's left-to-right sharding or right-to-left. To address + this, we use _StridedShard placement to make this right-to-left sharding compatible + with our left-to-right convention on both tensor distribution and redistribution. + + Now with _StridedShard, the right-to-left sharding above can be represented as: + tensor shape: [8, 8] + mesh: [[0, 1], [2, 3]], names=("dp", "tp") + placements: [_StridedShard(0, split_factor=2), Shard(0)] + + And a left-to-right processing of `placements` will produce the same result, which is + different from using the `Shard` placement: + Rank | Mesh Coordinate | Shard Index + ------------------------------------------------ + 0 | (0, 0) | 0 (row 0-1) + 1 | (0, 1) | 2 (row 4-5) + 2 | (1, 0) | 1 (row 2-3) + 3 | (1, 1) | 3 (row 6-7) + + The argument `split_factor` is the number of existing shards over the tensor sharding + dimension before processing the _StridedShard placement, as if the sharding happened + right-to-left. In the example above, the tensor should first be sharded on the "tp" + dimension into 2 shards before being sharded on the "dp" dimension. Therefore, the + `split_factor` of the _StridedShard placement on "dp" dim is 2. + + TODO: we should remove _StridedShard placement once we can unify it with Shard + """ + + def __hash__(self) -> int: + return hash((self.dim, self.split_factor)) + + def __repr__(self) -> str: + """ + machine readable representation of the _StridedShard placement + """ + return f"_StridedShard(dim={self.dim}, sf={self.split_factor})" + + def __str__(self) -> str: + """human readable representation of the _StridedShard placement""" + return f"_S({self.dim}, {self.split_factor})" + + @staticmethod + @maybe_run_for_local_tensor + def _select_shard(shards: list[torch.Tensor], shard_index) -> torch.Tensor: + return shards[shard_index].clone() + + @classmethod + def _make_shard_tensor( + cls, + dim: int, + tensor: torch.Tensor, + mesh: DeviceMesh, + mesh_dim: int, + src_data_rank: int | None = 0, + split_factor: int = 1, + ) -> torch.Tensor: + strided_shard_placement = cls(dim=dim, split_factor=split_factor) + return strided_shard_placement._shard_tensor( + tensor, mesh, mesh_dim, src_data_rank + ) + + def _shard_tensor( + self, + tensor: torch.Tensor, + mesh: DeviceMesh, + mesh_dim: int, + src_data_rank: Optional[int] = 0, + ) -> torch.Tensor: + """ + Shard and scatter a tensor on a mesh dimension (use coordinate 0 on the + mesh dimension as source of truth). + + Create the local tensor for this rank following the given StridedShard + placement. If src_data_rank is None, perform only local splitting. + Otherwise, additionally scatter data from src_data_rank. Unlike + ``_split_tensor``, which supports uneven sharding via padding, this + method requires the tensor dimension to be evenly divisible by the + number of chunks (mesh dimension size). + """ + my_coordinate = mesh.get_coordinate() + num_chunks = mesh.size(mesh_dim=mesh_dim) + + if my_coordinate is None: + # if rank is not part of mesh, we simply return an empty tensor + return tensor.new_empty(0, requires_grad=tensor.requires_grad) + + mesh_dim_local_rank = my_coordinate[mesh_dim] + + if src_data_rank is None: + # src_data_rank specified as None explicitly means to skip the + # communications, simply split + scatter_list, _ = self._split_tensor( + tensor, num_chunks, with_padding=False, contiguous=True + ) + + return self._select_shard(scatter_list, mesh_dim_local_rank) + + scatter_list, pad_sizes = self._split_tensor( + tensor, num_chunks, with_padding=True, contiguous=True + ) + + it = iter(scatter_list) + first = next(it) + # Tensors in the scatter list are expected to have the same shape because + # split is requested with padding. + assert all(first.shape == v.shape for v in it) + + output = torch.empty_like(first) + + # perform scatter from the src_data_rank as data source when it is not None + mesh_scatter( + output, scatter_list, mesh, mesh_dim=mesh_dim, group_src=src_data_rank + ) + + return Shard._maybe_unpad_tensor_with_sizes( + self.dim, output, pad_sizes, mesh_dim_local_rank, True + ) + + def _split_tensor( + self, + tensor: torch.Tensor, + num_chunks: int, + *, + with_padding: bool = True, + contiguous: bool = True, + ) -> tuple[list[torch.Tensor], list[int]]: + assert self.dim <= tensor.ndim, ( + f"Sharding dim {self.dim} greater than tensor ndim {tensor.ndim}" + ) + + # Essentially _StridedShard express the right-to-left sharding in the + # reversed order. Here we perform first_split as the virtual "right" sharding, + # and then second_split as the virtual "left" sharding, and finally assemble + # results in the transposed left-first order. + + # First split: chunk into split_factor pieces + first_split = list(torch.chunk(tensor, self.split_factor, dim=self.dim)) + first_split = fill_empty_tensor_to_shards( + first_split, self.dim, self.split_factor - len(first_split) + ) + + # Second split: chunk each piece into num_chunks pieces + second_split = [] + for s in first_split: + chunks = list(torch.chunk(s, num_chunks, dim=self.dim)) + chunks = fill_empty_tensor_to_shards( + chunks, self.dim, num_chunks - len(chunks) + ) + second_split.append(chunks) + + shard_list: list[torch.Tensor] = [] + for i in range(num_chunks): + shard = torch.cat( + [second_split[j][i] for j in range(self.split_factor)], + dim=self.dim, + ) + if contiguous: + shard = shard.contiguous() + shard_list.append(shard) + + # The amount of padding is determined by the local chunk with the largest size. + pad_sizes: list[int] = [] + max_chunk_size = max([shard.size(self.dim) for shard in shard_list]) + if with_padding: + pad_sizes = [max_chunk_size - shard.size(self.dim) for shard in shard_list] + + return shard_list, pad_sizes + + def _to_replicate_tensor( + self, + local_tensor: torch.Tensor, + mesh: DeviceMesh, + mesh_dim: int, + current_logical_shape: list[int], + ) -> torch.Tensor: + """ + replay the replicate-to-shard process to understand how to stitch shards back + """ + num_chunks = mesh.size(mesh_dim=mesh_dim) + logical_dim_size = current_logical_shape[self.dim] + + # indices_tensor is 1D torch.arange(logical_dim_size) unsqueezed + # so that we can reuse self._split_tensor which splits on self.dim + shape = [1] * self.dim + [logical_dim_size] + indices_tensor = torch.arange( + logical_dim_size, device=local_tensor.device + ).view(shape) + + sharded_indices, _ = self._split_tensor( + indices_tensor, + num_chunks, + with_padding=False, + contiguous=False, + ) + # squeeze back to 1D indices tensor + sharded_indices = [shard.view(-1) for shard in sharded_indices] + + max_chunk_size = max([len(shard) for shard in sharded_indices]) + local_pad_size = max_chunk_size - local_tensor.size(self.dim) + local_tensor_padded = pad_tensor(local_tensor, self.dim, local_pad_size) + + if not local_tensor_padded.is_contiguous(): + local_tensor_padded = local_tensor_padded.contiguous() + + replicate_tensor_permuted_padded = funcol.all_gather_tensor( + local_tensor_padded, + gather_dim=self.dim, + group=(mesh, mesh_dim), + ) + if isinstance(replicate_tensor_permuted_padded, funcol.AsyncCollectiveTensor): + replicate_tensor_permuted_padded = replicate_tensor_permuted_padded.wait() + + if replicate_tensor_permuted_padded.shape[self.dim] > logical_dim_size: + replicate_tensor_permuted = unpad_tensor( + replicate_tensor_permuted_padded, + self.dim, + replicate_tensor_permuted_padded.shape[self.dim] - logical_dim_size, + ) + else: + replicate_tensor_permuted = replicate_tensor_permuted_padded + + permutation = torch.cat(sharded_indices) + inv_permutation = torch.argsort(permutation) + replicate_tensor = torch.index_select( + replicate_tensor_permuted, self.dim, inv_permutation + ) + + return replicate_tensor.contiguous() + + @staticmethod + @maybe_run_for_local_tensor + def _local_shard_size(sharded_indices: list[torch.Tensor], rank: int) -> int: + return len(sharded_indices[rank]) + + # delete pyre-ignore once separating _StridedShard from Shard + def _local_shard_size_and_offset( # pyre-ignore[bad-override] + self, + curr_local_size: int, + num_chunks: int, + rank: int, + return_first_offset: bool = True, + ) -> tuple[int, int | list[int]]: + return _StridedShard.local_shard_size_and_offset( + self, curr_local_size, num_chunks, rank, return_first_offset + ) + + @staticmethod + @maybe_run_for_local_tensor + def local_shard_size_and_offset( # pyre-ignore[bad-override] + self, + curr_local_size: int, + num_chunks: int, + rank: int, + return_first_offset: bool = True, + ) -> tuple[int, list[int] | int]: + """ + Compute the local shard size and offset(s) for a _StridedShard placement. + + Unlike the regular Shard placement which produces contiguous offsets, _StridedShard + produces non-contiguous (strided) offsets due to the right-to-left sharding semantics. + This method computes the actual indices that belong to the local shard. + + Args: + self (_StridedShard): The _StridedShard placement instance. + curr_local_size (int): The current size of the tensor dimension to be sharded. + num_chunks (int): Number of chunks to split the dimension into (typically the mesh dimension size). + rank (int): The rank index to compute the shard for. + return_first_offset (bool): If True, return only the first offset as an int. If False, + return all offsets as a list. Defaults to True. + + Returns: + tuple: A tuple containing: + - local_shard_size (int): The number of elements in the local shard for this rank. + - offset (int | list[int]): If return_first_offset is True, returns the first offset + as an int. If False or if the shard size is 0, returns a list of all offsets + (which may be empty for empty shards). + """ + # indices_tensor is 1D torch.arange(logical_dim_size) unsqueezed + # so that we can reuse self._split_tensor which splits on self.dim + shape = [1] * self.dim + [curr_local_size] + indices_tensor = torch.arange( + curr_local_size, + ).view(shape) + + sharded_indices, _ = self._split_tensor( + indices_tensor, + num_chunks, + with_padding=False, + contiguous=False, + ) + # squeeze back to 1D indices tensor + sharded_indices = [shard.view(-1) for shard in sharded_indices] + + local_shard_size = _StridedShard._local_shard_size(sharded_indices, rank) + if local_shard_size > 0: + offsets = sharded_indices[rank].tolist() + else: + offsets = [] + + if return_first_offset: + # Always return an int for consistency across ranks. + # For empty shards, return -1 as an invalid offset indicator. + offsets = offsets[0] if len(offsets) > 0 else -1 + + return local_shard_size, offsets + + +class Replicate(torch._C._distributed.Replicate): + """ + The ``Replicate()`` placement describes the DTensor replicating on a corresponding + ``DeviceMesh`` dimension, where each rank on the DeviceMesh dimension holds a + replica of the global Tensor. The ``Replicate`` placement can be used by all + DTensor APIs (i.e. ``distribute_tensor``, ``DTensor.from_local``, etc.) + """ + + def __hash__(self) -> int: + # every replicate placement is the same + return -1 + + def __repr__(self) -> str: + """ + machine readable representation of the Replicate placement + """ + return "Replicate()" + + def __str__(self) -> str: + """ + human readable representation of the Replicate placement + """ + return "R" + + @classmethod + def _make_replicate_tensor( + cls, + tensor: torch.Tensor, + mesh: DeviceMesh, + mesh_dim: int, + src_data_rank: int | None = 0, + ) -> torch.Tensor: + """ + Replicate (broadcast) a torch.Tensor on a mesh dimension (use + the first coordinate on the mesh dimension as source of truth) + """ + my_coordinate = mesh.get_coordinate() + if my_coordinate is None: + # if rank is not part of mesh, we simply return an empty tensor + return tensor.new_empty(0, requires_grad=tensor.requires_grad) + + tensor = tensor.contiguous() + + if src_data_rank is not None: + # perform broadcast from the src_data_rank as data source when it is not None + mesh_broadcast(tensor, mesh, mesh_dim=mesh_dim, group_src=src_data_rank) + return tensor + + def _replicate_tensor( + self, + tensor: torch.Tensor, + mesh: DeviceMesh, + mesh_dim: int, + src_data_rank: int | None = 0, + ) -> torch.Tensor: + return Replicate._make_replicate_tensor(tensor, mesh, mesh_dim, src_data_rank) + + +class Partial(torch._C._distributed.Partial): + """ + The ``Partial(reduce_op)`` placement describes the DTensor that is pending + reduction on a specified ``DeviceMesh`` dimension, where each rank on the + DeviceMesh dimension holds the partial value of the global Tensor. User can + redistribute the ``Partial`` DTensor to a ``Replicate`` or ``Shard(dim)`` + placement on the specified ``DeviceMesh`` dimension using ``redistribute``, + which would trigger necessary communication operations under the hood (i.e. + ``allreduce``, ``reduce_scatter``). + + Args: + reduce_op (str, optional): The reduction op to be used for the partial DTensor + to produce Replicated/Sharded DTensor. Only element-wise reduction operations + are supported, including: "sum", "avg", "product", "max", "min", default: "sum". + + .. note:: The ``Partial`` placement can be generated as a result of the DTensor operators, + and can only be used by the ``DTensor.from_local`` API. + """ + + def _reduce_value( + self, tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int + ) -> torch.Tensor: + # Partial placement contract #1: + # _reduce_value: reduce the value of the tensor on the mesh dimension + return funcol.all_reduce( + tensor, reduceOp=self.reduce_op, group=(mesh, mesh_dim) + ) + + def _reduce_shard_value( + self, + tensor: torch.Tensor, + mesh: DeviceMesh, + mesh_dim: int, + shard_spec: Placement, + ) -> torch.Tensor: + # Partial placement contract #2: + # _reduce_shard_value: reduce_scatter the value of the tensor over the mesh dimension + shard_spec = cast(Shard, shard_spec) + return shard_spec._reduce_shard_tensor(tensor, mesh, self.reduce_op, mesh_dim) + + def _partition_value( + self, tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int + ) -> torch.Tensor: + # Partial placement contract #3: + # _partition_value: partition the value of a replicated tensor on the mesh dimension + + # _partition_value is the conjugate operation of _reduce_value, e.g. + # - _partition_value on a sum reduce op is just a division operation + # - _reduce_value on a sum reduce op would just be a sum(allreduce) operation + num_chunks = mesh.size(mesh_dim=mesh_dim) + if self.reduce_op == "sum": + return tensor / num_chunks + elif self.reduce_op in ("avg", "min", "max"): + return tensor + else: + raise ValueError( + f"Replicate to Partial({self.reduce_op}) conversion is not supported." + ) + + def __hash__(self) -> int: + return 1 + hash(self.reduce_op) + + def __repr__(self) -> str: + """ + machine readable representation of the Partial placement + """ + return f"Partial({self.reduce_op})" + + def __str__(self) -> str: + """ + human readable representation of the Partial placement + """ + return f"P({self.reduce_op})" + + +# We keep the old _Partial name for a while for BC reason +_Partial = Partial + + +@dataclass(frozen=True) +class MaskPartial(Partial): + """ + A partial mask placement devised for rowwise sharded embedding op, where we need + to mask and adjust the indices to the local embedding shard, embedding masking + is a special type of the Partial placement + + NOTE: the lifecycle of this MaskPartial placement follows the corresponding DTensor + lifecycle, i.e. the indices_mask would only be alive during the lifetime of the DTensor. + """ + + mask_buffer: MaskBuffer = field(default_factory=MaskBuffer) + + # required fields for computing the local offset and deriving the mask + offset_shape: torch.Size | None = None + offset_dim: int = 0 + + def __init__( + self, + reduce_op=None, + mask_buffer=None, + offset_shape=None, + offset_dim=0, + *args, + **kwargs, + ): + super().__init__(reduce_op) + if mask_buffer is None: + mask_buffer = MaskBuffer() + object.__setattr__(self, "mask_buffer", mask_buffer) + object.__setattr__(self, "offset_shape", offset_shape) + object.__setattr__(self, "offset_dim", offset_dim) + + @staticmethod + @maybe_run_for_local_tensor + def _mask_tensor( + tensor: torch.Tensor, local_offset_on_dim: int, local_shard_size: int + ) -> tuple[torch.Tensor, torch.Tensor]: + # Build the input mask and save it for the current partial placement + # this is so that the output of embedding op can reuse the same partial + # placement saved mask to perform mask + reduction + mask = (tensor < local_offset_on_dim) | ( + tensor >= local_offset_on_dim + local_shard_size + ) + # mask the input tensor + masked_tensor = tensor.clone() - local_offset_on_dim + masked_tensor[mask] = 0 + return mask, masked_tensor + + def _partition_value( + self, tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int + ) -> torch.Tensor: + my_coordinate = mesh.get_coordinate() + assert my_coordinate is not None, "my_coordinate should not be None" + # override parent logic to perform partial mask for embedding + num_chunks = mesh.size(mesh_dim) + # get local shard size and offset on the embedding_dim + assert self.offset_shape is not None, ( + "offset_shape needs to be set for MaskPartial" + ) + local_shard_size, local_offset_on_dim = Shard.local_shard_size_and_offset( + self.offset_shape[self.offset_dim], + num_chunks, + my_coordinate[mesh_dim], + ) + mask, masked_tensor = MaskPartial._mask_tensor( + tensor, local_offset_on_dim, local_shard_size + ) + # materialize the mask buffer to be used for reduction + self.mask_buffer.materialize_mask(mask) + return masked_tensor + + def _reduce_value( + self, tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int + ) -> torch.Tensor: + # by the time we need reduction, we should have already saved the mask + assert self.mask_buffer.data is not None + + # apply the mask to the tensor that pending reduction + self.mask_buffer.apply_mask(tensor) + + # clear the mask buffer + self.mask_buffer.release_mask() + + # perform sum reduction + return funcol.all_reduce( + tensor, reduceOp=self.reduce_op, group=(mesh, mesh_dim) + ) + + def _reduce_shard_value( + self, + tensor: torch.Tensor, + mesh: DeviceMesh, + mesh_dim: int, + shard_spec: Placement, + ) -> torch.Tensor: + # by the time we need reduction, we should have already saved the mask + assert self.mask_buffer.data is not None + + # apply the mask to the tensor that pending reduction + self.mask_buffer.apply_mask(tensor) + + # clear the mask buffer + self.mask_buffer.release_mask() + + # call reduce_shard_tensor of the shard_spec. + shard_spec = cast(Shard, shard_spec) + return shard_spec._reduce_shard_tensor(tensor, mesh, self.reduce_op, mesh_dim) + + def __eq__(self, other: object) -> bool: + if not isinstance(other, MaskPartial): + return False + + # if either data is not None, we invalidate the sharding cache, as this indicates + # the current MaskPartial placement is still in use and should not be used for cache hit. + if self.mask_buffer.data is not None or other.mask_buffer.data is not None: + return False + + return ( + self.reduce_op == other.reduce_op + and self.offset_shape == other.offset_shape + and self.offset_dim == other.offset_dim + ) + + def __hash__(self) -> int: + return 1 + hash( + ( + self.reduce_op, + self.offset_shape, + self.offset_dim, + ) + ) + + def __repr__(self) -> str: + """ + machine readable representation of the MaskPartial placement + """ + return f"MaskPartial(reduce_op={self.reduce_op}, offset_shape={self.offset_shape}, offset_dim={self.offset_dim})" + + def __str__(self) -> str: + """ + human readable representation of the MaskPartial placement + """ + return f"MaskP({self.reduce_op}, {self.offset_shape}, {self.offset_dim})" diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/export/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/export/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f396aa745940def6750fb2e0e62fa38b7281c04b Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/export/__pycache__/__init__.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/export/__pycache__/_draft_export.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/export/__pycache__/_draft_export.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f3744683f791a61616979fbe1b359ae2477a1323 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/export/__pycache__/_draft_export.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/export/__pycache__/_leakage_detection_utils.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/export/__pycache__/_leakage_detection_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0e015ab8794cf8c49ca32719a6d0217486757e69 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/export/__pycache__/_leakage_detection_utils.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/export/__pycache__/_remove_auto_functionalized_pass.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/export/__pycache__/_remove_auto_functionalized_pass.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dd330dc2a4d6a7a400933b3d8a6faf2f24bd1cef Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/export/__pycache__/_remove_auto_functionalized_pass.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/export/__pycache__/_remove_effect_tokens_pass.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/export/__pycache__/_remove_effect_tokens_pass.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c27dfe2b1797e2c83b6c1a5db7a2df5cdef5a974 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/export/__pycache__/_remove_effect_tokens_pass.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/export/__pycache__/_safeguard.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/export/__pycache__/_safeguard.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2805accd9d2c407e0d446b58242b290e04b7bca6 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/export/__pycache__/_safeguard.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/export/__pycache__/_swap.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/export/__pycache__/_swap.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d9fa8691b2a5e5c48312e530c1684d255ba3b363 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/export/__pycache__/_swap.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/export/__pycache__/_trace.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/export/__pycache__/_trace.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c590c12d8397850f4848168e070dd267a374ffff Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/export/__pycache__/_trace.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/export/__pycache__/_tree_utils.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/export/__pycache__/_tree_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..88c8202fb53fc3a5b79903e7234ee1c6afeb4444 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/export/__pycache__/_tree_utils.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/export/__pycache__/_unlift.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/export/__pycache__/_unlift.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b0e6622718ec3f21133c0047bc0670c48fc9dc91 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/export/__pycache__/_unlift.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/export/__pycache__/_wrapper_utils.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/export/__pycache__/_wrapper_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..aee8f73a93f2f350926b6190697270ac7863506d Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/export/__pycache__/_wrapper_utils.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/export/__pycache__/custom_obj.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/export/__pycache__/custom_obj.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e39e9270f023ca93c79c2e41be900b2d72513922 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/export/__pycache__/custom_obj.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/export/__pycache__/custom_ops.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/export/__pycache__/custom_ops.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f8e5c529bb8bc4147ec03f1511e6605f2af9679e Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/export/__pycache__/custom_ops.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/export/__pycache__/decomp_utils.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/export/__pycache__/decomp_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7425076c57cfd1f3060afcfdffdcd01c548ab42f Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/export/__pycache__/decomp_utils.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/export/__pycache__/dynamic_shapes.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/export/__pycache__/dynamic_shapes.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..42a47b6329828bd5444be09968f30c4db211779b Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/export/__pycache__/dynamic_shapes.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/export/__pycache__/exported_program.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/export/__pycache__/exported_program.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..422285e8d884817025674b7fa9fcd691f73c9bf1 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/export/__pycache__/exported_program.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/export/__pycache__/graph_signature.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/export/__pycache__/graph_signature.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0f2a5700c28ff0a5c817d96a68a9d418779069b5 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/export/__pycache__/graph_signature.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/export/__pycache__/unflatten.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/export/__pycache__/unflatten.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c59cdc471748b08ea91e62587327ddc76b02c1b5 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/export/__pycache__/unflatten.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/export/experimental/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/export/experimental/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..14399a7bfdadd7d7a35781892dd60e8809a6d5b7 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/export/experimental/__init__.py @@ -0,0 +1,430 @@ +import copy +import dataclasses +import functools +import os +import types +import typing +import typing_extensions +import zipfile +from pathlib import Path + +import torch +from torch.export.experimental._utils import _get_main_cpp_file, _get_make_file +from torch.export.exported_program import _decompose_exported_program + + +_InputT = typing_extensions.ParamSpec("_InputT") +_RetT = typing.TypeVar("_RetT") + + +__all__ = [] # type: ignore[var-annotated] + + +def _copy_graph_module_and_signature( + ep: torch.export.ExportedProgram, +) -> tuple[torch.fx.GraphModule, torch.export.graph_signature.ExportGraphSignature]: + # copy.deepcopy lets the objects override __deepcopy__ methods with graph_copy() and node_copy(), + # and this can break placeholder names in some particular cases. + # For example, node copying will avoid Python keywords like 'input', suffixing and renaming to 'input_1'. + # So we manually overwrite placeholder names by reading the old graph. + gm = copy.deepcopy(ep.graph_module) + new_graph_signature = copy.deepcopy(ep.graph_signature) + + # iterate over old/new graph modules + for old_gm, new_gm in zip(ep.graph_module.modules(), gm.modules()): # type: ignore[union-attr] + old_phs = [node for node in old_gm.graph.nodes if node.op == "placeholder"] + new_phs = [node for node in new_gm.graph.nodes if node.op == "placeholder"] + # iterate over placeholders + assert len(old_phs) == len(new_phs) + for old_node, new_node in zip(old_phs, new_phs): + new_node.name = old_node.name + + return gm, new_graph_signature + + +def _remove_detach_pass( + gm: torch.fx.GraphModule, sig: torch.export.graph_signature.ExportGraphSignature +) -> None: + with gm._set_replace_hook(sig.get_replace_hook()): + for node in list(reversed(gm.graph.nodes)): + if node.op != "call_function": + continue + if ( + node.target is torch.ops.aten.detach.default + and len(node.users) == 1 + and next(iter(node.users)).target is torch.ops.aten.detach.default + ): + next(iter(node.users)).replace_all_uses_with(node) + + gm.graph.eliminate_dead_code() + gm.recompile() + + +def _export_forward_backward( + ep: torch.export.ExportedProgram, joint_loss_index: int = 0 +) -> torch.export.ExportedProgram: + """ + WARNING: This API is highly unstable and will be subject to change in the future. + """ + from torch._decomp import core_aten_decompositions + + ep = _decompose_exported_program( + ep, + cia_to_decomp={}, + python_decomp_table=core_aten_decompositions(), + joint_loss_index=joint_loss_index, + # For serialization purpose, we don't want to decompose custom triton ops. + # If users would like to decompose custom triton ops, they could do it + # with run_decompositions() API. + decompose_custom_triton_ops=False, + ) + gm, new_graph_signature = _copy_graph_module_and_signature(ep) + _remove_detach_pass(gm, new_graph_signature) + + return ep._update(gm, new_graph_signature) + + +def _sticky_export( + forward_func: typing.Callable[_InputT, _RetT], + dynamic_shapes_callback: typing.Callable[ + _InputT, list[typing.Any] | dict[str, typing.Any] | tuple[typing.Any, ...] + ] + | None = None, +) -> typing.Callable[_InputT, _RetT]: + """ + Lazily export the model on first forward call. + Usage: + model.forward = _sticky_export(model.forward, dynamic_shapes_callback=callback) + """ + model = forward_func.__self__ # type: ignore[attr-defined] + original_forward = forward_func.__func__ # type: ignore[attr-defined] + + @functools.wraps(forward_func) + def wrapper(*args: _InputT.args, **kwargs: _InputT.kwargs) -> _RetT: + # Unpatch forward to avoid recursion during export + model.forward = types.MethodType(original_forward, model) + + dynamic_shapes_spec = None + if dynamic_shapes_callback: + dynamic_shapes_spec = dynamic_shapes_callback(*args, **kwargs) + + try: + exported = torch.export.export( + model, + args, + kwargs, + dynamic_shapes=dynamic_shapes_spec, + ).module() + wrapper._exported_artifact = exported # type: ignore[attr-defined] + finally: + # Restore the wrapper after export + model.forward = wrapper + + return exported(*args, **kwargs) + + return wrapper + + +@dataclasses.dataclass +class _ExportMethod: + overloads: dict[str, torch.export.ExportedProgram] + fallbacks: list[torch.export.ExportedProgram] + + +class _ExportPackage: + """ + An export package is a collection of torch.export()-ed PyTorch models consisting of + a list of exported methods and their corresponding overloads. ExportPackage is introduced + on top of torch.export() to support the following use cases: + - Exporting a model with multiple methods if a model has multiple independent parts. + - Exporting a function with multiple overloads based on tensor shapes or other metadata. + + ExportPackage is designed to contain multiple methods (associated with method names) and for + each method, it can have multiple overloads (associated with overload names). + + Here is an example of the data structure for an ExportPackage: + ``` + ExportPackage( + methods={ + "decoder": ExportMethod( + overloads={ + "prefill": ExportedProgram(...), + "decode": ExportedProgram(...), + }, + fallbacks=[], + ), + "encoder": ExportMethod(overloads={}, fallbacks=[ExportedProgram(...)]), + }, + ) + ``` + + To export a model into an ExportPackage, users can use the exporter API provided by ExportPackage. + Exporter is a decorator that takes a callable and returns a wrapper. The wrapper will export the + function into an ExportPackage, when it's invoked with some sample inputs (similar to how + torch.compile() works). For more details, please refer to the document on .exporter() method. + + This design allows users to decouple the exported callables from the actual sample inputs which can + be helpful for use cases where the exported callable is hidden behind helper functions or when sample + inpusts are hard to get. + + NOTE: This is an experimental API and anything can be changed in the future. + + Example usage: + ``` + def fn(x): + return x + 1 + + def main(f, x): + x += 1 + ret = f(x) + return ret + 1 + + package = ExportPackage() + main(package.exporter(fn), torch.randn(3, 2)) + ``` + + """ + + def __init__(self) -> None: + self.methods: dict[str, _ExportMethod] = {} + + def _exporter( + self, + method: str, + fn: typing.Callable[_InputT, _RetT], + *, + fallback: str = "once", + ) -> typing.Callable[_InputT, _RetT]: + """ + A function/module decorator that sets up a callable to be exported later invoked. + By default the exporter will only trigger torch.export for once and error on + later invocations. To customize this behavior, users have the following two options: + 1. Call .define_overload() method on the returned wrapper to define an overload. + 2. Adjust the fallback policy using `fallback` argument. + + An "overload" is a named branch for an ExportMethod with a user defined precondition, + typically based on input tensor shapes. It's up to a downstream backend implementation + of ExportMethod to respect the precondition later in inference. + + define_overload() takes arguments like the following: + - A name, for indexing purposes in a backend. + - A callable (spec) that: + - Has the same model input signature as the original model code. + - Returns an optional dynamic shape spec. + + Exporter will only export an overload when the spec callable successfully returns + a result without raising AssertionError. + + For example: + ``` + package = ExportPackage() + + + def prefill(x, xa, kv_cache): + assert x.shape[1] == 3 + assert kv_cache == {} + + + def decode(x, xa, kv_cache): + assert x.shape[1] > 1 + assert len(kv_cache) > 0 + return {...} # dynamic shape specs here + + + exporter = ( + package.exporter(decoder) + .define_overload("prefill", prefill) + .define_overload("decode", decode) + ) + ``` + + A "fallback" is exported when no overload precondition matches a given set of sample + inputs. Overloads should + Fallbacks don't have names and are ordered in a list. It's up to a backend to decide + which fallback is used amony multiple ones. + + A reference backend implementation of ExportMethod may look like the following: + ``` + def execute(method: ExportMethod, *args, **kwargs): + for overload in method.overloads: + if match_precondition(overload, *args, **kwargs): + return execute_overload(overload, *args, **kwargs) + for fallback in method.fallbacks: + if match_precondition(fallback, *args, **kwargs): + return execute_fallback(fallback, *args, **kwargs) + ``` + + Args: + method(str): The method name for an exported part of PyTorch model. This + will be saved together with the exported/compiled artifacts + in any serialization format and can be used as the key to + index ExportPackage methods later. + fn(callable): A PyTorch function/module to be exported. + fallback(str): The fallback policy to decide when to call torch.export + - "once" is the default policy. Under this policy a PyTorch program is assumed + to be only called once later and an error will be raised for subsequent + runs. + - "error" means the ExportMethod will never have any fallbacks, meaning + users should define all the possible overloads ahead of time. + + """ + + fallbacks: list[torch.export.ExportedProgram] = [] + specs: dict[str, typing.Callable[_InputT, typing.Any]] = {} + overloads: dict[str, torch.export.ExportedProgram] = {} + self.methods[method] = _ExportMethod(fallbacks=fallbacks, overloads=overloads) + + @functools.wraps(fn) + def _exporter_context(*args, **kwargs): # type: ignore[no-untyped-def] + import torch.export._wrapper_utils + + model: torch.nn.Module + if not isinstance(fn, torch.nn.Module): + model = torch.export._wrapper_utils._WrapperModule(fn) + else: + model = fn + + for k, v in specs.items(): + try: + if isinstance(fn, torch.nn.Module): + dynamic_shapes = v(fn, *args, **kwargs) # type: ignore[arg-type] + else: + # pyrefly: ignore [invalid-param-spec] + dynamic_shapes = v(*args, **kwargs) + except AssertionError: + continue + if k not in overloads: + ep = torch.export.export( + model, args, kwargs, dynamic_shapes=dynamic_shapes + ) + overloads[k] = ep + ep = overloads[k] + return ep.module()(*args, **kwargs) + + if fallback == "error": + raise RuntimeError( + f"Exporter: Cannot export fallback {fn} when fallback policy is set to 'error'," + + "please specify an overload or adjust the fallback policy." + ) + elif fallback == "once": + if len(fallbacks) > 0: + raise RuntimeError( + f"Exporter: Cannot export {fn} more than once, " + + "please specify an overload or adjust the fallback policy." + ) + else: + raise RuntimeError(f"Unknown fallback policy: {fallback}") + ep = torch.export.export(model, args, kwargs) + + fallbacks.append(ep) + return ep.module()(*args, **kwargs) + + if isinstance(fn, torch.nn.Module): + _exporter_context = torch._dynamo.eval_frame.OptimizedModule( # type: ignore[assignment] # noqa: F811 + fn, + lambda _: _exporter_context, # type: ignore[arg-type] + ) + + def _define_overload( + overload: str, spec: typing.Callable[_InputT, typing.Any] + ) -> typing.Any: + assert overload not in specs + assert callable(spec) + assert overload.isidentifier() + specs[overload] = spec + return _exporter_context + + assert not hasattr(fn, "_define_overload") + _exporter_context._define_overload = _define_overload # type: ignore[attr-defined] + + # pyrefly: ignore [bad-return] + return _exporter_context + + @property + def _method_overloads( + self, + ) -> typing.Iterator[tuple[str, torch.export.ExportedProgram]]: + for method, method_data in self.methods.items(): + for overload, ep in method_data.overloads.items(): + yield f"{method}:{overload}", ep + + def _compiled_and_package( + self, + f: torch.types.FileLike, + standalone: bool = False, + package_example_inputs: bool = False, + ) -> None: + options: dict[str, typing.Any] = { + "aot_inductor.package": True, + "aot_inductor.package_cpp_only": True, + "always_keep_tensor_constants": True, + # we'll change this back to False once we enable weight deduping for standalone mode + "aot_inductor.package_constants_in_so": standalone, + "aot_inductor_mode.compile_standalone": standalone, + } + aoti_files_map = {} + model_names = [] + for name, ep in self._method_overloads: + name = name.replace(":", "__") + model_names.append(name) + options["aot_inductor.model_name_for_generated_files"] = name + aoti_files = torch._inductor.aot_compile( + ep.module(), # type: ignore[arg-type] + ep.example_inputs[0], + kwargs=ep.example_inputs[1], + options=options, + ) + # pyrefly: ignore [unsupported-operation] + aoti_files_map[name] = aoti_files + + from torch._inductor.package import package + + pt2_path = package.package_aoti( + f, + aoti_files_map, # type: ignore[arg-type] + ) + + if not standalone: + return + + assert isinstance(pt2_path, str) + base_directory = os.path.dirname(pt2_path) + package_name = os.path.basename(pt2_path)[:-4] + with ( + zipfile.ZipFile(pt2_path, "r") as zip_ref, + ): + zip_ref.extractall(base_directory) + + example_inputs_map: dict[str, int] | None = ( + {} if package_example_inputs else None + ) + use_cuda = False + for name, ep in self._method_overloads: + name = name.replace(":", "__") + # TODO: also dump kwargs + # TODO: currently only support list of Tensors and they need to be on the same device + if not ep.example_inputs: + continue + for inp in ep.example_inputs[0]: + if isinstance(inp, torch.Tensor) and inp.device.type == "cuda": + # TODO: more carefully determine the device type + use_cuda = True + if package_example_inputs: + assert example_inputs_map is not None + example_inputs_map[name] = len(ep.example_inputs[0]) + for i, t in enumerate(ep.example_inputs[0]): + path = Path(base_directory) / f"{name}_input_{i}.pt" + torch.save(t, path) + + # Detect if ROCm is being used + is_hip = torch.version.hip is not None + cmake_file_str = _get_make_file(package_name, model_names, use_cuda, is_hip) + + with open(Path(base_directory) / "CMakeLists.txt", "w") as file: + file.write(cmake_file_str) + + main_file_str = _get_main_cpp_file( + package_name, model_names, use_cuda, example_inputs_map, is_hip + ) + with open(Path(base_directory) / "main.cpp", "w") as file: + file.write(main_file_str) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/export/experimental/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/export/experimental/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5f7a6d2be0a04e60a84401b6d5a8c664b2d1a491 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/export/experimental/__pycache__/__init__.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/export/experimental/__pycache__/_utils.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/export/experimental/__pycache__/_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..53626f2920403e17cc0c04f4b4504c4cdb0d87b4 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/export/experimental/__pycache__/_utils.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/export/experimental/_utils.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/export/experimental/_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..1005effe2f299a2bd33ac0517e24b46d840bf675 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/export/experimental/_utils.py @@ -0,0 +1,238 @@ +import logging + +from torch._inductor.utils import IndentedBuffer + + +__all__ = [] # type: ignore[var-annotated] +logger = logging.getLogger(__name__) + + +def _get_main_cpp_file( + package_name: str, + model_names: list[str], + cuda: bool, + example_inputs_map: dict[str, int] | None, + is_hip: bool, +) -> str: + """ + Generates a main.cpp file for AOTInductor standalone models in the specified package. + + Args: + package_name (str): Name of the package containing the models. + model_names (List[str]): List of model names to include in the generated main.cpp. + cuda (bool): Whether to generate code with CUDA support. + example_inputs_map (Optional[Dict[str, List[Tensor]]]): A mapping from model name to + its list of example input tensors. If provided, the generated main.cpp will + load and run these inputs. + + Returns: + str: The contents of the generated main.cpp file as a string. + """ + + ib = IndentedBuffer() + + ib.writelines( + [ + "#include ", + "#include ", + "#include ", + "#include ", + "#include ", + "#include ", + "#include ", + ] + ) + if cuda: + if is_hip: + ib.writelines( + [ + "#include ", + ] + ) + + else: + ib.writelines( + [ + "#include ", + "#include ", + ] + ) + + for model_name in model_names: + ib.writeline( + f'#include "{package_name}/data/aotinductor/{model_name}/{model_name}.h"' + ) + + ib.newline() + for model_name in model_names: + ib.writeline(f"using torch::aot_inductor::AOTInductorModel{model_name};") + + ib.writelines( + [ + "using torch::aot_inductor::ConstantHandle;", + "using torch::aot_inductor::ConstantMap;", + "", + "int main(int argc, char* argv[]) {", + ] + ) + + with ib.indent(): + ib.writeline(f'std::string device_str = "{"cuda" if cuda else "cpu"}";') + ib.writeline("try {") + + with ib.indent(): + ib.writeline("c10::Device device(device_str);") + + if example_inputs_map is not None: + # TODO: add device + for i, model_name in enumerate(model_names): + num_inputs = example_inputs_map[model_name] + + ib.writeline(f"// Load input tensors for model {model_name}") + ib.writeline(f"std::vector input_tensors{i + 1};") + ib.writeline(f"for (int j = 0; j < {num_inputs}; ++j) {{") + with ib.indent(): + ib.writeline( + f'std::string filename = "{model_name}_input_" + std::to_string(j) + ".pt";' + ) + ib.writeline("std::ifstream in(filename, std::ios::binary);") + ib.writeline("if (!in.is_open()) {") + with ib.indent(): + ib.writeline( + 'std::cerr << "Failed to open file: " << filename << std::endl;' + ) + ib.writeline("return 1;") + ib.writeline("}") + ib.writeline( + "std::vector buffer((std::istreambuf_iterator(in)), std::istreambuf_iterator());" + ) + ib.writeline( + "torch::IValue ivalue = torch::pickle_load(buffer);" + ) + ib.writeline( + f"input_tensors{i + 1}.push_back(ivalue.toTensor().to(device));" + ) + ib.writeline("}") + ib.newline() + + ib.newline() + ib.writeline("\n// Create array of input handles") + for i in range(len(model_names)): + ib.writelines( + [ + f"auto input_handles{i + 1} =", + f" torch::aot_inductor::unsafe_alloc_new_handles_from_tensors(input_tensors{i + 1});", + ] + ) + + ib.writeline("\n// Create array for output handles") + for i in range(len(model_names)): + ib.writeline(f"AtenTensorHandle output_handle{i + 1};") + + ib.writeline("\n// Create and load models") + for i, model_name in enumerate(model_names): + ib.writelines( + [ + f"auto constants_map{i + 1} = std::make_shared();", + f"auto constants_array{i + 1} = std::make_shared>();", + f"auto model{i + 1} = std::make_unique(", + f" std::move(constants_map{i + 1}),", + f" std::move(constants_array{i + 1}),", + " device_str,", + f' "{package_name}/data/aotinductor/{model_name}/");', + f"model{i + 1}->load_constants();", + ] + ) + + if example_inputs_map is not None: + ib.writeline("\n// Run the models") + for i in range(len(model_names)): + ib.writeline( + f"torch::aot_inductor::DeviceStreamType stream{i + 1} = nullptr;" + ) + ib.writeline( + f"model{i + 1}->run(&input_handles{i + 1}[0], &output_handle{i + 1}, stream{i + 1}, nullptr);" + ) + + ib.writeline("\n// Convert output handles to tensors") + for i in range(len(model_names)): + ib.writelines( + [ + f"auto output_tensor{i + 1} =", + f" torch::aot_inductor::alloc_tensors_by_stealing_from_handles(&output_handle{i + 1}, 1);", + ] + ) + + ib.writeline("\n// Validate outputs") + for i in range(len(model_names)): + ib.writeline( + f"""std::cout << "output_tensor{i + 1}\\n" << output_tensor{i + 1} << std::endl;""" + ) + ib.writeline( + f"""torch::save(output_tensor{i + 1}, "output_tensor{i + 1}.pt");""" + ) + + ib.writeline("return 0;") + + ib.writelines( + [ + "} catch (const std::exception &e) {", + ] + ) + with ib.indent(): + ib.writeline('std::cerr << "Error: " << e.what() << std::endl;') + ib.writeline("return 1;") + + ib.writeline("}") + ib.writeline("}") + + return ib.getvalue() + + +def _get_make_file( + package_name: str, model_names: list[str], cuda: bool, is_hip: bool +) -> str: + ib = IndentedBuffer() + + ib.writelines( + [ + "cmake_minimum_required(VERSION 3.10)", + "project(TestProject)", + "", + "set(CMAKE_CXX_STANDARD 17)", + "", + ] + ) + + from torch._inductor.config import test_configs + + if test_configs.use_libtorch: + ib.writeline("find_package(Torch REQUIRED)") + + if cuda: + if is_hip: + ib.writeline("find_package(hip REQUIRED)") + else: + ib.writeline("find_package(CUDA REQUIRED)") + + ib.newline() + for model_name in model_names: + ib.writeline(f"add_subdirectory({package_name}/data/aotinductor/{model_name}/)") + + ib.writeline("\nadd_executable(main main.cpp)") + if cuda: + if is_hip: + ib.writeline("target_compile_definitions(main PRIVATE USE_HIP)") + else: + ib.writeline("target_compile_definitions(main PRIVATE USE_CUDA)") + + model_libs = " ".join(model_names) + ib.writeline(f"target_link_libraries(main PRIVATE torch {model_libs})") + + if cuda: + if is_hip: + ib.writeline("target_link_libraries(main PRIVATE hip::host)") + else: + ib.writeline("target_link_libraries(main PRIVATE cuda ${CUDA_LIBRARIES})") + + return ib.getvalue() diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/export/passes/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/export/passes/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ca5634bd4eadb7a80ddb7521ec0dae26fb2cfec5 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/export/passes/__init__.py @@ -0,0 +1,97 @@ +from typing import Union + +import torch +import torch.utils._pytree as pytree +from torch.export.exported_program import ExportedProgram + + +__all__ = ["move_to_device_pass"] + + +def move_to_device_pass( + ep: ExportedProgram, location: torch.device | str | dict[str, str] +) -> ExportedProgram: + """ + Move the exported program to the given device. + + Args: + ep (ExportedProgram): The exported program to move. + location (Union[torch.device, str, Dict[str, str]]): The device to move the exported program to. + If a string, it is interpreted as a device name. + If a dict, it is interpreted as a mapping from + the existing device to the intended one + + Returns: + ExportedProgram: The moved exported program. + """ + + def _get_new_device( + curr_device: torch.device, + location: torch.device | str | dict[str, str], + ) -> str: + if isinstance(location, dict): + if str(curr_device) in location: + return location[str(curr_device)] + else: + return str(curr_device) + else: + return str(location) + + # move all the state_dict + for k, v in ep.state_dict.items(): + if isinstance(v, torch.nn.Parameter): + ep._state_dict[k] = torch.nn.Parameter( + v.to(_get_new_device(v.device, location)), + v.requires_grad, + ) + else: + ep._state_dict[k] = v.to(_get_new_device(v.device, location)) + + # move all the constants + for k, v in ep.constants.items(): + if isinstance(v, torch.Tensor): + ep._constants[k] = v.to(_get_new_device(v.device, location)) + + # move example_inputs if they exist + if ep.example_inputs is not None: + args, kwargs = ep.example_inputs + moved_args = pytree.tree_map_only( + torch.Tensor, + lambda tensor: tensor.to(_get_new_device(tensor.device, location)), + args, + ) + moved_kwargs = pytree.tree_map_only( + torch.Tensor, + lambda tensor: tensor.to(_get_new_device(tensor.device, location)), + kwargs, + ) + ep._example_inputs = (moved_args, moved_kwargs) + + for m in ep.graph_module.modules(): + if isinstance(m, torch.fx.GraphModule): + for node in m.graph.nodes: + # move all the nodes kwargs with burnt-in device + if "device" in node.kwargs: + kwargs = node.kwargs.copy() + kwargs["device"] = _get_new_device(kwargs["device"], location) + node.kwargs = kwargs + + if ( + node.op == "call_function" + and node.target is torch.ops.aten.to.device + ): + args = list(node.args) + # pyrefly: ignore [unsupported-operation] + args[1] = _get_new_device(args[1], location) + node.args = tuple(args) + + # move all the tensor metadata + node.meta["val"] = pytree.tree_map( + lambda v: v.to(_get_new_device(v.device, location)) + if isinstance(v, torch.Tensor) + else v, + node.meta.get("val"), + ) + + ep.validate() + return ep diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/export/passes/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/export/passes/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fc20cffecd65a69cc299a78c16a3c26123c27152 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/export/passes/__pycache__/__init__.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/export/pt2_archive/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/export/pt2_archive/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b2bf26a275d9eef91f4b6807ac472b2cd0c30b0f --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/export/pt2_archive/__init__.py @@ -0,0 +1,4 @@ +from ._package import is_pt2_package, PT2ArchiveReader, PT2ArchiveWriter + + +__all__ = ["PT2ArchiveWriter", "PT2ArchiveReader", "is_pt2_package"] diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/export/pt2_archive/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/export/pt2_archive/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ea9bf62a180e7ca40781c38309bcb40fc4c937bb Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/export/pt2_archive/__pycache__/__init__.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/export/pt2_archive/__pycache__/_package.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/export/pt2_archive/__pycache__/_package.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6eec452b95d5eb69f587ed4fecedaaf4a450126e Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/export/pt2_archive/__pycache__/_package.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/export/pt2_archive/__pycache__/_package_weights.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/export/pt2_archive/__pycache__/_package_weights.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..98054f136493339801f18e9188e6cdfd07324f8f Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/export/pt2_archive/__pycache__/_package_weights.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/export/pt2_archive/__pycache__/constants.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/export/pt2_archive/__pycache__/constants.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0bd8405c3f8eeb170346e93a843317cfbec46afc Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/export/pt2_archive/__pycache__/constants.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/export/pt2_archive/_package.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/export/pt2_archive/_package.py new file mode 100644 index 0000000000000000000000000000000000000000..1b46db0958d28b37a602686b34e400c17cecacb3 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/export/pt2_archive/_package.py @@ -0,0 +1,1204 @@ +import glob +import io +import json +import logging +import os +import tempfile +import zipfile +from dataclasses import dataclass +from typing import Any, IO, TYPE_CHECKING, TypeAlias + +import torch +import torch.utils._pytree as pytree +from torch._export.serde import schema +from torch._export.serde.serialize import ( + _dataclass_to_dict, + _dict_to_dataclass, + deserialize_device, + deserialize_scalar_type, + deserialize_size, + deserialize_storage_offset, + deserialize_stride, + ExportedProgramDeserializer, + serialize, + serialize_tensor_meta, + SerializedArtifact, +) +from torch._inductor.cpp_builder import normalize_path_separator +from torch._subclasses.fake_tensor import FakeTensor +from torch.export import ExportedProgram +from torch.export._tree_utils import reorder_kwargs +from torch.export.pt2_archive._package_weights import ( + get_complete, + group_weights, + TensorProperties, + Weights, +) +from torch.export.pt2_archive.constants import ( + AOTINDUCTOR_DIR, + ARCHIVE_FORMAT_PATH, + ARCHIVE_FORMAT_VALUE, + ARCHIVE_VERSION_PATH, + ARCHIVE_VERSION_VALUE, + CONSTANTS_CONFIG_FILENAME_FORMAT, + CONSTANTS_DIR, + CUSTOM_OBJ_FILENAME_PREFIX, + EXECUTORCH_DIR, + EXTRA_DIR, + MODELS_DIR, + MODELS_FILENAME_FORMAT, + SAMPLE_INPUTS_FILENAME_FORMAT, + TENSOR_CONSTANT_FILENAME_PREFIX, + WEIGHT_FILENAME_PREFIX, + WEIGHTS_CONFIG_FILENAME_FORMAT, + WEIGHTS_DIR, +) +from torch.types import FileLike + + +if TYPE_CHECKING: + from torch.utils._ordered_set import OrderedSet + + +DEFAULT_PICKLE_PROTOCOL = 2 +AOTI_FILES: TypeAlias = list[str | Weights] | dict[str, list[str | Weights]] + + +logger: logging.Logger = logging.getLogger(__name__) + + +def is_pt2_package(serialized_model: bytes | str) -> bool: + """ + Check if the serialized model is a PT2 Archive package. + """ + try: + with zipfile.ZipFile( + io.BytesIO(serialized_model) + if isinstance(serialized_model, bytes) + else serialized_model + ) as zip_reader: + root_folder = zip_reader.namelist()[0].split(os.path.sep)[0] + archive_format_path = f"{root_folder}/{ARCHIVE_FORMAT_PATH}" + if archive_format_path in zip_reader.namelist(): + return zip_reader.read(archive_format_path) == b"pt2" + except Exception: + logger.info("Model is not a PT2 package") + return False + + +class PT2ArchiveWriter: + """ + Context manager for writing a PT2 archive. + """ + + def __init__(self, archive_path_or_buffer: FileLike): + if isinstance(archive_path_or_buffer, str): + archive_path_or_buffer = normalize_path_separator(archive_path_or_buffer) + self.archive_file = torch._C.PyTorchFileWriter(archive_path_or_buffer) # type: ignore[arg-type] + # NOTICE: version here is different from the archive_version + # this is the version of zip file format, which is used by PyTorchFileWriter, which write to /.data/version + # archive_version is the version of the PT2 archive spec, which write to /archive_version + self.archive_file.set_min_version(6) + + def __enter__(self) -> "PT2ArchiveWriter": + return self + + def __exit__(self, *args: Any) -> None: + if not self.has_record(ARCHIVE_FORMAT_PATH): + self.write_string(ARCHIVE_FORMAT_PATH, ARCHIVE_FORMAT_VALUE) + + if not self.has_record(ARCHIVE_VERSION_PATH): + self.write_string(ARCHIVE_VERSION_PATH, ARCHIVE_VERSION_VALUE) + + self.close() + + def has_record(self, name: str) -> bool: + """ + Check if a record exists in the archive. + """ + return name in self.archive_file.get_all_written_records() + + def count_prefix(self, prefix: str) -> int: + """ + Count the number of records that start with a given prefix. + """ + return sum( + 1 + for record in self.archive_file.get_all_written_records() + if record.startswith(prefix) + ) + + def write_bytes(self, name: str, data: bytes) -> None: + """ + Write a bytes object to the archive. + name: The destination file inside the archive. + data: The bytes object to write. + """ + assert isinstance(data, bytes), f"Expected bytes but got {type(data)}" + self.archive_file.write_record(name, data, len(data)) + + def write_string(self, name: str, data: str) -> None: + """ + Write a string object to the archive. + name: The destination file inside the archive. + data: The string object to write. + """ + assert isinstance(data, str), f"Expected string but got {type(data)}" + data_bytes = data.encode() + self.write_bytes(name, data_bytes) + + def write_file(self, name: str, file_path: str) -> None: + """ + Copy a file into the archive. + name: The destination file inside the archive. + file_path: The source file on disk. + """ + assert os.path.isfile(file_path), f"{file_path} is not a valid file path" + + with open(file_path, "rb") as f: + file_bytes = f.read() + self.write_bytes(name, file_bytes) + + def write_folder(self, archive_dir: str, folder_dir: str) -> None: + """ + Copy a folder into the archive. + archive_dir: The destination folder inside the archive. + folder_dir: The source folder on disk. + """ + assert os.path.isdir(folder_dir), f"{folder_dir} is not a valid directory path" + + file_paths = filter( + os.path.isfile, glob.glob(f"{folder_dir}/**", recursive=True) + ) + for file_path in file_paths: + # pyrefly: ignore [no-matching-overload] + filename = os.path.relpath(file_path, folder_dir) + archive_path = os.path.join(archive_dir, filename) + # pyrefly: ignore [bad-argument-type] + self.write_file(archive_path, file_path) + + def close(self) -> None: + """ + Close the archive. + """ + self.archive_file.write_end_of_file() + + +class PT2ArchiveReader: + """ + Context manager for reading a PT2 archive. + """ + + def __init__(self, archive_path_or_buffer: FileLike): + if isinstance(archive_path_or_buffer, str): + archive_path_or_buffer = normalize_path_separator(archive_path_or_buffer) + self.archive_file = torch._C.PyTorchFileReader(archive_path_or_buffer) # type: ignore[arg-type] + assert self.read_string(ARCHIVE_FORMAT_PATH) == ARCHIVE_FORMAT_VALUE, ( + "Invalid archive format" + ) + + def __enter__(self) -> "PT2ArchiveReader": + return self + + def __exit__(self, *args: Any) -> None: + # torch._C.PyTorchFileReader doesn't have a close method + pass + + def read_bytes(self, name: str) -> bytes: + """ + Read a bytes object from the archive. + name: The source file inside the archive. + """ + return self.archive_file.get_record(name) + + def read_string(self, name: str) -> str: + """ + Read a string object from the archive. + name: The source file inside the archive. + """ + data = self.read_bytes(name) + return data.decode() + + def archive_version(self) -> int: + """ + Get the archive version. + """ + try: + archive_version = self.read_string(ARCHIVE_VERSION_PATH) + except Exception: + # if archive_version is not found, it means the archive is older than version 0. + # In this case, we assume the archive is version 0. + archive_version = "0" + + return int(archive_version) + + def get_file_names(self) -> list[str]: + """ + Get the file names in the archive. + """ + return self.archive_file.get_all_records() + + +is_pt2_package.__module__ = "torch.export.pt2_archive" +PT2ArchiveWriter.__module__ = "torch.export.pt2_archive" +PT2ArchiveReader.__module__ = "torch.export.pt2_archive" + + +def _package_aoti_files( + archive_writer: PT2ArchiveWriter, + aoti_files: AOTI_FILES | None, + pickle_protocol: int = DEFAULT_PICKLE_PROTOCOL, +) -> None: + if aoti_files is None: + return + + if isinstance(aoti_files, list): + aoti_files = {"model": aoti_files} + + assert isinstance(aoti_files, dict) + + all_weights: dict[str, Weights] = {} # model_name -> weight + weights_configs: dict[ + str, dict[str, Any] + ] = {} # model_name -> (weight_name -> (filename, shape, stride, offset)) + + for model_name, files in aoti_files.items(): + num_so_files = 0 + weights_configs[model_name] = {} + + for file in files: + if file == "": + continue + + if isinstance(file, Weights): + all_weights[model_name] = file + continue + + if file.endswith(".so"): + num_so_files += 1 + if num_so_files > 1: + raise RuntimeError( + f"Multiple .so files found in {files}. " + "You might need to clear your cache " + "directory before calling aoti_compile again." + ) + + filename = os.path.basename(file) + if filename.startswith(CUSTOM_OBJ_FILENAME_PREFIX): + new_filepath = os.path.join(CONSTANTS_DIR, filename) + else: + new_filepath = os.path.join(AOTINDUCTOR_DIR, model_name, filename) + logger.debug( + "Saving AOTI generated file %s to archive in %s", file, new_filepath + ) + archive_writer.write_file( + str(new_filepath), + file, + ) + + if len(all_weights) > 0: + # Dedup weights + grouped_tensors: list[OrderedSet[tuple[str, str]]] = group_weights(all_weights) + for idx, group in enumerate(grouped_tensors): + filename = f"{WEIGHT_FILENAME_PREFIX}{idx}" + model_name, weight_name = get_complete(group, all_weights) + complete_tensor, _ = all_weights[model_name].get_weight(weight_name) + buffer = io.BytesIO() + torch.save(complete_tensor, buffer, pickle_protocol=pickle_protocol) + archive_writer.write_bytes( + os.path.join(WEIGHTS_DIR, filename), buffer.getvalue() + ) + for model_name, weight_name in group: + _, w_property = all_weights[model_name].get_weight(weight_name) + weights_configs[model_name][weight_name] = ( + filename, + w_property.shape, + w_property.stride, + w_property.offset, + ) + + for model_name, weights_config in weights_configs.items(): + archive_writer.write_string( + os.path.join(AOTINDUCTOR_DIR, model_name, "weights_config.json"), + json.dumps(weights_config), + ) + logger.debug("packaging weights_config for model %s", model_name) + logger.debug(weights_config) + + +def _is_fake_tensor(t: torch.Tensor) -> bool: + return isinstance(t, FakeTensor) + + +def _is_tensor_subclass(t: torch.Tensor) -> bool: + return isinstance(t, torch.Tensor) and type(t.data) is not torch.Tensor + + +def _get_raw_tensor_bytes(value: torch.Tensor) -> bytes: + """ + Get the raw bytes of a tensor. This is used to save the tensor in pt2 archive. + """ + # NOTE: don't chain .cpu() with .data_ptr(). If an HtoD copy needs to be + # performed, the CPU copy needs to be kept alive when its underlying + # memory is accessed. + import ctypes + + if _is_fake_tensor(value): + value_bytes = b"" + elif value.data_ptr(): + cpu_tensor = value.cpu() + value_untyped_storage = cpu_tensor.untyped_storage() + # we store the raw bytes the untyped storage. Tensor metadata is stored separately + value_bytes = bytes( + ctypes.cast( + value_untyped_storage.data_ptr(), + ctypes.POINTER(ctypes.c_ubyte * value_untyped_storage.size()), + ).contents + ) + else: + # for empty tensor + value_bytes = b"" + return value_bytes + + +def _should_use_pickle(t: torch.Tensor) -> bool: + return _is_tensor_subclass(t) and not _is_fake_tensor(t) + + +def _save_pickled_tensors( + pickled_items: list[tuple[str, torch.Tensor]], + archive_writer: PT2ArchiveWriter, + config: dict[str, schema.PayloadMeta], + directory: str, + filename_prefix: str, + idx: int, + pickle_protocol: int = DEFAULT_PICKLE_PROTOCOL, +) -> int: + """Save pickled tensors and update config. Returns updated index.""" + for item_fqn, tensor in pickled_items: + path_name = f"{filename_prefix}{idx}" + archive_path = os.path.join(directory, path_name) + buffer = io.BytesIO() + torch.save(tensor, buffer, pickle_protocol=pickle_protocol) + archive_writer.write_bytes(archive_path, buffer.getvalue()) + + config[item_fqn] = schema.PayloadMeta( + path_name=path_name, + is_param=isinstance(tensor, torch.nn.Parameter), + use_pickle=True, + tensor_meta=serialize_tensor_meta(tensor), + ) + idx += 1 + return idx + + +def _save_raw_tensors( + raw_items: dict[str, tuple[torch.Tensor, TensorProperties]], + model_name: str, + archive_writer: PT2ArchiveWriter, + config: dict[str, schema.PayloadMeta], + directory: str, + filename_prefix: str, + idx: int, +) -> int: + """Save deduplicated raw tensor bytes and update config. Returns updated index.""" + if not raw_items: + return idx + + weights_dict = {model_name: Weights(raw_items)} + storage_groups = group_weights(weights_dict) + + for group in storage_groups: + # Find the complete tensor that covers all others in this storage group + model_name, complete_item_name = get_complete(group, weights_dict) + complete_tensor, _ = weights_dict[model_name].get_weight(complete_item_name) + + path_name = f"{filename_prefix}{idx}" + archive_path = os.path.join(directory, path_name) + tensor_bytes = _get_raw_tensor_bytes(complete_tensor) + archive_writer.write_bytes(archive_path, tensor_bytes) + idx += 1 + + for _, item_fqn in group: + tensor, _ = weights_dict[model_name].get_weight(item_fqn) + config[item_fqn] = schema.PayloadMeta( + path_name=path_name, + is_param=isinstance(tensor, torch.nn.Parameter), + use_pickle=False, + tensor_meta=serialize_tensor_meta(tensor), + ) + + return idx + + +def _package_state_dict( + model_name: str, + exported_program: ExportedProgram, + archive_writer: PT2ArchiveWriter, + pickle_protocol: int = DEFAULT_PICKLE_PROTOCOL, +) -> schema.PayloadConfig: + weights_config: dict[str, schema.PayloadMeta] = {} + + pickled_weights: list[tuple[str, torch.Tensor]] = [] + raw_weights: dict[str, tuple[torch.Tensor, TensorProperties]] = {} + + # Categorize weights + for weight_fqn, weight_tensor in exported_program.state_dict.items(): + assert isinstance(weight_tensor, torch.Tensor), ( + "only torch.Tensor is allowed in state_dict" + ) + if _should_use_pickle(weight_tensor): + pickled_weights.append((weight_fqn, weight_tensor)) + else: + raw_weights[weight_fqn] = (weight_tensor, TensorProperties(weight_tensor)) + + idx = archive_writer.count_prefix(os.path.join(WEIGHTS_DIR, WEIGHT_FILENAME_PREFIX)) + + # Save weights in pickle format + idx = _save_pickled_tensors( + pickled_weights, + archive_writer, + weights_config, + WEIGHTS_DIR, + WEIGHT_FILENAME_PREFIX, + idx, + pickle_protocol, + ) + + # Save weights in raw bytes format + _save_raw_tensors( + raw_weights, + model_name, + archive_writer, + weights_config, + WEIGHTS_DIR, + WEIGHT_FILENAME_PREFIX, + idx, + ) + + return schema.PayloadConfig(config=weights_config) + + +def _package_constants( + model_name: str, + exported_program: ExportedProgram, + archive_writer: PT2ArchiveWriter, + pickle_protocol: int = DEFAULT_PICKLE_PROTOCOL, +) -> schema.PayloadConfig: + constants_config: dict[str, schema.PayloadMeta] = {} + + pickled_constants: list[tuple[str, torch.Tensor]] = [] + raw_constants: dict[str, tuple[torch.Tensor, TensorProperties]] = {} + custom_objects: list[tuple[str, torch._C.ScriptObject]] = [] + + # Categorize constants + for constant_fqn, constant in exported_program.constants.items(): + if isinstance(constant, torch.Tensor): + if _should_use_pickle(constant): + pickled_constants.append((constant_fqn, constant)) + else: + raw_constants[constant_fqn] = (constant, TensorProperties(constant)) + + elif isinstance(constant, torch._C.ScriptObject): + custom_objects.append((constant_fqn, constant)) + + else: + raise RuntimeError(f"Unsupported constant type: {type(constant)}") + + tensor_idx = archive_writer.count_prefix( + os.path.join(CONSTANTS_DIR, TENSOR_CONSTANT_FILENAME_PREFIX) + ) + custom_obj_idx = archive_writer.count_prefix( + os.path.join(CONSTANTS_DIR, CUSTOM_OBJ_FILENAME_PREFIX) + ) + + # Save constants in pickle format + tensor_idx = _save_pickled_tensors( + pickled_constants, + archive_writer, + constants_config, + CONSTANTS_DIR, + TENSOR_CONSTANT_FILENAME_PREFIX, + tensor_idx, + pickle_protocol, + ) + + # Save constants in raw bytes format + _save_raw_tensors( + raw_constants, + model_name, + archive_writer, + constants_config, + CONSTANTS_DIR, + TENSOR_CONSTANT_FILENAME_PREFIX, + tensor_idx, + ) + + # Handle custom objects + for constant_fqn, constant in custom_objects: + path_name = f"{CUSTOM_OBJ_FILENAME_PREFIX}{custom_obj_idx}" + archive_path = os.path.join(CONSTANTS_DIR, path_name) + custom_obj_bytes = torch._C._pickle_save(constant) + archive_writer.write_bytes(archive_path, custom_obj_bytes) + + constants_config[constant_fqn] = schema.PayloadMeta( + path_name=path_name, + is_param=False, + use_pickle=True, + tensor_meta=None, + ) + custom_obj_idx += 1 + + return schema.PayloadConfig(config=constants_config) + + +def _package_payload_config( + archive_writer: PT2ArchiveWriter, + payload_config: schema.PayloadConfig, + config_file: str, +) -> None: + """ + Save the payload config as json file in the archive. + """ + archive_writer.write_string( + config_file, json.dumps(_dataclass_to_dict(payload_config)) + ) + + +def _package_exported_programs( + archive_writer: PT2ArchiveWriter, + exported_programs: ExportedProgram | dict[str, ExportedProgram] | None, + opset_version: dict[str, int] | None = None, + pickle_protocol: int = DEFAULT_PICKLE_PROTOCOL, +) -> None: + if exported_programs is None: + return + + if isinstance(exported_programs, ExportedProgram): + exported_programs = {"model": exported_programs} + + assert isinstance(exported_programs, dict) + + for model_name, ep in exported_programs.items(): + weights_config = _package_state_dict( + model_name, ep, archive_writer, pickle_protocol + ) + weights_config_file = WEIGHTS_CONFIG_FILENAME_FORMAT.format(model_name) + _package_payload_config(archive_writer, weights_config, weights_config_file) + + constants_config = _package_constants( + model_name, ep, archive_writer, pickle_protocol + ) + constants_config_file = CONSTANTS_CONFIG_FILENAME_FORMAT.format(model_name) + _package_payload_config(archive_writer, constants_config, constants_config_file) + + artifact: SerializedArtifact = serialize( + ep, + opset_version, + pickle_protocol, + ) + + archive_writer.write_bytes( + MODELS_FILENAME_FORMAT.format(model_name), artifact.exported_program + ) + archive_writer.write_bytes( + SAMPLE_INPUTS_FILENAME_FORMAT.format(model_name), + artifact.example_inputs, + ) + + +def _package_extra_files( + archive_writer: PT2ArchiveWriter, extra_files: dict[str, Any] | None +) -> None: + if extra_files is None: + return + + for extra_file_name, content in extra_files.items(): + archive_writer.write_string(f"{EXTRA_DIR}{extra_file_name}", content) + + +def _package_executorch_files( + archive_writer: PT2ArchiveWriter, executorch_files: dict[str, bytes] | None +) -> None: + if executorch_files is None: + return + + for file_name, content in executorch_files.items(): + archive_writer.write_bytes(f"{EXECUTORCH_DIR}{file_name}", content) + + +def package_pt2( + f: FileLike, + *, + exported_programs: ExportedProgram | dict[str, ExportedProgram] | None = None, + aoti_files: AOTI_FILES | None = None, + extra_files: dict[str, Any] | None = None, + opset_version: dict[str, int] | None = None, + pickle_protocol: int = DEFAULT_PICKLE_PROTOCOL, + executorch_files: dict[str, bytes] | None = None, +) -> FileLike: + r""" + Saves the artifacts to a PT2Archive format. The artifact can then be loaded + using ``load_pt2``. + + Args: + f (str | os.PathLike[str] | IO[bytes]): A file-like object (has to + implement write and flush) or a string containing a file name. + + exported_programs (Union[ExportedProgram, dict[str, ExportedProgram]]): + The exported program to save, or a dictionary mapping model name to an + exported program to save. The exported program will be saved under + models/\*.json. If only one ExportedProgram is specified, this will + automatically be named "model". + + aoti_files (Union[list[str], dict[str, list[str]]]): A list of files + generated by AOTInductor via + ``torch._inductor.aot_compile(..., {"aot_inductor.package": True})``, + or a dictionary mapping model name to its AOTInductor generated files. + If only one set of files is specified, this will automatically be named + "model". + + extra_files (Optional[Dict[str, Any]]): Map from filename to contents + which will be stored as part of the pt2. + + opset_version (Optional[Dict[str, int]]): A map of opset names + to the version of this opset + + pickle_protocol: can be specified to override the default protocol + + executorch_files (Optional[dict[str, bytes]]): Optional executorch + artifacts to save. + + """ + assert not ( + exported_programs is None and aoti_files is None and extra_files is None + ), ( + "No value passed in for `exported_programs`, `aoti_files`, and " + "`extra_files`, implying that you do not plan on saving anything." + ) + + if not ( + (isinstance(f, (io.IOBase, IO)) and f.writable() and f.seekable()) + or (isinstance(f, (str, os.PathLike)) and os.fspath(f).endswith(".pt2")) + or (isinstance(f, tempfile._TemporaryFileWrapper) and f.name.endswith(".pt2")) + ): + # TODO: turn this into an error + logger.warning( + "Expect archive file to be a file ending in .pt2, or is a buffer. " + "Instead got {%s}", + f, + ) + + if isinstance(f, (str, os.PathLike)): + f = os.fspath(f) + + # pyrefly: ignore [bad-argument-type] + with PT2ArchiveWriter(f) as archive_writer: + _package_exported_programs( + archive_writer, exported_programs, pickle_protocol=pickle_protocol + ) + _package_aoti_files( + archive_writer, + aoti_files, + pickle_protocol=pickle_protocol, + ) + _package_extra_files(archive_writer, extra_files) + _package_executorch_files(archive_writer, executorch_files) + + if isinstance(f, (io.IOBase, IO)): + f.seek(0) + # pyrefly: ignore [bad-return] + return f + + +class AOTICompiledModel: + """ + Callable AOT Inductor loaded model from a .pt2 + """ + + def __init__(self, loader: torch._C._aoti.AOTIModelPackageLoader) -> None: + self.loader = loader + + def __call__(self, *args, **kwargs): # type: ignore[no-untyped-def] + call_spec = self.loader.get_call_spec() + in_spec = pytree.treespec_loads(call_spec[0]) + out_spec = pytree.treespec_loads(call_spec[1]) + flat_inputs = pytree.tree_flatten((args, reorder_kwargs(kwargs, in_spec)))[0] + flat_inputs = [x for x in flat_inputs if isinstance(x, torch.Tensor)] + flat_outputs = self.loader.boxed_run(flat_inputs) + return pytree.tree_unflatten(flat_outputs, out_spec) + + def get_metadata(self) -> dict[str, str]: + return self.loader.get_metadata() + + def load_constants( + self, + constants_map: dict[str, torch.Tensor], + *, + check_full_update: bool, + user_managed: bool = False, + ) -> None: + """ + Given a mapping of constant fqns to tensors, load the constants into the model. + You can use ``get_constant_fqns`` to get the list of constant fqns that + are needed in the compiled model. + + Args: + constants_map: A mapping of constant fqns to tensors. + check_full_update: Whether to add check to see if all the constants + are updated and have values. + """ + self.loader.load_constants( + constants_map, False, check_full_update, user_managed + ) + + def get_constant_fqns(self) -> list[str]: + return self.loader.get_constant_fqns() + + def __deepcopy__(self, memo: dict[Any, Any] | None) -> "AOTICompiledModel": + logger.warning( + "AOTICompiledModel deepcopy warning: AOTICompiledModel.loader is not deepcopied." + ) + return AOTICompiledModel(self.loader) + + +@dataclass +class PT2ArchiveContents: + exported_programs: dict[str, ExportedProgram] + aoti_runners: dict[str, AOTICompiledModel] + extra_files: dict[str, Any] + + +def _create_flat_tensor_from_bytes( + tensor_bytes: bytes, + tensor_meta: schema.TensorMeta, +) -> torch.Tensor: + """ + Create a flat tensor from raw bytes with dtype, device and requires_grad. + It will be re-strided based on size, stride, and storage_offset later. + """ + dtype = deserialize_scalar_type(tensor_meta.dtype) + size = deserialize_size(tensor_meta.sizes) + device = deserialize_device(tensor_meta.device) + + if len(tensor_bytes) != 0: + tensor = torch.frombuffer( + tensor_bytes, dtype=dtype, requires_grad=tensor_meta.requires_grad + ).to(device) + else: + # cannot call torch.frombuffer() on empty bytes + logger.warning( + "Cannot call torch.frombuffer() on empty bytes. " + "Creating a tensor with zeros as workaround." + ) + tensor = torch.zeros(size, dtype=dtype, device=device) + + return tensor + + +def _build_file_map( + archive_reader: PT2ArchiveReader, + config: schema.PayloadConfig, + base_dir: str, +) -> dict[str, torch.Tensor]: + """ + Build a map from file path to the payload in flat tensor format. + """ + file_map: dict[str, torch.Tensor] = {} + for payload_meta in config.config.values(): + # skip pickled objects + if payload_meta.use_pickle: + continue + # skip files that already exist in the map + if payload_meta.path_name in file_map: + continue + + tensor_bytes = archive_reader.read_bytes( + os.path.join(base_dir, payload_meta.path_name) + ) + assert payload_meta.tensor_meta is not None + tensor = _create_flat_tensor_from_bytes(tensor_bytes, payload_meta.tensor_meta) + file_map[payload_meta.path_name] = tensor + + return file_map + + +def _load_payload_config( + archive_reader: PT2ArchiveReader, + config_file: str, +) -> schema.PayloadConfig: + """ + Load and parse a payload config from the archive. + """ + return _dict_to_dataclass( + schema.PayloadConfig, + json.loads(archive_reader.read_string(config_file)), + ) + + +def _load_state_dict( + archive_reader: PT2ArchiveReader, + model_name: str, +) -> dict[str, torch.Tensor] | bytes: + # Make it BC compatible with legacy weight files + legacy_weights_file = f"{WEIGHTS_DIR}{model_name}.pt" + if legacy_weights_file in archive_reader.get_file_names(): + logger.warning( + "You are loading weight from the legacy format. " + "Please generate a new pt2 file using torch.export.save()." + ) + return archive_reader.read_bytes(legacy_weights_file) + else: + weights_config_file = WEIGHTS_CONFIG_FILENAME_FORMAT.format(model_name) + assert weights_config_file in archive_reader.get_file_names(), ( + f"{weights_config_file} not found in PT2 archive" + ) + weights_config = _load_payload_config(archive_reader, weights_config_file) + # construct the mapping from file name (e.g. weight_0) to flat weight payload + state_dict_file_map = _build_file_map( + archive_reader, weights_config, WEIGHTS_DIR + ) + # chain the mapping weight FQN -> weight file name -> strided weight payload + # so that the aliasing of weights is preserved + state_dict: dict[str, torch.Tensor] = {} + for weight_fqn, payload_meta in weights_config.config.items(): + if payload_meta.use_pickle: + weight_bytes = archive_reader.read_bytes( + os.path.join(WEIGHTS_DIR, payload_meta.path_name) + ) + state_dict[weight_fqn] = torch.load( + io.BytesIO(weight_bytes), weights_only=False + ) + else: + tensor_meta = payload_meta.tensor_meta + assert tensor_meta is not None + weight_tensor = torch.as_strided( + input=state_dict_file_map[payload_meta.path_name], + size=deserialize_size(tensor_meta.sizes), + stride=deserialize_stride(tensor_meta.strides), + storage_offset=deserialize_storage_offset( + tensor_meta.storage_offset + ), + ) + if payload_meta.is_param: + state_dict[weight_fqn] = torch.nn.Parameter( + weight_tensor, requires_grad=tensor_meta.requires_grad + ) + else: + state_dict[weight_fqn] = weight_tensor + + return state_dict + + +def _load_constants( + archive_reader: PT2ArchiveReader, + model_name: str, +) -> dict[str, torch.Tensor] | bytes: + # Make it BC compatible with legacy constant files + legacy_constants_file = f"{CONSTANTS_DIR}{model_name}.pt" + if legacy_constants_file in archive_reader.get_file_names(): + logger.warning( + "You are loading constant from the legacy format. " + "Please generate a new pt2 file using torch.export.save()." + ) + return archive_reader.read_bytes(legacy_constants_file) + else: + constants_config_file = CONSTANTS_CONFIG_FILENAME_FORMAT.format(model_name) + assert constants_config_file in archive_reader.get_file_names(), ( + f"{constants_config_file} not found in PT2 archive" + ) + constants_config = _load_payload_config(archive_reader, constants_config_file) + # construct the mapping from file name (e.g. constant_0) to constant payload + constant_file_map = _build_file_map( + archive_reader, constants_config, CONSTANTS_DIR + ) + # chain the mapping constant FQN -> constant file name -> strided constant payload + # so that the aliasing of constants is preserved + constants: dict[str, torch.Tensor] = {} + for constant_fqn, payload_meta in constants_config.config.items(): + path_name = payload_meta.path_name + if path_name.startswith(TENSOR_CONSTANT_FILENAME_PREFIX): + if payload_meta.use_pickle: + constant_bytes = archive_reader.read_bytes( + os.path.join(CONSTANTS_DIR, path_name) + ) + constants[constant_fqn] = torch.load( + io.BytesIO(constant_bytes), weights_only=False + ) + else: + tensor_meta = payload_meta.tensor_meta + assert tensor_meta is not None + constant_tensor = torch.as_strided( + input=constant_file_map[path_name], + size=deserialize_size(tensor_meta.sizes), + stride=deserialize_stride(tensor_meta.strides), + storage_offset=deserialize_storage_offset( + tensor_meta.storage_offset + ), + ) + constants[constant_fqn] = constant_tensor + + elif path_name.startswith(CUSTOM_OBJ_FILENAME_PREFIX): + constant_bytes = archive_reader.read_bytes( + os.path.join(CONSTANTS_DIR, path_name) + ) + constants[constant_fqn] = torch._C._pickle_load_obj(constant_bytes) + + else: + raise RuntimeError(f"Unsupported constant type: {path_name}") + + return constants + + +def _load_exported_programs( + archive_reader: PT2ArchiveReader, + file_names: list[str], + expected_opset_version: dict[str, int] | None, +) -> dict[str, ExportedProgram]: + exported_program_files = [ + file for file in file_names if file.startswith(MODELS_DIR) + ] + exported_programs = {} + for file in exported_program_files: + prefix, suffix = MODELS_FILENAME_FORMAT.split( + "{}" + ) # split "models/{}.json" into "models/" and "json" + model_name = file[ + len(prefix) : -len(suffix) + ] # given "models/foo.json" we can now get "foo" + + sample_inputs_file = SAMPLE_INPUTS_FILENAME_FORMAT.format(model_name) + serialized_sample_inputs = archive_reader.read_bytes(sample_inputs_file) + + from torch._export.serde.serialize import _bytes_to_dataclass + + exported_program_bytes = archive_reader.read_bytes(file) + serialized_exported_program = _bytes_to_dataclass( + schema.ExportedProgram, exported_program_bytes + ) + state_dict = _load_state_dict(archive_reader, model_name) + constants = _load_constants(archive_reader, model_name) + + ep = ExportedProgramDeserializer(expected_opset_version).deserialize( + serialized_exported_program, + state_dict, + constants, + serialized_sample_inputs, + ) + + exported_programs[model_name] = ep + + return exported_programs + + +def _load_extra_files( + archive_reader: PT2ArchiveReader, file_names: list[str] +) -> dict[str, Any]: + extra_files = [file for file in file_names if file.startswith(EXTRA_DIR)] + + extra_file_contents: dict[str, Any] = {} + for file in extra_files: + contents = archive_reader.read_string(file) + extra_file_contents[file[len(EXTRA_DIR) :]] = contents + + return extra_file_contents + + +def _load_aoti( + file: str, + model_name: str, + run_single_threaded: bool, + num_runners: int, + device_idx: int, +) -> AOTICompiledModel: + loaded_metadata = torch._C._aoti.AOTIModelPackageLoader.load_metadata_from_package( # type: ignore[attr-defined] + file, model_name + ) + + device = loaded_metadata["AOTI_DEVICE_KEY"] + current_device_info = torch._inductor.codecache.get_device_information(device) + + for k, v in current_device_info.items(): + if k in loaded_metadata: + if v != loaded_metadata[k]: + logger.warning( + "Device information mismatch for %s: %s vs %s. " + "This could cause some issues when loading the AOTInductor compiled artifacts.", + k, + v, + loaded_metadata[k], + ) + + aoti_compiled_model = AOTICompiledModel( + torch._C._aoti.AOTIModelPackageLoader( + file, + model_name, + run_single_threaded, + num_runners, + device_idx, + ) + ) + + return aoti_compiled_model + + +def load_pt2( + f: FileLike, + *, + expected_opset_version: dict[str, int] | None = None, + run_single_threaded: bool = False, + num_runners: int = 1, + device_index: int = -1, + load_weights_from_disk: bool = False, +) -> PT2ArchiveContents: # type: ignore[type-arg] + """ + Loads all the artifacts previously saved with ``package_pt2``. + + Args: + f (str | os.PathLike[str] | IO[bytes]): A file-like object (has to + implement write and flush) or a string containing a file name. + + expected_opset_version (Optional[Dict[str, int]]): A map of opset names + to expected opset versions + + num_runners (int): Number of runners to load AOTInductor artifacts + + run_single_threaded (bool): Whether the model should be run without + thread synchronization logic. This is useful to avoid conflicts with + CUDAGraphs. + + device_index (int): The index of the device to which the PT2 package is + to be loaded. By default, `device_index=-1` is used, which corresponds + to the device `cuda` when using CUDA. Passing `device_index=1` would + load the package to `cuda:1`, for example. + + Returns: + A ``PT2ArchiveContents`` object which contains all the objects in the PT2. + """ + + from torch._inductor.cpp_builder import normalize_path_separator + + if not ( + (isinstance(f, (io.IOBase, IO)) and f.readable() and f.seekable()) + or (isinstance(f, (str, os.PathLike)) and os.fspath(f).endswith(".pt2")) + ): + # TODO: turn this into an error in 2.9 + logger.warning( + "Unable to load package. f must be a buffer or a file ending in " + ".pt2. Instead got {%s}", + f, + ) + + if isinstance(f, (str, os.PathLike)): + f = os.fspath(f) + + weights = {} + weight_maps = {} + # pyrefly: ignore [bad-argument-type] + with PT2ArchiveReader(f) as archive_reader: + version = archive_reader.read_string(ARCHIVE_VERSION_PATH) + if version != ARCHIVE_VERSION_VALUE: + raise ValueError( + f"Saved archive version {version} does not match our current " + f"archive version {ARCHIVE_VERSION_VALUE}." + ) + + file_names = archive_reader.get_file_names() + + exported_programs = _load_exported_programs( + archive_reader, file_names, expected_opset_version + ) + extra_files = _load_extra_files(archive_reader, file_names) + + # Get a list of AOTI model names + aoti_model_names: set[str] = set() + for file in file_names: + if file.startswith(AOTINDUCTOR_DIR): + file_end = file[ + len(AOTINDUCTOR_DIR) : + ] # remove data/aotinductor/ prefix + file_end = normalize_path_separator( + file_end + ) # Win32 need normalize path before split. + model_name = file_end.split("/")[ + 0 + ] # split "model_name/...cpp" into "model_name" + aoti_model_names.add(model_name) + if load_weights_from_disk and file.endswith("weights_config.json"): + weight_map = json.loads(archive_reader.read_string(file)) + weight_maps[model_name] = weight_map + elif load_weights_from_disk and file.startswith(WEIGHTS_DIR): + weight_file_name = file[ + len(WEIGHTS_DIR) : + ] # remove data/weights/ prefix + weight_bytes = archive_reader.read_bytes(file) + loaded_weight = torch.load(io.BytesIO(weight_bytes)) + weights[weight_file_name] = loaded_weight + + if isinstance(f, (io.IOBase, IO)): + if len(aoti_model_names) > 0: + # Workaround for AOTIModelPackageLoader not reading buffers + with tempfile.NamedTemporaryFile(suffix=".pt2") as tf: + f.seek(0) + tf.write(f.read()) + f.seek(0) + logger.debug("Writing buffer to tmp file located at %s.", tf.name) + + aoti_runners = { + model_name: _load_aoti( + tf.name, + model_name, + run_single_threaded, + num_runners, + device_index, + ) + for model_name in aoti_model_names + } + else: + aoti_runners = {} + else: + aoti_runners = { + model_name: _load_aoti( + f, + model_name, + run_single_threaded, + num_runners, + device_index, + ) + for model_name in aoti_model_names + } + + if weight_maps: + for model_name in aoti_model_names: + model_weights = {} + for weight_name, (file, shape, stride, storage_offset) in weight_maps[ + model_name + ].items(): + weight = weights[file] + model_weights[weight_name] = weight.as_strided( + shape, stride, storage_offset + ) + + # user_managed=True ensures the weights updates are shared by all runners. + aoti_runners[model_name].load_constants( + model_weights, check_full_update=True, user_managed=True + ) + + return PT2ArchiveContents(exported_programs, aoti_runners, extra_files) + + +def load_weights_to_pt2_contents( + pt2_contents: PT2ArchiveContents, weights_map: dict[str, Any] +) -> None: + """ + Load weights into the models in PT2 archive contents + + Args: + pt2_contents (PT2ArchiveContents): The contents of the PT2 archive. + """ + for model_name, weights in weights_map.items(): + if model_name not in pt2_contents.aoti_runners: + raise RuntimeError(f"Model {model_name} not found in PT2 archive contents.") + pt2_contents.aoti_runners[model_name].load_constants( + weights, check_full_update=True, user_managed=True + ) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/export/pt2_archive/_package_weights.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/export/pt2_archive/_package_weights.py new file mode 100644 index 0000000000000000000000000000000000000000..5acd86feebf0a691d7e527e4ea382e7b4aaabf9c --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/export/pt2_archive/_package_weights.py @@ -0,0 +1,135 @@ +import collections +import warnings + +import torch +from torch._subclasses.fake_tensor import FakeTensor +from torch.utils._ordered_set import OrderedSet + + +def _end_ptr(tensor: torch.Tensor) -> int: + if tensor.nelement(): + stop = tensor.view(-1)[-1].data_ptr() + tensor.element_size() + else: + stop = tensor.data_ptr() + return stop + + +class TensorProperties: + def __init__(self, tensor: torch.Tensor): + self.is_fake = isinstance(tensor, FakeTensor) + self.is_contiguous = tensor.is_contiguous() + self.storage_ptr = None + self.storage_size = None + self.start = None + self.end = None + + if not self.is_fake: + # only get the storage pointer for real tensors + # pyrefly: ignore [bad-assignment] + self.storage_ptr = tensor.untyped_storage().data_ptr() + if self.is_contiguous: + # only get storage size and start/end pointers for contiguous tensors + # pyrefly: ignore [bad-assignment] + self.storage_size = tensor.untyped_storage().nbytes() + # pyrefly: ignore [bad-assignment] + self.start = tensor.data_ptr() + # pyrefly: ignore [bad-assignment] + self.end = _end_ptr(tensor) + + # info to recover tensor + self.shape = tensor.shape + self.stride = tensor.stride() + self.offset = tensor.storage_offset() + + def is_complete(self) -> bool: + """ + Whether the tensor completely overlaps with its underlying storage + """ + if self.is_fake: + # Theoretically, fake tensors should not appear in weights + # But we handle this corner case to make it always complete + return True + if not self.is_contiguous: + return False + + assert self.storage_ptr is not None + assert self.storage_size is not None + assert self.start is not None + assert self.end is not None + return ( + self.start == self.storage_ptr + and self.end == self.storage_ptr + self.storage_size + ) + + +class Weights(dict): + """ + A dictionary mapping from weight name to a tuple of (tensor, TensorProperties). + tensor represents the actual initial value of the weight. + TensorProperties represents the properties of the weight that are needed to recover the weight. + + We use two separate entries because `tensor` could be a clone of the original weight tensor, + so it doesn't have the same property as the original weight (such as underlying storage pointer). + """ + + def __init__(self, weight_dict: dict[str, tuple[torch.Tensor, TensorProperties]]): + super().__init__(weight_dict) + + def get_weight(self, name: str) -> tuple[torch.Tensor, TensorProperties]: + return self[name] + + def get_weight_properties(self, name: str) -> TensorProperties: + return self[name][1] + + +def get_complete( + group: OrderedSet[tuple[str, str]], models_weights: dict[str, Weights] +) -> tuple[str, str]: + """ + `group` is a (model_name, weight_name) tuple. + `model_weights` is a dictionary mapping from model name to its Weights. + + One of the tensor in `group` must be complete and they must share the + same underlying storage. + + Returns the name of the complete tensor in the `group`. If multiple + tensors are complete, returns an arbitrary one. + """ + + def get_tensor_properties(name_tuple: tuple[str, str]) -> TensorProperties: + # returns the tensor properties + (model_name, weight_name) = name_tuple + return models_weights[model_name].get_weight_properties(weight_name) + + for name_tuple in group: + tensor_property = get_tensor_properties(name_tuple) + if tensor_property.is_complete(): + return name_tuple + + warnings.warn( + "No complete tensor found in the group! Returning the first one. " + "This may cause issues when your weights are not on CPU.", + stacklevel=2, + ) + assert len(group) > 0 + return next(iter(group)) + + +def group_weights(all_weights: dict[str, Weights]) -> list[OrderedSet[tuple[str, str]]]: + """ + Group weights that share the same underlying storage. + + Returns a list of sets, each set contains a tuple of (model_name, weight_name). + """ + + weights_dict: dict[tuple[int, torch.dtype], OrderedSet[tuple[str, str]]] = ( + collections.defaultdict(OrderedSet) + ) # (storage_key, dtype) -> set(weight) + + for model_name, weights in all_weights.items(): + for weight_name, (tensor, properties) in weights.items(): + weights_dict[(properties.storage_ptr, tensor.dtype)].add( + (model_name, weight_name) + ) + + return list(weights_dict.values()) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/export/pt2_archive/constants.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/export/pt2_archive/constants.py new file mode 100644 index 0000000000000000000000000000000000000000..4b05e257b8f3dfc387b553f0aeecc7a0e1653528 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/export/pt2_archive/constants.py @@ -0,0 +1,35 @@ +# Defined in torch/csrc/export/pt2_archive_constants.h +from torch._C._export import pt2_archive_constants + + +AOTINDUCTOR_DIR: str = pt2_archive_constants.AOTINDUCTOR_DIR +ARCHIVE_FORMAT_PATH: str = pt2_archive_constants.ARCHIVE_FORMAT_PATH +ARCHIVE_FORMAT_VALUE: str = pt2_archive_constants.ARCHIVE_FORMAT_VALUE +ARCHIVE_ROOT_NAME: str = pt2_archive_constants.ARCHIVE_ROOT_NAME +ARCHIVE_VERSION_PATH: str = pt2_archive_constants.ARCHIVE_VERSION_PATH +ARCHIVE_VERSION_VALUE: str = pt2_archive_constants.ARCHIVE_VERSION_VALUE +CONSTANTS_DIR: str = pt2_archive_constants.CONSTANTS_DIR +CONSTANTS_CONFIG_FILENAME_FORMAT: str = ( + pt2_archive_constants.CONSTANTS_CONFIG_FILENAME_FORMAT +) +CUSTOM_OBJ_FILENAME_PREFIX: str = pt2_archive_constants.CUSTOM_OBJ_FILENAME_PREFIX +EXECUTORCH_DIR: str = pt2_archive_constants.EXECUTORCH_DIR +EXTRA_DIR: str = pt2_archive_constants.EXTRA_DIR +MODELS_DIR: str = pt2_archive_constants.MODELS_DIR +MODELS_FILENAME_FORMAT: str = pt2_archive_constants.MODELS_FILENAME_FORMAT +MODULE_INFO_PATH: str = pt2_archive_constants.MODULE_INFO_PATH +MTIA_DIR: str = pt2_archive_constants.MTIA_DIR +SAMPLE_INPUTS_DIR: str = pt2_archive_constants.SAMPLE_INPUTS_DIR +SAMPLE_INPUTS_FILENAME_FORMAT: str = pt2_archive_constants.SAMPLE_INPUTS_FILENAME_FORMAT +TENSOR_CONSTANT_FILENAME_PREFIX: str = ( + pt2_archive_constants.TENSOR_CONSTANT_FILENAME_PREFIX +) +WEIGHTS_CONFIG_FILENAME_FORMAT: str = ( + pt2_archive_constants.WEIGHTS_CONFIG_FILENAME_FORMAT +) +WEIGHT_FILENAME_PREFIX: str = pt2_archive_constants.WEIGHT_FILENAME_PREFIX +WEIGHTS_DIR: str = pt2_archive_constants.WEIGHTS_DIR +XL_MODEL_WEIGHTS_DIR: str = pt2_archive_constants.XL_MODEL_WEIGHTS_DIR +XL_MODEL_WEIGHTS_PARAM_CONFIG_PATH: str = ( + pt2_archive_constants.XL_MODEL_WEIGHTS_PARAM_CONFIG_PATH +) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..53be63bdc9812d1b7439a3a82772bbb2b6317a74 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/__pycache__/__init__.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/__pycache__/_compatibility.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/__pycache__/_compatibility.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..39bb044ce4fc770ca0f8c5d15adaae84b66b56e6 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/__pycache__/_compatibility.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/__pycache__/_graph_pickler.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/__pycache__/_graph_pickler.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b57151095b226b50136c75f86d7ee42125e8d36d Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/__pycache__/_graph_pickler.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/__pycache__/_lazy_graph_module.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/__pycache__/_lazy_graph_module.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e4d0358629b2f72a11dd467a5de3f9698f1689b1 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/__pycache__/_lazy_graph_module.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/__pycache__/_pytree.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/__pycache__/_pytree.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a340b257d11c56cc95d4a4f8495391b7a6233020 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/__pycache__/_pytree.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/__pycache__/_symbolic_trace.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/__pycache__/_symbolic_trace.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..95afd58cc64125a569dc05a50fffc72ef9a0be25 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/__pycache__/_symbolic_trace.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/__pycache__/_utils.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/__pycache__/_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5a289627379fec4762f4c5604c4707d4c8b3dada Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/__pycache__/_utils.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/__pycache__/annotate.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/__pycache__/annotate.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8896efe77804e7de694abb4a7863d99d29e14fa8 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/__pycache__/annotate.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/__pycache__/config.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/__pycache__/config.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cc9f7297b79cec2f3a9fbeba6ef61e348f3b4a74 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/__pycache__/config.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/__pycache__/graph_module.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/__pycache__/graph_module.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8c444883822bf7665b9fec3cbffd87a60c51add7 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/__pycache__/graph_module.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/__pycache__/immutable_collections.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/__pycache__/immutable_collections.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..612aa820e7ed42528f6b94f86308c6eed8da7c26 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/__pycache__/immutable_collections.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/__pycache__/interpreter.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/__pycache__/interpreter.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8a28255485b91305d0d12c4b268c68532977bb84 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/__pycache__/interpreter.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/__pycache__/node.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/__pycache__/node.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..de44bcf542765e74fe3e368de4aa69ace9b76010 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/__pycache__/node.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/__pycache__/operator_schemas.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/__pycache__/operator_schemas.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a6caffd7d5f3e630153ac8940af1ef86bbc19269 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/__pycache__/operator_schemas.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/__pycache__/proxy.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/__pycache__/proxy.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0141df11a21d179c0b4d928e1436ea08e3663f81 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/__pycache__/proxy.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/__pycache__/subgraph_rewriter.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/__pycache__/subgraph_rewriter.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d4cd2f9a6f7c8d69067e836072f0285a0135f8c2 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/__pycache__/subgraph_rewriter.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/__pycache__/tensor_type.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/__pycache__/tensor_type.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bccc6c57d71f060364247ecf78a218a089a1ca85 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/__pycache__/tensor_type.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/__pycache__/traceback.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/__pycache__/traceback.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c995971ec6a0dfa92c1431a26eccd75fcc8bc39b Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/__pycache__/traceback.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b856ecb44b5648c4b7d33c82c9427796051db125 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/__pycache__/__init__.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/__pycache__/_backward_state.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/__pycache__/_backward_state.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..37b5f0dbfff3e58af36233033de9e8fb3e4152d7 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/__pycache__/_backward_state.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/__pycache__/_config.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/__pycache__/_config.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6fdb1c07c790376d04161248d3459498c4d53741 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/__pycache__/_config.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/__pycache__/_constant_symnode.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/__pycache__/_constant_symnode.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d5c8d7ffc76e17d5c0354ffb09cde04323d8f655 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/__pycache__/_constant_symnode.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/__pycache__/const_fold.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/__pycache__/const_fold.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8919dc8c4d711b05793cc93b1e96407ef6514f2c Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/__pycache__/const_fold.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/__pycache__/graph_gradual_typechecker.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/__pycache__/graph_gradual_typechecker.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..341a1945606365a39b47d313d7ae4fa03dfe5f18 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/__pycache__/graph_gradual_typechecker.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/__pycache__/merge_matmul.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/__pycache__/merge_matmul.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..26e04856b554188c14bb96223e31ff43bc619a14 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/__pycache__/merge_matmul.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/__pycache__/normalize.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/__pycache__/normalize.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ae633dd275f153d42c54842bb63b32d1e660818f Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/__pycache__/normalize.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/__pycache__/optimization.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/__pycache__/optimization.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..50f0e1e7498a8d5bba0eb5d3f04a6944ee3538ec Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/__pycache__/optimization.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/__pycache__/partitioner_utils.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/__pycache__/partitioner_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f96bc261fa67fee44aa379d38f0d47ba875465a5 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/__pycache__/partitioner_utils.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/__pycache__/recording.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/__pycache__/recording.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3e4fab7ce2e5faf69ce2aaee46182584df4162ce Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/__pycache__/recording.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/__pycache__/refinement_types.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/__pycache__/refinement_types.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..60b73592a19c12e50352bbdf728d4d1a5f57ce02 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/__pycache__/refinement_types.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/__pycache__/rewriter.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/__pycache__/rewriter.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f8a03196bc06dd4dcde2b32cb69adf0d357aa559 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/__pycache__/rewriter.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/__pycache__/schema_type_annotation.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/__pycache__/schema_type_annotation.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1fcb2a78db26fff9a2f23f278350f8bd4cb4d364 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/__pycache__/schema_type_annotation.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/__pycache__/sym_node.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/__pycache__/sym_node.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7685556ed192544e89633e31e2bc6a3aa6e70df6 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/__pycache__/sym_node.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/__pycache__/unify_refinements.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/__pycache__/unify_refinements.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c2ae9930d7910791f00269679fbc11e39909754a Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/__pycache__/unify_refinements.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/__pycache__/validator.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/__pycache__/validator.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5fa8734050f7aee838623f46396f6cbf04162b57 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/__pycache__/validator.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/_backward_state.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/_backward_state.py new file mode 100644 index 0000000000000000000000000000000000000000..9c742431857c33af22dbc1ad73b5bdfcf6124b9c --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/_backward_state.py @@ -0,0 +1,27 @@ +import torch.fx + + +class BackwardState: + """ + BackwardState is used to pass Python hooks from the forwards pass + into the backwards pass in Dynamo+Compiled Autograd. + + It is created by TorchDynamo and has special handling there. + Dynamo will pass an empty BackwardState to the forwards, then populate + members on it (via setattr) only after the forwards graph is finished. + Later on, in CompileAutograd we will inline and add the needed guards + on the BackwardState. + + BackwardState is identified and has special handling in AOTAutograd. + During AOTAutograd: + 1) BackwardState is an input to the forwards graph + 2) It must only be used in the backwards + 3) It will be empty in the forwards + 4) In the forwards we add a wrapper to save it + 5) In the backwards it becomes an input + 6) There can only be one per graph + + BackwardState requires CompiledAutograd. + """ + + proxy: torch.fx.Proxy diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/_config.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/_config.py new file mode 100644 index 0000000000000000000000000000000000000000..a537978db3834d0bbb425bbd4214a8b17163db18 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/_config.py @@ -0,0 +1,112 @@ +import os +import sys +from typing import Optional + +from torch.utils._config_module import Config, install_config_module + + +# [@compile_ignored: debug] Fails hard instead of graph breaking on guard on data dependent errors. +no_data_dependent_graph_break = ( + os.environ.get("TORCHDYNAMO_NO_DATA_DEPENDENT_GRAPH_BREAK", "0") == "1" +) +# [@compile_ignored: debug] Uses z3 for validating the guard optimizations transformations. +translation_validation = ( + os.environ.get("TORCHDYNAMO_TRANSLATION_VALIDATION", "0") == "1" +) +# Timeout (in milliseconds) for z3 finding a solution. +# [@compile_ignored: debug] +translation_validation_timeout = int( + os.environ.get("TORCHDYNAMO_TRANSLATION_VALIDATION_TIMEOUT", "600000") +) +# Disables bisection for translation validation. +# +# Translation validation bisection is enabled by default, if translation validation +# is also enabled. This should help finding guard simplification issues. However, +# since validation uses Z3 for bisecting, it might take a lot of time. +# +# Set this configuration option so as to avoid bisecting. +# [@compile_ignored: debug] +translation_validation_no_bisect = ( + os.environ.get("TORCHDYNAMO_TRANSLATION_NO_BISECT", "0") == "1" +) +# Checks whether replaying ShapeEnv events on a freshly constructed one yields +# the a ShapeEnv with the same state. This should be used only in testing. +check_shape_env_recorded_events = False + +# TODO: Perhaps consider allowing unions for the configs below (so you can hit +# multiple reps at the same time) + +# Give extended debug information if the string representation of a guard +# matches this. For example, set this to "Ne(s0, 10)" and whenever we issue +# this guard, we will generate full Python and C++ backtrace +# [@compile_ignored: debug] +extended_debug_guard_added = os.environ.get( + "TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED", None +) + +# Give extended debug information when a particular symbol is allocated. For +# example, set this to "u2" and whenever we create this symbol, we will +# generate full Python and C++ backtrace +# [@compile_ignored: debug] +extended_debug_create_symbol = os.environ.get( + "TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL", None +) + +# Give extended debug information (C++ backtrace) for all extended debug +# settings as well as errors. The C++ backtrace is slow and very spammy so we +# don't include it by default even when you're requesting extended debug. +# [@compile_ignored: debug] +extended_debug_cpp = os.environ.get("TORCHDYNAMO_EXTENDED_DEBUG_CPP", "") != "" + +# Give extended debug information (line of code) when a torch function +# is called during export. This is useful for showing progress and detecting +# where export might be stuck. Currently only works for strict=False. +# [@compile_ignored: debug] +extended_debug_current_loc = ( + os.environ.get("TORCHEXPORT_EXTENDED_DEBUG_CURRENT_LOC", "0") == "1" +) + +# [@compile_ignored: debug] Show a warning for every specialization +print_specializations = False + +# wraps (un)equalities with 'Not' class after recording the correct expression +# in the FX graph. This should incorrectly construct the divisible and replacement +# lists, and incorrectly issue guards. +inject_EVALUATE_EXPR_flip_equality_TESTING_ONLY = False + +# [@compile_ignored: debug] Validate that ShapeEnv's version key is updated correctly +validate_shape_env_version_key = False + +# If we produce more than this many guards on a symbol, force the symbol to +# get specialized and bail out if this many guards mention this particular +# symbol. This may be slightly more aggressive than the true number of guards +# issued (as we test if we've hit the limit on-the-fly, whereas we may +# do further simplifications at final guard issuance time that make guards +# irrelevant.) +symbol_guard_limit_before_specialize: Optional[int] = None + +# This flag changes whether we should use the same symbolic variable to represent input sizes that are the same. +use_duck_shape = True + +# Controls the registration of torch.nonzero() on the meta device. +# When True, nonzero returns a tensor with shape (self.numel(), self.dim()) +# assuming all elements are none-zero. +# Default is False to prevent unintended registration. Set to True to enable. +meta_nonzero_assume_all_nonzero = False + +# Applies size-oblivious reasoning to backed symbols. This allocates a [0, inf] range for backed size symbols, +# and relies on size-oblivious semantics to avoid 0/1 specialization guards by marking them size-like. +# Currently an experimental option for export. +backed_size_oblivious = False + +# Skip dtype check in meta registrations. Only used for systems that does its own dtype checking. +skip_dtype_check_in_meta_registrations = False + +# Experimental: If True, graph module will register fx metadata during recompile() +enrich_profiler_metadata: bool = Config( # type: ignore[var-annotated] + default=False, + env_name_default="TORCH_ENRICH_RPOFILER_STACK_TRACE", +) + + +install_config_module(sys.modules[__name__]) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/_constant_symnode.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/_constant_symnode.py new file mode 100644 index 0000000000000000000000000000000000000000..b3b40bda324c8fd6ad171d14ddb17f52508cb23a --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/_constant_symnode.py @@ -0,0 +1,78 @@ +from typing import * # noqa: F403 + + +# Python version of c10/core/ConstantSymNodeImpl.cpp +# This needs to exist because the Python version of nested int is not compatible +# with the C++ version of constant symnode. +class ConstantIntNode: + def __init__(self, val: int): + self.val = val + + def is_constant(self) -> bool: + return True + + def maybe_as_int(self) -> int: + return self.val + + def is_int(self) -> bool: + return True + + def is_float(self) -> bool: + return False + + def is_bool(self) -> bool: + return False + + def is_nested_int(self) -> bool: + return False + + def clone(self) -> "ConstantIntNode": + return self + + def _str(self) -> str: + return str(self.val) + + def __str__(self) -> str: + return self._str() + + def __repr__(self) -> str: + return self._str() + + def _graph_repr(self) -> str: + return self._str() + + def add(self, other: Any) -> Any: + return other.add(self) + + def sub(self, other: Any) -> Any: + return other.neg().add(self.val) + + def mul(self, other: Any) -> Any: + return other.mul(self) + + def eq(self, other: Any) -> Any: + return other.eq(self) + + def ne(self, other: Any) -> Any: + return other.ne(self) + + def gt(self, other: Any) -> Any: + return other.lt(self) + + def lt(self, other: Any) -> Any: + return other.gt(self) + + def le(self, other: Any) -> Any: + return other.ge(self) + + def ge(self, other: Any) -> Any: + return other.le(self) + + def is_symbolic(self) -> bool: + return False + + def constant_int(self) -> int: + return self.val + + def guard_int(self, file: str, line: int) -> int: + return self.val diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/accelerator_partitioner.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/accelerator_partitioner.py new file mode 100644 index 0000000000000000000000000000000000000000..7cfd41b039e9ec6f9c456fd0240b18902756dc55 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/accelerator_partitioner.py @@ -0,0 +1,1085 @@ +# mypy: allow-untyped-defs +import operator +from collections import deque +from typing import NamedTuple + +import torch +from torch.fx.experimental.partitioner_utils import ( + Device, + get_extra_size_of, + get_latency_of_partitioned_graph, + get_partition_to_latency_mapping, + NodeLatency, + Partition, + PartitionerConfig, + PartitionMode, +) +from torch.fx.graph_module import GraphModule +from torch.fx.node import map_arg, Node +from torch.fx.passes.graph_manipulation import get_size_of_all_nodes +from torch.fx.passes.split_module import split_module + + +class DAGNode: + """DAGNode class maintains useful information for a partition (submodule), + and its input submodules and output submodules. + """ + + def __init__( + self, + submodule_node: Node, + input_nodes: list[Node], + output_nodes: list[Node], + logical_device_ids: list[int], + size_bytes: int, + ) -> None: + self.submodule_node: Node = submodule_node + self.input_nodes: list[Node] = input_nodes + self.output_nodes: list[Node] = output_nodes + self.logical_device_ids: list[int] = logical_device_ids + self.size_bytes = size_bytes + + def __str__(self) -> str: + return str(self.submodule_node) + + +class DAG: + """DAG class contains all the DAG nodes""" + + def __init__(self) -> None: + self.nodes: list[DAGNode] = [] + + def create_node( + self, + submodule_node: Node, + input_nodes: list[Node], + output_nodes: list[Node], + logical_devices: list[int], + size_bytes: int, + ) -> None: + node = DAGNode( + submodule_node, input_nodes, output_nodes, logical_devices, size_bytes + ) + self.nodes.append(node) + + +class PartitionResult(NamedTuple): + """NameTuple used for returning DAG and a new fx module""" + + dag: DAG + module_with_submodules: GraphModule + + +"""Followings are some helper functions for partition manipulation""" + + +def reset_partition_device(partitions): + for partition in partitions: + partition.logical_device_ids = [] + + +def combine_two_partitions( + partition_0: Partition, partition_1: Partition, partitions: list[Partition] +) -> None: + """Given a list of partitions and its two partitions, + combine these two partitions into a new one appending to the partitions + and remove the previous two partitions from the list of partitions + """ + partition = Partition(len(partitions)) + partition.nodes = partition_0.nodes.union(partition_1.nodes) + partition.recalculate_mem_size() + partitions.append(partition) + partitions.remove(partition_0) + partitions.remove(partition_1) + reorganize_partitions(partitions) + return + + +def set_parents_and_children(partitions: list[Partition]) -> None: + """Given a list of partitions, mark parents and children for each partition""" + # Go through all nodes in a partition. + # If a node's user is in other partition, + # then the other partition is this partition's children. + # This partition is the other partition's parent + for partition in partitions: + partition.children = set() + partition.parents = set() + for partition in partitions: + for node in partition.nodes: + # For each node in the current partition, find its users + users = node.users + for n in users: + # Find which the partition the user node belongs to. + # Note that if the node itself is also belongs to that partition, + # that partition is not the child of the current partition + for p in partitions: + if p != partition and n in p.nodes and node not in p.nodes: + partition.children.add(p) + p.parents.add(partition) + return + + +def reorganize_partitions(partitions: list[Partition]) -> None: + """Given a list of partitions, reorganize partition id, + its parents and its children for each partition + """ + # Rearrange partition ids + for i, partition in enumerate(partitions): + partition.partition_id = i + set_parents_and_children(partitions) + return + + +def get_bfs_level_partition(partitions: list[Partition]) -> None: + """Given a list of partitions, + mark the bfs level for each partition + """ + current_level: set[Partition] = set() + visited: set[Partition] = set() + for partition in partitions: + # If a partition has no parent, it should be in root level + if len(partition.parents) == 0: + current_level.add(partition) + next_level: set[Partition] = set() + level = 0 + # bfs + while current_level: + partition = current_level.pop() + partition.bfs_level = level + visited.add(partition) + children = partition.children + for child in children: + if child not in next_level: + next_level.add(child) + if not current_level: + current_level = next_level.copy() + next_level = set() + level += 1 + return + + +def get_node_to_partition_mapping(partitions: list[Partition]) -> dict[Node, int]: + """Given a list of partitions,return node to partition mapping""" + node_to_partition: dict[Node, int] = {} + for partition in partitions: + for node in partition.nodes: + node_to_partition[node] = partition.partition_id + return node_to_partition + + +def get_logical_id_to_device(devices: list[Device]) -> dict[int, Device]: + """Get a mapping from device logical ID to Device object.""" + logical_id_to_device: dict[int, Device] = {} + for d in devices: + logical_id_to_device[d.logical_id] = d + return logical_id_to_device + + +def get_device_partition_stats( + partitions: list[Partition], devices: list[Device] +) -> tuple[dict[Device, list[Partition]], dict[Device, int], list[Partition]]: + """Given a list of partitions and a list of devices, returns: + 1. A mapping from device to partitions on it; + 2. A mapping from device to its remaining memory size; + 3. A list of partitions that do not have a device. + """ + # logical id to device + logical_id_to_device = get_logical_id_to_device(devices) + # Track partitions on device + device_to_partitions: dict[Device, list[Partition]] = {} + # Track device's left mem size + device_to_left_mem_bytes: dict[Device, int] = {} + for d in devices: + device_to_partitions[d] = [] + device_to_left_mem_bytes[d] = d.available_mem_bytes + + # Deal with the partitions that already have a device + # and also collect all partitions without a device (no_device_partitions) + no_device_partitions = [] + for partition in partitions: + if partition.logical_device_ids != []: + for logical_id in partition.logical_device_ids: + device = logical_id_to_device[logical_id] + device_to_partitions[device].append(partition) + device_to_left_mem_bytes[device] -= partition.used_mem_bytes + else: + no_device_partitions.append(partition) + + return ( + device_to_partitions, + device_to_left_mem_bytes, + no_device_partitions, + ) + + +def get_device_to_partitions_mapping( + partitions: list[Partition], devices: list[Device] +): + """Given a list of partitions and a list of devices, + map each partition into a device. + """ + + def calculate_extra_mem_bytes_needed_for( + partition: Partition, partitions: list[Partition] + ): + all_nodes: set[Node] = set() + for p in partitions: + all_nodes = all_nodes.union(p.nodes) + if len(all_nodes) == 0: + return partition.used_mem_bytes + all_nodes = all_nodes.union(partition.nodes) + extra_size_needed = 0 + for node in partition.nodes: + extra_size_needed += get_extra_size_of(node, all_nodes) + return extra_size_needed + + def find_device_for(partition: Partition): + """Given a partition, find a logical device for the partition + The algorithm is to put the partition on the device + that has just enough mem left for that partition. + device_to_left_mem_bytes is a dictionary between device and its left mem size + sorted by its left mem size + """ + for d in device_to_left_mem_bytes: + extra_size_needed = calculate_extra_mem_bytes_needed_for( + partition, device_to_partitions[d] + ) + if extra_size_needed < device_to_left_mem_bytes[d]: + device_to_partitions[d].append(partition) + partition.logical_device_ids.append(d.logical_id) + device_to_left_mem_bytes[d] -= extra_size_needed + return True + return False + + ( + device_to_partitions, + device_to_left_mem_bytes, + no_device_partitions, + ) = get_device_partition_stats(partitions, devices) + + # Find devices for all the partitions without a device + found_device = True + for partition in no_device_partitions: + device_to_left_mem_bytes = dict( + sorted(device_to_left_mem_bytes.items(), key=operator.itemgetter(1)) + ) + found_device = find_device_for(partition) + if not found_device: + break + return found_device + + +def check_dependency(partition): + """Given a partition,check if there is a circular dependency on + this partition using bfs + """ + visited: set[Partition] = {partition} + queue: deque[Partition] = deque([partition]) + while queue: + p = queue.popleft() + for child in p.children: + if child == partition: + return True + else: + if child not in visited: + visited.add(child) + queue.append(child) + return False + + +class Partitioner: + """A fx module may not fit into one device. + Partitioner class helps partition one fx module into submodules (partitions), + so that the submodules can be executed crossing different accelerators. + The main function of this class is self.partition_graph. + It partitions the fx module based on the scheme specified in partition_config + A DAG structure is returned + along with a new fx module with submodule nodes. + """ + + def __init__(self) -> None: + self.partitions: list[Partition] = [] + self.node_to_partition: dict[Node, int] = {} + self.devices: list[Device] = [] + + def partition_graph( + self, + fx_module: GraphModule, + torch_module: torch.nn.Module, + partitioner_config: PartitionerConfig, + ) -> PartitionResult: + """Given the fx module, torch module and partitioner_config, + find the partitions, do the partitions, + and then return a DAG and a new fx module with submodule nodes (partitions) + """ + self.graph_module = fx_module + self.torch_module = torch_module + self.devices = partitioner_config.devices + if len(self.devices) == 0: + raise RuntimeError("No devices") + # Tag the size in bytes to all nodes in the graph_module. + get_size_of_all_nodes(self.graph_module) + # Check if there are op nodes in the fx module + nodes = self.graph_module.graph.nodes + if all(node.op in {"placeholder", "get_attr", "output"} for node in nodes): + raise RuntimeError("No Partition since no operations in the module") + # Calculate total size of the fx module + total_size_of_graph = 0 + for node in nodes: + if node.op == "output": + break + total_size_of_graph += node.size_bytes.total_size + # Find the device with the max mem size + device_with_max_mem = max(self.devices, key=lambda d: d.available_mem_bytes) + # AOT based partition + if partitioner_config.mode == PartitionMode.aot_based: + self.aot_based_partition( + partitioner_config.node_to_partition_mapping, + partitioner_config.partition_to_logical_device_mapping, + ) + # Single partition if the whole module can be fit into one device + elif total_size_of_graph <= device_with_max_mem.available_mem_bytes: + self.find_single_partition( + total_size_of_graph, logical_device_id=device_with_max_mem.logical_id + ) + elif total_size_of_graph > sum(d.available_mem_bytes for d in self.devices): + raise RuntimeError("Devices have no enough memory for the module") + else: + # Sparse nn based partition + if partitioner_config.mode == PartitionMode.sparse_nn: + available_mem_bytes = self.devices[0].available_mem_bytes + if not all( + device.available_mem_bytes == available_mem_bytes + for device in self.devices + ): + raise RuntimeError("All devices must have same memory size!") + # sparse_nn_partition only support same memory size + # TODO: add different size support for sparse_nn_partition + self.sparse_nn_partition(available_mem_bytes) + # Cost aware partition + elif partitioner_config.mode == PartitionMode.cost_aware: + self.cost_aware_partition( + partitioner_config.transfer_rate_bytes_per_sec, + partitioner_config.node_to_latency_mapping, + ) + # KL based partition + elif partitioner_config.mode == PartitionMode.kl_based: + self.kl_based_partition( + partitioner_config.transfer_rate_bytes_per_sec, + partitioner_config.node_to_latency_mapping, + ) + else: + self.size_based_partition() + + # Saturate host if possible. + if partitioner_config.saturate_host: + self.saturate_host() + + # Partition the graph module based on the partition assignment. + module_with_submodules = self.do_partition() + + # The DAG contains DAGNodes with info of each partition's input nodes, output nodes + # and how partitions are connected. + dag = self.dump_dag(module_with_submodules) + ret = PartitionResult(dag, module_with_submodules) + return ret + + def find_single_partition( + self, total_size_of_graph, logical_device_id: int = 0 + ) -> None: + """Fit the whole fx module into one device""" + partition_0 = self.create_partition() + for node in self.graph_module.graph.nodes: + if node.op == "output": + # Skip the output node, but there can + # be nodes after the output in certain cases. + continue + partition_0.nodes.add(node) + partition_0.used_mem_bytes = total_size_of_graph + partition_0.logical_device_ids = [logical_device_id] + # Get the node to partition mapping + self.node_to_partition = get_node_to_partition_mapping(self.partitions) + return + + def size_based_partition(self) -> None: + """This method is to partition the fx module based on memory size. + It uses greedy approach. The result may not be the best. + The basic idea is: + Step 1: + Find a device which has enough memory to fit the current node, create a empty partition + with the size of that device. + Then keep adding the following nodes into the partition until the partition is full. + Step 2: + Repeat Step 1 until no device left + Step 3: + If some nodes are left, create a partition for each left node (single node partition). + and then try to map those partitions into logical devices with enough mem left. + """ + + def find_device_based_on_size(node) -> Device: + """Given a node, this function is to find a logical device + that could fit the node. + """ + mem_size_needed = get_extra_size_of(node, set()) + device = Device("", -1, -1) + for d in self.devices: + if ( + d not in occupied_devices + and d.available_mem_bytes >= mem_size_needed + ): + device = d + break + if device.available_mem_bytes < 0: + raise RuntimeError(str(node) + "is too large to fit any device") + occupied_devices.append(device) + return device + + # Track partition and its left mem size + partition_to_left_mem_bytes: dict[Partition, int] = {} + # Track all the devices that have been used + occupied_devices: list[Device] = [] + partition = self.create_partition() + for node in self.graph_module.graph.nodes: + if node.op in {"call_module", "call_method", "call_function"}: + # Check if there are devices left + if len(self.partitions) <= len(self.devices): + total_size_of_input_nodes = get_extra_size_of(node, partition.nodes) + # Check if the current partition is the very first partition + if partition.used_mem_bytes == 0: + # Find a device to fit the first node, return available mem size + device = find_device_based_on_size(node) + occupied_devices.append(device) + # Update partition and its left mem size + partition_to_left_mem_bytes[partition] = ( + device.available_mem_bytes + ) + # Update available mem for the current partition + partition.logical_device_ids.append(device.logical_id) + else: + # The current partition is not the first partition + # Check if the current node can fit into current partition + if ( + partition_to_left_mem_bytes[partition] + < total_size_of_input_nodes + ): + # Check if no device is left + if len(self.partitions) == len(self.devices): + # No device is left + # Create the first single node partition for the current node + self.create_single_node_partition(node) + continue + # Some devices are still left + # Create a new partition with a mem size that is enough for the current node + device = find_device_based_on_size(node) + partition = self.create_partition() + total_size_of_input_nodes = get_extra_size_of( + node, partition.nodes + ) + partition_to_left_mem_bytes[partition] = ( + device.available_mem_bytes + ) + partition.logical_device_ids.append(device.logical_id) + partition.add_node(node) + partition_to_left_mem_bytes[partition] -= total_size_of_input_nodes + # Create single node partitions if no device is left + else: + self.create_single_node_partition(node) + reorganize_partitions(self.partitions) + # Get the node to partition mapping + self.node_to_partition = get_node_to_partition_mapping(self.partitions) + # Mapping all partitions into device + found_partition_to_device_mapping = get_device_to_partitions_mapping( + self.partitions, self.devices + ) + if not found_partition_to_device_mapping: + raise RuntimeError("Cannot Get a Valid Partition to Logical Device Mapping") + return + + def saturate_host(self) -> None: + """Saturate host by assigning replicates to unused devices with enough memory. + It uses a greedy approach to find a next available set of devices to place all split + partitions: For each used device, it searches for an idle device with minimal memory + size that can hold all the partition located on that device; If the search is successful + for all used devices, it then assigns the new devices' logical ID to the corresponding + partition. + """ + ( + device_to_partitions, + device_to_left_mem_bytes, + no_device_partitions, + ) = get_device_partition_stats(self.partitions, self.devices) + + assert len(no_device_partitions) == 0, ( + f"Expect no_device_partitions has 0 device, but get {len(no_device_partitions)}" + ) + + # Devices that hold partitions + used_devices = [d for d in self.devices if len(device_to_partitions[d]) > 0] + # Track replicates of the assigned devices + replicated_device_to_used_device: dict[Device, Device] = {} + + while len(used_devices) * 2 + len(replicated_device_to_used_device) <= len( + self.devices + ): + # Success flag for this round + success = True + # Devices that have not been assigned + idle_devices = [ + d + for d in self.devices + if d not in used_devices and d not in replicated_device_to_used_device + ] + # Temporary mapping from replicated device to original device + temp_replicate_mapping = {} + + # Find a new device to replicate all partitions on an used device + for used_device in used_devices: + # Idle devices that have enough memory + available_devices = [ + d + for d in idle_devices + if d.available_mem_bytes + >= used_device.available_mem_bytes + - device_to_left_mem_bytes[used_device] + ] + if len(available_devices) == 0: + success = False + break + new_device = min(available_devices, key=lambda d: d.available_mem_bytes) + idle_devices.remove(new_device) + temp_replicate_mapping[new_device] = used_device + + if not success: + break + replicated_device_to_used_device.update(temp_replicate_mapping) + + # Update logical device IDs assigned to the partitions + for ( + replicate_device, + original_device, + ) in replicated_device_to_used_device.items(): + logical_id = replicate_device.logical_id + for partition in device_to_partitions[original_device]: + partition.logical_device_ids.append(logical_id) + for p in self.partitions: + print(p.logical_device_ids) + + def do_partition(self) -> GraphModule: + """Return a new fx module with submodule nodes (partitions).""" + module_with_submodules = split_module( + self.graph_module, + self.torch_module, + lambda node: self.node_to_partition[node], + ) + return module_with_submodules + + def dump_dag(self, module_with_submodules: GraphModule) -> DAG: + """Return the dag structure and the new fx module with submodules.""" + dag = DAG() + for node in module_with_submodules.graph.nodes: + if node.op == "output": + break + if node.op in {"placeholder", "get_attr"}: + continue + if node.target is operator.__getitem__: + continue + input_nodes: dict[Node, None] = {} + map_arg(node.args, input_nodes.setdefault) + map_arg(node.kwargs, input_nodes.setdefault) + # When a node has two or more output nodes, + # it outputs its result to 'getitem' nodes. + # Those 'getitem' nodes are the output node for this node. + # Otherwise, the output node is this node itself. + if len(node.users) > 1: + output_nodes = list(node.users) + else: + output_nodes = [node] + partition_id = int(node.name.rsplit("_", 1)[-1]) + device_ids = self.partitions[partition_id].logical_device_ids + size_bytes = self.partitions[partition_id].used_mem_bytes + dag.create_node( + node, list(input_nodes), output_nodes, device_ids, size_bytes + ) + return dag + + def create_partition(self) -> Partition: + """Create a partition and append it to self.partitions.""" + partition_id = len(self.partitions) + partition = Partition(partition_id) + self.partitions.append(partition) + return partition + + def create_single_node_partition(self, node): + """Create a partition for a single node""" + partition = self.create_partition() + partition.add_node(node) + return + + def sparse_nn_partition(self, available_mem_bytes: int) -> None: + """This method partition a sparse nn module. + It is size based partition but different from size_based_partition, + it only works when all the devices have same memory size (available_mem_bytes). + In the future, devices with different mem sizes will be supported like size_based_partition. + It first traverse all the nodes and do the partitions based on the same memory size. + If the current partition has no enough memory left for a new op node + (call_module, call_method, call_function), a new partition is created. + When crossing the boundary between non-embedding nodes and embedding nodes, + a new partition is created regardlessly. + For example, if the current node is a non-embedding node but the next node is an + embedding node, a new partition is created for the next node. + After the partition, the partitions are combined as much as possible. + The rule is that a non-embedding partition only + combines with another non-embedding one. + So as the embedding partitions. + """ + + def combine_partitions_based_on_size( + partitions: list[Partition], available_mem_bytes: int + ) -> None: + """Combining small partitions together to keep as less partitions as possible. + Here is an example of the algorithm to do this: + Assume some partitions, we first sort them based on partition used memory size. + [(partition_4, 1), (partition_3, 1), (partition_2, 2), (partition_1, 7), (partition_0, 9)] + The available memory is 10. + step 1: self.find_partition_to_combine_based_on_size() + First, mark bfs level for each partition + Second, look the smallest partition, partition_4: 10 - 1 = 9 + It means any partition has a used memory equal or less than 9 could combine this partition + We go from the largest and selection partition_0. + Check the bfs level for two partitions, if the level difference is less than 2, + it can be combined. + step 2: repeat step 1 until no partitions can be combined + """ + find_combination = True + while find_combination: + # Sort partitions based on memory size + sorted_partitions = sorted(partitions, key=lambda p: p.used_mem_bytes) + # Mark bfs level + get_bfs_level_partition(self.partitions) + find_combination, partitions = find_partition_to_combine_based_on_size( + sorted_partitions, + available_mem_bytes, + # pyrefly: ignore [bad-argument-type] + partitions, + ) + return + + def calculate_mem_bytes_needed(p1, p2): + """Given two partitions, calculate how many mem bytes + are needed if two partitions are combined + """ + nodes = p1.nodes.union(p2.nodes) + mem_bytes_needed = 0 + for node in nodes: + mem_bytes_needed += get_extra_size_of(node, nodes) + return mem_bytes_needed + + def find_partition_to_combine_based_on_size( + sorted_partitions: list[Partition], + available_mem_bytes: int, + partitions: list[Partition], + ) -> tuple[bool, list[Partition]]: + """step 1 in combine_partition_based_on_size()""" + find_combination = False + smallest_partition = sorted_partitions.pop(0) + for p in sorted_partitions[::-1]: + if abs(smallest_partition.bfs_level - p.bfs_level) <= 1: + # Calculate how many bytes needed if combined + mem_bytes_needed = calculate_mem_bytes_needed(p, smallest_partition) + if mem_bytes_needed <= available_mem_bytes: + combine_two_partitions(p, smallest_partition, self.partitions) + partitions.remove(smallest_partition) + partitions.remove(p) + partitions.append(self.partitions[-1]) + find_combination = True + break + return find_combination, partitions + + def reset_partition_in_sparse_nn(partition, new_partition=True): + """If crossing the boundary between non-embedding nodes and + embedding nodes, create a new partition + """ + if in_embedding_region: + embedding_partitions.append(partition) + else: + non_embedding_partitions.append(partition) + if new_partition: + partition = self.create_partition() + # pyrefly: ignore [missing-attribute] + partition.left_mem_bytes = available_mem_bytes + return partition + return None + + def is_embedding_node(node: Node) -> bool: + """Check if a node is an embedding node""" + if node.op == "call_module": + submodule = self.graph_module + for atom in str(node.target).split("."): + if not hasattr(submodule, atom): + raise RuntimeError( + f"Module {submodule} has no attribute {atom}" + ) + submodule = getattr(submodule, atom) + if "Embedding" in str(submodule): + return True + return False + + # Track embedding partitions and non-embedding partitions separately + embedding_partitions: list[Partition] = [] + non_embedding_partitions: list[Partition] = [] + # A Flag to check the boundary + in_embedding_region: bool = False + partition = self.create_partition() + for node in self.graph_module.graph.nodes: + if node.op in {"call_module", "call_method", "call_function"}: + # Check if crossing the boundary between embedding nodes and non embedding nodes + if is_embedding_node(node) != in_embedding_region: + # Crossing the boundary + # Check if the current partition is an empty partition + if partition.used_mem_bytes != 0: + # The current partition isn't an empty partition. Create a new one. + partition = reset_partition_in_sparse_nn(partition) + in_embedding_region = not in_embedding_region + total_size_of_input_nodes = get_extra_size_of(node, partition.nodes) + if ( + total_size_of_input_nodes + partition.used_mem_bytes + > available_mem_bytes + ): + partition = reset_partition_in_sparse_nn(partition) + total_size_of_input_nodes = get_extra_size_of(node, partition.nodes) + if total_size_of_input_nodes > available_mem_bytes: + raise RuntimeError( + node.target + "is too large to fit into a device" + ) + partition.add_node(node) + reset_partition_in_sparse_nn(partition, new_partition=False) + # Set parents and children for partitions + set_parents_and_children(self.partitions) + # Combining non-embedding partitions + combine_partitions_based_on_size(non_embedding_partitions, available_mem_bytes) + # Combining embedding partitions + combine_partitions_based_on_size(embedding_partitions, available_mem_bytes) + total_size_of_non_embedding_partitions = 0 + for partition in non_embedding_partitions: + total_size_of_non_embedding_partitions += partition.used_mem_bytes + # Check if devices are enough for all partitions + if len(embedding_partitions) > len(self.devices): + msg = ( + "Need " + + str(len(embedding_partitions)) + + " devices, but only " + + str(len(self.devices)) + + " provided" + ) + raise RuntimeError(msg) + occupied_devices = [] + for i, partition in enumerate(embedding_partitions): + # Check if all non-embedding partitions can fit into embedding partition devices + if ( + total_size_of_non_embedding_partitions + partition.used_mem_bytes + > available_mem_bytes + ): + raise RuntimeError( + "partition_" + + str(partition.partition_id) + + "(embedding partition) and non embedding partitions can not fit into one device" + ) + else: + # Add logical device to the partition + partition.logical_device_ids = [self.devices[i].logical_id] + occupied_devices.append(self.devices[i].logical_id) + # Add logical devices to the non_embedding_partitions + for partition in non_embedding_partitions: + partition.logical_device_ids = occupied_devices + # Get the node to partition mapping + self.node_to_partition = get_node_to_partition_mapping(self.partitions) + return + + def cost_aware_partition( + self, + transfer_rate_bytes_per_sec: float, + node_to_latency_mapping: dict[Node, NodeLatency], + ) -> None: + """This method is to partition the fx module based on the cost. + The cost is the total latency of running the whole fx module. + In partitioner_utils.py, the cost model is built. + The cost aware partition algorithm is: + #1. At every beginning, each node is a partition. + Then we map all the partitions to the devices + and calculate the cost + #2. Then try to pre-combine any two of the partitions if the two + partitions can be combined. + (the bfs level is less than 2 or two partitions are connected and + can find partition to device mapping) + See if any partition pair could reduce the current cost. + Choose the pair that shows the minimum cost and then combine them + #3. Repeat #2 until the cost cannot be reduced. + """ + + def try_combining_partitions(p0_index, p1_index, partitions) -> float: + """Given two partitions and a list of partitions, combine these two partitions + and see what is the cost of the modified partition list + """ + p0 = partitions[p0_index] + p1 = partitions[p1_index] + """If two partitions' bfs level are less than 2 or two partitions are connected to each other, + then they can be combined + """ + if ( + (abs(p0.bfs_level - p1.bfs_level) <= 1) + or (p0 in p1.parents) + or p0 in (p1.children) + ): + combine_two_partitions(p0, p1, partitions) + # Check if a circular dependency exists after combining + if check_dependency(partitions[-1]): + return float("inf") + # Check if the modified partition list can be mapped to devices after combination + reset_partition_device(partitions) + found_deivce = get_device_to_partitions_mapping( + partitions, self.devices + ) + if not found_deivce: + return float("inf") + # Calculate the new cost + partition_to_latency_mapping = get_partition_to_latency_mapping( + partitions, node_to_latency_mapping + ) + cost = get_latency_of_partitioned_graph( + partitions, + partition_to_latency_mapping, + transfer_rate_bytes_per_sec, + ) + return cost + # If two partition can not be combined, the cost is inf + return float("inf") + + def search_combination( + transfer_rate_bytes_per_sec, node_to_latency_mapping + ) -> bool: + """Given transfer rate between partitions and each node's latency, + find two partitions to combine so the cost of the partitions can + be reduced. + The algorithm is : + 1. Go through all the partition pairs and see + if any pair of partitions can be combined. + 2. Calculate the cost after the combination. + 3. Select the minimum cost and combine its corresponding partition pair. + """ + partition_to_latency_mapping = get_partition_to_latency_mapping( + self.partitions, node_to_latency_mapping + ) + cost = get_latency_of_partitioned_graph( + self.partitions, + partition_to_latency_mapping, + transfer_rate_bytes_per_sec, + ) + if len(self.partitions) == 1: + return False + partition_pair: list[int] = [] + for i in range(len(self.partitions) - 1): + for j in range(i + 1, len(self.partitions)): + # Try to combine the partition pair + # and see the new cost after combination + new_cost = try_combining_partitions(i, j, self.partitions[:]) + if new_cost <= cost: + partition_pair = [i, j] + cost = new_cost + reorganize_partitions(self.partitions) + # If a partition pair is found, combine them + if len(partition_pair) != 0: + p0 = self.partitions[partition_pair[0]] + p1 = self.partitions[partition_pair[1]] + combine_two_partitions(p0, p1, self.partitions) + get_bfs_level_partition(self.partitions) + reset_partition_device(self.partitions) + get_device_to_partitions_mapping(self.partitions, self.devices) + return len(partition_pair) != 0 + + for node in self.graph_module.graph.nodes: + if node.op not in {"placeholder", "get_attr", "output"}: + self.create_single_node_partition(node) + # Set up parent partitions and children partitions for each partition + set_parents_and_children(self.partitions) + # Get bfs level for each partition + get_bfs_level_partition(self.partitions) + find_combination = True + while find_combination: + # Search for a pair partition to generate the minimum new cost, + # then combine them + find_combination = search_combination( + transfer_rate_bytes_per_sec, node_to_latency_mapping + ) + # Make sure all partitions are set up correctly + reorganize_partitions(self.partitions) + # Set up node to partition mapping + self.node_to_partition = get_node_to_partition_mapping(self.partitions) + return + + def kl_based_partition( + self, + transfer_rate_bytes_per_sec: float, + node_to_latency_mapping: dict[Node, NodeLatency], + ) -> None: + """This function is a cost aware partition based + on Kernighan-Lin algorithm. + First, the graph is partitioned using size_based_partition. + Then, each node is swapped with any other node in a different + partition, and at the same time, the cost is estimated after + the swapping. + For example, we have nodes n0, n1, n2, n3 and n4. + Using size_based_partition, n0 and n1 are in Partition p0. + n2, n3 and n4 in Partition p1. The current cost is estimated. + We first tried using n0 to swap with n2 from the other partition. + Then we see that swapping n0 and n2 shows a lower cost + than the current cost and it is the minimum among other pairs like + (n0, None)(This means moving n0 to Partition without swapping other nodes), + (n0, n3) and (n0, n4). We swap n0 and n2 and set the new cost + as the current cost. + Then We repeat this process for all the other nodes until all swapping pairs + are tried. + """ + + def swap_nodes(n0, n1, p0, p1): + # Either n0 or n1 could be None + # That means we simply move the node + # to another partition + if n0 is not None: + p0.remove_node(n0) + p1.add_node(n0) + if n1 is not None: + p0.add_node(n1) + p1.remove_node(n1) + + def try_swap_nodes( + n0, n1, p0, p1, node_to_latency_mapping, transfer_rate_per_sec + ): + cost = float("inf") + swap_nodes(n0, n1, p0, p1) + # Reorganize partitions after swapping + reorganize_partitions(self.partitions) + # Check if there is a circular dependency after swapping + if (not check_dependency(p0)) and (not check_dependency(p1)): + reset_partition_device(self.partitions) + partition_to_latency_mapping = get_partition_to_latency_mapping( + self.partitions, node_to_latency_mapping + ) + # Check if all partitions can be mapped to logical devices after swapping + found_device = get_device_to_partitions_mapping( + self.partitions, self.devices + ) + if not found_device: + cost = float("inf") + else: + cost = get_latency_of_partitioned_graph( + self.partitions, + partition_to_latency_mapping, + transfer_rate_bytes_per_sec, + ) + # Swap back and reset all partitions back to original + swap_nodes(n1, n0, p0, p1) + reorganize_partitions(self.partitions) + reset_partition_device(self.partitions) + get_device_to_partitions_mapping(self.partitions, self.devices) + return cost + + def swap_node_to_partition( + node, p0, p1, node_to_latency_mapping, transfer_rate_per_sec + ): + """This function helps to swap one node from partition p0 + with all the nodes in another partition p1 + """ + p1_nodes = list(p1.nodes) + [None] + min_cost = float("inf") + node_pair: list[Node] = [] + for n1 in p1_nodes: + # Ignore the node if it is not a op node + if n1 is not None and n1.op in {"placeholder", "get_attr"}: + continue + # Try swapping node in p0 with n1 in p1 + cost = try_swap_nodes( + node, n1, p0, p1, node_to_latency_mapping, transfer_rate_per_sec + ) + if cost < min_cost: + # pyrefly: ignore [bad-assignment] + node_pair = [node, n1] + min_cost = cost + return cost, node_pair # type: ignore[possibly-undefined] + + # First use size_base_partition + self.size_based_partition() + partition_to_latency_mapping = get_partition_to_latency_mapping( + self.partitions, node_to_latency_mapping + ) + # Calculate the cost of the partitions + cost = get_latency_of_partitioned_graph( + self.partitions, partition_to_latency_mapping, transfer_rate_bytes_per_sec + ) + # Keep tracking the node pair that shows the better cost + node_pair: list[Node] = [] + # Keep tracking the partition pair of node pair + partition_pair: list[Partition] = [] + # Collect all the op nodes from the graph + op_nodes = [ + n + for n in self.graph_module.graph.nodes + if n.op not in {"placeholder", "get_attr", "output"} + ] + for node in op_nodes: + # Find which partition the current node belongs + p0_index = self.node_to_partition[node] + p0 = self.partitions[p0_index] + # Go through all the other partitions to swap + # with other nodes from those partitions + for p1_index, _ in enumerate(self.partitions): + if p0_index != p1_index: + p1 = self.partitions[p1_index] + new_cost, new_node_pair = swap_node_to_partition( + node, + p0, + p1, + node_to_latency_mapping, + transfer_rate_bytes_per_sec, + ) + # Update the cost + # Track the swapped node pair and their partitions + if new_cost < cost: + cost = new_cost + node_pair = new_node_pair + partition_pair = [p0, p1] + # Do the swapping after trying all the nodes from a partition + if len(node_pair) != 0: + swap_nodes( + node_pair[0], node_pair[1], partition_pair[0], partition_pair[1] + ) + reorganize_partitions(self.partitions) + get_device_to_partitions_mapping(self.partitions, self.devices) + reorganize_partitions(self.partitions) + # Mapping the device to the partition + get_device_to_partitions_mapping(self.partitions, self.devices) + return + + def aot_based_partition( + self, node_to_partition_mapping, partition_to_logical_device_mapping + ): + """This function helps to rebuild the partitions given the nodes and its + corresponding partition id + """ + partition_id_to_partition_mapping: dict[int, Partition] = {} + self.node_to_partition = node_to_partition_mapping + for node in self.node_to_partition: + partition_id = self.node_to_partition[node] + # If the requested partition has not been created, create the partition + if partition_id not in partition_id_to_partition_mapping: + partition = Partition(partition_id) + self.partitions.append(partition) + partition_id_to_partition_mapping[partition_id] = partition + partition.logical_device_ids = partition_to_logical_device_mapping[ + partition_id + ] + else: + partition = partition_id_to_partition_mapping[ + self.node_to_partition[node] + ] + # Add the current node into the partition + partition.add_node(node) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/const_fold.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/const_fold.py new file mode 100644 index 0000000000000000000000000000000000000000..f494f11593410467623b680a7587e50a614be5a7 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/const_fold.py @@ -0,0 +1,354 @@ +# mypy: allow-untyped-defs +import re +from collections.abc import Callable +from typing import Optional, Union + +import torch.fx +from torch.fx.node import map_arg +from torch.fx.passes.split_module import split_module + + +__all__ = [ + "FoldedGraphModule", + "get_unique_attr_name_in_module", + "split_const_subgraphs", +] + + +class FoldedGraphModule(torch.fx.GraphModule): + """ + FoldedGraphModule is a GraphModule which also contains another + `const_subgraph_module` representing a subgraph which has all const attr + inputs and which can be run once before running the main standard + `graph`. The `const_output_names` are the ordered list names of attrs which + represent what each respective output from the const_subgraph should be set + on which attrs. + """ + + def __init__( + self, + root: torch.nn.Module, + graph: torch.fx.Graph, + const_subgraph: Optional[torch.fx.Graph] = None, + fx_const_folded_attrs_name: Optional[str] = None, + device_for_folded_attrs: str = "cuda", + ): + super().__init__(root, graph) + self.const_subgraph_module = ( + None + if const_subgraph is None + else torch.fx.GraphModule(root, const_subgraph) + ) + self.has_folding_been_run = False + self.fx_const_folded_attrs_name = fx_const_folded_attrs_name + self.device_for_folded_attrs = device_for_folded_attrs + + def __call__(self, *args, **kwargs): + if not self.has_folding_been_run: + self.run_folding() + return super().__call__(*args) + + def run_folding(self): + # If there's no const subgraph module or attr output names to use, return + # early as there is no const folding to perform. + if ( + self.const_subgraph_module is None + or self.fx_const_folded_attrs_name is None + ): + return + + assert not self.has_folding_been_run + self.has_folding_been_run = True + + # Actually run const folding subgraph. Note that single attr const fold + # subgraphs output a single Tensor while multiple outputs are returned as + # Tuple[Tensor,]. + folded_attrs = self.const_subgraph_module() + + def _create_param(i): + return torch.nn.Parameter( + i.detach().clone() + if not isinstance(i, int) + else torch.Tensor([i]).to(device=self.device_for_folded_attrs), + requires_grad=i.requires_grad if isinstance(i, torch.Tensor) else False, + ) + + params = ( + torch.nn.ParameterList([_create_param(i) for i in folded_attrs]) + if isinstance(folded_attrs, tuple) + else _create_param(folded_attrs) + ) + setattr(self, self.fx_const_folded_attrs_name, params) + + +def _inline_module(gm: torch.fx.GraphModule, inline_mod_name: str): + """ + Given `gm` and some graph module which is called with target name `inline_mod_name`, + this helper will inline all of the nodes from that called graph module into `gm`. + """ + # Fetch the inner graph module that we want to inline inside `gm`. + inline_mod = dict(gm.named_modules())[inline_mod_name] + assert isinstance(inline_mod, torch.fx.GraphModule) + call_mod_node_to_replace = None + for node in gm.graph.nodes: + if node.op == "call_module" and node.target == inline_mod_name: + call_mod_node_to_replace = node + break + assert call_mod_node_to_replace is not None + + # Now actually do the swap. Note that we have to keep track of new nodes that are + # copied into `gm` -- we do this via replacement_mapping. + call_mod_args = call_mod_node_to_replace.args + call_mod_kwargs = call_mod_node_to_replace.kwargs + + replacement_mapping: dict[torch.fx.Node, torch.fx.Node] = {} + ph_count = 0 + + def replacement_fn(node): + new_node = replacement_mapping[node] + new_node.meta = node.meta.copy() + return new_node + + for inline_node in inline_mod.graph.nodes: + if inline_node.op == "placeholder": + replacement_mapping[inline_node] = ( + call_mod_kwargs[inline_node.name] + if inline_node.name in call_mod_kwargs + else call_mod_args[ph_count] + ) + + ph_count += 1 + continue + + if inline_node.op == "output": + outputs = inline_node.args[0] + output_replacements = map_arg(outputs, replacement_fn) + call_mod_node_to_replace.replace_all_uses_with(output_replacements) + continue + + with gm.graph.inserting_before(call_mod_node_to_replace): + new_node = gm.graph.node_copy(inline_node, replacement_fn) + replacement_mapping[inline_node] = new_node + + # Explicitly remove the module that was just inlined, + # this module may contain impure ops so cannot be dead code eliminated, + # this module is unneeded as it's just inlined back to main graph. + gm.graph.erase_node(call_mod_node_to_replace) + gm.graph.eliminate_dead_code() + + +def get_unique_attr_name_in_module(mod_traced: torch.fx.GraphModule, name: str) -> str: + """ + Make sure the name is unique (in a module) and can represents an attr. + """ + # Delete all characters that are illegal in a Python identifier. + name = re.sub("[^0-9a-zA-Z_]+", "_", name) + if name[0].isdigit(): + name = f"_{name}" + # Now make sure it is in fact unique to the module by incrementing suffix value. + while hasattr(mod_traced, name): + match = re.match(r"(.*)_(\d+)$", name) + if match is None: + name = name + "_1" + else: + base, num = match.group(1, 2) + name = f"{base}_{int(num) + 1}" + + return name + + +def split_const_subgraphs( + module: Union[torch.nn.Module, torch.fx.GraphModule], + skip_folding_node_fn: Optional[Callable[[torch.fx.Node], bool]] = None, + device_for_folded_attrs: str = "cpu", +) -> FoldedGraphModule: + """ + Looks through `module` for any nodes that have all constant attribute inputs + and separates them out into their own constant subgraph, and returns a + FoldedGraphModule which runs that constant subgraph on the first run to set + attributes on the module prior to running the non-constant portion of the + graph. + """ + + import sympy + + if not isinstance(module, torch.fx.GraphModule): + mod_traced = torch.fx.symbolic_trace(module) + else: + mod_traced = module + + def _subgraph_has_impure_ops(module: torch.fx.GraphModule) -> bool: + """ + Return True if a GraphModule type subgraph contains any impure op, else False. + """ + assert isinstance(module, torch.fx.GraphModule), ( + "caller should only pass GraphModule to subgraph_has_impure_ops check" + ) + for node in module.graph.nodes: + if node.op == "call_function" and node.is_impure(): + return True + if ( + # pyrefly: ignore [invalid-argument] + node.op == "call_module" + # pyrefly: ignore [not-callable] + and (submodule := module.get_submodule(node.target)) + and isinstance(submodule, torch.fx.GraphModule) + ): + return _subgraph_has_impure_ops(submodule) + return False + + # Build up a list of const_nodes, defined as nodes that are themselves + # get_attrs, or have all get_attr or other constant node inputs. + const_nodes: set[torch.fx.Node] = set() + found_const_folding = False + for node in mod_traced.graph.nodes: + # Skip over placeholders/outputs because they can't be const folded and + # we don't want to add tags to them. + if node.op in {"placeholder", "output"}: + continue + + # If the node itself is constant, or all of its inputs are constant, + # then tag it as constant. + if node.op != "get_attr" and not set(node.all_input_nodes).issubset( + const_nodes + ): + continue + + # If provided skip folding function says to skip, then skip. + if skip_folding_node_fn and skip_folding_node_fn(node): + continue + + # Skip folding side-effectful functions + if node.is_impure(): + continue + + # Skip folding nodes that have symbolic fill_value + if isinstance(node.kwargs.get("fill_value", None), sympy.Expr): + continue + + # Skip folding submodules that have impure ops + if ( + # pyrefly: ignore [invalid-argument] + node.op == "call_module" + # pyrefly: ignore [not-callable] + and (target_mod := mod_traced.get_submodule(node.target)) + and isinstance(target_mod, torch.fx.GraphModule) + and _subgraph_has_impure_ops(target_mod) + ): + continue + + # Must be a constant foldable node at this point. + const_nodes.add(node) + if node.op != "get_attr": + found_const_folding = True + + # If we did not find any const folding then return early without a const fold subgraph. + if not found_const_folding: + return FoldedGraphModule(mod_traced, mod_traced.graph) + + # Partition the module into two: submod_0 for constant folding subgraph, and + # submod_1 for the rest. + def mod_partition(node: torch.fx.Node): + return 0 if node in const_nodes else 1 + + split = split_module(mod_traced, module, mod_partition) + + const_mod_name, non_const_mod_name = "submod_0", "submod_1" + # Safely get submod_1 in case there are no non-const nodes + const_gm, non_const_gm = split.submod_0, getattr(split, non_const_mod_name, None) + + # The module that a call_module node refers to gets copied to submodules during split. + # The path to the module also gets inlined, i.e. mod.a.b -> mod_a_b. Here we need to + # attach inlined modules to `split` as it's the owning module now. + for node in non_const_gm.graph.nodes if non_const_gm else []: + if node.op == "call_module": + setattr(split, node.target, getattr(non_const_gm, node.target)) + for node in const_gm.graph.nodes: + if node.op == "call_module": + setattr(split, node.target, getattr(const_gm, node.target)) + + # split_module currently does not use get_attrs for attrs. Instead it passes + # them in as args from the parent module, which used get_attrs. Here we set + # them as get_attrs inside const_gm, allowing for running folding without + # somehow a priori knowing the attrs that should be passed as args. We can + # unconditionally do this for all placeholders because we know all + # placeholders to const_gm must be constants accessible via get_attr. + call_const_gm_args = None + for node in split.graph.nodes: + if node.op == "call_module": + if node.target == const_mod_name: + call_const_gm_args = node.args + break + assert call_const_gm_args is not None + + # Here we do the actual replacement of placeholders to get_attrs. Note that here we + # set the const_gm.graph into a new root_const_gm with split as the root module, + # because we are fetching attributes directly from the root module, instead of + # fetching them from const_gm. Example: The const_gm must have some format like: + # graph(): + # %inp : [num_users=1] = placeholder[target=const_inp] + # %add : [num_users=1] = call_function[target=operator.add](args = (%inp, %inp), kwargs = {}) + # return add + # We replace that with the following, which does not have any placeholders: + # graph(): + # %inp_1 : [num_users=1] = get_attr[target=const_inp] + # %add : [num_users=1] = call_function[target=operator.add](args = (%inp_1, %inp_1), kwargs = {}) + # return add + root_const_gm = torch.fx.GraphModule(split, const_gm.graph) + + # The order of placeholders in the const_gm graph should match the order of + # args in the outer module, so we can simply use an index for the + # placeholder mapping + ph_idx = 0 + for node in root_const_gm.graph.nodes: + if node.op == "output": + multiple_outputs = isinstance(node.args[0], tuple) + continue + if node.op != "placeholder": + continue + assert ph_idx < len(call_const_gm_args) + in_node = call_const_gm_args[ph_idx] + ph_idx += 1 + assert in_node.op == "get_attr" + with root_const_gm.graph.inserting_before(node): + new_node = root_const_gm.graph.get_attr(in_node.target) + new_node.meta = node.meta.copy() + node.replace_all_uses_with(new_node) + root_const_gm.graph.erase_node(node) + assert "multiple_outputs" in locals() + + # Now find the call to const_gm inside split, and replace it with a getattr to the + # folded tensor(s) that result from constant folding. Note that we don't need to + # worry about whether this is one or more tensors because the original graph + # correctly uses getitem to extract individual tensors if there are multiple folded. + fx_const_folded_attrs_name = get_unique_attr_name_in_module( + mod_traced, "_FX_CONST_FOLDED_ATTRS" + ) + setattr( + split, + fx_const_folded_attrs_name, + torch.nn.ParameterList() if multiple_outputs else torch.nn.Parameter(), # type: ignore[possibly-undefined] + ) + for node in split.graph.nodes: + if node.op == "call_module" and node.target == const_mod_name: + with node.graph.inserting_before(node): + folded_attrs = node.graph.get_attr(fx_const_folded_attrs_name) + folded_attrs.meta = node.meta.copy() + node.replace_all_uses_with(folded_attrs) + break + + # Finally, inline the non-constant submod (if it exists) into the split submod. + # This is so that the original caller who may have passed in a graph module will + # get back out a graph module whose graph is traced to the same granularity. + if hasattr(split, non_const_mod_name): + _inline_module(split, non_const_mod_name) + + split.graph.eliminate_dead_code() + + return FoldedGraphModule( + split, + split.graph, + root_const_gm.graph, + fx_const_folded_attrs_name, + device_for_folded_attrs, + ) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/graph_gradual_typechecker.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/graph_gradual_typechecker.py new file mode 100644 index 0000000000000000000000000000000000000000..58a62aee314607320bb5f7eb922192888fa172a5 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/graph_gradual_typechecker.py @@ -0,0 +1,1011 @@ +# mypy: allow-untyped-defs +import itertools +import operator +from collections.abc import Callable +from functools import reduce +from typing import TypeVar +from typing_extensions import ParamSpec + +import sympy + +import torch +from torch.fx.experimental.refinement_types import Equality +from torch.fx.experimental.unification import Var # type: ignore[attr-defined] +from torch.fx.node import Node, Target +from torch.fx.tensor_type import Dyn, is_consistent, is_more_precise, TensorType +from torch.nn.modules.batchnorm import BatchNorm2d +from torch.nn.modules.conv import Conv2d + + +_T = TypeVar("_T") +_P = ParamSpec("_P") + +_INFERENCE_RULES: dict[Target, Callable] = {} +_REFINEMENT_RULES: dict[Target, Callable] = {} +_RULES: dict[Target, Callable] = {} + +__all__ = [ + "GraphTypeChecker", + "Refine", + "adaptiveavgpool2d_check", + "adaptiveavgpool2d_inference_rule", + "add_inference_rule", + "all_eq", + "bn2d_inference_rule", + "broadcast_types", + "calculate_out_dimension", + "conv2d_inference_rule", + "conv_refinement_rule", + "conv_rule", + "element_wise_eq", + "expand_to_tensor_dim", + "first_two_eq", + "flatten_check", + "flatten_inference_rule", + "flatten_refinement_rule", + "get_attr_inference_rule", + "get_greatest_upper_bound", + "get_parameter", + "linear_check", + "linear_inference_rule", + "linear_refinement_rule", + "maxpool2d_check", + "maxpool2d_inference_rule", + "register_algebraic_expressions_inference_rule", + "register_inference_rule", + "register_refinement_rule", + "relu_inference_rule", + "reshape_inference_rule", + "transpose_inference_rule", +] + + +def expand_to_tensor_dim(t, n): + """ + Expand a type to the desired tensor dimension if possible + Raise an error otherwise. + - t is the given type + - n is a number of dimensions to expand to + """ + if t == Dyn: + dims = [Dyn] * n + return TensorType(tuple(dims)) + elif isinstance(t, TensorType): + if len(t.__args__) != n: + raise TypeError( + f"Cannot extend tensor. Tensor {t} has rank {len(t.__args__)}. It should have rank {n}" + ) + return t + else: + raise TypeError(f"Cannot match the type {t}") + + +def broadcast_types(t1, t2): + """ + Applies broadcasting to both given types such that they + become consistent with each other and returns two new + resulting types + """ + + # if either type is Dyn, do nothing since the types are already consistent + if t1 == Dyn or t2 == Dyn or isinstance(t1, Var) or isinstance(t2, Var): + return t1, t2 + + if isinstance(t1, TensorType) and isinstance(t2, TensorType): + s1 = len(t1.__args__) + s2 = len(t2.__args__) + + new_t1 = list(t1.__args__) + new_t2 = list(t2.__args__) + + # We make the types the same length which is the first requirement + # for consistency + if s1 > s2: + for _ in range(s1 - s2): + new_t2.insert(0, 1) + + elif s2 > s1: + for _ in range(s2 - s1): + new_t1.insert(0, 1) + + # we replace occurrences of "1" with each tensor with + # the corresponding type from the other tensor + for i, (x, y) in enumerate(zip(new_t1, new_t2)): + if x == 1: + new_t1[i] = y + elif y == 1: + new_t2[i] = x + + # at this point our tensors should be consistent + # and we can apply the element-wise operation and find the right dimension + # for the output of the operation + (t1, t2) = TensorType(tuple(new_t1)), TensorType(tuple(new_t2)) + return (t1, t2) + else: + raise TypeError(f"Cannot broadcast types {t1} and {t2}") + + +def register_inference_rule( + call_target: Target, +) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]: + def register(fn: Callable[_P, _T]) -> Callable[_P, _T]: + if call_target in _INFERENCE_RULES: + raise RuntimeError(f"Inference rule already registered for {call_target}!") + _INFERENCE_RULES[call_target] = fn + return fn + + return register + + +def register_refinement_rule( + call_target: Target, +) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]: + def register(fn: Callable[_P, _T]) -> Callable[_P, _T]: + if call_target in _REFINEMENT_RULES: + raise RuntimeError(f"Refinement rule already registered for {call_target}!") + _REFINEMENT_RULES[call_target] = fn + return fn + + return register + + +def register_algebraic_expressions_inference_rule( + call_target: Target, +) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]: + def register(fn: Callable[_P, _T]) -> Callable[_P, _T]: + if call_target in _RULES: + raise RuntimeError(f"Rule already registered for {call_target}!") + _RULES[call_target] = fn + return fn + + return register + + +@register_inference_rule(torch.add) +@register_inference_rule(operator.add) +def add_inference_rule(n: Node): + """ + Apply the addition inference rule. This includes: + - scalar addition + - broadcasting semantics + + Note that we always return the least precise type between + the operands (after applying broadcasting) to be the final type of the operation + + Note that we do not modify the operand types themselves after applying broadcasting + to them. We only use them to calculate the final type + """ + assert isinstance(n.args[0], Node) + assert isinstance(n.args[1], Node) + t1 = n.args[0].type + t2 = n.args[1].type + + # handle scalar addition + if t1 is int and isinstance(t2, TensorType): + n.type = t2 + return n.type + + # handle scalar addition + elif t2 is int and isinstance(t1, TensorType): + n.type = t1 + return n.type + + # we bring the new types to the point where + # we can check for consistency + # any inconsistency would not have been caused + # by broadcasting at this point + (new_t1, new_t2) = broadcast_types(t1, t2) + + if new_t1 != t1 or new_t2 != t2: + n.meta["broadcast"] = True + n.meta[str(n.args[0])] = new_t1 + n.meta[str(n.args[1])] = new_t2 + + else: + n.meta["broadcast"] = False + + new_t1 = t1 if not n.meta["broadcast"] else new_t1 + new_t2 = t2 if not n.meta["broadcast"] else new_t2 + + # we check for consistency between the new types + if is_consistent(new_t1, new_t2): + # we return the less precise type because + # broadcasting may have happened + # for operands with shape [1,2,Dyn] and [1,2,1] + # we have to assign the node [1,2,Dyn] + if is_more_precise(new_t1, new_t2): + n.type = new_t2 + else: + n.type = new_t1 + return n.type + else: + raise TypeError( + f"Cannot add arguments {n.args[0]} ({n.args[0].type}) and {n.args[1]} ({n.args[1].type}) in node {n}." + f" Types should match " + ) + + +@register_inference_rule(getattr) +def get_attr_inference_rule(n: Node, traced): + """ + The current getattr rule only handles the shape attribute + Can be extended to other attributes + The most representitive type we have is "Dyn" but the system + can be extended with more types, such as a type to represent shapes + """ + attr_name = n.args[1] + + if attr_name == "shape": + n.type = Dyn + else: + raise TypeError("Not yet implemented") + + # TODO. We leave it like this till we add a type to represent tensor sizes + return n.type + + +@register_inference_rule(torch.transpose) +def transpose_inference_rule(n: Node): + """ + We check that dimensions for the transpose operations + are within range of the tensor type of the node + """ + if n.target is torch.transpose: + assert isinstance(n.args[0], Node) + t = n.args[0].type + + assert isinstance(n.args[1], int) + assert isinstance(n.args[2], int) + dim1, dim2 = n.args[1], n.args[2] + + if t == Dyn: + n.type = Dyn + return n.type + + elif isinstance(t, TensorType): + if 0 <= dim1 < len(t.__args__) and 0 <= dim2 < len(t.__args__): + new_type = list(t.__args__) + new_type[dim1], new_type[dim2] = new_type[dim2], new_type[dim1] + final = TensorType(new_type) + n.type = get_greatest_upper_bound(n.type, final) + return n.type + else: + raise TypeError( + f"Cannot transpose {dim1} and {dim2} in type {t} for node {n}" + ) + else: + raise TypeError( + f"Cannot transpose {dim1} and {dim2} in type {t} for node {n}" + ) + + +@register_inference_rule(torch.reshape) +def reshape_inference_rule(n: Node): + """ + Without dynamism, the rule checks that the + product of the elements of the argument tensor + type is equal to the product of the elements + of the required shape. We gradualize this rule + by adding a case to handle fully dynamic input + as well as input where some of the tensor dimensions + are unknown. In this case we check for divisibility + """ + assert isinstance(n.args[0], Node) + t1 = n.args[0].type + + assert isinstance(n.args[1], list) + t2 = n.args[1] + t2_type = TensorType([Dyn if elem == -1 else elem for elem in t2]) + + # if we do not know the original tensor dimension, + # we return the required dimension + if t1 == Dyn: + n.type = t2_type + return t2_type + + # if any of the dimensions are unknown, + # we check for divisibility + elif isinstance(t1, TensorType): + assert isinstance(t1, TensorType) + a = [e if e != Dyn else 1 for e in t1.__args__] + p1 = reduce(operator.mul, a) + p2 = reduce(operator.mul, t2) + if p1 % p2 == 0 or p2 % p1 == 0: + n.type = t2_type + return t2_type + else: + raise TypeError(f"Cannot reshape in node {n} from {t1} to {t2_type}") + else: + raise TypeError(f"Cannot reshape in node {n} from {t1} to {t2_type}") + + +@register_inference_rule(BatchNorm2d) +def bn2d_inference_rule(n: Node, module_instance): + """ + Given a BatchNorm2D instance and a node check the following conditions: + - the input type can be expanded to a size 4 tensor: t = (x_1, x_2, x_3, x_4) + - the current node type can be expanded to a size 4 tensor: t' = (x_1', x_2', x_3', x_4') + - t is consistent with t' + - x_2 is consistent with the module's num_features + - x_2' is consistent with the module's num_features + output type: the more precise type of t and t' + """ + assert isinstance(n.args[0], Node) + n.args[0].type = expand_to_tensor_dim(n.args[0].type, 4) + arg_type = n.args[0].type + n.type = expand_to_tensor_dim(n.type, 4) + + # we check the conditions on the incoming argument + # and any existing annotation + # we also check for consistency between both annotations + if ( + is_consistent(arg_type.__args__[1], module_instance.num_features) + and is_consistent(n.type.__args__[1], module_instance.num_features) + and is_consistent(arg_type, n.type) + ): + # we choose the more precise type + # to be the node type + # so if an incoming argument has more type information + # we set this node's type to be the argument type + n.type = get_greatest_upper_bound(arg_type, n.type) + return n.type + else: + raise TypeError( + f"Cannot apply {module_instance} with input type {arg_type} and existing type {n.type} on {n}" + ) + + +def calculate_out_dimension(d_in, module_instance, index): + """ + For calculating h_in and w_out according to the conv2D documentation + """ + padding = ( + (module_instance.padding, module_instance.padding) + if isinstance(module_instance.padding, int) + else module_instance.padding + ) + kernel_size = ( + (module_instance.kernel_size, module_instance.kernel_size) + if isinstance(module_instance.kernel_size, int) + else module_instance.kernel_size + ) + stride = ( + (module_instance.stride, module_instance.stride) + if isinstance(module_instance.stride, int) + else module_instance.stride + ) + dilation = ( + (module_instance.dilation, module_instance.dilation) + if isinstance(module_instance.dilation, int) + else module_instance.dilation + ) + + DIMENSION_TYPES = (int, sympy.Symbol) + + if d_in == Dyn: + return Dyn + + elif isinstance(d_in, DIMENSION_TYPES): + n = d_in + 2 * padding[index] - dilation[index] * (kernel_size[index] - 1) - 1 + + return (n // stride[0]) + 1 + + else: + raise TypeError( + f"{d_in} in {module_instance} must be a number or Dyn. Received {type(d_in)}" + ) + + +def get_greatest_upper_bound(type1, type2): + """ + Get the most precise type that's consistent with the given types + """ + if type1 == Dyn: + return type2 + elif type2 == Dyn: + return type1 + elif isinstance(type1, TensorType) and isinstance(type2, TensorType): + if not is_consistent(type1, type2): + raise TypeError(f"Inconsistent types {type1}, {type2}") + gub = [ + t1 if is_more_precise(t1, t2) else t2 + for (t1, t2) in zip(type1.__args__, type2.__args__) + ] + return TensorType(tuple(gub)) + + +@register_inference_rule(Conv2d) +def conv2d_inference_rule(n: Node, module_instance): + """ + Given a Conv2D instance and a node check the following conditions: + - the input type can be expanded to a size 4 tensor: t = (x_1, x_2, H, W) + - the current node type can be expanded to a size 4 tensor: t' = (x_1', x_2', x_3', x_4') + - x_2 is consistent with the module's in_channels + - let o = (x_1, out_channels, H_out, W_out) + then the output is the greatest upper bound of o and the existing node type t'. + """ + assert isinstance(n.args[0], Node) + n.args[0].type = expand_to_tensor_dim(n.args[0].type, 4) + arg_type = n.args[0].type + curr_node_type = expand_to_tensor_dim(n.type, 4) + + if is_consistent(arg_type.__args__[1], module_instance.in_channels): + w_in = arg_type.__args__[3] + h_in = arg_type.__args__[2] + h_out = calculate_out_dimension(h_in, module_instance, 0) + w_out = calculate_out_dimension(w_in, module_instance, 1) + new_type = TensorType( + (arg_type.__args__[0], module_instance.out_channels, h_out, w_out) + ) + gub = get_greatest_upper_bound(new_type, curr_node_type) + n.type = gub + return n.type + else: + raise TypeError( + f"Cannot apply {module_instance} with input type {arg_type} and existing type {n.type} on {n}" + ) + + +@register_inference_rule(torch.nn.ReLU) +def relu_inference_rule(n: Node, module_instance): + """ + Input and output shapes should be equal. + """ + assert isinstance(n.args[0], Node) + + if n.args[0].type == Dyn and isinstance(n.type, TensorType): + n.args[0].type = expand_to_tensor_dim(n.args[0].type, len(n.type.__args__)) + + if isinstance(n.args[0].type, TensorType): + n.type = get_greatest_upper_bound(n.args[0].type, n.type) + return n.type + + +def maxpool2d_check(typ, module_instance): + """ + Applies the maxpool2d shape information to the input + this affects the last two dimensions + """ + new_type_list = list(typ.__args__) + if len(new_type_list) == 4 or len(new_type_list) == 3: + w_in = new_type_list[-1] + h_in = new_type_list[-2] + + h_out = calculate_out_dimension(h_in, module_instance, 0) + w_out = calculate_out_dimension(w_in, module_instance, 1) + + new_type_list[-1] = w_out + new_type_list[-2] = h_out + return TensorType(tuple(new_type_list)) + + else: + raise TypeError(f"Wrong size {typ} for {module_instance}") + + +@register_inference_rule(torch.nn.MaxPool2d) +def maxpool2d_inference_rule(n: Node, module_instance): + """ + Given a MaxPool2D instance and a node check the following conditions: + - Input size matches size 3 or 4 + - Current node type is consistent with the output type we will calculate + - Input size matches output size and the last two dimensions of the output + are w_out and h_out. The remaining dimensions are the same as the input + - Our final result is the greatest upper bound of the output we calculate + and the current node type. + """ + assert isinstance(n.args[0], Node) + + if n.args[0].type == Dyn and isinstance(n.type, TensorType): + n.args[0].type = expand_to_tensor_dim(n.args[0].type, len(n.type.__args__)) + if isinstance(n.args[0].type, TensorType): + output = maxpool2d_check(n.args[0].type, module_instance) + n.type = get_greatest_upper_bound(output, n.type) + return n.type + + +def linear_check(tensor_type, module_instance): + """ + Checks that an input tensor type satisfies the conditions for linear operation + and returns the output type based on in and out features given by module_instance + """ + if len(tensor_type.__args__) >= 2: + if is_consistent(module_instance.in_features, tensor_type.__args__[-1]): + new_type_args = list(tensor_type.__args__) + new_type_args[-1] = module_instance.out_features + return TensorType(tuple(new_type_args)) + else: + raise TypeError( + f"Inconsistent {module_instance.in_features} and {tensor_type.__args__[-1]} in {module_instance}" + ) + else: + raise TypeError(f"Type {tensor_type} must have rank 2 or more.") + + +@register_inference_rule(torch.nn.Linear) +def linear_inference_rule(n: Node, module_instance): + """ + Applies the shape information to the input then gets the greatest upper bound + of the resulting type and the existing type + """ + assert isinstance(n.args[0], Node) + if n.args[0].type == Dyn and isinstance(n.type, TensorType): + n.args[0].type = expand_to_tensor_dim(n.args[0].type, len(n.type.__args__)) + if isinstance(n.args[0].type, TensorType): + output_type = linear_check(n.args[0].type, module_instance) + n.type = get_greatest_upper_bound(output_type, n.type) + return n.type + + +def adaptiveavgpool2d_check(tensor_type, module_instance): + output_size = module_instance.output_size + if isinstance(output_size, int): + output_size = [output_size, output_size] + elif isinstance(output_size, tuple): + output_size = list(output_size) + if output_size[0] is None: + output_size[0] = output_size[1] + if output_size[1] is None: + output_size[1] = output_size[0] + + new_type_list = list(tensor_type.__args__) + + if len(tensor_type.__args__) == 4 or len(tensor_type.__args__) == 3: + new_type_list[-1] = output_size[1] + new_type_list[-2] = output_size[0] + + return TensorType(tuple(new_type_list)) + + else: + raise TypeError(f"Tensor ranks must be 3 or 4. Got {tensor_type}") + + +@register_inference_rule(torch.nn.AdaptiveAvgPool2d) +def adaptiveavgpool2d_inference_rule(n: Node, module_instance): + """ + The input and output sizes should be the same except for the last + two dimensions taken from the input, which represent width and height + """ + assert isinstance(n.args[0], Node) + if n.args[0].type == Dyn and isinstance(n.type, TensorType): + n.args[0].type = expand_to_tensor_dim(n.args[0].type, len(n.type.__args__)) + if isinstance(n.args[0].type, TensorType): + output_type = adaptiveavgpool2d_check(n.args[0].type, module_instance) + n.type = get_greatest_upper_bound(n.type, output_type) + return n.type + + +def flatten_check(tensor_type, start_dim, end_dim): + l = len(tensor_type.__args__) + + start_dim = l if start_dim == -1 else abs(start_dim) + end_dim = l + end_dim + 1 if end_dim < 0 else end_dim + 1 + + if 0 <= start_dim <= (l - 1) and 0 <= end_dim <= l and start_dim < end_dim: + my_args = list(tensor_type.__args__) + lhs = my_args[0:start_dim] + rhs = my_args[end_dim:] + mid = my_args[start_dim:end_dim] + if Dyn in mid: + mid = [Dyn] + else: + mid = [reduce(operator.mul, my_args[start_dim:end_dim])] + new_type_list = lhs + mid + rhs + return TensorType(tuple(new_type_list)) + else: + raise TypeError( + f"Incompatible dimensions {start_dim}, {end_dim - 1} in type {tensor_type}" + ) + + +@register_inference_rule(torch.flatten) +def flatten_inference_rule(n: Node): + """ + Applies the flatten shape information to the input then gets the + greatest upper bound of the resulting type and the existing type + """ + assert isinstance(n.args[0], Node) + + # set the default start and end dims + start_dim = 1 + end_dim = -1 + + if len(n.args) > 1: + assert isinstance(n.args[1], int) + start_dim = n.args[1] + + if len(n.args) > 2: + assert isinstance(n.args[2], int) + end_dim = n.args[2] + + if n.args[0].type == Dyn and isinstance(n.type, TensorType): + n.args[0].type = expand_to_tensor_dim(n.args[0].type, len(n.type.__args__)) + + if isinstance(n.args[0].type, TensorType): + output_type = flatten_check(n.args[0].type, start_dim, end_dim) + n.type = get_greatest_upper_bound(output_type, n.type) + + return n.type + + +class GraphTypeChecker: + def __init__(self, env, traced): + self.env = env + self.traced = traced + + def type_check(self): + """ + A gradual type checker for graphs + Effect: every node's field type will be + populated with a type after type-checking is done + """ + graph = self.traced.graph + + # type check every node with gradual type rules + # if any node does not type check return false + for n in graph.nodes: + self.type_check_node(n) + return True + + def type_check_node(self, n: Node): + """ + Type check a given fx node. + Current operations: + - Reshape + - Transpose + - Add + - Relu + - conv2d + - batchnorm2d + - flatten + - maxpool2d + - adaptiveavgpool2d + - linear + """ + if n.type is None: + n.type = Dyn + + if n.op == "placeholder": + return n.type + + elif n.op == "get_attr": + t = get_parameter(self.traced, n.target) # type: ignore[arg-type] + if isinstance(t.data, torch.Tensor): + n.type = TensorType(t.data.shape) + return n.type + + elif n.op == "call_function": + if n.target is getattr: + assert getattr in _INFERENCE_RULES + return _INFERENCE_RULES[n.target](n, self.traced) + + elif n.target in _INFERENCE_RULES: + return _INFERENCE_RULES[n.target](n) + else: + raise RuntimeError( + f"No inference rule registered for target {n.target}!" + ) + + elif n.op == "call_module": + module_instance = self.traced.get_submodule(n.target) + if type(module_instance) in _INFERENCE_RULES: + return _INFERENCE_RULES[type(module_instance)](n, module_instance) + else: + raise RuntimeError( + f"No inference rule registered for class {type(module_instance)}!" + ) + + elif n.op == "output": + + def get_node_type(a): + return a.type + + n.type = torch.fx.node.map_arg(n.args[0], get_node_type) + return n.type + + else: + raise NotImplementedError(f"Method {n.op} not yet implemented") + + +@register_refinement_rule(Conv2d) +def conv_refinement_rule(n: Node): + """ + The equality constraints are between the first dimension of + the input and output + """ + res = [] + assert isinstance(n.args[0], Node) + arg_type = n.args[0].type + if isinstance(arg_type, TensorType) and isinstance(n.type, TensorType): + res = [Equality(arg_type.__args__[0], n.type.__args__[0])] + return res + + +@register_refinement_rule(torch.nn.Linear) +def linear_refinement_rule(n: Node): + """ + The equality constraints are between the first dimension of + the input and output + """ + res = [] + assert isinstance(n.args[0], Node) + arg_type = n.args[0].type + if isinstance(arg_type, TensorType) and isinstance(n.type, TensorType): + res = [Equality(arg_type.__args__[0], n.type.__args__[0])] + return res + + +@register_refinement_rule(BatchNorm2d) +@register_refinement_rule(torch.nn.ReLU) +def all_eq(n: Node): + """ + For operations where the input shape is equal to the output shape + """ + res = [] + assert isinstance(n.args[0], Node) + arg_type = n.args[0].type + if isinstance(arg_type, TensorType) and isinstance(n.type, TensorType): + args1 = arg_type.__args__ + args2 = n.type.__args__ + res = [Equality(args1[i], args2[i]) for i in range(len(args1))] + return res + + +@register_refinement_rule(torch.nn.AdaptiveAvgPool2d) +@register_refinement_rule(torch.nn.MaxPool2d) +def first_two_eq(n: Node): + """ + For operations where the first two dimensions of the input and output shape + are equal + """ + res = [] + assert isinstance(n.args[0], Node) + arg_type = n.args[0].type + if isinstance(arg_type, TensorType) and isinstance(n.type, TensorType): + args1 = arg_type.__args__ + args2 = n.type.__args__ + res = [Equality(args1[0], args2[0]), Equality(args1[1], args2[1])] + return res + + +@register_refinement_rule(torch.add) +@register_refinement_rule(operator.add) +def element_wise_eq(n: Node): + """ + For element-wise operations and handles broadcasting. + Note that after applying broadcasting to the arguments + we are able to determine if certain dimensions have not been broadcast + if they are symbolicallu equal. + + in this case, we can establish equality between those dimensions and the + corresponding output dimensions. + + Note that it takes two iterations for this result. One iteration to establish + equality between certain dimensions of the operands (requiring the whole solver + including unification) and another iteration to establish equality between the operands + and the resulting type, requiring another round of constraint generation and unificaiton. + """ + res = [] + if isinstance(n.args[0], Node) and isinstance(n.args[1], Node): + arg_type1 = n.args[0].type + arg_type2 = n.args[1].type + if ( + isinstance(arg_type1, TensorType) + and isinstance(arg_type2, TensorType) + and isinstance(n.type, TensorType) + ): + args1, args2 = broadcast_types(arg_type1, arg_type2) + # by this point, we know that args1 and args2 are the same size. + a1 = args1.__args__ + a2 = args2.__args__ + a3 = n.type.__args__ + + # we would be here in the second iteration where we establish equality + # between operand type dimensions and the resulting type dimensions + r = [] + for x, y, z in zip(a1, a2, a3): + if x == y: + r.append(Equality(x, z)) + res = r + return res + + +@register_refinement_rule(torch.flatten) +def flatten_refinement_rule(n: Node): + """ + Generates equality constraints between the dimensions of the input and output + that will not be involved in the flatten operation + """ + assert isinstance(n.args[0], Node) + + eq_const = [] + + start_dim = 1 + end_dim = -1 + + if len(n.args) > 1: + assert isinstance(n.args[1], int) + start_dim = n.args[1] + + if len(n.args) > 2: + assert isinstance(n.args[2], int) + end_dim = n.args[2] + + if isinstance(n.type, TensorType) and isinstance(n.args[0].type, TensorType): + l = len(n.type.__args__) + arg_type = n.args[0].type + start_dim = l if start_dim == -1 else start_dim + end_dim = l + end_dim + 1 if end_dim < 0 else end_dim + 1 + + for t1, t2 in zip(n.type.__args__[0:start_dim], arg_type.__args__[0:start_dim]): + eq_const.append(Equality(t1, t2)) + + for t1, t2 in zip(n.type.__args__[end_dim:], arg_type.__args__[end_dim:]): + eq_const.append(Equality(t1, t2)) + return eq_const + + +@register_algebraic_expressions_inference_rule(Conv2d) +def conv_rule(n: Node, module_instance): + """ + Represents the output in terms of an algrbraic expression w.r.t + the input when possible + """ + assert isinstance(n.args[0], Node) + arg_type = n.args[0].type + if isinstance(arg_type, TensorType) and isinstance(n.type, TensorType): + w_in = arg_type.__args__[3] + h_in = arg_type.__args__[2] + h_out = calculate_out_dimension(h_in, module_instance, 0) + w_out = calculate_out_dimension(w_in, module_instance, 1) + new_type = TensorType((n.type.__args__[0], n.type.__args__[1], h_out, w_out)) + n.type = new_type + return new_type + + +class Refine: + """ + Symbolic shape inference. + Generates constraints over type variables. + Currently all constraints are equality constraints. + """ + + def __init__(self, traced): + self.constraints = [] + self.traced = traced + self.symbol_iter = itertools.count(start=0, step=1) + + def refine(self): + """ + Generates constraints for + every node in the graph based on + the operation. + """ + graph = self.traced.graph + for n in graph.nodes: + self.refine_node(n) + return True + + def symbolic_relations(self): + """ + Infers algebraic relations + """ + graph = self.traced.graph + for n in graph.nodes: + self.infer_symbolic_relations(n) + return True + + def replace_dyn_with_fresh_var(self, typ): + """ + Replace all unknown types with fresh type variables. + """ + if typ == Dyn: + new_symbol = Var(next(self.symbol_iter)) + return new_symbol + elif isinstance(typ, TensorType): + new_args = [self.replace_dyn_with_fresh_var(a) for a in typ.__args__] + return TensorType(tuple(new_args)) + elif isinstance(typ, list): + return [self.replace_dyn_with_fresh_var(t) for t in typ] + elif isinstance(typ, tuple): + return (self.replace_dyn_with_fresh_var(t) for t in typ) + else: + return typ + + def convert_to_sympy_symbols(self, typ): + """ + Replace all unknown types with fresh type variables. + """ + if isinstance(typ, Var): + return sympy.symbols(str(typ)) + elif isinstance(typ, TensorType): + new_args = [self.convert_to_sympy_symbols(a) for a in typ.__args__] + return TensorType(tuple(new_args)) + elif isinstance(typ, list): + return [self.convert_to_sympy_symbols(t) for t in typ] + elif isinstance(typ, tuple): + return (self.convert_to_sympy_symbols(t) for t in typ) + else: + return typ + + def refine_node(self, n: Node): + """ + Returns a list of equality constraints for + call_module and call_function nodes. + Models the relation between input and output dimensions + using constraints in case they are both tensors. + All operations used in resnet50 are defined. + """ + if n.type is None: + n.type = Dyn + + n.type = self.replace_dyn_with_fresh_var(n.type) + + if n.op == "call_function": + if n.target in _REFINEMENT_RULES: + self.constraints += _REFINEMENT_RULES[n.target](n) + + if n.op == "call_module": + module_instance = self.traced.get_submodule(n.target) + if type(module_instance) in _REFINEMENT_RULES: + self.constraints += _REFINEMENT_RULES[type(module_instance)](n) + + if n.op == "output": + + def get_node_type(a): + return a.type + + n.type = torch.fx.node.map_arg(n.args[0], get_node_type) + return n.type + + def infer_symbolic_relations(self, n: Node): + n.type = self.convert_to_sympy_symbols(n.type) + if n.op == "call_function": + if n.target in _RULES: + return _RULES[n.target](n) + + if n.op == "call_module": + module_instance = self.traced.get_submodule(n.target) + if type(module_instance) in _RULES: + return _RULES[type(module_instance)](n, module_instance) + + if n.op == "output": + + def get_node_type(a): + return a.type + + n.type = torch.fx.node.map_arg(n.args[0], get_node_type) + return n.type + + +def get_parameter(traced, target: str): + """ + Returns the parameter given by ``target`` if it exists, + otherwise throws an error. + + See the docstring for ``get_submodule`` for a more detailed + explanation of this method's functionality as well as how to + correctly specify ``target``. + + Args: + target: The fully-qualified string name of the Parameter + to look for. (See ``get_submodule`` for how to specify a + fully-qualified string.) + + Returns: + torch.nn.Parameter: The Parameter referenced by ``target`` + + Raises: + AttributeError: If the target string references an invalid + path or resolves to something that is not an + ``nn.Parameter`` + """ + module_path, _, param_name = target.rpartition(".") + + mod: torch.nn.Module = traced.get_submodule(module_path) + + if not hasattr(mod, param_name): + raise AttributeError(mod._get_name() + " has no attribute `" + param_name + "`") + + param: torch.nn.Parameter = getattr(mod, param_name) + + return param diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/meta_tracer.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/meta_tracer.py new file mode 100644 index 0000000000000000000000000000000000000000..cb3adfba8d412a12012cb3148732e0fab42a7b66 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/meta_tracer.py @@ -0,0 +1,320 @@ +# mypy: allow-untyped-defs +import builtins +import functools +import warnings +from collections.abc import Callable +from typing import Any, Optional, Union + +import torch +import torch.fx + + +def embedding_override(self, input): + return torch.empty(*input.shape, self.weight.shape[-1], device="meta") + + +def nn_layernorm_override(self, input): + return input + + +def torch_relu_override(x): + return x + + +def torch_nn_relu_override(self, x): + return x + + +def functional_relu_override(x, inplace=False): + assert not inplace, "dont support inplace functional.relu for metatensor analysis" + return x + + +def torch_where_override(condition, x, y): + # torch.where returns the broadcasted tensor of condition, x, and y, + # so hack it by using addition + return condition.to(device="meta") + x.to(device="meta") + y.to(device="meta") + + +def torch_abs_override(input, *, out=None): + assert out is None, "Dont support in-place abs for MetaTensor analysis" + return input + + +manual_meta_overrides: dict[Callable, Callable] = { + torch.nn.Embedding: embedding_override, + torch.nn.LayerNorm: nn_layernorm_override, + torch.relu: torch_relu_override, + torch.nn.functional.relu: functional_relu_override, + torch.nn.ReLU: torch_nn_relu_override, + torch.where: torch_where_override, + torch.abs: torch_abs_override, +} + + +def gen_constructor_wrapper(target): + @functools.wraps(target) + def wrapper(*args, **kwargs): + proxy = None + + def check_has_proxy(v): + if isinstance(v, torch.fx.Proxy): + nonlocal proxy + proxy = v + + torch.fx.node.map_aggregate(args, check_has_proxy) + torch.fx.node.map_aggregate(kwargs, check_has_proxy) + + if proxy is not None: + return proxy.tracer.create_proxy("call_function", target, args, kwargs) + else: + return target(*args, **kwargs) + + return wrapper, target + + +class MetaProxy(torch.fx.Proxy): + def install_tensor_meta(self, tensor_meta): + self._tensor_meta = tensor_meta + + def size(self, dim=None): + if hasattr(self, "_tensor_meta") and self._tensor_meta is not None: + return self._tensor_meta.size(*[dim] if dim else []) + return self.tracer.create_proxy( + "call_method", "size", (self, dim) if dim else (self,), {} + ) + + def dim(self): + if hasattr(self, "_tensor_meta") and self._tensor_meta is not None: + return self._tensor_meta.dim() + return self.tracer.create_proxy("call_method", "dim", (self,), {}) + + @property + def shape(self): + if hasattr(self, "_tensor_meta") and self._tensor_meta is not None: + return self._tensor_meta.shape + return self.tracer.create_proxy( + "call_function", builtins.getattr, (self, "shape"), {} + ) + + @property + def dtype(self): + if hasattr(self, "_tensor_meta") and self._tensor_meta is not None: + return self._tensor_meta.dtype + return self.tracer.create_proxy( + "call_function", builtins.getattr, (self, "dtype"), {} + ) + + @property + def device(self): + # Hack so we can track when devices are used. During meta-tensor propagation, + # replace these values with a constant 'meta' + return MetaDeviceAttribute(self, "device") + + def __getattr__(self, k): + if k == "_tensor_meta": + return self.__getattribute__(k) + # note: not added to the graph yet, if this is a method call + # we peephole optimize to the method invocation + return MetaAttribute(self, k) + + +class MetaAttribute(MetaProxy): + def __init__(self, root, attr: str): + self.root = root + self.attr = attr + self.tracer = root.tracer + self._node = None + + @property + def node(self): # type: ignore[override] + # the node for attributes is added lazily, since most will just be method calls + # which do not rely on the getitem call + if self._node is None: + self._node = self.tracer.create_proxy( + "call_function", getattr, (self.root, self.attr), {} + ).node + return self._node + + def __call__(self, *args, **kwargs): + return self.tracer.create_proxy( + "call_method", self.attr, (self.root,) + args, kwargs + ) + + +class MetaDeviceAttribute(MetaAttribute): + pass + + +def proxys_to_metas(v): + if isinstance(v, MetaDeviceAttribute): + return "meta" + if isinstance(v, torch.fx.Proxy): + assert isinstance(v, MetaProxy), f"Expected MetaProxy but got {type(v)}" + assert hasattr(v, "_tensor_meta"), "MetaProxy does not have an associated meta" + return v._tensor_meta + return v + + +class MetaTracer(torch.fx.Tracer): + allow_insert_stateless_mods: bool = True + + _TORCH_METHODS_TO_PATCH = ["arange", "zeros", "ones", "full_like", "eye"] + + def create_proxy( + self, + kind, + target, + args, + kwargs, + name=None, + type_expr=None, + proxy_factory_fn=None, + ): + rv = super().create_proxy( + kind, + target, + args, + kwargs, + name, + type_expr, + # pyrefly: ignore [bad-argument-type] + proxy_factory_fn, + ) + + if kind == "placeholder" and target in self.meta_args: + rv.install_tensor_meta(self.meta_args[target]) + return rv + + if target in self.orig_fns: + # NOTE: tensor constructors in PyTorch define the `device` argument as + # *kwargs-only*. That is why this works. If you add methods to + # _TORCH_METHODS_TO_PATCH that do not define `device` as kwarg-only, + # this will break and you will likely see issues where we cannot infer + # the size of the output. + if "device" in kwargs: + kwargs["device"] = "meta" + + try: + args_metas = torch.fx.node.map_aggregate(args, proxys_to_metas) + kwargs_metas = torch.fx.node.map_aggregate(kwargs, proxys_to_metas) + + if kind == "call_function": + meta_target = manual_meta_overrides.get(target, target) + # pyrefly: ignore [not-callable] + meta_out = meta_target(*args_metas, **kwargs_metas) + elif kind == "call_method": + meta_target = getattr(args_metas[0], target) # type: ignore[index] + meta_out = meta_target(*args_metas[1:], **kwargs_metas) # type: ignore[index] + elif kind == "call_module": + assert hasattr(self, "orig_forward") + self._disable_module_getattr = True + try: + mod = self.root.get_submodule(target) + mod_type = type(mod) + if mod_type in manual_meta_overrides: + meta_out = manual_meta_overrides[mod_type]( + mod, *args_metas, **kwargs_metas + ) # type: ignore[misc, arg-type] + else: + meta_out = self.orig_forward(*args_metas, **kwargs_metas) + finally: + self._disable_module_getattr = False + elif kind == "get_attr": + self._disable_module_getattr = True + try: + attr_itr = self.root + atoms = target.split(".") + for atom in atoms: + attr_itr = getattr(attr_itr, atom) + assert isinstance(attr_itr, torch.Tensor) + meta_out = attr_itr.to(device="meta") + finally: + self._disable_module_getattr = False + else: + return rv + + # TODO + assert isinstance(rv, torch.fx.Proxy), "Dont support composite output yet" + rv.install_tensor_meta(meta_out) + except Exception as e: + warnings.warn(f"Could not compute metadata for {kind} target {target}: {e}") + + return rv + + def getattr(self, attr, attr_val, parameter_proxy_cache): + if getattr(self, "_disable_module_getattr", False): + return attr_val + else: + return super().getattr(attr, attr_val, parameter_proxy_cache) + + def call_module(self, m, forward, args, kwargs): + self.orig_forward = forward + return super().call_module(m, forward, args, kwargs) + + def _insert_module_as_submodule(self, mod: torch.nn.Module) -> str: + """ + Helper method which tries to insert a module that was not declared as submodule. + """ + idx = 0 + mod_name = mod.__class__.__name__.lower() + path = f"{mod_name}_{idx}" + while hasattr(self.root, path): + path = f"{mod_name}_{idx}" + idx += 1 + + self.root.add_module(path, mod) + return path + + def path_of_module(self, mod: torch.nn.Module) -> str: + try: + return super().path_of_module(mod) + except NameError: + if ( + self.allow_insert_stateless_mods + and len(list(mod.parameters())) == 0 + and len(list(mod.buffers())) == 0 + ): + path = self._insert_module_as_submodule(mod) + self.prev_module = path + return path + raise + + def proxy(self, node): + return MetaProxy(node, self) + + def trace(self, root, meta_args: dict[str, torch.Tensor], concrete_args=None): # type: ignore[override] + assert isinstance(meta_args, dict) + self.meta_args = meta_args + + self.patched_torch_methods = { + target: gen_constructor_wrapper(getattr(torch, target)) + for target in self._TORCH_METHODS_TO_PATCH + } + self.orig_fns = set() + + for name, (wrapper, orig) in self.patched_torch_methods.items(): + setattr(torch, name, wrapper) + self.orig_fns.add(orig) + + try: + graph = super().trace(root, concrete_args) + graph._tracer_extras = {"meta_args": meta_args} + return graph + finally: + for name, (_, orig) in self.patched_torch_methods.items(): + setattr(torch, name, orig) + + +def symbolic_trace( + root: Union[torch.nn.Module, Callable[..., Any]], + meta_args: Optional[dict[str, torch.Tensor]] = None, + concrete_args: Optional[dict[str, Any]] = None, +) -> torch.fx.GraphModule: + tracer = MetaTracer() + graph = tracer.trace(root, meta_args, concrete_args) # type: ignore[arg-type] + name = ( + root.__class__.__name__ if isinstance(root, torch.nn.Module) else root.__name__ + ) + gm = torch.fx.GraphModule(tracer.root, graph, name) + return gm diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/migrate_gradual_types/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/migrate_gradual_types/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/migrate_gradual_types/constraint.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/migrate_gradual_types/constraint.py new file mode 100644 index 0000000000000000000000000000000000000000..e46b3a607044a47774db97ec14c0ed40bea3d23d --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/migrate_gradual_types/constraint.py @@ -0,0 +1,637 @@ +# mypy: allow-untyped-defs +from torch.fx.experimental.migrate_gradual_types.operation import ( + op_add, + op_div, + op_eq, + op_gt, + op_lt, + op_mod, + op_mul, + op_neq, + op_sub, +) +from torch.fx.tensor_type import Dyn, TensorType + + +class Constraint: + pass + + +class Conj(Constraint): + def __init__(self, conjuncts): + """ + :param conjuncts: Conjunction of constraints + """ + self.conjucts = conjuncts + + def __eq__(self, other): + if isinstance(other, Conj): + return self.conjucts == other.conjucts and self.conjucts == other.conjucts + else: + return False + + def __repr__(self): + return f"And({self.conjucts})" + + +class Disj(Constraint): + def __init__(self, disjuncts): + """ + :param disjuncts: Disjunction of constraints + """ + self.disjuncts = disjuncts + + def __eq__(self, other): + if isinstance(other, Disj): + return ( + self.disjuncts == other.disjuncts and self.disjuncts == other.disjuncts + ) + else: + return False + + def __repr__(self): + return f"Or({self.disjuncts})" + + +class Prod(Constraint): + def __init__(self, products): + """ + :param products: lists of dimensions to multiply + """ + self.products = products + + def __eq__(self, other): + if isinstance(other, Prod): + return self.products == other.products and self.products == other.products + else: + return False + + def __repr__(self): + return f"Product({self.products})" + + +class T(Constraint): + """ + True + """ + + def __init__(self) -> None: + pass + + def __eq__(self, other): + return isinstance(other, T) + + def __repr__(self): + return "True" + + +class F(Constraint): + """ + False + """ + + def __init__(self) -> None: + pass + + def __eq__(self, other): + return isinstance(other, F) + + def __repr__(self): + return "False" + + +class BinaryConstraint(Constraint): + """ + Represents all binary operations + """ + + def __init__(self, lhs, rhs, op): + """ + :param lhs: lhs of the constraint + :param rhs: rhs of the constraint + :param op: string representing the operation + """ + self.lhs = lhs + self.rhs = rhs + self.op = op + + def __eq__(self, other): + if isinstance(other, BinaryConstraint): + return ( + self.lhs == other.lhs and self.rhs == other.rhs and self.op == other.op + ) + else: + return False + + def __repr__(self): + return f"({self.lhs} {self.op} {self.rhs})" + + +class BinConstraintT(BinaryConstraint): + """ + Binary constraints about tensors + """ + + def __init__(self, lhs, rhs, op): + assert (isinstance(lhs, (TVar, TensorType, int)) or lhs == Dyn) and ( + isinstance(rhs, (TVar, TensorType, int)) or rhs == Dyn + ) + super().__init__(lhs, rhs, op) + + +class BinConstraintD(BinaryConstraint): + """ + Binary constraints about dimensions + """ + + def __init__(self, lhs, rhs, op): + assert is_algebraic_expression(lhs) or is_dim(lhs) or is_bool_expr(lhs) + assert is_algebraic_expression(rhs) or is_dim(rhs) or is_bool_expr(rhs) + + super().__init__(lhs, rhs, op) + + +class TGreatestUpperBound(Constraint): + """ + Greatest Upper bound for tensors with dynamic type + """ + + def __init__(self, res, rhs1, rhs2): + """ + :param res: tensor variable that stores the result of the output + :param rhs1: tensor or tensor variable + :param rhs2: tensor or tensor variabke + """ + self.res = res + self.rhs1 = rhs1 + self.rhs2 = rhs2 + + def __repr__(self): + return f"{self.res} = {self.rhs1}\u2294*{self.rhs2}" + + def __eq__(self, other): + if isinstance(other, TGreatestUpperBound): + return ( + self.res == other.res + and self.rhs1 == other.rhs1 + and self.rhs2 == other.rhs2 + ) + else: + return False + + +class DGreatestUpperBound(Constraint): + """ + Greatest Upper bound for dimensions + """ + + def __init__(self, res, rhs1, rhs2): + """ + :param res: Dimension variable to store the result + :param rhs1: dimension variable 1 + :param rhs2: dimension variable 2 + """ + assert is_dim(res) + assert is_dim(rhs1) + assert is_dim(rhs2) + + self.res = res + self.rhs1 = rhs1 + self.rhs2 = rhs2 + + def __repr__(self): + return f"{self.res} = {self.rhs1}\u2294{self.rhs2}" + + def __eq__(self, other): + if isinstance(other, DGreatestUpperBound): + return ( + self.res == other.res + and self.rhs1 == other.rhs1 + and self.rhs2 == other.rhs2 + ) + else: + return False + + +class CanReshape(Constraint): + """ + can_reshape constraint + """ + + def __init__(self, src, target): + """ + :param src: tensor variable + :param target: tensor + """ + self.src = src + self.target = target + + def __repr__(self): + return f"can-reshape({self.src}, {self.target})" + + def __eq__(self, other): + if isinstance(other, CanReshape): + return self.src == other.src and self.target == other.target + else: + return False + + +class IndexSelect(Constraint): + def __init__(self, tensor_size, input_var, dim_replace, index, output): + """ + Args: + input_var: input to index_select + tensor_size: tensor size we are considering + dim_replace: the dimension of the output at "index" + index: location of the dimensions to replace in the input + output: variable to store the result + """ + assert isinstance(input_var, TVar) + assert isinstance(output, TVar) + assert isinstance(dim_replace, DVar) or dim_replace == Dyn + assert isinstance(index, int) + + self.input_var = input_var + self.tensor_size = tensor_size + self.dim_replace = dim_replace + self.index = index + self.output = output + + def __repr__(self): + return ( + f" {self.output} = " + f"IndexSelect({self.input_var}, " + f"tensor_size: {self.tensor_size}, " + f"{self.dim_replace}, " + f"{self.index})" + ) + + def __eq__(self, other): + if isinstance(other, IndexSelect): + return ( + self.tensor_size == other.tensor_size + and self.dim_replace == other.dim_replace + and self.index == other.index + and self.output == other.output + and self.input_var == other.input_var + ) + else: + return False + + +class Transpose(Constraint): + def __init__(self, tensor_size, input_var, index1, index2, output): + """ + Args: + tensor_size: current tensor size + input_var: variable to hold input + index1: dimension 1 + index2: dimension 2 + output: output that stores result + """ + assert isinstance(input_var, TVar) + assert isinstance(output, TVar) + assert isinstance(index1, int) + assert isinstance(index2, int) + + self.input_var = input_var + self.tensor_size = tensor_size + self.index1 = index1 + self.index2 = index2 + self.output = output + + def __repr__(self): + return ( + f" {self.output} = " + f"Transpose({self.input_var}, " + f"tensor_size: {self.tensor_size}, " + f"{self.index1}, " + f"{self.index2})" + ) + + def __eq__(self, other): + if isinstance(other, Transpose): + return ( + self.tensor_size == other.tensor_size + and self.index1 == other.index1 + and self.index2 == other.index2 + and self.output == other.output + and self.input_var == other.input_var + ) + else: + return False + + +class GetItem(Constraint): + def __init__(self, tensor_size, index, res, input_var): + """ + Constraint for getting item given a tensor size + :param tensor_size: actual number + :param index: actual number representing the index + :param res: dimension variable to carry the item we get + :param input_var: a tensor variable from which we will get item + """ + assert isinstance(res, DVar) + + self.res = res + self.tensor_size = tensor_size + self.index = index + self.input_var = input_var + + def __repr__(self): + return f" {self.res} = GetItem({self.input_var}, tensor_size: {self.tensor_size}, {self.index})" + + def __eq__(self, other): + if isinstance(other, GetItem): + return ( + self.res == other.res + and self.tensor_size == other.tensor_size + and self.index == other.index + and self.input_var == other.input_var + ) + else: + return False + + +class GetItemTensor(Constraint): + def __init__(self, tensor_size, index_tuple, res, input_var): + """ + Constraint for getting item given a tensor size + However, when the argument is a tuple, we will + expect a tensor + :param tensor_size: actual number representing the rank + :param index_tuple: tuple for indexing + :param res: tensor variable to carry the item we get + :param input_var: a tensor variable from which we will get item + """ + assert isinstance(res, TVar) + + self.res = res + self.tensor_size = tensor_size + self.index_tuple = index_tuple + self.input_var = input_var + + def __repr__(self): + return f" {self.res} = GetItemT({self.input_var}, tensor_size: {self.tensor_size}, {self.index_tuple})" + + def __eq__(self, other): + if isinstance(other, GetItemTensor): + return ( + self.res == other.res + and self.tensor_size == other.tensor_size + and self.index_tuple == other.index_tuple + and self.input_var == other.input_var + ) + else: + return False + + +class CalcConv(Constraint): + def __init__( + self, + conv_result, + input_var, + c_out, + kernel, + padding, + stride, + dilation, + matching_constraint_vars, + ): + """ + :param conv_result: the convolution result + :param input_var: input to convolution + :param c_out: output channel type + :param kernel: kernel tuple + """ + self.conv_result = conv_result + self.input_var = input_var + self.c_out = c_out + self.kernel = kernel + self.padding = padding + self.stride = stride + self.dilation = dilation + self.matching_constraint = matching_constraint_vars + + def __repr__(self): + return ( + f"{self.conv_result} =" + f" calc-conv({self.input_var}," + f" {self.c_out}, {self.kernel}, " + f"{self.padding}, {self.stride}," + f" {self.dilation})" + ) + + def __eq__(self, other): + if isinstance(other, CalcConv): + return ( + self.conv_result == other.conv_result + and self.input_var == other.input_var + and self.c_out == other.c_out + and self.kernel == other.kernel + and self.padding == other.padding + and self.stride == other.stride + and self.dilation == other.dilation + and self.matching_constraint == other.matching_constraint + ) + else: + return False + + +class CalcMaxPool(Constraint): + def __init__( + self, + maxpool_result, + input_var, + kernel, + padding, + stride, + dilation, + matching_constraint_vars, + ): + """ + :param maxpool_result: the result of maxpool + :param input_var: input to convolution + :param kernel: kernel tuple + """ + self.maxpool_result = maxpool_result + self.input_var = input_var + self.kernel = kernel + self.padding = padding + self.stride = stride + self.dilation = dilation + self.matching_constraint = matching_constraint_vars + + def __repr__(self): + return ( + f"{self.maxpool_result} =" + f" calc-maxpool({self.input_var}," + f" {self.kernel}, " + f"{self.padding}, {self.stride}," + f" {self.dilation})" + ) + + def __eq__(self, other): + if isinstance(other, CalcMaxPool): + return ( + self.maxpool_result == other.maxpool_result + and self.input_var == other.input_var + and self.kernel == other.kernel + and self.padding == other.padding + and self.stride == other.stride + and self.dilation == other.dilation + and self.matching_constraint == other.matching_constraint + ) + else: + return False + + +class ApplyBroadcasting(Constraint): + def __init__(self, res1, res2, input1, input2): + """ + :param res1: resulting tensor 1 + :param res2: resulting tensor 2 + :param input1: tensor variable 1 + :param input2: tensor variable 2 + """ + self.res1 = res1 + self.res2 = res2 + self.input1 = input1 + self.input2 = input2 + + def __eq__(self, other): + if isinstance(other, ApplyBroadcasting): + return ( + self.res1 == other.res1 + and self.res2 == other.res2 + and self.input1 == other.input1 + and self.input2 == other.input2 + ) + else: + return False + + def __repr__(self): + return ( + f"{self.res1}, {self.res2} =" + f" apply-broadcasting({self.input1}," + f" {self.input2})" + ) + + +class CalcProduct(Constraint): + """ + Given correct dimensions, calculate the product for flatten accounting for Dyn + """ + + def __init__(self, start, end, flattened, dims_to_flatten): + """ + :param start: start index + :param end: end index + :param flattened: variable to store the product + :param dims_to_flatten: the type which we will flatten + """ + assert isinstance(dims_to_flatten, list) + assert isinstance(flattened, TVar) + assert isinstance(start, int) + assert isinstance(end, int) + + self.start = start + self.end = end + self.dims_to_flatten = dims_to_flatten + self.flattened = flattened + + def __eq__(self, other): + if isinstance(other, CalcProduct): + return ( + self.start == other.start + and self.end == other.end + and self.dims_to_flatten == other.dims_to_flatten + and self.flattened == other.flattened + ) + + else: + return False + + def __repr__(self): + return f"{self.flattened} = CalcProduct({self.start}, {self.end}, {self.dims_to_flatten})" + + +class TVar: + """ + Tensor variable with no tensor constructor + """ + + def __init__(self, tvar): + """ + :param tvar: tensor variable + """ + self.tvar = tvar + + def __repr__(self): + return f"TV({self.tvar})" + + def __eq__(self, other): + if isinstance(other, TVar): + return self.tvar == other.tvar + else: + return False + + +class DVar: + """ + Dimension variable + """ + + def __init__(self, c): + """ + :param c: character or number + """ + self.c = c + + def __repr__(self): + return f"DV({self.c})" + + def __eq__(self, other): + if isinstance(other, DVar): + return self.c == other.c + else: + return False + + +class BVar: + """ + Boolean variable + """ + + def __init__(self, c): + """ + :param c: character or number + """ + self.c = c + + def __repr__(self): + return f"BV({self.c})" + + def __eq__(self, other): + if isinstance(other, BVar): + return self.c == other.c + else: + return False + + +def is_algebraic_expression(constraint): + if isinstance(constraint, BinConstraintD): + return constraint.op in [op_add, op_sub, op_div, op_mul, op_mod] + else: + return isinstance(constraint, Prod) + + +def is_bool_expr(constraint): + if isinstance(constraint, BinConstraintD): + return constraint.op in [op_gt, op_lt, op_neq, op_eq] + else: + return isinstance(constraint, (BVar, Conj, Disj)) + + +def is_dim(d): + return isinstance(d, (DVar, int)) or d == Dyn diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/migrate_gradual_types/constraint_generator.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/migrate_gradual_types/constraint_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..28e5c7c215e64f0ee61a840f37482c56f988c445 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/migrate_gradual_types/constraint_generator.py @@ -0,0 +1,1565 @@ +# mypy: allow-untyped-defs +import operator +import warnings +from collections.abc import Callable, Iterable +from typing import TypeVar +from typing_extensions import ParamSpec + +import torch +from torch.fx._symbolic_trace import _assert_is_none +from torch.fx.experimental.migrate_gradual_types.constraint import ( + ApplyBroadcasting, + BinConstraintD, + BinConstraintT, + CalcConv, + CalcMaxPool, + CalcProduct, + CanReshape, + Conj, + DGreatestUpperBound, + Disj, + DVar, + F, + GetItem, + GetItemTensor, + IndexSelect, + T, + TGreatestUpperBound, + Transpose, + TVar, +) +from torch.fx.experimental.migrate_gradual_types.operation import ( + op_add, + op_consistency, + op_div, + op_eq, + op_gt, + op_leq, + op_lt, + op_matching, + op_mul, + op_neq, + op_precision, + op_sub, +) +from torch.fx.experimental.migrate_gradual_types.util import ( + gen_bvar, + gen_dvar, + gen_nat_constraints, + gen_tensor_dims, + gen_tvar, +) +from torch.fx.node import Node, Target +from torch.fx.tensor_type import Dyn, TensorType +from torch.nn.modules.batchnorm import BatchNorm2d +from torch.nn.modules.conv import Conv2d + + +_T = TypeVar("_T") +_P = ParamSpec("_P") + +_INFERENCE_RULES: dict[Target, Callable] = {} + +MAX_TENSOR_RANK = 4 + +__all__ = [ + "ConstraintGenerator", + "adaptive_inference_rule", + "add_layer_norm_constraints", + "add_linear_constraints", + "arange_inference_rule", + "assert_inference_rule", + "batchnorm_inference_rule", + "bmm_inference_rule", + "broadcasting_inference_rule", + "conv2d_inference_rule", + "cumsum_inference_rule", + "embedding_inference_rule", + "embedding_inference_rule_functional", + "eq_inference_rule", + "equality_inference_rule", + "expand_inference_rule", + "flatten_inference_rule", + "full_inference_rule", + "gen_broadcasting_constraints", + "gen_embedding_rules", + "gen_layer_norm_constraints", + "generate_flatten_constraints", + "get_attr_inference_rule", + "getitem_inference_rule", + "gt_inference_rule", + "index_select_inference_rule", + "layer_norm_functional", + "layer_norm_inference_rule", + "linear_constraints", + "linear_inference_rule", + "lt_inference_rule", + "masked_fill_inference_rule", + "maxpool_inference_rule", + "neq_inference_rule", + "range_check", + "register_inference_rule", + "relu_inference_rule", + "reshape_inference_rule", + "size_inference_rule", + "tensor_inference_rule", + "torch_dim_inference_rule", + "torch_linear_inference_rule", + "transpose_inference_rule", + "type_inference_rule", + "view_inference_rule", +] + + +def register_inference_rule( + call_target: Target, +) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]: + def register(fn: Callable[_P, _T]) -> Callable[_P, _T]: + if call_target in _INFERENCE_RULES: + raise RuntimeError(f"Inference rule already registered for {call_target}!") + _INFERENCE_RULES[call_target] = fn + return fn + + return register + + +def generate_flatten_constraints(start_dim, end_dim, input, flattened, n, counter): + d, counter = gen_tensor_dims(n, counter) + c1 = BinConstraintT(input, TensorType(d), op_eq) + start_dim = n if start_dim == -1 else abs(start_dim) + end_dim = n + end_dim + 1 if end_dim < 0 else end_dim + 1 + c2 = CalcProduct(start_dim, end_dim, flattened, d) + nat_constraints = gen_nat_constraints(d) + return Conj([c1, c2, *nat_constraints]), counter + + +@register_inference_rule(getattr) +def get_attr_inference_rule(n: Node, symbols, constraints, counter): + """ + If the attribute is "device" then the tensor shape is preserved + """ + assert isinstance(n.args[0], Node) + assert isinstance(n.args[1], str) + output, counter = gen_tvar(counter) + symbols[n] = output + + input = symbols[n.args[0]] + attr = n.args[1] + + if attr == "device": + return [BinConstraintT(input, output, op_eq)], counter + else: + raise NotImplementedError("Not yet implemented") + + +@register_inference_rule(torch.bmm) +def bmm_inference_rule(n: Node, symbols, constraints, counter): + """ + Constraints that match the input to a size 3 tensor + and switch the dimensions according to the rules + of batch multiplication + """ + assert isinstance(n.args[0], Node) + assert isinstance(n.args[1], Node) + + bmm_output, counter = gen_tvar(counter) + symbols[n] = bmm_output + + bmm_input1 = symbols[n.args[0]] + bmm_input2 = symbols[n.args[1]] + + dims_input1, counter = gen_tensor_dims(3, counter) + dims_input2, counter = gen_tensor_dims(3, counter) + + inputs_dyn = Conj( + [ + BinConstraintT(bmm_input1, Dyn, op_eq), + BinConstraintT(bmm_input2, Dyn, op_eq), + BinConstraintT(bmm_output, Dyn, op_eq), + ] + ) + + input1_dyn = Conj( + [ + BinConstraintT(bmm_input1, Dyn, op_eq), + BinConstraintT(bmm_input2, TensorType(dims_input2), op_eq), + BinConstraintT( + bmm_output, TensorType([dims_input2[0], Dyn, dims_input2[2]]), op_eq + ), + ] + ) + + input2_dyn = Conj( + [ + BinConstraintT(bmm_input2, Dyn, op_eq), + BinConstraintT(bmm_input1, TensorType(dims_input1), op_eq), + BinConstraintT( + bmm_output, TensorType([dims_input1[0], dims_input1[1], Dyn]), op_eq + ), + ] + ) + + consistency_constraints = [ + BinConstraintD(dims_input1[0], dims_input2[0], op_consistency) + ] + + batch_size, counter = gen_dvar(counter) + + inputs_are_tensors = Conj( + [ + BinConstraintT(bmm_input1, TensorType(dims_input1), op_eq), + BinConstraintT(bmm_input2, TensorType(dims_input2), op_eq), + BinConstraintT( + bmm_output, + TensorType([batch_size, dims_input1[1], dims_input2[2]]), + op_eq, + ), + *consistency_constraints, + DGreatestUpperBound(batch_size, dims_input1[0], dims_input2[0]), + ] + ) + + return [Disj([inputs_dyn, input1_dyn, input2_dyn, inputs_are_tensors])], counter + + +@register_inference_rule("index_select") +def index_select_inference_rule(n: Node, symbols, constraints, counter): + """ + We constrain the second argument to a vector or Dyn. + The output replaces the input with the shape of the vector + at the position given by the index (first argument) + """ + # print(n.args) + assert isinstance(n.args[0], Node) + assert isinstance(n.args[1], int) + assert isinstance(n.args[2], Node) + + index_select, counter = gen_tvar(counter) + symbols[n] = index_select + + dims, counter = gen_tensor_dims(1, counter) + + # equality constraint + is_size_1 = BinConstraintT(symbols[n.args[2]], TensorType(dims), op_eq) + is_dyn = BinConstraintT(symbols[n.args[2]], Dyn, op_eq) + + c2 = Conj( + [ + is_size_1, + Disj( + [ + IndexSelect( + i + 1, symbols[n.args[0]], dims[0], n.args[1], index_select + ) + for i in range(MAX_TENSOR_RANK) + ] + ), + ] + ) + c3 = Conj( + [ + is_dyn, + Disj( + [ + IndexSelect(i + 1, symbols[n.args[0]], Dyn, n.args[1], index_select) + for i in range(MAX_TENSOR_RANK) + ] + ), + ] + ) + + return [Disj([c2, c3])], counter + + +@register_inference_rule("expand") +def expand_inference_rule(n: Node, symbols, constraints, counter): + """ + We generate the exact constraints as we do for tensor additions but we constraint + the rank of this expression to be equal to len(n.args[1:]) so that only + those cases get considered for the output + """ + assert isinstance(n.args[0], Node) + + # define the output for expand + expand, counter = gen_tvar(counter) + symbols[n] = expand + + # since we do not have two nodes here, we will construct an argument variable + e1 = symbols[n.args[0]] + e2, counter = gen_tvar(counter) + + e2_nat_constraints = [] + for arg in n.args[1:]: + assert isinstance(arg, (Node, int)) + if isinstance(arg, Node): + assert isinstance(symbols[arg], DVar) + e2_nat_constraints.append(BinConstraintD(0, symbols[arg], op_leq)) + + e2_constraint = BinConstraintT( + e2, + TensorType( + [arg if isinstance(arg, int) else symbols[arg] for arg in n.args[1:]] + ), + op_eq, + ) + + constraints, counter = gen_broadcasting_constraints( + e1, e2, symbols, counter, expand + ) + + # constraint the output size + dims, counter = gen_tensor_dims(len(n.args[1:]), counter) + nat_constraints = gen_nat_constraints(dims) + c = [ + BinConstraintT(expand, TensorType(dims), op_eq), + *nat_constraints, + e2_constraint, + *e2_nat_constraints, + ] + constraints += c + + return constraints, counter + + +@register_inference_rule(torch.nn.functional.gelu) +@register_inference_rule(torch.nn.functional.dropout) +@register_inference_rule(torch.nn.functional.softmax) +@register_inference_rule("detach") +@register_inference_rule("to") +@register_inference_rule("int") +@register_inference_rule("long") +@register_inference_rule("contiguous") +@register_inference_rule(torch.ones) +@register_inference_rule(torch.zeros) +def equality_inference_rule(n: Node, symbols, constraints, counter): + """ + We generate the constraint: input = output + """ + output, counter = gen_tvar(counter) + symbols[n] = output + + if isinstance(n.args[0], Node): + input = symbols[n.args[0]] + if isinstance(input, TVar): + return [BinConstraintT(input, output, op_eq)], counter + + # then we have dimension variables + else: + for arg in n.args: + assert isinstance(symbols[arg], DVar) + my_size = [symbols[arg] for arg in n.args] + return [BinConstraintT(output, TensorType(my_size), op_eq)], counter + + elif isinstance(n.args[0], tuple): + # then the tuple is the size + assert len(n.args[0]) <= 4 + my_size = [symbols[arg] for arg in n.args[0]] + return [BinConstraintT(output, TensorType(my_size), op_eq)], counter + else: + raise NotImplementedError("Method not yet implemented") + + +@register_inference_rule("transpose") +def transpose_inference_rule(n: Node, symbols, constraints, counter): + """ + Can be considered as a sequence of two index selects, so we generate constraints accordingly + """ + assert isinstance(n.args[0], Node) + assert isinstance(n.args[1], int) + assert isinstance(n.args[2], int) + + output, counter = gen_tvar(counter) + symbols[n] = output + + from_arg = symbols[n.args[0]] + assert isinstance(from_arg, TVar) + + # input and output are dyn + is_dyn = Conj( + [BinConstraintT(from_arg, Dyn, op_eq), BinConstraintT(output, Dyn, op_eq)] + ) + + # or input is a tensor and we actually do the replacement + c3 = Disj( + [ + Transpose(i + 1, from_arg, n.args[1], n.args[2], output) + for i in range(MAX_TENSOR_RANK) + ] + ) + + return [Disj([is_dyn, c3])], counter + + +@register_inference_rule("type_as") +def type_inference_rule(n: Node, symbols, constraints, counter): + """ + We generate the constraint: input = output + """ + assert isinstance(n.args[0], Node) + assert isinstance(n.args[1], Node) + + output, counter = gen_tvar(counter) + symbols[n] = output + + from_arg = symbols[n.args[0]] + to_arg = symbols[n.args[1]] + + assert isinstance(from_arg, TVar) + assert isinstance(to_arg, TVar) + + return [ + BinConstraintT(from_arg, to_arg, op_consistency), + BinConstraintT(output, to_arg, op_eq), + ], counter + + +@register_inference_rule("masked_fill_") +def masked_fill_inference_rule(n: Node, symbols, constraints, counter): + """ + Similar to addition. For now we implement the constraints when + the argument is a boolean tensor. There is also a case for when + it is a condition. We will leave this out for now. + """ + + assert isinstance(n.args[0], Node) + assert isinstance(n.args[1], Node) + + # We will retrieve the type variables from the symbol table + # and confirm they are tensor variables + + e1 = symbols[n.args[0]] + e2 = symbols[n.args[1]] + + if isinstance(e1, TVar) and isinstance(e2, TVar): + masked_fill_tensor, counter = gen_tvar(counter) + symbols[n] = masked_fill_tensor + return gen_broadcasting_constraints( + e1, e2, symbols, counter, masked_fill_tensor + ) + else: + raise NotImplementedError("Not yet implemented") + + +@register_inference_rule(torch.nn.functional.embedding) +def embedding_inference_rule_functional(n: Node, symbols, constraints, counter): + assert isinstance(n.args[0], Node) + + embedding_dim_weights = symbols[n.args[1]] + + # will treat this as a static shape. So we will not use matching. + weight_dims, counter = gen_tensor_dims(2, counter) + equality_constraint = BinConstraintT( + embedding_dim_weights, TensorType(weight_dims), op_eq + ) + embedding_dim = weight_dims[1] + constraints, counter = gen_embedding_rules(n, symbols, embedding_dim, counter) + return [equality_constraint] + constraints, counter + + +@register_inference_rule(torch.nn.modules.sparse.Embedding) +def embedding_inference_rule(n: Node, module_instance, symbols, constraints, counter): + """ + The output shape differs from the input shape in the last dimension + """ + assert isinstance(n.args[0], Node) + return gen_embedding_rules(n, symbols, module_instance.embedding_dim, counter) + + +def gen_embedding_rules(n: Node, symbols, embedding_dim, counter): + embedding_output, counter = gen_tvar(counter) + symbols[n] = embedding_output + embedding_input = symbols[n.args[0]] + + input_dyn = BinConstraintT(embedding_input, Dyn, op_eq) + output_dyn = BinConstraintT(embedding_output, Dyn, op_eq) + + c1 = Conj([input_dyn, output_dyn]) + c2 = [] + + for i in range(1, MAX_TENSOR_RANK): + new_dims, counter = gen_tensor_dims(i, counter) + nat_constraints = gen_nat_constraints(new_dims) + + # we consider all tensor sizes and append embedding_dim to the end of the output dimension in all cases + c_tensor_i = Conj( + [ + BinConstraintT(embedding_input, TensorType(new_dims), op_eq), + BinConstraintT( + embedding_output, TensorType(new_dims + [embedding_dim]), op_eq + ), + ] + + nat_constraints + ) + c2.append(c_tensor_i) + + return [Disj([c1, Disj(c2)])], counter + + +@register_inference_rule(torch.tensor) +def tensor_inference_rule(n: Node, symbols, constraints, counter): + """ + If the tensor is a scalar, we will skip it since we + do not support scalars yet. We will add support in the future + if it's needed. For our examples so far, scalars are not needed. + """ + return [], counter + + +@register_inference_rule("reshape") +@register_inference_rule("view") +def view_inference_rule(n: Node, symbols, constraints, counter): + """ + Similar to reshape but with an extra condition on the strides + """ + assert isinstance(n.args[0], Node) + + # generate the new variable + my_view, counter = gen_tvar(counter) + symbols[n] = my_view + + src_var = symbols[n.args[0]] + t2 = [ + symbols[elem] if isinstance(elem, Node) else elem for elem in n.args[1:] + ] # target shape + t2_type = [] + num_constraints = [] + + for t in t2: + if t == -1: + var, counter = gen_dvar(counter) + t2_type.append(var) + # pyrefly: ignore [bad-argument-type] + num_constraints.append(BinConstraintD(var, Dyn, op_neq)) + + else: + # pyrefly: ignore [bad-argument-type] + num_constraints.append(BinConstraintD(t, Dyn, op_neq)) + t2_type.append(t) # type: ignore[arg-type] + + t2_type = TensorType(t2_type) # type: ignore[assignment] + + c1 = BinConstraintT(my_view, t2_type, op_eq) + c2 = CanReshape(src_var, t2_type) + + # TODO: add the extra check mentioned here: + # https://pytorch.org/docs/stable/generated/torch.Tensor.view.html#torch.Tensor.view + + return [c1, c2] + num_constraints, counter # type: ignore[operator] + + +@register_inference_rule("size") +def size_inference_rule(n: Node, symbols, constraints, counter): + """ + The constraint is just lhs = rhs. + Ex: size = input_ids.size() + """ + + if len(n.args) == 1: + # generate the new variable + size, counter = gen_tvar(counter) + symbols[n] = size + input = symbols[n.args[0]] + c = BinConstraintT(input, size, op_eq) + return [c], counter + + elif len(n.args) == 2: + # TODO: review this rule; should input = dyn; output = dyn be included here? + if isinstance(n.args[1], int): + # generate the new variable + size_index, counter = gen_dvar(counter) + symbols[n] = size_index + input = symbols[n.args[0]] + c2 = [ + GetItem(i + 1, n.args[1], size_index, input) + for i in range(MAX_TENSOR_RANK) + ] + c3 = BinConstraintD(0, size_index, op_leq) + + input_dyn = BinConstraintT(input, Dyn, op_eq) + output_dyn = BinConstraintD(size_index, Dyn, op_eq) + c1 = Conj([input_dyn, output_dyn]) + + return [Disj([c1, Conj([Disj(c2), c3])])], counter + + else: + raise NotImplementedError + + else: + raise NotImplementedError + + +def range_check(i, n): + """ + Checks if an index i is within range of a size n list + Args: + i: index + n: list size + + Returns: Boolean + """ + if i >= 0: + return T() if i < n else F() + else: + return T() if i >= n else F() + + +@register_inference_rule(torch.cumsum) +def cumsum_inference_rule(n: Node, symbols, constraints, counter): + """ + Input and output shapes should be equal + We should verify that the index is valid + """ + assert isinstance(n.args[0], Node) + arg_1 = n.args[1] if len(n.args) > 1 else n.kwargs["dim"] + assert isinstance(arg_1, int) + + output, counter = gen_tvar(counter) + symbols[n] = output + input = symbols[n.args[0]] + + input_dyn = BinConstraintT(input, Dyn, op_eq) + output_dyn = BinConstraintT(output, Dyn, op_eq) + c1 = Conj([input_dyn, output_dyn]) + c2 = [] + for i in range(1, MAX_TENSOR_RANK + 1): + new_dims, counter = gen_tensor_dims(i, counter) + + nat_constraints = gen_nat_constraints(new_dims) + + c_tensor_i = Conj( + [ + BinConstraintT(input, TensorType(new_dims), op_eq), + BinConstraintT(output, TensorType(new_dims), op_eq), + ] + + [range_check(arg_1, i)] + + nat_constraints + ) + + c2.append(c_tensor_i) + dyn_or_tensor = Disj([c1, Disj(c2)]) + return [dyn_or_tensor], counter + + +@register_inference_rule(_assert_is_none) +def assert_inference_rule(n: Node, symbols, constraints, counter): + assert len(n.users) == 0 + return [], counter + + +@register_inference_rule(operator.getitem) +def getitem_inference_rule(n: Node, symbols, constraints, counter): + assert isinstance(n.args[0], Node) + + # dimension output case + if isinstance(n.args[1], int): + # create and store the new dimension variable + get_item_output, counter = gen_dvar(counter) + symbols[n] = get_item_output + + # retrieve arg variables + get_item_arg = symbols[n.args[0]] + assert isinstance(get_item_arg, TVar) + + # if the input is dynamic, we accept any index and return + # a dynamic dimension as output + input_dyn = BinConstraintT(get_item_arg, Dyn, op_eq) + output_dyn = BinConstraintD(get_item_output, Dyn, op_eq) + c1 = Conj([input_dyn, output_dyn]) + + # if the input is a tensor, + # generate a getItem constraint which will be expanded based on the + # tensor dimension. + + c2 = [ + GetItem(i + 1, n.args[1], get_item_output, get_item_arg) + for i in range(MAX_TENSOR_RANK) + ] + + # since the output is a dimension, we make sure it's a natural number + # added as a conjunction to the disjunction of c2 + c3 = BinConstraintD(0, get_item_output, op_leq) + return [Disj([c1, Conj([Disj(c2), c3])])], counter + + # tensor output case + elif isinstance(n.args[1], tuple): + # create and store the new tensor variable + get_item_output, counter = gen_tvar(counter) # type: ignore[arg-type,assignment] + symbols[n] = get_item_output + + # retrieve arg variables + if n.args[0] in symbols: + get_item_arg = symbols[n.args[0]] + assert isinstance(get_item_arg, TVar) + + input_dyn = BinConstraintT(get_item_arg, Dyn, op_eq) + output_dyn = BinConstraintT(get_item_output, Dyn, op_eq) # type: ignore[assignment] + c1 = Conj([input_dyn, output_dyn]) + + c2 = [ + GetItemTensor(i + 1, n.args[1], get_item_output, get_item_arg) # type: ignore[misc] + for i in range(MAX_TENSOR_RANK) + ] + else: + # TODO: we should figure out why there is a key-error here. + return [], counter + + return [Disj([c1, *c2])], counter + + else: + raise RuntimeError("Method not yet implemented") + + +@register_inference_rule(operator.gt) +def gt_inference_rule(n: Node, symbols, constraints, counter): + assert isinstance(n.args[0], (Node, int)) + assert isinstance(n.args[1], (Node, int)) + + # We make sure this node will not be used again. We do not + # generate a constraint about that node. Only about the operands. + + e1 = symbols[n.args[0]] if isinstance(n.args[0], Node) else n.args[0] + e2 = symbols[n.args[1]] if isinstance(n.args[1], Node) else n.args[1] + + if isinstance(n.args[0], Node) and isinstance(n.args[1], Node): + if isinstance(e1, TVar) and isinstance(e2, TVar): + gt_tensor, counter = gen_tvar(counter) + symbols[n] = gt_tensor + return gen_broadcasting_constraints(e1, e2, symbols, counter, gt_tensor) + + elif isinstance(e1, DVar) and isinstance(e2, DVar): + # This is meant to be used for flow analysis only + gt_constraint = BinConstraintD(e1, e2, op_gt) + + my_gt, counter = gen_bvar(counter) + equality_constraint = BinConstraintD(my_gt, gt_constraint, op_eq) + return [equality_constraint], counter + + else: + raise RuntimeError("Sort Mismatch") + + elif isinstance(n.args[0], Node) and not isinstance(n.args[1], Node): + if isinstance(e1, DVar): + # This is meant to be used for flow analysis only + gt_constraint = BinConstraintD(e1, e2, op_gt) + + my_gt, counter = gen_bvar(counter) + equality_constraint = BinConstraintD(my_gt, gt_constraint, op_eq) + return [equality_constraint], counter + + elif isinstance(e1, TVar) and isinstance(e2, int): + # then we made the wrong assumption about the argument being a tensor + # so we should fix the assumption + warnings.warn( + f"Made the wrong assumption for node {n}. Correctness not guaranteed." + ) + + new_e1, counter = gen_dvar(counter) + symbols[n.args[0]] = new_e1 + symbols[n.args[0]] + + gt_constraint = BinConstraintD(new_e1, e2, op_gt) + + my_gt, counter = gen_bvar(counter) + equality_constraint = BinConstraintD(my_gt, gt_constraint, op_eq) + return [equality_constraint], counter + + else: + raise NotImplementedError("Method not yet implemented") + + else: + raise NotImplementedError("Method not yet implemented") + + +@register_inference_rule(operator.eq) +def eq_inference_rule(n: Node, symbols, constraints, counter): + assert isinstance(n.args[0], (Node, int)) + assert isinstance(n.args[1], (Node, int)) + + e1 = symbols[n.args[0]] if isinstance(n.args[0], Node) else n.args[0] + e2 = symbols[n.args[1]] if isinstance(n.args[1], Node) else n.args[1] + + if isinstance(n.args[0], Node) and isinstance(n.args[1], Node): + if isinstance(e1, TVar) and isinstance(e2, TVar): + eq_tensor, counter = gen_tvar(counter) + symbols[n] = eq_tensor + return gen_broadcasting_constraints(e1, e2, symbols, counter, eq_tensor) + + elif isinstance(e1, DVar) and isinstance(e2, DVar): + # This is meant to be used for flow analysis only + eq_constraint = BinConstraintD(e1, e2, op_eq) + + my_eq, counter = gen_bvar(counter) + equality_constraint = BinConstraintD(my_eq, eq_constraint, op_eq) + return [equality_constraint], counter + + else: + raise RuntimeError("Sort Mismatch") + + elif isinstance(n.args[0], Node) and not isinstance(n.args[1], Node): + if isinstance(e1, DVar): + # This is meant to be used for flow analysis only + eq_constraint = BinConstraintD(e1, e2, op_eq) + + my_eq, counter = gen_bvar(counter) + equality_constraint = BinConstraintD(my_eq, eq_constraint, op_eq) + return [equality_constraint], counter + else: + raise NotImplementedError("Method not yet implemented") + else: + raise NotImplementedError("Method not yet implemented") + + +@register_inference_rule(operator.ne) +def neq_inference_rule(n: Node, symbols, constraints, counter): + """ + Translates to inconsistent in gradual types. + To prove inequality, we should prove that + tensors are either different sizes or + disagree on at least one dimension + + This is a WIP (works when the condition + is false. We are working on making this operation work + when the condition is true as well) + """ + assert isinstance(n.args[0], Node) + assert isinstance(n.args[1], tuple) + + # implementing for size 3 and 4 + if len(n.args[1]) == 3: + assert isinstance(n.args[1][0], (Node, int)) + assert isinstance(n.args[1][1], (Node, int)) + assert isinstance(n.args[1][2], (Node, int)) + + lhs = symbols[n.args[0]] + + b, counter = gen_tensor_dims(4, counter) + input_is_size3 = BinConstraintT(lhs, TensorType([b[0], b[1], b[2]]), op_eq) + + d1 = n.args[1][0] if isinstance(n.args[1][0], int) else symbols[n.args[1][0]] + d2 = n.args[1][1] if isinstance(n.args[1][1], int) else symbols[n.args[1][1]] + d3 = n.args[1][2] if isinstance(n.args[1][2], int) else symbols[n.args[1][2]] + + # dimensions not equal + my_ne, counter = gen_bvar(counter) + neq_1 = BinConstraintD(d1, b[0], op_neq) + neq_2 = BinConstraintD(d2, b[1], op_neq) + neq_3 = BinConstraintD(d3, b[2], op_neq) + + # dimensions inconsistent + dims_inconsistent1 = Conj( + [BinConstraintD(d1, Dyn, op_neq), BinConstraintD(b[0], Dyn, op_neq), neq_1] + ) + dims_inconsistent2 = Conj( + [BinConstraintD(d2, Dyn, op_neq), BinConstraintD(b[1], Dyn, op_neq), neq_2] + ) + dims_inconsistent3 = Conj( + [BinConstraintD(d3, Dyn, op_neq), BinConstraintD(b[2], Dyn, op_neq), neq_3] + ) + + dims_inconsistent = Disj( + [dims_inconsistent1, dims_inconsistent2, dims_inconsistent3] + ) + + # we are covering size 3 and 4 only for now + ne_constraint = Conj([input_is_size3, dims_inconsistent]) + + my_ne, counter = gen_bvar(counter) + equality_constraint = BinConstraintD(my_ne, ne_constraint, op_eq) + + elif len(n.args[1]) == 4: + assert isinstance(n.args[1][0], (Node, int)) + assert isinstance(n.args[1][1], (Node, int)) + assert isinstance(n.args[1][2], (Node, int)) + assert isinstance(n.args[1][3], (Node, int)) + + lhs = symbols[n.args[0]] + + b1, counter = gen_dvar(counter) + b2, counter = gen_dvar(counter) + b3, counter = gen_dvar(counter) + b4, counter = gen_dvar(counter) + + input_is_size4 = BinConstraintT(lhs, TensorType([b1, b2, b3, b4]), op_eq) + + d1 = n.args[1][0] if isinstance(n.args[1][0], int) else symbols[n.args[1][0]] + d2 = n.args[1][1] if isinstance(n.args[1][1], int) else symbols[n.args[1][1]] + d3 = n.args[1][2] if isinstance(n.args[1][2], int) else symbols[n.args[1][2]] + d4 = n.args[1][3] if isinstance(n.args[1][3], int) else symbols[n.args[1][3]] + + # dimensions not equal + my_ne, counter = gen_bvar(counter) + neq_1 = BinConstraintD(d1, b1, op_neq) + neq_2 = BinConstraintD(d2, b2, op_neq) + neq_3 = BinConstraintD(d3, b3, op_neq) + neq_4 = BinConstraintD(d4, b4, op_neq) + + # dimensions to inconsistent + dims_inconsistent1 = Conj( + [BinConstraintD(d1, Dyn, op_neq), BinConstraintD(b1, Dyn, op_neq), neq_1] + ) + dims_inconsistent2 = Conj( + [BinConstraintD(d2, Dyn, op_neq), BinConstraintD(b2, Dyn, op_neq), neq_2] + ) + dims_inconsistent3 = Conj( + [BinConstraintD(d3, Dyn, op_neq), BinConstraintD(b3, Dyn, op_neq), neq_3] + ) + dims_inconsistent4 = Conj( + [BinConstraintD(d4, Dyn, op_neq), BinConstraintD(b3, Dyn, op_neq), neq_4] + ) + + dims_inconsistent = Disj( + [ + dims_inconsistent1, + dims_inconsistent2, + dims_inconsistent3, + dims_inconsistent4, + ] + ) + + ne_constraint = Conj([input_is_size4, dims_inconsistent]) + + my_ne, counter = gen_bvar(counter) + + equality_constraint = BinConstraintD(my_ne, ne_constraint, op_eq) + + else: + raise NotImplementedError("Method not yet implemented") + + return [equality_constraint], counter + + +@register_inference_rule(operator.lt) +def lt_inference_rule(n: Node, symbols, constraints, counter): + assert isinstance(n.args[0], (Node, int)) + assert isinstance(n.args[1], (Node, int)) + + # We make sure this node will not be used again. We do not + # generate a constraint about that node. Only about the operands. + + e1 = symbols[n.args[0]] if isinstance(n.args[0], Node) else n.args[0] + e2 = symbols[n.args[1]] if isinstance(n.args[1], Node) else n.args[1] + + if isinstance(n.args[0], Node) and isinstance(n.args[1], Node): + if isinstance(e1, TVar) and isinstance(e2, TVar): + lt_tensor, counter = gen_tvar(counter) + symbols[n] = lt_tensor + return gen_broadcasting_constraints(e1, e2, symbols, counter, lt_tensor) + + elif isinstance(e1, DVar) and isinstance(e2, DVar): + # This is meant to be used for flow analysis only + lt_constraint = BinConstraintD(e1, e2, op_lt) + + my_lt, counter = gen_bvar(counter) + equality_constraint = BinConstraintD(my_lt, lt_constraint, op_eq) + return [equality_constraint], counter + + else: + raise RuntimeError("Sort Mismatch") + + elif isinstance(n.args[0], Node) and not isinstance(n.args[1], Node): + if isinstance(e1, DVar): + # This is meant to be used for flow analysis only + lt_constraint = BinConstraintD(e1, e2, op_lt) + + my_lt, counter = gen_bvar(counter) + equality_constraint = BinConstraintD(my_lt, lt_constraint, op_eq) + return [equality_constraint], counter + else: + raise NotImplementedError("Method not yet implemented") + + else: + raise NotImplementedError("Method not yet implemented") + + +@register_inference_rule(torch.full) +def full_inference_rule(n: Node, symbols, constraints, counter): + full, counter = gen_tvar(counter) + symbols[n] = full + res = [] + + assert isinstance(n.args[0], Iterable) + for arg in n.args[0]: + dim = arg if isinstance(arg, int) else symbols[arg] + res.append(dim) + c = BinConstraintT(full, TensorType(list(res)), op_eq) # type: ignore[arg-type] + return [c], counter + + +# TODO normalize index +@register_inference_rule(torch.arange) +def arange_inference_rule(n: Node, symbols, constraints, counter): + start = 0 + step = 1 + + if len(n.args) == 1: + end = symbols[n.args[0]] + else: + raise NotImplementedError("Not yet implemented") + + # int((end - start) / step) + d1, counter = gen_dvar(counter) + size_constraint = BinConstraintD( + d1, BinConstraintD(BinConstraintD(end, start, op_sub), step, op_div), op_eq + ) + arange, counter = gen_tvar(counter) + symbols[n] = arange + + # either the a parameter is a number or it is Dyn + c1 = Disj( + [ + BinConstraintD(end, Dyn, op_eq), + BinConstraintD(start, Dyn, op_eq), + BinConstraintD(step, Dyn, op_eq), + ] + ) + c2 = BinConstraintD(d1, Dyn, op_eq) + both_dyn = Conj([c1, c2]) + + c11 = Conj( + [ + BinConstraintD(end, Dyn, op_neq), + BinConstraintD(start, Dyn, op_neq), + BinConstraintD(step, Dyn, op_neq), + ] + ) + c22 = BinConstraintD(d1, Dyn, op_neq) + both_numbers = Conj([c11, c22, size_constraint]) + + return [ + BinConstraintT(arange, TensorType([d1]), op_eq), + Disj([both_dyn, both_numbers]), + ], counter + + +def gen_broadcasting_constraints(e1, e2, symbols, counter, output_var): + # additional vars that don't correspond to expressions + e11, counter = gen_tvar(counter) + e22, counter = gen_tvar(counter) + + # generate constraints + c1 = TGreatestUpperBound(output_var, e11, e22) + c2 = ApplyBroadcasting(e11, e22, e1, e2) + c3 = BinConstraintT(e11, e22, op_consistency) + return [c1, c2, c3], counter + + +@register_inference_rule(operator.mul) +@register_inference_rule(torch.ne) +@register_inference_rule("ne") +@register_inference_rule(torch.add) +@register_inference_rule(operator.add) +def broadcasting_inference_rule(n: Node, symbols, constraints, counter): + op_code = None + if n.target is operator.add or n.target is torch.add: + op_code = op_add + elif n.target is operator.mul: + op_code = op_mul + + if isinstance(n.args[0], Node) and isinstance(n.args[1], Node): + if isinstance(symbols[n.args[0]], TVar) and isinstance( + symbols[n.args[1]], TVar + ): + my_output, counter = gen_tvar(counter) + symbols[n] = my_output + e1 = symbols[n.args[0]] + e2 = symbols[n.args[1]] + + return gen_broadcasting_constraints(e1, e2, symbols, counter, my_output) + else: + raise NotImplementedError("Method not yet implemented") + + elif isinstance(n.args[0], Node) and isinstance(n.args[1], (int, float)): + if isinstance(symbols[n.args[0]], TVar): + my_output, counter = gen_tvar(counter) + symbols[n] = my_output + e1 = symbols[n.args[0]] + return [BinConstraintT(my_output, e1, op_eq)], counter + elif isinstance(symbols[n.args[0]], DVar): + my_output, counter = gen_dvar(counter) # type: ignore[arg-type,assignment] + symbols[n] = my_output + e1 = symbols[n.args[0]] + + # we will propagate the runtime value here since this is regular addition + c = Conj( + [ + BinConstraintD( + my_output, BinConstraintD(e1, n.args[1], op_code), op_eq + ), + BinConstraintD(0, my_output, op_leq), + ] + ) + return [c], counter + + elif isinstance(n.args[1], Node) and isinstance(n.args[0], (int, float)): + if isinstance(symbols[n.args[1]], TVar): + my_output, counter = gen_tvar(counter) + symbols[n] = my_output + e2 = symbols[n.args[1]] + return [BinConstraintT(my_output, e2, op_eq)], counter + elif isinstance(symbols[n.args[1]], DVar): + my_output, counter = gen_dvar(counter) # type: ignore[arg-type,assignment] + symbols[n] = my_output + e2 = symbols[n.args[1]] + + # we will propagate the runtime value here since this is regular addition + c = Conj( + [ + BinConstraintD( + my_output, BinConstraintD(e2, n.args[0], op_code), op_eq + ), + BinConstraintD(0, my_output, op_leq), + ] + ) + return [c], counter + + else: + raise NotImplementedError("Method not yet implemented") + + else: + # TODO generate add constraints for scalar addition + raise NotImplementedError("Addition not yet implemented") + + +@register_inference_rule(torch.flatten) +def flatten_inference_rule(n: Node, symbols, constraints, counter): + assert isinstance(n.args[0], Node) + + # generate the new variable + flattened, counter = gen_tvar(counter) + symbols[n] = flattened + + input = symbols[n.args[0]] + + # set the default start and end dims + start_dim = 1 + end_dim = -1 + + if len(n.args) > 1: + assert isinstance(n.args[1], int) + start_dim = n.args[1] + + if len(n.args) > 2: + assert isinstance(n.args[2], int) + end_dim = n.args[2] + + c1 = BinConstraintT(input, Dyn, op_eq) + c2 = BinConstraintT(flattened, Dyn, op_eq) + both_dyn = Conj([c1, c2]) + + const = [] + for i in range(1, MAX_TENSOR_RANK + 1): + c, counter = generate_flatten_constraints( + start_dim, end_dim, input, flattened, i, counter + ) + const.append(c) + + return [Disj([both_dyn, *const])], counter + + +@register_inference_rule(torch.nn.functional.layer_norm) +def layer_norm_functional(n: Node, symbols, constraints, counter): + """ + We generate the constraint: input = output + """ + assert isinstance(n.args[0], Node) + return gen_layer_norm_constraints(n, n.args[1], symbols, counter) + + +@register_inference_rule(torch.nn.LayerNorm) +def layer_norm_inference_rule(n: Node, module_instance, symbols, constraints, counter): + """ + Input and output shapes should be equal. + Input should be consistent with the normalized_shape + """ + assert isinstance(n.args[0], Node) + return gen_layer_norm_constraints( + n, module_instance.normalized_shape, symbols, counter + ) + + +def gen_layer_norm_constraints(n: Node, normalized_shape, symbols, counter): + output, counter = gen_tvar(counter) + symbols[n] = output + input = symbols[n.args[0]] + + input_dyn = BinConstraintT(input, Dyn, op_eq) + output_dyn = BinConstraintT(output, Dyn, op_eq) + + c1 = Conj([input_dyn, output_dyn]) + + c2 = [] + for i in range(1, MAX_TENSOR_RANK + 1): + new_dims_rhs, counter = gen_tensor_dims(i, counter) + nat_constraints = gen_nat_constraints(new_dims_rhs) + + c_tensor_i = Conj( + [ + BinConstraintT(input, TensorType(new_dims_rhs), op_eq), + BinConstraintT(output, TensorType(new_dims_rhs), op_eq), + ] + + add_layer_norm_constraints(new_dims_rhs, list(normalized_shape)) + + nat_constraints + ) + c2.append(c_tensor_i) + return [Disj([c1, Disj(c2)])], counter + + +@register_inference_rule(torch.nn.Dropout) +@register_inference_rule(torch.nn.ReLU) +def relu_inference_rule(n: Node, module_instance, symbols, constraints, counter): + """ + Input and output shapes should be equal. + """ + assert isinstance(n.args[0], Node) + output, counter = gen_tvar(counter) + symbols[n] = output + input = symbols[n.args[0]] + assert isinstance(input, TVar) + return [BinConstraintT(input, output, op_eq)], counter + + +@register_inference_rule(torch.nn.Linear) +def linear_inference_rule(n: Node, module_instance, symbols, constraints, counter): + """ + Input and output sizes should be the same except for the last dimension + If the input is Dyn, then so should the output + """ + assert isinstance(n.args[0], Node) + return linear_constraints( + n, module_instance.in_features, module_instance.out_features, symbols, counter + ) + + +@register_inference_rule("dim") +def torch_dim_inference_rule(n: Node, symbols, constraints, counter): + assert isinstance(n.args[0], Node) + my_dim, counter = gen_dvar(counter) + symbols[n] = my_dim + input = symbols[n.args[0]] + + input_dyn = BinConstraintT(input, Dyn, op_eq) + output_dyn = BinConstraintD(my_dim, Dyn, op_eq) + + c1 = [] + + for i in range(1, MAX_TENSOR_RANK + 1): + new_dims_rhs_1, counter = gen_tensor_dims(i, counter) + + c_tensor_i = Conj( + [ + BinConstraintT(input, TensorType(new_dims_rhs_1), op_eq), + BinConstraintD(my_dim, i, op_eq), + ] + ) + c1.append(c_tensor_i) + + return [Disj([Conj([input_dyn, output_dyn]), Disj(c1)])], counter + + +@register_inference_rule(torch._C._nn.linear) +def torch_linear_inference_rule(n: Node, symbols, constraints, counter): + assert isinstance(n.args[0], Node) + weight_dims, counter = gen_tensor_dims(2, counter) + equality_constraint = BinConstraintT( + symbols[n.args[1]], TensorType(weight_dims), op_eq + ) + constraints, counter = linear_constraints( + n, weight_dims[1], weight_dims[0], symbols, counter + ) + return [equality_constraint] + constraints, counter + + +def linear_constraints(n: Node, in_features, out_features, symbols, counter): + linear_output, counter = gen_tvar(counter) + symbols[n] = linear_output + linear_input = symbols[n.args[0]] + + input_dyn = BinConstraintT(linear_input, Dyn, op_eq) + output_dyn = BinConstraintT(linear_output, Dyn, op_eq) + + c1 = Conj([input_dyn, output_dyn]) + + c2 = [] + for i in range(1, MAX_TENSOR_RANK + 1): + new_dims_rhs_1, counter = gen_tensor_dims(i, counter) + new_dims_rhs_2, counter = gen_tensor_dims(i, counter) + + nat_constraints = gen_nat_constraints(new_dims_rhs_1 + new_dims_rhs_2) + + c_tensor_i = Conj( + [ + BinConstraintT(linear_input, TensorType(new_dims_rhs_1), op_eq), + BinConstraintT(linear_output, TensorType(new_dims_rhs_2), op_eq), + ] + + add_linear_constraints( + new_dims_rhs_1, new_dims_rhs_2, in_features, out_features + ) + + nat_constraints + ) + c2.append(c_tensor_i) + return [Disj([c1, Disj(c2)])], counter + + +def add_layer_norm_constraints(input_dim, normalized_dim): + """ + The constraints say that the type has te form: [*, 1024, 1024] + while the normalized_dim have the form [1024, 1024] + Args: + input_dim: Input shape of layer norm + normalized_dim: normalized_dim parameter of the module instance + + """ + + # in this case we return false since there's a pattern mismatch + if len(normalized_dim) > len(input_dim): + return [F()] + + else: + constraints = [] + for i, n in zip(reversed(input_dim), reversed(normalized_dim)): + constraints.append(BinConstraintD(i, n, op_consistency)) + return constraints + + +def add_linear_constraints(dims1, dims2, in_features, out_features): + assert len(dims1) == len(dims2) + constraints = [] + for i in range(len(dims1)): + if i == len(dims1) - 1: + constraints.append(BinConstraintD(dims1[i], in_features, op_consistency)) + constraints.append(BinConstraintD(dims2[i], out_features, op_eq)) + else: + constraints.append(BinConstraintD(dims1[i], dims2[i], op_eq)) + + return constraints + + +@register_inference_rule(torch.reshape) +def reshape_inference_rule(n: Node, symbols, constraints, counter): + assert isinstance(n.args[0], Node) + + # generate the new variable + my_reshape, counter = gen_tvar(counter) + symbols[n] = my_reshape + + src_var = symbols[n.args[0]] + t2 = n.args[1] + t2_type = TensorType([Dyn if elem == -1 else elem for elem in t2]) # type: ignore[union-attr] + c1 = BinConstraintT(my_reshape, t2_type, op_eq) # type: ignore[union-attr] + c2 = CanReshape(src_var, t2_type) + + return [c1, c2], counter + + +@register_inference_rule(BatchNorm2d) +def batchnorm_inference_rule(n: Node, module_instance, symbols, constraints, counter): + assert isinstance(n.args[0], Node) + + # generate the new variable + batchnorm_output, counter = gen_tvar(counter) + symbols[n] = batchnorm_output + batchnorm_input = symbols[n.args[0]] + + # dim vars + d1, counter = gen_dvar(counter) + d2, counter = gen_dvar(counter) + d3, counter = gen_dvar(counter) + d4, counter = gen_dvar(counter) + + nat_constraints = gen_nat_constraints([d1, d2, d3, d4]) + + c1 = BinConstraintT(batchnorm_input, TensorType([d1, d2, d3, d4]), op_matching) + c2 = BinConstraintT(batchnorm_input, batchnorm_output, op_eq) + return [c1, c2, *nat_constraints], counter + + +@register_inference_rule(torch.nn.AdaptiveAvgPool2d) +def adaptive_inference_rule(n: Node, module_instance, symbols, constraints, counter): + assert isinstance(n.args[0], Node) + + avg_pool, counter = gen_tvar(counter) + + symbols[n] = avg_pool + input_var = symbols[n.args[0]] + + # dim vars + d1, counter = gen_dvar(counter) + d2, counter = gen_dvar(counter) + d3, counter = gen_dvar(counter) + d4, counter = gen_dvar(counter) + nat_constraints = gen_nat_constraints([d1, d2, d3, d4]) + c1 = BinConstraintT(input_var, TensorType([d1, d2, d3, d4]), op_matching) + c2 = BinConstraintT( + avg_pool, + TensorType( + [d1, d2, module_instance.output_size[0], module_instance.output_size[1]] + ), + op_eq, + ) + + return [c1, c2, *nat_constraints], counter + + +@register_inference_rule(Conv2d) +def conv2d_inference_rule(n: Node, module_instance, symbols, constraints, counter): + assert isinstance(n.args[0], Node) + + my_conv, counter = gen_tvar(counter) + symbols[n] = my_conv + input_var = symbols[n.args[0]] + + # dim vars + [d1, d2, d3, d4], counter = gen_tensor_dims(MAX_TENSOR_RANK, counter) + + # c1 = Matching(input_var, TensorType([d1, d2, d3, d4])) + c1 = BinConstraintT(input_var, TensorType([d1, d2, d3, d4]), op_matching) + + # c2 = DConsistency(module_instance.in_channels, d2) + c2 = BinConstraintD(module_instance.in_channels, d2, op_consistency) + + c3 = CalcConv( + my_conv, + input_var, + module_instance.out_channels, + module_instance.kernel_size, + module_instance.padding, + module_instance.stride, + module_instance.dilation, + [d1, d2, d3, d4], + ) + + nat_constraints = gen_nat_constraints([d1, d2, d3, d4]) + + return [c1, c2, c3, *nat_constraints], counter + + +@register_inference_rule(torch.nn.MaxPool2d) +def maxpool_inference_rule(n: Node, module_instance, symbols, constraints, counter): + assert isinstance(n.args[0], Node) + maxpool, counter = gen_tvar(counter) + symbols[n] = maxpool + input_var = symbols[n.args[0]] + + # dim vars + [d1, d2, d3, d4], counter = gen_tensor_dims(MAX_TENSOR_RANK, counter) + + c1 = BinConstraintT(input_var, TensorType([d1, d2, d3, d4]), op_matching) + + c2 = CalcMaxPool( + maxpool, + input_var, + module_instance.kernel_size, + module_instance.padding, + module_instance.stride, + module_instance.dilation, + [d1, d2, d3, d4], + ) + + nat_constraints = gen_nat_constraints([d1, d2, d3, d4]) + + return [c1, c2, *nat_constraints], counter + + +class ConstraintGenerator: + def __init__(self, traced, graph=None): + self.traced = traced # traced or tracer.root + self.traced_params = dict(self.traced.named_parameters()) + self.constraints = [] + self.symbol_dict = {} + self.graph = traced.graph if hasattr(traced, "graph") else graph + + def generate_constraints(self, counter=0): + """ + Iterate through every node and generate constraints + Effect: self.constraints will be populated with the final constraints + """ + graph = self.graph + + all_constraints = [] + + # pyrefly: ignore [missing-attribute] + for n in graph.nodes: + (constraints, counter) = self.generate_constraints_node(n, counter) + all_constraints += constraints + + return Conj(all_constraints), counter + + def generate_constraints_node(self, n: Node, counter): + """ + Generate constraints the given node: + Currently supported operations: + - Reshape + - Add + - conv2d + """ + + if n.op == "placeholder": + x, counter = gen_tvar(counter) + self.symbol_dict[n] = x + + my_type = n.type + + if n.type != Dyn and (not isinstance(n.type, TensorType)): + if n.type == torch.nn.parameter.Parameter: + # since we have a parameter, the shape must be static + assert "example_value" in n.meta + my_type = TensorType(n.meta["example_value"].size()) + else: + my_type = Dyn + + c1 = BinConstraintT(my_type, x, op_precision) + c2 = BinConstraintT(x, MAX_TENSOR_RANK, op_leq) + return [c1, c2], counter + + elif n.op == "call_function": + if n.target in _INFERENCE_RULES: + return _INFERENCE_RULES[n.target]( + n, self.symbol_dict, self.constraints, counter + ) + else: + raise RuntimeError( + f"No inference rule registered for target {n.target}!" + ) + + elif n.op == "call_module": + module_instance = self.traced.get_submodule(n.target) + if type(module_instance) in _INFERENCE_RULES: + return _INFERENCE_RULES[type(module_instance)]( + n, module_instance, self.symbol_dict, self.constraints, counter + ) + else: + raise RuntimeError( + f"No inference rule registered for class {type(module_instance)}!" + ) + + elif n.op == "call_method": + if n.target in _INFERENCE_RULES: + return _INFERENCE_RULES[n.target]( + n, self.symbol_dict, self.constraints, counter + ) + else: + raise RuntimeError( + f"No inference rule registered for target {n.target}!" + ) + + elif n.op == "get_attr": + t = self.traced_params.get(n.target, None) + + if isinstance(t, torch.Tensor): + if len(t.shape) > 0: + res = list(t.shape) + attr_type = TensorType(res) + output, counter = gen_tvar(counter) + self.symbol_dict[n] = output + return [BinConstraintT(output, attr_type, op_eq)], counter + else: + # scalar? + return [], counter + else: + return [], counter + + elif n.op == "output": + return [], counter + + else: + raise NotImplementedError(f"Method {n.op} not yet implemented") diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/migrate_gradual_types/operation.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/migrate_gradual_types/operation.py new file mode 100644 index 0000000000000000000000000000000000000000..267100c8545c8b2310299337ecf64211f633f6ce --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/migrate_gradual_types/operation.py @@ -0,0 +1,14 @@ +op_add = "+" +op_sub = "-" +op_mul = "*" +op_div = "/" +op_eq = "=" +op_neq = "!=" +op_imp = "=>" +op_matching = "\u22b3" # (contains) +op_consistency = "~" +op_precision = "\u2291" # (square image of or equal to) +op_leq = "\u2264" # less-than or equal to +op_lt = "<" +op_gt = ">" +op_mod = "%" diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/migrate_gradual_types/z3_types.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/migrate_gradual_types/z3_types.py new file mode 100644 index 0000000000000000000000000000000000000000..939f4865ab7d982289303093db2024eda6603521 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/migrate_gradual_types/z3_types.py @@ -0,0 +1,30 @@ +try: + import z3 # type: ignore[import] + + HAS_Z3 = True + # dynamic type + dyn = z3.DeclareSort("Dyn") + dyn_type = z3.Const("dyn", dyn) + + # dimension + dim = z3.Datatype("dim") + dim.declare("dim", ("0", z3.IntSort()), ("1", z3.IntSort())) + dim = dim.create() + + # tensors + tensor_type = z3.Datatype("TensorType") + tensor_type.declare("Dyn", ("dyn", dyn)) + tensor_type.declare("tensor1", ("0", dim)) + tensor_type.declare("tensor2", ("0", dim), ("1", dim)) + tensor_type.declare("tensor3", ("0", dim), ("1", dim), ("2", dim)) + tensor_type.declare("tensor4", ("0", dim), ("1", dim), ("2", dim), ("3", dim)) + tensor_type = tensor_type.create() + + # create dimension + D = dim.dim + + z3_dyn = tensor_type.Dyn(dyn_type) + + +except ImportError: + HAS_Z3 = False diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/normalize.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/normalize.py new file mode 100644 index 0000000000000000000000000000000000000000..e2dd3c962bbe4d274284d8db26bac70a1a170bed --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/normalize.py @@ -0,0 +1,164 @@ +# mypy: allow-untyped-defs +import operator +from collections.abc import Callable +from typing import Any, Optional + +import torch +import torch.fx +import torch.fx as fx +from torch.fx import Proxy, Transformer +from torch.fx.node import Argument, map_aggregate, Node, Target +from torch.fx.operator_schemas import ( + create_type_hint, + normalize_function, + normalize_module, +) + +from .schema_type_annotation import AnnotateTypesWithSchema + + +class NormalizeArgs(Transformer): + """ + Normalize arguments to Python targets. This means that + `args/kwargs` will be matched up to the module/functional's + signature and rewritten to exclusively kwargs in positional order + if `normalize_to_only_use_kwargs` is true. Also populates default + values. Does not support positional-only parameters or varargs + parameters (*args, **kwargs). + + If the nodes have 'type' metadata, it will use it to disambiguate + overloads. Otherwise, it will throw an error. + + Example usage: + m = torchvision.models.resnet18() + traced = torch.fx.symbolic_trace(m) + traced = NormalizeArgs(traced).transform() + """ + + def __init__( + self, module: torch.fx.GraphModule, normalize_to_only_use_kwargs: bool = True + ): + super().__init__(module) + self.node_map: dict[Proxy, Node] = {} + self.normalize_to_only_use_kwargs = normalize_to_only_use_kwargs + + def run_node(self, n: Node) -> Any: + args, kwargs = self.fetch_args_kwargs_from_env(n) + + def get_type(arg): + if isinstance(arg, fx.Node): + return n.meta.get("type") + return type(arg) + + arg_types = map_aggregate(n.args, get_type) + assert isinstance(arg_types, tuple) + arg_types = tuple(create_type_hint(i) for i in arg_types) + kwarg_types = {k: get_type(v) for k, v in kwargs.items()} + if n.op == "call_function": + out = self.call_function(n.target, args, kwargs, arg_types, kwarg_types) + else: + out = super().run_node(n) + if n.op != "output": + self.node_map[out] = n + out.node.meta = n.meta + out.node.type = n.type + return out + + def call_function( + self, + target: Target, + args: tuple[Argument, ...], + kwargs: dict[str, Any], + arg_types: Optional[tuple[Any, ...]] = None, + kwarg_types: Optional[dict[str, Any]] = None, + ): + assert callable(target) + new_args_and_kwargs = normalize_function( + target, + args, # type: ignore[arg-type] + kwargs, + arg_types, # type: ignore[arg-type] + kwarg_types, + self.normalize_to_only_use_kwargs, + ) + if new_args_and_kwargs: + new_args, new_kwargs = new_args_and_kwargs + return self.tracer.create_proxy( + "call_function", target, new_args, new_kwargs + ) + else: + return super().call_function(target, args, kwargs) + + def call_module( + self, target: Target, args: tuple[Argument, ...], kwargs: dict[str, Any] + ): + assert isinstance(target, str) + new_args_and_kwargs = normalize_module( + self.module, + target, + args, # type: ignore[arg-type] + kwargs, + self.normalize_to_only_use_kwargs, + ) + if new_args_and_kwargs: + new_args, new_kwargs = new_args_and_kwargs + return super().call_module(target, new_args, new_kwargs) + else: + return super().call_module(target, args, kwargs) + + +class NormalizeOperators(AnnotateTypesWithSchema): + """ + Normalize callsites that are different ways of "spelling" the same + invocation into a single, canonical call. Currently supports: + + 1. Normalize operators (e.g. operator.add) to the `torch` ops they + ultimately invoke (e.g. torch.add) when it is possible to statically + reason that + + Example usage: + + m = torchvision.models.resnet18() + + traced = torch.fx.symbolic_trace(m) + + traced = NormalizeOperators(traced).transform() + """ + + binary_magic_method_remap: dict[ + Callable[[Any, Any], Any], Callable[[Any, Any], Any] + ] = { + torch.add: operator.add, + torch.mul: operator.mul, + torch.sub: operator.sub, + torch.div: operator.truediv, + torch.floor_divide: operator.floordiv, + torch.remainder: operator.mod, + torch.eq: operator.eq, + torch.ne: operator.ne, + torch.lt: operator.lt, + torch.le: operator.le, + torch.gt: operator.gt, + torch.ge: operator.ge, + } + + def call_function( + self, target: Target, args: tuple[Argument, ...], kwargs: dict[str, Any] + ): + # Normalize operators according to the magic methods implemented on tensors here: + # https://github.com/pytorch/pytorch/blob/28c5d90b679c6b38bf4183ec99f16d933c2f1bcd/tools/autograd/templates/python_variable_methods.cpp#L1137 # noqa: B950 + + assert callable(target) + + if target in self.binary_magic_method_remap: + if len(args) != 2: + return super().call_function(target, args, kwargs) + lhs, rhs = args + + return super().call_function( + target=self.binary_magic_method_remap[target], + args=(lhs, rhs), + kwargs={}, + ) + + return super().call_function(target, args, kwargs) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/optimization.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/optimization.py new file mode 100644 index 0000000000000000000000000000000000000000..219e6f66c7bf52d8f4bf6384b871dee4a9a494d1 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/optimization.py @@ -0,0 +1,490 @@ +# mypy: allow-untyped-defs +import copy +import logging +import operator +import time +from collections import defaultdict +from collections.abc import Iterable +from enum import Enum +from typing import Any, cast, Optional + +import torch +import torch.fx as fx +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.mkldnn as th_mkldnn +from torch.fx.node import Argument, Target +from torch.fx.passes.shape_prop import ShapeProp +from torch.nn.utils.fusion import fuse_conv_bn_eval, fuse_linear_bn_eval + + +__all__ = [ + "matches_module_pattern", + "replace_node_module", + "fuse", + "remove_dropout", + "extract_subgraph", + "modules_to_mkldnn", + "reset_modules", + "MklSubgraph", + "gen_mkl_autotuner", + "use_mkl_length", + "UnionFind", + "optimize_for_inference", +] + + +def _parent_name(target: str) -> tuple[str, str]: + """ + Splits a qualname into parent path and last atom. + For example, `foo.bar.baz` -> (`foo.bar`, `baz`) + """ + *parent, name = target.rsplit(".", 1) + return parent[0] if parent else "", name + + +# Works for length 2 patterns with 2 modules +def matches_module_pattern( + pattern: Iterable[type], node: fx.Node, modules: dict[str, Any] +): + if len(node.args) == 0: + return False + nodes: tuple[Any, fx.Node] = (node.args[0], node) + for expected_type, current_node in zip(pattern, nodes): + if not isinstance(current_node, fx.Node): + return False + if current_node.op != "call_module": + return False + if not isinstance(current_node.target, str): + return False + if current_node.target not in modules: + return False + if type(modules[current_node.target]) is not expected_type: + return False + return True + + +def replace_node_module( + node: fx.Node, modules: dict[str, Any], new_module: torch.nn.Module +): + assert isinstance(node.target, str) + parent_name, name = _parent_name(node.target) + modules[node.target] = new_module + setattr(modules[parent_name], name, new_module) + + +def fuse(model: torch.nn.Module, inplace=False, no_trace=False) -> torch.nn.Module: + """ + Fuses convolution/BN and linear/BN layers for inference purposes. + Will deepcopy your model by default, but can modify the model inplace as well. + """ + patterns = [ + (nn.Conv1d, nn.BatchNorm1d), + (nn.Conv2d, nn.BatchNorm2d), + (nn.Conv3d, nn.BatchNorm3d), + (nn.Linear, nn.BatchNorm1d), + ] + if not inplace: + model = copy.deepcopy(model) + if not no_trace or not isinstance(model, torch.fx.GraphModule): + fx_model = fx.symbolic_trace(model) + else: + fx_model = model + modules = dict(fx_model.named_modules()) + new_graph = copy.deepcopy(fx_model.graph) + + for pattern in patterns: + for node in new_graph.nodes: + if matches_module_pattern(pattern, node, modules): + if len(node.args[0].users) > 1: + # Output of conv/linear is used by other nodes + continue + first_layer = modules[node.args[0].target] + bn = modules[node.target] + if not bn.track_running_stats: + continue + if pattern[0] in [nn.Conv1d, nn.Conv2d, nn.Conv3d]: + fused_layer = fuse_conv_bn_eval(first_layer, bn) + else: # nn.Linear + fused_layer = fuse_linear_bn_eval(first_layer, bn) + replace_node_module(node.args[0], modules, fused_layer) + node.replace_all_uses_with(node.args[0]) + new_graph.erase_node(node) + return fx.GraphModule(fx_model, new_graph) + + +def remove_dropout(model: nn.Module) -> nn.Module: + """ + Removes all dropout layers from the module. + """ + fx_model = fx.symbolic_trace(model) + + class DropoutRemover(torch.fx.Transformer): + def call_module( + self, target: Target, args: tuple[Argument, ...], kwargs: dict[str, Any] + ) -> Any: + if isinstance(self.submodules[target], nn.Dropout): + assert len(args) == 1 + return args[0] + else: + return super().call_module(target, args, kwargs) + + return DropoutRemover(fx_model).transform() + + +def extract_subgraph( + orig_module: nn.Module, + nodes: list[fx.Node], + inputs: list[fx.Node], + outputs: list[fx.Node], +): + """ + Given lists of nodes from an existing graph that represent a subgraph, returns a submodule that executes that subgraph. + """ + new_graph = fx.Graph() + env: dict[fx.Node, fx.Node] = {} + for input in inputs: + new_node = new_graph.placeholder(input.name) + env[input] = new_node + for node in nodes: + new_node = new_graph.node_copy(node, lambda x: env[x]) + env[node] = new_node + new_graph.output([env[output] for output in outputs]) + new_graph.lint() + return fx.GraphModule(orig_module, new_graph) + + +mkldnn_supported = [ + nn.Conv2d, + nn.Linear, + nn.BatchNorm2d, + nn.ReLU, + nn.MaxPool2d, + nn.AvgPool2d, + nn.AdaptiveAvgPool2d, + torch.relu, + torch.transpose, + torch.sigmoid, + F.relu, + F.avg_pool2d, + F.adaptive_avg_pool2d, +] +# These are operators that may not be convertible into MKLDNN ops (e.g. the +# args are scalar values). Thus, we only include them in the subgraph if their +# arguments are already in MKLDNN. +# TODO: Determine whether this can be removed after type inference. +mkldnn_supported_unknown = [operator.add, operator.mul] +mkldnn_map = { + nn.Conv2d: th_mkldnn.MkldnnConv2d, + nn.Linear: th_mkldnn.MkldnnLinear, + nn.BatchNorm2d: lambda a, _: th_mkldnn.MkldnnBatchNorm(a), +} + + +def modules_to_mkldnn(nodes: list[fx.Node], modules: dict[str, nn.Module]): + """ + For each node, if it's a module that can be preconverted into MKLDNN, + then we do so and create a mapping to allow us to convert from the MKLDNN + version of the module to the original. + """ + old_modules: dict[nn.Module, nn.Module] = {} + for node in nodes: + if node.op == "call_module": + assert isinstance(node.target, str) + cur_module = modules[node.target] + if type(cur_module) in mkldnn_map: + # pyrefly: ignore [index-error] + new_module = mkldnn_map[type(cur_module)](cur_module, torch.float) + assert isinstance(new_module, nn.Module) + old_modules[new_module] = copy.deepcopy(cur_module) + replace_node_module(node, modules, new_module) + return old_modules + + +def reset_modules( + nodes: list[fx.Node], + modules: dict[str, nn.Module], + old_modules: dict[nn.Module, nn.Module], +): + """ + Maps each module that's been changed with `modules_to_mkldnn` back to its + original. + """ + for node in nodes: + if node.op == "call_module": + assert isinstance(node.target, str) + cur_module = modules[node.target] + if cur_module in old_modules: + replace_node_module(node, modules, old_modules[cur_module]) + + +class MklSubgraph: + def __init__(self, fx_graph: fx.Graph): + self.fx_graph = fx_graph + self.nodes: list[fx.Node] = [] + self.start_nodes: list[fx.Node] = [] + self.end_nodes: list[fx.Node] = [] + + +def gen_mkl_autotuner(example_inputs, iters=10, warmup=1): + """ + This generates a heuristic that can be passed into `optimize_for_inference` that + determines whether a subgraph should be run in MKL by running it with the example_inputs. + + Example usage: + heuristic = gen_mkl_autotuner(example_inputs, iters=10) + fast_model = optimization.optimize_for_inference(model, heuristic) + """ + fx_model = None + old_modules = None + + def use_mkl_heuristic(graph: MklSubgraph) -> bool: + nonlocal fx_model, old_modules + input_nodes = graph.start_nodes + if fx_model is None: + fx_model = graph.fx_graph.owning_module + old_modules = graph.fx_graph.old_modules # type: ignore[attr-defined] + ShapeProp(fx_model).propagate(example_inputs) + sample_inputs = [torch.randn(node.shape) for node in input_nodes] # type: ignore[attr-defined] + output_args = cast(list[fx.Node], [node.args[0] for node in graph.end_nodes]) + submodule = extract_subgraph(fx_model, graph.nodes, input_nodes, output_args) + + def benchmark(f): + for _ in range(warmup): + f() + begin = time.time() + for _ in range(iters): + f() + return time.time() - begin + + mkl_time = benchmark( + lambda: [ + i.to_dense() for i in submodule(*[i.to_mkldnn() for i in sample_inputs]) + ] + ) + + reset_modules( + submodule.graph.nodes, + dict(submodule.named_modules()), + # pyrefly: ignore [bad-argument-type] + old_modules, + ) + no_mkl_time = benchmark(lambda: submodule(*sample_inputs)) + return mkl_time < no_mkl_time + + return use_mkl_heuristic + + +def use_mkl_length(graph: MklSubgraph) -> bool: + """ + This is a heuristic that can be passed into `optimize_for_inference` that + determines whether a subgraph should be run in MKL by checking if there + are more than 2 nodes in it + """ + return len(graph.nodes) > 2 + + +class UnionFind: + def __init__(self, n): + self.parent: list[Optional[int]] = [None] * n + self.size: list[int] = [0] * n + + def make_set(self, v: int): + self.parent[v] = v + self.size[v] = 1 + + def find(self, v: int) -> int: + par = self.parent[v] + if v == par: + return v + assert par is not None + self.parent[v] = self.find(par) + return cast(int, self.parent[v]) + + def join(self, a: int, b: int): + a, b = self.find(a), self.find(b) + if a == b: + return a + if self.size[a] < self.size[b]: + a, b = b, a + self.parent[b] = a + self.size[a] += self.size[b] + + +def optimize_for_inference( + model: torch.nn.Module, + pass_config: Optional[dict[str, Any]] = None, + tracer: type[fx.Tracer] = fx.Tracer, +) -> torch.nn.Module: + """ + Performs a set of optimization passes to optimize a model for the + purposes of inference. Specifically, the passes that are run are: + 1. Conv/BN fusion + 2. Dropout removal + 3. MKL layout optimizations + + The third optimization takes a function `use_mkl_heuristic` that's used + to determine whether a subgraph should be explicitly run in MKL layout. + + Note: As FX does not currently handle aliasing, this pass currently + assumes nothing aliases. If that isn't true, use at your own risk. + """ + default_pass_config = { + "conv_bn_fuse": True, + "remove_dropout": True, + "mkldnn_layout_optimize": {"heuristic": use_mkl_length}, + } + if pass_config is None: + pass_config = {} + default_pass_config.update(pass_config) + + if default_pass_config["conv_bn_fuse"]: + model = fuse(model) + if default_pass_config["remove_dropout"]: + model = remove_dropout(model) + if default_pass_config["mkldnn_layout_optimize"] is False: + return model + if not isinstance(default_pass_config["mkldnn_layout_optimize"], dict): + raise RuntimeError("mkldnn_layout_optimize config is not a dict") + if "heuristic" not in default_pass_config["mkldnn_layout_optimize"]: + raise RuntimeError("Heuristic not found in mkldnn_layout_optimize config") + use_mkl_heuristic = default_pass_config["mkldnn_layout_optimize"]["heuristic"] + + cur_tracer = tracer() + fx_graph = cur_tracer.trace(copy.deepcopy(model)) + fx.GraphModule(cur_tracer.root, fx_graph) + modules: dict[str, nn.Module] = dict(model.named_modules()) + + class MklSupport(Enum): + NO = 1 + YES = 2 + UNKNOWN = 3 + + # Inserts to_mkldnn and to_dense around every node we want to be a MKLDNN node. + # If the op is in `mkldnn_supported` then we always treat it as a MKLDNN node. + # However, if it's in `mkldnn_supported_unknown`, then we only treat it as + # a MKLDNN node if its inputs are MKLDNN nodes. + for node in list(fx_graph.nodes): + supports_mkldnn = MklSupport.NO + if node.op == "call_module": + cur_module = modules[node.target] + if type(cur_module) in mkldnn_supported: + supports_mkldnn = MklSupport.YES + sample_parameter = next(cur_module.parameters(), None) + if sample_parameter is not None: + assert sample_parameter.dtype == torch.float, ( + "this pass is only for torch.float modules" + ) + assert sample_parameter.device == torch.device("cpu"), ( + "this pass is only for CPU modules" + ) + elif node.op == "call_function": + if node.target in mkldnn_supported: + supports_mkldnn = MklSupport.YES + elif node.target in mkldnn_supported_unknown: + supports_mkldnn = MklSupport.UNKNOWN + + if supports_mkldnn != MklSupport.NO: + if supports_mkldnn == MklSupport.UNKNOWN: + if not any(arg.target == "to_dense" for arg in node.args): + continue + with fx_graph.inserting_before(node): + mkldnn_args = fx.map_arg( + node.args, lambda n: fx_graph.call_method("to_mkldnn", (n,)) + ) + + node.args = cast(tuple[fx.node.Argument], mkldnn_args) + + with fx_graph.inserting_after(node): + dense_x = fx_graph.create_node("call_method", "to_dense", (node,)) + node.replace_all_uses_with(dense_x) + dense_x.args = (node,) + + # Does pre-conversion of all modules into MKLDNN (when possible) + old_modules = modules_to_mkldnn(list(fx_graph.nodes), modules) + fx_graph.old_modules = old_modules # type: ignore[attr-defined] + + # optimizes all a -> to_dense -> to_mkldnn -> b patterns into a -> b + for node in fx_graph.nodes: + if node.op == "call_method" and node.target == "to_dense": + prv_node = node.args[0] + users = list(node.users) + for user in users: + if user.op == "call_method" and user.target == "to_mkldnn": + user.replace_all_uses_with(prv_node) + fx_graph.erase_node(user) + if len(node.users) == 0: + fx_graph.erase_node(node) + + num_nodes = len(fx_graph.nodes) + uf = UnionFind(num_nodes) + + def get_color(n): + if hasattr(n, "color"): # Current node is part of a MKL subgraph + return uf.find(n.color) + if hasattr(n, "start_color"): # Current node is input to MKL subgraph + return uf.find(n.start_color) + return None + + # This code is to find each MKLDNN subgraph. Each MKLDNN subgraph consists + # of input nodes (which are only `to_mkldnn` calls), output nodes + # (`to_dense` calls), and intermediate nodes, which are run entirely on + # MKLDNN layout tensors. + # + # Specifically, this code does a flood fill on a directed acyclic graph + # (DAG), starting from each possible "start node" (i.e: `to_mkldnn` nodes). + # If every node only had one input, this would be sufficient. However, in + # the case that a node has multiple inputs coming from different start + # nodes (i.e. colors), we need to join these 2 colors into 1. That's done + # using a Disjoint Set Union. + for cur_idx, node in enumerate(fx_graph.nodes): + if node.op == "call_method" and node.target == "to_mkldnn": + node.start_color = cur_idx + uf.make_set(cur_idx) + elif node.op == "call_method" and node.target == "to_dense": + assert get_color(node.args[0]) is not None + node.end_color = get_color(node.args[0]) + else: + cur_colors = [ + get_color(i) + for i in node.all_input_nodes + if isinstance(i, fx.Node) + if get_color(i) is not None + ] + + if len(cur_colors) == 0: + continue + assert not any(i is None for i in cur_colors) + cur_colors = sorted(cur_colors) + node.color = cur_colors[0] + for other_color in cur_colors[1:]: + uf.join(cur_colors[0], other_color) + + mkldnn_graphs: dict[int, MklSubgraph] = defaultdict(lambda: MklSubgraph(fx_graph)) + for node in fx_graph.nodes: + if hasattr(node, "color"): + mkldnn_graphs[uf.find(node.color)].nodes.append(node) + if hasattr(node, "start_color"): + mkldnn_graphs[uf.find(node.start_color)].start_nodes.append(node) + if hasattr(node, "end_color"): + mkldnn_graphs[uf.find(node.end_color)].end_nodes.append(node) + + # Now that we have all the subgraphs, we need to decide which MKLDNN + # subgraphs we actually want to keep in MKLDNN. + for graph in mkldnn_graphs.values(): + if not use_mkl_heuristic(graph): + for node in graph.start_nodes + graph.end_nodes: + prv = node.args[0] + node.replace_all_uses_with(prv) # type: ignore[arg-type] + fx_graph.erase_node(node) + reset_modules(graph.nodes, modules, old_modules) + + mkldnn_conversions = 0 + for node in fx_graph.nodes: + if node.target == "to_mkldnn" or node.target == "to_dense": + mkldnn_conversions += 1 + + logging.getLogger(__name__).info("mkldnn conversions: %s", mkldnn_conversions) + fx_graph.lint() + result = fx.GraphModule(model, fx_graph) + return result diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/partitioner_utils.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/partitioner_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..3658dd1a9ce96aff26adbc5f47818e9e57e13d35 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/partitioner_utils.py @@ -0,0 +1,317 @@ +# mypy: allow-untyped-defs +from enum import Enum +from typing import NamedTuple + +from torch.fx.node import map_arg, Node + + +class Partition: + """Partition class contains all the information about an individual partition. + It also provides necessary methods for manipulation the partition. + """ + + def __init__(self, partition_id: int) -> None: + self.nodes: set[Node] = set() + self.partition_id = partition_id + self.parents: set[Partition] = set() + self.children: set[Partition] = set() + self.bfs_level: int = -1 + self.used_mem_bytes: int = 0 + self.logical_device_ids: list[int] = [] + + def __str__(self): + return str(self.partition_id) + + def recalculate_mem_size(self): + self.used_mem_bytes = 0 + for node in self.nodes: + self.used_mem_bytes += get_extra_size_of(node, self.nodes) + + def add_node(self, node): + input_nodes: dict[Node, None] = {} + map_arg(node.args, input_nodes.setdefault) + map_arg(node.kwargs, input_nodes.setdefault) + # Add current node's input nodes if they are placeholder or constants + for n in input_nodes: + if n.op in {"placeholder", "get_attr"}: + self.nodes.add(n) + self.nodes.add(node) + self.recalculate_mem_size() + + def remove_node(self, node): + # Remove a node only if the node is in the partition + if node in self.nodes: + self.nodes.remove(node) + # Collect the node's input nodes + input_nodes: dict[Node, None] = {} + map_arg(node.args, input_nodes.setdefault) + map_arg(node.kwargs, input_nodes.setdefault) + # Check if an input node is a placeholder or get_attr, + # and this input node is not used by some other nodes in this partition, + # the remove this input node + for input_node in input_nodes: + if all( + n not in self.nodes for n in input_node.users + ) and input_node.op in {"placeholder", "get_attr"}: + self.nodes.remove(input_node) + self.recalculate_mem_size() + + +class Device(NamedTuple): + name: str + available_mem_bytes: int + logical_id: int + + +class NodeLatency(NamedTuple): + # Latency due to the memory bandwidth + mem_latency_sec: float + # Latency due to the computation + computer_latency_sec: float + + +class PartitionLatency(NamedTuple): + # Sum of all nodes' memory latency on the critical path + mem_latency_sec: float + # Sum of all nodes' compute latency on the critical path + computer_latency_sec: float + # Latency of the critical path + overall_latency_sec: float + + +class PartitionMode(Enum): + size_based = 0 + sparse_nn = 1 + cost_aware = 2 + kl_based = 3 + aot_based = 4 + + +class PartitionerConfig(NamedTuple): + devices: list[Device] + mode: PartitionMode = PartitionMode.size_based + transfer_rate_bytes_per_sec: float = 0.0 + node_to_latency_mapping: dict[Node, NodeLatency] = {} + node_to_partition_mapping: dict[Node, int] = {} + partition_to_logical_device_mapping: dict[int, list[int]] = {} + # Saturate host by replicating partitions to the remaining idle devices. + saturate_host: bool = False + + +def get_extra_size_of(node: Node, nodes: set[Node]) -> int: + """Given a node and a set of nodes, + this function return the extra size that needed + if this node is included in this set. + """ + # Find all its input nodes + input_nodes: dict[Node, None] = {} + map_arg(node.args, input_nodes.setdefault) + map_arg(node.kwargs, input_nodes.setdefault) + # Calculate total size of related nodes + total_size_of_input_nodes = 0 + for n in input_nodes: + # Make sure this node hasn't been in this set yet + if n not in nodes: + size_bytes = getattr(n, "size_bytes", None) + if size_bytes: + total_size_of_input_nodes += size_bytes.output_size + else: + raise RuntimeError("node has no size_bytes attr") + # Don't forget the op node itself + size_bytes = getattr(node, "size_bytes", None) + if size_bytes: + total_size_of_input_nodes += size_bytes.total_size + else: + raise RuntimeError("node has no size_bytes attr") + return total_size_of_input_nodes + + +def get_latency_of_one_partition( + partition: Partition, node_to_latency_mapping: dict[Node, NodeLatency] +) -> PartitionLatency: + """Given a partition and its nodes' latency, return a PartitionLatency for this partition""" + + def get_top_nodes(partition: Partition) -> list[Node]: + """Given a partition, return a list of nodes on the top bfs level""" + top_nodes: list[Node] = [] + for node in partition.nodes: + # Skip placeholder and get_attr nodes + if node.op in {"placeholder", "get_attr"}: + continue + input_nodes: dict[Node, None] = {} + map_arg(node.args, input_nodes.setdefault) + map_arg(node.kwargs, input_nodes.setdefault) + # If a node has no input nodes in this partition, + # or its input nodes in this partition are placeholders and get_attrs + # this node is on the top bfs level in this partition + if not any( + n in partition.nodes and n.op not in {"placeholder", "get_attr"} + for n in input_nodes + ): + top_nodes.append(node) + return top_nodes + + def dfs_helper(node: Node, partition_latency) -> PartitionLatency: + """Given a top node of a partition, this function returns + the latency of the critical path in the partition + """ + node_latency = node_to_latency_mapping[node] + # Calculate the current overall latency of the partition + overall_latency_sec = partition_latency.overall_latency_sec + max( + node_latency.computer_latency_sec, node_latency.mem_latency_sec + ) + # Update the mem latency of this path + mem_latency_sec = ( + partition_latency.mem_latency_sec + node_latency.mem_latency_sec + ) + # Update the compute latency of this path + computer_latency_sec = ( + partition_latency.computer_latency_sec + node_latency.computer_latency_sec + ) + # Get all users of this node that are in this partition + users = set(node.users).intersection(partition.nodes) + if users: + max_latency = PartitionLatency( + mem_latency_sec=0.0, computer_latency_sec=0.0, overall_latency_sec=0.0 + ) + for n in users: + # Get new partition latency recursively + new_partition_latency = dfs_helper( + n, + PartitionLatency( + mem_latency_sec, computer_latency_sec, overall_latency_sec + ), + ) + if ( + new_partition_latency.overall_latency_sec + > max_latency.overall_latency_sec + ): + max_latency = new_partition_latency + return max_latency + # If there is no user, the node is at bottom of the partition + return PartitionLatency( + mem_latency_sec, computer_latency_sec, overall_latency_sec + ) + + # Main part starts + # Get all top level nodes of this partition + top_nodes = get_top_nodes(partition) + critical_path_latency = PartitionLatency( + mem_latency_sec=0.0, computer_latency_sec=0.0, overall_latency_sec=0.0 + ) + # Go through all top nodes and find the largest latency (critical pass latency) + for node in top_nodes: + partition_latency = dfs_helper( + node, + PartitionLatency( + mem_latency_sec=0.0, computer_latency_sec=0.0, overall_latency_sec=0.0 + ), + ) + if ( + partition_latency.overall_latency_sec + > critical_path_latency.overall_latency_sec + ): + critical_path_latency = partition_latency + return critical_path_latency + + +def get_partition_to_latency_mapping( + partitions: list[Partition], node_to_latency_mapping: dict[Node, NodeLatency] +) -> dict[Partition, PartitionLatency]: + """Given all the partitions and node_to_latency_mapping dictionary, + return a mapping dictionary of each partition to its overall latency + """ + partition_to_latency_mapping: dict[Partition, PartitionLatency] = {} + # Go through each partition and get its latency + for partition in partitions: + partition_latency = get_latency_of_one_partition( + partition, node_to_latency_mapping + ) + partition_to_latency_mapping[partition] = partition_latency + return partition_to_latency_mapping + + +def get_comm_latency_between( + parent_partition: Partition, + child_partition: Partition, + transfer_rate_bytes_per_sec: float, +): + """Given two partitions (parent and child), + calculate the communication latency between the two. + """ + # If two partitions are on the same device, the comm latency is 0. + if ( + parent_partition.logical_device_ids != [] + and child_partition.logical_device_ids != [] + and parent_partition.logical_device_ids == child_partition.logical_device_ids + ): + return 0.0 + # Keep tracking the communication size between parent and child + comm_size = 0 + # Keep tracking all the counted node + visited_nodes = set() + # Go through all nodes in the child partition + # If a node has input nodes from the parent partition, + # the output size of those input nodes will be counted + # and added to comm_size + for node in child_partition.nodes: + input_nodes: dict[Node, None] = {} + map_arg(node.args, input_nodes.setdefault) + map_arg(node.kwargs, input_nodes.setdefault) + for n in input_nodes: + if n in parent_partition.nodes and n not in visited_nodes: + size_bytes = getattr(n, "size_bytes", None) + if size_bytes is not None: + comm_size += size_bytes.output_size + visited_nodes.add(n) + return comm_size / transfer_rate_bytes_per_sec + + +def get_latency_of_partitioned_graph( + partitions: list[Partition], + partition_to_latency_mapping: dict[Partition, PartitionLatency], + transfer_rate_bytes_per_sec: float, +): + """Given all partitions in a graph, find the critical path among all partitions + and return its latency as the latency of the whole graph + """ + + def dfs_helper(partition: Partition, latency_so_far_sec: float) -> float: + """This function helps to recursively get the latency of a path of partitions""" + # Update latency by adding current partition's latency + latency_so_far_sec += partition_to_latency_mapping[ + partition + ].overall_latency_sec + + if partition.children: + max_latency_sec = 0.0 + for child in partition.children: + # Calculate latency between + comm_latency_sec = get_comm_latency_between( + partition, child, transfer_rate_bytes_per_sec + ) + new_latency_sec = dfs_helper( + child, latency_so_far_sec + comm_latency_sec + ) + if new_latency_sec > max_latency_sec: + max_latency_sec = new_latency_sec + return max_latency_sec + return latency_so_far_sec + + def get_top_partitions(partitions: list[Partition]) -> list[Partition]: + """This function is to return all the partitions without parents + as the starting points of all the paths + """ + # If a partition has no parents, then it is a top partition + top_partitions = [ + partition for partition in partitions if len(partition.parents) == 0 + ] + return top_partitions + + top_partitions = get_top_partitions(partitions) + critical_path_latency_sec = 0.0 + for partition in top_partitions: + latency_sec = dfs_helper(partition, 0.0) + if latency_sec > critical_path_latency_sec: + critical_path_latency_sec = latency_sec + return critical_path_latency_sec diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py new file mode 100644 index 0000000000000000000000000000000000000000..f763ad2ee2cfc1e3bd500f1da9877144aca1a3b2 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py @@ -0,0 +1,2817 @@ +# mypy: allow-untyped-decorators +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +import functools +import inspect +import logging +import operator +import threading +import typing +import typing_extensions +import weakref +from collections import defaultdict, OrderedDict +from collections.abc import Callable, Generator, Mapping, Sequence +from contextlib import _GeneratorContextManager, contextmanager, ExitStack, nullcontext +from dataclasses import dataclass +from typing import ( + Any, + Concatenate, + Optional, + overload, + Protocol, + TYPE_CHECKING, + TypeVar, + Union, +) +from typing_extensions import ParamSpec, Self, TypeVarTuple, Unpack +from weakref import WeakKeyDictionary + +import torch +import torch._ops +import torch.fx as fx +import torch.fx.traceback as fx_traceback +import torch.utils._pytree as pytree +from torch import SymBool, SymInt, Tensor +from torch._dispatch.python import enable_python_dispatcher +from torch._library.fake_class_registry import FakeScriptObject +from torch._library.opaque_object import is_opaque_type +from torch._logging import trace_structured +from torch._ops import HigherOrderOperator +from torch._subclasses.fake_impls import fast_detach +from torch._subclasses.fake_tensor import ( + FakeTensor, + FakeTensorMode, + is_fake, + unset_fake_temporarily, +) +from torch._subclasses.meta_utils import is_sparse_any +from torch.fx import GraphModule, Proxy, Tracer +from torch.fx.graph_module import _assign_attr +from torch.fx.node import ( + _side_effectful_need_to_be_preserved_pre_dispatch, + Argument, + Target, +) +from torch.fx.passes.shape_prop import _extract_tensor_metadata +from torch.nn import Module +from torch.overrides import TorchFunctionMode +from torch.utils._python_dispatch import ( + _disable_infra_mode, + _push_mode, + _unset_infra_mode, + autograd_would_have_decomposed, + TorchDispatchMode, +) +from torch.utils._stats import count +from torch.utils._thunk import Thunk +from torch.utils.weak import _WeakHashRef, WeakIdKeyDictionary, WeakTensorKeyDictionary + +from ._backward_state import BackwardState +from .sym_node import SymNode + + +if TYPE_CHECKING: + import types + from collections.abc import MutableMapping + + import sympy + + from torch._ops import OpOverload + from torch.fx._symbolic_trace import PHBase + from torch.types import BoolLikeType, FloatLikeType, IntLikeType + +__all__ = [ + "PythonKeyTracer", + "dispatch_trace", + "make_fx", + "DecompositionInterpreter", + "selective_decompose", + "py_sym_types", + "get_innermost_proxy_mode", + "get_proxy_mode", + "handle_sym_dispatch", + "maybe_enable_thunkify", + "maybe_disable_thunkify", +] + +_ProxyTracer = Union["PythonKeyTracer", "_GraphAppendingTracerEx"] + +_AnyScriptObject = (torch.ScriptObject, FakeScriptObject) +_AnyScriptObjectType = Union[torch.ScriptObject, FakeScriptObject] + +aten = torch.ops.aten +prim = torch.ops.prim + +log = logging.getLogger(__name__) +not_implemented_log = torch._logging.getArtifactLogger(__name__, "not_implemented") + +CURRENT_DECOMPOSITION_TABLE: Mapping[OpOverload, Callable] = {} + +CONSTANT_NUMEL_LIMIT = 1 + +T = TypeVar("T") +U = TypeVar("U") +_P = ParamSpec("_P") +R = TypeVar("R") +_Ts = TypeVarTuple("_Ts") + +null_ctx_type = type(nullcontext) +# We currently convert all SymInt to proxies before we use them. +# This could plausibly be handled at the Dynamo level. +pytree.register_pytree_node( + torch.Size, + lambda xs: (list(xs), None), + lambda xs, _: tuple(xs), + # pyrefly: ignore [bad-argument-type] + flatten_with_keys_fn=lambda xs: ( + [(pytree.SequenceKey(i), x) for i, x in enumerate(xs)], + None, + ), + serialized_type_name="torch.Size", +) +# Ideally unflattening should not lose info, but we unflatten +# torch.Size to tuple (see above). This is necessary because the +# torch.Size constructor only accepts ints whereas our infra often +# transforms them to non-ints, e.g. symint proxies. Anyway, losing +# such info can cause pytree mapping or spec matching to fail, so +# work around this problem using the following dict as needed. +_pytree_subclasses_that_lose_info = {torch.Size: tuple} + + +def fake_signature(fn: Callable[_P, R], nargs: int) -> Callable[_P, R]: + """FX gets confused by varargs, de-confuse it""" + argnames = ",".join(f"arg{i}" for i in range(nargs)) + return eval(f"lambda {argnames}: fn({argnames})", {"fn": fn}) + + +@contextmanager +def decompose( + decomposition_table: Optional[Mapping[OpOverload, Callable]], +) -> Generator[Mapping[OpOverload, Callable], None, None]: + global CURRENT_DECOMPOSITION_TABLE + old_decomposition_table = CURRENT_DECOMPOSITION_TABLE + CURRENT_DECOMPOSITION_TABLE = decomposition_table or {} + try: + yield CURRENT_DECOMPOSITION_TABLE + finally: + CURRENT_DECOMPOSITION_TABLE = old_decomposition_table + + +# ensure we cannot collide with other properties +proxy_slot = object() + + +class _NoDefault: + pass + + +no_default = _NoDefault() + +from torch.types import py_sym_types, PySymType + + +class _HasMeta(Protocol): + meta: dict[str, PySymType] + + +def is_sym_node(node: _HasMeta) -> bool: + assert hasattr(node, "meta"), "All nodes traced with proxy_tensor should have meta" + return "val" in node.meta and isinstance(node.meta["val"], py_sym_types) + + +@overload # type: ignore[no-overload-impl] +def set_proxy_slot(obj: Tensor, tracer: _ProxyTracer, proxy: _ProxyTensor) -> None: ... + + +@overload +def set_proxy_slot( + obj: _AnyScriptObjectType, tracer: _ProxyTracer, proxy: Proxy +) -> None: ... + + +@overload +def set_proxy_slot( + obj: PySymType, tracer: _ProxyTracer, proxy: _PySymProxyType +) -> None: ... + + +class _DisableUpdateTensorTracker(threading.local): + value: bool = False + + +_disable_update_tensor_tracker_tls = _DisableUpdateTensorTracker() + + +_FAKE_TENSOR_ID_TO_PROXY_MAP_FOR_EXPORT: dict[int, torch.fx.Node] = {} + + +def _is_proxy_tensor_update_tensor_tracker_disabled() -> bool: + """ + Returns current state of disabling update tensor tracker. + """ + return _disable_update_tensor_tracker_tls.value + + +@contextmanager +def _proxy_tensor_disable_update_tensor_tracker() -> Generator[None, None, None]: + """ + NOTE "Do not clobber inplace ops" + By default tensor_tracker is updated every time. + This leads to chaining every operation by the FakeTensor. + For example for mutable ops if we have several consecutive mutable operations: + + def f(x, y, z): + x.copy_(y) + x.copy_(z) + return x + + Default graph result: + def f_graph(x, y, z) + x_1 = x.copy_(y) + x_2 = x_1.copy_(z) + return x_2 + + This chaining simplifies the fx passes and helps to prevent the reordering. + But in some cases, we want those nodes to be disconnected. + E.g. in case of splitting joint graph into forward and backward. + If first inplace op happened in forward, second in backward, + we want them after split to be properly placed. + + Enabling this context manager for copy_ will result in: + def f_graph_2(x, y, z): + x_1 = x.copy_(y) + x_2 = x.copy_(z) + return x + + Results of copy_ x1 and x2 will have empty users in the graph. + The reason why this behavior is not enabled for all inplace ops is that + some fx passes (e.g. fx quantization) rely on chaining inplace ops like add_ + in their fusions passes. + We could revisit enabling this logic for all inplace ops in future. + """ + orig_value = _disable_update_tensor_tracker_tls.value + _disable_update_tensor_tracker_tls.value = True + try: + yield + finally: + _disable_update_tensor_tracker_tls.value = orig_value + + +def set_proxy_slot( # type: ignore[no-redef] + obj: Union[PySymType, _AnyScriptObjectType, Tensor], + tracer: _ProxyTracer, + proxy: object, +) -> None: + log.debug("set_proxy_slot %s (%s) %s", obj, id(obj), proxy) + if isinstance(obj, Tensor): + # We DO want to clobber proxies whenever we run an inplace operation + # on a tensor, and it affects the metadata on the proxy. + assert isinstance(proxy, _ProxyTensor) + # see NOTE [Do not clobber inplace ops] + if not _is_proxy_tensor_update_tensor_tracker_disabled(): + tracer.tensor_tracker[obj] = proxy + elif isinstance(obj, (_AnyScriptObject)): + # We DO want to clobber proxies, with a similar rationale as for tensors. + assert isinstance(proxy, Proxy) + tracer.script_object_tracker[obj] = proxy + else: + # NB: Never clobber pre-existing proxy. Although the proxies + # are in principle equivalent, when we do graph partitioning + # we need there not to be spurious dependencies on tangent inputs. + # This works because primals get their SymInts set first, and + # THEN later we allocate tangent inputs. Make sure if a SymInt + # is derivable from a primal that we use that. + assert isinstance(obj, py_sym_types), type(obj) + if obj not in tracer.symnode_tracker: + proxy = typing.cast(_PySymProxyType, proxy) + tracer.symnode_tracker[obj] = proxy + + # WAR: python test/dynamo/test_subclasses.py + # TestNestedTensor.test_basic_autograd + # + # AOTAutograd doesn't pass the "outer sizes" as an actual argument + # to make_fx, but it is made use of internally in AOTAutograd's + # call to tensor unflatten. Because the outer sizes isn't passed + # as an argument, it is therefore untracked. However, it turns + # out you luck out, because *Dynamo* will manually add the outer + # sizes as an argument so you can fix up the proxy'ness. + # + # This is probably fixed in + # https://github.com/pytorch/pytorch/pull/125941/ + import sympy + + if isinstance(obj.node.expr, sympy.Symbol): + tracer.sympy_expr_tracker[obj.node.expr] = _SympyExprTrackerValue( + proxy, obj + ) + + +def has_proxy_slot(obj: Tensor, tracer: _ProxyTracer) -> bool: + assert isinstance(obj, (Tensor, SymNode)), type(obj) + # pyrefly: ignore [no-matching-overload] + return bool(get_proxy_slot(obj, tracer, False, lambda _: True)) + + +_PySymProxyType = Thunk[Proxy] + + +@overload +def get_proxy_slot( + obj: Tensor, + tracer: _ProxyTracer, +) -> _ProxyTensor: ... + + +@overload +def get_proxy_slot( + obj: Tensor, + tracer: _ProxyTracer, + default: U, +) -> Union[_ProxyTensor, U]: ... + + +@overload +def get_proxy_slot( + obj: Tensor, + tracer: _ProxyTracer, + default: U, + transform: Callable[[_ProxyTensor], R], +) -> Union[R, U]: ... + + +@overload +def get_proxy_slot( + obj: _AnyScriptObjectType, + tracer: _ProxyTracer, +) -> Proxy: ... + + +@overload +def get_proxy_slot( + obj: _AnyScriptObjectType, + tracer: _ProxyTracer, + default: U, +) -> Union[Proxy, U]: ... + + +@overload +def get_proxy_slot( + obj: _AnyScriptObjectType, + tracer: _ProxyTracer, + default: U, + transform: Callable[[Proxy], R], +) -> Union[R, U]: ... + + +@overload +def get_proxy_slot( + obj: PySymType, + tracer: _ProxyTracer, +) -> _PySymProxyType: ... + + +@overload +def get_proxy_slot( + obj: PySymType, + tracer: _ProxyTracer, + default: T, +) -> Union[T, _PySymProxyType]: ... + + +@overload +def get_proxy_slot( + obj: PySymType, + tracer: _ProxyTracer, + default: U, + transform: Callable[[_PySymProxyType], R], +) -> Union[R, U]: ... + + +# the default argument is what to return if the slot is not set. +# the transform argument is handy if you need to extract a subfield from +# the successfully looked up result (but NOT the default.) +def get_proxy_slot( + obj: Union[Tensor, _AnyScriptObjectType, PySymType], + tracer: _ProxyTracer, + default: object = no_default, + transform: Callable = lambda x: x, +) -> object: + tracker: Any + if isinstance(obj, Tensor): + tracker = tracer.tensor_tracker + elif isinstance(obj, _AnyScriptObject): + tracker = tracer.script_object_tracker + else: + assert isinstance(obj, py_sym_types), type(obj) + tracker = tracer.symnode_tracker + + # pyrefly: ignore [index-error] + # pyrefly: ignore [no-matching-overload, bad-argument-type] + value = tracker.get(obj) + + if value is None and isinstance(obj, py_sym_types): + if obj.node.is_symbolic(): + # Last ditch - we found a SymInt (SymBool, etc) we don't know + # about. + if (tmp := tracer.sympy_expr_tracker.get(obj.node.expr)) is not None: + value = tmp.proxy + + else: + # Attempt to build it from first principles. + _build_proxy_for_sym_expr(tracer, obj.node.expr, obj) + # pyrefly: ignore [no-matching-overload] + value = tracker.get(obj) + + if value is None: + # We don't know this value - return the default. + if isinstance(default, _NoDefault): + raise RuntimeError( + f"{obj} ({type(obj)}, {id(obj)})is not tracked with proxy for {tracer}" + ) + return default + + res = transform(value) + return res + + +@functools.cache +def _sympy_handlers() -> dict[type[sympy.Expr], Callable[..., Any]]: + """ + Returns a dict converting sympy functions to python operators + (i.e. `sympy.Mul` -> `operator.mul`) + """ + import torch.utils._sympy.interp + + handlers = {} + for k, v in torch.utils._sympy.interp.handlers().items(): + op = getattr(operator, v, None) + if op is not None: + handlers[k] = op + return handlers + + +def _build_proxy_for_sym_expr( + tracer: _ProxyTracer, expr: sympy.Expr, out: PySymType | None = None +) -> IntLikeType | FloatLikeType | BoolLikeType | None: + """ + Decompose `expr` and look for the pieces as inputs. If `out` is provided + then that will be the resulting SymNode (and `out.expr` must be the same as + `expr`). + + This function is used when the ProxyTorchDispatchMode sees a SymNode + that it hasn't seen before to try to associate it with traced inputs. + + How can this happen? + + First thing to remember is that although sympy.Exprs are interned (so + `sympy.Expr("s3*s4")` will always have the same `id` and will always compare + equal) SymNode does not (so doing `SymNode("s3")*SymNode("s4")` twice in a + row will give two unique SymNodes). + + - On way for this to happen is if we turn off tracing to compute an + intermediate value and then USE that value with tracing turned on - for + example if we turn off tracing to do some FakeTensor propagation to + compute a size (dtensor does this) but then turn tracing back on and use + that computed size. + + - Another way is if we compute a size in one graph and stash it somewhere + hidden (such as in some meta-data) and later use it in a different graph + (dtensor does this too). Since the size was computed in the first graph + and it's not an official input to the second graph it's not tracked + properly. This is often going to show up as it usually works in fullgraph + but a graph break causes a failure. + + To handle this we decompose the sympy.Expr and look for the pieces as + inputs. But there are problems with this approach: + + - We lose operation provanance: We end up figuring out where to get the + inputs - but those may not actually be correct. If we have "s1" coming in + from both tensor1 and tensor2 and we pick the wrong one we could end up + keeping a tensor alive longer than intended. + + - There's no guarantee that those values are inputs to the graph: If we have + "s1*s2" computed in a graph #1 and used in graph #2 there's no guarantee + that the input that holds "s1" is actually an input on graph #2. + + - The decomposition isn't guaranteed to be the same: Sympy can "simplify" + expressions so it's possible that our inputs are "s1*s2" and "s3" but we + decompose it into "s1" and "s2*s3" - which wouldn't be found. + + Other ways we could handle this: + + - Don't: Just require that all inputs are tracked properly. This is the + "correct" solution but harder because you need to track down each + potential problem one by one and fix them. And when it fails it's a lot of + work to figure out both why it's failing and the right way to fix it. This + is complicated by the fact that a stashed value could be incorrect but + work fine until we happen to get an graph break in the wrong place - so it + may be a while before the bug is found. (Maybe we need a "dynamo abuse + mode" where we run tests with as many graph breaks inserted as possible?) + + - Track SymNode ops separately from proxy tracing: Right now SymNode + operations are tracked as part of the proxy tracing - so when we disable + proxy tracing we also disable SymNode tracing. But we don't have to do + that - we could instead always have SymNodes track where they came from + and just use that when needed. This solves the problem of tracing being + temporarily turned off but doesn't help if an input isn't present after a + graph break. + + - Better decomposition: Right now the decomposition is pretty simple. We do + have a sat-solver available to us so we could theoretically do a better + job figuring out a "correct" decomposition. But that still relies on + having the inputs available at all - which isn't a guarantee. + """ + + if (value := tracer.sympy_expr_tracker.get(expr)) is not None: + assert not out + return value.value + + if isinstance(expr, (int, float, bool)): + return expr + if expr.is_Integer: + return int(expr) + if expr.is_Float: + return float(expr) + + args = [] + for arg in expr.args: + if (arg_value := _build_proxy_for_sym_expr(tracer, arg)) is None: + return None + args.append(arg_value) + args = tuple(args) + + func: OpOverload | None = _sympy_handlers().get(expr.func) # type: ignore[assignment] + if not func: + # Handler not found + return None + + if out is None: + out = func(*args) + else: + _sym_register(tracer, func, args, out) + return out + + +def snapshot_fake(val: Tensor, include_real: bool = False) -> Optional[Tensor]: + # val.detach() will also eventually call fast_detach(), + # but this saves us a full trip into __torch_dispatch__ + # (snapshot_fake is called a lot) + if isinstance(val, FakeTensor): + return fast_detach(val.fake_mode, val, include_real) + else: + return val.detach() + + +_ExtractValType = Optional[ + Union[ + PySymType, + _AnyScriptObjectType, + BackwardState, + list["_ExtractValType"], + tuple["_ExtractValType", ...], + dict[str, "_ExtractValType"], + Tensor, + int, + float, + bool, + ] +] + + +def extract_val(val: _ExtractValType, include_real: bool = False) -> _ExtractValType: + if is_fake(val): + return snapshot_fake(val, include_real=include_real) + elif isinstance(val, py_sym_types): + return val + elif isinstance(val, _AnyScriptObject): + return val + elif isinstance(val, BackwardState): + return val + elif isinstance(val, (list, tuple)): + return val.__class__([extract_val(x) for x in val]) + elif isinstance(val, dict): + return {k: extract_val(v) for k, v in val.items()} + elif isinstance(val, Tensor): + if not val.is_sparse: + # NB: Kinda hacky, but we should try to get val as the metadata + # everywhere + # TODO: This doesn't properly track storages. A more robust + # approach would be to maintain a per-trace FakeTensorMode and + # from_real_tensor to create fake values (don't forget to + # snapshot_fake) + from torch._guards import detect_fake_mode + + fake_tensor_mode = detect_fake_mode(val) + if not fake_tensor_mode: + fake_tensor_mode = FakeTensorMode(allow_fallback_kernels=True) + with fake_tensor_mode: + return torch.empty_strided( + val.shape, val.stride(), device=val.device, dtype=val.dtype + ) + else: + return None + elif isinstance(val, (int, float, bool)): + return val + elif val is None: + return None + + typing_extensions.assert_never(val) + + +@contextmanager +def _enable_thunkify( + tracer: _ProxyTracer, *, enable: bool = True +) -> Generator[None, None, None]: + """ + Enable thunkification inside the context manager. Thunkification prevents + SymNode computation from directly being traced into an FX graph; instead, + the compute is only added to the graph if it is actually used. This helps + us track SymNode compute when it is computed (since we need /something/ + to put in the tracker) even if it is unlikely to be used. + """ + old = tracer.enable_thunkify + tracer.enable_thunkify = enable + try: + yield + finally: + tracer.enable_thunkify = old + + +@contextmanager +def maybe_disable_thunkify() -> Generator[None, None, None]: + """Within a context, disable thunkification. See :func:`maybe_enable_thunkify` + for more details. This is helpful if you have a wrapper function which + you want to enable thunkification on, but in some segment on the inside (say, + the original user function), you want to disable thunkification as you know + it is not needed there. + """ + proxy_mode = get_proxy_mode() + if proxy_mode is not None: + with _enable_thunkify(proxy_mode.tracer, enable=False): + yield + else: + yield + + +@contextmanager +def maybe_enable_thunkify() -> Generator[None, None, None]: + """Within this context manager, if you are doing make_fx tracing, we will thunkify + all SymNode compute and avoid tracing it into the graph unless it is actually needed. + You should prefer to avoid using this as much as possible, as lazy evaluation of + SymNode tracing can lead to long chains of thunks which will stack overflow + if you evaluate them. However, this is currently sometimes necessary as there + are buggy parts of PT2 which will fail with "s0 is not tracked with proxy" error + due to insufficient tracing of SymNode computation. + """ + proxy_mode = get_proxy_mode() + if proxy_mode is not None: + with _enable_thunkify(proxy_mode.tracer): + yield + else: + yield + + +# Note [invariants for node meta 'val'] +# What invariants do we have for the 'val' set on the FX node? It has accurate +# metadata... but only for metadata that exists "below" all other subsystems +# (most notably autograd, but also vmap, functorch transforms, etc). This means +# you can get the dtype, shape, stride, storage, but you CANNOT get requires_grad, +# grad_fn, _base (_base actually may be set due to recursive call to +# ADInplaceOrView, but you shouldn't rely on it.) +def set_meta(proxy: Proxy, val: _ExtractValType) -> Proxy: + proxy.node.meta["val"] = extract_val( + val, include_real=(proxy.node.op == "placeholder") + ) + + with _enable_thunkify(proxy.tracer): # type: ignore[arg-type] + # Best effort tensor_meta setting; prefer using val! + if is_fake(val): + proxy.node.meta["tensor_meta"] = _extract_tensor_metadata(val) + elif isinstance(val, Tensor) and not val.is_sparse: + proxy.node.meta["tensor_meta"] = _extract_tensor_metadata(val) + return proxy + + +def thunkify( + tracer: _ProxyTracer, f: Callable[_P, R], *args: _P.args, **kwargs: _P.kwargs +) -> Thunk[R]: + """ + Delays computation of f until it's called again + Also caches the result + """ + if tracer.enable_thunkify: + return Thunk(functools.partial(f, *args, **kwargs)) + else: + r = f(*args, **kwargs) + return Thunk(lambda: r) + + +def track_tensor( + tensor: Tensor, proxy: Proxy, *, constant: Optional[Tensor], tracer: _ProxyTracer +) -> None: + def try_set_proxy_slot( + outer_s: IntLikeType, + proxy_callable: Callable[Concatenate[PySymType, _P], Proxy], + *args: _P.args, + **kwargs: _P.kwargs, + ) -> None: + assert callable(proxy_callable) + if isinstance(outer_s, SymInt): + with _enable_thunkify(tracer): + set_proxy_slot( + outer_s, + tracer, + thunkify(tracer, proxy_callable, outer_s, *args, **kwargs), + ) + + # The basic idea is that we need to associate each tensor/SymInt + # with a Proxy. How do we setup this association? We just store + # the proxy on the proxy slot of the object, keyed on the tracer + # (so that if we have multiple tracers at the same time, they + # don't clobber each other.) + for i, s in enumerate(tensor.shape): + try_set_proxy_slot( + s, + lambda x, i: set_meta( + tracer.create_proxy( + "call_function", torch.ops.aten.sym_size.int, (proxy, i), {} + ), + x, + ), + i, + ) + + if not is_sparse_any(tensor): + for i, s in enumerate(tensor.stride()): + try_set_proxy_slot( + s, + lambda x, i: set_meta( + tracer.create_proxy( + "call_function", torch.ops.aten.sym_stride.int, (proxy, i), {} + ), + x, + ), + i, + ) + + try_set_proxy_slot( + tensor.numel(), + lambda x: set_meta( + tracer.create_proxy( + "call_function", torch.ops.aten.sym_numel.default, (proxy,), {} + ), + x, + ), + ) + if not is_sparse_any(tensor): + try_set_proxy_slot( + tensor.storage_offset(), + lambda x: set_meta( + tracer.create_proxy( + "call_function", + torch.ops.aten.sym_storage_offset.default, + (proxy,), + {}, + ), + x, + ), + ) + set_proxy_slot(tensor, tracer, _ProxyTensor(proxy, constant)) + + +_NestedProxys = Union[ + Proxy, Sequence["_NestedProxys"], Mapping[object, "_NestedProxys"] +] +_NestedTensors = Union[ + Tensor, Sequence["_NestedTensors"], Mapping[object, "_NestedTensors"] +] + + +def track_tensor_tree( + inner_res: T, + proxy_res: _NestedProxys, + *, + constant: Optional[_NestedTensors], + tracer: _ProxyTracer, +) -> T: + # NB: We call set_unbacked_bindings only on the *topmost* call to + # track_tensor_tree, not recursive calls. This is because there must + # be only ONE unbacked_binding proxy call, and it should be the one + # where all of the unbacked SymInts actually first come into existence. + # If you call this again on the inner proxies for the tuple projections, + # you will have multiple unbacked_bindings for the same symbol, but + # they're not going to show up anywhere. + # + # I was briefly deceived into setting unbacked bindings recursively when + # working on https://github.com/pytorch/pytorch/pull/133585 because I + # observed that some extra unbacked bindings were needed to handle some + # higher order operator code. But actually it looks like this was + # just an unrelated bug that needed to be fixed separately. + _set_unbacked_bindings(inner_res, proxy_res) + + def wrap_with_proxy( + e: object, proxy: _NestedProxys, constant: Optional[_NestedTensors] + ) -> None: + if isinstance(e, Tensor): + assert isinstance(proxy, Proxy) + assert constant is None or isinstance(constant, Tensor) + track_tensor(e, proxy, tracer=tracer, constant=constant) + set_meta(proxy, e) + elif isinstance(e, py_sym_types): + assert isinstance(proxy, Proxy) + # NB: eagerly set meta here, so that the numbering is in order + set_meta(proxy, e) + set_proxy_slot(e, tracer, thunkify(tracer, lambda: proxy)) + elif isinstance(e, _AnyScriptObject): + assert isinstance(proxy, Proxy) + set_proxy_slot(e, tracer, proxy) + set_meta(proxy, e) + elif isinstance(e, (tuple, list)): + # example use case: allreduce_ returns ([tensor], work) + if isinstance(proxy, fx.Proxy): + set_meta(proxy, e) + + def get_constant( + c: Optional[_NestedTensors], idx: int + ) -> Optional[_NestedTensors]: + if c is None: + return None + else: + assert isinstance(c, (list, tuple)) + return c[idx] + + for idx, ee in enumerate(e): + # Use an indexer here - if proxy is a List then it will unwrap + # it. If it's a Proxy then it will proxy the getelem. + wrap_with_proxy(ee, proxy[idx], get_constant(constant, idx)) # type: ignore[index] + + elif isinstance(e, dict): + # example use case: triton_kernel_wrapper takes arguments as kwargs + + # In theory we could support const-prop when proxy-tensor-tracing + # operators that returns dicts of tensors, but we have no use case + # for it today (since the only op we currently trace that can + # return a dict is triton_kernel_wrapper_functional/mutation, + # which does not participate in const-prop) + assert constant is None + + if isinstance(proxy, fx.Proxy): + set_meta(proxy, e) + + for key, val in e.items(): + wrap_with_proxy(val, proxy[key], None) # type: ignore[index] + + elif isinstance(e, BackwardState): + assert isinstance(proxy, Proxy) + set_meta(proxy, e) + e.proxy = proxy + else: + # intentionally pass on primitives + pass + + wrap_with_proxy(inner_res, proxy_res, constant) + + return inner_res + + +@dataclass +class _ProxyTensor: + proxy: Proxy + constant: Optional[Tensor] + + +def fetch_sym_proxy( + tracer: _ProxyTracer, +) -> Callable[[PySymType], Union[bool, int, float, Proxy]]: + def inner(e: PySymType) -> Union[int, bool, float, Proxy]: + n = e.node + if n.constant is not None: + return n.constant + if e.node.expr.is_number: + if isinstance(e, SymBool): + return bool(e.node.expr) + elif isinstance(e, SymInt): + return int(e.node.expr) + return float(e.node.expr) + else: + assert isinstance(e, py_sym_types) + # NB: we REQUIRE all symints to be tracked + return get_proxy_slot(e, tracer).force() + + return inner + + +@overload +def fetch_object_proxy( + tracer: _ProxyTracer, t: Tensor +) -> Union[_ProxyTensor, Tensor]: ... + + +@overload +def fetch_object_proxy( + tracer: _ProxyTracer, t: _AnyScriptObjectType +) -> Union[Proxy, _AnyScriptObjectType]: ... + + +@overload +def fetch_object_proxy( + tracer: _ProxyTracer, t: PySymType +) -> Union[_PySymProxyType, PySymType]: ... + + +def fetch_object_proxy( + tracer: _ProxyTracer, t: Union[Tensor, _AnyScriptObjectType, PySymType] +) -> object: + return get_proxy_slot(t, tracer, t) + + +HANDLED_TYPES = (Tensor, torch.nn.Parameter, FakeTensor) + + +def _maybe_record_pointwise_barrier( + func: object, proxy_mode: ProxyTorchDispatchMode +) -> None: + """ + Records operators whose tensor outputs or inputs are fp16/bf16 so downstream pointwise code can + emulate eager's rounding behavior when emulate_precision_casts is enabled. + """ + if proxy_mode.decomp_layers or not proxy_mode.emulate_precision_casts: + return + + if not isinstance(func, torch._ops.OpOverload): + return + + last_node = next(iter(reversed(proxy_mode.tracer.graph.nodes))) + t = last_node.meta.get("val") + low_pr_fp = (torch.bfloat16, torch.float16) + + output_low_precision = isinstance(t, torch.Tensor) and t.dtype in low_pr_fp + + if not output_low_precision: + for input_node in last_node.all_input_nodes: + val = input_node.meta.get("val") if hasattr(input_node, "meta") else None + if isinstance(val, torch.Tensor) and val.dtype in low_pr_fp: + output_low_precision = True + break + + if not output_low_precision: + return + + last_node.meta["low_precision_pointwise_barrier"] = True + + +def _fetch_proxies_and_all_constant_flag( + flat_args_kwargs: Union[list[object], tuple[object, ...]], tracer: _ProxyTracer +) -> tuple[list[object], tuple[object, ...], bool]: + """ + Given flat arguments, fetch the proxies and whether they are all constants. + This is later used in proxy_call or when someone is trying to stitch together + graph node in tf or td modes. + """ + f_flat_args_kwargs = [ + ( + fetch_object_proxy(tracer, x) + if isinstance(x, (Tensor, _AnyScriptObject)) + else x + ) + for x in flat_args_kwargs + ] + + # If there are SymInts, we also should not consider this constant. + # However, fake tensor handling of SymInts is sufficiently broken that + # I couldn't write a test for this case + all_constant = ( + not any( + t.constant is None + for t in f_flat_args_kwargs + if isinstance(t, _ProxyTensor) + ) + # TODO: maybe constant SymInts should also be allowed? Not sure if + # this can happen + and not any(isinstance(x, py_sym_types) for x in flat_args_kwargs) + ) + + proxy_flat_args_kwargs = [ + e.proxy if isinstance(e, _ProxyTensor) else e for e in f_flat_args_kwargs + ] + + proxy_flat_args_kwargs = [ + (fetch_sym_proxy(tracer)(e) if isinstance(e, py_sym_types) else e) + for e in proxy_flat_args_kwargs + ] + + return f_flat_args_kwargs, tuple(proxy_flat_args_kwargs), all_constant + + +def proxy_call( + proxy_mode: ProxyTorchDispatchMode, + func: OpOverload, + pre_dispatch: bool, + args: tuple[object, ...], + kwargs: dict[str, object], +) -> object: + unrecognized_types: list[type] = [] + flat_args_kwargs, spec = pytree.tree_flatten((args, kwargs)) + + def can_handle_tensor(x: Tensor) -> bool: + r = type(x) in HANDLED_TYPES or has_proxy_slot(x, proxy_mode.tracer) + if proxy_mode._allow_fake_constant: + r = r or type(x) is torch._subclasses.FakeTensor + if not r: + unrecognized_types.append(type(x)) + return r + + # If there are any tensor subclasses, we need to handle those tensor subclasses first + # TODO: we could use types to test this + if not all(can_handle_tensor(x) for x in flat_args_kwargs if isinstance(x, Tensor)): + not_implemented_log.debug( + "ProxyTensorMode tensors without proxy had unrecognized subclasses: %s", + unrecognized_types, + ) + return NotImplemented + + r = maybe_handle_decomp(proxy_mode, func, args, kwargs) + if r is not NotImplemented: + _maybe_record_pointwise_barrier(func, proxy_mode) + return r + + # For pre-autograd tracing, we do not want to run CompositeImplicit decomps. + if ( + not pre_dispatch + and func + not in [ + torch.ops.aten.size.default, + torch.ops.aten.stride.default, + torch.ops.aten.storage_offset.default, + ] + and autograd_would_have_decomposed(func, flat_args_kwargs) + ): + with proxy_mode: + r = func.decompose(*args, **kwargs) + if r is not NotImplemented: + return r + + if func is torch.ops.aten.is_nonzero.default: + with proxy_mode: + torch._check( + args[0].numel() == 1, # type: ignore[attr-defined] + lambda: "Boolean value of Tensor with more than one value is ambiguous", + ) + return (args[0] != 0).item() # type: ignore[attr-defined] + + tracer = proxy_mode.tracer + f_flat_args_kwargs, proxy_flat_args_kwargs, all_constant = ( + _fetch_proxies_and_all_constant_flag(flat_args_kwargs, tracer) + ) + + if torch.Tag.data_dependent_output in func.tags: + # Check if all of the Tensor inputs are constants + if all_constant: + const_flat_args_kwargs = [ + t.constant if isinstance(t, _ProxyTensor) else t + for t in f_flat_args_kwargs + ] + const_args, const_kwargs = pytree.tree_unflatten( + const_flat_args_kwargs, spec + ) + with unset_fake_temporarily(): + return func(*const_args, **const_kwargs) + # If any of the Tensor inputs are "real" (not FakeTensor), we may + # incorrectly burn in constants by allowing this access. Raise + # an error in this case + if proxy_mode._error_on_data_dependent_ops and pytree.tree_all_only( + Tensor, lambda t: not is_fake(t), (args, kwargs) + ): + raise RuntimeError( + f"It appears that you're trying to get value out of a tracing tensor with {func} - erroring out! " + "It's likely that this is caused by data-dependent control flow or similar. " + "It may be possible to trace this with dynamic shapes; try setting tracing_mode='symbolic' " + "in your make_fx call." + ) + + proxy_args, proxy_kwargs = pytree.tree_unflatten(proxy_flat_args_kwargs, spec) + + # When we trace through a torch.tensor invocation, you never actually + # see a torch.ops.aten.tensor call. Instead, the way this function is + # implemented internally is that we allocate a plain tensor (this is + # *guaranteed* to be a plain tensor, we disable all modes when doing + # so), and then call at::lift_fresh on it (to give modes a chance to do + # their stuff). Furthermore, the tensor argument to lift_fresh is guaranteed + # to be freshly allocated, so we want lift_fresh to be a no-op (directly + # returning the input argument). + # + # Here is the basic problem: when we trace this sequence of executions + # into an FX graph, what happens to this call sequence? Traditionally, + # tensor constants get interned as buffers on the FX GraphModule. But + # this is dangerous. Consider: + # + # x = torch.tensor(1) + # x.add_(2) + # + # Naively, this traces into: + # + # t = self._tensor_constant0 # initialized to torch.tensor(1) + # x = torch.ops.aten.lift_fresh(t) + # x.add_(2) + # + # If lift_fresh returns t directly, the subsequent add_ call will + # modify the tensor constant. Really, the problem is we've violated + # the invariant the argument to lift is fresh. So what we should + # preserve the invariant by replacing lift_fresh with lift_fresh_copy: + # + # t = self._tensor_constant0 # initialized to torch.tensor(1) + # x = torch.ops.aten.lift_fresh_copy(t) + # x.add_(2) + # + # This is what the overload modification does. + if func is torch.ops.aten.lift_fresh.default: + func = torch.ops.aten.lift_fresh_copy.default + + proxy_out = proxy_mode.tracer.create_proxy( + "call_function", + func, + proxy_args, + proxy_kwargs, + name=proxy_mode.tracer.graph._target_to_str(func.overloadpacket.__name__), + ) + + with _enable_thunkify(proxy_mode.tracer): + out = func(*args, **kwargs) + + # In some circumstances, we will be tracing in a situation where a tensor + # is *statically* known to be a constant (currently, this only happens if + # you run torch.tensor; deterministic factory functions like torch.arange + # don't get this treatment). When the tensor in question is small, it's + # helpful to due constant propagation in case we call item() (in which + # case we can return the constant value that is known, rather than give + # an error.) The logic here tests if constant propagation is possible + # (because all of the inputs are constant). If so, we disable fake tensor + # mode (if it is on) and do true compute on the constant. + # + # It's worth highlighting that we're making a policy decision here. + # There is a potential that the tensor is actually quite large, and we + # don't actually want to run the compute. The tensor being quite large + # is one of the reasons why factory functions don't get this treatment + # (since they can be quite large; if a parameter is initialized to a + # constant value it will be!) Similarly, there is also a potential + # to run an operator that blows up the size of a small tensor; we don't + # protect against this case, but we could force, e.g., only single + # element constant computation by testing the numel of the result before + # propagating const-ness. Similarly, we don't require the constant to + # live on CPU, but we could. + any_constant = any( + t.constant is not None + for t in f_flat_args_kwargs + if isinstance(t, _ProxyTensor) + ) + + constant = None + + def tensor_numel_in_limit(t: Tensor) -> bool: + return t.numel() <= CONSTANT_NUMEL_LIMIT + + # If this is a lift, the input tensor is guaranteed to be a + # constant, so we keep a copy of the original argument along so + # we can query it if we're asked to item() it at some later point + if ( + func is torch.ops.aten.lift_fresh_copy.default + and out.numel() <= CONSTANT_NUMEL_LIMIT + ): + with unset_fake_temporarily(): + assert isinstance(args[0], (Proxy, Tensor)), type(args[0]) + constant = args[0].clone() + elif ( + torch.Tag.nondeterministic_seeded not in func.tags + and all_constant + and any_constant + and pytree.tree_all_only(Tensor, tensor_numel_in_limit, out) + ): + # NB: do NOT include factories as constants + with unset_fake_temporarily(): + const_flat_args_kwargs = [ + t.constant if isinstance(t, _ProxyTensor) else t + for t in f_flat_args_kwargs + ] + const_args, const_kwargs = pytree.tree_unflatten( + const_flat_args_kwargs, spec + ) + constant = func(*const_args, **const_kwargs) + else: + constant = None + + track_tensor_tree(out, proxy_out, constant=constant, tracer=tracer) + _maybe_record_pointwise_barrier(func, proxy_mode) + return out + + +class _SymNodeDict: + """ + Wrapper around a dictionary that will hash SymInts with their nodes + """ + + def __init__(self) -> None: + self.sym_node_dict: dict[PySymType, _PySymProxyType] = {} + + def __setitem__(self, key: PySymType, value: _PySymProxyType) -> None: + self.sym_node_dict[key.node] = value + + def __getitem__(self, key: PySymType) -> _PySymProxyType: + return self.sym_node_dict[key.node] + + def __contains__(self, key: PySymType) -> bool: + return key.node in self.sym_node_dict + + def get( + self, key: PySymType, default: Optional[_PySymProxyType] = None + ) -> _PySymProxyType: + # dict.get()'s annotation doesn't accept `None` when the value type + # isn't Optional. + return self.sym_node_dict.get(key.node, default) # type: ignore[arg-type, return-value] + + def __iter__(self) -> Any: + raise NotImplementedError + + def __len__(self) -> int: + return len(self.sym_node_dict) + + +@dataclass +class _SympyExprTrackerValue: + proxy: _PySymProxyType + value: PySymType + + +class PythonKeyTracer(Tracer): + script_object_tracker: MutableMapping[_AnyScriptObjectType, Proxy] + symnode_tracker: _SymNodeDict + sympy_expr_tracker: dict[sympy.Symbol, _SympyExprTrackerValue] + tensor_tracker: MutableMapping[Tensor, _ProxyTensor] + torch_fn_counts: dict[OpOverload, int] + enable_thunkify: bool = False + + def __init__(self) -> None: + super().__init__(autowrap_modules=()) # type: ignore[arg-type] + self.tensor_tracker = WeakTensorKeyDictionary() + self.symnode_tracker = _SymNodeDict() + self.script_object_tracker = WeakIdKeyDictionary( + dict=None, ref_type=_WeakHashRef + ) + self.sympy_expr_tracker = {} + + # Stores the torch function that was called during tracing + self.torch_fn_metadata = None + # Stores the counts for every torch function called. This is to help + # distinguish between different calls to the same torch function. + self.torch_fn_counts = {} + self.enable_thunkify = False + + # In general, we don't want to make modules leaves. In principle, users of + # this tracer might want to override this in order to turn a couple specific + # modules into leaves in the traced graph. + def call_module( + self, + m: Module, + forward: Callable[..., Any], + args: tuple[Any, ...], + kwargs: dict[str, Any], + ) -> Any: + return forward(*args, **kwargs) + + # We don't want to turn getattr calls into proxies. So we just return the actual value. + def getattr( + self, attr: str, attr_val: object, parameter_proxy_cache: dict[str, Proxy] + ) -> object: + return attr_val + + def create_arg(self, a: object) -> fx.node.Node: + if isinstance(a, torch.nn.Parameter): + for n, p in self.root.named_parameters(): + if a is p: + return self.create_node("get_attr", n, (), {}) + + qualname = self.get_fresh_qualname("_param_constant") + setattr(self.root, qualname, a) + + return self.create_node("get_attr", qualname, (), {}) + elif isinstance(a, py_sym_types): + assert a.node.constant is not None + return a.node.constant + return super().create_arg(a) # type: ignore[return-value] + + @overload + def unwrap_proxy(self, e: Tensor) -> Union[Proxy, Tensor]: ... + + @overload + def unwrap_proxy(self, e: PySymType) -> Union[Proxy, PySymType]: ... + + @overload + def unwrap_proxy( + self, e: _AnyScriptObjectType + ) -> Union[Proxy, _AnyScriptObjectType]: ... + + def unwrap_proxy(self, e: T) -> object: + if isinstance(e, Tensor): + return get_proxy_slot(e, self, e, lambda x: x.proxy) # type: ignore[attr-defined] + elif isinstance(e, py_sym_types): + return get_proxy_slot(e, self, e, lambda e: e.force()) + elif isinstance(e, _AnyScriptObject): + return get_proxy_slot(e, self, e) + else: + return e + + def create_node( + self, + kind: str, + target: Target, + args: tuple[Argument, ...], + kwargs: dict[str, Argument], + name: Optional[str] = None, + type_expr: Optional[Any] = None, + ) -> torch.fx.Node: + node = super().create_node(kind, target, args, kwargs, name, type_expr) # type: ignore[arg-type] + + if node.op in ["placeholder", "output"] and "stack_trace" in node.meta: + del node.meta["stack_trace"] + + if kind == "get_attr": + assert isinstance(target, str) + attr = getattr(self.root, target) + if isinstance(attr, torch.Tensor): + with disable_proxy_modes_tracing(): + node.meta["val"] = extract_val(attr) + + def map_fn(v: Any) -> Optional[_ExtractValType]: + if not isinstance(v, torch.fx.Node) or "val" not in v.meta: + return None + val = v.meta["val"] + # other subclasses like FunctionalTensor error on `extract_val` + # "Attempting to use FunctionalTensor on its own." just store FakeTensors for now + if isinstance(val, torch.Tensor) and not isinstance(val, FakeTensor): + return None + return extract_val(v.meta["val"]) + + if _should_save_eager_input_vals(target, (args, kwargs)): + # NOTE "eager_input_vals" + # We save the original (args, kwargs) FakeTensor values for nodes + # that have exact stride requirements. This is useful downstream. + # We use this information inside Inductor to ensure that inputs to + # stride-sensitive operators have the correct strides. + arg_inp, kwarg_inp = torch.fx.node.map_aggregate((args, kwargs), map_fn) # type: ignore[misc, arg-type] + node.meta["eager_input_vals"] = (arg_inp, kwarg_inp) + + return node + + +def _should_save_eager_input_vals( + target: Any, + args_kwargs: Optional[tuple[tuple[Argument, ...], dict[str, Argument]]] = None, +) -> bool: + from torch._higher_order_ops.invoke_subgraph import InvokeSubgraphHOP + + if not callable(target): + return False + if isinstance( + target, + ( + torch._higher_order_ops.triton_kernel_wrap.TritonKernelWrapperFunctional, + torch._higher_order_ops.triton_kernel_wrap.TritonKernelWrapperMutation, + InvokeSubgraphHOP, + ), + ): + return True + if args_kwargs is not None and ( + target is torch.ops.higher_order.auto_functionalized + or target is torch.ops.higher_order.auto_functionalized_v2 + ): + args = args_kwargs[0] + assert isinstance( + args[0], (torch._ops.OpOverload, torch._ops.HigherOrderOperator) + ) + return _should_save_eager_input_vals(args[0], None) + if target is torch.ops.higher_order.with_effects: + # TODO: inductor lowering for with_effects needs to be updated to propagate + # the arg_kwarg_vals + return False + if isinstance(target, torch._ops.HigherOrderOperator): + if pytree.tree_any(_should_save_eager_input_vals, args_kwargs): + raise RuntimeError( + f"NYI: The HOP {target} has an input that is an OpOverload that " + f"needs exact strides. We probably need special logic to " + f"propagate the FakeTensor vals. Please file an issue." + ) + if isinstance(target, torch._ops.OpOverload): + from torch._library.utils import get_layout_constraint_tag + + return get_layout_constraint_tag(target) == torch._C.Tag.needs_exact_strides + return False + + +def _make_temp_remove_mode_context_manager( + mode_ty: type[TorchFunctionMode], +) -> Callable[[], _GeneratorContextManager[Optional[TorchFunctionMode]]]: + @contextmanager + def context_manager_fn() -> Generator[Optional[TorchFunctionMode], None, None]: + from torch.overrides import _len_torch_function_stack, _pop_mode, _push_mode + + temp_elements = [] + removed_mode = None + + while _len_torch_function_stack() > 0: + mode = _pop_mode() + if isinstance(mode, mode_ty): + removed_mode = mode + break + else: + temp_elements.append(mode) + + for mode in reversed(temp_elements): + _push_mode(mode) + + try: + yield removed_mode + + finally: + if removed_mode is not None: + count = len(temp_elements) + while count > 0: + mode = _pop_mode() + count -= 1 + + temp_elements.append(removed_mode) + + for mode in reversed(temp_elements): + _push_mode(mode) + + return context_manager_fn + + +@torch._disable_dynamo +def dispatch_trace( + root: Union[Module, Callable], + tracer: Tracer, + concrete_args: Optional[tuple[Any, ...]] = None, +) -> GraphModule: + graph = tracer.trace(root, concrete_args) # type: ignore[arg-type] + + # NB: be careful not to DCE .item() calls + def impure_pred(n: fx.Node) -> bool: + from .symbolic_shapes import is_accessor_node + + # Always defer to the built-in notion of impure + if n.is_impure(): + return True + + # Accessors always OK to DCE + if is_accessor_node(n): + return False + + # If the operator in question takes SymInt args to SymInt output, + # we assume it's pure and OK to DCE + if ( + isinstance(n.meta.get("val"), py_sym_types) + and + # NB: constant args ok + all( + isinstance(a.meta.get("val"), py_sym_types) + for a in n.args + if isinstance(a, fx.Node) + ) + ): + return False + + # No idea, just assume it's not OK + return True + + graph.eliminate_dead_code(impure_pred) + from torch._inductor.fx_passes.dedupe_symint_uses import dedupe_symints + + dedupe_symints(graph) + name = root.__class__.__name__ if isinstance(root, Module) else root.__name__ + return fx._lazy_graph_module._make_graph_module(tracer.root, graph, name) + + +def wrap_key( + f: Callable[[Unpack[_Ts]], R], + tensors: tuple[Unpack[_Ts]], + tracer: _ProxyTracer, + pre_dispatch: bool, +) -> Callable[_P, R]: + flat_tensors, _tensors_spec = pytree.tree_flatten(tensors) + + @functools.wraps(f) + def wrapped(*proxies: _P.args, **_unused: _P.kwargs) -> R: + nonlocal tensors + + flat_proxies, _proxies_spec = pytree.tree_flatten(proxies) + assert len(flat_proxies) == len(flat_tensors) + with disable_proxy_modes_tracing() as m: + assert isinstance(m, ProxyTorchDispatchMode) + track_tensor_tree(flat_tensors, flat_proxies, constant=None, tracer=tracer) + + if getattr(tracer, "proxy_module_inputs", False): + tensors = [ # type: ignore[assignment, var-annotated] + p if isinstance(t, torch.nn.Module) else t + for t, p in zip(tensors, proxies) # type: ignore[arg-type] + ] + + def get_tensor_proxy_slot(t: Tensor) -> Union[Tensor, Proxy]: + return get_proxy_slot(t, tracer, t, lambda x: x.proxy) # type: ignore[attr-defined] + + out = f(*tensors) # type:ignore[call-arg] + out = pytree.tree_map_only(Tensor, get_tensor_proxy_slot, out) + out = pytree.tree_map_only( + _AnyScriptObject, lambda t: get_proxy_slot(t, tracer, t, lambda x: x), out + ) + + def get_sym_proxy_slot(t: PySymType) -> Proxy: + return get_proxy_slot(t, tracer).force() + + out = pytree.tree_map_only(py_sym_types, get_sym_proxy_slot, out) + return out + + return wrapped + + +# TODO: Make downstream users of this work with OperatorBase +ORIGINAL_ATEN: Optional[object] = None + + +@contextmanager +def set_original_aten_op( + func: OpOverload | torch._ops.HigherOrderOperator, +) -> Generator[None, None, None]: + global ORIGINAL_ATEN + if ORIGINAL_ATEN is None and fx_traceback.has_preserved_node_meta(): + ORIGINAL_ATEN = func + fx_traceback.current_meta["original_aten"] = func + try: + yield + finally: + ORIGINAL_ATEN = None + fx_traceback.current_meta["original_aten"] = None + else: + yield + + +class TorchFunctionMetadataMode(TorchFunctionMode): + def __init__(self, tracer: _ProxyTracer) -> None: + self.tracer = tracer + + def __torch_function__( + self, + func: OpOverload, + types: tuple[torch._C._TensorMeta, ...], + args: tuple[object, ...] = (), + kwargs: Optional[dict[str, object]] = None, + ) -> object: + kwargs = kwargs or {} + # pyrefly: ignore [bad-assignment] + self.tracer.torch_fn_metadata = func + self.tracer.torch_fn_counts[func] = self.tracer.torch_fn_counts.get(func, 0) + 1 + return func(*args, **kwargs) + + +_temp_remove_metadata_torch_function_mode = _make_temp_remove_mode_context_manager( + TorchFunctionMetadataMode +) + + +# This mode is **only** used for pre_dispatch tracing. +# In particular, we need to make sure that autograd/autocast API's +# that do not desugar into dispatcher operators stay in the graph. +class PreDispatchTorchFunctionMode(TorchFunctionMode): + def __init__(self, tracer: _ProxyTracer) -> None: + self.tracer = tracer + # The input to torch.amp.autocast_mode._exit_autocast graph node should be the + # enter_autocast node. So we have to save the enter autocast node here, and assign it + # to the exit_autocast call_function node. + self.enter_autocast_nodes: list[torch.fx.Node] = [] + + def __torch_function__( + self, + func: Union[OpOverload, Callable], + types: tuple[torch._C._TensorMeta, ...], + args: tuple[object, ...] = (), + kwargs: Optional[dict[str, object]] = None, + ) -> object: + kwargs = kwargs or {} + if func in _side_effectful_need_to_be_preserved_pre_dispatch: + # It's for passing the export verifier which needs to verify the meta['val'] + # TODO(tmanlaibaatar): we should systematically couple it with export verifier, + # instead of hardcoding it here. + # T203648563 + if func is torch.amp.autocast_mode._exit_autocast: + enter_node = self.enter_autocast_nodes.pop() + args = (enter_node,) + node = self.tracer.create_node("call_function", func, args, {}) # type: ignore[arg-type] + if func is torch.amp.autocast_mode._enter_autocast: + self.enter_autocast_nodes.append(node) + if func in [ + torch._C._set_grad_enabled, + torch.amp.autocast_mode._enter_autocast, + torch.amp.autocast_mode._exit_autocast, + ]: + node.meta["val"] = None + # For autocast, the python APIs run so we don't have to run them again + # here. + if func is torch._C._set_grad_enabled: + # pyrefly: ignore [bad-argument-type] + func(*args, **kwargs) + return node + + # We need more complicated handling here because the inputs + # to these functions are sometimes tensors or symints where + # we need to fetch the proxies properly. + if func in [ + torch._functorch.predispatch._add_batch_dim, + torch._functorch.predispatch._remove_batch_dim, + torch._functorch.predispatch._vmap_increment_nesting, + torch._functorch.predispatch._vmap_decrement_nesting, + torch._functorch.vmap.lazy_load_decompositions, + ]: + _, proxies, _ = _fetch_proxies_and_all_constant_flag(args, self.tracer) + out_proxy = self.tracer.create_proxy( + "call_function", + func, + proxies, + {}, + ) + res = func(*args, **kwargs) + track_tensor_tree(res, out_proxy, constant=None, tracer=self.tracer) + return res + return func(*args, **kwargs) + + +_temp_remove_pre_dispatch_torch_function_mode = _make_temp_remove_mode_context_manager( + PreDispatchTorchFunctionMode +) + + +class ProxyTorchDispatchMode(TorchDispatchMode): + # Ensure this is read-only; this exists only for legacy reasons + @property + def enable_tracing(self) -> bool: + return True + + def __init__( + self, + tracer: _ProxyTracer, + tracing_mode: str, + pre_dispatch: bool = False, + _allow_fake_constant: bool = False, + _error_on_data_dependent_ops: bool = True, + ) -> None: + dk = torch._C.DispatchKey.PreDispatch if pre_dispatch else None + super().__init__(dk) + self.tracer = tracer + self.tracing_mode = tracing_mode + self.pre_dispatch = pre_dispatch + self._allow_fake_constant = _allow_fake_constant + self._error_on_data_dependent_ops = _error_on_data_dependent_ops + # Indicates to our torch_dispatch dispatching infra that + # this is an "infra" mode with lower dispatching precedence. + self._mode_key = torch._C._TorchDispatchModeKey.PROXY + # Every time we enter a mode, we maintain a stack telling us what the previous + # ProxyTorchDispatchMode state was (if there was any). + # This lets us properly reset the state on exit. + self.enter_stack: list[Optional[ProxyTorchDispatchMode]] = [] + self.decomp_layers: int = 0 + from torch._inductor import config + + self.emulate_precision_casts: bool = config.emulate_precision_casts + + @count + def __torch_dispatch__( + self, + func: OpOverload, + types: tuple[torch._C._TensorMeta, ...], + args: tuple[object, ...] = (), + kwargs: Optional[dict[str, object]] = None, + ) -> object: + with set_original_aten_op(func): + kwargs = kwargs or {} + + if func == prim.device.default: + return func(*args, **kwargs) + + return proxy_call(self, func, self.pre_dispatch, args, kwargs) + + def __enter__(self) -> Self: + # Stash and store the previous proxy mode (there may or may not be one) + maybe_prev_proxy_mode = _unset_infra_mode(torch._C._TorchDispatchModeKey.PROXY) + self.enter_stack.append(maybe_prev_proxy_mode) + return super().__enter__() + + def __exit__( + self, + exc_type: Optional[type[BaseException]], + exc_value: Optional[BaseException], + traceback: Optional[types.TracebackType], + ) -> Optional[bool]: + b = super().__exit__(exc_type, exc_value, traceback) + + # Re-enable the previous proxy mode, if there was one. + mb_previous_proxy_mode = self.enter_stack.pop() + if mb_previous_proxy_mode is not None: + _push_mode(mb_previous_proxy_mode) + + return b + + @classmethod + def is_infra_mode(cls) -> bool: + return True + + def __sym_dispatch__( + self, + func: OpOverload, + types: tuple[torch._C._TensorMeta, ...], + args: tuple[object, ...], + kwargs: dict[str, object], + ) -> object: + # Peephole optimize multiply by one + # NB: be careful not to trigger guards here! + if func is operator.mul: + if isinstance(args[1], int) and args[1] == 1: + return args[0] + elif isinstance(args[0], int) and args[0] == 1: + return args[1] + + # For speed, we assume there are no nested data structures + # (otherwise we could use tree_map) + # We also assume there are no keyword arguments. + assert not kwargs + out = func(*args, **kwargs) + _sym_register(self.tracer, func, args, out) + return out + + +def _sym_register( + tracer: _ProxyTracer, func: OpOverload, args: tuple[object, ...], out: object +) -> None: + # If func returned a constant, we don't need to trace; we have + # determined that the result is constant (no matter if the inputs + # were symbolic) and it is no longer necessary to trace the + # computation. This could occur if func triggered some guards. + if isinstance(out, py_sym_types): + p_out_thunk = thunkify( + tracer, _compute_proxy, tracer, func=func, args=args, out=out + ) + set_proxy_slot(out, tracer, p_out_thunk) + + +def _compute_proxy( + tracer: _ProxyTracer, func: OpOverload, args: tuple[object, ...], out: PySymType +) -> Proxy: + # Handle torch.sym_sum + n_args: tuple[object, ...] + if len(args) == 1 and isinstance(args[0], (list, tuple)): + n_args = ( + tuple( + ( + get_proxy_slot(a, tracer).force().node + if isinstance(a, py_sym_types) + else a + ) + for a in args[0] + ), + ) + else: + n_args = tuple( + ( + get_proxy_slot(a, tracer).force().node + if isinstance(a, py_sym_types) + else a + ) + for a in args + ) + + # func doesn't have a __torch_function__ that Proxy can interpose, so + # we gotta do it manually + n_out = tracer.create_node("call_function", func, n_args, {}) # type: ignore[arg-type] + p_out = fx.Proxy(n_out, tracer) + set_meta(p_out, out) + return p_out + + +class _GraphAppendingTracerEx(fx.proxy.GraphAppendingTracer): + script_object_tracker: MutableMapping[_AnyScriptObjectType, Proxy] + symnode_tracker: MutableMapping[PySymType, _PySymProxyType] + tensor_tracker: MutableMapping[Tensor, _ProxyTensor] + sympy_expr_tracker: dict[sympy.Symbol, _SympyExprTrackerValue] + torch_fn_metadata: Optional[OpOverload] + torch_fn_counts: dict[OpOverload, int] + enable_thunkify: bool = False + + def __init__(self, graph: fx.graph.Graph) -> None: + super().__init__(graph) + self.symnode_tracker = weakref.WeakKeyDictionary() + self.tensor_tracker = WeakTensorKeyDictionary() + self.sympy_expr_tracker = {} + self.script_object_tracker = WeakIdKeyDictionary( + dict=None, ref_type=_WeakHashRef + ) + # Stores the torch function that was called during tracing + self.torch_fn_metadata = None + # Stores the counts for every torch function called. This is to help + # distinguish between different calls to the same torch function. + self.torch_fn_counts = {} + + +# TODO: I'm not sure what the point of this class is; you can just +# make_fx through a regular Interpreter +class DecompositionInterpreter(fx.Interpreter): + def __init__( + self, + module: fx.GraphModule, + new_graph: fx.Graph, + decomposition_table: Optional[Mapping[OpOverload, Callable]] = None, + **kwargs: object, + ) -> None: + super().__init__(module, **kwargs) # type: ignore[arg-type] + self.new_graph = new_graph + self.tracer = _GraphAppendingTracerEx(self.new_graph) + # Blegh + self.decomposition_table = decomposition_table or {} + self.mode = ProxyTorchDispatchMode(self.tracer, tracing_mode="real") + + # pyrefly: ignore [bad-override] + def placeholder( + self, + target: str, # type: ignore[override] + args: tuple[object, ...], + kwargs: dict[str, object], + ) -> object: + out = super().placeholder(target, args, kwargs) # type: ignore[arg-type] + proxy = fx.Proxy(self.new_graph.placeholder(target), self.tracer) + track_tensor_tree(out, proxy, constant=None, tracer=self.tracer) + # TODO handle case where the first character of target is '*' + return out + + # pyrefly: ignore [bad-override] + def get_attr( + self, + target: str, # type: ignore[override] + args: tuple[object, ...], + kwargs: dict[str, object], + ) -> object: + out = super().get_attr(target, args, kwargs) # type: ignore[arg-type] + proxy = fx.Proxy(self.new_graph.get_attr(target), self.tracer) + track_tensor_tree(out, proxy, constant=None, tracer=self.tracer) + return out + + # call_function, call_method, call_module get traced automatically by the outer mode. + + # pyrefly: ignore [bad-override] + def output( + self, + target: str, # type: ignore[override] + args: tuple[object, ...], + kwargs: dict[str, object], + ) -> object: + out = super().output(target, args, kwargs) # type: ignore[arg-type] + + def get_proxy_node(x: _ProxyTensor) -> fx.node.Node: + return x.proxy.node + + def unwrap(e: Tensor) -> Union[Tensor, fx.Node]: + return get_proxy_slot(e, self.tracer, e, get_proxy_node) + + self.new_graph.output(pytree.tree_map(unwrap, out)) + return out + + def run(self, *args: object, **kwargs: object) -> object: + # Should enter the mode at least once for being able to restore it later + # See: https://github.com/pytorch/pytorch/pull/82549#discussion_r934782025 + with decompose(self.decomposition_table), self.mode: + return super().run(*args, **kwargs) # type: ignore[arg-type] + + +class _SelectiveDecomposeInterpreter(fx.Interpreter): + def __init__( + self, + module: fx.GraphModule, + should_decompose: Callable[[fx.Node], bool], + decomposition_table: Mapping[OpOverload, Callable], + **kwargs: object, + ) -> None: + """ + For all nodes in `module`, selectively decompose if is `should_decompose`, + following the given `decomposition_table`. + """ + super().__init__(module, **kwargs) # type: ignore[arg-type] + self.should_decompose = should_decompose + self.decomposition_table = decomposition_table + + @staticmethod + def recursive_wrap( + gm: fx.GraphModule, + should_decompose: Callable[[fx.Node], bool], + decomposition_table: Mapping[OpOverload, Callable], + **kwargs: object, + ) -> _SelectiveDecomposeInterpreter: + """ + Recursively wrap gm and its sub graph modules. Specifically, HOP takes + sub graph module as args. We may not want to decompose all nodes within + these sub graph modules. So we also need to wrap these sub graph modules. + As a result: + - if should_decompose(hop) is True, we decompose all nodes within the hop. + - if should_decompose(hop) is False, we check each node within the hop + and decide whether decompose or not. + """ + for node in gm.graph.nodes: + if node.op == "call_function" and isinstance( + node.target, HigherOrderOperator + ): + new_args = [] + for arg in node.args: + if isinstance(arg, fx.GraphModule): + new_arg = _SelectiveDecomposeInterpreter.recursive_wrap( + arg, should_decompose, decomposition_table, **kwargs + ) + else: + new_arg = arg + new_args.append(new_arg) + node.args = tuple(new_args) + + return _SelectiveDecomposeInterpreter( + gm, should_decompose, decomposition_table, **kwargs + ) + + def run_node(self, n): + if self.should_decompose(n): + with decompose(self.decomposition_table): + result = super().run_node(n) + else: + result = super().run_node(n) + return result + + +def selective_decompose( + joint_gm: fx.GraphModule, + *args, + decomposition, + should_decompose, + trace_joint_graph: bool, +) -> fx.GraphModule: + """Retrace a joint graph module and selectively apply decomposition.""" + + if trace_joint_graph: + # the arg name, primals and tangents, are important. + # make_fx keeps the name in the traced graph and partitioner later relies + # on the name to partition joint graph correctly. + def wrap_fn(primals: list[Any], tangents: list[Any]): + return _SelectiveDecomposeInterpreter.recursive_wrap( + joint_gm, should_decompose, decomposition + ).run(*args) + else: + + def wrap_fn(*args): + return _SelectiveDecomposeInterpreter.recursive_wrap( + joint_gm, should_decompose, decomposition + ).run(*args) + + return make_fx(wrap_fn, decomposition_table={})(*args) + + +def wrapper_and_args_for_make_fx( + func: Callable[..., R], args: tuple[object, ...], kwargs: dict[str, object] +) -> tuple[Callable[[list[object]], R], list[object]]: + # make_fx doesn't support kwargs, so we need to do this flattening + # and then unflatten the args before calling func + flat_args, spec = pytree.tree_flatten((args, kwargs)) + + def wrapped(flat_args: list[object]) -> R: + fn_args, fn_kwargs = pytree.tree_unflatten(flat_args, spec) + return func(*fn_args, **fn_kwargs) + + return wrapped, flat_args + + +@contextmanager +def disable_autocast_cache() -> Generator[None, None, None]: + old_value = torch.is_autocast_cache_enabled() + torch.set_autocast_cache_enabled(False) + try: + yield + finally: + torch.set_autocast_cache_enabled(old_value) + + +class _ModuleNotInstalledAsSubmoduleError(NameError): + pass + + +# Base class for inline _ModuleStackTracer.__init__.AttrProxy +class _AttrProxy: + def reset_proxy_mapping(self, base: Module, path: str) -> None: + pass + + +class _ModuleStackTracer(PythonKeyTracer): + r"""Customized version of PythonKeyTracer that retains module stack + information in node.meta["nn_module_stack"]. + + FX symbolic trace actually does this already, but it relies on `self.root` + being the actual module being traced. Since make_fx traces a lambda of our + creation, things don't work properly. + + So for this version we hold onto a reference to the original module + (scope_root) and use that to match the path. Also when we see, + A + / \ + B C + \ / + D + we want to record the path as A.B.D by recording only one path. + See Note [Preserving the nn module stack metadata during export non-strict mode] # noqa: W605 + """ + + def __init__(self, scope_root: GraphModule) -> None: + super().__init__() + self.record_stack_traces = True + self._record_forward_stack_traces_only = True + self.scope_root = scope_root + self.enable_attr_proxy = False + self.submodule_paths = {} + for name, m in self.scope_root.named_modules(remove_duplicate=False): + if m in self.submodule_paths: + log.info( + "Shared module found between %s and %s, AttrProxy is enabled.", + self.submodule_paths[m], + name, + ) + self.enable_attr_proxy = True + else: + self.submodule_paths[m] = name + + self.proxy_paths: WeakKeyDictionary[_AttrProxy, str] = WeakKeyDictionary() + self.attr_proxy_map: WeakKeyDictionary[Module, _AttrProxy] = WeakKeyDictionary() + self.proxy_modules: WeakKeyDictionary[_AttrProxy, Module] = WeakKeyDictionary() + self.counter = 0 + + self.module_id_cache = defaultdict(list) + for name, mod in self.scope_root.named_modules(remove_duplicate=False): + self.module_id_cache[id(mod)].append(name) + + # Build a wrapper around _AttrProxy to provide the tracer. We can't + # store it on _AttrProxy itself beceause we mimic the underlying class + # (including its attributes). + tracer = self + + class AttrProxy(_AttrProxy): + def __init__(self, base: Union[Module, _AttrProxy], path: str) -> None: + if isinstance(base, _AttrProxy): + base = base.get_base() # type: ignore[attr-defined] + + assert isinstance(base, Module) + # Class is modified to be a subclass of torch.nn.Module + # Warning: We blow away our own attributes here to mimic the base class + # - so don't expect `self.x` to do anything useful. + # pyrefly: ignore [no-matching-overload] + # pyrefly: ignore [bad-override] + self.__class__ = type( + base.__class__.__name__, + (self.__class__, base.__class__), + {}, + ) + self.__dict__ = base.__dict__ + self.__class__.__module__ = base.__class__.__module__ + self.__class__.__qualname__ = base.__class__.__qualname__ + + # This overwrites any existing paths if `base` is an AttrProxy + tracer.proxy_paths[self] = path + tracer.proxy_modules[self] = base + + def __getattr__(self, name: str) -> AttrProxy: + assert isinstance(self, Module) + # Calling into torch.nn.Module.__getattr__ with super(), + # That __getattr__ is patched to be module_getattr_wrapper in _symbolic_trace.py. + # which then calls into _ModuleStackTracer.getattr + attr_val = super().__getattr__(name) # type: ignore[misc] + if not isinstance(attr_val, Module): + return attr_val + + # pyrefly: ignore [index-error] + return AttrProxy(attr_val, tracer.proxy_paths[self] + "." + name) + + def get_base(self) -> Module: + return tracer.proxy_modules[self] + + def __getitem__(self, idx: Union[int, slice]) -> AttrProxy: + if isinstance(idx, slice): + if isinstance(self, torch.nn.Sequential): + # Copied from nn/modules/container.py + res = torch.nn.Sequential( + OrderedDict(list(self._modules.items())[idx]) + ) + # pyrefly: ignore [index-error] + return AttrProxy(res, f"{tracer.proxy_paths[self]}.{idx}") + elif isinstance(self, torch.nn.ModuleList): + # Copied from nn/modules/container.py + res = torch.nn.ModuleList(list(self._modules.values())[idx]) + # pyrefly: ignore [index-error] + return AttrProxy(res, f"{tracer.proxy_paths[self]}.{idx}") + + return super().__getitem__(idx) # type: ignore[misc] + + @property + def _modules(self) -> dict[str, AttrProxy]: + assert "_modules" in self.__dict__ + submodules = self.__dict__["_modules"] + assert isinstance(submodules, dict) + return { + key: ( + AttrProxy(value, tracer.proxy_paths[self] + "." + str(key)) # type: ignore[misc] + if value is not None + else value + ) + for key, value in submodules.items() + } + + self.proxy_type = AttrProxy + + def path_of_module(self, mod: Module) -> str: + """ + Use tracked access path during tracing instead of the default BFS behavior. + Still use all the possible module paths to verify the result. + """ + if mod is self.scope_root: + return "" + + if isinstance(mod, _AttrProxy): + return self.proxy_paths[mod] + + try: + return Tracer.path_of_module(self, mod) + except NameError as e: + raise _ModuleNotInstalledAsSubmoduleError from e + + def getattr( + self, attr: str, attr_val: object, parameter_proxy_cache: dict[str, Proxy] + ) -> object: + if ( + not isinstance(attr_val, Module) + or isinstance(attr_val, fx.GraphModule) + or not self.enable_attr_proxy + ): + return super().getattr(attr, attr_val, parameter_proxy_cache) + if isinstance(attr_val, _AttrProxy): + return attr_val + + # See NOTE [caching AttrProxy]. + if attr_val not in self.attr_proxy_map: + self.attr_proxy_map[attr_val] = self.proxy_type(attr_val, attr) + else: + self.attr_proxy_map[attr_val].reset_proxy_mapping(attr_val, attr) + return self.attr_proxy_map[attr_val] + + def trace( # type: ignore[override] + self, root: Union[Module, Callable], concrete_args: Optional[dict[str, object]] + ) -> fx.Graph: + res = super().trace(root, concrete_args) + + # NOTE [export non-strict fake tensor leak detection] + # In non-strict export, we don't have dynamo's side effect + # tracking logic which makes some cases hard to detect. + # In general, our detecting strategy is: + # (1) We instrument fake tensor creation to log all the fake tensors created during export. + # (2) We dump the proxy to fake tensor map from make_fx tracer (_FAKE_TENSOR_ID_TO_PROXY_MAP_FOR_EXPORT)) + # (3) Filter out fake tensors that are logged during (1): + # (1) Associated with TrackedFake (input tracking thing in symbolic_shapes) + # (2) Associated with gm.meta + # (4) Do ID match with the proxies + + global _FAKE_TENSOR_ID_TO_PROXY_MAP_FOR_EXPORT + _FAKE_TENSOR_ID_TO_PROXY_MAP_FOR_EXPORT.clear() + + for key, val in self.tensor_tracker.items(): + _FAKE_TENSOR_ID_TO_PROXY_MAP_FOR_EXPORT[id(key)] = val.proxy.node + + # Since we are making _AttrProxy mimic the original + # submodule, when someone registers a module directly + # to the tracer while tracing, the proxy object gets registered + # first. So we need to replace the proxy modules with the real ones + # This can happen during HOO tracing + proxy_module_names_to_be_replaced: list[tuple[str, _AttrProxy]] = [] + for name, module in self.root.named_modules(): + if module in self.proxy_modules: + proxy_module_names_to_be_replaced.append((name, module)) + + def _delete_proxy_attr(obj: Module, target: str) -> bool: + # Copied from fx/graph_module.py + # Customized it for proxy type + atoms = target.split(".") + path, target_submod = atoms[:-1], atoms[-1] + assert isinstance(obj, Module) + mod = obj + + # Get the parent module + for item in path: + if not hasattr(mod, item): + return False + + mod = getattr(mod, item) + + if not isinstance(mod, (_AttrProxy, Module)): + return False + + if not hasattr(mod, target_submod): + return False + + # At least the leaf module should be proxy type. + if not isinstance(getattr(mod, target_submod), _AttrProxy): + return False + + delattr(mod, target_submod) + return True + + for proxy_module_name, proxy_module in proxy_module_names_to_be_replaced: + _delete_proxy_attr(self.root, proxy_module_name) + actual_module = self.proxy_modules[proxy_module] + _assign_attr(actual_module, self.root, proxy_module_name) + + return res + + def call_module( + self, + m: Module, + forward: Callable, + args: tuple[object, ...], + kwargs: dict[str, object], + ) -> None: + """PythonKeyTracer overrides call_module to avoid the scope handling, + but we actually want it. + """ + from torch._dynamo import OptimizedModule + + # FIXME (tmanlaibaatar) + # When we call torch.compile inside HOO, we will end up + # invoking a module that is not registered on the root. For + # now, we just inline them. But once we start supporting + # mark_strict in export, we do need to properly handle this. + # Right now, it doesn't matter because current non-strict + # use cases don't need to work with HOO. + if isinstance(m, (OptimizedModule, GraphModule)): + return forward(*args, **kwargs) + + try: + return Tracer.call_module(self, m, forward, args, kwargs) + except _ModuleNotInstalledAsSubmoduleError: + log.debug( + "Unable to find the path of the module %s. " + "This might be because the module was not properly registered " + "as a submodule, which is not good practice. We will trace " + "through the module without recording stack information.", + str(m), + ) + return forward(*args, **kwargs) + + def is_leaf_module(self, m: Module, module_qualified_name: str) -> bool: + return False + + def create_node(self, *args: object, **kwargs: object) -> fx.node.Node: + """ + Create node and add on metadata. + Add nn_module_stack here instead of TracerBase, + since calls to make_fx() might not want to record module stack metadata. + Add torch_fn by looking at torch_fn_metadata and torch_fn_counts. + Add stack_trace by filtering out forward() stack frames. + """ + node = super().create_node(*args, **kwargs) # type: ignore[arg-type] + + # nn_module_stack + if node.op not in ["placeholder", "output"]: + if node.meta.get("nn_module_stack") is None: + node.meta["nn_module_stack"] = self.module_stack.copy() + # convert nn_module_stack from Dict[key, (FQN, class)] -> Dict[str, Tuple[str, str]] + for key, (fqn, mod_cls) in node.meta["nn_module_stack"].items(): + if isinstance(mod_cls, type): + node.meta["nn_module_stack"][key] = ( + fqn, + mod_cls.__module__ + "." + mod_cls.__qualname__, + ) + + # torch_fn + if ( + node.op == "call_function" + and self.torch_fn_metadata is not None + and "torch_fn" not in node.meta + ): + node.meta["torch_fn"] = ( + f"{self.torch_fn_metadata.__name__}_{self.torch_fn_counts[self.torch_fn_metadata]}", + f"{self.torch_fn_metadata.__class__.__name__}.{self.torch_fn_metadata.__name__}", + ) + + return node + + +class _MakefxTracer: + def __init__( + self, + decomposition_table: Optional[Mapping[OpOverload, Callable]], + tracing_mode: str, + _allow_non_fake_inputs: bool, + pre_dispatch: bool, + record_module_stack: bool, + _allow_fake_constant: bool, + _error_on_data_dependent_ops: bool, + record_stack_traces: bool = False, + parent_tracer: Optional[_MakefxTracer] = None, + proxy_module_inputs: bool = False, + ) -> None: + # Configurations that are used to initialize the context managers and their states. + # Should not modify them during tracing. + self.decomposition_table: dict[OpOverload, Callable] = dict( + decomposition_table or {} + ) + self.decomposition_table.setdefault( + torch.ops.aten.sym_numel.default, torch._decomp.decompositions.sym_numel + ) + self.tracing_mode: str = tracing_mode + self._allow_non_fake_inputs: bool = _allow_non_fake_inputs + self.pre_dispatch: bool = pre_dispatch + self.record_module_stack: bool = record_module_stack + self._allow_fake_constant: bool = _allow_fake_constant + self._error_on_data_dependent_ops: bool = _error_on_data_dependent_ops + + # All context managers and their states should be initialized before tracing based on the inputs + # and configurations. After tracing, their states should be cleaned except for shape_env. + # Remember to specify how to initialize it from user inputs and from parent tracer whenever + # adding new modes in _MakefxTracer. + self.fake_tensor_mode: Optional[FakeTensorMode] = None + self.proxy_mode: Union[nullcontext, ProxyTorchDispatchMode] = nullcontext() + self.proxy_function_mode: Union[nullcontext, PreDispatchTorchFunctionMode] = ( + nullcontext() + ) + self.fx_tracer: Optional[PythonKeyTracer] = None + self.python_dispatcher_mode: Union[nullcontext, Any] = nullcontext() + self.torch_fn_metadata_mode: Union[nullcontext, TorchFunctionMetadataMode] = ( + nullcontext() + ) + self.record_stack_traces = record_stack_traces + self.parent_tracer: Optional[_MakefxTracer] = parent_tracer + self.proxy_module_inputs = proxy_module_inputs + + def _checkpoint_modes(self) -> list[Any]: + return [ + self.fake_tensor_mode, + self.proxy_mode, + self.proxy_function_mode, + self.fx_tracer, + self.python_dispatcher_mode, + self.torch_fn_metadata_mode, + ] + + def _restore_modes( + self, + prev_fake_tensor_mode: Optional[FakeTensorMode], + prev_proxy_mode: Union[nullcontext, ProxyTorchDispatchMode], + prev_proxy_function_mode: Union[nullcontext, PreDispatchTorchFunctionMode], + prev_fx_tracer: Optional[PythonKeyTracer], + prev_python_dispatcher_mode: Union[nullcontext, Any], + prev_torch_fn_metadata_mode: Union[nullcontext, TorchFunctionMetadataMode], + ) -> None: + self.fake_tensor_mode = prev_fake_tensor_mode + self.proxy_mode = prev_proxy_mode + self.proxy_function_mode = prev_proxy_function_mode + self.fx_tracer = prev_fx_tracer + self.python_dispatcher_mode = prev_python_dispatcher_mode + self.torch_fn_metadata_mode = prev_torch_fn_metadata_mode + + @contextmanager + def _init_modes_from_inputs( + self, f: Callable, args: tuple[object, ...] + ) -> Generator[None, None, None]: + prev_modes = self._checkpoint_modes() + try: + # Avoid importing sympy at a module level + from .symbolic_shapes import ShapeEnv + + if hasattr(f, "_orig_mod") and self.record_module_stack: + scope_root = f._orig_mod + # _ModuleStackTracer always try to preserve stack trace + # in forward functions + self.fx_tracer = _ModuleStackTracer(scope_root) + else: + self.fx_tracer = PythonKeyTracer() + self.fx_tracer.record_stack_traces = self.record_stack_traces + if self.record_stack_traces: + self.fx_tracer._record_forward_stack_traces_only = True + + if self.tracing_mode == "fake": + import torch._dynamo + + fake_tensor_mode = torch._dynamo.utils.detect_fake_mode(args) + if fake_tensor_mode is None: + import torch._functorch.config as _config + + with _config.patch(fake_tensor_allow_unsafe_data_ptr_access=False): + fake_tensor_mode = FakeTensorMode( + allow_fallback_kernels=True, + allow_non_fake_inputs=self._allow_non_fake_inputs, + shape_env=ShapeEnv(), + static_shapes=True, + ) + self.fake_tensor_mode = fake_tensor_mode + elif self.tracing_mode == "symbolic": + import torch._dynamo + + fake_tensor_mode = torch._dynamo.utils.detect_fake_mode(args) + if fake_tensor_mode is None: + shape_env = ShapeEnv() + import torch._functorch.config as _config + + with _config.patch(fake_tensor_allow_unsafe_data_ptr_access=False): + fake_tensor_mode = FakeTensorMode( + allow_fallback_kernels=False, + allow_non_fake_inputs=self._allow_non_fake_inputs, + shape_env=shape_env, + ) + assert fake_tensor_mode.shape_env is not None, ( + "shape_env should be set if tracing with 'symbolic'" + ) + self.fake_tensor_mode = fake_tensor_mode + else: + if not self.tracing_mode == "real": + raise AssertionError( + f"Unexpected tracing type: {self.tracing_mode}" + ) + + self._construct_modes_with_fx_tracer(self.fx_tracer) + yield + finally: + self._restore_modes(*prev_modes) + + def _construct_modes_with_fx_tracer(self, fx_tracer: _ProxyTracer) -> None: + self.proxy_mode = ProxyTorchDispatchMode( + fx_tracer, + self.tracing_mode, + pre_dispatch=self.pre_dispatch, + _allow_fake_constant=self._allow_fake_constant, + _error_on_data_dependent_ops=self._error_on_data_dependent_ops, + ) + + if self.pre_dispatch: + self.proxy_function_mode = PreDispatchTorchFunctionMode(fx_tracer) + + # pre-autograd tracing uses per-dispatch-key modes, + # which requires the python dispatcher + if self.tracing_mode == "symbolic" or self.pre_dispatch: + self.python_dispatcher_mode = enable_python_dispatcher() + + self.torch_fn_metadata_mode = TorchFunctionMetadataMode(fx_tracer) + fx_tracer.proxy_module_inputs = self.proxy_module_inputs # type: ignore[union-attr] + + @contextmanager + def _init_modes_from_parent( + self, parent_tracer: _MakefxTracer + ) -> Generator[None, None, None]: + # By default, subtracer creates new modes based on parent tracer's config. + # However, there are cases where we want to share the same modes with parent tracer + # For example, fake_tensor_mode, we want the example value's fake_mode of parent graph and subgraphs to be the same. + prev_modes = self._checkpoint_modes() + try: + self.fake_tensor_mode = parent_tracer.fake_tensor_mode + + def _create_sub_fx_tracer(parent_tracer: _ProxyTracer) -> PythonKeyTracer: + if type(parent_tracer) is PythonKeyTracer: + return PythonKeyTracer() + elif type(parent_tracer) is _ModuleStackTracer: + return _ModuleStackTracer(parent_tracer.scope_root) + else: + raise RuntimeError( + f"Unexpected tracer type: {type(parent_tracer)}." + ) + + assert parent_tracer.fx_tracer is not None + self.fx_tracer = _create_sub_fx_tracer(parent_tracer.fx_tracer) + self._construct_modes_with_fx_tracer(self.fx_tracer) + yield + finally: + self._restore_modes(*prev_modes) + + def _trace_inner(self, f: Callable, *args: object) -> GraphModule: + # TODO: We need to explicitly import torch._dynamo before calling dispatch_trace, + # because dispatch_trace will introduce the lazy import of torch._dynamo, + # and some contexts set before calling dispatch_trace will cause problems with the import of torch._dynamo, + # such as some torch API(torch.ones and so on) in populate_builtin_to_tensor_fn_map() will be affected + # by the context set before dispatch_trace. + import torch._dynamo + + phs = pytree.tree_map(lambda _: torch.fx._symbolic_trace.PH, args) + + def _wrap_fake(args: T) -> T: + arg_count = 0 + + def inner_wrap_fake(x: object) -> object: + nonlocal arg_count + # TODO: it would be nice to line these up with the names + # FX will choose for the placeholders, but we don't + # actually know what the names will be at this point yet + # NB: the Source here is actually meaningless + from torch._dynamo.source import ConstantSource + + assert self.fake_tensor_mode is not None + source = ConstantSource(f"input{arg_count}") + if isinstance(x, Tensor): + arg_count += 1 + return self.fake_tensor_mode.from_tensor(x, source=source) + # NB: don't match on bools + elif type(x) is int and self.tracing_mode == "symbolic": + assert self.fake_tensor_mode.shape_env is not None, ( + "shape_env should be set if tracing with 'symbolic'" + ) + return self.fake_tensor_mode.shape_env.create_symintnode( + self.fake_tensor_mode.shape_env.create_symbol( + x, source, positive=None + ), + hint=x, + source=source, + ) + elif isinstance(x, torch.ScriptObject) or is_opaque_type(type(x)): + return torch._library.fake_class_registry.maybe_to_fake_obj( + self.fake_tensor_mode, x + ) + + assert not isinstance(x, FakeScriptObject), ( + f"ScriptObject {x} has been fakified. Cannot wrap_fake it again." + ) + return x + + wrap_fn_map = { + "real": lambda x: x, + "fake": inner_wrap_fake, + "symbolic": inner_wrap_fake, + } + return pytree.tree_map(wrap_fn_map[self.tracing_mode], args) + + def _wrap_func(f: Callable[_P, R], phs: Sequence[PHBase]) -> Callable[_P, R]: + if ( + not hasattr(inspect.unwrap(f), "__code__") + or inspect.unwrap(f).__code__.co_flags & inspect.CO_VARARGS + ): + # FX doesn't support varargs, so we gotta fake up a wrapper + # TODO: Would be nice to fix this at the source... + return fake_signature(f, len(phs)) + return f + + args = _wrap_fake(args) + func = _wrap_func(f, phs) + # We disable the autocast cache as the autocast cache causes type conversions on parameters to + # check a cache, which introduces untracked tensors into the graph + # + # We also disable tracing by any other tensor proxy-based tracers except the current. The + # purpose of `make_fx` is to produce graphmodules as a side effect; its internal execution is + # thus irrelevant to any external functional trace. + proxy_mode: ProxyTorchDispatchMode = typing.cast( + ProxyTorchDispatchMode, self.proxy_mode + ) + with ExitStack() as stack: + stack.enter_context(decompose(self.decomposition_table)) + if self.fake_tensor_mode: + stack.enter_context(self.fake_tensor_mode) + stack.enter_context(self.python_dispatcher_mode) + stack.enter_context(self.proxy_function_mode) + stack.enter_context(self.torch_fn_metadata_mode) + stack.enter_context(proxy_mode) + stack.enter_context(disable_autocast_cache()) + stack.enter_context(_set_make_fx_tracer(self)) + + assert self.fx_tracer is not None + try: + t = dispatch_trace( + wrap_key(func, args, self.fx_tracer, self.pre_dispatch), + tracer=self.fx_tracer, + concrete_args=tuple(phs), + ) + except Exception: + trace_structured( + "artifact", + metadata_fn=lambda: { + "name": "make_fx_fail_partial", + "encoding": "string", + }, + payload_fn=lambda: self.fx_tracer.graph.python_code( # type: ignore[union-attr] + root_module="self", + verbose=True, + include_stride=True, + include_device=True, + ).src, + ) + raise + + if ( + self.is_hop_subgraph_tracer() + and (fake_mode := torch._guards.detect_fake_mode(args)) + and fake_mode.shape_env is not None + ): + from torch.fx.passes.runtime_assert import insert_deferred_runtime_asserts + + insert_deferred_runtime_asserts(t, fake_mode.shape_env, "reenter_make_fx") + t.recompile() + # TODO: kind of a bad way to do it, should maybe figure out a better way + if self.tracing_mode == "symbolic": + assert self.fake_tensor_mode is not None + t.shape_env = self.fake_tensor_mode.shape_env # type: ignore[assignment] + return t + + def trace(self, f: Callable, *args: object) -> fx.GraphModule: + with self._init_modes_from_inputs(f, args): + return self._trace_inner(f, *args) + + def is_hop_subgraph_tracer(self) -> bool: + return self.parent_tracer is not None + + def trace_subgraph(self, f: Callable, *args: object) -> GraphModule: + # Create a new tracer based on parent's config + sub_tracer = _MakefxTracer( + self.decomposition_table, + "real", + self._allow_non_fake_inputs, + self.pre_dispatch, + self.record_module_stack, + self._allow_fake_constant, + self._error_on_data_dependent_ops, + parent_tracer=self, + ) + with sub_tracer._init_modes_from_parent(self): + return sub_tracer._trace_inner(f, *args) + + +_CURRENT_MAKE_FX_TRACER: Optional[_MakefxTracer] = None + + +@contextmanager +def _set_make_fx_tracer(tracer: _MakefxTracer) -> Generator[None, None, None]: + global _CURRENT_MAKE_FX_TRACER + prev_tracer = _CURRENT_MAKE_FX_TRACER + try: + _CURRENT_MAKE_FX_TRACER = tracer + yield + finally: + _CURRENT_MAKE_FX_TRACER = prev_tracer + + +def make_fx( + f: Callable, + decomposition_table: Optional[Mapping[OpOverload, Callable]] = None, + tracing_mode: str = "real", + _allow_non_fake_inputs: bool = False, + *, + pre_dispatch: bool = False, + record_module_stack: bool = False, + _allow_fake_constant: bool = False, + _error_on_data_dependent_ops: bool = True, + record_stack_traces: bool = False, + proxy_module_inputs: bool = False, +) -> Callable[..., GraphModule]: + """ + Given a function f, return a new function which when executed with valid + arguments to f, returns an FX GraphModule representing the set of operations that + were executed during the course of execution. + + If record_stack_traces is True, the stack trace will be preserved on node.meta["stack_trace"] + """ + + assert tracing_mode in ["real", "fake", "symbolic"] + + from torch._inductor import config + + make_fx_tracer = _MakefxTracer( + decomposition_table, + tracing_mode, + _allow_non_fake_inputs, + pre_dispatch, + record_module_stack, + _allow_fake_constant, + _error_on_data_dependent_ops, + record_stack_traces=record_stack_traces + or config.trace.provenance_tracking_level == 1, + proxy_module_inputs=proxy_module_inputs, + ) + + @functools.wraps(f) + def wrapped(*args: object) -> GraphModule: + return make_fx_tracer.trace(f, *args) + + return wrapped + + +def get_torch_dispatch_modes() -> list[TorchDispatchMode]: + return torch.utils._python_dispatch._get_current_dispatch_mode_stack() + + +# TODO: this is a legacy name, there is only ever one proxy mode as it's an +# infra mode +def get_innermost_proxy_mode() -> Optional[ProxyTorchDispatchMode]: + return get_proxy_mode() + + +def get_proxy_mode() -> Optional[ProxyTorchDispatchMode]: + """ + Current the currently active proxy tracing mode, or None if + we are not currently tracing. This includes pre-dispatch proxy + tracing. + """ + pre_dispatch_mode = torch._ops._get_dispatch_mode_pre_dispatch( + torch._C._TorchDispatchModeKey.PROXY + ) + mode = torch._C._get_dispatch_mode(torch._C._TorchDispatchModeKey.PROXY) + assert pre_dispatch_mode is None or mode is None, ( + f"pre_dispatch_mode={pre_dispatch_mode}, mode={mode}" + ) + return pre_dispatch_mode or mode + + +def handle_sym_dispatch( + func: Callable[_P, R], + args: _P.args, # type: ignore[valid-type] # not allowed to use _P.args here + kwargs: _P.kwargs, # type: ignore[valid-type] # not allowed to use _P.kwargs here +) -> R: + """ + Call into the currently active proxy tracing mode to do a + SymInt/SymFloat/SymBool dispatch trace on a function that operates on + these arguments. + """ + mode = get_proxy_mode() + assert mode + # Have to do it manually, because we're not doing the normal torch + # dispatch machinery which disables it for us + with disable_proxy_modes_tracing(): + # TODO: properly compute types + types: list[type] = [] + return mode.__sym_dispatch__(func, types, args, kwargs) # type: ignore[arg-type, return-value] + + +@contextmanager +def disable_proxy_modes_tracing() -> Generator[ProxyTorchDispatchMode, None, None]: + return _disable_infra_mode(torch._C._TorchDispatchModeKey.PROXY) + + +def maybe_handle_decomp( + proxy_mode: ProxyTorchDispatchMode, + op: OpOverload, + args: tuple[object, ...], + kwargs: dict[str, object], +) -> object: + from torch._inductor.compiler_bisector import CompilerBisector + + if op in CURRENT_DECOMPOSITION_TABLE: + if CompilerBisector.disable_subsystem( + "aot_eager_decomp_partition", "decomposition", lambda: repr(op) + ): + return NotImplemented + + with proxy_mode: + proxy_mode.decomp_layers += 1 + out = CURRENT_DECOMPOSITION_TABLE[op](*args, **kwargs) + proxy_mode.decomp_layers -= 1 + return out + + return NotImplemented + + +def get_isolated_graphmodule( + func: Callable, + args: tuple[object, ...], + kwargs: dict[str, object], + tracing_mode: str = "real", + decomposition_table: Optional[Mapping[OpOverload, Callable]] = None, +) -> GraphModule: + """A helper function used to get the GraphModule for the given func. + + It's expected to be used in the ProxyTensor tracing context. + It detaches the args and kwargs from the current tracer so that the trace of + the current graph module can be created without any side-effects. + """ + wrapped, all_args = wrapper_and_args_for_make_fx(func, args, kwargs) + + with disable_proxy_modes_tracing(): + gm = make_fx( + wrapped, decomposition_table=decomposition_table, tracing_mode=tracing_mode + )(all_args) + return gm + + +def _set_unbacked_bindings(out: object, out_proxy: _NestedProxys) -> None: + """A helper function for setting up unbacked_bindings on the destination FX graph.""" + from .symbolic_shapes import compute_unbacked_bindings + + # Can't use detect_fake_mode here, + # + # python test/distributed/_tensor/test_dtensor_compile.py -k + # test_tp_compile_fullgraph_is_seq_parallel_False + # + # will fail. Very strange, it probably isn't right for them to be using + # two fake modes there... + fake_mode = torch._C._get_dispatch_mode(torch._C._TorchDispatchModeKey.FAKE) + if fake_mode and fake_mode.shape_env: + if symbol_to_path := compute_unbacked_bindings(fake_mode.shape_env, out): + assert isinstance(out_proxy, Proxy), out_proxy + out_proxy.node.meta["unbacked_bindings"] = symbol_to_path diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/recording.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/recording.py new file mode 100644 index 0000000000000000000000000000000000000000..4ec092898cd69d74362acbe57a029b09d9b23bee --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/recording.py @@ -0,0 +1,530 @@ +# mypy: allow-untyped-defs +import functools +import inspect +import itertools +import logging +from collections.abc import Callable +from dataclasses import dataclass +from typing import Any, Optional, Union + +import torch +import torch.utils._pytree as pytree + + +log = logging.getLogger(__name__) +trace_shape_events_log = torch._logging.getArtifactLogger( + __name__, "trace_shape_events" +) + + +__all__ = [ + "ShapeEnvEvent", + "record_shapeenv_event", + "replay_shape_env_events", + "FakeTensorMeta", + "shape_env_check_state_equal", + "NotEqualError", +] + +# [Note: Recording ShapeEnv Events] +# ================================= +# +# What is a ShapeEnv event? +# ------------------------- +# We consider a ShapeEnv event every function call (ShapeEnv method or +# independent function) that modifies the state of the ShapeEnv instance. +# Such calls are recorded alongside their positional and keyword arguments, +# so that it may be replayed over a different ShapeEnv instance. +# +# See [Note: ShapeEnv State Equality] for what is considered the state +# of a ShapeEnv instance. +# +# What is it for? +# --------------- +# ShapeEnv events recording is used for reconstructing the ShapeEnv in an +# arbitrary state in time. +# +# Being able to arbitrarily replay events like so is useful, mainly for +# translation validation bisection. i.e. if a ValidationException has been +# raised, find the earliest point in time where the translation validation +# fails. +# +# Besides that, it also allows us to inspect the given instance and, +# for example, check the guards that would actually be issued at that point. +# +# What kind of arguments can be stored in an event? +# ------------------------------------------------- +# There's no specific rule for what cannot be used as an argument. +# That said, pay special attention to the following cases: +# +# 1. Tensor inputs: there are some tests that check whether the inputs +# were garbage collected after execution. These will fail if there's +# an event that is holding a reference to those inputs. +# +# 2. ShapeEnv arguments: if there is an argument of ShapeEnv type, that +# will be automatically replaced by the new given ShapeEnv instance. +# +# 3. SymTypes arguments: they also hold references to ShapeEnv. So, +# whenever we see them, we create a new instance, replacing the +# ShapeEnv reference. +# +# 4. FX nodes: specifically, FX nodes from the FX graph for symbolic +# shapes. That argument must be replaced when replaying the event at +# ShapeEnvEvent.run, since it has to reference a node from the given +# instance, and not from the recorded instance. + + +# Event class for reconstructing ShapeEnv at arbitrary time. +# +# Represents a method call that mutates ShapeEnv in a way that affects the +# issued guards, when ShapeEnv.produce_guards is called. +@dataclass +class ShapeEnvEvent: + # ShapeEnv method. + f: Callable + + # Arguments and keyword arguments called with. + args: Optional[list[Any]] = None + kwargs: Optional[dict[str, Any]] = None + + # List of tracked_fakes at the time the method was called. + tracked_fakes: Optional[list[Any]] = None + + # Name of the captured event. + # Used for special handling of particular methods. + name: Optional[str] = None + + # Replay itself, but using shape_env as self. + def run(self, shape_env=None) -> Any: + from torch.fx.experimental.symbolic_shapes import ( + is_symbolic, + ShapeEnv, + SymTypes, + ) + + # Special handling for the constructor event. + if self.f is ShapeEnv: + assert shape_env is None and self.args is None and self.kwargs is not None + return ShapeEnv(**self.kwargs) + + assert shape_env is not None + args = list(self.args or []) + kwargs = dict(self.kwargs or {}) + + # Replace any argument of type ShapeEnv by the given one. + args, kwargs = pytree.tree_map_only( + ShapeEnv, lambda _: shape_env, (args, kwargs) + ) + + # Replace any argument of type SymTypes by a new instance, + # replacing its ShapeEnv reference. + args, kwargs = pytree.tree_map_only( + lambda x: isinstance(x, SymTypes) and is_symbolic(x), + lambda a: type(a)(a.node.with_shape_env(shape_env)), + (args, kwargs), + ) + + # Converts FX nodes using the mapping argument. + def maybe_convert_node(x: Any) -> Any: + if not isinstance(x, torch.fx.Node): + # Don't do anything to x if it's not an FX node. + return x + + # If, at some point, we created an FX node, it means that translation validation is on. + # It also means we are building an FX graph for symbolic shapes at shape_env.graph, and + # we are tracking node names at shape_env.name_to_node. + assert hasattr(shape_env, "name_to_node") + name_to_node = shape_env.name_to_node # type: ignore[attr-defined] + assert x.name in name_to_node + return name_to_node[x.name] + + # Replaces the value of an specific argument by the result of fn. + def replacearg(index: int, key: str, fn: Callable): + if index < len(args): + args[index] = fn(args[index]) + if key in kwargs: + kwargs[key] = fn(kwargs[key]) + + if self.is_create_fx_call_function(): + # ShapeEnv.create_fx_call_function: + # "args" parameter is a tuple of FX nodes from the FX graph of the old ShapeEnv. + # They must be replaced, since a "call_function" FX node with this tuple as argument + # will be added to the FX graph of the new shape_env. + replacearg( + index=2, + key="args", + fn=lambda args: tuple(maybe_convert_node(a) for a in args), + ) + if self.is_evaluate_expr() or self.is_defer_runtime_assert(): + # ShapeEnv.evaluate_expr and ShapeEnv.guard_or_defer_runtime_assert: + # "fx_node" parameter is an (optional) FX node that represents the evaluate expression. + # They must be replaced, since it will be part of a "call_function" FX node for + # torch._assert, which will be added to the FX graph of the new shape_env. + replacearg(index=3, key="fx_node", fn=maybe_convert_node) + + # Actually call the method with the converted arguments. + return self.f(*args, **kwargs) + + def __str__(self) -> str: + name = self.name if self.name is not None else self.f.__name__ + return f"event: {name} ({self.args}, {self.kwargs})" + + def is_create_fx_call_function(self) -> bool: + return self.name == "_create_fx_call_function" + + def is_evaluate_expr(self) -> bool: + return self.name == "evaluate_expr" + + def is_defer_runtime_assert(self) -> bool: + return self.name == "guard_or_defer_runtime_assert" + + +NEST = 0 + + +# Extracts a ShapeEnv instance inside args and kwargs. +# Specifically, it looks for: +# 1. ShapeEnv arguments +# 2. SymInt, SymFloat, or SymBool arguments +# If we find more than one object of any of the above types, we +# also check that the ShapeEnv instance is the same for all of them. +def _extract_shape_env_and_assert_equal(args, kwargs): + from torch.fx.experimental.symbolic_shapes import is_symbolic, ShapeEnv, SymTypes + + def assert_equal(old: Optional[ShapeEnv], new: ShapeEnv) -> ShapeEnv: + if old is not None: + assert old is new, "call with different ShapeEnv" + return new + + shape_env = None + for val in itertools.chain(args, kwargs.values()): + if isinstance(val, ShapeEnv): + shape_env = assert_equal(shape_env, val) + if isinstance(val, SymTypes) and is_symbolic(val): + shape_env = assert_equal(shape_env, val.node.shape_env) + + return shape_env + + +# Decorator for recording the given function as a replayable event. +# +# This decorator should be used at every function that mutates the state of +# ShapeEnv in some way that affects the resulting issued guards (i.e. when +# ShapeEnv.produce_guards is called). +# +# save_tracked_fakes: saves a snapshot of the TrackedFake list. +# This is used when calling ShapeEnv.produce_guards at arbitrary points in time. +# +# name: the name of the function being recorded. Normally (and by default) this +# is taken from the decorated function but can be set if you need to override +# it. +# +# When to save the list of TrackedFake? +# ===================================== +# We should save the list of TrackedFake whenever the translation validation +# bisection may actually stop and call the produce_guards method at the moment +# right after the recorded function was played. In other words, since the +# bisection bisects through torch._assert calls, we should save in all methods +# that adds a torch._assert call to the symbolic shapes FX graph. +# +# At the moment, there are 2 methods that save the list: +# - ShapeEnv.evaluate_expr +# - ShapeEnv.guard_or_defer_runtime_assert +def record_shapeenv_event( + *, save_tracked_fakes: bool = False, name: Optional[str] = None +) -> Callable: + def decorator(fn: Callable) -> Callable: + assert callable(fn) + args = inspect.getfullargspec(fn).args + assert args and args[0] == "self", ( + "record_shapeenv_event should only wrap methods on ShapeEnv; refactor your " + "code so that it calls into a method on ShapeEnv" + ) + nonlocal name + if name is None: + name = fn.__name__ + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + from torch.fx.experimental.symbolic_shapes import ShapeEnv + + assert isinstance(args[0], ShapeEnv) + + global NEST + + trace_shape_events_log.debug( + "%scall %s(*%r, **%r)", " " * NEST, name, args[1:], kwargs + ) + NEST += 1 + + def retlog(r): + trace_shape_events_log.debug("%s-> %s", " " * (NEST - 1), r) + return r + + shape_env = args[0] + + try: + if not shape_env.should_record_events or shape_env.is_recording: # type: ignore[has-type] + # If ShapeEnv is already recording an event, call the wrapped + # function directly. + # + # NB: here, we skip the check of whether all ShapeEnv instances + # are equal, in favor of a faster dispatch. + return retlog(fn(*args, **kwargs)) + + # Retrieve an instance of ShapeEnv. + # Assumption: the collection of args and kwargs may not reference + # different ShapeEnv instances. + self = _extract_shape_env_and_assert_equal(args, kwargs) + + # If we are calling this function without any ShapeEnv instance + # alive in its arguments, we don't record and call the original. + if self is None: + return retlog(fn(*args, **kwargs)) + + # Otherwise, start recording and call the function. + with self._recording(): + # Take a snapshot of the current tracked_fakes. + tracked_fakes = ( + self._snapshot_tracked_fakes() if save_tracked_fakes else None + ) + # Record the event for 'fn'. + event = ShapeEnvEvent( + fn, + list(args), + kwargs, + tracked_fakes, + name=name, + ) + # Play the event on this ShapeEnv. + # NB: It's important to put the event first, because running + # the event can trigger internal events that must be ordered + # after this event. However, if an exception happens, we do + # NOT want to have the event in the list, so pop it off from + # the record if an error happened + self.events.append(event) + try: + return retlog(event.run(self)) + except Exception: + self.events.pop() + raise + + except Exception: + if not shape_env.should_record_events or shape_env.is_recording: + # If ShapeEnv is disabled or already recording an event, re-raise the exception without logging. + raise + log.error( # noqa: G201 + "failed while running %s(*%s, **%s)", + name, + args[1:], + kwargs, + exc_info=log.isEnabledFor(logging.INFO), + ) + raise + + finally: + NEST -= 1 + + return wrapper + + return decorator + + +# Replays the ShapeEnvEvents list. +# It assumes the first event is the constructor call. +# +# fn: transforms an old FX node into one corresponding to the newly created ShapeEnv. +def replay_shape_env_events(events): + from torch.fx.experimental.symbolic_shapes import ShapeEnv + + constructor_event = events[0] + assert constructor_event.f == ShapeEnv + + # Constructs the new ShapeEnv. + shape_env = constructor_event.run() + + for event in events[1:]: + try: + # Actually replays each event. + # We need to call create_mapping_fn every time, since the node list might + # change after each event is replayed. + event.run(shape_env) + except Exception: + log.error("failed when running event: %s", event) + raise + + return shape_env + + +# FakeTensor metadata. +# This is to be used in place of FakeTensor placeholders when calling +# ShapeEnv.produce_guards. +@dataclass +class FakeTensorMeta: + tensor_size: tuple[Union[int, torch.SymInt], ...] + tensor_stride: tuple[Union[int, torch.SymInt], ...] + tensor_storage_offset: Union[int, torch.SymInt] + is_nested: bool + + def size(self) -> tuple[Union[int, torch.SymInt], ...]: + return self.tensor_size + + def stride(self) -> tuple[Union[int, torch.SymInt], ...]: + return self.tensor_stride + + def storage_offset(self) -> Union[int, torch.SymInt]: + return self.tensor_storage_offset + + def dim(self) -> int: + return len(self.tensor_size) + + @staticmethod + def from_fake(fake) -> "FakeTensorMeta": + return FakeTensorMeta( + fake.size(), fake.stride(), fake.storage_offset(), fake.is_nested + ) + + +# [Note: ShapeEnv State Equality] +# =============================== +# +# What is considered ShapeEnv state? +# ---------------------------------- +# We consider to be the state of a ShapeEnv instance everything that +# is not in the inline tuple inside remove_nonstate_variables function. +# That is: the fields within ShapeEnv that modify the flow of execution +# of the program. +# +# So, for example: the replacements field might influence on how an +# expression is simplified. That, in turn, may result in a guard being +# statically known (i.e. not added). +# +# On the other hand, var_to_stack serves only changes what is printed +# in the screen, i.e. used only for debugging purposes. Therefore, we +# should not consider it when comparing states. +# +# What to do on NotEqualError? +# ---------------------------- +# Here are a few possible causes for getting a NotEqualError raised: +# +# 1. New field that does not belong in the ShapeEnv state. +# For example: log field of type ShapeEnvLoggerAdapter. Different +# ShapeEnv instances will always have different ShapeEnvLoggerAdapter +# instances, i.e. equality comparison would fail. +# Solution: add it to the inlined tuple inside remove_nonstate_variables +# function inside check_equal method. +# +# 2. New field that is not directly comparable across instances. +# For example: guards field of type List[ShapeGuard]. More specifically, +# the ShapeGuard type holds an expression and a stack information +# for debugging purposes. When replaying the even on a new ShapeEnv +# instance, the stack would be different, which would trigger this error. +# Solution: add a special case to the map_value function inside +# check_equal function. +# +# 3. Mutation of ShapeEnv on some not recorded function. +# If a mutation of the state of ShapeEnv happens inside a function +# that is not recorded (or that no caller in the stack is recorded), +# then, the replayed ShapeEnv won't catch that. +# Solution: decorate the function with record_shape_env_event. + + +# Checks whether the state of two ShapeEnv are equal w.r.t. the guards +# returned by ShapeEnv.produce_guards. +def shape_env_check_state_equal(env1, env2, non_state_variable_names, map_value): + # Collect and remove variables that don't necessarily represent the state + # of a ShapeEnv. Note: we copy the dictionary so that we don't modify the + # instance itself. + env1_vars = vars(env1).copy() + env2_vars = vars(env2).copy() + + for v in non_state_variable_names: + if v in env1_vars: + env1_vars.pop(v) + if v in env2_vars: + env2_vars.pop(v) + + # Function for transforming the mismatched values into string. + # Needed, since dict and set entries order might not be the same every time. + def value_to_str(value: Any) -> str: + if isinstance(value, dict): + return ( + "{" + + ", ".join(f"{k}: {value[k]}" for k in sorted(value.keys(), key=str)) + + "}" + ) + if isinstance(value, set): + return "{" + ", ".join(f"{v}" for v in sorted(value)) + "}" + return str(value) + + # Compares env1_vars with env2_vars. + # Here, we allow the value of each field to be mapped, so that we appropriately + # compare the two values. + def compare_vars( + map_value: Callable[[str, Any], Any], + ) -> list[tuple[str, str, str]]: + env1_set, env2_set = set(env1_vars), set(env2_vars) + + # First, compare the set of keys in each vars dictionary. + if env1_set != env2_set: + raise NotEqualError( + "field set mismatch:", + [ + ( + "found unique fields:", + str(sorted(env1_set - env2_set)), + str(sorted(env2_set - env1_set)), + ), + ], + ) + + # Then, sort the keys, and compare the mapped values of each key. + sorted_keys = list(env1_set) + sorted_keys.sort() + + mapped_dict = [ + (k, map_value(k, env1_vars[k]), map_value(k, env2_vars[k])) + for k in sorted_keys + ] + + # Return a list of tuples representing the fields that did not match + # alongside their respective mapped values. + return [ + (f"{k}: values don't match.", value_to_str(val1), value_to_str(val2)) + for k, val1, val2 in mapped_dict + if val1 != val2 + ] + + # Accumulate the mismatching fields. + errors = compare_vars(map_value) + + if len(errors) > 0: + raise NotEqualError("field values don't match:", errors) + + +class NotEqualError(Exception): + def __init__( + self, + msg: str, + mismatched: list[tuple[str, str, str]], + ) -> None: + details = "\n".join( + [ + "\n".join( + [ + f"==> {inner_msg}", + f" > Left: {str1}", + f" > Right: {str2}", + ] + ) + for inner_msg, str1, str2 in mismatched + ] + ) + + super().__init__( + f"""\ +ShapeEnv not equal: {msg} + +{details} +""" + ) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/refinement_types.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/refinement_types.py new file mode 100644 index 0000000000000000000000000000000000000000..8e92163a2139caab2fd2a690d810f52073e75644 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/refinement_types.py @@ -0,0 +1,16 @@ +class Equality: + def __init__(self, lhs: object, rhs: object): + self.lhs = lhs + self.rhs = rhs + + def __str__(self) -> str: + return f"{self.lhs} = {self.rhs}" + + def __repr__(self) -> str: + return f"{self.lhs} = {self.rhs}" + + def __eq__(self, other: object) -> bool: + if isinstance(other, Equality): + return self.lhs == other.lhs and self.rhs == other.rhs + else: + return False diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/rewriter.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/rewriter.py new file mode 100644 index 0000000000000000000000000000000000000000..2cc902599aeb0a36d8253b0cf8cbece3f6e5ac68 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/rewriter.py @@ -0,0 +1,144 @@ +# mypy: allow-untyped-decorators +# mypy: allow-untyped-defs +import ast +import copy +import functools +import inspect +import textwrap +from collections.abc import Callable +from types import FunctionType +from typing import Any, cast, Optional, Union + +import torch +from torch._sources import normalize_source_lines +from torch.fx._symbolic_trace import Tracer +from torch.fx.graph import Graph + + +class AST_Rewriter(ast.NodeTransformer): + """ + Take a FunctionType object representing a `forward` method, then + perform an AST rewrite to swap out nodes that are not symbolically + traceable with a callsite to the FX alternative. + + To support swapping out an AST node, define a new `visit` method on + that node. For more details, see: + https://docs.python.org/3/library/ast.html#ast.NodeTransformer + """ + + # This function checks for new keys added in the globals dict. TorchDynamo + # can insert new keys in the global dict and upset the check. Therefore, put + # a disable here. This function is an optimization pass and not really + # suitable for dynamo tracing anyways. + @torch._dynamo.disable + def rewrite(self, fn: FunctionType): + # Normalize the source lines + sourcelines, _ = inspect.getsourcelines(fn) + sourcelines = normalize_source_lines(sourcelines) + source = "".join(sourcelines) + normalized_str = textwrap.dedent(source) + + # Rewrite the original AST + source_ast = ast.parse(normalized_str) + dest_ast = ast.fix_missing_locations(self.visit(source_ast)) + + # Pull out the compiled function from the newly-created Module + code = compile(dest_ast, "", "exec") + globals_dict = copy.copy(fn.__globals__) + keys_before = set(globals_dict.keys()) + exec(code, globals_dict) + new_keys = list(set(globals_dict.keys()) - keys_before) + assert len(new_keys) == 1 + fn_compiled = globals_dict[new_keys[0]] + + # return the compiled function with the original globals + def change_func_globals(f, globals): + """Based on https://stackoverflow.com/a/13503277/2988730 (@unutbu)""" + # __globals__ is a private member of the function class + # so we have to copy the function, f, all of its member, except f.__globals__ + g = FunctionType( + f.__code__, + globals, + name=f.__name__, + argdefs=f.__defaults__, + closure=f.__closure__, + ) + g = functools.update_wrapper(g, f) + g.__kwdefaults__ = copy.copy(f.__kwdefaults__) # type:ignore[attr-defined] + return g + + # Return the correct FunctionType object + return change_func_globals(fn_compiled, globals=fn.__globals__) + + def visit_Assert(self, node): + """ + Swap out the Assert node (Python's `assert`) with a callsite to the + symbolically-traceable torch._assert function + """ + # Create the Call node + n = ast.parse("torch._assert()", mode="eval") + assert isinstance(n, ast.Expression) + call_node = n.body + assert isinstance(call_node, ast.Call) + msg = node.msg if node.msg else ast.Constant(value="", kind=None) + call_node.args = [node.test, msg] + + # Ensure that the new node conforms to the Python AST grammar + expr_wrapper = ast.Expr(value=call_node) + + # Return the new Call node to signify that we want to use it as + # a replacement for the original _assert node + return ast.copy_location(expr_wrapper, node) + + def visit_AnnAssign(self, node): + """ + Swap out Python's AnnAssign with an Assign node where the annotation function is called. + Example: + Original: + y: Tensor_Type(1,2,3, Dyn) = f2(x) + Output: + y = annotate(f2(x),Tensor_Type((1,2,3,Dyn))) + """ + return ast.Assign( + targets=[node.target], + value=ast.Call( + func=ast.Name(id="annotate", ctx=ast.Load()), + args=[node.value, node.annotation], + keywords=[], + ), + ) + + +class RewritingTracer(Tracer): + def trace( + self, + root: Union[torch.nn.Module, Callable], + concrete_args: Optional[dict[str, Any]] = None, + ) -> Graph: + return super().trace(_rewrite(root), concrete_args) + + +def _rewrite(fn: Union[torch.nn.Module, Callable]) -> Union[torch.nn.Module, Callable]: + if isinstance(fn, torch.nn.Module): + # Rewrite this module's `forward` as well as the `forward`s of + # all of this module's recursive descendents. Return the new, + # rewritten module hierarchy. + def rewrite_module(m: torch.nn.Module): + class RewrittenModule(torch.nn.Module): + def __init__(self, orig): + super().__init__() + for k, v in orig.__dict__.items(): + if isinstance(v, torch.nn.Module): + self.__dict__[k] = copy.copy(rewrite_module(v)) + else: + self.__dict__[k] = copy.copy(v) + + RewrittenModule.forward = AST_Rewriter().rewrite( + cast(FunctionType, m.forward) + ) + return RewrittenModule(m) + + return rewrite_module(fn) + else: + # Rewrite this single free function + return AST_Rewriter().rewrite(cast(FunctionType, fn)) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/schema_type_annotation.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/schema_type_annotation.py new file mode 100644 index 0000000000000000000000000000000000000000..b1b2f1680d64a1ff928a8519dd4d93d61a861a54 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/schema_type_annotation.py @@ -0,0 +1,145 @@ +# mypy: allow-untyped-defs +import inspect +from typing import Any, Optional + +import torch +import torch.fx +from torch._jit_internal import boolean_dispatched +from torch.fx import Transformer +from torch.fx.node import Argument, Target +from torch.fx.operator_schemas import _torchscript_type_to_python_type + + +class AnnotateTypesWithSchema(Transformer): + """ + Use Python function signatures to annotate types for `Nodes` within an FX graph. + This pulls out Python function signatures for: + + 1. Standard `torch.nn` Module calls + 2. `torch.nn.functional` calls + 3. Attribute fetches via `get_attr` + + Example usage: + + m = torchvision.models.resnet18() + + traced = torch.fx.symbolic_trace(m) + + traced = AnnotateTypesWithSchema(traced).transform() + + """ + + def __init__( + self, + module: torch.nn.Module, + annotate_functionals: bool = True, + annotate_modules: bool = True, + annotate_get_attrs: bool = True, + ): + super().__init__(module) + self.annotate_functionals = annotate_functionals + self.annotate_modules = annotate_modules + self.annotate_get_attrs = annotate_get_attrs + + def call_function( + self, target: Target, args: tuple[Argument, ...], kwargs: dict[str, Any] + ): + python_ret_type = None + if self.annotate_functionals and target.__module__ == "torch.nn.functional": + target_for_analysis = target + if target in boolean_dispatched: + # HACK: `boolean_dispatch` as used in `torch.nn.functional` makes it so that we have + # a 2-way dispatch based on a boolean value. Here we check that the `true` and `false` + # branches of the dispatch have exactly the same signature. If they do, use the `true` + # branch signature for analysis. Otherwise, leave this un-normalized + assert not isinstance(target, str) + dispatched = boolean_dispatched[target] + if_true, if_false = dispatched["if_true"], dispatched["if_false"] + # TODO: can we emit the union of these? What are the implications on TorchScript + # compilation? + if ( + inspect.signature(if_true).return_annotation + != inspect.signature(if_false).return_annotation + ): + return super().call_function(target, args, kwargs) + target_for_analysis = if_true + + python_ret_type = self._extract_python_return_type(target_for_analysis) + + return_proxy = super().call_function(target, args, kwargs) + return_proxy.node.type = ( + return_proxy.node.type if return_proxy.node.type else python_ret_type + ) + return return_proxy + + def call_module( + self, target: Target, args: tuple[Argument, ...], kwargs: dict[str, Any] + ): + python_ret_type = None + assert isinstance(target, str) + submod = self.fetch_attr(target) + if self.annotate_modules and hasattr(submod.__class__, "__name__"): + classname = submod.__class__.__name__ + if getattr(torch.nn, classname, None) == submod.__class__: + python_ret_type = self._extract_python_return_type(submod.forward) + return_proxy = super().call_module(target, args, kwargs) + return_proxy.node.type = ( + return_proxy.node.type if return_proxy.node.type else python_ret_type + ) + return return_proxy + + def get_attr( + self, + target: torch.fx.node.Target, + args: tuple[Argument, ...], + kwargs: dict[str, Any], + ): + attr_proxy = super().get_attr(target, args, kwargs) + + if self.annotate_get_attrs: + module_itr = self.module + assert isinstance(target, str) + atoms = target.split(".") + for i, atom in enumerate(atoms): + if not hasattr(module_itr, atom): + raise RuntimeError( + f"Node referenced nonextent target {'.'.join(atoms[:i])}!" + ) + module_itr = getattr(module_itr, atom) + + maybe_inferred_ts_type = torch._C._jit_try_infer_type(module_itr) + if maybe_inferred_ts_type.success(): + python_type = _torchscript_type_to_python_type( + maybe_inferred_ts_type.type() + ) + attr_proxy.node.type = ( + python_type if not attr_proxy.node.type else attr_proxy.node.type + ) + + return attr_proxy + + def _extract_python_return_type(self, target: Target) -> Optional[Any]: + """ + Given a Python call target, try to extract the Python return annotation + if it is available, otherwise return None + + Args: + + target (Callable): Python callable to get return annotation for + + Returns: + + Optional[Any]: Return annotation from the `target`, or None if it was + not available. + """ + assert callable(target) + try: + sig = inspect.signature(target) + except (ValueError, TypeError): + return None + + return ( + sig.return_annotation + if sig.return_annotation is not inspect.Signature.empty + else None + ) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/sym_node.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/sym_node.py new file mode 100644 index 0000000000000000000000000000000000000000..96b44b0aebd4d34eb9dc00fa3bfc0e133fe609bc --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/sym_node.py @@ -0,0 +1,1896 @@ +# mypy: allow-untyped-defs + +from __future__ import annotations + + +""" +This file does three things: +- Contains the definition of SymNode +- Installs all the magic methods into SymBool, SymFloat, SymFloat at import time +- Does not depend on sympy at import time + +As this file is imported from within torch/__init__.py we do not want it to depend on SymPy +to avoid having to load SymPy at import time, as doing so is *very* slow. +""" + + +import builtins +import functools +import inspect +import itertools +import logging +import math +import operator +import sys +from functools import lru_cache, update_wrapper +from typing import Optional, TYPE_CHECKING, Union + +import torch +import torch._logging.structured as structured + +# NB: The sym_* functions are used via getattr() and must be imported here. +from torch import ( # noqa: F401 + sym_float, + sym_ite, + sym_max, + sym_min, + sym_not, + SymBool, + SymFloat, + SymInt, +) +from torch._logging import dtrace_structured + + +if TYPE_CHECKING: + from torch.fx.experimental.symbolic_shapes import ShapeEnv + +log = logging.getLogger(__name__) +sym_node_log = torch._logging.getArtifactLogger(__name__, "sym_node") + + +__all__ = ["SymNode", "method_to_operator", "magic_methods", "DynamicInt"] + + +from torch.types import py_sym_types as SymTypes + + +def _to_symtype(t): + if t is bool: + return SymBool + if t is int: + return SymInt + if t is float: + return SymFloat + return t + + +# TODO: An incomplete list +# 1. Set variables to be equal when we do equality +# 2. Specialize on 0/1 when we do subtraction +class SymNode: + """ + This is a type erased SymInt/SymFloat which we use to do actual operations. + End users don't touch this. Magic methods are NOT defined on this object. + """ + + # Note [optimized_summation]: indicates that SymNode is an Add expression of the form + # a + b + c + d... etc where all terms are unique symbols. This allows us to do some optimizations + # for common patterns see _optimized_add. + + # The unfortunate reason we have this here is because sympy sets __slots__ = () for add expression, + # so we cannot add the attribute directly to the sympy expression. Furthermore, we cannot use it as + # a weak dictionary key either! So instead, we attach the attribute here to the SymNode. + _optimized_summation: bool = False + + def __init__( + self, + expr, + shape_env, + pytype, + hint: Optional[Union[int, float, bool]], + constant=None, + fx_node=None, + optimized_summation=False, + ): + self._expr = expr + self.shape_env = shape_env + self.pytype = pytype + self._optimized_summation = optimized_summation + + # What's the difference between hint and constant? + # + # - A constant is known to be invariant across invocations of the model; + # it will always be this value. We only really know this when we + # encounter an honest-to-goodness literal (when wrapping it into + # a SymNode, we set constant.) Most of the time, constant is None + # + # - A hint is a *particular* value from the particular run we are + # tracing, but it may vary the next time around. It's useful to + # keep this around, as if we need a concrete value from a SymNode, + # we will return the hint and guard on the expression that produced + # it giving the same hint next time around. The hint is not + # guaranteed to be set either: if you have an unbacked SymNode, + # there won't be any hint; it was the result of some tensor-dependent + # computation, but we don't know what it actually is because we + # haven't actually run the tensor computation. + # + # If _hint is None, we will query maybe_evaluate_static(compute_hint=True) + # in hopes that we've learned enough about the unbacked symints to + # discharge the hint; otherwise, you're likely to just error out. + # + # (A previous version of this system had some optimizations to only + # recompute when it was possible we had learned enough about the + # unbacked symint that a hint was now possible, but as we added more + # potential refinements to unbacked symints this got harder to keep + # in sync, so we've deleted it for now.) + + def compute_hint(): + from torch.fx.experimental.symbolic_shapes import has_free_unbacked_symbols + + # This occasionally gets exercised by, e.g., + # convert_shape_to_symint. It's just a nicety so you don't HAVE + # to have a correct hint on hand when making a SymNode. + # Don't attempt to compute for unbacked, this can be quite + # expensive. + if has_free_unbacked_symbols(self.expr): + return None + hint = self.shape_env._maybe_evaluate_static(self.expr, compute_hint=True) + if hint is not None: + hint = self.pytype(hint) if not isinstance(hint, SymTypes) else hint + return hint + + if hint is not None: + assert type(hint) is pytype or type(hint) is _to_symtype(pytype), ( + "Cannot create SymNode of type " + f"{pytype} with incompatible hint of type {type(hint)}" + ) + if self.shape_env and self.shape_env._translation_validation_enabled: + # This is technically not TV, but this assert is expensive so + # let's only do it when we're already doing expensive things + computed_hint = compute_hint() + assert hint == computed_hint, ( + f"{hint} != {computed_hint} (for {self.expr})" + ) + else: + hint = compute_hint() + self._hint = hint + self.constant: Optional[Union[int, float, bool]] = constant + + # Record the FX node of the current node if we are doing translation + # validation. They will be used for building the input assertions for + # the translation validation problem. + tx_validation_en = ( + self.shape_env and self.shape_env._translation_validation_enabled + ) + self.fx_node = tx_validation_en and fx_node + + def with_shape_env(self, shape_env: ShapeEnv) -> SymNode: + return SymNode( + self._expr, shape_env, self.pytype, self._hint, self.constant, self.fx_node + ) + + def _value_eq(self, other: SymNode) -> bool: + # Purposely don't include the shape_env in the eq. + return ( + self._expr == other._expr + and self.pytype == other.pytype + and self._hint == other._hint + and self.constant == other.constant + and self.fx_node == other.fx_node + ) + + def _value_hash(self) -> int: + # Purposely don't include the shape_env in the hash. + return hash((self._expr, self.pytype, self._hint, self.constant, self.fx_node)) + + @property + def expr(self): + return self.shape_env.replace(self._expr) + + @property + def hint(self): + return self._hint + + def has_hint(self): + return self._hint is not None + + def require_hint(self, fallback=None): + from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols + + if self._hint is None: + if fallback is not None: + # Say we have some expr like 2*u0 + s0 + # The hint will be None, since the expr contains at least 1 unbacked. + # We will: + # - replace every backed free symbol with its corresponding hint + # - replace every unbacked free symbol with the fallback + # - regenerate the expression with those symbol replacements + # Note: this is not really complete either, since right now + # this logic does not take into account any value ranges + # for the unbacked symints, we may need to beef it up at some point. + unbacked_symbols = free_unbacked_symbols(self.expr) + replacements = { + s: 4096 if s in unbacked_symbols else self.shape_env.var_to_val[s] + for s in self.expr.free_symbols + } + return self.expr.xreplace(replacements) + # NB: we expect this to raise + return self.shape_env.size_hint(self.expr) + return self._hint + + def maybe_as_int(self): + if self.expr.is_number: + return int(self.expr) + else: + return None + + # NB: This does conversions, not sure if this is good or not + def maybe_as_float(self): + import sympy + + if isinstance(self.expr, sympy.Float): + return float(self.expr) + else: + return None + + def maybe_as_bool(self): + import sympy + + if self.expr is sympy.true: + return True + elif self.expr is sympy.false: + return False + else: + return None + + def is_int(self): + return self.pytype is int + + def is_float(self): + return self.pytype is float + + def is_bool(self): + return self.pytype is bool + + def is_nested_int(self): + # Unbacked SymInts cannot be nested int today + return ( + self._hint is not None + and isinstance(self._hint, SymInt) + and self._hint.node.is_nested_int() + ) + + def wrap_int(self, num): + assert type(num) is int + import sympy + + return SymNode( + sympy.Integer(num), self.shape_env, int, num, constant=num, fx_node=num + ) + + def wrap_float(self, num): + assert type(num) is float + import sympy + + return SymNode( + sympy.Float(num), self.shape_env, float, num, constant=num, fx_node=num + ) + + def wrap_bool(self, num): + assert type(num) is bool + import sympy + + return SymNode( + sympy.true if num else sympy.false, + self.shape_env, + bool, + num, + constant=num, + fx_node=num, + ) + + def clone(self): + return self + + def str(self): + return f"{self.expr}" + + def __str__(self): + return self.str() + + def __repr__(self): + rep = [ + f"SymNode({self._expr}, shape_env={self.shape_env}, pytype={self.pytype}", + ] + if self._hint is not None: + rep.append(f"hint={self._hint}") + if self.constant is not None: + rep.append(f"constant={self.constant}") + if self.fx_node is not None: + rep.append(f"fx_node={self.fx_node}") + return ", ".join(rep) + ")" + + def _graph_repr(self) -> builtins.str: + # Representation used by GraphModule to create a pythonic version of a graph + return self.str() + + # These methods call the metaprogrammed methods, they're hand written + # here so we get good stack traces + def abs(self) -> SymNode: + return self._abs() # type: ignore[attr-defined] + + def pos(self) -> SymNode: + return self._pos() # type: ignore[attr-defined] + + def round(self, ndigits=None) -> SymNode: + return self._round(ndigits) # type: ignore[attr-defined] + + def trunc(self) -> SymNode: + return self._trunc() # type: ignore[attr-defined] + + def add(self, other) -> SymNode: + return self._add(other) # type: ignore[attr-defined] + + def sub(self, other) -> SymNode: + return self._sub(other) # type: ignore[attr-defined] + + def mul(self, other) -> SymNode: + return self._mul(other) # type: ignore[attr-defined] + + def mod(self, other) -> SymNode: + return self._mod(other) # type: ignore[attr-defined] + + def float_pow(self, other) -> SymNode: + return self._float_pow(other) # type: ignore[attr-defined] + + def pow_by_natural(self, other) -> SymNode: + return self._pow_by_natural(other) # type: ignore[attr-defined] + + def and_(self, other) -> SymNode: + return self._and_(other) # type: ignore[attr-defined] + + def or_(self, other) -> SymNode: + return self._or_(other) # type: ignore[attr-defined] + + def float_truediv(self, other) -> SymNode: + return self._float_truediv(other) # type: ignore[attr-defined] + + def int_truediv(self, other) -> SymNode: + return self._int_truediv(other) # type: ignore[attr-defined] + + def int_floordiv(self, other) -> SymNode: + return self._int_floordiv(other) # type: ignore[attr-defined] + + def lshift(self, other) -> SymNode: + return self._lshift(other) # type: ignore[attr-defined] + + def rshift(self, other) -> SymNode: + return self._rshift(other) # type: ignore[attr-defined] + + def sym_not(self) -> SymNode: # noqa: F811 + return self._sym_not() # type: ignore[attr-defined] + + def eq(self, other) -> SymNode: + return self._eq(other) # type: ignore[attr-defined] + + def ne(self, other) -> SymNode: + return self._ne(other) # type: ignore[attr-defined] + + def gt(self, other) -> SymNode: + return self._gt(other) # type: ignore[attr-defined] + + def lt(self, other) -> SymNode: + return self._lt(other) # type: ignore[attr-defined] + + def le(self, other) -> SymNode: + return self._le(other) # type: ignore[attr-defined] + + def ge(self, other) -> SymNode: + return self._ge(other) # type: ignore[attr-defined] + + def floor(self) -> SymNode: + return self._floor() # type: ignore[attr-defined] + + def is_integer(self) -> SymNode: + return self._is_integer() # type: ignore[attr-defined] + + def sym_float(self) -> SymNode: # noqa: F811 + return self._sym_float() # type: ignore[attr-defined] + + def sym_int(self) -> SymNode: + return self._sym_int() # type: ignore[attr-defined] + + def ceil(self) -> SymNode: + return self._ceil() # type: ignore[attr-defined] + + def neg(self) -> SymNode: + return self._neg() # type: ignore[attr-defined] + + def sym_min(self, other) -> SymNode: # noqa: F811 + return self._sym_min(other) # type: ignore[attr-defined] + + def sym_max(self, other) -> SymNode: # noqa: F811 + return self._sym_max(other) # type: ignore[attr-defined] + + def sym_ite(self, then_val, else_val) -> SymNode: + return self._sym_ite(then_val, else_val) # type: ignore[attr-defined] + + def is_contiguous(self, sizes, strides) -> SymNode: + return self._is_contiguous(sizes, strides) # type: ignore[attr-defined] + + def is_channels_last_contiguous_2d(self, sizes, strides) -> SymNode: + return self._is_channels_last_contiguous_2d(sizes, strides) # type: ignore[attr-defined] + + def is_channels_last_contiguous_3d(self, sizes, strides) -> SymNode: + return self._is_channels_last_contiguous_3d(sizes, strides) # type: ignore[attr-defined] + + def is_channels_last_strides_2d(self, sizes, strides) -> SymNode: + return self._is_channels_last_strides_2d(sizes, strides) # type: ignore[attr-defined] + + def is_channels_last_strides_3d(self, sizes, strides) -> SymNode: + return self._is_channels_last_strides_3d(sizes, strides) # type: ignore[attr-defined] + + def is_non_overlapping_and_dense_indicator(self, sizes, strides) -> SymNode: + return self._is_non_overlapping_and_dense_indicator(sizes, strides) # type: ignore[attr-defined] + + # Make C++ happy + def sym_or(self, other): + return self.or_(other) + + def sym_and(self, other): + return self.and_(other) + + # Integer bitwise ops + def bitwise_and(self, other): + return self._bitwise_and(other) # type: ignore[attr-defined] + + def bitwise_or(self, other): + return self._bitwise_or(other) # type: ignore[attr-defined] + + def bitwise_xor(self, other): + return self._bitwise_xor(other) # type: ignore[attr-defined] + + # There is no int_truediv available from C++ + def truediv(self, other): + return self.float_truediv(other) + + def floordiv(self, other) -> SymNode: + return self.int_floordiv(other) + + # We didn't bind integer pow in C++ + def pow(self, other): + return self.float_pow(other) + + def is_non_overlapping_and_dense(self, sizes, strides): + return self.is_non_overlapping_and_dense_indicator(sizes, strides).eq( + to_node(self, 1) + ) # type: ignore[attr-defined] + + def int_(self): + return self.guard_int("", 0) # NB: uses Python backtrace + + # This one is currently done by hand, but if we add other variadic + # functions consider factoring it out to be metaprogrammed too. Note that + # some load bearing logic is directly in torch.sym_sum + + def sym_sum(self, args) -> SymNode: + import sympy + + # Inner impl + from torch.fx.experimental.proxy_tensor import ( + get_proxy_mode, + handle_sym_dispatch, + ) + + if get_proxy_mode(): + return to_node( + self, + handle_sym_dispatch( + torch.sym_sum, + (tuple(wrap_node(a) for a in args),), + {}, + ), + ) + exprs = [a.expr for a in args] + out = sympy.Add(*exprs) + + size_hints = [] + out_hint = None + for a in args: + if a.hint is None: + break + size_hints.append(a.hint) + else: + out_hint = sum(size_hints) + + fx_node, _ = self.shape_env._create_fx_call_function( + torch.sym_sum, (tuple(a.fx_node for a in args),) + ) + + # NB: Only for integers! + return SymNode(out, self.shape_env, int, out_hint, fx_node=fx_node) + + def evaluate(self, size_oblivious=False): + return self.shape_env.evaluate_sym_node(self, size_oblivious) + + # You can manually trigger a guard with this function + def guard_int(self, file, line): + # TODO: use the file/line for some useful diagnostic on why a + # guard occurred + r = self.evaluate() + try: + return int(r) + except Exception: + log.warning("Failed to convert to int: %s", r) + raise + + def guard_float(self, file, line): + # TODO: use the file/line for some useful diagnostic on why a + # guard occurred + r = self.evaluate() + try: + return float(r) + except Exception: + log.warning("Failed to convert to float: %s", r) + raise + + def guard_bool(self, file, line): + # TODO: use the file/line for some useful diagnostic on why a + # guard occurred + r = self.evaluate() + try: + return bool(r) + except Exception: + log.warning("Failed to convert to bool: %s", r) + raise + + def expect_true(self, file, line): + from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols + + if ( + self.has_hint() + and not free_unbacked_symbols(self.expr) + and not self.shape_env.prefer_deferred_runtime_asserts_over_guards + ): + # OK to generate guards + return self.guard_bool(file, line) + # Generate a deferred runtime assert (this might actually end up doing + # a regular guard if we can!) + # TODO: file/line here is very important, because the assert has been + # deferred so you can't backtrace easily + return self.shape_env.guard_or_defer_runtime_assert( + self.expr, f"{file}:{line}", fx_node=self.fx_node + ) + + def statically_known_true(self, file, line): + from torch.fx.experimental.symbolic_shapes import statically_known_true + + assert self.is_bool() + return statically_known_true(SymBool(self)) + + def guard_size_oblivious(self, file, line): + """ + Like guard_bool, but if we encounter unbacked symbols, if those symbols + are size-like, we will treat them as >= 2 for the purposes of the analysis. + + This CHANGES the runtime semantics, but all size-oblivious sites have been + audited to ensure that the runtime semantics don't change in a material way. + Acceptable runtime semantic changes are, e.g., squeeze() no longer dropping + an unbacked one size, or a tensor reporting as non-contiguous even if it's + contiguous if it would have been reported contiguous due to being empty. + """ + # TODO: use the file/line for some useful diagnostic on why a + # guard occurred + r = self.evaluate(size_oblivious=True) + try: + return bool(r) + except Exception: + log.warning("Failed to convert to bool: %s", r) + raise + + def guard_or_false(self, file, line): + from torch.fx.experimental.symbolic_shapes import guard_or_false + + assert self.is_bool() + return guard_or_false(SymBool(self)) + + def guard_or_true(self, file, line): + from torch.fx.experimental.symbolic_shapes import guard_or_true + + assert self.is_bool() + return guard_or_true(SymBool(self)) + + def bool_(self): + return self.guard_bool("", 0) + + def is_symbolic(self): + return True + + def nested_int(self): + return None + + def is_constant(self): + return False + + +class _DynamicScalar: + def __new__(cls, *args): + if cls is _DynamicScalar: + raise TypeError("_DynamicScalar is an abstract base class, use DynamicInt.") + return super().__new__(cls, *args) + + +class DynamicInt(_DynamicScalar, int): + """ + User API for marking dynamic integers in `torch.compile`. + Intended to be compatible with both compile and eager mode. + + Example usage:: + + fn = torch.compile(f) + x = DynamicInt(4) + fn(x) # compiles x as a dynamic integer input; returns f(4) + """ + + def __new__(cls, val): + assert isinstance(val, int) + obj = super().__new__(cls, int(val)) + return obj + + def __repr__(self): + return f"DynamicInt({self.real})" + + def __floordiv__(self, other): # // was casting to int without these overrides? + return DynamicInt(self.real // other) + + def __rfloordiv__(self, other): + return DynamicInt(other // self.real) + + +# TODO: this probably needs the sizes-strides eval functions +METHOD_TO_OPERATOR = { + "pos": operator.pos, + "abs": operator.abs, + "add": operator.add, + "and": operator.and_, + "bitwise_and": operator.and_, + "ceil": math.ceil, + "eq": operator.eq, + "floor": math.floor, + "trunc": math.trunc, + "int_floordiv": operator.floordiv, + "ge": operator.ge, + "gt": operator.gt, + "is_integer": lambda x: x.is_integer(), + "le": operator.le, + "lshift": operator.lshift, + "lt": operator.lt, + "mod": operator.mod, + "mul": operator.mul, + "ne": operator.ne, + "neg": operator.neg, + "or": operator.or_, + "bitwise_or": operator.or_, + "bitwise_xor": operator.xor, + "float_pow": operator.pow, + "pow_by_natural": operator.pow, + "round": builtins.round, + "rshift": operator.rshift, + "sub": operator.sub, + "sym_float": sym_float, + "sym_ite": sym_ite, + "sym_max": sym_max, + "sym_min": sym_min, + "sym_not": sym_not, + "float_truediv": operator.truediv, + "int_truediv": operator.truediv, +} + +unary_magic_methods = { + "abs", + "sym_float", + "sym_int", + "ceil", + "floor", + "neg", + "sym_not", + "pos", + "trunc", +} + + +# Adding math ops: sqrt, cos, sin, ... +def _get_sym_node_fn(name): + def fn(self): + return getattr(self, f"_sym_{name}")() + + return fn + + +math_op_names = ( + "sqrt", + "cos", + "cosh", + "sin", + "sinh", + "tan", + "tanh", + "asin", + "acos", + "atan", + "log2", +) +for name in math_op_names: + sym_name = f"sym_{name}" + priv_sym_name = f"_{sym_name}" + setattr(SymNode, sym_name, _get_sym_node_fn(name)) + METHOD_TO_OPERATOR[sym_name] = getattr(torch, priv_sym_name) + unary_magic_methods.add(sym_name) + __all__.append(sym_name) + + +# Unary methods that are not magic methods +unary_nonmagic_methods = { + "is_integer", +} + +unary_methods = unary_magic_methods | unary_nonmagic_methods + +# Most methods are only registered on SymInt and SymFloat +# Some methods are only be registered on SymBool +only_bool_magic_methods = {"and", "or", "sym_not", "sym_ite"} +# Methods that implicitly convert SymBool into SymInt +bool_becomes_int_magic_methods = {"add", "sub", "mul"} +# Methods that are also on SymBool, in addition to on SymInt and SymFloat +also_bool_magic_methods = {"eq"} +bool_magic_methods = only_bool_magic_methods | also_bool_magic_methods + +# Methods that are only for float +only_float_magic_methods = {"is_integer", "round", "sym_int", "sym_log2"} + + +magic_methods_on_operator_with_trailing_underscore = {"and", "or"} +# remap necessary because an op name can have a bitwise and boolean implementation +bitwise_ops = {"bitwise_and": "and", "bitwise_or": "or", "bitwise_xor": "xor"} + + +always_float_magic_methods = {"int_truediv", "float_truediv", "sym_float", "float_pow"} + +for name in math_op_names: + sym_name = f"sym_{name}" + always_float_magic_methods.add(sym_name) + + +always_int_magic_methods = {"ceil", "floor", "trunc", "pow_by_natural"} +always_bool_magic_methods = { + "eq", + "ne", + "gt", + "lt", + "le", + "ge", + "and", + "or", + "sym_not", + "is_non_overlapping_and_dense", + "is_integer", +} + +# Methods that have a `__foo__` as well as `__rfoo__` + + +def _sympy_float_truediv(a, b): + from torch.utils._sympy.functions import FloatTrueDiv + + return FloatTrueDiv(a, b) + + +def _sympy_int_truediv(a, b): + from torch.utils._sympy.functions import IntTrueDiv + + return IntTrueDiv(a, b) + + +def _sympy_floordiv(a, b): + from torch.utils._sympy.functions import FloorDiv + + return FloorDiv(a, b) + + +def _sympy_mod(a, b): + from torch.utils._sympy.functions import Mod, PythonMod + + if a.is_nonnegative and b.is_nonnegative: + return Mod(a, b) + else: + return PythonMod(a, b) + + +def _sympy_pow_by_natural(a, b): + from torch.utils._sympy.functions import PowByNatural + + return PowByNatural(a, b) + + +def _sympy_float_pow(a, b): + from torch.utils._sympy.functions import FloatPow + + return FloatPow(a, b) + + +def _sympy_and(a, b): + import sympy + + return sympy.And(a, b) + + +def _sympy_or(a, b): + import sympy + + return sympy.Or(a, b) + + +def _sympy_lshift(a, b): + from torch.utils._sympy.functions import LShift + + return LShift(a, b) + + +def _sympy_rshift(a, b): + from torch.utils._sympy.functions import RShift + + return RShift(a, b) + + +def _binary_search_insert_arg(ordered_args, new_arg): + """ + If new_arg is found in ordered_args None is returned, else the new + ordered_args with new_arg inserted + """ + if len(ordered_args) == 0: + return [new_arg] + + from sympy.core.basic import _args_sortkey as sort_key, Basic + + # Fast path when new_arg > ordered_args[-1]. + if sort_key(ordered_args[-1]) < sort_key(new_arg): + return ordered_args + [new_arg] + + # Fast path when new_arg < ordered_args[0]. + if sort_key(ordered_args[0]) > sort_key(new_arg): + return [new_arg] + ordered_args + + low, high = 0, len(ordered_args) - 1 + + while low <= high: + mid = (low + high) // 2 + compare_result = Basic.compare(ordered_args[mid], new_arg) + if compare_result == 0: + return None + elif compare_result < 0: + low = mid + 1 + else: + high = mid - 1 + + ordered_args.insert(low, new_arg) + return ordered_args + + +def _optimized_add( + lhs, rhs, lhs_is_optimized_summation=False, rhs_is_optimized_summation=False +): + """ + Custom optimization for Add used to optimize incremental binary summations of certain properties. The idea + is when we know the expression is a summation of unique symbols all we need to know is the correct order of symbols, + and no other optimizations are needed. We pass evaluate=false, with the correct order of args and save the following. + 1. Avoid running other optimizations when the Add is constructed. + 2. Manually figure out the order of the args for the new expression in log(n) comparisons instead of nLog(n) + (comparing terms is expensive and shows in the profiles). + The function returns a tuple of (1) a boolean that indicates whether the output is a summation of unique symbols, + (2) the result sympy expression. + """ + import sympy + from sympy.core.basic import _args_sortkey as sortkey + + def make_optimized(ordered_args): + assert ordered_args is not None + result = sympy.Add(*ordered_args, evaluate=False) + return (True, result) + + from torch.utils._sympy.functions import _is_symbols_binary_summation + + lhs_is_optimized_summation |= _is_symbols_binary_summation(lhs) + rhs_is_optimized_summation |= _is_symbols_binary_summation(rhs) + + if lhs_is_optimized_summation and rhs_is_optimized_summation: + # (a0+a1..) + (a2+a3..) => (a0+a1+a2+a3) + if sortkey(lhs._args[-1]) < sortkey(rhs._args[0]): + return make_optimized(lhs._args + rhs._args) + # (a2+a3..) + (a0+a1..) => (a0+a1+a2+a3) + if sortkey(lhs._args[0]) > sortkey(rhs._args[-1]): + return make_optimized(rhs._args + lhs._args) + + # (a1+a3) + (a0+a2) => (a0+a1+a2+a3) + if len(lhs._args) <= 2 and len(rhs._args) <= 2: + new_args = list(lhs._args) + for a in rhs._args: + new_args = _binary_search_insert_arg(new_args, a) + if new_args is None: + break + # None means an element already exists. + if new_args is not None: + return make_optimized(new_args) + + # (a0+a2) + a1 => (a0+a1+a2) + if lhs_is_optimized_summation and rhs.is_symbol: + new_args = _binary_search_insert_arg(list(lhs._args), rhs) + # None means an element already exists. + if new_args is not None: + return make_optimized(new_args) + + # a1 + (a0+a2)=> (a0+a1+a2) + if rhs_is_optimized_summation and lhs.is_symbol: + new_args = _binary_search_insert_arg(list(rhs._args), lhs) + # None means an element already exists. + if new_args is not None: + return make_optimized(new_args) + + result = sympy.Add(lhs, rhs) + return (_is_symbols_binary_summation(result), result) + + +def _bitwise_and(a, b): + from torch.utils._sympy.functions import BitwiseFn_bitwise_and + + return BitwiseFn_bitwise_and(a, b) + + +def _bitwise_or(a, b): + from torch.utils._sympy.functions import BitwiseFn_bitwise_or + + return BitwiseFn_bitwise_or(a, b) + + +def _bitwise_xor(a, b): + from torch.utils._sympy.functions import BitwiseFn_bitwise_xor + + return BitwiseFn_bitwise_xor(a, b) + + +reflectable_magic_methods = { + "add": operator.add, + "sub": operator.sub, + "mul": operator.mul, + "mod": _sympy_mod, + "pow_by_natural": _sympy_pow_by_natural, + "float_pow": _sympy_float_pow, + "and": _sympy_and, + "bitwise_and": _bitwise_and, + "or": _sympy_or, + "bitwise_or": _bitwise_or, + "bitwise_xor": _bitwise_xor, + "float_truediv": _sympy_float_truediv, + "int_truediv": _sympy_int_truediv, + "int_floordiv": _sympy_floordiv, + "lshift": _sympy_lshift, + "rshift": _sympy_rshift, +} + + +def _floor_ceil_helper(a, fn): + import sympy + + if isinstance(a, sympy.Mul): + aa = a.args + if len(aa) == 2 and isinstance(aa[0], sympy.Float) and aa[1].is_integer: + coef = sympy.Integer(aa[0]) + if aa[0] == coef: # structural equality test + return coef * aa[1] + if ( + isinstance(a, sympy.Float) + and a == sympy.Integer(a) + or isinstance(a, sympy.Integer) + ): + return sympy.Integer(a) + return fn(a) + + +def _sympy_floor(a): + from torch.utils._sympy.functions import FloorToInt + + return FloorToInt(a) + + +# NB: this is Python trunc semantics which returns an int. Do NOT use this to +# represent torch.trunc (which is float to float) +def _sympy_trunc(a): + from torch.utils._sympy.functions import TruncToInt + + return TruncToInt(a) + + +def _sympy_ceil(a): + from torch.utils._sympy.functions import CeilToInt + + return CeilToInt(a) + + +def _sympy_eq(a, b): + import sympy + + return sympy.Eq(a, b) + + +def _sympy_ne(a, b): + import sympy + + return sympy.Ne(a, b) + + +def _sympy_gt(a, b): + import sympy + + return sympy.Gt(a, b) + + +def _sympy_lt(a, b): + import sympy + + return sympy.Lt(a, b) + + +def _sympy_le(a, b): + import sympy + + return sympy.Le(a, b) + + +def _sympy_ge(a, b): + import sympy + + return sympy.Ge(a, b) + + +def _sympy_min(a, b): + from torch.utils._sympy.functions import Min + + return Min(a, b) + + +def _sympy_max(a, b): + from torch.utils._sympy.functions import Max + + return Max(a, b) + + +def _sympy_ite(a, t, f): + import sympy + + return sympy.Piecewise((t, a), (f, True)) + + +current_module = sys.modules[__name__] + + +def _get_sym_math_fn(name): + def fn(a): + import torch.utils._sympy.functions + + return getattr(torch.utils._sympy.functions, f"OpaqueUnaryFn_{name}")(a) + + return fn + + +for name in math_op_names: + priv_sympy_name = f"_sympy_{name}" + fn = _get_sym_math_fn(name) + fn.__qualname__ = fn.__name__ = priv_sympy_name + setattr(current_module, priv_sympy_name, fn) + +del fn, name, priv_sympy_name # type: ignore[possibly-undefined] + + +def _sympy_abs(a): + import sympy + + return sympy.Abs(a) + + +def _sympy_round(number, ndigits=None): + from torch.utils._sympy.functions import RoundDecimal, RoundToInt + + if ndigits is None: + return RoundToInt(number) + else: + return RoundDecimal(number, ndigits) + + +def _sympy_sym_float(a): + from torch.utils._sympy.functions import ToFloat + + # NB: Cannot use a * 1.0 here, because 0 * 1.0 is 0 which incorrectly + # reports that it is an integer + return ToFloat(a) + + +def _sympy_is_integer(a): + import sympy + + from torch.utils._sympy.functions import ToFloat + + return sympy.Eq(ToFloat(sympy.floor(a)), a) + + +magic_methods = { + **reflectable_magic_methods, + "sym_not": operator.invert, + "pos": operator.pos, + "eq": _sympy_eq, + "ne": _sympy_ne, + "gt": _sympy_gt, + "lt": _sympy_lt, + "le": _sympy_le, + "ge": _sympy_ge, + "floor": _sympy_floor, + "trunc": _sympy_trunc, + "sym_float": _sympy_sym_float, + "ceil": _sympy_ceil, + "neg": operator.neg, + "sym_min": _sympy_min, + "sym_max": _sympy_max, + "sym_ite": _sympy_ite, + "abs": _sympy_abs, + "round": _sympy_round, + "is_integer": _sympy_is_integer, +} + + +for name in math_op_names: + sym_name = f"sym_{name}" + magic_methods[sym_name] = getattr(current_module, f"_sympy_{name}") + +del name, sym_name, math_op_names, current_module # type: ignore[possibly-undefined] + + +def sympy_is_contiguous(sizes, strides): + dim = len(sizes) + return sympy_is_contiguous_generic(sizes, strides, list(range(dim - 1, -1, -1))) + + +def sympy_is_contiguous_generic(sizes, strides, dim_order): + import sympy + + dim = len(sizes) + + if len(dim_order) != dim: + return sympy.false + + is_contiguous = sympy.true + z = sympy.S.One + # Contiguous if the strides make sense (or the dim is size 1) + for d in dim_order: + is_contiguous &= sympy.Eq(sizes[d], sympy.S.One) | sympy.Eq(strides[d], z) + z *= sizes[d] + # OR if any size is zero + for d in range(dim): + is_contiguous |= sympy.Eq(sizes[d], sympy.S.Zero) + return is_contiguous + + +# NB: There is a TODO in C++ to allow omitting the batch dim. If that +# happens you will need to refactor this + + +def sympy_is_channels_last_contiguous_2d(sizes, strides): + return sympy_is_contiguous_generic(sizes, strides, [1, 3, 2, 0]) + + +def sympy_is_channels_last_contiguous_3d(sizes, strides): + return sympy_is_contiguous_generic(sizes, strides, [1, 4, 3, 2, 0]) + + +def sympy_is_channels_last_strides_generic(sizes, strides, dim_order): + import sympy + + from torch.utils._sympy.functions import Max + + dim = len(sizes) + + if dim != len(dim_order): + return sympy.false + + m = sympy.S.Zero + r = sympy.true + + # special case for trivial C dimension. default to NCHW + r &= sympy.Ne(strides[1], 0) + + for d in dim_order: + r &= sympy.Ne(sizes[d], 0) & (strides[d] >= m) + # Fallback to NCHW as default layout for ambiguous cases + # This is the flaw of implicit memory_format from strides. + # N111 tensor with identical strides for size 1 dimension; + # Two cases could lead us here: + # a. N111 contiguous Tensor ([N,1,1,1]@[1,1,1,1]) + # b. N11W contiguous Tensor sliced on the W-dimension. + # ([N,1,1,1]@[W,W,W,W]) + if d == 0: + r &= sympy.Ne(m, strides[1]) + # This is necessary to: + # 1. distinguish the memory_format of N1H1; + # [H, 1, 1, 1] channels_last stride + # [H, H, 1, 1] contiguous stride + # 2. permutation of 1C1W: + # [1, C, 1, H]@[HC, H, H, 1] transpose(1, 3) + # [1, H, 1, C]@[HC, 1, H, H] shouldn't be identified as + # channels_last + m = strides[d] * Max(sizes[d], 1) + + return r + + +def sympy_is_channels_last_strides_2d(sizes, strides): + return sympy_is_channels_last_strides_generic(sizes, strides, [1, 3, 2, 0]) + + +def sympy_is_channels_last_strides_3d(sizes, strides): + return sympy_is_channels_last_strides_generic(sizes, strides, [1, 4, 3, 2, 0]) + + +def _sympy_is_non_overlapping_and_dense_indicator(sizes, strides): + from torch.utils._sympy.functions import IsNonOverlappingAndDenseIndicator + + return IsNonOverlappingAndDenseIndicator(*sizes, *strides) + + +sizes_strides_methods = { + # TODO: These could also be done with indicators, maybe it is better + # for reasoning to do it that way + "is_contiguous": sympy_is_contiguous, + "is_channels_last_contiguous_2d": sympy_is_channels_last_contiguous_2d, + "is_channels_last_contiguous_3d": sympy_is_channels_last_contiguous_3d, + "is_channels_last_strides_2d": sympy_is_channels_last_strides_2d, + "is_channels_last_strides_3d": sympy_is_channels_last_strides_3d, + "is_non_overlapping_and_dense_indicator": _sympy_is_non_overlapping_and_dense_indicator, +} + + +def to_node(self, num): + if isinstance(num, SymTypes): + return num.node + elif type(num) is bool: + return self.wrap_bool(num) + elif type(num) is int: + return self.wrap_int(num) + elif type(num) is float: + return self.wrap_float(num) + else: + # NotImplemented is important so that Python tries the + # other magic method + return NotImplemented + + +def wrap_node(x): + # TODO: let C++ also take advantage of this + if isinstance(x, SymNode) and x.constant is not None: + return x.constant + if x.is_int(): + return SymInt(x) + elif x.is_float(): + return SymFloat(x) + elif x.is_bool(): + return SymBool(x) + else: + raise AssertionError(f"unrecognized return type {x}") + + +def method_to_operator(method): + return METHOD_TO_OPERATOR[method] + + +def _make_node_magic(method, func): + func = lru_cache(256)(func) + + if method in magic_methods_on_operator_with_trailing_underscore: + method_attr = f"{method}_" + else: + method_attr = method + + def uninteresting_files() -> set[str]: + import torch + + mods = [ + torch._dynamo.eval_frame, + torch._dynamo.utils, + torch.fx.experimental.sym_node, + torch, + ] + import torch._dynamo.guards + + return ( + {inspect.getfile(m) for m in mods} + | torch._dynamo.guards.uninteresting_files() + | {""} + ) + + def capture_provenance(fn): + @functools.wraps(fn) + def wrapper(self, other=None): + if other is None: + result = fn(self) + else: + result = fn(self, other) + if torch._logging._internal.GET_DTRACE_STRUCTURED: + if other is not None: + arguments = [self, other] + else: + arguments = [self] + + def get_id(sym_node) -> Optional[int]: + # We don't want to return an ID if the input is a constant + import sympy + + if sym_node.constant is not None: + return None + elif id(sym_node) == id(result): + return None + elif isinstance(sym_node.expr, (sympy.Integer, sympy.Float)): + return None + elif sym_node.expr in (sympy.true, sympy.false): + return None + return id(sym_node) + + dtrace_structured( + "expression_created", + metadata_fn=lambda: { + "method": method, + "result": str(result), + "result_id": id(result), + "arguments": [str(a) for a in arguments], + "argument_ids": [ + get_id(i) for i in arguments if get_id(i) is not None + ], + "user_stack": structured.get_user_stack(3), + "stack": structured.get_framework_stack(3), + }, + ) + + return result + + return wrapper + + @capture_provenance + def binary_magic_impl(self, other): + from torch.fx.experimental.proxy_tensor import ( + get_proxy_mode, + handle_sym_dispatch, + ) + + op = method_to_operator(method) + + out_hint = None + if self.hint is not None and other.hint is not None: + out_hint = op(self.hint, other.hint) + + if get_proxy_mode(): + return to_node( + self, handle_sym_dispatch(op, (wrap_node(self), wrap_node(other)), {}) + ) + assert isinstance(other, SymNode) + optimized_summation = False + try: + if method == "mod": + from torch.utils._sympy.functions import Mod, PythonMod + + # Special handling for mod that requires access to the value + # ranges + shape_env = self.shape_env + if ( + self.expr.is_nonnegative + or shape_env.bound_sympy(self.expr).lower >= 0 + ) and ( + other.expr.is_nonnegative + or shape_env.bound_sympy(other.expr).lower >= 0 + ): + out = Mod(self.expr, other.expr) + else: + out = PythonMod(self.expr, other.expr) + elif method == "add": + # see Note [optimized_summation] + (optimized_summation, out) = _optimized_add( + self.expr, + other.expr, + self._optimized_summation, + other._optimized_summation, + ) + else: + # TODO: consider constant prop here + out = func(self.expr, other.expr) + except Exception: + log.warning("failed to eval %s(%s, %s)", method, self.expr, other.expr) + raise + sym_node_log.debug("%s %s %s -> %s", method, self.expr, other.expr, out) + pytype: type + # This is not strictly correct. In Python, a**b may return complex when + # a < 0 and b is a float: (-1)**2.1. Same for sympy.sqrt(-3.14). This + # returns a float while both arguments are ints: 2**(-1). Also, max and + # min do not type promote. To avoid having data-dependent control flow + # here, we just set the type to float if one of the args is a float. In + # case of a type mismatch, we assume that it will be detected during + # evaluation. + if method in always_float_magic_methods: + pytype = float + elif method in always_bool_magic_methods: + pytype = bool + elif self.pytype is float or other.pytype is float: + pytype = float + else: + pytype = self.pytype + + if ( + pytype is not None + and out_hint is not None + and not isinstance(out_hint, SymTypes) + ): + out_hint = pytype(out_hint) + + # Create a FX node that corresponds to the operation being applied to + # this node. + fx_node, _ = self.shape_env._create_fx_call_function( + op, (self.fx_node, other.fx_node) + ) + + result = SymNode( + out, + self.shape_env, + pytype, + out_hint, # type: ignore[arg-type] + fx_node=fx_node, + optimized_summation=optimized_summation, # see Note [optimized_summation] + ) + return result + + @capture_provenance + def unary_magic_impl(self): + from torch.fx.experimental.proxy_tensor import ( + get_proxy_mode, + handle_sym_dispatch, + ) + + op = method_to_operator(method) + if get_proxy_mode(): + return to_node(self, handle_sym_dispatch(op, (wrap_node(self),), {})) + # TODO: consider constant prop here + expr = self.expr + if method == "floor" or method == "ceiling": + expr = self.shape_env._simplify_floor_div(expr) + + try: + out = func(expr) + except Exception: + log.warning("failed to eval %s(%s)", method, expr) + raise + sym_node_log.debug("%s %s -> %s", func, expr, out) + out_hint = None + if self.hint is not None: + out_hint = op(self.hint) + pytype: type + if method in always_int_magic_methods: + pytype = int + elif method in always_bool_magic_methods: + pytype = bool + elif method in always_float_magic_methods: + pytype = float + else: + pytype = self.pytype + + fx_node, _ = self.shape_env._create_fx_call_function(op, (self.fx_node,)) + return SymNode(out, self.shape_env, pytype, out_hint, fx_node=fx_node) + + if method in unary_methods: + setattr(SymNode, f"_{method_attr}", unary_magic_impl) + elif method == "sym_ite": + + def sym_ite_impl(pred_node, then_node, else_node): + from torch.fx.experimental.proxy_tensor import ( + get_proxy_mode, + handle_sym_dispatch, + ) + + out_hint = then_node.hint if pred_node.hint else else_node.hint + if get_proxy_mode(): + return to_node( + pred_node, + handle_sym_dispatch( + sym_ite, + ( + wrap_node(pred_node), + wrap_node(then_node), + wrap_node(else_node), + ), + {}, + ), + ) + + try: + out = func(pred_node.expr, then_node.expr, else_node.expr) + except Exception: + log.warning( + "failed to eval %s(%s, %s, %s)", + method, + pred_node.expr, + then_node.expr, + else_node.expr, + ) + raise + + fx_node, _ = pred_node.shape_env._create_fx_call_function( + sym_ite, (pred_node.fx_node, then_node.fx_node, else_node.fx_node) + ) + return SymNode( + out, pred_node.shape_env, then_node.pytype, out_hint, fx_node=fx_node + ) + + setattr(SymNode, f"_{method_attr}", sym_ite_impl) + elif method == "round": + + def round_impl(self, ndigits=None): + from torch.fx.experimental.proxy_tensor import ( + get_proxy_mode, + handle_sym_dispatch, + ) + + op = builtins.round + if get_proxy_mode(): + return to_node( + self, handle_sym_dispatch(op, (wrap_node(self), ndigits), {}) + ) + + expr = self.expr + try: + out = func(expr, ndigits) + except Exception: + log.warning("failed to eval %s(%s, ndigits=%s)", method, expr, ndigits) + raise + + if ndigits is None: + pytype = int + else: + pytype = self.pytype + + out_hint = None + if self.hint is not None: + out_hint = op(self.hint, ndigits) + + # Internally, None is used as sentinel to indicate that a something is not a node on an FX graph. At the + # same time, there is no way to wrap a plain None into an FX node. Thus, there is no way to pass None here + # without triggering some asserts that check whether we are mixing FX nodes with untracked arguments. The + # hack down below works, because all round function down the line all take ndigits=None as default in their + # signature. + # TODO: Remove the args construction below if a different sentinel is used by FX. + # ezyang(May 2024): LOL + args = [self.fx_node] + if ndigits is not None: + args.append(ndigits) + fx_node, _ = self.shape_env._create_fx_call_function(op, tuple(args)) + return SymNode(out, self.shape_env, pytype, out_hint, fx_node=fx_node) + + setattr(SymNode, f"_{method_attr}", round_impl) + else: + setattr(SymNode, f"_{method_attr}", binary_magic_impl) + + +def _make_node_sizes_strides(method, func): + # NB: don't LRU cache, lots of arguments + + def sizes_strides_impl(self, sizes, strides): + from torch.fx.experimental.proxy_tensor import ( + get_proxy_mode, + handle_sym_dispatch, + ) + + op = getattr(sys.modules[__name__], method) + if get_proxy_mode(): + return to_node( + self, + handle_sym_dispatch( + op, + ([wrap_node(s) for s in sizes], [wrap_node(s) for s in strides]), + {}, + ), + ) + size_exprs = [s.expr for s in sizes] + stride_exprs = [s.expr for s in strides] + try: + out = func(size_exprs, stride_exprs) + except Exception: + log.warning("failed to eval %s(%s, %s)", method, size_exprs, stride_exprs) + raise + # bool is never expandable + + size_hints = [] + out_hint = None + for s in sizes: + if s.hint is None: + break + size_hints.append(s.hint) + else: + stride_hints = [] + for s in strides: + if s.hint is None: + break + stride_hints.append(s.hint) + else: + out_hint = op(size_hints, stride_hints) + + # NB: This is the indicator function, not the actual bool! + pytype: type + if method.endswith("_indicator"): + pytype = int + else: + pytype = bool + return SymNode(out, self.shape_env, pytype, out_hint) + + setattr(SymNode, f"_{method}", sizes_strides_impl) + + # TODO: This is technically hotpath, but in the ideal end state + # guards on this will resolve at a higher level so you never + # spend time in this code + def sizes_strides_user(sizes, strides): + import sympy + + from torch.fx.experimental.symbolic_shapes import ( + eval_is_non_overlapping_and_dense, + ) + + for a in itertools.chain(sizes, strides): + if isinstance(a, SymInt): + return wrap_node( + getattr(a.node, method)( + [to_node(a.node, b) for b in sizes], + [to_node(a.node, b) for b in strides], + ) + ) + if method == "is_non_overlapping_and_dense_indicator": + return eval_is_non_overlapping_and_dense(sizes, strides) + else: + # TODO: this is an awful implementation + return bool( + func( + [sympy.sympify(a) for a in sizes], + [sympy.sympify(a) for a in strides], + ) + ) + + # Skip for is_non_overlapping_and_dense_indicator + if not hasattr(sys.modules[__name__], method): + setattr(sys.modules[__name__], method, sizes_strides_user) + + +for method, func in magic_methods.items(): + _make_node_magic(method, func) + +for method, func in sizes_strides_methods.items(): + _make_node_sizes_strides(method, func) + + +def _make_user_magic(method, user_type): + # User magic takes care of wrapping the other operand into a node, + # so that our internal logic can assume everything is nodes + if method in magic_methods_on_operator_with_trailing_underscore: + method_attr = f"sym_{method}" + else: + method_attr = method + + def get_constant(x: Union[SymInt, int, SymFloat, float, SymBool, bool]): + if isinstance(x, (int, float, bool)): + return x + if isinstance(x, SymInt): + return x.node.guard_int("", 0) + if isinstance(x, SymBool): + return x.node.guard_bool("", 0) + raise AssertionError("expect to be called with constant SymBools") + + def is_constant(x): + if isinstance(x, (int, float, bool)): + return True + if isinstance(x, (SymInt, SymFloat, SymBool)): + return x.node.is_constant() + return False + + # Promotion rules for binary operations. NB: we preserve PYTHON semantics + # - if args are same type, do nothing + # - if one arg is float, promote other arg to float + # - nb: this applies to floordiv, even though output is integral + # (it's still float) + # - pow is funny business + # - if both ints + # - trigger a guard on exponent >= 0 + # - if non-negative, output is int + # - otherwise, output is float + # - otherwise, promote other arg to float + # - nb: complex is impossible to handle correctly lol, with + # negative base and integral float need to diverge semantics and + # just always return complex. Neener neener pretend this problem + # doesn't exist + # - equality is pain: Python does the fancy thing where it unpacks the + # mantissa from the float and then compares that against the int. + # Which means it is able to tell that + # 9007199254740993 != 9007199254740992. (rather than if the LHS was + # promoted to float, in which case it would have truncated to the RHS + # and subsequently been equal). We'll model this exactly by having + # special mixed type equality operations. Unfortunately, we need to + # do this for all comparison operations (maybe I'll only implement + # compare) + # - sym_ite mumble mumble really shouldn't allow mixed but whatever + + if method in bool_becomes_int_magic_methods: + + def promote(x): + """Implements True+True=2, which works in python but not sympy""" + if isinstance(x, SymBool): + return SymInt(x.node.wrap_int(int(x))) + return x + + else: + + def promote(x): + return x + + def promote2(self, other): + # TODO: Remove eq and other relations from this list. + # CPython has fancy implementations for these to get as much precision + # as possible instead of just promoting to float64 and praying, so we + # need to handle them specially too. + # Also, note that int_truediv doesn't go through this path: both + # arguments are "int" so there isn't any promotion + if method not in [ + "add", + "sub", + "mul", + "mod", + "float_pow", + "float_truediv", + "int_floordiv", + "sym_min", + "sym_max", + # TODO: remove these + "eq", + "ne", + "gt", + "lt", + "le", + "ge", + ]: + return self, other + f_self = isinstance(self, (float, torch.SymFloat)) + f_other = isinstance(other, (float, torch.SymFloat)) + if f_self or f_other: + if not f_self: + self = torch.sym_float(self) + if not f_other: + other = torch.sym_float(other) + return self, other + + # Before and after performing the operation, check if any operands are constant. + # If so, extract out the constant values first. If `self` itself is a + # constant, then "redispatch" by calling back into the operator. Sometimes + # this means that operations involving SymBool return plain bools. + # Alternatively, we could also rewrap into constant Symbool (i.e. by + # implementing wrap_bool in ConstantSymNodeImpl), but we're not doing that + # today for no particular reason. + def unary_magic_impl(self): + self = promote(self) + if is_constant(self): + return (method_to_operator(method))(get_constant(self)) + return wrap_node(getattr(self.node, method_attr)()) + + def binary_magic_impl(self, other): + if not isinstance(other, (int, float, bool, SymInt, SymFloat, SymBool)): + return NotImplemented + sym_node_log.debug("MAGIC %s %s %s", method, self, other) + self = promote(self) + other = promote(other) + self, other = promote2(self, other) + if is_constant(self): + return (method_to_operator(method))(get_constant(self), other) + if is_constant(other): + other = get_constant(other) + other_node = to_node(self.node, other) + if other_node is NotImplemented: + return NotImplemented + ret = wrap_node(getattr(self.node, method_attr)(other_node)) + return get_constant(ret) if is_constant(ret) else ret + + def rbinary_magic_impl(self, other): + if not isinstance(other, (int, float, bool, SymInt, SymFloat, SymBool)): + return NotImplemented + self = promote(self) + other = promote(other) + self, other = promote2(self, other) + if is_constant(self): + return (method_to_operator(method))(other, get_constant(self)) + if is_constant(other): + other = get_constant(other) + other_node = to_node(self.node, other) + if other_node is NotImplemented: + return NotImplemented + ret = wrap_node(getattr(other_node, method_attr)(self.node)) + return get_constant(ret) if is_constant(ret) else ret + + def setattrs(user_type, attr, symnode_impl): + """ + Registers the SymNode magic method on SymInt/Float/Bool, + and optionally registers a corresponding wrapped method on DynamicInt. + """ + + # SymInt/Float/Bool + setattr(user_type, attr, symnode_impl) + + # DynamicInt impl + def dynamic_int_impl(*args): + args = [x.real if isinstance(x, DynamicInt) else x for x in args] + out = getattr(int, attr)(*args) + if isinstance(out, int) and not isinstance(out, bool): + return DynamicInt(out) + return out + + if user_type is SymInt: + setattr(DynamicInt, attr, dynamic_int_impl) + + if method in unary_magic_methods: + setattrs(user_type, f"__{method}__", unary_magic_impl) + elif method in unary_nonmagic_methods: + orig = getattr(user_type, method) + setattrs(user_type, method, update_wrapper(unary_magic_impl, orig)) + elif method == "sym_ite": + + def sym_ite_magic_impl(pred, then_val, else_val): + pred_node = pred.node + then_node = to_node(pred_node, then_val) + else_node = to_node(pred_node, else_val) + if then_node is NotImplemented or else_node is NotImplemented: + return NotImplemented + assert ( + isinstance(then_node, SymNode) + and isinstance(else_node, SymNode) + and then_node.pytype == else_node.pytype + ) + ret = wrap_node(getattr(pred.node, method_attr)(then_node, else_node)) + return get_constant(ret) if ret.node.is_constant() else ret + + setattrs(user_type, f"__{method}__", sym_ite_magic_impl) + elif method == "round": + + def round_magic_impl(self, ndigits=None): + if is_constant(self): + return builtins.round(get_constant(self), ndigits) + + return wrap_node(getattr(self.node, method)(ndigits)) + + setattrs(user_type, f"__{method}__", round_magic_impl) + else: + method_name = method + if method in bitwise_ops: + method_name = bitwise_ops[method] + setattrs(user_type, f"__{method_name}__", binary_magic_impl) + if method in reflectable_magic_methods: + setattrs(user_type, f"__r{method_name}__", rbinary_magic_impl) + + +for method in magic_methods: # type: ignore[assignment] + if method in only_bool_magic_methods: + _make_user_magic(method, SymBool) + continue + if method in only_float_magic_methods: + _make_user_magic(method, SymFloat) + continue + if method in also_bool_magic_methods or method in bool_becomes_int_magic_methods: + _make_user_magic(method, SymBool) + _make_user_magic(method, SymInt) + if method not in bitwise_ops: + _make_user_magic(method, SymFloat) + +del method +del func diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/symbolic_shapes.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/symbolic_shapes.py new file mode 100644 index 0000000000000000000000000000000000000000..56ffc77c23b08e0c35860783658c2c84f3ce0397 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/symbolic_shapes.py @@ -0,0 +1,8121 @@ +from __future__ import annotations + +import sympy +from sympy import S + +from torch._prims_common import BoolLike, FloatLike, IntLike + + +""" +``torch.fx.experimental.symbolic_shapes`` provides interfaces for interacting with +our symbolic shapes reasoning system that is used heavily in torch.compile. Although +this is not generally considered public API, when writing framework code in PyTorch +as well as extensions to PyTorch (e.g., in custom operator implementations), you may +need to make use of these APIs to setup dynamic shapes support appropriately. +""" + +import abc +import atexit +import collections +import dis +import functools +import hashlib +import inspect +import itertools +import logging +import math +import operator +import os +import re +import sys +import threading +import traceback +from collections import Counter, defaultdict +from collections.abc import Callable, Generator, Iterator, Mapping, Sequence +from contextlib import _GeneratorContextManager, contextmanager +from dataclasses import asdict, dataclass, field +from enum import Enum +from typing import ( + Any, + cast, + Generic, + NamedTuple, + NoReturn, + Optional, + TYPE_CHECKING, + TypeAlias, + TypeGuard, + TypeVar, + Union, +) +from typing_extensions import deprecated, ParamSpec + +import torch +import torch.fx +import torch.fx.traceback as fx_traceback +import torch.utils._pytree as pytree + +# NB: The sym_* functions are used via getattr() and must be imported here. +from torch import SymBool, SymFloat, SymInt +from torch._C._functorch import get_unwrapped, is_batchedtensor +from torch._guards import ShapeGuard, SLoc, Source, TracingContext +from torch._logging import dtrace_structured, LazyString, structured, trace_structured +from torch._subclasses.meta_utils import is_sparse_any +from torch._utils_internal import signpost_event +from torch.fx.experimental import _config as config +from torch.fx.experimental.recording import ( + FakeTensorMeta, + record_shapeenv_event, + replay_shape_env_events, + shape_env_check_state_equal, + ShapeEnvEvent, +) +from torch.fx.experimental.sym_node import SymNode, SymTypes +from torch.types import py_sym_types +from torch.utils._ordered_set import OrderedSet +from torch.utils._python_dispatch import is_traceable_wrapper_subclass +from torch.utils._sympy.functions import ( + Application, + CeilToInt, + CleanDiv, + FloorDiv, + FloorToInt, + IntTrueDiv, + IsNonOverlappingAndDenseIndicator, + Max, + Mod, + PythonMod, + TruncToInt, +) +from torch.utils._sympy.numbers import int_oo +from torch.utils._sympy.printers import CppPrinter, PythonPrinter +from torch.utils._sympy.singleton_int import SingletonInt +from torch.utils._sympy.solve import try_solve +from torch.utils._sympy.symbol import make_symbol, symbol_is_type, SymT +from torch.utils._sympy.value_ranges import ( + bound_sympy, + SymPyValueRangeAnalysis, + ValueRangeError, + ValueRanges, +) +from torch.utils._traceback import CapturedTraceback, format_frame + + +if TYPE_CHECKING: + import types + + from torch import Tensor + from torch._dynamo.source import TensorPropertySource + from torch._subclasses.fake_tensor import FakeTensor + from torch.types import BoolLikeType, FloatLikeType, IntLikeType + + +InputList = list +DimList = list + +log = logging.getLogger(__name__) + + +class GuardOnDataDependentSymNode(RuntimeError): + cond: sympy.Basic + + def __init__(self, cond: sympy.Basic, *args: Any) -> None: + super().__init__(*args) + self.cond = cond + + +class PendingUnbackedSymbolNotFound(RuntimeError): + pass + + +aten = torch._ops.ops.aten # type: ignore[has-type] + +__all__ = [ + "size_hint", + "guard_or_false", + "guard_or_true", + "has_symbolic_sizes_strides", + "create_contiguous", + "ShapeEnv", + "is_concrete_int", + "is_concrete_float", + "is_concrete_bool", + "has_static_value", + "guard_int", + "guard_float", + "guard_scalar", + "canonicalize_bool_expr", + "hint_int", + "SYMPY_INTERP", + "free_symbols", + "is_symbol_binding_fx_node", + "is_nested_int", + "SHAPEENV_EVENT_KEY", + "CURRENT_NODE_KEY", + "has_free_symbols", + "has_free_unbacked_symbols", + "sym_and", + "sym_eq", + "sym_or", + "SymbolicContext", + "StatelessSymbolicContext", + "StatefulSymbolicContext", + "SubclassSymbolicContext", + "SymIntSymbolicContext", + "TrackedFake", + "statically_known_true", + "statically_known_false", + "guard_size_oblivious", + "check_consistent", + "compute_unbacked_bindings", + "ConvertIntKey", + "rebind_unbacked", + "resolve_unbacked_bindings", + "is_accessor_node", + "ValueRangesSLoc", + "SymIntEqByExpr", + "Specialization", +] + +# FX node metadata keys for symbolic shape FX graph. +SHAPEENV_EVENT_KEY = "shapeenv_event" +CURRENT_NODE_KEY = "current_node" + + +def log_lru_cache_stats(wrapped_f: functools._lru_cache_wrapper[object]) -> None: + log.debug( + "lru_cache_stats %s: %s", + wrapped_f.__name__, # type: ignore[attr-defined] + wrapped_f.cumulative_cache_info(), # type: ignore[attr-defined] + ) + + +# Note about Sympy Expr/SympyBoolean/Basic typing: the Sympy hierarchy is +# +# Basic +# Expr +# SympyBoolean +# Relational +# +# Notably, Expr and SympyBoolean are not related. So use Basic when the +# expression could denote int, float OR bool, and otherwise use the more +# specific Expr for int/float and SympyBoolean for bool. +# +# In obscure Meta only situations, sympy.logic.boolalg doesn't exist at runtime. +# So make sure only type checker evaluates this alias. +# Xref: https://www.internalfb.com/diff/D53324783 +SympyBoolean: TypeAlias = "sympy.logic.boolalg.Boolean" + + +_T = TypeVar("_T") +_SympyT = TypeVar("_SympyT", sympy.Expr, SympyBoolean, sympy.Basic) + + +class SymIntEqByExpr: + """ + This is a wrapper around SymInt which has alternative semantics for + equality and pickling. Specifically, instead of erroring or guarding, we + instead will hash/compare equality based on the underlying sympy + expression; e.g., s0 and s1 will always compare as False. + + NB: This does NOT do fancy analysis that maybe_evaluate_static does; + we can only reason through equalities that occur because to expressions + canonicalize to the same expression via regular simplification. + """ + + @staticmethod + def _extract(val: Union[torch.SymInt, int]) -> sympy.Expr: + if isinstance(val, torch.SymInt): + return val.node.expr + else: + return sympy.Integer(val) + + def __init__(self, val: Union[torch.SymInt, int]) -> None: + self.val: sympy.Expr = SymIntEqByExpr._extract(val) + + def __repr__(self) -> str: + return repr(self.val) + + def __eq__(self, other: object) -> bool: + assert isinstance(other, SymIntEqByExpr) + return self.val == other.val + + def __hash__(self) -> int: + return hash(self.val) + + +def _nested_int_aware_sort( + tup: tuple[IntLikeType, int], +) -> tuple[int, IntLikeType, int]: + return ( + # Order nested ints by their coefficients. + # 1 here to order nested ints after non-nested-ints. + (1, tup[0].node.nested_int_coeff(), tup[1]) + if is_nested_int(tup[0]) + else (0, *tup) + ) + + +def size_hint(x: int | torch.SymInt, *, allow_none: bool = False) -> int | None: + """Gets a size hint for a given expression from the underlying shapes we had. + Does not introduce a guard, so only use this when you can guarantee that + your code is still valid for arbitrary shapes (such as optimization decisions) + """ + if isinstance(x, int): + return x + assert isinstance(x, torch.SymInt) + return x.node.shape_env.size_hint(x.node.expr, allow_none=allow_none) + + +# Wrapper on lru_cache that reports statistics at process end +def lru_cache( + maxsize: Optional[int], +) -> Callable[[Callable[..., _T]], functools._lru_cache_wrapper[_T]]: + def inner(f: Callable[..., _T]) -> functools._lru_cache_wrapper[_T]: + wrapped_f = functools.lru_cache(maxsize)(f) + old_cache_clear = wrapped_f.cache_clear + prev_hits = 0 + prev_misses = 0 + + # TODO: There's a ref-cycle here (wrapped_f -> cumulative_cache_info + # -> wrapped_f) but cannot be solved with weakref as wrapped_f is not + # weakref'able on some versions of Python + + def cumulative_cache_info() -> functools._CacheInfo: + cur = wrapped_f.cache_info() + return functools._CacheInfo( + prev_hits + cur.hits, + prev_misses + cur.misses, + cur.maxsize, + cur.currsize, + ) + + def new_cache_clear() -> None: + nonlocal prev_hits, prev_misses + cur = wrapped_f.cache_info() + prev_hits += cur.hits + prev_misses += cur.misses + old_cache_clear() + + wrapped_f.cache_clear = new_cache_clear # type: ignore[attr-defined, method-assign] + wrapped_f.cumulative_cache_info = cumulative_cache_info # type: ignore[attr-defined, method-assign] + if log.isEnabledFor(logging.DEBUG): + atexit.register(log_lru_cache_stats, wrapped_f) # type: ignore[arg-type] + return wrapped_f + + return inner + + +# These are modules that contain generic code for interacting with ShapeEnv +# which are unlikely to identify a particular interesting guard statement +@lru_cache(None) +def uninteresting_files() -> set[str]: + import torch._compile + import torch._dynamo.eval_frame + import torch._inductor.sizevars + import torch._library.custom_ops + import torch._library.fake_impl + import torch._logging + import torch._subclasses.fake_tensor + import torch._subclasses.meta_utils + import torch.export._trace + + mods = [ + sys.modules[__name__], + torch.export._trace, + torch.fx.experimental.recording, + torch.fx.experimental.sym_node, + torch.fx.interpreter, + torch.fx._symbolic_trace, + torch, + torch._compile, + torch._dynamo.eval_frame, + torch._inductor.sizevars, + torch._library.custom_ops, + torch._library.fake_impl, + torch._subclasses.meta_utils, + torch._subclasses.fake_tensor, + torch._logging._internal, + torch._logging.structured, + ] + import torch._dynamo.guards + + return ( + {inspect.getfile(m) for m in mods} + | torch._dynamo.guards.uninteresting_files() + | {""} + ) + + +class ConstraintViolationError(RuntimeError): + pass + + +def has_symbolic_sizes_strides(elem: torch.Tensor) -> bool: + return elem._has_symbolic_sizes_strides + + +Int: TypeAlias = Union[torch.SymInt, int] + + +def create_contiguous(shape: Sequence[Int]) -> list[Int]: + strides: list[Int] = [1] + for dim in reversed(shape[:-1]): + strides.append(dim * strides[-1]) # type: ignore[operator] + return list(reversed(strides)) + + +def hint_int(a: Union[torch.SymInt, int], fallback: Optional[int] = None) -> int: + """ + Retrieve the hint for an int (based on the underlying real values as observed + at runtime). If no hint is available (e.g., because data dependent shapes), + if fallback is not None, use that instead (otherwise raise an error). + """ + if isinstance(a, torch.SymInt): + return a.node.require_hint(fallback) + assert type(a) is int, a + return a + + +Scalar: TypeAlias = Union[torch.SymInt, torch.SymFloat, torch.SymBool, int, float, bool] + + +def has_hint(a: Scalar) -> bool: + if isinstance(a, SymTypes): + return a.node.has_hint() + return True + + +def is_concrete_int(a: IntLikeType) -> bool: + """ + Utility to check if underlying object + in SymInt is concrete value. Also returns + true if integer is passed in. + + Args: + a (SymInt or int): Object to test if it int + """ + assert isinstance(a, (SymInt, int)) + + if isinstance(a, int): + return True + + if isinstance(a.node.expr, sympy.core.numbers.Integer): + return True + + return False + + +def is_concrete_float(a: FloatLikeType) -> bool: + r"""Utility to check if underlying object + in SymInt is concrete value. Also returns + true if integer is passed in. + + Args: + a (SymInt or float): Object to test if it float + """ + assert isinstance(a, (SymFloat, float)) + + if isinstance(a, float): + return True + + if isinstance(a.node.expr, sympy.core.numbers.Float): + return True + + return False + + +def is_concrete_bool(a: BoolLikeType) -> bool: + """ + Utility to check if underlying object + in SymBool is concrete value. Also returns + true if integer is passed in. + + Args: + a (SymBool or bool): Object to test if it bool + """ + assert isinstance(a, (SymBool, bool)) + + if isinstance(a, bool): + return True + + if isinstance( + a.node.expr, (sympy.logic.boolalg.BooleanTrue, sympy.logic.boolalg.BooleanFalse) + ): + return True + + return False + + +def has_static_value(a: Union[SymBool, SymFloat, SymInt, bool, float, int]) -> bool: + """ + User-code friendly utility to check if a value is static or dynamic. + Returns true if given a constant, or a symbolic expression with a fixed value. + + Args: + a (Union[SymBool, SymFloat, SymInt, bool, float, int]): Object to test + """ + assert isinstance(a, BoolLike + FloatLike + IntLike) + if ( + isinstance(a, BoolLike) + and is_concrete_bool(a) # type: ignore[arg-type] + or isinstance(a, FloatLike) + and is_concrete_float(a) # type: ignore[arg-type] + or isinstance(a, IntLike) + and is_concrete_int(a) # type: ignore[arg-type] + ): + return True + + assert isinstance(a, py_sym_types) + return a.node.shape_env.bound_sympy(a.node.expr).is_singleton() # type: ignore[union-attr] + + +def guard_size_oblivious(expr: Union[torch.SymBool, bool]) -> bool: + """ + Perform a guard on a symbolic boolean expression in a size oblivious way. + This is typically used when a non-oblivious test would result in a guard + on a data dependent value of which we don't know the value of at compile time. + When a guard is tested this way, we may diverge in behavior from how regular + PyTorch semantics would treat it. For more information, see + https://github.com/pytorch/pytorch/pull/118579 + """ + if isinstance(expr, torch.SymBool): + return expr.node.guard_size_oblivious("", 0) + else: + assert isinstance(expr, bool), expr + return expr + + +def check_consistent(new: _T, old: _T) -> None: + """ + Test that two "meta" values (typically either Tensor or SymInt) have + the same values, e.g., after retracing. If we don't understand the + quantities in question, we'll just skip the consistency check. + """ + # TODO: do boolean equality test too, see + # https://github.com/pytorch/pytorch/issues/124110 + scalar_types = (torch.SymInt, torch.SymFloat, int, float) + + if isinstance(new, torch.Tensor): + assert isinstance(old, torch.Tensor) + torch._check( + old.dim() == new.dim(), lambda: f"{old.shape} != {new.shape} (old != new)" + ) + # Do this manually so that each individual test is irrefutable + # (TODO: should be a helper for this, maybe sym_eq? That + # gives us a compound expression and I'm not sure it + # simplifies right now) + for i, j in zip(old.shape, new.shape): + torch._check(i == j, lambda: f"{old.shape} != {new.shape} (old != new)") + # NB: bool is subclass of int + elif isinstance(new, scalar_types) and not isinstance(new, bool): + assert isinstance(old, scalar_types) and not isinstance(old, bool), ( + f"{old} != {new}" + ) + torch._check(old == new, lambda: f"{old} != {new} (old != new)") + + +def resolve_unbacked_bindings( + shape_env: Optional[ShapeEnv], + bindings: Optional[dict[sympy.Symbol, pytree.KeyPath]], +) -> Optional[dict[sympy.Symbol, pytree.KeyPath]]: + """ + When we do fake tensor prop, we oftentimes will allocate new unbacked symints. + We then run proxy tensor mode, which populates node.meta["unbacked_bindings"] + with these new symints. To ensure consistency we use PropagateUnbackedSymInts + to rename unbacked bindings to their old ones. But all of the node metas are + still using the old bindings from before the renaming. This function helps to + post facto apply any renamings discovered in the PropogateUnbackedSymInts pass. + """ + if bindings is None: + return None + assert shape_env is not None + return {shape_env.unbacked_renamings.get(k, k): v for k, v in bindings.items()} + + +Result: TypeAlias = Union[torch.Tensor, tuple[torch.Tensor, ...]] + + +def rebind_unbacked( + shape_env: Optional[ShapeEnv], n: torch.fx.Node, result: Result +) -> None: + """ + Suppose we are retracing a pre-existing FX graph that previously had + fake tensor propagation (and therefore unbacked SymInts). When we retrace, + we re-propagate fake tensors, which results in new unbacked SymInts. + When this happens, we need to tell the shape environment about the equivalence + of the old and new unbacked SymInts. Pass us the old torch.fx.Node (which + has the old binding information) and the new result (which we can extract the + new unbacked SymInts out from). + """ + + # Inputs never need rebinding + if n.op == "placeholder": + return + + if bindings := resolve_unbacked_bindings( + shape_env, n.meta.get("unbacked_bindings") + ): + assert shape_env is not None + for raw_u0, path in bindings.items(): + u1 = pytree.key_get(result, path) + + # Sometimes, things were previously unbacked bindings become constants. + # There are two situations this can happen. + # + # First, you might have a runtime assert that causes the + # constant-ification. In this case, the /binding/ itself will + # still be an unbacked symbol (because we will only force it + # to be a constant later in fake tensor propagation). In this + # case, u1 is a SymInt and we still do all our work as normal. + # + # But second, it might be that fake tensor propagation DIRECTLY + # converted the unbacked SymInt into a constant. This happens + # more rarely, but we have identified two situations it can + # validly occur: + # + # - If you have a tensor_version operator, these are initially + # allocated as unbacked SymInts, but after AOTAutograd they + # get forced specialized to specific values. In this case, + # there is no reason to do runtime asserts on them, this is + # just a hack to properly keep track of them to start. + # + # - If you have an item() call on a constant tensor, the result + # of the item() call is constant and we do not need runtime + # asserts on this symbol. In + # https://github.com/pytorch/pytorch/issues/140625 we have a + # case where in the initial trace of the program we are unable + # to determine that torch.tensor is constant, but then + # subsequent passes cause torch.tensor to become a constant and + # then the unbacked symbol goes poof. + # + # In all of these cases, it is no longer necessary to generate + # deferred runtime asserts, since other subsystems (e.g., the + # constant-ification pass) ensure that the quantity is now truly + # static and cannot change at runtime. So it's OK to discard + # in these situations. + # + # There is one more hazard (re + # https://github.com/pytorch/pytorch/issues/141248), the problem + # is that you can end up with "dangling" unbacked symbols that + # exist in the ShapeEnv but are never bound anywhere. You might + # like an invariant that unbacked symbols never get lost. But + # we do not have this invariant, so do not try to enforce it. + if isinstance(u1, (int, float)): + log.info( + "rebind_unbacked: discard %s %s %s -> %s", + n.target, + raw_u0, + path, + u1, + ) + continue + + # We only care about rebinding unbacked things + if u1.node.hint is not None: + continue + + # unbacked symbols bindings might be replaced to other backed or + # unbacked replacements. + # + # Example: + # u = x.item() + # torch._check(u == 5) + # + # The safest approach is to retrieve raw_u1 from u1.node._expr + # and perform the rebinding on the original unbacked symbol, + # even if it’s no longer directly referenced. + # + # In other words, we should always rebind the original symbol + # before any replacements are applied. + # u0 -> u0 == s1 + raw_u1 = u1.node._expr + + # TODO Do we still need this logic below? + # Simplify SymBool binding + if ( + isinstance(raw_u1, sympy.Piecewise) + and len(raw_u1.args) == 2 + and ( + raw_u1_args0 := cast( + tuple[sympy.Basic, sympy.Basic], raw_u1.args[0] + ) + ) + and raw_u1_args0[0] == 1 + and isinstance(eq := raw_u1_args0[1], sympy.Eq) + and isinstance(new_raw_u1 := eq.lhs, sympy.Symbol) + and shape_env.var_to_range[new_raw_u1].issubset(ValueRanges(0, 1)) + and eq.rhs == 1 + and cast(tuple[sympy.Basic, sympy.Basic], raw_u1.args[1]) == (0, True) + ): + # This is what the pattern match above is testing + repacked = _sympy_cast_symbool_to_symint_guardless( + sympy.Eq(new_raw_u1, 1) + ) + assert repacked == raw_u1, f"{repacked} != {raw_u1}" + # Cancel the to_int(to_bool(x)). This is sound because x in + # [0, 1] + + raw_u1 = new_raw_u1 + + if not isinstance(raw_u1, sympy.Symbol): + assert not raw_u1.free_symbols, ( + f"should have been constant, but got {raw_u1}" + ) + continue + + # The old and new could be the same if you improperly hit the memo + # while retracing. Make sure you updated FakeTensorMode.epoch + assert raw_u0 != raw_u1, f"{raw_u0} possible memo disaster" + # Reuse the OLD symbol name + shape_env._rename_unbacked_to(raw_u1, raw_u0) + + +# NB: You could try to expand this to cover more cases by simply +# detecting whenever you have an int output, but this is a bit +# dangerous in case someone adds a function that returns an int but is +# mutating. So manually whitelist for now. +def is_accessor_node(node: torch.fx.Node) -> bool: + """ + Helper function to determine if a node is trying to access + a symbolic integer such as size, stride, offset or item. Currently + primarily only used in a DCE pass to figure out purity. + """ + + # Dynamo only exercised condition + if ( + node.op == "call_method" + and isinstance(node.args[0], torch.fx.Node) + and isinstance(node.args[0].meta.get("example_value"), torch.Tensor) + and node.target in ["size", "stride", "storage_offset", "item"] + ): + return True + + if node.op == "call_function" and node.target in [ + torch.ops.aten.sym_size, + torch.ops.aten.sym_size.default, + torch.ops.aten.sym_size.int, + torch.ops.aten.sym_stride, + torch.ops.aten.sym_stride.default, + torch.ops.aten.sym_stride.int, + torch.ops.aten.sym_storage_offset, + torch.ops.aten.sym_storage_offset.default, + torch.ops.aten.sym_numel.default, + ]: + return True + + return False + + +def canonicalize_bool_expr(expr: _T) -> _T: + """ + Canonicalize a boolean expression by transforming it into a lt / le + inequality and moving all the non-constant terms to the rhs. + We canonicalize And / Ors / Not via cnf and then canonicalize their subexpr + recursively + nb. sympy.Rel.canonical is not good enough https://github.com/sympy/sympy/issues/25924 + + Args: + expr (sympy.Expr): Expression to canonicalize + """ + # Canonicalise an inequality by transforming it into a lt / le + # inequality and moving all the non-constant terms to the rhs + # We canonicalise And / Ors / Not via cnf + # nb. Relational.canonical in sympy is broken + # https://github.com/sympy/sympy/issues/25924 + + if not isinstance( + expr, (sympy.Rel, sympy.And, sympy.Or, sympy.Not, sympy.Eq, sympy.Ne) + ): + return expr + + if isinstance(expr, (sympy.And, sympy.Or, sympy.Not)): + expr = sympy.logic.boolalg.to_cnf(expr) + return _canonicalize_bool_expr_impl(expr) # type: ignore[arg-type, return-value] + + +def _sympy_from_args( + cls: type[Union[sympy.Add, sympy.Mul]], + args: list[sympy.Expr], + sort: bool = True, + is_commutative: Optional[bool] = None, +) -> sympy.Expr: + """ + Create a sympy expression from a list of arguments, optimizing for performance. + + This function creates a sympy Add or Mul expression from a list of arguments + while avoiding expensive operations like flattening. It handles sorting the + arguments appropriately based on the expression type. + + Args: + cls: The sympy class to create (Add or Mul) + args: List of sympy expressions to combine + sort: Whether to sort the arguments (default: True) + is_commutative: Whether the operation is commutative (default: None) + + Returns: + A sympy expression of type cls combining all arguments + + Raises: + ValueError: If cls is not sympy.Add or sympy.Mul + """ + + if not args: + return cls.identity # type: ignore[union-attr] + + # These args are already in canonical form, so we avoid calling + # Add(*args) to avoid expensive Add.flatten operation + if sort: + if cls is sympy.Add: + sort_fn = sympy.core.add._addsort + elif cls is sympy.Mul: + sort_fn = sympy.core.mul._mulsort + else: + raise ValueError(f"Unknown cls: {cls}") + + # we don't support non commutative with sort + assert is_commutative is True + if args[0].is_Number: + rest = args[1:] + sort_fn(rest) + return cls._from_args([args[0]] + rest, is_commutative=is_commutative) # type: ignore[attr-defined] + else: + args = args.copy() + sort_fn(args) + return cls._from_args(args, is_commutative=is_commutative) # type: ignore[attr-defined] + else: + # if the args are already sorted, we create directly + return cls._from_args(args, is_commutative=is_commutative) # type: ignore[attr-defined] + + +def _canonicalize_bool_expr_impl(expr: SympyBoolean) -> SympyBoolean: + """ + After canonicalization, we are guaranteed to have eliminated Ge/Gt relations + (rewriting them to Le/Lt, respectively). + """ + if isinstance(expr, (sympy.And, sympy.Or)): + return type(expr)(*map(canonicalize_bool_expr, expr.args)) + + opposite = {sympy.Gt: sympy.Lt, sympy.Ge: sympy.Le} + t: Union[type[Any]] + if isinstance(expr, tuple(opposite.keys())): + rhs = expr.lhs - expr.rhs # type: ignore[attr-defined] + t = opposite[type(expr)] # type: ignore[index] + else: + assert isinstance(expr, (sympy.Lt, sympy.Le, sympy.Eq, sympy.Ne)) + rhs = expr.rhs - expr.lhs + t = type(expr) + + def is_neg(t: sympy.Expr) -> bool: + return (t.is_Number and t.is_negative) or ( + isinstance(t, sympy.Mul) and t.args[0].is_Number and t.args[0].is_negative + ) + + lhs = S.Zero + rhs = _reduce_to_lowest_terms(rhs) + if isinstance(rhs, sympy.Add): + pos = [] + neg = [] + for term in rhs.args: + if is_neg(term): + neg.append(-term) + else: + pos.append(term) + # these are already sorted + rhs = _sympy_from_args(sympy.Add, pos, sort=False, is_commutative=True) + # the terms were changed, so needs a sorting + lhs = _sympy_from_args(sympy.Add, neg, sort=True, is_commutative=True) + elif is_neg(rhs): + # lhs == 0 + lhs, rhs = -rhs, S.Zero + # We don't have to evaluate here because lhs, rhs came from a Boolean + # and it was already simplified + return t(lhs, rhs, evaluate=False) + + +def _reduce_to_lowest_terms(expr: sympy.Expr) -> sympy.Expr: + """ + Eliminates any integer factor from a given expression. + E.g., 6x + 4y reduces to 3x + 2y. + + Useful when an expression is == or != to 0. + """ + + def integer_coefficient(x: sympy.Expr) -> int: + if x.is_Integer: + return abs(int(x)) + elif x.is_Mul: + # If one of the args of a Mul is an Integer, it is the + # first arg. eg: args(2*x*3*y) == (6, x, y) + return abs(int(x.args[0])) if x.args[0].is_Integer else 1 # type: ignore[call-overload] + else: + return 1 + + def div_by_factor(x: sympy.Expr, factor: int) -> sympy.Expr: + if x.is_Integer: + return x / factor + elif x.is_Mul: + if x.args[0] != factor: + args = [x.args[0] / sympy.Integer(factor), *x.args[1:]] + else: + # Mul._from_args require a canonical list of args + # so we remove the first arg (x.args[0] / factor) if it was 1 + args = list(x.args[1:]) + return _sympy_from_args(sympy.Mul, args, is_commutative=x.is_commutative) + else: + raise AssertionError(f"illegal arg to div_by_factor: {x}") + + if expr.is_Add: + atoms = cast(Sequence[sympy.Expr], expr.args) + factor = functools.reduce(math.gcd, map(integer_coefficient, atoms)) + if factor == 1: + return expr + # pyrefly: ignore [bad-argument-type] + atoms = [div_by_factor(x, factor) for x in atoms] + return _sympy_from_args( + sympy.Add, atoms, sort=True, is_commutative=expr.is_commutative + ) + elif expr.is_Integer: + return S.One + elif expr.is_Mul: + return div_by_factor(expr, integer_coefficient(expr)) + return expr + + +def is_nested_int(s: IntLikeType) -> TypeGuard[SymInt]: + return isinstance(s, torch.SymInt) and s.node.is_nested_int() + + +IterateExprsAtom: TypeAlias = Union[ + SymInt, SymFloat, SymBool, int, float, bool, sympy.Basic, torch.Tensor +] +IterateExprs: TypeAlias = Union[IterateExprsAtom, Sequence[IterateExprsAtom]] + + +def _iterate_exprs(val: IterateExprs) -> Iterator[sympy.Basic]: + """ + Recursively iterate through a value and yield all sympy expressions contained within it. + + This function traverses various data structures (tensors, lists, tuples, etc.) and extracts + any symbolic expressions they contain. It's used for operations like finding free symbols + in complex nested structures. + + Args: + val: The value to extract sympy expressions from. Can be a symbolic type (SymInt, SymFloat, SymBool), + a sympy expression, a primitive type (int, float, bool), a container (tuple, list), + a sparse tensor, a regular tensor, None, or a torch.Generator. + + Yields: + sympy.Basic: Each sympy expression found in the value. + + Raises: + AssertionError: If the value is of an unsupported type. + """ + # This is almost close enough to implement in terms of _iterate_nodes() + # except that it needs to handle `list[sympy.Basic]` which _iterate_nodes() + # can't handle. + if isinstance(val, SymTypes): + # This allow applies to the jagged layout NestedTensor case as + # nested ints are not symbolic + if is_symbolic(val): + yield val.node.expr + elif isinstance(val, SymNode): + yield val.expr + elif isinstance(val, sympy.Basic): + yield val + elif isinstance(val, (int, float, bool)): + pass + elif isinstance(val, (tuple, list)): + for s in val: + yield from _iterate_exprs(s) + elif is_sparse_any(val): + yield from _iterate_exprs(val.size()) + elif isinstance(val, torch.Tensor): + yield from _iterate_exprs(val.size()) + yield from _iterate_exprs(val.stride()) + yield from _iterate_exprs(val.storage_offset()) + elif val is None: + pass + # see Note: [Generator arguments in AOTDispatcher] + elif isinstance(val, torch.Generator): + pass + else: + raise AssertionError(f"cannot extract sympy expressions from {val} {type(val)}") + + +def _iterate_nodes(val: Any) -> Iterator[SymNode]: + """ + Recursively iterate through a value and yield all SymNodes contained + within it. + """ + if isinstance(val, SymNode): + yield val + elif isinstance(val, py_sym_types): + # This allow applies to the jagged layout NestedTensor case as + # nested ints are not symbolic + if is_symbolic(val): + yield val.node + elif isinstance(val, (tuple, list, torch.Size)): + for s in val: + yield from _iterate_nodes(s) + elif isinstance(val, torch.Tensor): + yield from _iterate_nodes(val.size()) + if not is_sparse_any(val): + yield from _iterate_nodes(val.stride()) + yield from _iterate_nodes(val.storage_offset()) + + +def free_symbols(val: IterateExprs) -> OrderedSet[sympy.Symbol]: + """ + Recursively collect all free symbols from a value. + + This function traverses various data structures (tensors, lists, tuples, etc.) and extracts + all sympy symbols contained within them. It's useful for finding all symbolic variables + that a complex nested structure depends on. + + Args: + val: The value to extract symbols from. Can be a symbolic type (SymInt, SymFloat, SymBool), + a container (tuple, list), a tensor, or None. + + Returns: + OrderedSet[sympy.Symbol]: An ordered set of all free symbols found in the value. + """ + if val is None: + return OrderedSet() + + itr = _iterate_exprs(val) + + # we need at least 1 to call union, so we hand code the identity + try: + first_expr = next(itr) + except StopIteration: + return OrderedSet() + + # TODO: Apparently, returning an OrderedSet here breaks + # python test/distributed/tensor/test_dtensor_compile.py TestDTensorCompile.test_dtensor_dynamic + return first_expr.free_symbols.union(*(e.free_symbols for e in itr)) # type: ignore[return-value] + + +def has_free_symbols(val: IterateExprs) -> bool: + """Faster version of bool(free_symbols(val))""" + return not all((e.is_number or e.is_Boolean) for e in _iterate_exprs(val)) + + +def has_free_unbacked_symbols(x: IterateExprs) -> bool: + """Faster version of bool(free_unbacked_symbols(val))""" + from sympy.core.traversal import iterargs + + for s in _iterate_exprs(x): + for arg in iterargs(s): + if arg.is_Symbol and symbol_is_type( + arg, (SymT.UNBACKED_INT, SymT.UNBACKED_FLOAT) + ): + return True + return False + + +def free_unbacked_symbols(x: IterateExprs) -> OrderedSet[sympy.Symbol]: + """Like free_symbols, but filtered to only report unbacked symbols""" + + # NB: keep synced with is_unbacked_symint + return OrderedSet( + s + for s in free_symbols(x) + if symbol_is_type(s, (SymT.UNBACKED_INT, SymT.UNBACKED_FLOAT)) + ) + + +def _free_non_source_unbacked_symbols( + x: IterateExprs, unbacked_inputs: OrderedSet[sympy.Symbol] +) -> OrderedSet[sympy.Symbol]: + """Unbacked symbols that are not inputs to the graph. These are symbols that originated from + data-dependent operations as opposed to mark_unbacked calls.""" + unbacked_symbols = free_unbacked_symbols(x) + non_source_symbols = unbacked_symbols - unbacked_inputs + return non_source_symbols + + +# WARNING: Don't use this on Dynamo produced graphs, they don't have meta +# setup! +def is_symbol_binding_fx_node(node: torch.fx.Node) -> Optional[sympy.Symbol]: + """ + Check if a given FX node is a symbol binding node. + + A symbol binding node is one that has a SymInt value in its meta that contains + a sympy Symbol expression, and is either a placeholder node or contains unbacked symbols. + + Args: + node (torch.fx.Node): The FX node to check + + Returns: + Optional[sympy.Symbol]: The sympy Symbol if the node is a symbol binding node, None otherwise + """ + if ( + "val" in node.meta + and isinstance(node.meta["val"], torch.SymInt) + and isinstance(node.meta["val"].node.expr, sympy.Symbol) + and ( + node.op == "placeholder" + or free_unbacked_symbols(node.meta["val"].node.expr) + ) + ): + return node.meta["val"].node.expr + return None + + +def find_symbol_binding_fx_nodes( + graph: torch.fx.Graph, +) -> dict[sympy.Symbol, torch.fx.Node]: + """ + Find all nodes in an FX graph that bind sympy Symbols. + + This function scans through all nodes in the given FX graph and identifies + nodes that bind sympy Symbols (typically placeholder nodes with SymInt values). + When multiple nodes bind the same symbol, only the first occurrence is kept. + + Args: + graph: The FX graph to search for symbol binding nodes + + Returns: + A dictionary mapping from sympy Symbols to their binding FX nodes + """ + r = {} + # NB: Prefer first occurrence of symbol + for node in graph.nodes: + if (s := is_symbol_binding_fx_node(node)) is not None and s not in r: + r[s] = node + return r + + +@dataclass(frozen=True) +class Specialization: + """ + This class is used in multi-graph compilation contexts where we generate + multiple specialized graphs and dispatch to the appropriate one at runtime. + This allows us to optimize the trade-off between performance and generality + by creating specialized versions for common patterns (e.g., x.shape[0] % 16 == 0) + while maintaining a general fallback. + """ + + source: TensorPropertySource + check_fn: Callable + + +# Analogous to ConvertIntSource +@dataclass(frozen=True) +class ConvertIntKey: + def __str__(self) -> str: + return ".cast_symbool_to_symint_guardless()" + + def get(self, b: bool) -> IntLikeType: + """Get the int value from bool""" + return cast_symbool_to_symint_guardless(b) + + +@dataclass(frozen=True) +class CallMethodKey: + name: str + + def __str__(self) -> str: + return f".{self.name}()" + + def get(self, o: Any) -> Any: + """Call the method on object""" + return getattr(o, self.name)() + + +@dataclass(frozen=True) +class InnerTensorKey: + inner_name: str + + def __str__(self) -> str: + return f".{self.inner_name}" + + def get(self, o: Any) -> Any: + """Get the inner tensor attribute""" + return getattr(o, self.inner_name) + + +@dataclass(frozen=True) +class DivideByKey: + divisor: IntLikeType + + def __str__(self) -> str: + return f".__floordiv__({self.divisor})" + + def get(self, o: int) -> int: + """Divide object by divisor""" + return o // self.divisor + + +def _free_unbacked_symbols_with_path( + a: object, + path: pytree.KeyPath, + real: Optional[object] = None, + shape_env: Optional[ShapeEnv] = None, + pending: Optional[set[sympy.Symbol]] = None, + simplify: bool = False, +) -> dict[sympy.Symbol, pytree.KeyPath]: + """ + Recursively traverses a structure to find unbacked symbols and their access paths. + + This function walks through tensors, lists, tuples, and symbolic values to locate + unbacked symbols that are in the pending set, and returns a mapping from those + symbols to their access paths in the structure. + + Args: + a: The object to traverse (tensor, list, tuple, SymInt, etc.) + path: The current path in the object tree + real: Optional real tensor corresponding to the fake tensor being traversed + shape_env: Optional ShapeEnv to register unbacked values with + pending: Set of unbacked symbols to look for (will be modified in-place) + simplify: Whether to use simplified expressions + + Returns: + A dictionary mapping unbacked symbols to their access paths + """ + go = functools.partial( + _free_unbacked_symbols_with_path, + shape_env=shape_env, + pending=pending, + simplify=simplify, + ) + + def expr(s: Union[SymInt, SymFloat, SymBool]) -> sympy.Expr: + if simplify: + return s.node.expr + # (When called from compute_unbacked_bindings) + # NB: Intentionally access _expr, not expr, do not want + # simplification! + return s.node._expr + + if pending is None: + pending = set() + r = {} + + def match_tensor(a: torch.Tensor, real_tensor: Optional[torch.Tensor] = None): + r.update( + go( + a.size(), + path + (CallMethodKey("size"),), + real=real_tensor.size() if real_tensor is not None else None, + ) + ) + if a.layout not in [ + torch.sparse_csr, + torch.sparse_csc, + torch.sparse_bsr, + torch.sparse_bsc, + ]: + r.update( + go( + a.stride(), + path + (CallMethodKey("stride"),), + real=real_tensor.stride() if real_tensor is not None else None, + ) + ) + r.update( + go( + a.storage_offset(), + path + (CallMethodKey("storage_offset"),), + real=( + real_tensor.storage_offset() if real_tensor is not None else None + ), + ) + ) + + if isinstance(a, (tuple, list)): + # NB: real is apparently not always a tuple/list here + # python test/inductor/test_torchinductor.py CpuTests.test_index_propagation_nested_indirect_indexing_cpu + for i in range(len(a)): + r.update( + go( + a[i], + path + (pytree.SequenceKey(i),), + real=real[i] if real is not None else None, # type: ignore[index] + ) + ) + elif is_traceable_wrapper_subclass(a): + # TODO: Determine if this is correct + attrs, _ = a.__tensor_flatten__() + for attr in attrs: + sub = getattr(a, attr) + r.update(go(sub, path + (InnerTensorKey(attr),))) + + # match DTensor outer shapes + if torch.distributed.is_available() and isinstance( + a, torch.distributed.tensor.DTensor + ): + match_tensor(a) + elif isinstance(a, torch.Tensor) and is_batchedtensor(a): + unwrapped_tensor = get_unwrapped(a) + r.update(go(unwrapped_tensor, path)) + elif isinstance(a, torch.Tensor) and not is_batchedtensor(a): + from torch._subclasses.fake_tensor import FakeTensor + + assert isinstance(a, FakeTensor) + match_tensor(a, a.real_tensor) + elif ( + isinstance(a, (torch.SymInt, torch.SymFloat)) + and isinstance(s := expr(a), sympy.Symbol) + and s in pending + ): + r[s] = path + if shape_env and real is not None: + assert isinstance(real, (int, float)) + + shape_env.set_unbacked_var_to_val(s, real) + + pending.remove(s) + # When an unbacked SymInt is perfectly divisible by an integer + # constant, we replace it with the integer constant to improve + # reasoning capabilities. However, in synthetic examples, it is + # then possible that the factor never is explicitly allocated. + # Fortunately, we can compute it by division. + elif ( + isinstance(a, torch.SymInt) + and isinstance(s := expr(a), sympy.Mul) + and len(s.args) == 2 + and isinstance(lhs := s.args[0], (sympy.Integer, sympy.Symbol)) + and isinstance(rhs := s.args[1], sympy.Symbol) + # support exactly one unbacked for now + and ((rhs in pending) ^ (lhs in pending)) + # support constant coefficient or backed symbolic coefficient + and ( + isinstance(coeff := lhs if lhs not in pending else rhs, sympy.Integer) + or shape_env + and coeff in shape_env.var_to_val + ) + ): + + def _symint_wrap(s: sympy.Symbol) -> SymInt: + return shape_env.create_symintnode( # type: ignore[union-attr] + s, + hint=int(shape_env.var_to_val[s]), # type: ignore[union-attr] + source=shape_env.var_to_sources.get(s, [None])[0], # type: ignore[union-attr] + ) + + unbacked = lhs if lhs in pending else rhs + divisor: IntLikeType = ( + int(coeff) + if shape_env and isinstance(coeff, sympy.Integer) + else _symint_wrap(coeff) + ) + # TODO: DivideByKey needs to test divisibility at runtime! + + r[unbacked] = path + (DivideByKey(divisor),) + if real is not None: + assert isinstance(real, int) + val = ( + real // int(coeff) + if isinstance(coeff, sympy.Integer) + else CleanDiv(real, coeff) + ) + if shape_env: + shape_env.set_unbacked_var_to_val(unbacked, val) + pending.remove(unbacked) + # The annoyance here arises from the fact that SymBool is + # allocated by allocating a SymInt and then testing if it's equal + # to one. So you have a complicated binding site logic for this. + elif ( + isinstance(a, torch.SymBool) + and isinstance(s := expr(a), sympy.Eq) + # This must match create_unbacked_symbool EXACTLY + and isinstance(s.lhs, sympy.Symbol) + and s.rhs == 1 + and s.lhs in pending + ): + r[s.lhs] = path + (ConvertIntKey(),) + if real is not None: + assert type(real) is bool + if shape_env: + shape_env.set_unbacked_var_to_val(s, int(real)) + + pending.remove(s.lhs) + + return r + + +def compute_unbacked_bindings( + shape_env: Optional[ShapeEnv], + example_value: object, + old_example_value: Optional[object] = None, + peek: bool = False, +) -> Optional[dict[sympy.Symbol, pytree.KeyPath]]: + """ + After having run fake tensor propagation and producing example_value + result, traverse example_value looking for freshly bound unbacked + symbols and record their paths for later. It is an error if + we have allocated an unbacked SymInt but it cannot be found in + example_value. (NB: this means if you have a multi-output + function, you must call this on the tuple of tensor output, you + cannot wait!) + + The peek parameter lets you check out what the bindings are without + changing the affected list. This is primarily useful for ensuring + unbacked_var_to_val is promptly populated when propagate_real_tensors is on. + """ + if shape_env is None: + return None + + fs = shape_env.pending_fresh_unbacked_symbols + + pending = set(fs) + if not pending: + return None + + if not peek: + log.info("compute_unbacked_bindings %s", fs) + fs.clear() + + symbol_to_path = _free_unbacked_symbols_with_path( + example_value, (), shape_env=shape_env, pending=pending, simplify=False + ) + if not peek and pending: + extra = ( + repr((example_value.stride(), example_value.storage_offset())) + if isinstance(example_value, torch.Tensor) + else "" + ) + raise PendingUnbackedSymbolNotFound( + f"Pending unbacked symbols {pending} not in returned outputs {example_value} {extra}.\n" + "Did you accidentally call new_dynamic_size() or item() more times " + "than you needed to in your fake implementation?\n" + "For more help, see https://docs.google.com/document/d/1RWrH-3wLEpzR9kCS6gGBNen_-Fs-8PVbWWFE5AcgeWE/edit" + ) + + # Why do we have to do some rebinding here? If the original FX node + # wasn't a binding site because you had a memo hit, but post + # translation you aren't a memo hit anymore, there's now a new binding + # site... but we know (because it's the same FX node) that the value + # is actually the same, they're just not obviously equal anymore. + # + # The logic here is written carefully, because unlike the + # bind_unbacked case, we are not guaranteed to have a symbol for + # old_sym. If we have a symbol, do regular rename unbacked to; but if + # we don't, we need to specially eliminate the fresh unbacked symbol + # (NB: we are /trusting/ that the memoization is correct, and that we + # don't need to generate a new runtime assert. This is load bearing, + # as repropagation can happen after we've frozen runtime asserts.) + if old_example_value is not None: + for keypath in symbol_to_path.values(): + old_sym = pytree.key_get(old_example_value, keypath) + new_sym = pytree.key_get(example_value, keypath) + if isinstance(new_sym, SymTypes) and isinstance( + new_s := new_sym.node.expr, sympy.Symbol + ): + if ( + isinstance(old_sym, SymTypes) + and (old_s := old_sym.node.expr) != new_s + ): + # If old_s is not an unbacked_symbol, + # we assume that the original unbacked symbol is replaced + # by a backed symbol (old_s). This can happen + # when this node reuses the original symbol (due to memoi) + # and the original symbol gets replaced by the backed symbol. + # When this happens we just replace new_s by the old_s + # because we know the value is the same. + + if isinstance(old_s, sympy.Symbol) and free_unbacked_symbols(old_s): + shape_env._rename_unbacked_to(new_s, old_s) + else: + shape_env._eliminate_unbacked(new_s, old_s) + elif not isinstance(old_sym, SymTypes): + shape_env._eliminate_unbacked(new_s, sympy.sympify(old_sym)) + + return symbol_to_path + + +# Note [guard_or_] +# The following two functions are common utilities used while defining unbacked semantics +# of various framework code. Those would be used in situations you prefer to guard and know +# the result of the expression over not guarding, but in case you hit a data dependent error +# you are ok with just returning true or false. +# +# When to use this? +# (1) If you can use a higher level combinator prefer using those instead, they are definitely safe (modulo short-circuiting). +# +# (2) It can be used if the program would behave equivalently if _guard_or returned true or false. +# Many inductor optimizations fall in this bracket for example. +# +# (3) Finally, it's even be OK if the program wouldn't behave equivalently, so long as the +# change is semantics preserving. It can be semantics preserving if the program errors in more +# cases than it did previously (but otherwise behaves identically), or if it changes some quantity +# in a way that doesn't matter (e.g., strides often fall in this bucket.) +# +# (4) Specialize for the general case and add a runtime assertion that would fail during +# runtime if the conditions for the general case are not satisfied. Examples for this are; +# assuming expand/reshape inputs are not -1. or assuming the non-broadcasting path. +# +def _guard_or(a: BoolLikeType, default: bool) -> bool: + """ + Try to guard a, if data dependent error encountered just return default. + """ + if not isinstance(a, SymBool): + assert isinstance(a, bool) + return a + + # if backed_size_oblivious is True we treat backed as unbacked here. + if torch.fx.experimental._config.backed_size_oblivious: + result = _static_eval_sym_bool(a) + return result if result is not None else default + + shape_env = getattr(a.node, "shape_env", None) + + # xla symnode path. + if shape_env is None: + return guard_bool(a) + + sym_node = a.node + r = sym_node.shape_env.evaluate_sym_node( + sym_node, size_oblivious=False, fallback_value=default + ) + return bool(r) + + +def guard_or_false(a: BoolLikeType) -> bool: + """ + Try to guard a, if data dependent error encountered just return false. + """ + return _guard_or(a, False) + + +def guard_or_true(a: BoolLikeType) -> bool: + """ + Try to guard a, if data dependent error encountered just return true. + """ + return _guard_or(a, True) + + +def _static_eval_sym_bool(x: SymBool) -> Optional[bool]: + assert isinstance(x, SymBool) + expr = x.node.expr + + try: + # Shape env access is inside the try on purpose. xla symnode does not + # have it on its attributes. + shape_env = x.node.shape_env + simplified = shape_env._maybe_evaluate_static(expr) + if simplified is not None: + return bool(simplified) + else: + return None + except Exception: + log.debug("Could not simplify %s", expr) + return None + + +def statically_known_false(x: BoolLikeType) -> bool: + """ + Returns True if x can be simplified to a constant and is False. + If x cannot be evaluated from static, we return False + + .. note:: + This function doesn't introduce new guards, so the expression may end + up evaluating to False at runtime even if this function returns False. + + Args: + x (bool, SymBool): The expression to try statically evaluating + """ + if not isinstance(x, SymBool): + assert isinstance(x, bool) + return not x + + result = _static_eval_sym_bool(x) + if result is None: + return False + + return not result + + +def statically_known_true(x: BoolLikeType) -> bool: + """ + Returns True if x can be simplified to a constant and is true. + + .. note:: + This function doesn't introduce new guards, so the expression may end + up evaluating to true at runtime even if this function returns False. + + Args: + x (bool, SymBool): The expression to try statically evaluating + """ + if not isinstance(x, SymBool): + assert isinstance(x, bool) + return x + result = _static_eval_sym_bool(x) + if result is None: + return False + + return result + + +def sym_and(x: BoolLikeType, *others: BoolLikeType) -> BoolLikeType: + """ + and, but for symbolic expressions, without bool casting. + """ + if len(others) == 0: + return x + for y in others: + x = operator.and_(x, y) + return x + + +def sym_eq(x: _T, y: _T) -> BoolLikeType: + """ + Like ==, but when run on list/tuple, it will recursively test equality + and use sym_and to join the results together, without guarding. + """ + if isinstance(x, (tuple, list)) and isinstance(y, (list, tuple)): + if len(x) != len(y): + return False + return functools.reduce(operator.and_, map(sym_eq, x, y), True) + elif isinstance(x, (int, torch.SymInt)) and isinstance(y, (int, torch.SymInt)): + return x == y + else: + raise AssertionError(f"unexpected sym_eq between {type(x)} {type(y)}") + + +def sym_or(x: BoolLikeType, *others: BoolLikeType) -> BoolLikeType: + """ + or, but for symbolic expressions, without bool casting. + """ + if len(others) == 0: + return x + for y in others: + x = operator.or_(x, y) + return x + + +def guard_scalar( + a: Union[SymBool, SymInt, SymFloat, int, bool, float], +) -> Union[bool, int, float]: + """ + Guard a scalar value, which can be a symbolic or concrete boolean, integer, or float. + + This function dispatches to the appropriate guard function based on the type of the input. + + Args: + a: A symbolic or concrete scalar value (bool, int, or float) + + Returns: + The concrete value after guarding + + Raises: + AssertionError: If the input is not a recognized scalar type + """ + if isinstance(a, (SymBool, bool)): + return guard_bool(a) + elif isinstance(a, (SymInt, int)): + return guard_int(a) + elif isinstance(a, (SymFloat, float)): + return guard_float(a) + else: + raise AssertionError(f"unrecognized scalar {a}") + + +def _advise_is_size(a: SymInt) -> None: + """ + Don't use this directly; use torch._check_is_size instead. + + This is a softer version of _constrain_range_for_size (with min=0, + max=Inf). Instead of forcibly constraining a variable (and erroring if we + failed to constrain it), it will simply advise us that a size is + constrained in some way. We will always defer a runtime assert for this + constraint if we cannot prove it at compile-time, but we we only + *sometimes* learn useful extra information at compile-time with this + information. This is in contrast to constrain_range_for_size, where if + you don't call that on a fresh unbacked symint, chances are we will choke. + + TODO: Make Dynamo handle this appropriately if this is seen in Dynamo-ed + code. Right now this is only really used in code with AOTAutograd trace + through, so it is not a big problem that this isn't supported, but in + principle all of this code should be Dynamo'able too. + + TODO: I didn't support min/max because I didn't have a use case where this + actually helped. In principle we can support it, it just makes the + implementation below more complicated. + """ + + # This must always succeed, because the sole allowed caller _check_is_size + # was responsible for expect_true'ing this + # This assert triggers expensive sym compute, do not do it until its cheap. + # assert a >= 0 + + # NB: it's important not to constrain range for size for *hinted* SymInts, + # because it is not only unsound, it will immediately trip our asserts + # that hints have to be consistent with static analysis! If you somehow + # have an unbounded SymInt that later constrains to 1, this will be + # inconsistent with the range + if ( + isinstance(a, SymInt) + and isinstance(a.node, SymNode) + and isinstance(a.node.expr, sympy.Symbol) + and a.node.shape_env.is_unbacked_symint(a.node.expr) + ): + _constrain_range_for_size(a) + + +def _advise_is_bounded(a: SymInt, upper_bound: IntLikeType) -> None: + if ( + isinstance(a, SymInt) + and isinstance(a.node, SymNode) + and isinstance(a.node.expr, sympy.Symbol) + and a.node.shape_env.is_unbacked_symint(a.node.expr) + and isinstance(upper_bound, int) # TODO: relax + ): + a.node.shape_env._constrain_is_bounded(a.node.expr, upper_bound) + + +def _constrain_range_for_size( + a: SymInt, min: Optional[int] = None, max: Optional[int] = None +) -> None: + """ + This function is NOT INTENDED to be used by itself. + """ + + if isinstance(a, (SymFloat, SymBool)): + raise ValueError("Constraining SymFloat/SymBool is nyi") + + assert isinstance(a, SymInt), "can only constrain range for SymInt" + assert isinstance(a.node.expr, sympy.Symbol), f"constraining non-Symbols NYI: {a}" + + a.node.shape_env._constrain_range_for_size(a.node.expr, min, max) + + +# inclusive both ways +def constrain_range( + a: SymInt, *, min: Optional[int], max: Optional[int] = None +) -> None: + """ + Applies a constraint that the passed in SymInt must lie between min-max + inclusive-inclusive, WITHOUT introducing a guard on the SymInt (meaning + that it can be used on unbacked SymInts). If min/max are None, we assume + that the dimension is unbounded in that direction. Repeated application + of constrain_range intersects the ranges. This is a fairly low level API + that doesn't have a lot of safety guarantees (TODO: provide higher level + APIs). + + Currently, we use this API in the following circumstance: when we allocate + an unbacked SymInt, denoting an integer quantity which is data dependent, + we ordinarily do not know anything about what values it may take. This + means that any sort of guard on it will immediately fail. However, in + many cases, we know something about the unbacked SymInt: for example, we + know that nonzero(x).size(0) must be >= 0. We use constrain_range to + narrow the possible range, declaring that negative symbols are impossible. + This permits to definitely answer True to queries like 'nnz >= 0', even if + we don't know what the actual (hinted) value of 'nnz' is. In fact, we + actually use constrain_range to unsoundly discharge common guards: for an + unbacked SymInt produced by nonzero, we will also assume that it is not + equal to 0/1 (even though these are perfectly possible values at runtime), + because we generally expect graphs that are valid for N=2 to also be valid + for N=1. + """ + if min is None: + min = -int_oo + if max is None: + max = int_oo + + if max < min: + raise ValueError( + "Maximum value to constrain_as_size can't be less than the specified min value, " + f"received min={min} and max={max}" + ) + + if isinstance(a, int): + if not (min <= a <= max): + raise ValueError(f"Invalid value {a} for range [{min}:{max}]") + return + + a.node.shape_env._constrain_range(a.node.expr, min, max) + + +def constrain_unify(a: torch.SymInt, b: torch.SymInt) -> None: + """ + Given two SymInts, constrain them so that they must be equal. NB: + this will not work with SymInts that represent nontrivial expressions + (yet!) + """ + if not isinstance(a, SymInt): + if not isinstance(b, SymInt): + assert a == b + return + else: + shape_env = b.node.shape_env + else: + shape_env = a.node.shape_env + + shape_env._constrain_unify(a, b) + + +# Assume that a boolean is true for the purposes of subsequent symbolic +# reasoning. This will keep track of corresponding runtime checks to verify +# that the result is upheld: either as a regular guard, or as a special set +# of asserts which are triggered when an unbacked SymInt is allocated. +# +# DO NOT use this function for these cases: +# +# - This is inappropriate for "branching" conditions (where both +# true and false result in valid programs). We will always assume +# the condition evaluates true, and so it will never be possible +# to trace the false condition when you use it. For true branching +# on unbacked SymInts, you must use torch.cond; if you incorrectly +# use expect_true in this case, you will make the false branch +# unreachable (as we will simply assume that only the true branch +# is ever exercised). +# +# - This is inappropriate for situations where you know some other system +# invariant guarantees that this property holds, since you don't +# really need to insert a runtime check in that case. Use something +# like constrain_range in that case. +# +# This API has a hitch. To avoid having to reimplement error reporting +# capabilities, this function CAN return False. The invariant is that +# the surrounding code must raise an error when this function returns +# False. This is quite low level, so we recommend using other functions +# like check() which enforce this in a more intuitive way. +# +# By the way, this name is a nod to the __builtin_expect macro, +# which is used similarly (but unlike __builtin_expect, you MUST fail +# in the unlikely branch.) (I think expect is a good name; in recent +# versions of C++, this is replaced with [[likely]], which is weaker +# and not accurate for this function!) +def expect_true(a: BoolLikeType, skip: int = 0) -> bool: + if isinstance(a, SymBool): + # TODO: check perf implications of this + frame = inspect.currentframe() + for _ in range(skip + 1): # always run this loop at least once + if frame is None: + break + frame = frame.f_back + return a.node.expect_true( + frame.f_code.co_filename if frame else "", frame.f_lineno if frame else 0 + ) + assert type(a) is bool, a + return a + + +def guard_bool(a: BoolLikeType) -> bool: + if isinstance(a, SymBool): + return a.node.guard_bool("", 0) # NB: uses Python backtrace + assert type(a) is bool, a + return a + + +def guard_int(a: IntLikeType) -> int: + if isinstance(a, SymInt): + return a.node.guard_int("", 0) # NB: uses Python backtrace + assert type(a) is int, a + return a + + +def guard_float(a: FloatLikeType) -> float: + if isinstance(a, SymFloat): + return a.node.guard_float("", 0) # NB: uses Python backtrace + assert isinstance(a, float), a + return a + + +# Given a GraphModule, return all the FakeTensors for all the placeholders +def fx_placeholder_vals(gm: torch.fx.GraphModule) -> list[object]: + return [n.meta["val"] for n in gm.graph.nodes if n.op == "placeholder"] + + +def fx_placeholder_targets(gm: torch.fx.GraphModule) -> list[str]: + return [n.target for n in gm.graph.nodes if n.op == "placeholder"] + + +# Given a GraphModule and arguments to run it with, evaluate that the guards +# for its associated ShapeEnv are satisfied by the passed arguments. This +# WILL check for duck sizing. +def eval_guards( + gm: torch.fx.GraphModule, *args: Tensor, ignore_static: bool = True +) -> bool: + assert gm.shape_env is not None + return gm.shape_env.evaluate_guards_for_args( # type: ignore[operator, union-attr] + fx_placeholder_vals(gm), args, ignore_static=ignore_static + ) + + +def bind_symbols(gm: torch.fx.GraphModule, *args: Tensor) -> dict[sympy.Symbol, int]: + assert gm.shape_env is not None + return gm.shape_env.bind_symbols(fx_placeholder_vals(gm), args) # type: ignore[operator, union-attr] + + +class DimDynamic(Enum): + """ + Controls how to perform symbol allocation for a dimension. It is always + sound to default this to DYNAMIC, but the policies DUCK and STATIC can + result in better trace-time and compile-time performance, as they reduce + the number of allocated symbols and generally make your graph more static. + + NB: If we notice you've applied a constraint to the dimension, we will + force it to DYNAMIC for simplicity. + + DimDynamic is controlled by a variety of higher level UX features. + Currently: + + - In eager mode, the default policy is DUCK. + - The default is changed to STATIC with assume_static_by_default. + - An individual dim is marked DYNAMIC if you mark_dynamic_dim. + - In export mode, the default policy is STATIC. + - An individual dim is marked DYNAMIC if you specify it in + dynamic_shapes passed to export. + """ + + # Treat the dimension symbolically + DYNAMIC = 0 + # Treat the dimension symbolically, but if its hint matches another + # dynamic dimension, unify the two symbols ("duck sizing") + DUCK = 1 + # Treat the dimension statically based on its hint + STATIC = 2 + # Treat the dimension as a size-like unbacked + SIZE_LIKE_UNBACKED = 3 + # Infer the strides from stride. If size is static, strides will be static as well. + INFER_STRIDE = 4 + # Like SIZE_LIKE_UNBACKED, but there's a hint + OBLIVIOUS_SIZE = 5 + + +# NB: These constraints affect both clients and backends: given some +# constraint C, the client must pass inputs that satisfy the constraint, +# while a backend must not introduce guards BEYOND this constraint. +# For clarity, we document the implications on both sides for both the client +# and the backend. +# +# NB: These constraints are on a *single* dimension. In principle, we could +# also have multi-dimension constraints, but our guess is that this is not +# actually useful and so we are not supporting it right now. +# +# NB: Strict constraints are typically only suitable for export, as in eager +# a backend like inductor may validly introduce extra, discretionary guards +# to improve performance of code. A StrictMinMaxConstraint would be brittle +# under future optimizations performed by inductor; we don't guarantee +# eager code with StrictMinMaxConstraint will keep working in the future! + + +@dataclass(frozen=True) +class Constraint: + warn_only: bool + + +@dataclass(frozen=True) +class StrictMinMaxConstraint(Constraint): + """ + For clients: the size at this dimension must be within 'vr' (which + specifies a lower and upper bound, inclusive-inclusive) AND it + must be non-negative and should not be 0 or 1 (but see NB below). + + For backends: there must not be any guards on this dimension which + are not implied by the given lower and upper bound. Regardless of + the lower bound, the backend can assume the size is non-negative + and that it is not 0 or 1. + + An unbounded StrictMinMaxConstraint can be thought of as a strict version + of "RelaxedUnspecConstraint". + + NB: Export will often unsoundly assume that a graph works for 0/1, even + though at trace time we assumed size is not 0 or 1. The idea is that + if we produce a graph that works for a range of values, it will be OK + for N=0/1 too. + """ + + vr: ValueRanges + + def render(self, source: Source) -> str: + """Format the constrain equation""" + # TODO: better printing for -oo and oo + return f"{self.vr.lower} <= {source.name} <= {self.vr.upper}" + + +@dataclass(frozen=True) +class RelaxedUnspecConstraint(Constraint): + """ + For clients: no explicit constraint; constraint is whatever is implicitly + inferred by guards from tracing. + + For backends: there must exist at least TWO possible values for the + size at this dimension which satisfy the guards for this dimension. + + In other words, this constraint helps us distinguish between "we don't + care if this dimension specializes or not" versus "this dimension must be + unspecialized." However, this constraint doesn't say very much about what + specialization is permitted; for example, if we guard on a size being + even, this would still be acceptable under an unspec constraint. This + makes RelaxedUnspecConstraint useful for eager mode, where your backend compiler + may add constraints to otherwise dynamic dimensions; we can't assert that + there are NO guards as this is brittle because compilers should be able to + add extra constraints. If you want to assert that there are no guards, + use StrictMinMaxConstraint with an unbounded ValueRanges. + """ + + def render(self, source: Source) -> str: + return f"RelaxedUnspecConstraint({source.name})" + + +# NB: None here indicates the client constraint is whatever is implicitly +# inferred by guards from tracing, and that a backend can add whatever guards +# it wants (including fully specializing the value). +DimConstraint = Union[StrictMinMaxConstraint, RelaxedUnspecConstraint, None] + + +@dataclass(frozen=True) +class EqualityConstraint(Constraint): + """ + Represent and decide various kinds of equality constraints between input sources. + + A "source pair" is a pair of input sources for dynamic dimensions that + are specified equal. We represent `source_pairs` in a union-find forest + so that we can efficiently check whether two such sources are transitively equal. + + A "derived equality" relates an input source to an expression over a root. + The root can be another input source, corresponding to some dynamic dimension, + or a phantom symbol that does not directly represent any dynamic dimension. We + represent `derived_equalities` involving input sources in a transitively-closed map + so that we can efficiently check whether an input source is transitively equal to + a given expression over another input source. + (NOTE: In contrast, it is easy to decide whether an input source is transitively equal + to a given expression over a phantom symbol; such expressions are already in canonical + form and so the problem reduces to symbolic expression equality.) + """ + + source_pairs: list[tuple[Source, Source]] + derived_equalities: list[ + tuple[Source, Union[Source, sympy.Symbol], Callable[[sympy.Expr], sympy.Expr]] + ] + phantom_symbols: list[sympy.Symbol] + relaxed_sources: set[Source] + + _parents: dict[Source, Source] = field(init=False) + _defs: dict[Source, sympy.Expr] = field(init=False) + + def __post_init__(self) -> None: + """ + Pre-processing to answer queries `is_equal` and `is_derived` below. + + Example: Suppose we are given: + source_pairs [a = b, b = c] + derived_equalities [d = c + 1, e = d - 1] + We first construct a union find with source_pairs: + _parents = {a: a, b: a, c: a} + Then we compute canonical symbolic expressions, recursively applying derived_equalities + until we bottom out: + _defs = {d: c + 1, e: (c + 1) - 1 aka c} + """ + + # self._parents is a map from input sources to input sources where, conceptually, + # these are directed edges in a union-find forest + _parents: dict[Source, Source] = {} + object.__setattr__(self, "_parents", _parents) + # self._defs is a map from input sources to "canonical" symbolic expressions, + # i.e., unary expressions with symbols that corresponds to regular Dims (i.e., + # not derived Dims) + _defs: dict[Source, sympy.Expr] = {} + object.__setattr__(self, "_defs", _defs) + + for source1, source2 in self.source_pairs: + # preprocess into a union-find forest + self._union(self._find(source1), self._find(source2)) + for source, root, fn in self.derived_equalities: + # preprocess into a transitively-closed map + # NOTE(avik): we reuse the union-find forest for canonicalizing input sources + if isinstance(root, (sympy.Symbol, sympy.Integer)): + self._defs[self._find(source)] = fn(root) + else: + self._defs[self._find(source)] = fn(self._rewrite(root)) + + def _find(self, source: Source) -> Source: + # chase edges to find the root of this equivalence class + if source in self._parents: + return self._find(self._parents[source]) + else: + return source + + def _union(self, root1: Source, root2: Source) -> None: + # merge two equivalence classes by adding an edge from one root to the other + if root1 != root2: + self._parents[root1] = root2 + + def _rewrite(self, src: Source) -> sympy.Expr: + # always represent the given source by the root of its equivalence class + src = self._find(src) + if src in self._defs: + # simply look up the definition if it exists + # NOTE(avik): This works because definitions are always transitively-closed; + # otherwise we would have to do recursive rewriting. + return self._defs[src] + else: + # otherwise, create a symbol representing the source + return sympy.Symbol(src.name) + + def is_equal(self, source1: Source, source2: Source) -> bool: + return ( + # check whether source1 and source2 have the same root + # or are relaxed + (src1 := self._find(source1)) in self.relaxed_sources + or (src2 := self._find(source2)) in self.relaxed_sources + or src1 == src2 + # check whether source1 is derived equal to source2 + or self.is_derived(source1, source2, lambda x: x) + ) + + def is_derived( + self, src: Source, symbol_src: Source, fn: Callable[[sympy.Expr], sympy.Expr] + ) -> bool: + # check whether both src and symbol_src have the same definition + return self._rewrite(src) == fn(self._rewrite(symbol_src)) + + +def _assert_symbol_context(symbolic_context: object) -> TypeGuard[SymbolicContext]: + assert isinstance(symbolic_context, SymbolicContext), ( + "Invalid symbolic_context object" + ) + assert type(symbolic_context) is not SymbolicContext, ( + "Illegal usage of symbolic_context ABC" + ) + return True + + +def _is_supported_equivalence(expr: sympy.Expr) -> bool: + # Currently supported Dim ops are linear expressions with integer coefficients. + # So check that expr only contains +, *, ints, and a single occurrence of a symbol. + # (See also documentation of dynamic_shapes._DerivedDim.) + if isinstance(expr, (sympy.Add, sympy.Mul)): + if len(expr.args) > 2: + return False + lhs, rhs = expr.args + return (_is_supported_equivalence(lhs) and isinstance(rhs, sympy.Integer)) or ( + isinstance(lhs, sympy.Integer) and _is_supported_equivalence(rhs) + ) + return isinstance(expr, sympy.Symbol) + + +def _has_uninterpretable_sympy_function(expr: sympy.Basic) -> bool: + """ + Add functions that our sympy interpreter can't reify into FX nodes + """ + return expr.has( + torch.utils._sympy.functions.ToFloat, + torch.utils._sympy.functions.TruncToInt, + torch.utils._sympy.functions.CeilToInt, + ) + + +@dataclass(frozen=True) +class SymbolicContext: + """ + Data structure specifying how we should create symbols in + ``create_symbolic_sizes_strides_storage_offset``; e.g., should + they be static or dynamic. + + This is an abstract base class because we are probably going to add + another version of this that says "use exactly these SymInts, don't + allocate fresh symbols." + """ + + +@dataclass(frozen=True) +class SymIntSymbolicContext(SymbolicContext): + """ + Data structure specifying any constraints on a SymInt input + """ + + constraint: DimConstraint + + +_P1 = ParamSpec("_P1") +_T1 = TypeVar("_T1") + + +@dataclass(frozen=True) +class StatelessSymbolicContext(SymbolicContext, Generic[_P1, _T1]): + """ + Create symbols in ``create_symbolic_sizes_strides_storage_offset`` via + a symbolic_context determination as given by ``DimDynamic`` and ``DimConstraint``. + This will cause fresh symbols to be allocated + """ + + dynamic_sizes: DimList[DimDynamic] + dynamic_strides: DimList[DimDynamic] = None # type: ignore[assignment] + constraint_sizes: DimList[DimConstraint] = None # type: ignore[assignment] + constraint_strides: DimList[DimConstraint] = None # type: ignore[assignment] + specialize_on: Optional[list[list[Callable[_P1, _T1]]]] = None + # If the tensor is a view, this should be populated for the base. It contains + # information on how to allocate symbols when recursively fakeifying the base + # during view fake-ification. + view_base_context: Optional[SymbolicContext] = None + # TODO: add storage offset and stride symbolic_context + + def __post_init__(self) -> None: + if self.specialize_on is None: + object.__setattr__( + self, + "specialize_on", + [[]] * len(self.dynamic_sizes), + ) + if self.dynamic_strides is None: + object.__setattr__( + self, + "dynamic_strides", + [DimDynamic.INFER_STRIDE] * len(self.dynamic_sizes), + ) + if self.constraint_sizes is None: + object.__setattr__( + self, "constraint_sizes", [None] * len(self.dynamic_sizes) + ) + if self.constraint_strides is None: + object.__setattr__( + self, "constraint_strides", [None] * len(self.dynamic_sizes) + ) + assert all( + stride in (DimDynamic.INFER_STRIDE, DimDynamic.DYNAMIC, DimDynamic.DUCK) + for stride in self.dynamic_strides + ) + + +# note [Tensor Fakification and Symbol Caching] +# +# As of the time of this note, dynamo creates a fresh fake tensor mode for backends. +# The reason we do this is because there are certain classes of operations, namely, +# metadata mutations, that change tensor size, stride, etc. This means that the fake tensor +# state at the end of a dynamo trace is different than the fake tensor state at the beginning +# of a trace. Backends like aot_autograd need a fresh fake tensor to correctly track metadata mutation, +# view relationships, etc. +# +# As we create a new fake mode, we also lose the memoization that comes with it. Rather than +# transfer the memoization cache, we instead transfer the shape env. However, with this +# comes nuance - as dynamo is selective in how it makes symbolic shapes. Due to strategies in +# automatic dynamic and constraints, the policy for which dims are dynamic is nuanced and varies across +# recompilations. +# +# In order to preserve the symbolic decisions made during dynamo tensor fakification, we pass +# a StatefulSymbolicContext at creation time. This object is tracked, per tensor, on the TracingContext. +# The lifecycle of this object should match the lifecycle of the original dynamo tracked tensor, and it is +# safe to reuse this object as many times as necessary to create a fake tensor. Fake tensors +# created with new fake modes should produce the same exact symbols as the original, providing the same shape_env +# is used. +# TODO(voz): Shape env validation +@dataclass(frozen=True) +class StatefulSymbolicContext(StatelessSymbolicContext): + """ + Create symbols in ``create_symbolic_sizes_strides_storage_offset`` via + a symbolic_context determination as given by a cache of Source:Symbol. A cache hit + will reuse a stored symbol, and a cache miss will write to this cache. + + This behaves like StatelessSymbolicContext, except the cache supersedes the + other values - dynamic_sizes and constraint_sizes will not be read if we cache + hit. + + It is the cache owner's responsibility to maintain the lifecycle of the cache + with respect to different shape_envs, clearing, etc. + """ + + tensor_source: Source = None # type: ignore[assignment] + # Why is this keyed on int first? + # That integer is actually the id of the shape_env. This cache short-circuits symbol + # creation, and we must store it per shape env. Now, while tracing invariants are a single + # shape env per tracing context, and every new frame gets a new shape_env. So where would we have + # multiple shape envs? The answer lies in recording. When we are replaying, replay_shape_env_events + # is invoked, and creates a new shape_env. Replaying events against this new shape_env will + # cause it to fail with unknown symbols, as the symbols cached here will skip creation, and never + # get recorded in var_to_val, etc. + # TODO(voz): consider a weakref to the shape_env here + shape_env_to_source_to_symbol_cache: dict[int, dict[str, sympy.Expr]] = None # type: ignore[assignment] + + def __post_init__(self) -> None: + super().__post_init__() + # The None default is annoying, but required because of dataclass limitations + assert self.tensor_source is not None + if not self.shape_env_to_source_to_symbol_cache: + object.__setattr__(self, "shape_env_to_source_to_symbol_cache", {}) + + +@dataclass(frozen=True) +class SubclassSymbolicContext(StatefulSymbolicContext): + """ + The correct symbolic context for a given inner tensor of a traceable tensor subclass + may differ from that of the outer symbolic context. This structure allows for this + flexibility, with inner symbolic contexts mapped via attr -> symbolic context. + """ + + inner_contexts: dict[str, SymbolicContext] = None # type: ignore[assignment] + + def __post_init__(self) -> None: + super().__post_init__() + if self.inner_contexts is None: + # pyrefly: ignore [bad-assignment] + self.inner_contexts = {} + + +@dataclass +class TrackedFake: + """ + Tracks the sources of all fake tensors we wrap in Dynamo. + Used by shape guard computation. + """ + + fake: Union[FakeTensor, SymInt] + source: Source + symbolic_context: Optional[SymbolicContext] + + def __hash__(self) -> int: + return hash((self.fake, self.source.name)) + + def __eq__(self, other: object) -> bool: + if isinstance(other, TrackedFake): + return self.fake is other.fake and self.source.name == other.source.name + return False + + +def is_symbolic( + val: Union[int, SymInt, float, SymFloat, bool, SymBool], +) -> TypeGuard[Union[SymInt, SymFloat, SymBool]]: + if isinstance(val, (int, float, bool)): + return False + return val.node.is_symbolic() + + +IndicatorTypes = (IsNonOverlappingAndDenseIndicator,) + + +def _expandsums(args: list[sympy.Expr]) -> tuple[sympy.Expr, bool]: + """ + Expand products of sums into sums of products. + + This function takes a list of sympy expressions and separates them into + additive expressions (those with is_Add=True) and other expressions. + It then computes the distributive product, expanding (a+b)*(c+d) into a*c + a*d + b*c + b*d. + + Args: + args: A list of sympy expressions to expand + + Returns: + A tuple containing: + - The expanded expression as a sympy.Expr + - A boolean indicating whether expansion occurred (True if multiple additive + expressions were present or if there was at least one additive and one other expression) + """ + adds, other = [], [] + for arg in args: + if arg.is_Add: + adds.append(arg) + else: + other.append(arg) + + result = [sympy.Mul(*other)] + for add in adds: + result = [a * b for a, b in itertools.product(result, add.args)] + + result = sympy.Add(*result) + return result, len(adds) > 1 or (len(adds) > 0 and len(other) > 0) + + +def _fast_expand(expr: _SympyT) -> _SympyT: + """ + A faster implementation of sympy's expand function for common cases. + + This function expands expressions like (a+b)^n or (a+b)*(c+d) into sums of products, + but avoids the expensive checks and features of sympy's full expand implementation. + It only recreates objects when necessary to avoid expensive operations. + + Args: + expr: A sympy expression to expand + + Returns: + The expanded expression + """ + + # The expand algorithm in sympy is slow due to all the features is supports + # For eg: e^(-x)*(x-1)/(x+1) is expanded to (x-1)/(e^x + e^x*x) if x is + # positive and (e^(-x)*x-e^(-x))/(x+1) if x is negative. We do not implement + # such features here to avoid expensive checks. We also make sure that we + # only re-create the objects if any of the args changed to avoid expensive + # checks when re-creating objects. + new_args = [_fast_expand(arg) for arg in expr.args] # type: ignore[arg-type] + # pyrefly: ignore [missing-attribute] + if any(arg is not new_arg for arg, new_arg in zip(expr.args, new_args)): + # pyrefly: ignore [missing-attribute] + return _fast_expand(expr.func(*new_args)) + + # pyrefly: ignore [missing-attribute] + if expr.is_Pow: + base: sympy.Expr + exp: sympy.Expr + base, exp = expr.args # type: ignore[assignment] + if exp.is_Integer and base.is_Add: + if exp > 1: + return sympy.expand_multinomial(expr, deep=False) + elif exp < 0: + return S.One / sympy.expand_multinomial(S.One / expr, deep=False) + # pyrefly: ignore [missing-attribute] + elif expr.is_Mul: + num: list[sympy.Expr] = [] + den: list[sympy.Expr] = [] + # pyrefly: ignore [missing-attribute] + for arg in expr.args: + if arg.is_Pow and arg.args[1] == -1: + den.append(S.One / arg) # type: ignore[operator, arg-type] + else: + num.append(arg) # type: ignore[arg-type] + + num, num_changed = _expandsums(num) + den, den_changed = _expandsums(den) + if num_changed or den_changed: + return num / den + + return expr + + +@lru_cache(256) +def safe_expand(r: _SympyT) -> _SympyT: + """ + Expand the given symbolic expression by recursively rewriting product of + sums into sum of products (with the product being either a multiplication or + exponentiation). + + NOTE: using this on an intermediate expression may prevent simplification + down the line, e.g., if we eagerly expand `(a + b)^2` into `a^2 + 2ab + b^2`, + we won't be able to simplify `(a^2 + 2ab + b^2) / (a + b)` as easily. + """ + if hasattr(r, "expand"): + try: + return _fast_expand(r) + except RecursionError: + log.warning("RecursionError in _fast_expand(%s)", r) + return r + else: + return r + + +class _SymbolInfo(NamedTuple): + k: sympy.Symbol + vr: Optional[ValueRanges] + val: Optional[sympy.Integer] + is_size_like: bool + + +@lru_cache(None) +def _maybe_evaluate_static_worker( + expr: _SympyT, + # NB: this is a tuple to ensure it can be LRU cached + symbol_info: tuple[_SymbolInfo, ...], + unbacked_only: bool, + size_oblivious: bool, +) -> Optional[_SympyT]: + """ + This variant of ShapeEnv._maybe_evaluate_static has no dependence on + ShapeEnv and thus can be cached indefinitely. It does the "heavy" lifting + for static evaluation, including nontrivial reliance on Sympy simplification + that occurs when we reallocate the symbols + """ + + # Simplify making use of value range lower bound + new_shape_env = {} + new_range_env = {} + for idx, sinfo in enumerate(symbol_info): + k, vr, val, is_size_like = sinfo + if isinstance(val, SingletonInt): + # Skip var_ranges logic for SingletonInt which is only used + # for jagged layout NestedTensors today + continue + assert vr is not None + if size_oblivious and is_size_like: + lower = max(2, vr.lower) + # Clamping size-oblivious to some quantity below sys.maxsize + # helps us determine that f(u0) != sys.maxsize, which is a + # test that is looking for sys.maxsize as a sentinel, but you + # don't really want to worry about it for unbacked SymInts. + # This is similar to the flavor where size oblivious omits + # 0/1, it changes semantics but in a benign way. + upper = min(2**48, vr.upper) + # Excluding the very upper bound can be helpful + if upper > lower: + upper = upper - 1 + # This is a bit dodgy: what this means is that there was a + # size-like unbacked symbol whose upper bound < 2. This + # causes... problems. + if lower <= upper: + vr = ValueRanges(lower, upper) + else: + lower = vr.lower + # Don't do anything if we don't have a nontrivial lower bound + # Also don't do anything if we asked only to simplify unbacked + # SymInt + if lower is -int_oo or (unbacked_only and val is not None) or not vr.is_int: + new_range_env[k] = vr + continue + # The goal is to take our symbols which have various lower bounds + # and reallocate them into new symbols which are exactly positive; + # e.g., if we have s0 in [2, inf], we want to turn it into ess0 in + # [1, inf], where s0 = ess0 + 1. This gives the most information + # to sympy for subsequent simplifications. + # + # Positive means >= 1 + # Positive - 1 means >= 0 + # Positive + lower - 1 means >= lower + # The new symbol 's' is "too low", so when we substitute it in + # we have to increase it by offset (and conversely, the new + # variables have to have their value range bounds adjusted as + # well) + s = sympy.Symbol(f"evaluate_static_shape_{idx}", positive=True, integer=True) + + # Note: + # Offset might be a fraction(e.g. aten.split.Tensor), but shapes are always integers. + # Sympy might give unexpected results when comparing an integer with a non-integer + # Therefore, we cast offset to int here. + # For example: + # shape_0 = sympy.Symbol("shape_0", positive=True, integer=True) + # expr = sympy.Eq(shape_0 - 1/3, 4) + # expr.xreplace({}) # False + offset = int(lower - 1) + new_shape_env[k] = s + offset + new_range_env[s] = SymPyValueRangeAnalysis.add(vr, -offset) + + # TODO: remove this try catch (esp for unbacked_only) + try: + # pyrefly: ignore [missing-attribute] + new_expr = expr.xreplace(new_shape_env) + except RecursionError: + log.warning("RecursionError in sympy.xreplace(%s, %s)", expr, new_shape_env) + return None + + # We need to canonicalize, as after expand we may have something like `a + b = a` and + # sympy will not simplify the a. The two appeareances of the a will then make value ranges + # analysis give lose bounds + new_expr = canonicalize_bool_expr(safe_expand(new_expr)) + if new_expr.is_number: + return new_expr + + # Check if the range can solve it statically + out = bound_sympy(new_expr, new_range_env) + if out.is_singleton(): + return out.lower + + return new_expr if unbacked_only else None + + +def error() -> NoReturn: + raise AssertionError("shouldn't be hit") + + +# TODO: Deduplicate this with torch/_prims_common/__init__.py +def eval_is_non_overlapping_and_dense( + sizes: Sequence[int], strides: Sequence[int] +) -> int: + return int(guard_bool(_eval_is_non_overlapping_and_dense(sizes, strides))) + + +def _eval_is_non_overlapping_and_dense( + sizes: Sequence[int], strides: Sequence[int] +) -> bool: + """ + Evaluates whether a tensor with the given sizes and strides is non-overlapping and dense. + + A tensor is non-overlapping if there's no memory location that belongs to more than one element. + A tensor is dense if all elements are stored in memory without gaps. + + Args: + sizes: Sequence of dimension sizes for the tensor + strides: Sequence of strides for the tensor + + Returns: + True if the tensor is non-overlapping and dense, False otherwise + """ + dim = len(sizes) + + # Short-circuits for tensors of rank one, which are + # non-overlapping and "dense" if their stride is one + # or it is a 0/1 element tensor + if dim == 1: + return strides[0] == 1 or sizes[0] < 2 + + # Checks that there exists a permutation of the strides s.t. the tensor would be contiguous + # Sorts (length, stride) pairs by stride + lengths_and_strides = sorted(zip(sizes, strides), key=operator.itemgetter(1)) + + # Unlike the C++ code, we don't move the 0/1 size dimensions to the + # end. So we have to keep going for this code. + expected_stride = 1 + for length, stride in lengths_and_strides: + if length == 1: + continue + + if stride != expected_stride: + return False + + expected_stride *= length + + return True + + +def _sympy_cast_symbool_to_symint_guardless(x: SympyBoolean) -> sympy.Expr: + return sympy.Piecewise((1, x), (0, True)) + + +def cast_symbool_to_symint_guardless( + symbool: Union[bool, torch.SymBool], +) -> Union[int, torch.SymInt]: + """ + Converts a SymBool or bool to a SymInt or int without introducing guards. + + This function maps True to 1 and False to 0, preserving the symbolic nature + of the input when it's a SymBool. Unlike regular casting which might introduce + guards, this function performs the conversion without adding any guards. + + Args: + symbool: A boolean value, either a concrete bool or symbolic SymBool + + Returns: + The corresponding integer value (1 for True, 0 for False) as either + a concrete int or symbolic SymInt + """ + if isinstance(symbool, bool): + return 1 if symbool else 0 + int_sym = _sympy_cast_symbool_to_symint_guardless(symbool.node.expr) + return symbool.node.shape_env.create_symintnode( + int_sym, hint=int(symbool.node.require_hint()) if has_hint(symbool) else None + ) + + +SYMPY_INTERP = { + "IsNonOverlappingAndDenseIndicator": eval_is_non_overlapping_and_dense, + "cast_symbool_to_symint_guardless": cast_symbool_to_symint_guardless, + "math": math, + "torch": torch, +} + + +def _lru_cache( + fn: Callable[..., _T], maxsize: Optional[int] = None +) -> functools._lru_cache_wrapper[_T]: + """ + Wrapper around lru_cache that clears when new info about shapes has been + updated. + + Use lru_cache if the output is always the same, regardless of the + constraints we know now (i.e. evaluate_expr) + + Use _lru_cache otherwise. + + Also note that this depends on _update_version_counter being called on the + shape environment whenever the constraints are updated, otherwise the cache + will not be cleared. + """ + fn_cache = lru_cache(maxsize)(fn) + prior_version = 0 + + if config.validate_shape_env_version_key: + prior_key = None + + @functools.wraps(fn) + def wrapper(self: ShapeEnv, *args: Any, **kwargs: Any) -> _T: + nonlocal prior_version, prior_key + if prior_key is None: + prior_key = self._get_key() + + if prior_version != self._version_counter: + fn_cache.cache_clear() + prior_version = self._version_counter + prior_key = self._get_key() + else: + assert prior_key == self._get_key(), ( + "ShapeEnv cache key changed without version being updated!" + ) + + return fn_cache(self, *args, **kwargs) + + else: + + @functools.wraps(fn) + def wrapper(self: ShapeEnv, *args: Any, **kwargs: Any) -> _T: # type: ignore[misc] + nonlocal prior_version + if prior_version != self._version_counter: + fn_cache.cache_clear() + prior_version = self._version_counter + + return fn_cache(self, *args, **kwargs) + + wrapper.cache_clear = fn_cache.cache_clear # type: ignore[attr-defined] + wrapper.cache_info = fn_cache.cache_info # type: ignore[attr-defined] + return wrapper # type: ignore[return-value] + + +@dataclass(frozen=True) +class RuntimeAssert: + """ + This is pretty similar to ShapeGuard but it also comes with a message, + and is exclusively used for things that MUST be true (unlike guards, + which can evaluate False, in which case you just choose not to use + a particular specialization) + """ + + expr: SympyBoolean + msg: str = field(repr=False) + stack: CapturedTraceback = field(repr=False) + + +# Used for printing SymExprs in compile_fx +class SymExprPrinter(PythonPrinter): + def _print_Float(self, expr: sympy.Float) -> str: + return str(float(expr)) + + +class _ShapeGuardPrinter(abc.ABC): + """ + Abstract base class for printers that convert symbolic expressions to string representations. + + This class provides common functionality for printing symbolic expressions with + special handling for symbols that represent tensor shapes, strides, etc. + Subclasses implement specific formatting for different output languages. + + Args: + symbol_to_source: Mapping from sympy symbols to their source objects + source_ref: Function to convert a source to its string representation + var_to_sources: Mapping from sympy symbols to their source objects (for error reporting) + """ + + def __init__( + self, + symbol_to_source: Mapping[sympy.Symbol, list[Source]], + source_ref: Callable[[Source], str], + var_to_sources: Mapping[sympy.Symbol, list[Source]], + ) -> None: + self.symbol_to_source = symbol_to_source + self.source_ref = source_ref + self.var_to_sources = var_to_sources + super().__init__() + + def _print_Float(self, expr: sympy.Float) -> str: + """Convert a sympy Float to a Python float string representation.""" + return str(float(expr)) + + def _print_Symbol(self, expr: sympy.Symbol) -> str: + """ + Convert a sympy Symbol to its source representation. + + This method looks up the symbol in symbol_to_source mapping and returns + the string representation of its first source. If the symbol is not in + symbol_to_source (which can happen when symbols appear in guard expressions + through simplification or substitution), it falls back to var_to_sources. + + Args: + expr: The sympy Symbol to convert + + Returns: + String representation of the symbol's source + + Raises: + AssertionError: If the symbol is not found in either mapping + """ + assert isinstance(expr, sympy.Symbol), str(type(expr)) + + # Try symbol_to_source first, fall back to var_to_sources if not found + if source := self.symbol_to_source.get(expr): + return self.print_source(source[0]) + elif source := self.var_to_sources.get(expr): + return self.print_source(source[0]) + else: + + def repr_sources(src: Mapping[sympy.Symbol, list[Source]]) -> str: + return repr( + { + symbol: [s.name for s in sources] + for symbol, sources in src.items() + } + ) + + raise RuntimeError( + f"{expr} not in {repr_sources(self.symbol_to_source)} or " + f"{repr_sources(self.var_to_sources)}. This could be due to " + "the issue described in https://github.com/pytorch/pytorch/pull/90665" + ) + + @abc.abstractmethod + def print_source(self, source: Source) -> str: + """ + Convert a source object to its string representation. + + Args: + source: The source object to convert + + Returns: + String representation of the source + """ + ... + + @abc.abstractmethod + def doprint(self, expr: sympy.Expr) -> str: + """ + Convert a sympy expression to its string representation. + + Args: + expr: The sympy expression to convert + + Returns: + String representation of the expression + """ + ... + + +class ShapeGuardPythonPrinter(_ShapeGuardPrinter, PythonPrinter): + """ + Python printer for shape guards that extends the base ShapeGuardPrinter. + + This class provides functionality to print symbolic expressions as Python code, + with caching to improve performance when printing the same expressions multiple times. + It handles printing of sources and expressions according to Python syntax. + + Args: + *args: Arguments passed to the parent classes. + """ + + def __init__(self, *args: Any) -> None: + super().__init__(*args) + self._print_cache: dict[sympy.Expr, str] = {} + + def print_source(self, source: Source) -> str: + """ + Convert a source object to its string representation using the source_ref function. + + Args: + source: The source object to convert + + Returns: + String representation of the source + """ + return self.source_ref(source) + + def doprint(self, expr: sympy.Expr) -> str: + """ + Convert a sympy expression to its Python string representation with caching. + + This method first checks if the expression is already in the cache. + If found, it returns the cached result; otherwise, it delegates to + PythonPrinter's doprint method and caches the result. + + Args: + expr: The sympy expression to convert + + Returns: + String representation of the expression in Python syntax + """ + val = self._print_cache.get(expr, None) + if val is not None: + return val + else: + res = PythonPrinter.doprint(self, expr) + self._print_cache[expr] = res + return res + + +@deprecated( + "`torch.fx.experimental.symbolic_shapes.ShapeGuardPrinter` is deprecated, " + "please use `torch.fx.experimental.symbolic_shapes.ShapeGuardPythonPrinter` instead.", + category=FutureWarning, +) +class ShapeGuardPrinter(ShapeGuardPythonPrinter): + pass + + +class _ShapeGuardCppPrinter(_ShapeGuardPrinter, CppPrinter): + def __init__(self, *args: Any) -> None: + self.all_symbols: set[str] = set() + self.source_to_symbol: dict[Source, sympy.Symbol] = {} + super().__init__(*args) + + def print_source(self, source: Source) -> str: + if source in self.source_to_symbol: + return self.source_to_symbol[source].name + + source_name = source.name + mangled_name = re.sub("[^0-9a-zA-Z_]+", "_", source_name) + old_mangled_name = mangled_name + count = 0 + while mangled_name in self.all_symbols: + mangled_name = f"{old_mangled_name}_{count}" + count += 1 + self.source_to_symbol[source] = sympy.Symbol(mangled_name) + self.all_symbols.add(mangled_name) + return mangled_name + + def doprint(self, expr: sympy.Expr) -> str: + return CppPrinter.doprint(self, expr) + + +# A dataclass for storing shape guards +@dataclass(frozen=True) +class _ShapeGuardsHelper: + exprs: list[str] + + +# A dataclass for storing C++ expressions and helper variables +@dataclass(frozen=True) +class _CppShapeGuardsHelper(_ShapeGuardsHelper): + source_to_symbol: dict[Source, sympy.Symbol] + + +class LoggingShapeGuardPrinter(ShapeGuardPythonPrinter): + def __init__(self, var_to_sources: Mapping[sympy.Symbol, list[Source]]): + super().__init__(var_to_sources, lambda n: n.name, var_to_sources) + + +class DynamicDimConstraintPrinter(PythonPrinter): + """ + Printer for dynamic dim constraints. + - Instead of symbol s_k it prints its source t.size()[i] + - Instead of Eq(_, _), Mod(_, _), etc. it prints _ == _, _ % _, etc. + + We use this to suggest code for specifying dynamic dim constraints. + """ + + def __init__( + self, + symbol_to_source: dict[sympy.Symbol, list[Source]], + source_name_to_debug_name: Mapping[str, str], + ): + super().__init__() + self.symbol_to_source = symbol_to_source + self.source_name_to_debug_name = source_name_to_debug_name + + def _print_Symbol(self, expr: sympy.Symbol) -> str: + assert isinstance(expr, sympy.Symbol), str(type(expr)) + assert self.symbol_to_source.get(expr), ( + f"Unknown symbol {expr} created by constraints solver" + ) + return self.symbol_to_source[expr][0].name + + +class DimConstraints: + """ + Custom solver for a system of constraints on symbolic dimensions. + Solutions are "static" values or simplified "dynamic" constraints. + """ + + def __init__( + self, + symbol_to_source: dict[sympy.Symbol, list[Source]], + var_to_val: Mapping[sympy.Symbol, sympy.Integer], + marked_dynamic: set[sympy.Symbol], + source_name_to_debug_name: Mapping[str, str], + ) -> None: + # We try to solve systems of inequalities with 1 free variable. + self._univariate_inequalities: dict[sympy.Symbol, set[SympyBoolean]] = ( + defaultdict(set) + ) + # Among them, we prioritize solving for a free variable that has equalities. + # NOTE: _symbols_with_equalities is always a subset of _univariate_inequalities.keys() + # and removing a symbol from the former => removing it from the latter. + self._symbols_with_equalities: set[sympy.Symbol] = set() + # A solution of a free variable with equalities becomes a substitution. + # We use these substitutions to simplify other constraints. + # NOTE: removing a symbol from _symbols_with_equalities => adding it to _substitutions. + self._substitutions: dict[sympy.Symbol, sympy.Integer] = {} + + # In general, constraints may have // and % operations. + # Of course, // can be expressed in terms of / and %. + # Our inequality solver can handle / but not %. So we need to transform them away. + # We do so by using the values of variables as hints to evaluate %. + # For soundness we record additional congruence guards and solve them separately. + self._var_to_val: Mapping[sympy.Symbol, sympy.Integer] = var_to_val + self._congruences: defaultdict[sympy.Symbol, set[sympy.Expr]] = defaultdict(set) + + # We do not try to (directly) solve inequalities with > 1 free variables. + # NOTE: free variables in these inequalities cannot also be in _substitutions. + self._multivariate_inequalities: set[SympyBoolean] = set() + + # We park external equalities between free variables here. + self._symbolic_equivalences: list[tuple[Source, sympy.Expr]] = [] + + # Solutions come in two forms: + # - (static) specializations + # - (dynamic) inequalities / congruences + self._static_results: set[str] = set() + self._dynamic_results: set[str] = set() + + # printer for solutions + self._dcp = DynamicDimConstraintPrinter( + symbol_to_source, source_name_to_debug_name + ) + + # inconsistencies found on substituting with concrete values / static solutions + self._inconsistencies: list[str] = [] + + # symbols that are marked dynamic + self._marked_dynamic = marked_dynamic + + # track supported sympy functions and subtract from list of all sympy functions + self._supported_sympy_functions: set[sympy.Function] = { + Application, + Mod, + PythonMod, + FloorDiv, + } + self._enumerate_sympy_functions() + + def rewrite_with_congruences(self, s: sympy.Symbol, expr: _SympyT) -> _SympyT: + """ + Eliminate expressions of the form b // d and b % d while adding congruences of the form b % d == k. + This leaves rational operators (in particular of the form b / d) that our inequality solver can handle. + We solve the added congruences separately (using our congruence solver, see below). + """ + + def mod_handler(*args: sympy.Expr) -> sympy.Expr: + # Suppose that we have an expression of the form b % d with free variable s. + # Using the value of s as a "hint," we can evaluate b % d to a value k. + # Then we can rewrite b % d to k while adding the guard b % d == k. + + # NOTE(avik): This abstraction is provably sound but, in general, incomplete. It is complete IFF + # the original expression always evaluates to a constant value (i.e., it does not vary with s). + # In other words, + # - solutions of s with the rewritten expression are guaranteed to also be solutions of s with + # the original expression; + # - while it may be possible to find solutions of s with the original expression that are not + # solutions with the rewritten expression, in that case the original expression cannot evaluate + # to the same value for all solutions of s. + # + # Should we be worried about this incompleteness? No, because of the following reasons: + # 1. It unblocks dramatic simplification that would not be otherwise possible with current tech + # (i.e., "don't let perfect be the enemy of the good"). + # 2. We already have a tradition of using hints to add guards in the compiler for making progress. + # 3. We have not yet seen a counterexample arise in practice! In particular, any congruence guards + # we generate (or simplify to) seem to be of the form b % d == k where k is a constant. + # + # Here's a theoretical counterexample: 3*s % (s + 1) == s - 2, that is satisfied by all s >= 2. + # With any hint (say) s = k, we'd rewrite this to: 3*s % (s + 1) == k - 2. But, substituting, we + # would then get k - 2 == s - 2, and thus s = k as the (only, constant) solution! + base, divisor = args + base, divisor = ( + self.rewrite_with_congruences(s, base), + self.rewrite_with_congruences(s, divisor), + ) + mod_reduced = base.xreplace(self._var_to_val) % divisor.xreplace( + self._var_to_val + ) + congruence = (base - mod_reduced) % divisor + if congruence != 0: + self._congruences[s].add(congruence) + return mod_reduced + + def floor_div_handler(*args: sympy.Expr) -> sympy.Expr: + # Suppose that we have an expression of the form b // d with free variable s. + # Using the value of s, we can evaluate b % d to a value k. + # Then we can rewrite b // d to (b - k) / d, while adding the guard b % d == k. + + # NOTE(avik): This is exactly equivalent to rewriting b // d as (b - (b % d)) / d + # and eliminating b % d as above. + base, divisor = args + base, divisor = ( + self.rewrite_with_congruences(s, base), + self.rewrite_with_congruences(s, divisor), + ) + mod_reduced = base.xreplace(self._var_to_val) % divisor.xreplace( + self._var_to_val + ) + congruence = (base - mod_reduced) % divisor + if congruence != 0: + self._congruences[s].add(congruence) + # NB: Must not be CleanDiv, it needs to be regular sympy division + # so inequality solver works. This is sort of problematic for + # is_integer tests though haha + return (base - mod_reduced) / divisor + + # pyrefly: ignore [missing-attribute] + if expr.has(Mod): + # pyrefly: ignore [missing-attribute] + expr = expr.replace(Mod, mod_handler) + # 7 // -3 is -3, 7 % -3 is -2, and 7 - (-2) / -3 is -3.0 so negative + # arguments should be OK. + # pyrefly: ignore [missing-attribute] + if expr.has(PythonMod): + # pyrefly: ignore [missing-attribute] + expr = expr.replace(PythonMod, mod_handler) + # pyrefly: ignore [missing-attribute] + if expr.has(FloorDiv): + # pyrefly: ignore [missing-attribute] + expr = expr.replace(FloorDiv, floor_div_handler) + return expr + + def _enumerate_sympy_functions(self) -> None: + module = torch.utils._sympy.functions + all_functions = set() + for attr in dir(module): + if isinstance(func := getattr(module, attr), sympy.FunctionClass): + all_functions.add(func) + self._unsupported_sympy_functions = all_functions.difference( + self._supported_sympy_functions + ) + + def _has_unsupported_sympy_function(self, expr: sympy.Basic) -> bool: + """ + Tracks list of sympy.Functions the export solver doesn't know how to handle. + """ + return expr.has(*self._unsupported_sympy_functions) + + def add(self, expr: SympyBoolean) -> bool: + """Add an expression to the set of constraints. + + Return whether the expression is a trivial constraint (i.e., an obvious tautology). + """ + if expr == sympy.true: + return True + orig_expr = expr + orig_reduced = orig_expr.xreplace(self._var_to_val) + # TODO(avik): https://github.com/pytorch/pytorch/issues/101093 + # It is possible that `expr` will fail the consistency check because of + # precision errors. Specifically, on substituting its free symbols with + # their concrete values, we might end up comparing floats. Until we have + # a fix for this issue, we delay raising such failures. See solve(). + if orig_reduced == sympy.false: + self._inconsistencies.append(f"{orig_expr} is inconsistent!") + if isinstance( + expr, (sympy.Ne, sympy.Or, sympy.And) + ) or self._has_unsupported_sympy_function(expr): + # we're not going to do anything useful with these, so drop them + return False + free_symbols = expr.free_symbols + assert free_symbols, f"Did not expect constraint with no free variables: {expr}" + if len(free_symbols) > 1: + # multivariate: record and move on + self._multivariate_inequalities.add(expr) + else: + # univariate: can solve these immediately + s = next(iter(free_symbols)) + # eliminate // and % (see documentation of `rewrite_with_congruences` above) + old_n_congruences = len(self._congruences[s]) + expr = self.rewrite_with_congruences(s, expr) + new_n_congruences = len(self._congruences[s]) + if expr == sympy.true: + return old_n_congruences == new_n_congruences + reduced = expr.xreplace(self._var_to_val) + if reduced == sympy.false: + self._inconsistencies.append( + f"{expr}, obtained by rewriting {orig_expr} with congruences, " + "is inconsistent!" + ) + if isinstance(expr, sympy.Eq): + # special status for symbols that have equalities (see `solve` below) + self._symbols_with_equalities.add(s) + self._univariate_inequalities[s].add(expr) + return False + + def add_equality(self, source: Source, expr: sympy.Expr) -> None: + """Add an equality constraint""" + if expr.is_number: + # specialization, right here + self._static_results.add(f"{source.name} == {expr}") + else: + # these will resolve to either specializations or dynamic equality constraints + self._symbolic_equivalences.append((source, expr)) + + def _reduce_congruences(self) -> dict[sympy.Symbol, set[sympy.Expr]]: + reduced_congruences: dict[sympy.Symbol, set[sympy.Expr]] = {} + for s, congruences in self._congruences.items(): + remainder_modulus_pairs = [] + congruences_to_check = set() + for congruence in congruences: + base, divisor = congruence.args + # We are given a congruence of the form base % divisor == 0 with a free variable s. So: + # - we transform this into an equation of the form base = divisor * tmp; + # - we solve this equation for s to get a linear solution with free variable tmp. + tmp = sympy.Symbol("reduce_congruences_tmp", integer=True) + symbol, solution = sympy.solve_linear(base - divisor * tmp, symbols=[s]) + # See https://docs.sympy.org/latest/modules/solvers/solvers.html#sympy.solvers.solvers.solve_linear + # for how to interpret the results. + if s == symbol: + # This means the solution is of the form s = modulus*tmp + remainder. + modulus, remainder = sympy.polys.polytools.div(solution, tmp) + if isinstance(modulus, sympy.Integer) and isinstance( + remainder, sympy.Integer + ): + # Make sure 0 <= remainder <= modulus. + remainder = remainder % modulus + remainder_modulus_pairs.append((remainder, modulus)) + continue + # This means that we did not get a unique solution to the equation. + # No problem, we will check it. + congruences_to_check.add(congruence) + # Finally we solve for a congruence s such that s = r_i mod m_i for each (r_i, m_i). + # The solution will be a congruence of the form s = r mod m. + # NOTE(avik): Since the given m_i may not be pairwise coprime, we can't just use CRT. + if remainder_modulus_pairs: + remainder, modulus = sympy.ntheory.modular.solve_congruence( + *remainder_modulus_pairs + ) + reduced_congruences[s] = {(s - remainder) % modulus} + substitution = { + s: modulus * sympy.Symbol("tmp", integer=True) + remainder + } + reduced_congruences[s].update( + congruence + for congruence in congruences_to_check + if not sympy.checksol(congruence, substitution) + ) + else: + reduced_congruences[s] = congruences_to_check + + return reduced_congruences + + def _raise_inconsistencies(self) -> None: + if self._inconsistencies: + msg = "\n".join(self._inconsistencies) + self._inconsistencies.clear() + raise ValueError(f"The following inconsistencies were found:\n{msg}") + + def solve(self) -> None: + """Solve the system of constraint equations to find simplified constraints""" + self._raise_inconsistencies() + # as long as there are symbols with equalities, solve for them + # NOTE(avik): this is guaranteed to terminate (#iterations <= #symbols) + while self._symbols_with_equalities: + s = self._symbols_with_equalities.pop() + exprs = self._univariate_inequalities.pop(s) + solution = sympy.solvers.inequalities.reduce_inequalities(exprs, s) + if isinstance(solution, sympy.And): + solution = next( + (arg for arg in solution.args if isinstance(arg, sympy.Eq)), + solution, + ) + assert isinstance(solution, sympy.Eq), ( + f"Expected an equality constraint for {s}, got {solution}" + ) + symbol, val = solution.args + assert symbol == s, f"Expected a constraint on {s} instead of on {symbol}" + # because this is univariate, the solution is a specialization + self._static_results.add( + f"{self._dcp.symbol_to_source[s][0].name} == {val}" + ) + # add this as a substitution to simplify other constraints + self._substitutions[s] = val # type: ignore[assignment] + + # simplify multivariate inequalities: some of them will now become univariate! + multivariate_inequalities = self._multivariate_inequalities + self._multivariate_inequalities = set() + for expr in multivariate_inequalities: + self.add(expr.xreplace({s: self._substitutions[s]})) + self._raise_inconsistencies() + + # solve linear congruences + # NOTE(avik): We do not need to solve them for symbols that have already been specialized. + reduced_congruences = self._reduce_congruences() + for s, congruences in reduced_congruences.items(): + for congruence in congruences: + # any congruence that cannot be checked becomes a dynamic constraint as well + if s not in self._substitutions or not sympy.checksol( + congruence, {s: self._substitutions[s]} + ): + if self._is_supported_congruence(congruence): + base, divisor = congruence.args + tmp_name = "_" + str( + self._dcp.source_name_to_debug_name.get( + self._dcp.symbol_to_source[s][0].name, + self._dcp.symbol_to_source[s][0].name, + ) + ) + tmp = sympy.Symbol(tmp_name, integer=True) + from torch._dynamo.source import ConstantSource + + self._dcp.symbol_to_source[tmp] = [ConstantSource(tmp_name)] + r = try_solve(sympy.Eq(base, divisor * tmp), s) + assert r is not None + self._dynamic_results.add(self._dcp.doprint(sympy.Eq(s, r[1]))) + + # remaining symbols have only pure inequalities (no equalities) + for s, exprs in self._univariate_inequalities.items(): + try: + solution = sympy.solvers.inequalities.reduce_inequalities(exprs, s) + # because this is univariate, the solution is a dynamic (range) constraint + if isinstance(solution, sympy.Or): + solution = next( + iter( + arg + for arg in solution.args + if arg.xreplace(self._var_to_val) + ) + ) + if isinstance(solution, sympy.And): + for arg in solution.args: + self._dynamic_results.add(self._dcp.doprint(arg)) + else: + self._dynamic_results.add(self._dcp.doprint(solution)) + except (NotImplementedError, AssertionError): + log.warning("Failed to reduce inequalities", exc_info=True) + for expr2 in exprs: + self._dynamic_results.add(self._dcp.doprint(expr2)) + + # simplify symbolic equivalences: some of them will now become specializations! + symbolic_equivalences = self._symbolic_equivalences + self._symbolic_equivalences = [] + for source, expr3 in symbolic_equivalences: + self.add_equality(source, expr3.xreplace(self._substitutions)) + + # remaining symbolic equivalences become dynamic equality constraints + for source, expr3 in self._symbolic_equivalences: + self._dynamic_results.add(f"{source.name} == {self._dcp.doprint(expr3)}") + + @classmethod + def _is_supported_congruence(cls, congruence: sympy.Expr) -> bool: + base, divisor = congruence.args + # Congruences that can be currently expressed with supported Dim ops are + # of the form (x + a) % b == 0, where x is a Dim and a and b are constants. + # This allows us to derive x as b*y - a for some Dim y. + # (See also documentation of dynamic_shapes._DerivedDim.) + if isinstance(base, sympy.Add): + lhs, rhs = base.args + cond = ( + isinstance(lhs, sympy.Symbol) and isinstance(rhs, sympy.Integer) + ) or (isinstance(lhs, sympy.Integer) and isinstance(rhs, sympy.Symbol)) + else: + cond = isinstance(base, sympy.Symbol) + cond = cond and isinstance(divisor, sympy.Integer) + return cond + + def forced_specializations(self) -> dict[str, sympy.Expr]: + """Returns a dictionary of the names of symbols to their specialized value""" + + def debug_name(src: Source) -> str: + name = src.name + if self._dcp.source_name_to_debug_name: + return f"{self._dcp.source_name_to_debug_name[name]} = {name}" + else: + return name + + return { + debug_name(self._dcp.symbol_to_source[s][0]): val + for s, val in self._substitutions.items() + if s in self._marked_dynamic + } + + def _is_derived_dim( + self, dim: object + ) -> TypeGuard[torch.export.dynamic_shapes._DerivedDim]: + return isinstance(dim, torch.export.dynamic_shapes._DerivedDim) + + def _is_dim(self, dim: object) -> TypeGuard[torch.export.dynamic_shapes.Dim]: + return isinstance(dim, torch.export.dynamic_shapes.Dim) and not isinstance( + dim, torch.export.dynamic_shapes._DerivedDim + ) + + def _process_derived_dim_roots( + self, + results: dict[str, dict[str, Any]], + name_to_dim: dict[str, Any], + ) -> None: + """ + Here we resolve 2 concerns with derived dims suggested fixes: 1) newly introduced roots, + and 2) root swapping. + + 1) Newly introduced roots appear with modulo guards, e.g. Mod(dx, 2) = 0 suggests + dx is a derived dim equal to 2 * _dx, introducing a new root _dx. Currently the final + suggested fixes handle this correctly, but we can get intermediate results that look like + {"dy": {"eq": "dx + 1"}, "dx": {"eq": "2 * _dx + 1, "min": 3, "max": 15}} + and this routine prettifies this by unifying to a single root, and making each suggestion + either a derived dim or min/max range, not both. + + 2) With suggested fixes for derived dims, roots can be swapped, + e.g. dx, dx - 1 -> dy + 1, dy. Here we don't want to print out the attached name, + since this leads to messages like "dx - 1 = Dim("dx - 1", ...)". + Instead we evaluate the new root value, and remove results for its derivations. + + First we find all the original roots (specified in dynamic_shapes), that are found in the + values of results (i.e. used for computing suggesting fix values). These original roots + (suppose `dx`) are either specialized, unchanged, refined, or swapped + (expressed as a derived dim). If any of the first 3 cases happen, we suggest `dx`'s value + in results, and remove suggestions for derivations of `dx`, assuming the derived relation + is valid. If swapped, we find the new root, and use the fix to evaluate `dx`'s new value, + and then do the same with `dx`'s derivations. + + Assuming the originally specified derived relations are correct is valid, because: + 1) if the relations are plain wrong (e.g. input shape = (6, 4) with spec (dx, dx - 1)) + produce_guards() will catch this and crash before hand. + 2) if the relations are numerically correct but do not match the emitted guard, + for example: + + def forward(self, x, y): + return x.reshape([-1]) + y # guard: s0 * 2 = s1 + inputs = (torch.randn(6, 2), torch.randn(12)) + dx = Dim("dx", min=2, max=32) + dynamic_shapes={"x": (dx, 2), "y": (dx + 6, )} # this matches values but not op + + then this leads to 2 linear equations, and a) produce_guards() is able to solve for + the unique solution of dx = 6 and specialize, and b) the export constraint solver will + raise an issue due to range constraints (a unique solution means not all values in a + range satisfy a guard) and also force specializations. + """ + from torch.export.dynamic_shapes import Dim + + def _check_same_range(c: Mapping[str, int], dim: object) -> bool: + # returns True if c & dim are both min/max ranges with same values + return ( + self._is_dim(dim) + and ("min" in c or "max" in c) + and ( + (dim.min < 2 and c.get("min", 2) == 2) or dim.min == c.get("min", 2) # type: ignore[attr-defined] + ) # let pass if analysis min = 2 and specified min = 0/1 + and dim.max == c.get("max", int_oo) # type: ignore[attr-defined] + ) + + # 1) newly introduced roots + # this part we handle adding newly introduced roots + # these arise from guards like "x.shape[0] % 3 == 0" + # leading to suggested fixes like "dx = 3*_dx" + # extract _dx, and find appropriate min/max values + # + # before, we have something like: + # {"dx": {"eq": 3*_dx+1, "min": 4, "max": 10}, "dy": dx+1, "dz": dx+2} + # we want instead: + # {"_dx": {"min": 1, "max": 4}, "dx": 3*_dx+1, "dy": 3*_dx+2, "dz": 3*_dx+3} + introduced_roots: dict[str, str] = {} # map new root -> old root + for k, c in list(results.items()): + if "eq" in c and isinstance(c["eq"], sympy.Expr): # derived dim + root = next(iter(c["eq"].free_symbols)) + if str(root) not in name_to_dim: + introduced_roots[str(root)] = k + # calculate necessary min & max + modulus, remainder = sympy.polys.polytools.div(c["eq"], root) + c_min = c.get("min", 2) + min_ = math.ceil((c_min - remainder) / modulus) + c_max = c.get("max", int_oo) + max_ = math.floor((c_max - remainder) / modulus) + # create result & dim + results[str(root)] = {"min": min_, "max": max_} + name_to_dim[str(root)] = Dim(str(root), min=min_, max=max_) + # remove old root min/max bounds + c.pop("min", None) + c.pop("max", None) + + # alter derivations that depend on old root, to unify to new root + # e.g. dx=3*_dx+1, dy=dx+1 -> dy=3*_dx+2 + for old_root in introduced_roots.values(): + for c in results.values(): + if ( + "eq" in c + and isinstance(c["eq"], sympy.Expr) + and str(symbol := next(iter(c["eq"].free_symbols))) == old_root + ): # derived dim with root = old_root + new_root_expr = results[str(old_root)]["eq"] # dx=3*_dx+1 + + new_expr = c["eq"].subs({symbol: new_root_expr}) # dy=(3*_dx+1)+1 + c["eq"] = new_expr + + # 2) root swapping + # collect all the original roots that are used for calculating values of suggested fixes + # this consists of: + # 1) {"dx": {"min": ..., "max": ...}} -> dx: refined root dim + # 2) {"dy": "dx + 1"} -> dx: root for suggested fix + modified_roots: set[str] = set() + for k, c in results.items(): + if k not in name_to_dim: # _dynamo.export() may handle source directly + continue + if self._is_dim(name_to_dim[k]) and ("min" in c or "max" in c): # case 1) + modified_roots.add(k) + elif "eq" in c and isinstance(c["eq"], sympy.Expr): # case 2) + root = next(iter(c["eq"].free_symbols)) + assert root is not None + modified_roots.add(str(root)) + + # exclude newly introduced roots, we've already processed these + modified_roots = modified_roots.difference(introduced_roots) + + # evaluate the new value for each root + # this is now either 1) unchanged, 2) refined with a new range, + # or 3) specialized to a concrete value + modified_root_values: dict[str, dict[str, Any]] = {} + for mroot in modified_roots: + swapped_root = True + if mroot in results: + c = results[mroot] + if ("min" in c or "max" in c) or isinstance( # range + c["eq"], int + ): # specialized + # here, the original root is a root Dim or concrete value in results. + # if it is a derived dim, it is swapped, and we handle that below. + if not _check_same_range( + c, name_to_dim[mroot] + ): # ignore if unchanged + modified_root_values[mroot] = c + swapped_root = False + + if swapped_root: + # if the original root has been swapped in results, that means the new root + # is a range (if it had specialized, the original root would have too). + # find this new root, and solve for the original root's range. + for k, c in results.items(): + if k not in name_to_dim: + continue + dim = name_to_dim[k] + if ( + dim.__class__.__name__ == "_DerivedDim" + and dim.root.__name__ == mroot + ): + # only look for min/max root, otherwise root would have specialized + if "min" in c or "max" in c: + expr = sympy.sympify(k) + s = next(iter(expr.free_symbols)) + result = { + "min": try_solve(sympy.Eq(expr, c["min"]), s)[1], # type: ignore[arg-type, index] + "max": try_solve(sympy.Eq(expr, c["max"]), s)[1], # type: ignore[arg-type, index] + } + if not _check_same_range( + result, + name_to_dim[mroot], # type: ignore[index, arg-type] + ): # ignore if unchanged + modified_root_values[mroot] = result # type: ignore[index] + break + + # filter out results where the key is a derived dim (e.g. {"dx - 1" : 4}) + # we only want to suggest fixes for the root, to avoid derived names. + # also, remove anything in modified_roots, since we either add new modified values after this, + # or have decided they are unchanged. + for k in list(results.keys()): + if k not in name_to_dim: + continue + if self._is_derived_dim(name_to_dim[k]) or k in modified_roots: + del results[k] + + # update results with modified root values + # now results has the following properties: + # - only contains original roots as keys + # - each root is now either specialized, refined, or derived from another original root + results.update(modified_root_values) + + def prettify_results( + self, + original_signature: inspect.Signature, + dynamic_shapes: Union[dict[str, Any], tuple[Any], list[Any]], + constraint_violation_error: object, + forced_specializations: dict[str, str], + ) -> str: + """Format a message for constraint violation errors""" + from torch.export.dynamic_shapes import _get_dim_name_mapping + + if not self._dcp.source_name_to_debug_name: + # nothing to do + return "" + + def transform(s: str, inverse: bool = False) -> str: + for k, v in self._dcp.source_name_to_debug_name.items(): + s = s.replace(k, v) if not inverse else s.replace(v, k) + return s + + results: defaultdict[str, dict[str, Any]] = defaultdict(dict) + if dynamic_shapes is None: + dynamic_shapes = {} + + def flip(op: str) -> str: + if op == "<=": + return ">=" + if op == ">=": + return "<=" + if op == "<": + return ">" + if op == ">": + return "<" + assert op == "==" + return op + + def relation_with_digit(expr: str, op: str, digit: int) -> None: + if op == "<=": + results[expr]["max"] = digit + elif op == "<": + results[expr]["max"] = digit - 1 + elif op == ">=": + results[expr]["min"] = digit + elif op == ">": + results[expr]["min"] = digit + 1 + else: + assert op == "==" + results[expr]["eq"] = digit + + # retrieve dynamic shapes + name_to_dim = _get_dim_name_mapping(dynamic_shapes) + + for s in self._static_results.union(self._dynamic_results): + t = transform(s) + if t == s: + continue + left, op, right = re.split(r"( == | <= | >= | < | > )", t) + op = op.strip() + if op == "==" and left == right: + continue + if right.isdigit(): + relation_with_digit(left, op, int(right)) + elif left.isdigit(): + relation_with_digit(right, flip(op), int(left)) + else: + assert op == "==", t + try: + results[left]["eq"] = sympy.sympify(right) + except TypeError: # rhs source is not linked to Dim name + pass + + # order forced specializations based on name + forced_specializations = { + k: forced_specializations[k] + for k in sorted( + forced_specializations.keys(), + key=lambda x: x.split(" = ")[1], + ) + } + + buf = "" + if forced_specializations: + debug_names = set() + for k in forced_specializations: + dim = name_to_dim[k.split(" = ")[0]] + if self._is_derived_dim(dim): + debug_names.add(dim.root.__name__) # type: ignore[attr-defined] + else: + debug_names.add(dim.__name__) + + buf += ( + f"Specializations unexpectedly required ({', '.join(sorted(debug_names))})! " + 'For more information, run with TORCH_LOGS="+dynamic".\n' + ) + for s, val in forced_specializations.items(): + buf += f" - solving the guards generated for {s} resulted in a specialized value of {val}.\n" + + self._process_derived_dim_roots(results, name_to_dim) + + dims = [] + others = [] + + # order results by source name + results2 = { + k: results[k] + for k in sorted( + results.keys(), + key=lambda x: transform(x, inverse=True), + ) + } + for k, c in results2.items(): + if "eq" in c: + other = c["eq"] + if isinstance(other, int): + others.append(f"{k} = {other}") + elif _is_supported_equivalence(other): + others.append(f"{k} = {other}") + else: + min_ = c.get("min", None) + if min_ == 2: + min_ = None + max_ = c.get("max", None) + if min_ is not None and max_ is not None: + dims.append(f"{k} = Dim('{k}', min={min_}, max={max_})") + elif min_ is not None: + dims.append(f"{k} = Dim('{k}', min={min_})") + elif max_ is not None: + dims.append(f"{k} = Dim('{k}', max={max_})") + else: + dims.append(f"{k} = Dim('{k}')") + + # results2 will get filtered out if no new suggestions, + # this can happen if guards are too complex. + # in that case don't suggest fix + if dims or others: + buf += "\nSuggested fixes:\n " + buf += "\n ".join(dims + others) + + return buf + + +TLS = threading.local() + + +@dataclass(frozen=True) +class ShapeEnvSettings: + """ + Encapsulates all shape env settings that could potentially affect + FakeTensor dispatch. Used when creating dispatch cache keys. + """ + + allow_scalar_outputs: bool + allow_dynamic_output_shape_ops: bool + assume_static_by_default: bool + specialize_zero_one: bool + duck_shape: bool + prefer_deferred_runtime_asserts_over_guards: bool + trace_asserts: bool + + +@dataclass +class ValueRangesSLoc: + """ + Locations of the guards that triggered lower and upper bound. + """ + + lower: SLoc + upper: SLoc + + +@contextmanager +def _suppress_guards(shape_env: ShapeEnv) -> Iterator[None]: + shape_env._suppress_guards_enter() + try: + yield + finally: + shape_env._suppress_guards_exit() + + +@dataclass +class _FrameLocalResult: + loc: Optional[str] = None + locals: dict[str, Any] = field(default_factory=dict) + symbols: dict[str, str] = field(default_factory=dict) + + +class ShapeEnv: + # This is a wrapper over the actual __init__ function. + # + # Where to add a new constructor parameter to ShapeEnv? + # ===================================================== + # This __init__ function should be used only for parameters related to event recording. + # These are parameters that we don't wish to pass down the road to new ShapeEnv instances + # created from replaying events. + # + # If you wish to add a parameter to the constructor of ShapeEnv, unrelated to event + # recording, do so in the _init function. + def __init__( + self, + *, + should_record_events: Optional[bool] = None, + tracked_fakes: Optional[list[Any]] = None, + **kwargs: Any, + ) -> None: + self._init(**kwargs) + + # Disable event recording when replaying. + kwargs["should_record_events"] = False + + from torch.fx.experimental.validator import translation_validation_enabled + + self._translation_validation_enabled = translation_validation_enabled() + + # If not specified, enable event recording if both: + # - Translation validation is on + # - Translation validation bisection is not disabled + self.should_record_events = ( + should_record_events + if should_record_events is not None + else ( + self._translation_validation_enabled + and not config.translation_validation_no_bisect + ) + ) + + # Enable event recording check if both: + # - It should record events + # - The recording check is enabled + self.check_recorded_events = ( + self.should_record_events and config.check_shape_env_recorded_events + ) + + # This will make sure we only record the top-level function call. + self.is_recording = False + # Keep track of the list of tracked fakes. + self.tracked_fakes = tracked_fakes + # List of events for reconstructing ShapeEnv at arbitrary points in time. + self.events: list[ShapeEnvEvent] = ( + [ShapeEnvEvent(ShapeEnv, kwargs=kwargs)] + if self.should_record_events + else [] + ) + + # FakeTensor per-ShapeEnv operation cache. This is used for caching + # operations that contain symbolic shapes which have guards on the + # ShapeEnv (so are ShapeEnv-dependent). + # + # NOTE: It's important that SymNodes in this cache have their ShapeEnv + # stripped otherwise you end up with cycles which can only be cleaned + # with the GC. + self.fake_tensor_cache: dict[ + torch._subclasses.fake_tensor._DispatchCacheKey, + torch._subclasses.fake_tensor._DispatchCacheEntry, + ] = {} + + # Pro-tip: if you add new field to ShapeEnv, this affects some accept + # tests. Accept their output with: + # + # EXPECTTEST_ACCEPT=1 python test/dynamo/test_dynamic_shapes.py -k test_shape_env_equal + # + def _init( + self, + *, + allow_scalar_outputs: bool = True, + allow_dynamic_output_shape_ops: bool = True, + # NB: These are legacy configuration that help us make good choices + # when the constraint/dynamic dims are not explicitly passed to us. + # Ideally we will fix all call sites to be explicit and not have + # implicit choices, but this apparently was pretty involved. + assume_static_by_default: bool = False, + # Note - On 0/1 specialization + # + # The following options affect decisions we make about eager + # specialization. Disabling them will increase trace time (as we do + # more symbolic reasoning) and can also harm the quality of generated + # code (because inductor may not be able to specialize for bounds + # being equal--although if we later respecialize because of a guard, + # your code may be just as good as it was before.) + # + # When True, eagerly specialize input sizes which have 0/1. + specialize_zero_one: bool = True, + # When True, assume input sizes which have the same size are + # symbolically equal. + duck_shape: Optional[bool] = None, + # For debugging + co_fields: Optional[dict[str, str]] = None, + # When True, whenever safe, we will generate a deferred runtime assert + # instead of a guard whenever we know that an expression must be True, + # otherwise it would be an error, even for backed SymInts (where we + # could ostensibly unconditionally generate guards). This is useful + # for export, where preventing "error checking" sizes from showing up + # in guards is helpful, since these guards in some sense are overly + # pedantic. See also https://github.com/pytorch/pytorch/issues/121749 + prefer_deferred_runtime_asserts_over_guards: bool = False, + # XXX Add any new settings that could affect FakeTensor evaluation + # to: torch._subclasses.fake_tensor._ShapeEnvSettings + trace_asserts: bool = False, + ) -> None: + if duck_shape is None: + duck_shape = config.use_duck_shape + + self.settings = ShapeEnvSettings( + # Not directly used by ShapeEnv; indirectly used by FakeTensor + allow_scalar_outputs=allow_scalar_outputs, + allow_dynamic_output_shape_ops=allow_dynamic_output_shape_ops, + # End + assume_static_by_default=assume_static_by_default, + specialize_zero_one=specialize_zero_one, + duck_shape=duck_shape, + prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards, + trace_asserts=trace_asserts, + ) + + self.guards: list[ShapeGuard] = [] + self.axioms: dict[sympy.Expr, sympy.Expr] = {} + + # A set of ids that have already been allocated. This is used + # for when we allocate symbol ids using the hash of the source + # names to ensure we don't have collisions via linear probing + self.unique_ids: set[int] = set() + # Maps symbolic ints to their original concrete values + # Currently populated from tensors + self.var_to_val: dict[sympy.Symbol, sympy.Integer] = {} + # Like var_to_val, but only set when propagate_real_tensors is on. + # Used as last resort to avoid GuardOnDataDependent error + self.unbacked_var_to_val: dict[sympy.Symbol, sympy.Integer] = {} + # Like above, but used exclusively for OBLIVIOUS_SIZE. These + # potentially could be put together but I am not sure, writing out + # the logic individually before abstracting. + self.oblivious_var_to_val: dict[sympy.Symbol, sympy.Integer] = {} + # Maps symbolic ints to their min/max range. These ranges + # are conservative: the int MUST fall in the range, but the + # range may contain ints which may not actually appear in + # practice + self.var_to_range: dict[sympy.Symbol, ValueRanges] = {} + self.var_to_range_sloc: dict[sympy.Symbol, ValueRangesSLoc] = {} + self.source_name_to_debug_name: dict[str, str] = {} + self.var_to_sources: dict[sympy.Symbol, list[Source]] = {} + # A set of unbacked symbols that are inputs (i.e: not data dependent). + self.unbacked_inputs: OrderedSet[sympy.Symbol] = OrderedSet() + self.var_to_stack: dict[sympy.Symbol, CapturedTraceback] = {} + self.var_to_hint_override: dict[sympy.Symbol, int] = {} + # Maps a source to the *original* symbol that was assigned to it + self.source_to_var: dict[str, sympy.Symbol] = {} + # Maps from sympy ints to expressions representing them + # Populated from equality guards (i.e. a.shape[0] == b.shape[0]) + self.replacements: dict[sympy.Symbol, sympy.Expr] = {} + # The sloc of the guard that triggered this replacement to be added + self.replacements_slocs: dict[sympy.Symbol, SLoc] = {} + self.unbacked_renamings: dict[sympy.Symbol, sympy.Symbol] = {} + # Set holds a % b expressions that evaluate to 0. + self.divisible: set[sympy.Expr] = set() + # Set that holds "size-like" symbols. When we perform + # "size-oblivious" tests, these can be assumed to be >= 2. + self.size_like: set[sympy.Symbol] = set() + # Duck-shaping says that if two input tensors have the same size, + # they get assigned the same symbolic variable + self.val_to_var: dict[int, sympy.Symbol] = {} + self.unbacked_symfloat_counter = 0 + self.unbacked_symint_counter = 0 + # Similar to guards, but these MUST evaluate to true and can + # only be evaluated at runtime midway through (i.e., they always + # involve unbacked symints) + # + # For efficiency reasons, we index in the following way. Suppose you have + # a runtime assert i0 + i1 <= s1. We pick the most recently allocated + # symbol in the source expression and add the assert to the list for + # that symbol e.g., {i1: [i0 + i1 <= s1]}. + # + # We access the runtime asserts in two situations: + # + # - When we are guarding on an expression, we will attempt to + # statically evaluate it, in case the unbacked SymInts can + # simplify away. If we have a runtime assert, we may be able + # to discharge the guard entirely. We only need to attempt + # runtime asserts that mention freevars of the expression in + # question. + # + # - When we are performing codegen (in Inductor for eager, or + # when finalizing the export FX graph), we need to know what + # extra runtime asserts to insert. Whenever an unbacked + # SymInt comes into scope, all runtime asserts involving it + # become eligible for insertion (so long as all of their other + # free unbacked symbols are also in scope). We technically + # can handle any choice of key by kicking inexpressible asserts + # to the next unbacked symbol to wait on, but if we choose the + # latest key, an assert will only show up at the moment when + # we can actually codegen it. + self.deferred_runtime_asserts: dict[ + Optional[sympy.Symbol], list[RuntimeAssert] + ] = {} + # This exists so we can efficiently invalidate the cache (it's used as + # part of the cache key); otherwise we'd have to iterate through + # deferred_runtime_asserts to compute its length + self.num_deferred_runtime_asserts = 0 + self.log = log + self.log.info("create_env") + self.frozen = False + self.runtime_asserts_frozen = False + self.dim_constraints: Optional[DimConstraints] = None + self.counter: Counter[str] = collections.Counter() + # Mapping from sympy.Symbol to the number of guards which mention this + # symbol + self.symbol_guard_counter: Counter[sympy.Symbol] = collections.Counter() + # A selection of important fields on co_field; solely used for + # signpost_event + self.co_fields = co_fields if co_fields else {} + + # Whenever we allocate a fresh unbacked Symbol, we add it to this + # pending list. Unbacked symbol allocation can occur at unpredictable + # points during meta tensor propagation, but at some point, we + # have to know what the binding site for an unbacked symbol is, and + # this is computed when we actually place the node in the graph. The + # important thing is that we always actually handle every unaccounted + # for unbacked symbol, so this list helps us keep track of them and + # then make sure they are all accounted for. + # + # We could potentially give rise to errors earlier by lexically + # scoping when we do propagation, and only allowing unbacked symbols + # to be allocated at this point in time. However this is inconvenient + # to do in Dynamo, because fake tensor propagation is far from when we + # analyze binding sites (set_example_value), so we do it in a more + # mutatey way. + # + # NB: fresh unbacked symbols NEVER get substitutions applied to them, + # they are binding sites! + self.pending_fresh_unbacked_symbols: list[sympy.Symbol] = [] + + # Version counter used to invalidate cached values + self._prev_cache_key = self._get_key() + self._version_counter = 0 + + # Each time divisible is changed this should be set to True, this is set in _update_version_counter. + self._resimplify_floor_div_axioms = True + + # Cache for FX nodes. + # Maps an already built node a tuple of: + # 1. node's target + # 2. list of arguments + # This drastically reduces the size of the FX graph, avoiding + # duplicated nodes. + self.fx_node_cache: dict[tuple[Callable, tuple[Any, ...]], torch.fx.Node] = {} + self.source_to_symbol: dict[str, sympy.Symbol] = {} + + # Suppose you want to replace an unbacked symbol with another + # unbacked symbol. This is error prone because you can cause + # references to unbacked symbols to time travel backwards. E.g., + # + # u1 = x.item() + # ... use of u1 ... + # u2 = y.item() + # u3 = z.item() + # torch._check(u1 == u2 + u3) + # + # If you replace u1 with u2 + u3, then the use of u1 now + # references u2 and u3 prior to them actually being bound at + # runtime. + # + # To control for this, we track the order unbacked symbols + # were allocated, and only allow substitutions if they respect + # the dependency from this order; an unbacked symbol can only + # be substituted with unbacked symbols that come before it in the + # order. + # + # This also imposes an ordering on the unbacked symbol binding + # sites themselves: you are not allowed to reorder unbacked symbol + # bindings. At the moment, this is not tracked, but we potentially + # could track this at the IR level using a higher order operator + # with something like effect token tracking. + self.unbacked_alloc_order: dict[sympy.Symbol, int] = {} + + self.specialization_stacks: dict[Source, traceback.StackSummary] = {} + + self.trace_asserts = trace_asserts + + self.specializations: OrderedSet[Specialization] = OrderedSet() + + from torch.fx.experimental.validator import translation_validation_enabled + + self._translation_validation_enabled = translation_validation_enabled() + + if self._translation_validation_enabled: + from torch.fx.experimental.validator import TranslationValidator + + self.validator = TranslationValidator() + self.graph = torch.fx.Graph() + # Create an output graph and start inserting before that. + # This is needed when 'deepcopy'-ing this object. + self.graph.inserting_before(self.graph.output(None)) + + # Mapping of each node name to the node itself. + # + # This is useful for matching an FX node from a recorded ShapeEnv.graph + # to the FX node of the ShapeEnv we are running the event on. + # + # Whenever you add a node to self.graph, you must add a mapping to this + # variable. Otherwise, the built FX graph on the replayed ShapeEnv will + # not be valid. + self.name_to_node: dict[str, torch.fx.Node] = {} + + @property + def allow_scalar_outputs(self) -> bool: + return self.settings.allow_scalar_outputs + + @property + def allow_dynamic_output_shape_ops(self) -> bool: + return self.settings.allow_dynamic_output_shape_ops + + @property + def assume_static_by_default(self) -> bool: + return self.settings.assume_static_by_default + + @property + def specialize_zero_one(self) -> bool: + return self.settings.specialize_zero_one + + @property + def duck_shape(self) -> bool: + return self.settings.duck_shape + + @property + def prefer_deferred_runtime_asserts_over_guards(self) -> bool: + return self.settings.prefer_deferred_runtime_asserts_over_guards + + @contextmanager + def patch_source_specialization( + self, source: Source, check_fn: Callable[[sympy.Symbol], sympy.Expr] + ) -> Iterator[None]: + """ + Temporarily add symbol-level axioms to the ShapeEnv. This is useful when you want to "fork" + and have parallel universes of ShapeEnvs. For example, we use this when doing multi-graph + compile so we can support various graphs with varying levels of specializations. + + This context manager allows for temporarily adding constraints to the shape environment + based on a specialization function applied to a symbol associated with a source. + + Args: + source: The source of the symbol to specialize + check_fn: A function that takes a sympy Symbol and returns a sympy expression + representing a constraint/specialization to be applied + """ + name = source.name + sym = self.source_to_var[name] + expr = check_fn(SymInt(SymNode(sym, self, int, None))).node._expr + new_axioms = dict(self.get_implications(self.simplify(expr))) + added_replacements = {} + + for axiom in new_axioms: + if ( + isinstance(axiom, sympy.Eq) + and isinstance(axiom.lhs, sympy.Symbol) + and isinstance(axiom.rhs, sympy.Integer) + and axiom.lhs not in self.replacements + ): + self.replacements[axiom.lhs] = axiom.rhs + added_replacements[axiom.lhs] = axiom.rhs + self.axioms.update(new_axioms) + + # We need to freeze the ShapeEnv because any additional modification of + # the ShapeEnv will cause unsoundness for subsequent specialization calls. + self.frozen = True + try: + yield + finally: + for k in new_axioms: + self.axioms.pop(k, None) + for k in added_replacements: + self.replacements.pop(k, None) + self.frozen = False + + def check_equal(self, other: ShapeEnv) -> None: + """Compare another ShapeEnv for equivalence""" + # ShapeEnv fields that are not relevant for the outcome of + # ShapeEnv.produce_guards call: + # - Debugging variables + # - Translation validation related variables + # - Events recording related variables + non_state_variable_names = ( + "counter", + "log", + "var_to_stack", + "fx_node_cache", + "graph", + "validator", + "check_recorded_events", + "should_record_events", + "is_recording", + "tracked_fakes", + "events", + "source_name_to_debug_name", + "_prev_cache_key", + "_version_counter", + "dim_constraints", + # source locations are OK to diverge + "var_to_range_sloc", + "replacements_slocs", + "_resimplify_floor_div_axioms", + "_expr_sym_node_id", + "specialization_stacks", + ) + + # Mapping of the value of each to-be-compared field into the values that + # should actually be compared. + # + # You should modify this if, for example, the field that holds state and + # debugging information. e.g. ShapeGuard holds the actual guard (sympy.Expr) + # and the stack when it was added to the set of guards. In order to compare + # it, we throw away the stack information. + def map_value(key: str, value: Any) -> Any: + if key == "guards": + # Transform the list of ShapeGuard into a list of expressions. + return [g.expr for g in value] + elif key == "deferred_runtime_asserts": + # Transform the list of RuntimeAsserts into a list of expressions. + return {s: [ra.expr for ra in ras] for s, ras in value.items()} + elif key == "name_to_node": + # Compare just the set of keys is the same. + return set(value.keys()) + elif key in ( + "symbol_guard_counter", + "pending_fresh_unbacked_symbols", + "fake_tensor_cache", + ): + # Skip this for comparisons + return None + return value + + shape_env_check_state_equal(self, other, non_state_variable_names, map_value) + + def _snapshot_tracked_fakes(self) -> Optional[list[Any]]: + if self.tracked_fakes is None: + return None + + from torch._dynamo.variables.builder import TrackedFake + + def maybe_transform_fake(fake: TrackedFake) -> TrackedFake: + inner_fake = ( + fake.fake + if isinstance(fake.fake, (torch.SymInt, torch.SymFloat)) + else FakeTensorMeta.from_fake(fake.fake) + ) + # Even though TrackedFake accepts either a Union[SymInt, FakeTensor], here we give it a + # FakeTensorMeta for two reasons: + # 1. this is all the information we need when recording ShapeEnvEvents. + # 2. it works even if each TrackedFake changes its metadata. + return TrackedFake(inner_fake, fake.source, fake.symbolic_context) # type: ignore[arg-type] + + return [maybe_transform_fake(fake) for fake in self.tracked_fakes] + + def _last_event_index(self) -> int: + return len(self.events) - 1 + + @contextmanager + def _recording(self) -> Iterator[None]: + self.is_recording = True + try: + yield + finally: + self.is_recording = False + + @record_shapeenv_event() + def _eliminate_unbacked(self, orig_s: sympy.Symbol, new_s: sympy.Expr) -> None: + self._set_replacement(orig_s, new_s, "eliminate_unbacked") + + @record_shapeenv_event() + def set_unbacked_var_to_val(self, k: sympy.Symbol, v: int) -> None: + """Used only when propagate_real_tensors; registers a value for an + unbacked symbol, which can be used last resort to resolve hints.""" + log.info("set_unbacked_var_to_val %s = %s", k, v) + self.unbacked_var_to_val[k] = sympy.sympify(v) + + # Unlike set_replacement, this records a shapeenv event + @record_shapeenv_event() + def _rename_unbacked_to(self, orig_s: sympy.Symbol, new_s: sympy.Symbol) -> None: + assert isinstance(orig_s, sympy.Symbol), orig_s + assert isinstance(new_s, sympy.Symbol), new_s + assert free_unbacked_symbols(new_s), new_s + assert free_unbacked_symbols(orig_s), orig_s + dest = self.replacements.get(orig_s) + if dest is not None: + assert not free_unbacked_symbols(dest), f"{orig_s} -> {dest}" + self._set_replacement(orig_s, new_s, "rename_unbacked_to") + self.unbacked_renamings[orig_s] = new_s + if dest is not None: + self._set_replacement(new_s, dest, "rename_unbacked_to_dest") + + @record_shapeenv_event() + def _constrain_is_bounded(self, a: sympy.Symbol, upper_bound: int) -> None: + # TODO: Do something nontrivial when upper_bound is expression + pass + + @record_shapeenv_event() + def _constrain_range_for_size( + self, a: sympy.Symbol, min: Optional[int] = None, max: Optional[int] = None + ) -> None: + if min is None: + min = 0 + if max is None: + max = int_oo + + if max < min: + raise ValueError( + "Maximum value to constrain_as_size can't be less than the specified min value, " + f"received min={min} and max={max}" + ) + + self.constrain_symbol_range( + a, + compiler_min=min, + compiler_max=max, + ) + self.size_like.add(a) + + @record_shapeenv_event() + def _constrain_range(self, a: sympy.Expr, min: int, max: int) -> None: + if isinstance(a, sympy.Integer): + if not (min <= int(a) <= max): + raise ValueRangeError(f"Invalid value {int(a)} for range [{min}:{max}]") + return + + # TODO: Shouldn't we install a guard if the symbol is backed? Or is the + # semantics that this is an "unchecked" assert (but it this actually + # something useful? Might be better to restrict only for unbacked + # SymInt). + if isinstance(a, sympy.Symbol): + self.constrain_symbol_range( + a, + compiler_min=min, + compiler_max=max, + ) + + @record_shapeenv_event() + def _constrain_unify(self, a: SymInt, b: SymInt) -> None: + """ + Given two SymInts, constrain them so that they must be equal. NB: + this will not work with SymInts that represent nontrivial expressions + (yet!) + """ + # TODO: this does not install a deferred runtime assert yet + + # TODO: Maybe dedupe this with _maybe_guard_rel? + # Update Feb 2024: this is extra important to do, this doesn't handle + # unbacked replacements properly nor does it generate deferred runtime + # asserts + if not isinstance(a, SymInt): + if not isinstance(b, SymInt): + assert a == b + else: + assert isinstance(b.node.expr, sympy.Symbol), ( + "constraining non-Symbols NYI" + ) + assert b.node.shape_env is self + self.replacements[b.node.expr] = sympy.Integer(a) + else: + # TODO: Actually, we can support this as long as one of them is a symbol. + # NB: We can't actually do "unification" as our operators are not + # injective + assert isinstance(a.node.expr, sympy.Symbol), "constraining non-Symbols NYI" + assert a.node.shape_env is self + if not isinstance(b, SymInt): + self.replacements[a.node.expr] = sympy.Integer(b) + else: + assert a.node.shape_env is b.node.shape_env + assert isinstance(b.node.expr, sympy.Symbol), ( + "constraining non-Symbols NYI" + ) + new_var = self._find(a.node.expr) + self.replacements[b.node.expr] = new_var + + def _ignore_fresh_unbacked_symbols_tls(self) -> bool: + return getattr(TLS, "ignore_fresh_unbacked_symbols", False) + + @record_shapeenv_event() + def _ignore_fresh_unbacked_symbols_set(self, b: bool) -> bool: + prev = self._ignore_fresh_unbacked_symbols_tls() + TLS.ignore_fresh_unbacked_symbols = b + return prev + + @contextmanager + def ignore_fresh_unbacked_symbols(self) -> Iterator[None]: + """ + Indicates that the newly allocated unbacked SymInts are being + discarded + """ + prev = self._ignore_fresh_unbacked_symbols_set(True) + try: + yield + finally: + self._ignore_fresh_unbacked_symbols_set(prev) + + @record_shapeenv_event() + def freeze(self) -> None: + """Freeze this ShapeEnv to stop accumulating guards + + A frozen ShapeEnv will ignore any further guards generated on it and + only emit a warning which may lead to accuracy problems. + """ + self.frozen = True + + @record_shapeenv_event() + def freeze_runtime_asserts(self) -> None: + """Freeze this ShapeEnv to stop adding deferred runtime asserts. + + We will error if you try to install a new runtime assert when it is + frozen. This would indicate a lowering violation, or perhaps something + we know statically is already True but we are checking it again in a way + that is not clearly dischargeable. + """ + # self.prefer_deferred_runtime_asserts_over_guards = False + self.runtime_asserts_frozen = True + + def _create_symbol_for_source(self, source: Source) -> Optional[sympy.Symbol]: + if not self._translation_validation_enabled: + return None + srcname = source.name + if source not in self.source_to_symbol: + self.source_to_symbol[srcname] = sympy.Symbol(srcname, integer=True) + return self.source_to_symbol[srcname] + + def _add_z3var(self, symbol: sympy.Symbol, type: type) -> None: + if self._translation_validation_enabled: + self.validator.add_var(symbol, type) + + def _add_target_expr(self, expr: SympyBoolean) -> None: + if self._translation_validation_enabled: + self.validator.add_target_expr(expr) + + def _add_assertion(self, expr: SympyBoolean) -> None: + if self._translation_validation_enabled: + self.validator.add_assertion(expr) + + def _check_translation_validate(self) -> None: + if self._translation_validation_enabled: + self.validator.validate() + + @record_shapeenv_event() + def _create_fx_call_function( + self, + op: Callable, + args: tuple, + ) -> tuple[Optional[torch.fx.Node], bool]: + # Cache this tuple in order to avoid duplicated nodes. + node_key = (op, args) + # Flags whether the returned node was cached or not. + fresh = False + + if self._translation_validation_enabled and node_key not in self.fx_node_cache: + # Presence of None in the arguments implies that we should ignore this operation. + if any(a is None for a in args): + # We check if we are not mixing SymNode that should not be ignored + # (fx_node is not None) with those that should (fx_node is None). + assert all(not isinstance(a, torch.fx.Node) for a in args) + return None, fresh + + fresh = True + + # If translation validation is enabled, all arguments must have its + # own FX node. + assert all(a is not None for a in args), ( + f"missing arg in FX graph ({op.__name__}): {args}" + ) + node = self.fx_node_cache[node_key] = self.graph.call_function(op, args) + self.name_to_node[node.name] = node + + return self.fx_node_cache.get(node_key, None), fresh + + def _create_fx_placeholder_and_z3var( + self, + symbol: sympy.Symbol, + type: type, + ) -> Optional[torch.fx.Node]: + if not self._translation_validation_enabled: + return None + + node_key = (self.graph.placeholder, (symbol,)) + + # Check if we haven't added this symbol already. + # If so, skip the placeholder creation, as it + # generates invalid Python code. + if node_key not in self.fx_node_cache: + # Add a Z3 variable according to 'type'. + self._add_z3var(symbol, type) + # Create the FX placeholder out of a mangled name. + mangled_name = re.sub( + r"[^a-zA-Z0-9]", "_", re.sub(r"[()]", "", symbol.name) + ) + node = self.fx_node_cache[node_key] = self.graph.placeholder(mangled_name) + self.name_to_node[node.name] = node + # Attach the 'symbol' to the placeholder so that we can retrieve + # the Z3 variable later. + node.meta["symbol"] = symbol + + return self.fx_node_cache[node_key] + + def _remove_fx_node(self, node: Optional[torch.fx.Node]) -> None: + if self._translation_validation_enabled and node is not None: + self.name_to_node.pop(node.name) + self.graph.erase_node(node) + + def _add_fx_node_metadata(self, node: torch.fx.Node) -> None: + from torch._dynamo.utils import get_current_node + + if self.should_record_events: + node.meta[SHAPEENV_EVENT_KEY] = self._last_event_index() + node.meta[CURRENT_NODE_KEY] = get_current_node() + + @staticmethod + def _suppress_guards_tls() -> bool: + return getattr(TLS, "suppress_guards", False) + + @record_shapeenv_event() + def _suppress_guards_enter(self) -> None: + if not hasattr(TLS, "suppress_guards_stack"): + TLS.suppress_guards_stack = [] + old = self._suppress_guards_tls() + TLS.suppress_guards_stack.append(old) + TLS.suppress_guards = True + + @record_shapeenv_event() + def _suppress_guards_exit(self) -> None: + old = ( + TLS.suppress_guards_stack.pop() + if len(TLS.suppress_guards_stack) > 0 + else False + ) + TLS.suppress_guards = old + + def suppress_guards(self) -> _GeneratorContextManager[None]: + """Context manager to ignore all guards generated inside""" + return _suppress_guards(self) + + def _get_key(self) -> tuple[int, int, int, int]: + """ + Defines the current "state" of the guards we've accumulated in this ShapeEnv. + Determines when we need to invalidate our cache + """ + return ( + len(self.replacements), + len(self.divisible), + self.num_deferred_runtime_asserts, + len(self.unbacked_var_to_val), + ) + + def _update_version_counter(self) -> None: + # if the change to shape env effects self.divisible set + # _resimplify_floor_div_axioms. + # This is used to trigger a resimplication of FloorDiv to CleanDivs + # in implication inside the function resimplify_floor_div. + if len(self.divisible) != self._prev_cache_key[1]: + self._resimplify_floor_div_axioms = True + + # The shape environment is queried orders of magnitude more often than + # it is changed, so we summarise the cache key into a linearly + # increasing version counter which is cheaper to check in _lru_cache + + # Only update version counter if the state actually changed + cur_key = self._get_key() + + if self._prev_cache_key != cur_key: + self._prev_cache_key = cur_key + self._version_counter += 1 + + def _produce_dyn_sizes( + self, + ex_size: Sequence[IntLikeType], + source: Source, + symbolic_context: SymbolicContext, + ) -> list[sympy.Expr]: + return self._produce_dyn_sizes_from_int_tuple( + tuple(ex_size), source, symbolic_context + ) + + def _produce_dyn_sizes_from_int_tuple( + self, + tensor_size: Sequence[IntLikeType], + source: Source, + symbolic_context: SymbolicContext, + hint_overrides: Optional[dict[int, int]] = None, + ) -> list[sympy.Expr]: + assert all(not is_symbolic(val) for val in tensor_size), ( + f"Expect size to be a plain tuple of ints but got {tensor_size}" + ) + from torch._dynamo.source import TensorProperty, TensorPropertySource + + if not hint_overrides: + hint_overrides = {} + + _assert_symbol_context(symbolic_context) + dynamic_dims = symbolic_context.dynamic_sizes # type: ignore[attr-defined] + constraint_dims = symbolic_context.constraint_sizes # type: ignore[attr-defined] + size = [] + for i, val in enumerate(tensor_size): + sym = self.create_symbol( + hint_overrides.get(i, val), + TensorPropertySource(source, TensorProperty.SIZE, i), + dynamic_dims[i], + constraint_dims[i], + do_not_specialize_zero_one=config.backed_size_oblivious, + symbolic_context=symbolic_context, + ) + if ( + isinstance(symbolic_context, StatelessSymbolicContext) + and symbolic_context.specialize_on + ): + for specialization in symbolic_context.specialize_on[i]: + self.specializations.add( + Specialization( + TensorPropertySource(source, TensorProperty.SIZE, i), + specialization, + ) + ) + if ( + config.backed_size_oblivious + and isinstance(sym, sympy.Symbol) # could be static + and symbol_is_type(sym, SymT.SIZE) + ): + self.size_like.add(sym) + size.append(sym) + return size + + def create_symbolic_sizes_strides_storage_offset( + self, + ex: torch.Tensor, + source: Source, + *, + symbolic_context: Optional[SymbolicContext] = None, + ) -> tuple[ + tuple[IntLikeType, ...], + tuple[IntLikeType, ...], + IntLikeType, + ]: + """ + Returns a list of symbolic sizes and strides for the given tensor. + We try our best to express stride in terms of the sizes, so as to not + introduce new symbolic variables. + """ + + ex_size = tuple( + self._maybe_specialize_sym_int_with_hint(sz) for sz in ex.size() + ) + ex_stride = tuple( + self._maybe_specialize_sym_int_with_hint(sd) for sd in ex.stride() + ) + ex_storage_offset = self._maybe_specialize_sym_int_with_hint( + ex.storage_offset() + ) + + return self._create_symbolic_sizes_strides_storage_offset( + ex_size, + ex_stride, + ex_storage_offset, + [_is_dim_dynamic(ex, i) for i in range(ex.dim())], + source, + symbolic_context=symbolic_context, + ) + + # Dynamo may want to wrap FakeTensors with SymInt sizes up e.g. make_fx(opt_f(), tracing_mode="symbolic"). + # We create symbols in shape_env using the backed hints behind SymInt. + + # Case 1: when SymInt is backed, dynamo can proceed with FakeTensors that have concrete shape. + # produce_guards will trigger specializations on the outer stuff + + # Case 2: when the SymInt is unbacked, we will throw an data dependent error in require_hint(). + # + # It's probably good for now but it's important to note that this approach has implications for + # the original shape_env when checking guards in different order. + + # Example: + # --------- + # Consider a function "opt_f" as shown below: + + # @torch.compile() + # def opt_f(x: bool, y: Tensor): + # if x == True: + # return y + torch.randn([4]) + # else: + # return y + # Depending on the sequence of calls, we might install two different sets of guards: + + # 1. opt_f(False, y): + # - "x == False" (always works for any size y) + + # 2. opt_f(True, y): + # - Triggers recompilation and results in guards like: + # - "x == True and y.size(0) == 4" + # - (or "y.size(0) == 4 and x == True") + + # The order of checking the guards matters. In this specific example: + # If True branch guard check precedes False branch and for True branch, y.size(0) check precedes x == True, + # we may have an unnecessary shape specialization for y. + def _maybe_specialize_sym_int_with_hint( + self, maybe_sym: IntLikeType + ) -> IntLikeType: + assert isinstance(maybe_sym, (int, torch.SymInt)) + if is_symbolic(maybe_sym): + assert maybe_sym.node.shape_env is not self, ( + "expect the symbol is created from an shape env other than current one." + ) + return maybe_sym.node.require_hint() + return maybe_sym + + @record_shapeenv_event() + def _create_symbolic_sizes_strides_storage_offset( + self, + # NB: SymInt is allowed here due to nested int, normally you don't + # actually pass true symbolic sizes to this function + ex_size: Sequence[IntLikeType], + ex_stride: Sequence[IntLikeType], + ex_storage_offset: IntLikeType, + is_dim_dynamic: Sequence[bool], + source: Source, + *, + symbolic_context: Optional[SymbolicContext] = None, + hint_overrides: Optional[dict[int, int]] = None, + ) -> tuple[ + tuple[IntLikeType, ...], + tuple[IntLikeType, ...], + IntLikeType, + ]: + dim = len(ex_size) + + if not hint_overrides: + hint_overrides = {} + + # Reimplement the legacy behavior + if symbolic_context is None: + constraint_sizes: list[DimConstraint] = [None] * dim + constraint_strides: list[DimConstraint] = [None] * dim + dynamic_dims = [] + dynamic_strides = [] + for i in range(dim): + # NB: This is encapsulation breaking! Legacy behavior was + # bad. + if is_dim_dynamic[i]: + r = DimDynamic.DYNAMIC + elif self.assume_static_by_default: + r = DimDynamic.STATIC + else: + r = DimDynamic.DUCK + dynamic_dims.append(r) + dynamic_strides.append(r) + dynamic_dims = [DimDynamic.DUCK] * dim + dynamic_strides = [DimDynamic.INFER_STRIDE] * dim + # symbolic_context is None - set one + symbolic_context = StatelessSymbolicContext( + dynamic_sizes=dynamic_dims, + dynamic_strides=dynamic_strides, + constraint_sizes=constraint_sizes, + constraint_strides=constraint_strides, + ) + # We got a StatelessSymbolicContext + _assert_symbol_context(symbolic_context) + constraint_sizes = symbolic_context.constraint_sizes # type: ignore[attr-defined] + constraint_strides = symbolic_context.constraint_strides # type: ignore[attr-defined] + dynamic_sizes = symbolic_context.dynamic_sizes # type: ignore[attr-defined] + dynamic_strides = symbolic_context.dynamic_strides # type: ignore[attr-defined] + + # TODO: make this configurable from outside symbolic_context; we made a symbolic_context + # decision here where if all sizes are static, we are going to + # specialize all of the inner strides/offset too. We don't have to + # do this, and arguably we should ALWAYS allow for dynamic offset, + # this is cheap. + # TODO: This should be DYNAMIC, using DUCK for BC + dynamic_offset = ( + DimDynamic.STATIC + if all(r == DimDynamic.STATIC for r in dynamic_sizes) + else DimDynamic.DUCK + ) + are_sizes_static = all(r == DimDynamic.STATIC for r in dynamic_sizes) + + assert len(dynamic_sizes) == dim, f"{len(dynamic_sizes)} != {dim}" + assert len(dynamic_strides) == dim, f"{len(dynamic_sizes)} != {dim}" + assert len(constraint_sizes) == dim + assert len(constraint_strides) == dim + + from torch._dynamo.source import TensorProperty, TensorPropertySource + + size: list[sympy.Expr] = self._produce_dyn_sizes_from_int_tuple( + ex_size, source, symbolic_context, hint_overrides=hint_overrides + ) + stride = self._compute_symbolic_stride( + source, + size, + ex_size, + ex_stride, + dynamic_strides, + constraint_strides, + are_sizes_static, + symbolic_context, + ) + + sym_sizes = [ + self.create_symintnode( + sym, + hint=hint_overrides.get(i, hint), + source=TensorPropertySource(source, TensorProperty.SIZE, i), + ) + for i, (sym, hint) in enumerate(zip(size, ex_size)) + ] + + for i, sym in enumerate(sym_sizes): + if isinstance(sym, torch.SymInt) and i in hint_overrides: + self.var_to_hint_override[sym.node.expr] = hint_overrides[i] + + sym_stride = [] + for i, stride_expr in enumerate(stride): + # NB: Don't duck size the stride; instead use the expression + # we computed + assert stride_expr is not None + sym_stride.append( + self.create_symintnode( + stride_expr, + hint=ex_stride[i], + source=TensorPropertySource(source, TensorProperty.STRIDE, i), + ) + ) + sym_storage_offset = self.create_symintnode( + self.create_symbol( + ex_storage_offset, + TensorPropertySource(source, TensorProperty.STORAGE_OFFSET), + dynamic_dim=dynamic_offset, + constraint_dim=None, + symbolic_context=symbolic_context, + ), + hint=ex_storage_offset, + source=TensorPropertySource(source, TensorProperty.STORAGE_OFFSET), + ) + return tuple(sym_sizes), tuple(sym_stride), sym_storage_offset + + def _compute_symbolic_stride( + self, + source: Source, + size: Sequence[sympy.Expr], + ex_size: Sequence[IntLikeType], + ex_stride: Sequence[IntLikeType], + dynamic_strides: Sequence[DimDynamic], + constraint_strides: Sequence[ + Optional[Union[StrictMinMaxConstraint, RelaxedUnspecConstraint]] + ], + are_sizes_static: bool, + symbolic_context: SymbolicContext, + ) -> list[sympy.Expr]: + from torch._dynamo.source import TensorProperty, TensorPropertySource + + stride: list[Optional[sympy.Expr]] = [None] * len(size) + candidates: dict[IntLikeType, sympy.Expr] = {} + + # iterate over unbound strides in val ascending order with + # index descending as a tie breaker since for cases like + # [(1, 1), (1, 0)], we want to fill in the right most + # stride first. + val_list = [(val, -i) for i, val in enumerate(ex_stride)] + val_list.sort(key=_nested_int_aware_sort) + + for val, neg_i in val_list: + i = -neg_i + contiguous_stride = ( + i != len(ex_stride) - 1 + and ex_stride[i] == ex_size[i + 1] * ex_stride[i + 1] + ) + if val in (0, 1) and not contiguous_stride: + out_stride = sympy.Integer(val) + else: + dynamic_stride = dynamic_strides[i] + if dynamic_stride == DimDynamic.INFER_STRIDE and val in candidates: + # Set stride to a candidate only for DimDynamic.INFER_STRIDE + out_stride = candidates[val] + else: + # Set INFER_STRIDE to STATIC or DUCK depending on sizes + dyn_stride = dynamic_stride + if dynamic_stride == DimDynamic.INFER_STRIDE: + dyn_stride = ( + DimDynamic.STATIC if are_sizes_static else DimDynamic.DUCK + ) + out_stride = self.create_symbol( + val, + TensorPropertySource(source, TensorProperty.STRIDE, i), + dynamic_dim=dyn_stride, + constraint_dim=constraint_strides[i], + symbolic_context=symbolic_context, + ) + stride[i] = out_stride + candidates[ex_size[i] * val] = size[i] * out_stride + + assert all(x is not None for x in stride) + return stride + + @record_shapeenv_event() + def create_symintnode( + self, + sym: sympy.Expr, + *, + hint: Optional[int], + source: Optional[Source] = None, + ) -> IntLikeType: + """Create a SymInt value from a symbolic expression + + If you know what the current hint value of the SymInt to be created + is, pass it into hint. Otherwise, pass None and we will make our best + guess + + """ + if self._translation_validation_enabled and source is not None: + # Create a new symbol for this source. + symbol = self._create_symbol_for_source(source) + assert symbol is not None + + # Create a new FX placeholder and Z3 variable for 'symbol'. + fx_node = self._create_fx_placeholder_and_z3var(symbol, int) + + # Add an equality assertion for the newly created symbol and 'sym'. + self._add_assertion(sympy.Eq(symbol, sym)) + else: + fx_node = None + + out: IntLikeType + if isinstance(sym, sympy.Integer): + if hint is not None: + assert int(sym) == hint + out = int(sym) + else: + # How can this occur? When we mark_unbacked, we end up with a real + # tensor that has hints for all sizes, but we MUST NOT create a + # SymNode with a hint, because we're hiding the hint from our eyes + # with the unbacked Symbol. And in fact, the hint compute may be + # inconsistent with size oblivious tests. + if free_unbacked_symbols(sym): + hint = None + out = SymInt(SymNode(sym, self, int, hint, fx_node=fx_node)) + return out + + @record_shapeenv_event() + def create_symfloatnode( + self, + sym: sympy.Expr, + *, + hint: Optional[int | float | bool], + source: Optional[Source] = None, + ) -> FloatLikeType: + """Create a SymFloat value from a symbolic expression""" + if self._translation_validation_enabled and source is not None: + # Create a new symbol for this source. + symbol = self._create_symbol_for_source(source) + assert symbol is not None + + # Create a new FX placeholder and Z3 variable for 'symbol'. + fx_node = self._create_fx_placeholder_and_z3var(symbol, float) + + # Add an equality assertion for the newly created symbol and 'sym'. + self._add_assertion(sympy.Eq(symbol, sym)) + else: + fx_node = None + + out: FloatLikeType + if isinstance(sym, sympy.Float): + if hint is not None: + assert float(sym) == hint + out = float(sym) + else: + # You could give this the same treatment as SymInt above if + # you supported mark_unbacked on a float, but it's a kind of + # strange thing to do though because floats don't get 0/1 + # specialization anyway + if free_unbacked_symbols(sym): + assert hint is None, sym + out = SymFloat(SymNode(sym, self, float, hint, fx_node=fx_node)) + return out + + @record_shapeenv_event() + def create_unspecified_symint_and_symbol( + self, value: int, source: Source, dynamic_dim: DimDynamic + ) -> IntLikeType: + """Create a SymInt wrapping a new unspecified symbol""" + return self.create_symintnode( + self.create_unspecified_symbol( + value, + source=source, + dynamic_dim=dynamic_dim, + ), + hint=value, + source=source, + ) + + def create_symboolnode(self, sym: sympy.Expr) -> SymBool: + """Create a SymBool object from a sympy boolean expression""" + # This function is only being used in serialization, so we do not track it + # for validation. + return SymBool(SymNode(sym, self, bool, None)) + + def _log_create_unbacked_symbol( + self, + prefix: str, + symbol: sympy.Symbol, + vr: ValueRanges, + source: Optional[Source] = None, + sym_node: Optional[SymNode] = None, + ) -> None: + is_debug = config.extended_debug_create_symbol is not None and str( + symbol + ) in config.extended_debug_create_symbol.split(",") + sloc: Union[str, SLoc] + if source is None: + sloc, maybe_extra_debug = self._get_stack_summary(is_debug) + else: + sloc, maybe_extra_debug = source.name, "" + log.info( + "%s %s [%s, %s] %s%s", + prefix, + symbol, + vr.lower, + vr.upper, + sloc, + maybe_extra_debug, + stack_info=is_debug, + ) + trace_structured( + "create_unbacked_symbol", + metadata_fn=lambda: { + "symbol": str(symbol), + "node_id": id(sym_node), + "vr": f"[{vr.lower}, {vr.upper}]", + "user_stack": structured.get_user_stack(3), + "stack": structured.get_framework_stack(), + }, + ) + + @record_shapeenv_event() + def create_unbacked_symfloat(self) -> SymFloat: + """Create a symbolic float without a hint value""" + symbol: sympy.Symbol = make_symbol( + SymT.UNBACKED_FLOAT, self.unbacked_symfloat_counter + ) + self.unbacked_symfloat_counter += 1 + self.counter["create_unbacked_symbol"] += 1 + if not self._ignore_fresh_unbacked_symbols_tls(): + self.pending_fresh_unbacked_symbols.append(symbol) + self.var_to_stack[symbol] = CapturedTraceback.extract(skip=1) + vr = self.var_to_range[symbol] = ValueRanges.unknown() + assert vr.is_float + sloc = self._get_sloc() + self.var_to_range_sloc[symbol] = ValueRangesSLoc(sloc, sloc) + + # Create a new FX placeholder and Z3 variable for 'symbol'. + fx_node = self._create_fx_placeholder_and_z3var(symbol, float) + + sym_node = SymNode(symbol, self, float, None, fx_node=fx_node) + self._log_create_unbacked_symbol( + "create_unbacked_symfloat", symbol, vr, sym_node=sym_node + ) + + return SymFloat(sym_node) + + @record_shapeenv_event() + def create_unbacked_symint(self, source: Optional[Source] = None) -> SymInt: + """Create a symbolic integer without a hint value""" + symbol: sympy.Symbol = make_symbol( + SymT.UNBACKED_INT, self.unbacked_symint_counter, integer=True + ) + self.unbacked_symint_counter += 1 + if not self._ignore_fresh_unbacked_symbols_tls(): + self.pending_fresh_unbacked_symbols.append(symbol) + self.counter["create_unbacked_symbol"] += 1 + self.var_to_stack[symbol] = CapturedTraceback.extract(skip=1) + vr = self.var_to_range[symbol] = self._default_unspecified_value_range() + assert vr.is_int + sloc = self._get_sloc() + self.var_to_range_sloc[symbol] = ValueRangesSLoc(sloc, sloc) + + # Create a new FX placeholder and Z3 variable for 'symbol'. + fx_node = self._create_fx_placeholder_and_z3var(symbol, int) + + sym_node = SymNode(symbol, self, int, None, fx_node=fx_node) + self._log_create_unbacked_symbol( + "create_unbacked_symint", symbol, vr, source, sym_node=sym_node + ) + return SymInt(sym_node) + + def is_unbacked_symint(self, symbol: sympy.Symbol) -> bool: + """Check if a sympy symbol matches the naming convention for unbacked symbols""" + return symbol_is_type(symbol, SymT.UNBACKED_INT) + + @record_shapeenv_event() + def create_unbacked_symbool(self) -> SymBool: + """Create a symbolic boolean without a hint value""" + symbol: sympy.Symbol = make_symbol( + SymT.UNBACKED_INT, self.unbacked_symint_counter, integer=True + ) + self.unbacked_symint_counter += 1 + if not self._ignore_fresh_unbacked_symbols_tls(): + self.pending_fresh_unbacked_symbols.append(symbol) + self.counter["create_unbacked_symbol"] += 1 + self.var_to_stack[symbol] = CapturedTraceback.extract(skip=1) + vr = self.var_to_range[symbol] = ValueRanges(0, 1) + assert vr.is_int + sloc = self._get_sloc("default value range for unbacked SymBool") + self.var_to_range_sloc[symbol] = ValueRangesSLoc(sloc, sloc) + + # Create a new FX placeholder and Z3 variable for 'symbol'. + fx_node = self._create_fx_placeholder_and_z3var(symbol, bool) + + sym_node = SymNode(sympy.Eq(symbol, 1), self, bool, None, fx_node=fx_node) + self._log_create_unbacked_symbol( + "create_unbacked_symbool", symbol, vr, sym_node=sym_node + ) + + return SymBool(sym_node) + + @record_shapeenv_event() + def create_unspecified_symbol( + self, + val: Union[int, SymInt, float, SymFloat], + source: Source, + dynamic_dim: DimDynamic = DimDynamic.DUCK, + constraint_dim: DimConstraint = None, # NB: includes None + symbolic_context: Optional[StatelessSymbolicContext] = None, + ) -> sympy.Expr: + """ + Create a symbol with an unspecified value + + Compared to standard symbols we do not assume the value is positive, + nor do we specialze on zero or one values. + """ + # 'positive' is None for unspecified symbols, since we can't + # assume that it will be neither positive nor negative. + + # We don't want to specialize zero one val for unspecified symbol + # so that we can always get a new symbol despite val. + return self.create_symbol( + val, + source, + dynamic_dim, + constraint_dim, + positive=None, + do_not_specialize_zero_one=True, + symbolic_context=symbolic_context, + ) + + @record_shapeenv_event() + def create_symbol( + self, + val: int, + source: Source, + dynamic_dim: DimDynamic = DimDynamic.DUCK, + constraint_dim: DimConstraint = None, # NB: includes None + positive: Optional[bool] = True, + do_not_specialize_zero_one: bool = False, + symbolic_context: Optional[StatelessSymbolicContext] = None, + ) -> sympy.Expr: + """Create a new symbol which is tracked by this ShapeEnv""" + # check if constraint_dim is actually static integer + if ( + isinstance(constraint_dim, StrictMinMaxConstraint) + and constraint_dim.vr.lower == constraint_dim.vr.upper + ): + dynamic_dim = DimDynamic.STATIC + if constraint_dim.vr.lower != val: + raise ConstraintViolationError( + f"Static shape constraint of {constraint_dim.vr.lower} does not match input size of {val}, " + f"for {source.name}" + ) + if symbolic_context: + from torch._dynamo.source import TensorPropertySource + + assert isinstance(source, TensorPropertySource) + # TODO: storage_offset handling? + assert source.idx is not None + symbolic_context.dynamic_sizes[source.idx] = dynamic_dim + symbolic_context.constraint_sizes[source.idx] = None + constraint_dim = None + + # see note [Tensor Fakification and Symbol Caching] + source_name = source.name + if ( + isinstance(symbolic_context, StatefulSymbolicContext) + and id(self) not in symbolic_context.shape_env_to_source_to_symbol_cache + ): + symbolic_context.shape_env_to_source_to_symbol_cache[id(self)] = {} + + if ( + isinstance(symbolic_context, StatefulSymbolicContext) + and source_name + and ( + source_name + in symbolic_context.shape_env_to_source_to_symbol_cache[id(self)] + ) + ): + return symbolic_context.shape_env_to_source_to_symbol_cache[id(self)][ + source_name + ] + + if dynamic_dim in (DimDynamic.SIZE_LIKE_UNBACKED, DimDynamic.OBLIVIOUS_SIZE): + out = self.create_unbacked_symint(source).node.expr + self._constrain_range_for_size(out) + + self.unbacked_inputs.add(out) + + if isinstance(symbolic_context, StatefulSymbolicContext) and source_name: + symbolic_context.shape_env_to_source_to_symbol_cache[id(self)][ + source_name + ] = out + if dynamic_dim is DimDynamic.OBLIVIOUS_SIZE: + self.oblivious_var_to_val[out] = val + return out + + if do_not_specialize_zero_one: + specialize_zero_one = False + else: + specialize_zero_one = self.specialize_zero_one + + assert isinstance(source, Source), f"{type(source)} {source}" + assert not (positive and val < 0), f"positive set for negative value: {val}" + # It's always sound to allocate a symbol as DYNAMIC. If the user + # constrained the symbol, force the symbolic_context to DYNAMIC, because our + # constraint code will do weird stuff if, e.g., it's duck shaped + if constraint_dim is not None: + dynamic_dim = DimDynamic.DYNAMIC + + if dynamic_dim is DimDynamic.STATIC: + out = sympy.Integer(val) + if isinstance(symbolic_context, StatefulSymbolicContext) and source_name: + symbolic_context.shape_env_to_source_to_symbol_cache[id(self)][ + source_name + ] = out + return out + + elif dynamic_dim is DimDynamic.DUCK: + # duck_shape can be used to globally turn off duck shaping, even + # if it was requested + duck = self.duck_shape + elif dynamic_dim is DimDynamic.DYNAMIC: + duck = False + else: + raise AssertionError(f"unhandled dynamic_dim {dynamic_dim}") + + sloc = self._get_sloc() + + if val in (0, 1) and specialize_zero_one: + if val == 0: + return sympy.S.Zero + else: + return sympy.S.One + elif not duck or val not in self.val_to_var: + # If we're not duck shaping, we always create a new symbol + # Even if we're duck shaping, if we haven't seen this particular + # value before, we also create a new symbol + symbol_id = self._generate_unique_id(source.name) + if type(val) is int or is_nested_int(val): + sympy_expr = make_symbol( + SymT.SIZE, symbol_id, positive=positive, integer=True + ) + else: + sympy_expr = make_symbol( + SymT.FLOAT, symbol_id, positive=positive, real=True + ) + self.source_to_var[source_name] = sympy_expr + # We always associate vars to vals + if isinstance(val, int): + self.var_to_val[sympy_expr] = sympy.Integer(val) + elif isinstance(val, float): + self.var_to_val[sympy_expr] = sympy.Float(val) + else: + # Only used for jagged layout nested tensors + self.var_to_val[sympy_expr] = SingletonInt( + val.node.nested_int(), coeff=val.node.nested_int_coeff() + ) + + # Do the appending later, because we always want to populate this + self.var_to_sources[sympy_expr] = [] + # Create a Z3 variable for the new symbol. + self._add_z3var(sympy_expr, int) + + if duck: + # Make sure to reuse this symbol for subsequent duck shaping + # pyrefly: ignore [unsupported-operation] + self.val_to_var[val] = sympy_expr + + if isinstance(val, int): + if positive: + # Add assertions for the newly created symbols + self._add_assertion(sympy_expr > 1) + + # Apply default range, which assumes not zero-one + self.var_to_range[sympy_expr] = self._default_value_range( + do_not_specialize_zero_one + ) + self.var_to_range_sloc[sympy_expr] = ValueRangesSLoc( + self._get_sloc( + "user code shown is first use of this value--the guard itself is not " + "due user code but due to 0/1 specialization in the framework; to " + "avoid specialization try torch._dynamo.decorators.mark_unbacked(tensor, dim)" + if self.specialize_zero_one + else None + ), + sloc, + ) + else: + self.var_to_range[sympy_expr] = ( + self._default_unspecified_value_range() + ) + self.var_to_range_sloc[sympy_expr] = ValueRangesSLoc(sloc, sloc) + + # Small performance optimization: if we have a min-max constraint, + # we can proactively narrow to that range + if isinstance(constraint_dim, StrictMinMaxConstraint): + assert not duck + self._update_var_to_range( + sympy_expr, constraint_dim.vr, is_constraint=True + ) + + vr = self.var_to_range[sympy_expr] + assert vr.is_int + + if val not in vr: + raise ConstraintViolationError( + f"{val} not in range [{vr.lower}, {vr.upper}]" + ) + + range_str = f"[{vr.lower}, {vr.upper}]" + elif isinstance(val, float): + self.var_to_range[sympy_expr] = vr = ValueRanges(-sympy.oo, sympy.oo) + self.var_to_range_sloc[sympy_expr] = ValueRangesSLoc(sloc, sloc) + range_str = f"[{vr.lower}, {vr.upper}]" + assert vr.is_float + else: + # Skip var_range logic for SingletonInt + # Only used for jagged layout nested tensors + range_str = "" + + r = sympy_expr + + is_debug = config.extended_debug_create_symbol is not None and str( + sympy_expr + ) in config.extended_debug_create_symbol.split(",") + maybe_more_info = "" + if not is_debug and os.getenv("TORCHDYNAMO_EXTENDED_ADVICE", "1") not in ( + "0", + "", + ): + maybe_more_info = ( + ", for more info run with " + f'TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="{sympy_expr}" ' + "or to suppress this message run with " + 'TORCHDYNAMO_EXTENDED_ADVICE="0"' + ) + sloc, maybe_extra_debug = self._get_stack_summary(is_debug) + self.log.info( + "create_symbol %s = %s for %s %s %s%s%s", + sympy_expr, + val, + source.name, + range_str, + sloc, + maybe_more_info, + maybe_extra_debug, + stack_info=is_debug, + ) + trace_structured( + "create_symbol", + metadata_fn=lambda: { + "symbol": str(sympy_expr), + "val": repr(val), + "vr": range_str, + "source": source.name, + "user_stack": structured.from_traceback( + TracingContext.extract_stack() + ), + "stack": structured.from_traceback( + CapturedTraceback.extract(skip=1).summary() + ), + }, + ) + + self.counter["create_symbol"] += 1 + else: + # This implements duck-shaping: input sizes that match are assigned + # the same symint + r = self.val_to_var[val] + self.source_to_var[source_name] = r + self.log.debug("create_symbol %s duck sized %s", r, source.name) + + if isinstance(r, sympy.Symbol): + r_sources = self.var_to_sources[r] + r_sources.append(source) + if not source.is_ephemeral() and r_sources[0].is_ephemeral(): + # prefer non-ephemeral source first since it may be guarded on later + r_sources[0], r_sources[-1] = r_sources[-1], r_sources[0] + + # This ensures we get zeros in symbol_guard_counts, which makes + # some queries simpler (since we will accumulate mass on 0 this + # way) + self.symbol_guard_counter[r] = 0 + + if isinstance(symbolic_context, StatefulSymbolicContext) and source_name: + symbolic_context.shape_env_to_source_to_symbol_cache[id(self)][ + source_name + ] = r + return r + + def add_var_to_val(self, expr: sympy.Symbol, val: int) -> None: + """Adds a new symbol to the symbolic environment.""" + log.debug("add_var_to_val %s %s", expr, val, stack_info=True) + assert expr not in self.var_to_val, f"{expr} already exists" + self.var_to_val[expr] = sympy.Integer(val) + + def _debug_name(self, source: Source) -> str: + src_name = source.name + return self.source_name_to_debug_name.get(src_name, src_name) + + def _render_range_for_constraint_violation( + self, source: Source, c: Union[StrictMinMaxConstraint, RelaxedUnspecConstraint] + ) -> str: + if isinstance(c, StrictMinMaxConstraint): + lower, upper = c.vr.lower, c.vr.upper + default = self._default_value_range() + if lower <= default.lower: + lower = None + if upper >= default.upper: + upper = None + c_render = ( + f"{self._debug_name(source)} = {source.name} in the specified range" + ) + if lower is not None and upper is not None: + c_render += f" {lower} <= {self._debug_name(source)} <= {upper}" + elif lower is None and upper is not None: + c_render += f" {self._debug_name(source)} <= {upper}" + elif lower is not None and upper is None: + c_render += f" {lower} <= {self._debug_name(source)}" + return c_render + return c.render(source) + + def produce_guards(self, *args: Any, **kwargs: Any) -> list[str]: + """ + Like produce_guards_verbose, but only returns the non-verbose python guard expressions + (no verbose guards produced.) + """ + return self.produce_guards_verbose(*args, **kwargs, langs=("python",))[0].exprs + + def produce_guards_verbose( + self, + placeholders: Sequence[FakeTensor], + sources: Sequence[Source], + source_ref: Callable[[Source], str] = lambda n: n.name, + *, + guards: Optional[list[ShapeGuard]] = None, + input_contexts: Optional[DimList[SymbolicContext]] = None, + # Encodes user-specified input shape equations of the form s = s' and s = fn(s'). + # (See docs on EqualityConstraint for details of the encoding.) + equalities_inputs: Optional[EqualityConstraint] = None, + _simplified: bool = False, + # Indicates if we should produce guards for known static values. + ignore_static: bool = True, + langs: tuple[str, ...] = ("python", "verbose_python"), + ) -> list[_ShapeGuardsHelper]: + """ + Generates a list of guards strings which, when evaluated in a context that + defines tensors for all the sources, returns True or False depending + on if the guards in the list evaluated to True or not. Primarily used by Dynamo, + but this is also helpful for manual testing of guards (see + evaluate_guards_for_args) + + For convenience in testing, a source is allowed to be a str, + in which case we will assume it is a LocalSource + + simplified lets you omit duck sizing, equality and 0/1 guards. + This is useful for testing when you don't care about the boilerplate + guards, and it may be helpful for user output too (be careful though; + some equality guards are nontrivial! It would be nice to get simplified + output to print them too). It's private because it's not + intended for normal use + + Returns guards in python and python with verbose comments (verbose) by + default. + """ + self.log.info("produce_guards") + + # Check if we get to the same ShapeEnv state by replaying the recorded events. + # This will create a new ShapeEnv instance, and call all recorded function + # calls on this new instance. Finally, it will check whether this new instance + # has equal state. + # + # It's important that we do it in the beginning of this function, since it modifies + # self.dim_constraints through its execution. Changes that happen in this method + # aren't interesting, since this is the function call we wish to reproduce at the + # end. If we wish to simply reproduce ShapeEnv instances even after this call, + # this method should also be recorded. + if self.check_recorded_events: + shape_env = replay_shape_env_events(self.events) + self.check_equal(shape_env) + + assert len(placeholders) == len(sources), ( + f"len({placeholders}) != len({sources})" + ) + Tensorlike = (torch.Tensor, FakeTensorMeta) + + def _create_no_constraints_context(t: Tensor) -> StatelessSymbolicContext: + return StatelessSymbolicContext( + # Ignored; only the constraints part is relevant below. + dynamic_sizes=[DimDynamic.DYNAMIC] * t.dim(), + dynamic_strides=[DimDynamic.INFER_STRIDE] * t.dim(), + constraint_sizes=[None] * t.dim(), + constraint_strides=[None] * t.dim(), + ) + + # Expand optional inputs, or verify invariants are upheld + if input_contexts is None: + # pyrefly: ignore [bad-assignment] + input_contexts = [ + # pyrefly: ignore [bad-argument-type] + _create_no_constraints_context(t) if isinstance(t, Tensorlike) else None + for t in placeholders + ] + else: + assert len(input_contexts) == len(placeholders) + + for i, (t, context) in enumerate(zip(placeholders, input_contexts)): + if isinstance(t, Tensorlike): + if context is None: + # pyrefly: ignore [bad-argument-type] + input_contexts[i] = _create_no_constraints_context(t) + else: + assert isinstance(t, (SymInt, int, SymFloat, float)) + assert not isinstance(context, list) + + # It took a lot of sweat to figure out the algorithm here. Let's + # explain how it works. + # + # The ShapeEnv lifecycle looks something like this: + # + # - For each input, you either generate a fresh Sympy symbol (s0) to + # represent its value (a binding site), or you reuse some + # preexisting symbol or expression, skipping the symbol allocation + # (e.g., duck sizing to a preexisting symbol, or expressing a + # stride as a multiplication of a separate stride and size.) + # Naively, you might expect to bind a fresh Sympy symbol for + # every input, but this is fairly wasteful as most of these + # symbols immediately simplify away, and if you don't eagerly + # specialize, e.g., 0/1 symbols, you end up with very complicated + # expressions that are not optimizable in practice. + # + # - You perform some compute on these symbols, occasionally + # introducing guards on boolean expressions on these symbols. + # In particular, whenever we guard on equality (_maybe_guard_rel), + # we can simplify shapes; e.g., when s0 == s1 * 2, we can now + # replace all occurrences of s0 with s1 * 2. Sometimes, a + # boolean expression evaluation doesn't introduce a guard, as + # the guard is already entailed by the simplifications we have + # applied. + # + # - In the end, you have a bunch of replacements (saying how to + # simplify shapes) and a bunch of guards (all the equality guards + # are trivial, because they're covered by the replacements). + # + # From the ShapeEnv, we must generate a Python expression that, when + # evaluated on a set of inputs, tells us whether or not these boolean + # expressions would have evaluated in the same way. However, + # we cannot easily compute this, as we elide recording boolean + # expressions when we think they are vacuously true. Thus, we seek + # an approximation: we must generate an expression, if true, would have + # produced an "equivalent" ShapeEnv, which would answer guard + # expressions in the same way. + # + # Our notion of equivalence is a bit subtle. For example, consider + # the ShapeEnv created from an input of size (5, 4) versus (4, 4) + # (no other guards.) Duck sizing would generate (s0, s1) in the first + # case but (s0, s0) in the second. We do NOT assume that size + # variables are disjoint; so in fact a graph that assumes the input + # could be (s0, s1) subsumes (s0, s0) (setting s0 == s1), but not + # vice versa. However, consider an analogous case (1,) versus (2,). + # Duck sizing generates (1,) and (s0,); the (s0,) graph does NOT + # subsume the (1,) graph because we assume that any size variables + # is NOT 0/1 (and make simplifications according to this; e.g., if + # we queried s0 == 0, we would immediately return False without + # returning a guard.) + # + # So, it is perhaps easier to flip things on their head: the guard + # expressions we generate here say what simplifications are valid, + # and what are not. Below, we explain each of the guard expressions + # we generate + + # TODO: Make this more efficient by binding all the size/stride/offsets + # to locals before performing tests on them. + + from torch._dynamo.source import TensorProperty, TensorPropertySource + + # Actual codegen must be delayed as we don't necessarily know what + # the symbol mapping is + input_guards = [] + + symbol_to_source: dict[sympy.Symbol, list[Source]] = collections.defaultdict( + list + ) + symbol_to_constraints: defaultdict[sympy.Symbol, set[Constraint]] = ( + collections.defaultdict(set) + ) + constraint_violations: list[tuple[bool, str, Callable[[], str]]] = [] + + printers: list[_ShapeGuardPrinter] = [] + py_printer = ShapeGuardPythonPrinter( + symbol_to_source, source_ref, self.var_to_sources + ) + for lang in langs: + if lang in ["python", "verbose_python"]: + printers.append(py_printer) + elif lang == "cpp": + printers.append( + _ShapeGuardCppPrinter( + symbol_to_source, source_ref, self.var_to_sources + ) + ) + else: + raise NotImplementedError(f"Unknown lang: {lang}") + + def record_constraint_violation( + warn_only: bool, + debug_name: str, + msg: str, + hint: Optional[Callable[[], str]] = None, + ) -> None: + constraint_violations.append( + (warn_only, debug_name, lambda: f"{msg}{hint()}" if hint else msg) + ) + + def is_dim(src: object) -> TypeGuard[TensorPropertySource]: + return ( + isinstance(src, TensorPropertySource) + and src.prop is TensorProperty.SIZE + ) + + if equalities_inputs: + source_index = {} + for i, src in enumerate(sources): + source_index[src.name] = i + + def get_expression(tensor_dim_src: Source) -> sympy.Expr: + fake = placeholders[source_index[tensor_dim_src.base.name]] # type: ignore[attr-defined] + assert tensor_dim_src.idx is not None # type: ignore[attr-defined] + symint = fake.shape[tensor_dim_src.idx] # type: ignore[attr-defined] + if isinstance(symint, torch.SymInt): + return symint.node.expr + else: + assert type(symint) is int, f"Expected int, got {type(symint)}" + return sympy.Integer(symint) + + for src1, src2 in equalities_inputs.source_pairs: + expr1, expr2 = get_expression(src1), get_expression(src2) # type: ignore[] + # Check whether given input shape values satisfy a specified equation s = s'. + # - Raise when the equation was violated by the given input shape values. + # - Otherwise issue a guard to constrain them. + concrete_val = self.evaluate_expr(sympy.Eq(expr1, expr2)) + if not concrete_val: + raise ConstraintViolationError( + f"{src1.name} = {expr1 if isinstance(expr1, int) else expr1.xreplace(self.var_to_val)}" + " is not equal to " + f"{src2.name} = {expr2 if isinstance(expr2, int) else expr2.xreplace(self.var_to_val)}" + ) + + for srcEq, root, fn in equalities_inputs.derived_equalities: + expr1 = get_expression(srcEq) + # recall that root is either a phantom symbol or an input source + if isinstance(root, sympy.Symbol): + expr2, debug_name = root, self.var_to_sources[root][0].name + elif isinstance(root, sympy.Integer): + expr2, debug_name = root, str(root) + else: + expr2, debug_name = get_expression(root), self._debug_name(root) + expr2_ = fn(expr2) + # Check whether given input shape values satisfy a specified equation s = fn(s'). + # - Raise when the equation was violated by the given input shape values. + # - Otherwise issue a guard to constrain them. + concrete_val = self.evaluate_expr(sympy.Eq(expr1, expr2_)) + if not concrete_val: + raise ConstraintViolationError( + f"Expected input {srcEq.name} to be equal to " + f"{fn(sympy.Symbol(debug_name))}, " + f"where {debug_name} = {expr2.xreplace(self.var_to_val)}, " + f"but got {expr1.xreplace(self.var_to_val)}" + ) + + for phantom_symbol in equalities_inputs.phantom_symbols: + if isinstance(phantom_symbol, sympy.Symbol): + # we created additional phantom symbols that are not input shape dimensions + symbol_to_source[phantom_symbol].extend( + self.var_to_sources[phantom_symbol] + ) + + # How do we know what the value of s0 is? Fresh variables can only be + # bound by inputs, so there MUST be some other input which binds the + # variable. If there is no such input, this is an error in our + # system. We record where all symbols come from, to help you diagnose + # why those symbols didn't occur. + # + # In fact, generally speaking it is only possible for the "outermost" + # user of a ShapeEnv to evaluate the guards, because some inputs may + # not be available to inner levels. For example, Dynamo can guard on + # tensors that never actually become graph arguments (they are + # pruned). In this case, only Dynamo knows about these arguments. + def track_symint( + source: Source, val: IntLikeType, constraint: DimConstraint = None + ) -> None: + log.debug( + "track_symint %s %s %s", + LazyString(lambda: source.name), + val, + constraint, + ) + assert not isinstance(val, SymInt) or is_symbolic(val) + + if isinstance(val, SymInt) and val.node.maybe_as_int() is not None: + val = val.node.maybe_as_int() + + if isinstance(val, SymInt): + s = val.node.expr + if isinstance(s, sympy.Symbol): + symbol_to_source[s].append(source) + if constraint is not None and not isinstance( + constraint, RelaxedUnspecConstraint + ): + symbol_to_constraints[s].add(constraint) + else: + constraint_violated = False + if isinstance(constraint, StrictMinMaxConstraint): + # try inferring the ranges of the expr s + sym_vrs = { + x: self.var_to_range.get(x, None) for x in s.free_symbols + } + if any(vr is None for vr in sym_vrs.values()): + # some of the free symbols in s don't have ranges + constraint_violated = True + elif isinstance(constraint, RelaxedUnspecConstraint): + if s.is_number: + i = int(s) + # Don't complain about 0/1 specialization, we + # expect to have to compile in this case anyway + if i not in (0, 1): + constraint_violated = True + if constraint_violated: + assert constraint is not None + + def hint(s: sympy.Expr) -> str: + sexpr = py_printer.doprint(s) + return f"{sexpr}." + + var_with_range = self._render_range_for_constraint_violation( + source, constraint + ) + msg = ( + f"Not all values of {var_with_range} are valid because " + f"{self._debug_name(source)} was inferred to be equal to " + ) + record_constraint_violation( + constraint.warn_only, + self._debug_name(source), + msg, + hint=functools.partial(hint, s), + ) + + input_guards.append((source, s)) + else: + s = sympy.Integer(val) + input_guards.append((source, s)) + constraint_violated = False + if isinstance(constraint, StrictMinMaxConstraint): + if not ( + s == constraint.vr.lower == constraint.vr.upper + ): # allow static constraints + constraint_violated = True + elif isinstance(constraint, RelaxedUnspecConstraint): + # Don't complain about 0/1 specialization, we + # expect to have to compile in this case anyway + if val not in (0, 1): + constraint_violated = True + if constraint_violated: + assert constraint is not None + var_with_range = self._render_range_for_constraint_violation( + source, constraint + ) + user_stack = self.specialization_stacks.get(source, None) + msg = ( + f"You marked {self._debug_name(source)} as dynamic but your code " + f"specialized it to be a constant ({val}). If you're using mark_dynamic, " + f"either remove it or use maybe_mark_dynamic. If you're using Dim.DYNAMIC, " + f"replace it with either Dim.STATIC or Dim.AUTO." + + ( + "\n\nUser stack:\n" + "".join(user_stack.format()) + if user_stack + else "" + ) + ) + record_constraint_violation( + constraint.warn_only, self._debug_name(source), msg + ) + + def track_symfloat(source: Source, val: FloatLikeType) -> None: + log.debug("track_symfloat %s %s", LazyString(lambda: source.name), val) + assert not isinstance(val, SymFloat) or is_symbolic(val) + + if isinstance(val, SymFloat) and val.node.maybe_as_float() is not None: + val = val.node.maybe_as_float() + + if isinstance(val, SymFloat): + s = val.node.expr + if isinstance(s, sympy.Symbol): + symbol_to_source[s].append(source) + input_guards.append((source, s)) + else: + s = sympy.Float(val) + input_guards.append((source, s)) + + # pyrefly: ignore [no-matching-overload] + for t, source, context in zip(placeholders, sources, input_contexts): + if isinstance(source, str): + from torch._dynamo.source import LocalSource + + source = LocalSource(source) + assert isinstance(source, Source) + if t is None: + continue + if isinstance(t, (SymInt, int)): + constraint = ( + None if context is None else getattr(context, "constraint", None) + ) + track_symint(source, t, constraint) + continue + elif isinstance(t, (SymFloat, float)): + track_symfloat(source, t) + continue + assert isinstance(t, Tensorlike) + if is_traceable_wrapper_subclass(t): + from torch._dynamo.source import AttrSource + + assert isinstance(context, SubclassSymbolicContext) + + # For subclasses, we need to track symints on BOTH the outer + # and inner tensors. + # TODO: type this better + sources_tensors_constraints: list[tuple[Source, Any, Any, Any]] = [ + (source, t, context.constraint_sizes, context.constraint_strides) + ] + attrs, _ = t.__tensor_flatten__() + for attr in attrs: + inner_t = getattr(t, attr) + inner_context = context.inner_contexts[attr] + sources_tensors_constraints.append( + ( + AttrSource(source, attr), + inner_t, + inner_context.constraint_sizes, # type: ignore[attr-defined] + inner_context.constraint_strides, # type: ignore[attr-defined] + ) + ) + else: + sources_tensors_constraints = [ + (source, t, context.constraint_sizes, context.constraint_strides) # type: ignore[attr-defined] + ] + + for ( + src, + curr_t, + constraint_size, + constraint_stride, + ) in sources_tensors_constraints: + if is_sparse_any(curr_t): + for i, ss in enumerate(curr_t.size()): + property_source = TensorPropertySource( + src, TensorProperty.SIZE, i + ) + track_symint(property_source, ss, constraint_size[i]) + else: + for i, ss in enumerate(curr_t.size()): + property_source = TensorPropertySource( + src, TensorProperty.SIZE, i + ) + track_symint(property_source, ss, constraint_size[i]) + + for i, ss in enumerate(curr_t.stride()): + property_source = TensorPropertySource( + src, TensorProperty.STRIDE, i + ) + track_symint(property_source, ss, constraint_stride[i]) + track_symint( + TensorPropertySource(src, TensorProperty.STORAGE_OFFSET), + curr_t.storage_offset(), + ) + + # 1. Every input must equal the final simplified symbolic expression + # stored on the placeholder. Given a placeholder (s0*2, s1), + # if we have an input (2, 3), we must show s0*2 == 2 and s1 == 3. + # This does a lot of work: it covers duck sizing and equality guards. + all_exprs: list[list[str]] = [[] for _ in langs] + + self.dim_constraints = DimConstraints( + symbol_to_source, + self.var_to_val, + set(symbol_to_constraints.keys()), + self.source_name_to_debug_name, + ) + + if not _simplified: + for source, expr in input_guards: + srcname = source.name + if self._translation_validation_enabled: + # Ignore sources that were not turned into SymInts. + if srcname in self.source_to_symbol: + self._add_target_expr( + sympy.Eq(self.source_to_symbol[srcname], expr) + ) + + # Small optimization + if ( + isinstance(expr, sympy.Symbol) + and symbol_to_source.get(expr) + and source == symbol_to_source[expr][0] + ): + continue + + # This logic excludes static values found on tensors from guarding, because + # dynamo's check_tensor_fn does that (see guards.cpp). + # However, for non tensor sources, we still need to guard here. + if ignore_static and isinstance(source, TensorPropertySource): + if expr.is_number: + self.log.debug( + "Skipping guard %s", f"{source_ref(source)} == {expr}" + ) + continue + + if is_dim(source): + self.dim_constraints.add_equality(source, expr) + + for exprs, printer, lang in zip(all_exprs, printers, langs): + res = f"{printer.print_source(source)} == {printer.doprint(expr)}" + + if lang == "verbose_python": + if (s0 := self.source_to_var.get(srcname)) is not None: + if source != self.var_to_sources[s0][0]: + res = ( + f"{res} # duck sizing added this equality because these " + f"variables had the same size {self.var_to_val[s0]} " + "(to avoid this specialization, set torch.fx.experimental._config.use_duck_shape = False)" + ) + elif (sloc := self.replacements_slocs.get(s0)) is not None: + res = f"{res} # {sloc}" + else: + res = f"{res} # (unknown var {s0}, please file a bug)" + else: + res = f"{res} # (unknown source {srcname}, please file a bug)" + exprs.append(res) + + if ( + isinstance(source, TensorPropertySource) + and source.prop is TensorProperty.SIZE + and equalities_inputs + and len(expr.free_symbols) == 1 + ): + symbol = next(iter(expr.free_symbols)) + if ( + isinstance(expr, sympy.Symbol) + and expr in symbol_to_constraints + and not equalities_inputs.is_equal( + source, symbol_to_source[expr][0] + ) + ): + msg = ( + f"The values of {self._debug_name(source)} = {source.name} and " + f"{self._debug_name(symbol_to_source[expr][0])} = {symbol_to_source[expr][0].name} " + "must always be equal." + ) + record_constraint_violation( + equalities_inputs.warn_only, self._debug_name(source), msg + ) + + if ( + not isinstance(expr, sympy.Symbol) + and symbol in symbol_to_constraints + and not equalities_inputs.is_derived( + source, + symbol_to_source[symbol][0], + lambda x: expr.xreplace({symbol: x}), + ) + ): + src = symbol_to_source[symbol][0] + msg = ( + f"The values of {self._debug_name(source)} = {source.name} must always be related to " + f"the values of {self._debug_name(src)} = {src.name} by " + f"{self._debug_name(source)} = {expr.xreplace({symbol: sympy.sympify(self._debug_name(src))})}." + ) + record_constraint_violation( + equalities_inputs.warn_only, self._debug_name(source), msg + ) + + # NB: Not necessary to report constraint violations here: + # constraints are guaranteed to be on symbols (we've already + # caught constants and non-atomic expressions), so we only + # have relational constraints, but we don't support those + # at the moment + + # 2. Every guard must evaluate to True (but remember many guards + # like s0 == s1*2 because trivial due to simplification) + issued = set() + + def issue_guard(guard: ShapeGuard) -> None: + expr = self.simplify(guard.expr) + + # Avoid re-issuing the same guard. + if expr in issued: + return + + issued.add(expr) + + try: + is_trivial = False + if any( + is_dim(source) + for s in expr.free_symbols + for source in symbol_to_source[s] + ): + assert self.dim_constraints is not None + is_trivial = self.dim_constraints.add(expr) + + for exprs, printer, lang in zip(all_exprs, printers, langs): + guard_expr = printer.doprint(expr) + if lang == "verbose_python": + guard_expr = f"{guard_expr} # {guard.sloc}" + exprs.append(guard_expr) + + self._add_target_expr(expr) + # A non-relational constraint on a single sizevar can violate + # a constraint + if not is_trivial and len(expr.free_symbols) == 1: + symbol = next(iter(expr.free_symbols)) + source = symbol_to_source[symbol][0] + constraints = symbol_to_constraints[symbol] + for c in constraints: + if isinstance(c, StrictMinMaxConstraint): + var_with_range = ( + self._render_range_for_constraint_violation(source, c) + ) + msg = ( + f"Not all values of {var_with_range} " + f"satisfy the generated guard {py_printer.doprint(expr)}." + ) + record_constraint_violation( + c.warn_only, self._debug_name(source), msg + ) + elif isinstance(c, RelaxedUnspecConstraint): + # This is fine, we allow guards here as long as it + # didn't constrain it to one value (we don't + # actually know this; this depends on our + # ValueRanges reasoning capability) + pass + else: + raise AssertionError(f"unrecognized constraint {c}") + except Exception: + self.log.warning("Failing guard allocated at %s", guard.sloc) + raise + + # First, issue all guards. + # This removes all the checks that follow from bounds + # We could simply emit those and also the bounds 2 <= size when necessary + for guard in guards if guards is not None else self.guards: + if ( + self._maybe_evaluate_static( + guard.expr, axioms=(), size_oblivious=guard.size_oblivious + ) + is not None + ): + continue + + issue_guard(guard) + + # Because there are guards that export's constraint solver can suggest good fixes for, that we may have + # deferred as runtime asserts, and that produce_guards() alone won't do anything with (e.g. divisiblity guards), + # we want to send runtime asserts to export's constraint solver too. These will still stay in the graph as asserts, + # but export's constraint solver can decide whether to do anything with them (i.e. raise an error and provide + # suggested fixes, or decide it's out of scope and leave as a runtime assert in the graph). + for ra in self.deferred_runtime_asserts.get(None, []): + if self._maybe_evaluate_static(ra.expr, axioms=()) is not None: + continue + expr = self.simplify(ra.expr) + + self.dim_constraints.add(expr) + + # 3. Every symbol must be within its value range (this handles 0/1 + # specialization too). + for symbol, sources in symbol_to_source.items(): + r = self.var_to_range.get(symbol) + if r is None: + continue + vr_sloc = self.var_to_range_sloc[symbol] + + assert sources + bounds = [] + rf = source_ref(sources[0]) + verbose_expr = "" + if r.lower not in (-sympy.oo, -int_oo): + if any(is_dim(source) for source in sources): + self.dim_constraints.add(sympy.Ge(symbol, r.lower)) + # Only print lower bound in simplified mode if it is not the + # default + if not _simplified or r.lower != self._default_value_range().lower: + bounds.append(sympy.Le(r.lower, symbol, evaluate=False)) + verbose_expr = f"{r.lower} <= {rf} # {vr_sloc.lower}" + if r.upper not in (sympy.oo, int_oo): + if any(is_dim(source) for source in sources): + self.dim_constraints.add(sympy.Le(symbol, r.upper)) + # nontrivial upper bound is always interesting + bounds.append(sympy.Le(symbol, r.upper, evaluate=False)) + if verbose_expr: + verbose_expr = f"{r.lower} <= {rf} <= {r.upper} # {vr_sloc.lower} and {vr_sloc.upper}" + else: + verbose_expr = f"{rf} <= {r.upper} # {vr_sloc.upper}" + if bounds: + bound = sympy.And(*bounds, evaluate=False) + + for exprs, printer, lang in zip(all_exprs, printers, langs): + if lang == "verbose_python": + exprs.append(verbose_expr) + else: + exprs.append(printer.doprint(bound)) + # NB: verbose_exprs are done above + + # Check constraints + constraints = symbol_to_constraints[symbol] + for c in constraints: + if isinstance(c, StrictMinMaxConstraint): + # TODO: With int_oo, I think this condition is a noop + # now + if not (c.vr & self._default_value_range()).issubset(r): + source = sources[0] + + expr = sympy.And( + sympy.Le(r.lower, symbol), sympy.Le(symbol, r.upper) + ) + guard_expr = py_printer.doprint(expr) + var_with_range = ( + self._render_range_for_constraint_violation(source, c) + ) + msg = f"Not all values of {var_with_range} satisfy the generated guard {guard_expr}" + record_constraint_violation( + c.warn_only, + self._debug_name(source), + msg, + ) + # We NaN specialize, which means similar to 0/1 specialization we + # should assume that the float is NOT nan. This is load bearing + # if you have something like an equality guard, nan will play + # merry hell with the reasoning. + if symbol_is_type(symbol, SymT.FLOAT): + res = f"not math.isnan({py_printer.print_source(sources[0])})" + for exprs, printer, lang in zip(all_exprs, printers, langs): + if lang == "verbose_python": + exprs.append( + f"{res} # implicit guard for float input due to NaN specialization in the framework" + ) + elif lang == "python": + exprs.append(res) + elif lang == "cpp": + exprs.append(f"~std::isnan({printer.print_source(sources[0])})") + else: + raise NotImplementedError(f"Unimplemented for lang: {lang}") + + if constraint_violations: + warn_msgs: list[str] = [] + error_msgs: list[str] = [] + debug_names = set() + for warn_only, debug_name, msg_cb in constraint_violations: + if warn_only: + str_msg = f" {len(warn_msgs) + 1}. {msg_cb()}" + warn_msgs.append(str_msg) + else: + str_msg = f" - {msg_cb()}" + error_msgs.append(str_msg) + # pyrefly: ignore [bad-argument-type] + debug_names.add(debug_name) + if len(error_msgs) > 0: + debug_names_str = ", ".join(sorted(debug_names)) + err = "\n".join(error_msgs) + raise ConstraintViolationError( + f"Constraints violated ({debug_names_str})! " + 'For more information, run with TORCH_LOGS="+dynamic".\n' + f"{err}" + ) + elif len(warn_msgs) > 0: + log.debug("%s Warning only constraints violated", len(warn_msgs)) + + signpost_event( + "dynamic", + "produce_guards", + { + **self.co_fields, + **self.counter, + "num_guards": len(all_exprs[0]), + "free_symbols": sum(1 for v in symbol_to_source.values() if v), + # The keys are meaningless from an aggregate perspective, so + # don't include them. Biggest first. + "symbol_guard_counts": sorted( + self.symbol_guard_counter.values(), reverse=True + ), + }, + ) + + if self._translation_validation_enabled: + from torch.fx.experimental.validator import PopulateValidator + + # Add all deferred runtime assertions; these are not technically + # handled by produce_guards but we need to put them in the target + # set + for ras in self.deferred_runtime_asserts.values(): + for ra in ras: + self._add_target_expr(ra.expr) + + # Add value range bound guards for all symbols with no trivial bounds. + # Reason: '_maybe_evaluate_static' may eliminate guards based on the + # refined value ranges. + for sym, vr in self.var_to_range.items(): + if vr.lower not in (-sympy.oo, -int_oo): + self._add_target_expr(sympy.Le(vr.lower, sym)) + if vr.upper not in (sympy.oo, int_oo): + self._add_target_expr(sympy.Le(sym, vr.upper)) + + # Before validating, populate the input of the validator with the + # built FX graph. + with fx_traceback.preserve_node_meta(): + PopulateValidator(self.graph, self.validator).run() + + # Only run translation validation when we are not passing custom guards + if guards is None: + self._check_translation_validate() + + helpers: list[_ShapeGuardsHelper] = [] + for exprs, printer, lang in zip(all_exprs, printers, langs): + if lang == "cpp": + assert isinstance(printer, _ShapeGuardCppPrinter) + helpers.append(_CppShapeGuardsHelper(exprs, printer.source_to_symbol)) + else: + helpers.append(_ShapeGuardsHelper(exprs)) + return helpers + + def produce_guards_expression( + self, + placeholders: Sequence[Union[SymInt, FakeTensor]], + *, + guards: Optional[list[ShapeGuard]] = None, + ignore_static: bool = True, + ) -> Optional[str]: + """ + Expected to be used with evaluate_guards_expression(). Produces the guards + for the given placeholders and returns a string expression to be evaluated + by evaluate_guards_expression given concrete values for the placeholders. + """ + from torch._dynamo.source import LocalSource + + arg_names = [f"t{i}" for i in range(len(placeholders))] + produced_guards = self.produce_guards( + placeholders, + [LocalSource(a) for a in arg_names], + guards=guards, + ignore_static=ignore_static, + ) + if produced_guards: + return " and ".join(produced_guards) + return None + + def evaluate_symexpr(self, code: str) -> Union[int, float, bool]: + """ + To be used by compile_fx to evaluate symexprs + """ + args = {str(e): val for e, val in self.var_to_val.items()} + return eval(code, SYMPY_INTERP, args) + + def deserialize_symexpr(self, code: str) -> Union[SymInt, SymFloat, SymBool]: + """ + To be used by compile_fx to deserialize symexprs + """ + args = { + str(e): SymInt(SymNode(e, self, int, int(val), fx_node=None)) + for e, val in self.var_to_val.items() + } + return eval(code, SYMPY_INTERP, args) + + def evaluate_guards_expression(self, code: str, args: Sequence[object]) -> bool: + """ + Expected to be used with produce_guards_expression(). Evaluates an expression + generated by produce_guards_expression for the given concrete args. + """ + arg_names = [f"t{i}" for i in range(len(args))] + return eval(code, SYMPY_INTERP, {"L": dict(zip(arg_names, args))}) + + def evaluate_guards_for_args( + self, + placeholders: Sequence[FakeTensor], + args: Sequence[Tensor], + *, + ignore_static: bool = True, + ) -> bool: + """Generate guards for a graph's placeholder values and evaluate the guards with args""" + code = self.produce_guards_expression(placeholders, ignore_static=ignore_static) + if code: + return self.evaluate_guards_expression(code, args) + return True + + def get_pruned_guards(self, symints: Sequence[torch.SymInt]) -> list[ShapeGuard]: + """ + Get a list of guards, but pruned so it only provides guards that + reference symints from the passed in input + """ + # pyrefly: ignore [bad-assignment] + symints = { + s.node.expr for s in symints if isinstance(s.node.expr, sympy.Symbol) + } + guards = [ + g for g in self.guards if all(s in symints for s in g.expr.free_symbols) + ] + return guards + + def bind_symbols( + self, placeholders: Sequence[FakeTensor], args: Sequence[Tensor] + ) -> dict[sympy.Symbol, int]: + """ + Given a paired list of placeholders (fake tensors with + symbolic sizes) and concrete arguments (regular tensors + with real sizes), returns a dictionary mapping each + symbol to its real value. So for example, if you + have a placeholder with size (s0, s1), binding + (2, 4) to it will give you {s0: 2, s1: 4}. This is + not guaranteed to bind ALL symbols in the ShapeEnv; + we can't bind a symbol if it doesn't occur in any placeholder, + and symbols that already have replacements won't get bindings. + + This is a little duplicative with evaluate_guards but + it's different enough that it seemed cleanest to make + another copy. This assumes the guards are already checked, + though if it's cheap we'll check for shenanigans + """ + bindings: dict[sympy.Symbol, int] = {} + + def bind_symint(arg: object, val: object) -> None: + if isinstance(val, SymInt): + assert isinstance(arg, int) + s = val.node.expr + + if isinstance(s, sympy.Symbol): + if s in bindings: + assert bindings[s] == arg, f"{bindings[s]} != {arg}" + else: + bindings[s] = arg + elif isinstance(-s, sympy.Symbol): + if -s in bindings: + assert bindings[-s] == -arg, f"{bindings[-s]} != {-arg}" + else: + bindings[-s] = -arg + + for t, arg in zip(placeholders, args): + if t is None: + continue + if isinstance(t, SymInt): + bind_symint(arg, t) + continue + assert isinstance(t, torch.Tensor) + for i, s in enumerate(t.size()): + bind_symint(arg.size(i), s) + for i, s in enumerate(t.stride()): + bind_symint(arg.stride(i), s) + bind_symint(arg.storage_offset(), t.storage_offset()) + + return bindings + + def get_nontrivial_guards(self) -> list[SympyBoolean]: + """Returns a list of guard expressions that aren't statically known (i.e. not trivial)""" + return [ + self.simplify(guard.expr) + for guard in self.guards + if self._maybe_evaluate_static( + guard.expr, axioms=(), size_oblivious=guard.size_oblivious + ) + is None + ] + + def format_guards(self, verbose: bool = False) -> str: + """Format this shape env's guard expressions with optional traceback info if verbose""" + + return "\n".join( + f" - {guard.expr}{' ' + str(guard.sloc) if verbose else ''}" + for guard in self.guards + ) + + def bound_sympy( + self, expr: sympy.Expr, size_oblivious: bool = False + ) -> ValueRanges: + """Given a sympy expression, computes a ValueRanges bound for what values it can be""" + # TODO: maybe it's guaranteed x in is var_to_range? + var_to_range = {x: self.var_to_range.get(x, None) for x in expr.free_symbols} + if size_oblivious: + # Clamp values of size-like variables + # NB: discarding the old upper bound in intentional, per + # https://github.com/pytorch/pytorch/pull/123675 + for x in self.size_like & var_to_range.keys(): + if var_to_range[x] is not None: + # NB: do NOT set upper to 2 ** 48, we're using this solely + # to determine if we can do size-like replacement, the + # upper bound is irrelevant here + var_to_range[x] = ValueRanges(2, int_oo) + return bound_sympy(expr, var_to_range) # type: ignore[arg-type] + + @_lru_cache + def get_axioms( + self, + symbols: Optional[tuple[sympy.Symbol]] = None, + compute_hint: bool = False, + ) -> tuple[SympyBoolean, ...]: + """ + Given the symbols in an expression, it returns all the runtime asserts that have those symbols + concatenated with all the guards. + If symbols is None, it returns all the runtime asserts (and all the guards) + """ + if symbols is None: + runtime_asserts = ( + r.expr for rs in self.deferred_runtime_asserts.values() for r in rs + ) + else: + runtime_asserts = ( + r.expr + for s in symbols + if s not in self.var_to_val + for r in self.deferred_runtime_asserts.get(s, ()) + ) + guards: Iterator[SympyBoolean] = (g.expr for g in self.guards) + axioms: Iterator[SympyBoolean] = itertools.chain(guards, runtime_asserts) + if compute_hint: + axioms = ( + canonicalize_bool_expr(a.xreplace(self.var_to_val)) for a in axioms + ) + return tuple(dict.fromkeys(axioms).keys()) + + @lru_cache(None) + def get_implications( + self, e: SympyBoolean + ) -> tuple[tuple[SympyBoolean, sympy.logic.boolalg.BooleanAtom], ...]: + """Given a expression, it returns a list of predicates that follow from it""" + equiv: dict[SympyBoolean, sympy.logic.boolalg.BooleanAtom] = {} + + def add_expr(expr: SympyBoolean) -> None: + expr = canonicalize_bool_expr(expr) + if isinstance(expr, (sympy.Eq, sympy.Ne)): + # No need to canonicalize + # TODO We could further canonicalize Eq ordering the lhs and rhs somehow + # With this, we could remove the need for the commutativity part + opposite = sympy.Eq if isinstance(expr, sympy.Ne) else sympy.Ne + # Commutativity of == and != + equiv[type(expr)(expr.lhs, expr.rhs, evaluate=False)] = sympy.true + equiv[type(expr)(expr.rhs, expr.lhs, evaluate=False)] = sympy.true + equiv[opposite(expr.lhs, expr.rhs, evaluate=False)] = sympy.false + equiv[opposite(expr.rhs, expr.lhs, evaluate=False)] = sympy.false + else: + # Expr and negation + equiv[expr] = sympy.true + # we do not pass evaluate=False like others on purpose here! + # we want not(a=b and not ~(a Optional[sympy.Basic]: + """ + Tries to evaluate expr without introducing guards + + If unbacked_only == True, then we only do substitutions on + unbacked SymInts (leaving regular hinted integers alone). This could + result in an expression that still contains backed SymInts, which you + could then potentially guard on. + + Use compute_hint == True if you are trying to compute a non-binding + hint for the particular hint values of backed and unbacked SymInts, + e.g., if s0 happens to be 3 this run, compute_hint will substitute s0 with 3. + """ + + # axioms with compute hint NYE + assert not compute_hint or not axioms + expr = self.simplify(expr, size_oblivious) + + if compute_hint: + expr = expr.xreplace(self.var_to_val).xreplace(self.unbacked_var_to_val) + + expr = canonicalize_bool_expr(expr) + + def resimplify_floor_div(axioms: dict[sympy.Expr, sympy.Expr]) -> None: + if not self._resimplify_floor_div_axioms: + return + self._resimplify_floor_div_axioms = False + new_items = {} + for k, v in list(axioms.items()): + # A FloorDiv in implications could have became CleanDiv at this point, due to new facts + # to the shapeEnv. This handles such issue but its not ideal. This is the only expression + # simplification that depends on the global state of shape env. + # TODO try to get rid of CleanDiv since it breaks the invariant that's simplifications of sympy + # expressions only depend on the expression itself. + if k.has(FloorDiv): + new_items.update({self.simplify(k): v}) + axioms.update(new_items) + + # Pattern matching + if axioms is None: + resimplify_floor_div(self.axioms) + subst = self.axioms + else: + subst = {} + for e in axioms: + if e.free_symbols.issubset(expr.free_symbols): + subst.update(dict(self.get_implications(self.simplify(e)))) + + resimplify_floor_div(subst) + + expr = expr.xreplace(subst) + # TODO: compute hint might have gotten broken here + + fs = expr.free_symbols + + if not fs and (expr.is_number or expr.is_Boolean): + return expr + + if var_to_range is None: + var_ranges = self.var_to_range + else: + var_ranges = dict(var_to_range) + + symbol_info = tuple( + _SymbolInfo( + s, + var_ranges.get(s), + self.var_to_val.get(s), + s in self.size_like, + ) + for s in sorted(fs, key=str) # TODO: speed up sort? + ) + + r = _maybe_evaluate_static_worker( + expr, symbol_info, unbacked_only, size_oblivious + ) + return r + + @_lru_cache + def replace(self, expr: _SympyT) -> _SympyT: + """ + Apply symbol replacements to any symbols in the given expression. + """ + replacements = {} + # pyrefly: ignore [missing-attribute] + for s in expr.free_symbols: + r = self._find(s) + + # Micro-optimization: only do replacements if r and s are different + # Otherwise, xreplace is not a no-op and will trigger expensive + # assumption queries if expr has a relational node. + if not r.is_Symbol or r != s: + replacements[s] = r + if replacements: + # pyrefly: ignore [missing-attribute] + return safe_expand(expr.xreplace(replacements)) + else: + return expr + + @_lru_cache + def _update_divisible(self) -> None: + new_divisible = set() + for k in self.divisible: + res = self.replace(k) + if not res.is_number: + new_divisible.add(k) + + self.divisible = new_divisible + self._update_version_counter() + + @_lru_cache + def simplify(self, expr: _SympyT, size_oblivious: bool = False) -> _SympyT: + """Use known constraints and replacements to simplify the given expr""" + expr = safe_expand(expr) + expr = self.replace(expr) + + # Simplify max(0/1, x) to x when x >= 0/1. max(1, x) is a commonly introduced + # expression when creating contiguous strides. + if not size_oblivious: + min_max_replacements = {} + for atom in expr.atoms(Max): # type: ignore[has-type] + if len(atom.args) > 2: + continue + a, b = atom.args + if b == 1 or b == 0: + a, b = b, a + + if a == 1 and self._maybe_evaluate_static(sympy.Ge(b, 1)): + min_max_replacements[atom] = b + if a == 0 and self._maybe_evaluate_static(sympy.Ge(b, 0)): + min_max_replacements[atom] = b + if min_max_replacements: + expr = expr.xreplace(min_max_replacements) + + if expr.has(TruncToInt): + trunc_replacements = {} + for atom in expr.atoms(TruncToInt): + if isinstance(atom.args[0], IntTrueDiv): + base, divisor = atom.args[0].args + if base % divisor == 0: + trunc_replacements[atom] = CleanDiv(base, divisor) + else: + # TruncToInt(IntTrueDiv(a,b)) == FloorDiv(a, b) + trunc_replacements[atom] = FloorDiv(base, divisor) + if trunc_replacements: + expr = expr.xreplace(trunc_replacements) + + # TODO it would seem that this pass is not necessary given the + # below replacement of // with /, but for nested FloorDivs + # the non-recursive replacement doesn't work, and + # recursive makes it hard to look up divisibility, + # because existing divisibility info has FloorDiv in it, not / + # for now just do a separate pass to catch common nested case + if expr.has(FloorDiv): + self._update_divisible() + div_replacements = {} + for atom in expr.atoms(FloorDiv): + base, divisor = atom.args + if isinstance(divisor, FloorDiv): + base1, divisor1 = divisor.args + if ( + self.replace(Mod(base, divisor)) in self.divisible + and base == base1 + and self.replace(Mod(base1, divisor1)) in self.divisible + ): + div_replacements[atom] = divisor1 + if div_replacements: + expr = expr.xreplace(div_replacements) + expr = safe_expand(expr) + if expr.has(FloorDiv): + div_replacements = {} + pows = expr.atoms(sympy.Pow) + rationals = expr.atoms(sympy.Rational).difference(expr.atoms(sympy.Integer)) + for fd in expr.atoms(FloorDiv): + base, divisor = fd.args + if self.replace(Mod(base, divisor)) in self.divisible: + div_replacements[fd] = CleanDiv(base, divisor) + if div_replacements: + new_expr = expr.xreplace(div_replacements) + new_expr = safe_expand(new_expr) + new_pows = new_expr.atoms(sympy.Pow) + new_rationals = new_expr.atoms(sympy.Rational).difference( + new_expr.atoms(sympy.Integer) + ) + # divisions simplified away + if new_pows.issubset(pows) and new_rationals.issubset(rationals): + expr = new_expr + return expr + + # TODO: overload for allow_none literal + @lru_cache(256) + def size_hint( + self, expr: sympy.Basic, *, allow_none: bool = False + ) -> Optional[sympy.Basic]: + """ + Gets a size hint for a given expression from the underlying shapes we had. + Does not introduce a guard, so only use this when you can guarantee that + your code is still valid for arbitrary shapes (such as optimization decisions) + """ + result_expr = safe_expand(expr).xreplace(self.var_to_val) + if not result_expr.is_number: + from torch.utils._sympy.singleton_int import SingletonInt + + if isinstance(result_expr, SingletonInt): + return None + r = self._maybe_evaluate_static(result_expr, compute_hint=True) + if r is not None: + return r + if allow_none: + return None + + if self.oblivious_var_to_val: + # See https://github.com/pytorch/pytorch/issues/137100#issuecomment-2495778113 + correct_hint = result_expr.xreplace(self.oblivious_var_to_val) + counterfactual_hint = result_expr.xreplace( + {k: max(v, 2) for k, v in self.oblivious_var_to_val.items()} + ) + if ( + not correct_hint.free_symbols + and not counterfactual_hint.free_symbols + ): + if correct_hint == counterfactual_hint: + log.info("oblivious_size hit %s -> %s", expr, correct_hint) + return correct_hint + else: + log.info( + "oblivious_size counterfactual failed %s -> %s != %s", + expr, + correct_hint, + counterfactual_hint, + ) + else: + log.info( + "oblivious_size miss %s -> %s (counterfactual: %s)", + expr, + correct_hint, + counterfactual_hint, + ) + + if self.unbacked_var_to_val: + unsound_expr = result_expr.xreplace(self.unbacked_var_to_val) + if not unsound_expr.free_symbols: + log.warning( + "propagate_real_tensors size_hint(%s) -> %s", expr, unsound_expr + ) + trace_structured( + "propagate_real_tensors", + metadata_fn=lambda: { + "expr": repr(expr), + "result": repr(unsound_expr), + "stack": structured.from_traceback( + CapturedTraceback.extract(skip=1).summary() + ), + }, + ) + self.guard_or_defer_runtime_assert( + sympy.Eq(result_expr, unsound_expr), + f"propagate_real_tensors: {result_expr} == {unsound_expr}", + ) + return unsound_expr + + raise self._make_data_dependent_error(result_expr, expr) + return result_expr + + # NB: keep in sync with size_hint + @lru_cache(256) + def has_hint(self, expr: sympy.Expr) -> bool: + result_expr = safe_expand(expr).xreplace(self.var_to_val) + return ( + result_expr.is_number + or self._maybe_evaluate_static(result_expr) is not None + ) + + def _make_data_dependent_error( + self, + expr: sympy.Basic, + unhinted_expr: sympy.Basic, + *, + expr_sym_node_id: Optional[int] = None, + ) -> GuardOnDataDependentSymNode: + # TODO: in a Dynamo context, having user code, and having the + # name of the local, will be much better + size_like_symbols = [] + for s in expr.free_symbols: + stacktrace = "".join(self.var_to_stack[s].format()) + self.log.debug( + "Data dependent variable '%s' allocated at:\n%s", s, stacktrace + ) + if s in self.size_like: + size_like_symbols.append(s) + size_oblivious_result_msg = "" + sloc, maybe_extra_debug = self._get_stack_summary(True) + if expr.is_integer: # type: ignore[attr-defined] + desc = ( + "Could not extract specialized integer from data-dependent expression" + ) + else: + desc = "Could not guard on data-dependent expression" + size_oblivious_result_msg = ( + "consider using data-dependent friendly APIs such as " + "guard_or_false, guard_or_true and statically_known_true." + ) + + msg = ( + f"{desc} {expr} (unhinted: {unhinted_expr}). " + f"(Size-like symbols: {', '.join(map(str, size_like_symbols)) or 'none'})\n\n" + f"{size_oblivious_result_msg}\n" + f"Caused by: {sloc}\n" + 'For more information, run with TORCH_LOGS="dynamic"\n' + "For extended logs when we create symbols, also add " + f'TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="{",".join(map(str, expr.free_symbols))}"\n' + "If you suspect the guard was triggered from C++, add TORCHDYNAMO_EXTENDED_DEBUG_CPP=1\n" + "For more debugging help, see " + "https://docs.google.com/document/d/1HSuTTVvYH1pTew89Rtpeu84Ht3nQEFTYhAX3Ypa_xJs/edit?usp=sharing\n" + + maybe_extra_debug + # TODO: Help text about how to use our runtime tests to fix this + # problem + ) + + dtrace_structured( + "guard_on_data_dependent_error", + metadata_fn=lambda: { + "expr": repr(expr), + "unhinted_expr": repr(unhinted_expr), + "expr_id": self._expr_sym_node_id, + "stack": structured.from_traceback( + CapturedTraceback.extract(skip=1).summary() + ), + }, + ) + return GuardOnDataDependentSymNode(expr, msg) + + def _update_var_to_range( + self, + symbol: sympy.Symbol, + vr: ValueRanges, + vr_sloc: Optional[ValueRangesSLoc] = None, + *, + is_constraint: bool = False, + ) -> None: + lower, upper = vr.lower, vr.upper + + # If we have a size-like unbacked SymInt, refuse to refine the range to be + # less than two. This is because when we intersect this range + # with [2, inf] for size oblivious tests, the range would be + # unsatisfiable. In other words, once you have a size-like + # unbacked SymInt, we can never learn that it is exactly zero or one, + # because we would now give inconsistent results for all size + # oblivous tests! + if upper < 2 and symbol in self.size_like: + vr = ValueRanges(lower, 2) + + # Updates the range and the guards corresponding to each bound of the symbol. + if symbol not in self.var_to_range: + self.log.debug("_update_var_to_range %s = %s (new)", symbol, vr) + self.var_to_range[symbol] = vr + if vr_sloc is None: + sloc = self._get_sloc() + vr_sloc = ValueRangesSLoc(sloc, sloc) + self.var_to_range_sloc[symbol] = vr_sloc + else: + old = self.var_to_range[symbol] + new = old & vr + if new != old: + if vr_sloc is None: + sloc = self._get_sloc() + vr_sloc = ValueRangesSLoc(sloc, sloc) + if new.lower != old.lower: + self.var_to_range_sloc[symbol].lower = vr_sloc.lower + if new.upper != old.upper: + self.var_to_range_sloc[symbol].upper = vr_sloc.upper + self.var_to_range[symbol] = new + self.log.debug("_update_var_to_range %s = %s (update)", symbol, new) + + if (v := self.var_to_val.get(symbol)) is not None: + r = self.var_to_range[symbol] + if v not in r: + # For constraint failure, delay this for later + # TODO: Rework all of this, the constraint logic is very + # duplicative with regular reasoning + if not is_constraint: + assert v in r, f"{v} not in {r}" + + def _set_replacement(self, a: sympy.Symbol, tgt: sympy.Expr, msg: str) -> None: + """ + Adds or updates a replacement for a symbol. + Use this instead of `self.replacements[a] = tgt`. + """ + + if tgt == self.replacements.get(a, None): + return + + if a in tgt.free_symbols: + return + + # Precondition: a == tgt + assert isinstance(a, sympy.Symbol) + + if ( + self.prefer_deferred_runtime_asserts_over_guards + and not _is_supported_equivalence(tgt) + ): + return # continuing leads to placeholder shapes having complex expressions that we can't resolve + + # Handles nested tensor symbolic variables which don't have + # var_to_range bounds + tgt_bound = None + if a in self.var_to_range: + src_bound = self.var_to_range[a] + + # First, refine the value range of a based on the computed value range + # of tgt. This is always OK to do, even if we decide not to do the + # substitution in the end. This might be a no-op, if a already has + # a tighter bound + tgt_bound = self.bound_sympy(tgt) + self._update_var_to_range(a, tgt_bound) + + # Next, check if we can update the range of free symbols in tgt + # based on the range in a. But only do it if: + # - the source bound non-trivially improves over what we get out of + # the existing bounds. + # - the replacement is univariate and we can invert the tgt expression + if not tgt_bound.issubset(src_bound) and len(tgt.free_symbols) == 1: + b = next(iter(tgt.free_symbols)) + # Try to invert the equality + r = try_solve(sympy.Eq(a, tgt), b, floordiv_inequality=False) + if r is not None: + self.log.debug( + "set_replacement: solve for %s in %s == %s gives %s", + b, + a, + tgt, + r, + ) + # The solution here can be non-integral, for example, if + # we have s0 = 2*s1, then s1 = s0/2. What we would like + # to do is calculated the bounds in arbitrary precision, + # and then requantize the bound to integers when we are + # done. + rat_b_bound = self.bound_sympy(r[1]) + b_bound = ValueRanges( + CeilToInt(rat_b_bound.lower), FloorToInt(rat_b_bound.upper) + ) + self._update_var_to_range(b, b_bound, self.var_to_range_sloc[a]) + tgt_bound = self.bound_sympy(tgt) + assert tgt_bound.issubset(src_bound), ( + f"{tgt_bound=} not a subset of {src_bound=}" + ) + + # TODO: Should we propagate size-like-ness? + # + # Pros: if u0 is size-like, intuitively u0 == u1 should cause u1 + # to become size-like. + # + # Cons: if u0 is size-like, what about u0 - 1 == u1? You CAN'T + # propagate in this case, because what if u0 == 0, then u1 is negative + # and clearly isn't a size. So, at minimum, any f(x) whose value + # range isn't [0, inf] given x in [0, inf] cannot propagate + # size-like-ness. But there are many situations where you could + # imagine u1 is going to be size-like and actually you just didn't + # have a refined enough value range on u0. Since even innocuous + # looking arithmetic operations can destroy size-like-ness, it's + # best to not propagate it at all and force the user to annotate it + # as necessary. + # + # Compromise: we preserve size-like-ness only for exact equality + # and nothing else. + if a in self.size_like and isinstance(tgt, sympy.Symbol): + self.size_like.add(tgt) + elif isinstance(tgt, sympy.Symbol) and tgt in self.size_like: + self.size_like.add(a) + + # Now, decide if we will do the substitution. + # + # - If the source has a non-trivial range, only substitute if + # we preserve this range. Note that we may have propagated + # the src_range to free variables in tgt when tgt is univariate + # and we could find an inverse, which helps us achieve this. + # This ensures we never "forget" about user defined ranges, + # even if they end up being defined on composite formulas + # like s0 + s1. + # + # - If the variable is unbacked, only substitute if the substitution + # would preserve the bounds also under size-like-ness conditions. + + if not tgt_bound.issubset(src_bound): + self.log.debug( + "skipped set_replacement %s = %s (%s) [%s not subset of %s]", + a, + tgt, + msg, + tgt_bound, + src_bound, + ) + return + elif a in self.size_like: + tgt_bound_so = self.bound_sympy(tgt, size_oblivious=True) + src_bound_so = self.bound_sympy(a, size_oblivious=True) + if not tgt_bound_so.issubset(src_bound_so): + self.log.debug( + "skipped set_replacement %s = %s (%s) " + "[%s not subset of %s (size-oblivious conditions)]", + a, + tgt, + msg, + tgt_bound_so, + src_bound_so, + ) + return + + if isinstance(tgt, (sympy.Integer, sympy.Float)): + # specializing to a constant, which is likely unexpected (unless + # you specified dynamic=True) + + user_tb = TracingContext.extract_stack() + trace_structured( + "symbolic_shape_specialization", + metadata_fn=lambda: { + "symbol": repr(a), + "sources": [s.name for s in self.var_to_sources.get(a, [])], + "value": repr(tgt), + "reason": msg, + "stack": structured.from_traceback( + CapturedTraceback.extract(skip=1).summary() + ), + "user_stack": ( + structured.from_traceback(user_tb) if user_tb else None + ), + }, + ) + + for source in self.var_to_sources.get(a, []): + if user_tb: + self.specialization_stacks[source] = user_tb + + if config.print_specializations: + self.log.warning( + "Specializing %s to %s", self.var_to_sources[a][0].name, tgt + ) + self.log.debug("SPECIALIZATION", stack_info=True) + log.info("set_replacement %s = %s (%s) %s", a, tgt, msg, tgt_bound) + self.replacements[a] = tgt + # NB: the replacement may get refined, but the user will find the + # FIRST one most useful (TODO: Maybe we could consider tracking all of + # them) + if a not in self.replacements_slocs: + self.replacements_slocs[a] = self._get_sloc() + self._update_version_counter() + + # When specializing 'a == tgt', the equality should be also conveyed to + # Z3, in case an expression uses 'a'. + self._add_target_expr(sympy.Eq(a, tgt, evaluate=False)) + + def _add_divisible(self, expr: sympy.Expr) -> None: + self.divisible.add(expr) + self._update_version_counter() + + @_lru_cache + @record_shapeenv_event() + def _find(self, a: sympy.Symbol) -> sympy.Expr: + """ + Implements a DSU-like algorithm to find the variable that represents a + Also handles transitive non-identity replacements. + + a: b + c + c: d + """ + if a not in self.replacements: + return a + res = self.replacements[a] + cur_replace = {s: self._find(s) for s in res.free_symbols} + replaced, changed = self.replacements[a]._xreplace(cur_replace) + if changed: + self._set_replacement(a, replaced, "find") + return self.replacements[a] + + @lru_cache(256) + def _maybe_guard_rel(self, expr: sympy.Expr) -> None: + """ + The relational guard is guarded to be true. Use this information to + simplify shapes (i.e. a == b or a % 5 == 0) + """ + if isinstance(expr, sympy.And): + for arg in expr.args: + self._maybe_guard_rel(arg) + return + elif not isinstance(expr, sympy.Rel): + return + + # A good example of what goes wrong if you don't do this is + # python test/functorch/test_aotdispatch.py -k + # test_aot_autograd_symbolic_module_exhaustive_nn_LazyConv3d_cpu_float32 + if isinstance(expr, sympy.Ne): + return + + free = list(expr.free_symbols) + + assert len(free) > 0, ( + f"The expression should not be static by this point: {expr}" + ) + # In case of really gnarly expression, we don't blow up + if len(free) > 5: + return + + # Prioritize unbacked symints for solving by ordering them last. + # Prefer to simplify out lexicographically higher symbols (i.e. simplify out s4 over s3). + # (NB: this unfortunately isn't strictly equivalent to simplifying out newer symbols) + # Prefer to simplify out symbols with ephemeral sources. + def _smart_symbol_sort(x: sympy.Symbol) -> tuple[int, int, str]: + has_only_ephemeral_sources = x in self.var_to_sources and all( + s.is_ephemeral() for s in self.var_to_sources[x] + ) + # NB: size_hint is int, not sympy.Expr, do not use int_oo here + hint_size = self.size_hint(x, allow_none=True) + if hint_size is None: + size = sys.maxsize + elif symbol_is_type(x, SymT.SIZE): + assert isinstance(hint_size, sympy.Expr) + size = int(hint_size) + else: + size = sys.maxsize + name = x.name + # 1 puts ephemeral sourced symbols first when sorting in reverse + return (1 if has_only_ephemeral_sources else 0, size, name) + + free = sorted(free, key=_smart_symbol_sort, reverse=True) # type: ignore[attr-defined] + lhs = expr.lhs + rhs = expr.rhs + + self._refine_ranges(expr) + + # The rest of this stuff is for equality only + if not isinstance(expr, sympy.Eq): + return + + if not expr.has(Mod): + try: + floor_div_atoms = lhs.atoms(FloorDiv).union(rhs.atoms(FloorDiv)) + if len(floor_div_atoms) > 0 and any( + a.divisor != 1 for a in floor_div_atoms + ): + raise NotImplementedError + + # Never replace unbacked symbols with other unbacked symbols that are + # not function arguments. (ex:mark_unbacked symbols are fine to replace + # other unbacked, but not those coming from .item() calls). + + # This is error prone because you can cause references to + # unbacked symbols to time travel backwards. E.g., + # + # u1 = x.item() + # ... use of u1 ... + # u2 = y.item() + # u3 = z.item() + # torch._check(u1 == u2 + u3) + # + # If you replace u1 with u2 + u3, then the use of u1 now + # references u2 and u3 prior to them actually being bound at + # runtime. It's pretty inconvenient to setup control + # dependencies for substitutions, so ban it entirely. + def trivial_solve(lhs: sympy.Expr, rhs: sympy.Expr) -> bool: + if isinstance(lhs, sympy.Symbol): + if free_unbacked_symbols( + lhs + ) and not _free_non_source_unbacked_symbols( + rhs, self.unbacked_inputs + ): + return True + if symbol_is_type(lhs, SymT.FLOAT): + return True + # TODO: Maybe trivial solutions for int should also be + # done? + return False + + # short-circuit when no solving is needed + if trivial_solve(lhs, rhs): + self._set_replacement(lhs, self._find(rhs), "trivial_lhs") + elif trivial_solve(rhs, lhs): + self._set_replacement(rhs, self._find(lhs), "trivial_rhs") + else: + r = try_solve(expr, free[0], floordiv_inequality=False) + if r is not None and all( + t.is_integer for t in sympy.preorder_traversal(r[1]) + ): + new_var = self._find(r[1]) + ok = len(free_unbacked_symbols(new_var)) == 0 + if ok: + self._set_replacement(free[0], new_var, "solve") + + except NotImplementedError: + pass + else: + # expression has mod. + mod_expr = next(iter(expr.atoms(Mod))) + try: + r = try_solve(expr, mod_expr, floordiv_inequality=False) + if r is not None and r[1] == 0: + self._add_divisible(mod_expr) + except NotImplementedError: + pass + return + + # See: Note - On 0/1 specialization + def _default_value_range( + self, do_not_specialize_zero_one: bool = False + ) -> ValueRanges: + lower = 0 if (do_not_specialize_zero_one or not self.specialize_zero_one) else 2 + return ValueRanges(lower, int_oo) + + def _default_unspecified_value_range(self) -> ValueRanges: + return ValueRanges.unknown_int() + + @_lru_cache + def _simplify_floor_div(self, expr: sympy.Expr) -> sympy.Expr: + floor_divs = tuple(expr.atoms(FloorDiv)) + # we expect floor_divs to be exact, + # and thus add the guards for the exact floordivs, + # even if tracing doesn't require them otherwise + for fd in reversed(floor_divs): + base, divisor = fd.args + mod_expr = Mod(base, divisor) + eq_expr = sympy.Eq(mod_expr, 0) + # add necessary mod guards + self.evaluate_expr(eq_expr) + return self.simplify(expr) + + # We're about to add a guard/runtime assert, check if the ShapeEnv is frozen + # and if so issue a warning + def _check_frozen(self, expr: sympy.Basic, concrete_val: sympy.Basic) -> None: + if self.frozen: + self.counter["ignored_backward_guard"] += 1 + signpost_event( + "dynamic", + "evaluate_expr_frozen", + { + **self.co_fields, + "ignored_guard": f"{expr} == {concrete_val}", + # no version = original state (this signpost is expected) + # version 2 = dynamic backwards is eagerly compiled + "version": 2, + }, + ) + log.info( + "Ignored guard %s == %s, this could result in accuracy problems", + expr, + concrete_val, + # only print stack trace when debug mode is on (e.g. TORCH_LOGS="dynamic") + stack_info=log.getEffectiveLevel() < logging.WARNING, + ) + + def _get_user_frame(self) -> Optional[types.FrameType]: + frame = inspect.currentframe() + while frame is not None: + if frame.f_code.co_filename not in uninteresting_files(): + return frame + frame = frame.f_back + return frame + + def _get_stack_summary( + self, is_debug: bool = False, framework_loc: Optional[str] = None + ) -> tuple[SLoc, str]: + floc: Optional[Union[str, traceback.FrameSummary]] = framework_loc + if floc is None: + frame = self._get_user_frame() + try: + if frame is not None: + floc = traceback.FrameSummary( + frame.f_code.co_filename, + frame.f_lineno, + frame.f_code.co_name, + ) + finally: + del frame + + # NB: this stack is truncated, but it's fine because the main + # stack_info will give you the rest of the info you need + maybe_user_loc = None + user_tb = TracingContext.extract_stack() + if user_tb: + idx = len(user_tb) - 1 + while idx > 0 and user_tb[idx].filename in uninteresting_files(): + idx -= 1 + maybe_user_loc = format_frame(user_tb[idx], line=True) + + maybe_extra_debug = "" + if is_debug and user_tb: + maybe_extra_debug = ( + "\nUser Stack (most recent call last):\n" + + " (snipped, see stack below for prefix)\n" + + "".join(traceback.format_list(user_tb)) + ) + if is_debug and config.extended_debug_cpp: + cpp_stack = CapturedTraceback.extract(cpp=True) + maybe_extra_debug += "\nC++ stack trace:\n" + "".join(cpp_stack.format()) + elif is_debug: + maybe_extra_debug += ( + "\nFor C++ stack trace, run with TORCHDYNAMO_EXTENDED_DEBUG_CPP=1" + ) + + return SLoc(floc, maybe_user_loc), maybe_extra_debug + + # Pass in framework_loc to override the framework location info + def _get_sloc(self, framework_loc: Optional[str] = None) -> SLoc: + sloc, _ = self._get_stack_summary(framework_loc=framework_loc) + return sloc + + def _generate_unique_id(self, source_name: str) -> int: + attempt = int(hashlib.sha256(source_name.encode()).hexdigest(), 16) % 100 + while attempt in self.unique_ids: + attempt += 1 + self.unique_ids.add(attempt) + return attempt + + def _find_frame_locals(self) -> _FrameLocalResult: + """ + Given the current user code frame, finds the relevant lines of code, + values of symbolic locals, and free symbols involved. + """ + frame_locals: dict[str, Any] = {} + frame_symbols: dict[str, str] = {} + + if ( + frame := _find_user_code_frame() + ) is None or frame.f_code.co_filename == "": + return _FrameLocalResult() + + # find bytecode instructions relevant to the frame + instructions = list(dis.Bytecode(frame.f_code)) + co_lines, offset = inspect.getsourcelines(frame.f_code) + start, end, cur = None, None, None + # pyrefly: ignore [bad-assignment] + for i, instr in enumerate(instructions): + if instr.starts_line is not None: + cur = instr.starts_line + if cur != frame.f_lineno: + continue + if start is None: + start = end = i + else: + end = i + + if start is None or end is None: # no instructions found + return _FrameLocalResult() + + # track involved locals and free symbols + def go(x: Any) -> Optional[str]: + if isinstance(x, torch.Tensor): + for y in x.size(): + go(y) + for y in x.stride(): + go(y) + go(x.storage_offset()) + return ( + f"Tensor(shape: {x.size()}, " + f"stride: {x.stride()}, " + f"storage_offset: {x.storage_offset()})" + ) + elif isinstance(x, (SymBool, SymInt, SymFloat)): + for s in x.node.expr.free_symbols: + if str(s) in frame_symbols: # type: ignore[operator] + continue + if s in self.var_to_sources: + frame_symbols[str(s)] = self.var_to_sources[s][0].name # type: ignore[assignment] + return str(x) + return None + + # go through instructions, seeing linenos & involved locals + last_lineno = frame.f_lineno + for instr in instructions[start : end + 1]: + if (lineno := instr.starts_line) is not None: + last_lineno = max(last_lineno, lineno) + if isinstance(instr.argval, str) and instr.argval in frame.f_locals: + flat_locals = pytree.tree_flatten(frame.f_locals[instr.argval])[0] + frame_locals[instr.argval] = [ + go(flat_local) for flat_local in flat_locals + ] + + # store LOC + locs = co_lines[frame.f_lineno - offset : last_lineno + 1 - offset] + if not locs: + return _FrameLocalResult() + + indent = len(locs[0]) - len(locs[0].lstrip()) + frame_loc = "".join([loc[indent:] for loc in locs]).strip() # type: ignore[assignment] + return _FrameLocalResult( + loc=frame_loc, locals=frame_locals, symbols=frame_symbols + ) + + def _log_guard(self, prefix: str, g: SympyBoolean, forcing_spec: bool) -> None: + dtrace_structured( + "guard_added", + metadata_fn=lambda: { + "expr": str(g), + "prefix": prefix, + "expr_node_id": self._expr_sym_node_id, + "user_stack": structured.get_user_stack(3), + "stack": structured.get_framework_stack(3), + "symbol_to_sources": { + str(v): k + for k, v in self.source_to_var.items() + if v in g.free_symbols + }, + "frame_locals": asdict(self._find_frame_locals()), + }, + ) + trace_structured( + "guard_added_fast", + metadata_fn=lambda: { + "expr": str(g), + "user_stack": structured.from_traceback(TracingContext.extract_stack()), + "stack": structured.from_traceback( + CapturedTraceback.extract(skip=1).summary() + ), + }, + ) + if self.log.isEnabledFor(logging.INFO): + str_g = str(g) + is_debug = ( + config.extended_debug_guard_added is not None + and str_g == config.extended_debug_guard_added + ) + sloc, maybe_extra_debug = self._get_stack_summary(is_debug) + maybe_more_info = "" + if not is_debug: + maybe_more_info = ( + ", for more info run with " + f'TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="{str_g}"' + ) + self.log.info( + "%s %s [guard added] %s%s%s", + prefix if not forcing_spec else f"{prefix} (forcing_spec)", + str_g, + sloc, + maybe_more_info, + maybe_extra_debug, + stack_info=is_debug, + ) + + # A local variable to evaluate_expr stored in the class to avoid + # using it for the lru_cache that is on top of it since it does + # not effect the results. When needed its read directly. + _expr_sym_node_id: Optional[int] = None + + def evaluate_sym_node( + self, + sym_node: SymNode, + size_oblivious: bool = False, + fallback_value: Optional[bool] = None, + ) -> sympy.Basic: + """ + Given a a SymNode, evaluates sym_node.expr, adding guards if necessary. + """ + + self._expr_sym_node_id = id(sym_node) + return self.evaluate_expr( + sym_node.expr, + sym_node.hint, + sym_node.fx_node, + size_oblivious, + fallback_value=fallback_value, + ) + + def _is_python_assert(self) -> bool: + # Check if this boolean is used in an assertion, bytecode pattern for + # assertions is pretty stable for Python 3.7--3.13, ported with minimal + # changes from torch/fx/proxy.py + # Bytecode pattern for `assert` statements: + # TO_BOOL / COMPARE_OP # Only for Python >= 3.13 + # POP_JUMP_IF_TRUE + # LOAD_ASSERTION_ERROR + # RAISE_VARARGS + frame = self._get_user_frame() + assert frame is not None + + insts = list(dis.get_instructions(frame.f_code)) + if sys.version_info >= (3, 11): + # For Python >= 3.11, instructions can be 2-4 bytes long. + from bisect import bisect_left + + cur = bisect_left(insts, frame.f_lasti, key=lambda x: x.offset) + else: + # For Python <= 3.10, instructions are always 2 bytes. + cur = frame.f_lasti // 2 + + if sys.version_info >= (3, 13): + if insts[cur].opname in ("TO_BOOL", "COMPARE_OP"): + # Peek 1 instruction further. + cur += 1 + + assert_insts = torch._dynamo.symbolic_convert.get_assert_bytecode_sequence( + False + ) + + cur_insts = insts[cur + 1 : cur + 1 + len(assert_insts)] + cur_insts = [inst.opname for inst in cur_insts] + return cur_insts == assert_insts + + def _log_real_tensor_propagation( + self, orig_expr: sympy.Basic, unsound_result: sympy.Basic + ) -> None: + log.warning( + "propagate_real_tensors evaluate_expr(%s) -> %s", + orig_expr, + unsound_result, + ) + trace_structured( + "propagate_real_tensors", + metadata_fn=lambda: { + "expr": repr(orig_expr), + "result": repr(unsound_result), + "stack": structured.from_traceback( + CapturedTraceback.extract(skip=1).summary() + ), + }, + ) + dtrace_structured( + "propagate_real_tensors_provenance", + metadata_fn=lambda: { + "expr": repr(orig_expr), + "result": repr(unsound_result), + "expr_node_id": self._expr_sym_node_id, + "user_stack": structured.get_user_stack(3), + "stack": structured.get_framework_stack(3), + "symbol_to_sources": { + str(v): k + for k, v in self.source_to_var.items() + if v in orig_expr.free_symbols + }, + "frame_locals": asdict(self._find_frame_locals()), + }, + ) + + def evaluate_expr( + self, + orig_expr: sympy.Basic, + hint: Optional[Union[int, bool, float]] = None, + fx_node: Optional[torch.fx.Node] = None, + size_oblivious: bool = False, + fallback_value: Optional[bool] = None, + *, + forcing_spec: bool = False, + ) -> sympy.Basic: + """ + Given an expression, evaluates it, adding guards if necessary + When fallback_value is not None the function return fallback_value instead of failing with data dependent error. + """ + + # Add extra state that evaluate_expr() depends on. + suppress_guards_tls = ShapeEnv._suppress_guards_tls() + return self._inner_evaluate_expr( + orig_expr, + hint, + fx_node, + size_oblivious, + forcing_spec, + suppress_guards_tls, + fallback_value, + ) + + @lru_cache(256) + @record_shapeenv_event(save_tracked_fakes=True, name="evaluate_expr") + def _inner_evaluate_expr( + self, + orig_expr: sympy.Basic, + hint: Optional[Union[int, bool, float]], + fx_node: Optional[torch.fx.Node], + size_oblivious: bool, + forcing_spec: bool, + _suppress_guards_tls: bool, + fallback_value: Optional[bool] = None, + ) -> sympy.Basic: + try: + return self._evaluate_expr( + orig_expr, + hint, + fx_node, + size_oblivious, + fallback_value, + forcing_spec=forcing_spec, + ) + except Exception as e: + if isinstance(e, GuardOnDataDependentSymNode): + pass + else: + self.log.warning( + "failed during evaluate_expr(%s, hint=%s, size_oblivious=%s, forcing_spec=%s", + orig_expr, + hint, + size_oblivious, + forcing_spec, + ) + raise + + def _log_suppressed_dde(self, a: SymBool, assumed_value: bool) -> None: + sloc, extra = self._get_stack_summary(True) + log.info( + "could not evaluate %s due to data dependency, it was assumed to be %s with no runtime assertions %s %s", + a, + assumed_value, + sloc, + extra, + ) + + def _evaluate_expr( + self, + orig_expr: sympy.Basic, + hint: Optional[Union[bool, int, float]] = None, + fx_node: Optional[torch.fx.Node] = None, + size_oblivious: bool = False, + fallback_value: Optional[bool] = None, + *, + forcing_spec: bool = False, + ) -> sympy.Basic: + # TODO: split conjunctions and evaluate them separately + if isinstance( + orig_expr, + (sympy.logic.boolalg.BooleanTrue, sympy.logic.boolalg.BooleanFalse), + ): + return orig_expr + + # Don't track this one. (Because this cache is inside this function the + # cache only lasts for the invocation of this function call) + @functools.cache + def compute_concrete_val() -> sympy.Basic: + if hint is None: + # This is only ever called for expressions WITHOUT unbacked + # symbols + r = self.size_hint(orig_expr) + assert r is not None + return r + else: + return sympy.sympify(hint) + + concrete_val: Optional[sympy.Basic] + + # Check if: + # 1. 'translation_validation' is set + # 2. the corresponding 'fx_node' is not 'None' + # 3. the guard should not be suppressed + # 4. the guard doesn't contain backed symfloat symbols + # since z3 can't handle floats + # 5. fallback_value is none. + # If all of the above check, we create an FX node representing the + # actual expression to be guarded. + node = None + fresh = False + if ( + self._translation_validation_enabled + and fx_node is not None + and not self._suppress_guards_tls() + and not size_oblivious + and not any(symbol_is_type(s, SymT.FLOAT) for s in orig_expr.free_symbols) + and fallback_value is None + ): + # TODO: does this even worked with unbacked :think: + concrete_val = compute_concrete_val() + if concrete_val is sympy.true: + node, fresh = self._create_fx_call_function(torch._assert, (fx_node,)) + elif concrete_val is sympy.false: + neg, _ = self._create_fx_call_function(operator.not_, (fx_node,)) + node, fresh = self._create_fx_call_function(torch._assert, (neg,)) + else: + eql, _ = self._create_fx_call_function( + operator.eq, (fx_node, concrete_val) + ) + node, fresh = self._create_fx_call_function(torch._assert, (eql,)) + + assert node is not None + # If this is a fresh node, we have to remember the event index that + # corresponds to this assertion node. + # Reason: so that, given an assertion node, we can replay the ShapeEnv + # events until the point where this assertion node was freshly created. + if fresh: + self._add_fx_node_metadata(node) + + # After creating the FX node corresponding to orig_expr, we must make sure that + # no error will be raised until the end of this function. + # + # Reason: the translation validation may become invalid otherwise. + # + # If an error is raised before the end of this function, we remove the FX node + # inserted, and re-raise the error. + guard = None + + try: + if orig_expr.is_number: + self.log.debug("eval %s [trivial]", orig_expr) + if hint is not None: + if isinstance(hint, bool): + assert orig_expr == hint, f"{orig_expr} != {hint}" + else: + assert sympy.Eq(orig_expr, hint), f"{orig_expr} != {hint}" + return orig_expr + + expr = orig_expr + + static_expr = self._maybe_evaluate_static( + expr, size_oblivious=size_oblivious + ) + if static_expr is not None: + self.log.debug( + "eval %s == %s [statically known]", + ( + f"size_oblivious({orig_expr})" + if size_oblivious + else size_oblivious + ), + static_expr, + ) + if ( + not size_oblivious + and config.backed_size_oblivious + and hint is not None + ): + # TODO: maybe reconcile this with use of counterfactual hints + # in unbacked case + assert static_expr == hint, f"{static_expr} != {hint}" + return static_expr + + transmute_into_runtime_assert = False + + concrete_val = None + if not (expr.free_symbols <= self.var_to_val.keys()): + # TODO: dedupe this with _maybe_evaluate_static + # Attempt to eliminate the unbacked SymInt + new_expr = self._maybe_evaluate_static(expr, unbacked_only=True) + assert new_expr is not None + if not (new_expr.free_symbols <= self.var_to_val.keys()): + ok = False + + # fallback_value is set when guard_or_true or guard_or_false are used. + if not ok and fallback_value is not None: + self._log_suppressed_dde(orig_expr, fallback_value) + return fallback_value + + # oblivious_var_to_val will be defined iff we have sizes with DimDynamic.OBLIVIOUS_SIZE type. + # See https://github.com/pytorch/pytorch/issues/137100#issuecomment-2495778113 + if ( + self.oblivious_var_to_val + and not ( + correct_hint := orig_expr.xreplace( + self.oblivious_var_to_val + ) + ).free_symbols + and not ( + counterfactual_hint := orig_expr.xreplace( + { + k: max(2, v) + for k, v in self.oblivious_var_to_val.items() + } + ) + ).free_symbols + and correct_hint == counterfactual_hint + ): + # TODO: better logging + log.info( + "oblivious_size %s -> %s (passed counterfactual)", + orig_expr, + correct_hint, + ) + + concrete_val = correct_hint + # NB: do NOT transmute into runtime assert + ok = True + + # unbacked_var_to_val is not None iff propagate_real_tensors is on. + # if propagate_real_tensors is on, we check the example values to generate (unsound_result) + # and if they pass we add a runtime assertions and continue. + if ( + not ok + and self.unbacked_var_to_val + and not ( + unsound_result := orig_expr.xreplace( + self.unbacked_var_to_val + ).xreplace(self.var_to_val) + ).free_symbols + ): + self._log_real_tensor_propagation(orig_expr, unsound_result) + transmute_into_runtime_assert = True + + concrete_val = unsound_result + ok = True + + # Check if this is coming from a python assert statement, if so, convert it to a runtime assertion + # instead of failing. + if not ok and self.trace_asserts and self._is_python_assert(): + concrete_val = sympy.true + transmute_into_runtime_assert = True + ok = True + + if not ok: + raise self._make_data_dependent_error( + expr.xreplace(self.var_to_val), + expr, + expr_sym_node_id=self._expr_sym_node_id, + ) + else: + expr = new_expr + + if concrete_val is None: + concrete_val = compute_concrete_val() + self._check_frozen(expr, concrete_val) + + if ( + config.inject_EVALUATE_EXPR_flip_equality_TESTING_ONLY + and isinstance(hint, bool) + and isinstance(expr, (sympy.Eq, sympy.Ne)) + ): + expr = sympy.Not(expr) + + # Turn this into a boolean expression, no longer need to consult + # concrete_val + if concrete_val is sympy.true: + g = cast(SympyBoolean, expr) + elif concrete_val is sympy.false: + g = sympy.Not(expr) + else: + g = sympy.Eq(expr, concrete_val) # type: ignore[arg-type] + + if transmute_into_runtime_assert: + self.guard_or_defer_runtime_assert( + g, f"propagate_real_tensors: {orig_expr} == {concrete_val}" + ) + return concrete_val + + if not self._suppress_guards_tls(): + self._log_guard("eval", g, forcing_spec=forcing_spec) + + # TODO: If we successfully eliminate a symbol via equality, it + # is not actually necessary to save a guard for the equality, + # as we will implicitly generate a guard when we match that + # input against the symbol. Probably the easiest way to + # implement this is to have maybe_guard_rel return a bool + # saying if it "subsumed" the guard (and therefore the guard + # is no longer necessary) + self._maybe_guard_rel(g) + + if ( + torch.compiler.is_exporting() + and self.prefer_deferred_runtime_asserts_over_guards + ): + # it's fine to defer simple guards here without checking, + # the _maybe_guard_rel() call above will set replacements if possible, + # and so the result here will be statically known + self.guard_or_defer_runtime_assert(g, f"evaluate_expr: {orig_expr}") + else: + # at this point, we've evaluated the concrete expr value, and have + # flipped/negated the guard if necessary. Now we know what to guard + # or defer to runtime assert on. + guard = ShapeGuard( + g, self._get_sloc(), size_oblivious=size_oblivious + ) + self.guards.append(guard) + self.axioms.update(dict(self.get_implications(self.simplify(g)))) + else: + self._log_guard("eval [guard suppressed]", g, forcing_spec=forcing_spec) + + except Exception: + if fresh: + self._remove_fx_node(node) + raise + + if not self._suppress_guards_tls(): + if guard is not None: # we might have deferred this to runtime assert + for s in g.free_symbols: + self.symbol_guard_counter[s] += 1 + # Forcing_spec to avoid infinite recursion + if ( + not forcing_spec + and config.symbol_guard_limit_before_specialize is not None + and self.symbol_guard_counter[s] + > config.symbol_guard_limit_before_specialize + ): + # Force specialization + self.log.info( + "symbol_guard_limit_before_specialize=%s exceeded on %s", + config.symbol_guard_limit_before_specialize, + s, + ) + self.evaluate_expr(s, forcing_spec=True) + + return concrete_val + + def cleanup(self) -> None: + """ + Break reference cycles. + + This destroys the stacks. If you really want to keep them, we + just need some way to break references on code objects. + """ + for s in self.var_to_stack.values(): + s.cleanup() + for ras in self.deferred_runtime_asserts.values(): + for ra in ras: + ra.stack.cleanup() + + @lru_cache(256) + @record_shapeenv_event(save_tracked_fakes=True) + def guard_or_defer_runtime_assert( + self, orig_expr: SympyBoolean, msg: str, fx_node: Optional[torch.fx.Node] = None + ) -> bool: + """ + Adds a guard that orig_expr is True if we can or fall back to adding an assert + that is checked at runtime. + + Args: + orig_expr (sympy.Expr): Boolean expression to assert is true + msg (str): Message to display on assertion failure + fx_node (Optional, torch.fx.Node): node in ``self.graph`` corresponding + to the expression, if applicable + """ + expr = orig_expr + + # TODO: split conjunctions and evaluate them separately + + static_expr = self._maybe_evaluate_static(expr) + if static_expr is not None: + self.log.debug( + "runtime_assert %s == %s [statically known]", orig_expr, static_expr + ) + # TODO: assert bool(static_expr) + return bool(static_expr) + + # Attempt to eliminate the unbacked SymInt + new_expr = self._maybe_evaluate_static(expr, unbacked_only=True) + assert new_expr is not None + if ( + not self.prefer_deferred_runtime_asserts_over_guards + and new_expr.free_symbols <= self.var_to_val.keys() + ): + # Do a normal guard + return self.evaluate_expr(new_expr, fx_node=fx_node) + # NB: Don't use new_expr as expr; it could contain gunk like shape0 + # which we don't want to guard on + + if ( + self._translation_validation_enabled + and fx_node is not None + and not self._suppress_guards_tls() + ): + node, fresh = self._create_fx_call_function(torch._assert, (fx_node,)) + assert node is not None + if fresh: + self._add_fx_node_metadata(node) + + if not self._suppress_guards_tls(): + self._log_guard("runtime_assert", orig_expr, forcing_spec=False) + # If you're here because of this assert, read Note [Backwards runtime asserts] + # in torch/_inductor/graph.py + if self.runtime_asserts_frozen: + log.debug("runtime_asserts_frozen but then got %s", expr) + self._check_frozen(expr, sympy.true) + # eliminate symbols on equality tests / refine ranges + self._maybe_guard_rel(expr) + + # canonicalise to remove equations that are trivially equal + orig_expr = expr + expr = canonicalize_bool_expr(expr) + stack = CapturedTraceback.extract(skip=1) + ra = RuntimeAssert(expr, msg, stack) + + # TODO: Do this in a way that is less janky than int(s.name[1:]) + cands = sorted( + (s for s in expr.free_symbols if symbol_is_type(s, SymT.UNBACKED_INT)), + key=lambda s: int(s.name[1:]), + ) + # Is None when prefer_deferred_runtime_asserts_over_guards=True + # and the guard in question has no unbacked SymInts in front + ix = cands[-1] if cands else None + self.deferred_runtime_asserts.setdefault(ix, []).append(ra) + self.axioms.update(dict(self.get_implications(self.simplify(expr)))) + self.num_deferred_runtime_asserts += 1 + self._update_version_counter() + else: + self._log_guard( + "runtime_assert [guard suppressed]", orig_expr, forcing_spec=False + ) + + return True + + # Refines the ranges of the variables present in 'guard'. + # + # This function tries to refine the range of the variables inside + # 'guard' by reasoning about it. Specifically, when 'guard' is a + # 'sympy.Relational' operation. + # + # It does mainly 3 things: + # 1. Tries to isolate a variable in the left-hand side + # 2. Compute the value range of the right-hand side + # 3. Update the value range of the variable, if better + def _refine_ranges(self, expr: SympyBoolean) -> None: + expr = self.simplify(expr) + + for symbol in expr.free_symbols: + assert isinstance(symbol, sympy.Symbol) + + if isinstance(self.var_to_val.get(symbol, None), SingletonInt): + # Skip var_to_range logic for SingletonInt which is only used + # for jagged layout NestedTensors today + continue + + r = try_solve(expr, symbol) + + if r is None or not (symbol.is_integer and r[1].is_integer): + # Range refinement only supports integer symbols for now. + # There are lots of SymPy bugs when it comes to comparing + # reals and integers, so we skip that for now. + continue + + r_expr, rhs = r + vr = self.var_to_range[symbol] + lower, upper = vr.lower, vr.upper + + rhs_vr = bound_sympy(rhs, self.var_to_range) + + # Let's suppose that we have a preexisting range for x [0, 100]. + # Now, we issue a guard x > y, where the range for y is [50, 150]. + # Then, lower = 0, rhs_vr.lower = 50 and therefore refinement can happen, + # refining x to [51, 100], since x must be greater than y, but the lowest + # y could be is 50. + # + # sympy.Eq may update both lower and upper bounds. + # sympy.G{t,e} may update the lower bound, only. + # sympy.L{t,e} may update the upper bound, only. + if lower <= rhs_vr.lower and isinstance( + r_expr, (sympy.Eq, sympy.Ge, sympy.Gt) + ): + # Strictly greater relations allow us to refine a bit more, since + # x < y implies that the lower bound for x is: y + 1. + lower = rhs_vr.lower + int(isinstance(r_expr, sympy.Gt)) + if upper >= rhs_vr.upper and isinstance( + r_expr, (sympy.Eq, sympy.Le, sympy.Lt) + ): + upper = rhs_vr.upper - int(isinstance(r_expr, sympy.Lt)) + + # Do nothing if the new value range is no better than what we already have. + if vr == ValueRanges(lower, upper): + continue + + # Updates the range and the guards corresponding to each bound of the symbol. + self._update_var_to_range(symbol, ValueRanges(lower, upper)) + # If the range is refined to singleton, set replacement + if self.var_to_range[symbol].is_singleton(): + self._set_replacement( + symbol, + self.var_to_range[symbol].lower, + "range_refined_to_singleton", + ) + + # Clears the cache, since this update can change the result. + self._maybe_evaluate_static.cache_clear() + + @lru_cache(maxsize=None) + @record_shapeenv_event() + def constrain_symbol_range( + self, s: sympy.Symbol, compiler_min: int, compiler_max: int + ) -> None: + upd_vr = ValueRanges(compiler_min, compiler_max) + old_vr = self.var_to_range.get(s, ValueRanges.unknown()) + self._update_var_to_range(s, upd_vr) + if (new_vr := self.var_to_range[s]) != old_vr: + log.info( + "constrain_symbol_range %s [%s, %s]", s, new_vr.lower, new_vr.upper + ) + + +def _is_int(expr: object) -> bool: + return isinstance(expr, SymInt) and expr.node.expr.is_number + + +# WARNING: This is legacy, DO NOT USE +def _is_dim_dynamic(t: torch.Tensor, d: int) -> bool: + return hasattr(t, "_dynamo_dynamic_indices") and d in t._dynamo_dynamic_indices + + +class PropagateUnbackedSymInts(torch.fx.Interpreter): + def run_node(self, n: torch.fx.Node) -> Result: + """ + Run an FX node, propagating unbacked Symbol bindings to the new fake tensor + """ + from torch._guards import detect_fake_mode + + result = super().run_node(n) + fake_mode = detect_fake_mode() + assert fake_mode is not None + rebind_unbacked(fake_mode.shape_env, n, result) + return result + + +def _find_user_code_frame() -> Optional[types.FrameType]: + frame = inspect.currentframe() + while frame is not None: + if not frame.f_code.co_filename.startswith( + os.path.dirname(inspect.getfile(torch)) + os.path.sep + ): + break + frame = frame.f_back + return frame + + +def _blame_user_code(e: Exception, frame: types.FrameType) -> None: + frame_summary = traceback.FrameSummary( + frame.f_code.co_filename, + frame.f_lineno, + frame.f_code.co_name, + ) + msg = e.args[0] + msg += "\n\nThe following call raised this error:\n" + "".join( + traceback.StackSummary.from_list([frame_summary]).format() + ) + e.args = (msg,) + + +class _PythonMsgPrinter(PythonPrinter): + """ + Util printer that replaces sympy symbols with their source-level names + and renders sympy relational operators (e.g., Eq, Ne, Ge, Le) inline + (i.e., as ==, !=, >, <). + """ + + def __init__(self, src_map: dict[str, list[str]]) -> None: + super().__init__() + self.src_map = src_map + + def _print_Symbol(self, sym: sympy.Symbol) -> str: + return self.src_map[sym.name][0] + + +def _suggest_torch_checks( + e: GuardOnDataDependentSymNode, src_map: defaultdict[str, list[str]] +) -> None: + """ + Enhances a GuardOnDataDependentSymNode error with suggested fixes using torch._check. + + This function analyzes the condition that caused the data-dependent error and generates + user-friendly suggestions for fixing it by adding appropriate torch._check calls. + It handles special cases like non-negative checks with specific recommendations. + + Args: + e: The GuardOnDataDependentSymNode error to enhance with suggestions + src_map: A mapping from symbol names to their corresponding source-level variable names + + Returns: + None. Modifies the error message in-place by updating e.args[0]. + """ + # extract the unresolved condition on unbacked symints in the error + cond = e.cond + diff = ", ".join(s.name for s in cond.free_symbols if s.name not in src_map) + if diff: + log.warning("Unable to find user code corresponding to {%s}", diff) + return + printer = _PythonMsgPrinter(src_map) + msg = e.args[0] + msg += "\nTo fix the error, insert one of the following checks before this call:" + + not_cond_str = printer.doprint(sympy.Not(cond)) + + # suggested fixes to resolve `cond` are to tell the compiler to assume + # either `cond` or its negation (the user will need to select which) + suggested_fixes = [ + f"torch._check({printer.doprint(cond)})", + f"torch._check({not_cond_str})", + ] + + for i, fix in enumerate(suggested_fixes): + msg += f"\n {i + 1}. {fix}" + src_mapped = ", ".join( + f"`{s}` with {' or '.join(src_map[s])}" + for s in sorted(s.name for s in cond.free_symbols) + ) + msg += f"\n\n(These suggested fixes were derived by replacing {src_mapped} in {cond} and its negation.)" + e.args = (msg,) + + +def _suggest_fixes_for_data_dependent_error_non_strict( + e: GuardOnDataDependentSymNode, +) -> None: + """ + Given a raised data-dependent error, add the following to the error message: + 1. the closest user code location that raised the error; + 2. suggested fixes for the error in terms of live variables at that location. + """ + + # walk the stack up from the data-dependent error until a non-torch frame is found + frame = _find_user_code_frame() + if frame is not None: + # add frame info to error message + _blame_user_code(e, frame) + + # map symbol names reachable via frame locals to their source-level names + src_map = defaultdict(list) + for var, val in frame.f_locals.items(): + try: + tree_leaves_with_path = pytree.tree_leaves_with_path(val) + except ValueError: + log.warning( + "pytree.tree_leaves_with_path failed for value of type {%s} in local variable {%s}", + type(val), + var, + ) + continue + # figure out how to access any symbol inside `val` through `var` + for path, leaf in tree_leaves_with_path: + name = var + pytree.keystr(path) + if isinstance(leaf, torch.SymInt): + src_map[str(leaf.node.expr)].append(name) + elif isinstance(leaf, torch.Tensor): + for i, dim in enumerate(leaf.shape): + if isinstance(dim, torch.SymInt): + src_map[str(dim.node.expr)].append(f"{name}.shape[{i}]") + + # add suggested torch.check()s based on `src_map` to the error message + # replacing unbacked symints in the unresolved condition in the error + if isinstance(e.cond, sympy.logic.boolalg.Boolean): + _suggest_torch_checks(e, src_map) + + +@contextmanager +def _remove_effect_token_unbacked_bindings( + node: torch.fx.Node, +) -> Generator[None, None, None]: + """ + Temporarily modifies unbacked_bindings in a node's metadata by removing the first element + of each path, which corresponds to an effect token. + + This is used when processing nodes that have effect tokens as the first element in their + unbacked_bindings paths. The context manager ensures that the original bindings are + restored after the operation is complete. + + Args: + node: The FX node whose unbacked_bindings will be temporarily modified + + Yields: + None + """ + old_bindings = node.meta.get("unbacked_bindings", {}) + + # Remove the extra layer for effect token + new_bindings = {k: path[1:] if path else path for k, path in old_bindings.items()} + + node.meta["unbacked_bindings"] = new_bindings + + try: + yield + finally: + node.meta["unbacked_bindings"] = old_bindings + + +# This helper function is used in passes that insert runtime assertions in the graph. +# When accessing expressions representing input placeholders, we do not apply replacements +# since those inputs should be seen by assertions that use them to be inserted. The only replacement +# that we apply is unbacked renaming. +def _get_placeholder_expr(sym_node: SymNode) -> sympy.Expr: + shape_env = sym_node.shape_env + result = sym_node._expr + if result in shape_env.unbacked_renamings: + return shape_env.unbacked_renamings[result] + return result diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/unification/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/unification/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7db0e29d1d4f75c770562c65013c03817643f6b7 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/unification/__init__.py @@ -0,0 +1,4 @@ +# mypy: disable-error-code=attr-defined +from .core import reify, unify # noqa: F403 +from .more import unifiable # noqa: F403 +from .variable import isvar, Var, var, variables, vars # noqa: F403 diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/unification/core.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/unification/core.py new file mode 100644 index 0000000000000000000000000000000000000000..3d8071c847ae5da144d7ab57b5d24e7968b5daf6 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/unification/core.py @@ -0,0 +1,141 @@ +# mypy: allow-untyped-defs +from collections.abc import Iterator # type: ignore[import] +from functools import partial + +from .dispatch import dispatch +from .unification_tools import assoc # type: ignore[import] +from .utils import transitive_get as walk +from .variable import isvar + + +__all__ = ["reify", "unify"] + +############### +# Reification # +############### + + +@dispatch(Iterator, dict) +def _reify(t, s): + return map(partial(reify, s=s), t) + # return (reify(arg, s) for arg in t) + + +_reify + + +@dispatch(tuple, dict) # type: ignore[no-redef] +def _reify(t, s): + return tuple(reify(iter(t), s)) + + +_reify + + +@dispatch(list, dict) # type: ignore[no-redef] +def _reify(t, s): + return list(reify(iter(t), s)) + + +_reify + + +@dispatch(dict, dict) # type: ignore[no-redef] +def _reify(d, s): + return {k: reify(v, s) for k, v in d.items()} + + +_reify + + +@dispatch(object, dict) # type: ignore[no-redef] +def _reify(o, s): + return o # catch all, just return the object + + +def reify(e, s): + """Replace variables of expression with substitution + >>> # xdoctest: +SKIP + >>> x, y = var(), var() + >>> e = (1, x, (3, y)) + >>> s = {x: 2, y: 4} + >>> reify(e, s) + (1, 2, (3, 4)) + >>> e = {1: x, 3: (y, 5)} + >>> reify(e, s) + {1: 2, 3: (4, 5)} + """ + if isvar(e): + return reify(s[e], s) if e in s else e + return _reify(e, s) + + +############### +# Unification # +############### + +seq = tuple, list, Iterator + + +@dispatch(seq, seq, dict) # type: ignore[arg-type] +def _unify(u, v, s): + if len(u) != len(v): + return False + for uu, vv in zip(u, v): # avoiding recursion + s = unify(uu, vv, s) + if s is False: + return False + return s + + +# +# @dispatch((set, frozenset), (set, frozenset), dict) +# def _unify(u, v, s): +# i = u & v +# u = u - i +# v = v - i +# return _unify(sorted(u), sorted(v), s) +# +# +# @dispatch(dict, dict, dict) +# def _unify(u, v, s): +# if len(u) != len(v): +# return False +# for key, uval in iteritems(u): +# if key not in v: +# return False +# s = unify(uval, v[key], s) +# if s is False: +# return False +# return s +# +# +# @dispatch(object, object, dict) +# def _unify(u, v, s): +# return False # catch all + + +@dispatch(object, object, dict) +def unify(u, v, s): # no check at the moment + """Find substitution so that u == v while satisfying s + >>> x = var("x") + >>> unify((1, x), (1, 2), {}) + {~x: 2} + """ + u = walk(u, s) + v = walk(v, s) + if u == v: + return s + if isvar(u): + return assoc(s, u, v) + if isvar(v): + return assoc(s, v, u) + return _unify(u, v, s) + + +unify + + +@dispatch(object, object) # type: ignore[no-redef] +def unify(u, v): + return unify(u, v, {}) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/unification/dispatch.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/unification/dispatch.py new file mode 100644 index 0000000000000000000000000000000000000000..72b950c5b36d67f34cca322ffbbf6851b151de36 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/unification/dispatch.py @@ -0,0 +1,8 @@ +from functools import partial + +from .multipledispatch import dispatch as _dispatch # type: ignore[import] + + +namespace = {} # type: ignore[var-annotated] + +dispatch = partial(_dispatch, namespace=namespace) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/unification/match.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/unification/match.py new file mode 100644 index 0000000000000000000000000000000000000000..01861a086f64b6121aa9e174d16176533cd0e1a5 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/unification/match.py @@ -0,0 +1,129 @@ +# mypy: allow-untyped-defs +from .core import reify, unify # type: ignore[attr-defined] +from .unification_tools import first, groupby # type: ignore[import] +from .utils import _toposort, freeze +from .variable import isvar + + +class Dispatcher: + def __init__(self, name): + self.name = name + self.funcs = {} + self.ordering = [] + + def add(self, signature, func): + self.funcs[freeze(signature)] = func + self.ordering = ordering(self.funcs) + + def __call__(self, *args, **kwargs): + func, _ = self.resolve(args) + return func(*args, **kwargs) + + def resolve(self, args): + n = len(args) + for signature in self.ordering: + if len(signature) != n: + continue + s = unify(freeze(args), signature) + if s is not False: + result = self.funcs[signature] + return result, s + raise NotImplementedError( + "No match found. \nKnown matches: " + + str(self.ordering) + + "\nInput: " + + str(args) + ) + + def register(self, *signature): + def _(func): + self.add(signature, func) + return self + + return _ + + +class VarDispatcher(Dispatcher): + """A dispatcher that calls functions with variable names + >>> # xdoctest: +SKIP + >>> d = VarDispatcher("d") + >>> x = var("x") + >>> @d.register("inc", x) + ... def f(x): + ... return x + 1 + >>> @d.register("double", x) + ... def f(x): + ... return x * 2 + >>> d("inc", 10) + 11 + >>> d("double", 10) + 20 + """ + + def __call__(self, *args, **kwargs): + func, s = self.resolve(args) + d = {k.token: v for k, v in s.items()} + return func(**d) + + +global_namespace = {} # type: ignore[var-annotated] + + +def match(*signature, **kwargs): + namespace = kwargs.get("namespace", global_namespace) + dispatcher = kwargs.get("Dispatcher", Dispatcher) + + def _(func): + name = func.__name__ + + if name not in namespace: + namespace[name] = dispatcher(name) + d = namespace[name] + + d.add(signature, func) + + return d + + return _ + + +def supercedes(a, b): + """``a`` is a more specific match than ``b``""" + if isvar(b) and not isvar(a): + return True + s = unify(a, b) + if s is False: + return False + s = {k: v for k, v in s.items() if not isvar(k) or not isvar(v)} + if reify(a, s) == a: + return True + if reify(b, s) == b: + return False + + +# Taken from multipledispatch +def edge(a, b, tie_breaker=hash): + """A should be checked before B + Tie broken by tie_breaker, defaults to ``hash`` + """ + if supercedes(a, b): + if supercedes(b, a): + return tie_breaker(a) > tie_breaker(b) + else: + return True + return False + + +# Taken from multipledispatch +def ordering(signatures): + """A sane ordering of signatures to check, first to last + Topological sort of edges as given by ``edge`` and ``supercedes`` + """ + signatures = list(map(tuple, signatures)) + edges = [(a, b) for a in signatures for b in signatures if edge(a, b)] + edges = groupby(first, edges) + for s in signatures: + if s not in edges: + edges[s] = [] + edges = {k: [b for a, b in v] for k, v in edges.items()} # type: ignore[attr-defined, assignment] + return _toposort(edges) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/unification/more.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/unification/more.py new file mode 100644 index 0000000000000000000000000000000000000000..42074a46a4202cface9799af4b81743c292e766d --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/unification/more.py @@ -0,0 +1,131 @@ +# mypy: allow-untyped-defs +from .core import ( # type: ignore[attr-defined] + _reify as core_reify, + _unify as core_unify, + reify, + unify, +) +from .dispatch import dispatch + + +__all__ = ["unifiable", "reify_object", "unify_object"] + + +def unifiable(cls): + """Register standard unify and reify operations on class + This uses the type and __dict__ or __slots__ attributes to define the + nature of the term + See Also: + >>> # xdoctest: +SKIP + >>> class A(object): + ... def __init__(self, a, b): + ... self.a = a + ... self.b = b + >>> unifiable(A) + + >>> x = var("x") + >>> a = A(1, 2) + >>> b = A(1, x) + >>> unify(a, b, {}) + {~x: 2} + """ + core_unify.add((cls, cls, dict), unify_object) # type: ignore[attr-defined] + core_reify.add((cls, dict), reify_object) # type: ignore[attr-defined] + + return cls + + +######### +# Reify # +######### + + +def reify_object(o, s): + """Reify a Python object with a substitution + >>> # xdoctest: +SKIP + >>> class Foo(object): + ... def __init__(self, a, b): + ... self.a = a + ... self.b = b + ... + ... def __str__(self): + ... return "Foo(%s, %s)" % (str(self.a), str(self.b)) + >>> x = var("x") + >>> f = Foo(1, x) + >>> print(f) + Foo(1, ~x) + >>> print(reify_object(f, {x: 2})) + Foo(1, 2) + """ + if hasattr(o, "__slots__"): + return _reify_object_slots(o, s) + else: + return _reify_object_dict(o, s) + + +def _reify_object_dict(o, s): + obj = object.__new__(type(o)) + d = reify(o.__dict__, s) + if d == o.__dict__: + return o + obj.__dict__.update(d) + return obj + + +def _reify_object_slots(o, s): + attrs = [getattr(o, attr) for attr in o.__slots__] + new_attrs = reify(attrs, s) + if attrs == new_attrs: + return o + else: + newobj = object.__new__(type(o)) + for slot, attr in zip(o.__slots__, new_attrs): + setattr(newobj, slot, attr) + return newobj + + +@dispatch(slice, dict) +def _reify(o, s): + """Reify a Python ``slice`` object""" + # pyrefly: ignore [not-iterable] + return slice(*reify((o.start, o.stop, o.step), s)) + + +######### +# Unify # +######### + + +def unify_object(u, v, s): + """Unify two Python objects + Unifies their type and ``__dict__`` attributes + >>> # xdoctest: +SKIP + >>> class Foo(object): + ... def __init__(self, a, b): + ... self.a = a + ... self.b = b + ... + ... def __str__(self): + ... return "Foo(%s, %s)" % (str(self.a), str(self.b)) + >>> x = var("x") + >>> f = Foo(1, x) + >>> g = Foo(1, 2) + >>> unify_object(f, g, {}) + {~x: 2} + """ + if type(u) is not type(v): + return False + if hasattr(u, "__slots__"): + return unify( + [getattr(u, slot) for slot in u.__slots__], + [getattr(v, slot) for slot in v.__slots__], + s, + ) + else: + return unify(u.__dict__, v.__dict__, s) + + +@dispatch(slice, slice, dict) +def _unify(u, v, s): + """Unify a Python ``slice`` object""" + return unify((u.start, u.stop, u.step), (v.start, v.stop, v.step), s) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/unification/unification_tools.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/unification/unification_tools.py new file mode 100644 index 0000000000000000000000000000000000000000..8b4216a79ad0351cc6fedba64c06810fdc894426 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/unification/unification_tools.py @@ -0,0 +1,420 @@ +# mypy: allow-untyped-defs +import collections +import operator +from collections.abc import Mapping +from functools import reduce + + +__all__ = [ + "merge", + "merge_with", + "valmap", + "keymap", + "itemmap", + "valfilter", + "keyfilter", + "itemfilter", + "assoc", + "dissoc", + "assoc_in", + "update_in", + "get_in", +] + + +def _get_factory(f, kwargs): + factory = kwargs.pop("factory", dict) + if kwargs: + raise TypeError( + f"{f.__name__}() got an unexpected keyword argument '{kwargs.popitem()[0]}'" + ) + return factory + + +def merge(*dicts, **kwargs): + """Merge a collection of dictionaries + + >>> merge({1: "one"}, {2: "two"}) + {1: 'one', 2: 'two'} + + Later dictionaries have precedence + + >>> merge({1: 2, 3: 4}, {3: 3, 4: 4}) + {1: 2, 3: 3, 4: 4} + + See Also: + merge_with + """ + if len(dicts) == 1 and not isinstance(dicts[0], Mapping): + dicts = dicts[0] + factory = _get_factory(merge, kwargs) + + rv = factory() + for d in dicts: + rv.update(d) + return rv + + +def merge_with(func, *dicts, **kwargs): + """Merge dictionaries and apply function to combined values + + A key may occur in more than one dict, and all values mapped from the key + will be passed to the function as a list, such as func([val1, val2, ...]). + + >>> merge_with(sum, {1: 1, 2: 2}, {1: 10, 2: 20}) + {1: 11, 2: 22} + + >>> merge_with(first, {1: 1, 2: 2}, {2: 20, 3: 30}) # doctest: +SKIP + {1: 1, 2: 2, 3: 30} + + See Also: + merge + """ + if len(dicts) == 1 and not isinstance(dicts[0], Mapping): + dicts = dicts[0] + factory = _get_factory(merge_with, kwargs) + + result = factory() + for d in dicts: + for k, v in d.items(): + if k not in result: + result[k] = [v] + else: + result[k].append(v) + return valmap(func, result, factory) + + +def valmap(func, d, factory=dict): + """Apply function to values of dictionary + + >>> bills = {"Alice": [20, 15, 30], "Bob": [10, 35]} + >>> valmap(sum, bills) # doctest: +SKIP + {'Alice': 65, 'Bob': 45} + + See Also: + keymap + itemmap + """ + rv = factory() + rv.update(zip(d.keys(), map(func, d.values()))) + return rv + + +def keymap(func, d, factory=dict): + """Apply function to keys of dictionary + + >>> bills = {"Alice": [20, 15, 30], "Bob": [10, 35]} + >>> keymap(str.lower, bills) # doctest: +SKIP + {'alice': [20, 15, 30], 'bob': [10, 35]} + + See Also: + valmap + itemmap + """ + rv = factory() + rv.update(zip(map(func, d.keys()), d.values())) + return rv + + +def itemmap(func, d, factory=dict): + """Apply function to items of dictionary + + >>> accountids = {"Alice": 10, "Bob": 20} + >>> itemmap(reversed, accountids) # doctest: +SKIP + {10: "Alice", 20: "Bob"} + + See Also: + keymap + valmap + """ + rv = factory() + rv.update(map(func, d.items())) + return rv + + +def valfilter(predicate, d, factory=dict): + """Filter items in dictionary by value + + >>> iseven = lambda x: x % 2 == 0 + >>> d = {1: 2, 2: 3, 3: 4, 4: 5} + >>> valfilter(iseven, d) + {1: 2, 3: 4} + + See Also: + keyfilter + itemfilter + valmap + """ + rv = factory() + for k, v in d.items(): + if predicate(v): + rv[k] = v + return rv + + +def keyfilter(predicate, d, factory=dict): + """Filter items in dictionary by key + + >>> iseven = lambda x: x % 2 == 0 + >>> d = {1: 2, 2: 3, 3: 4, 4: 5} + >>> keyfilter(iseven, d) + {2: 3, 4: 5} + + See Also: + valfilter + itemfilter + keymap + """ + rv = factory() + for k, v in d.items(): + if predicate(k): + rv[k] = v + return rv + + +def itemfilter(predicate, d, factory=dict): + """Filter items in dictionary by item + + >>> def isvalid(item): + ... k, v = item + ... return k % 2 == 0 and v < 4 + + >>> d = {1: 2, 2: 3, 3: 4, 4: 5} + >>> itemfilter(isvalid, d) + {2: 3} + + See Also: + keyfilter + valfilter + itemmap + """ + rv = factory() + for item in d.items(): + if predicate(item): + k, v = item + rv[k] = v + return rv + + +def assoc(d, key, value, factory=dict): + """Return a new dict with new key value pair + + New dict has d[key] set to value. Does not modify the initial dictionary. + + >>> assoc({"x": 1}, "x", 2) + {'x': 2} + >>> assoc({"x": 1}, "y", 3) # doctest: +SKIP + {'x': 1, 'y': 3} + """ + d2 = factory() + d2.update(d) + d2[key] = value + return d2 + + +def dissoc(d, *keys, **kwargs): + """Return a new dict with the given key(s) removed. + + New dict has d[key] deleted for each supplied key. + Does not modify the initial dictionary. + + >>> dissoc({"x": 1, "y": 2}, "y") + {'x': 1} + >>> dissoc({"x": 1, "y": 2}, "y", "x") + {} + >>> dissoc({"x": 1}, "y") # Ignores missing keys + {'x': 1} + """ + factory = _get_factory(dissoc, kwargs) + d2 = factory() + + if len(keys) < len(d) * 0.6: + d2.update(d) + for key in keys: + if key in d2: + del d2[key] + else: + remaining = set(d) + remaining.difference_update(keys) + for k in remaining: + d2[k] = d[k] + return d2 + + +def assoc_in(d, keys, value, factory=dict): + """Return a new dict with new, potentially nested, key value pair + + >>> purchase = { + ... "name": "Alice", + ... "order": {"items": ["Apple", "Orange"], "costs": [0.50, 1.25]}, + ... "credit card": "5555-1234-1234-1234", + ... } + >>> assoc_in(purchase, ["order", "costs"], [0.25, 1.00]) # doctest: +SKIP + {'credit card': '5555-1234-1234-1234', + 'name': 'Alice', + 'order': {'costs': [0.25, 1.00], 'items': ['Apple', 'Orange']}} + """ + return update_in(d, keys, lambda x: value, value, factory) + + +def update_in(d, keys, func, default=None, factory=dict): + """Update value in a (potentially) nested dictionary + + inputs: + d - dictionary on which to operate + keys - list or tuple giving the location of the value to be changed in d + func - function to operate on that value + + If keys == [k0,..,kX] and d[k0]..[kX] == v, update_in returns a copy of the + original dictionary with v replaced by func(v), but does not mutate the + original dictionary. + + If k0 is not a key in d, update_in creates nested dictionaries to the depth + specified by the keys, with the innermost value set to func(default). + + >>> inc = lambda x: x + 1 + >>> update_in({"a": 0}, ["a"], inc) + {'a': 1} + + >>> transaction = { + ... "name": "Alice", + ... "purchase": {"items": ["Apple", "Orange"], "costs": [0.50, 1.25]}, + ... "credit card": "5555-1234-1234-1234", + ... } + >>> update_in(transaction, ["purchase", "costs"], sum) # doctest: +SKIP + {'credit card': '5555-1234-1234-1234', + 'name': 'Alice', + 'purchase': {'costs': 1.75, 'items': ['Apple', 'Orange']}} + + >>> # updating a value when k0 is not in d + >>> update_in({}, [1, 2, 3], str, default="bar") + {1: {2: {3: 'bar'}}} + >>> update_in({1: "foo"}, [2, 3, 4], inc, 0) + {1: 'foo', 2: {3: {4: 1}}} + """ + ks = iter(keys) + k = next(ks) + + rv = inner = factory() + rv.update(d) + + # pyrefly: ignore [not-iterable] + for key in ks: + if k in d: + d = d[k] + dtemp = factory() + dtemp.update(d) + else: + d = dtemp = factory() + + inner[k] = inner = dtemp + k = key + + if k in d: + inner[k] = func(d[k]) + else: + inner[k] = func(default) + return rv + + +def get_in(keys, coll, default=None, no_default=False): + """Returns coll[i0][i1]...[iX] where [i0, i1, ..., iX]==keys. + + If coll[i0][i1]...[iX] cannot be found, returns ``default``, unless + ``no_default`` is specified, then it raises KeyError or IndexError. + + ``get_in`` is a generalization of ``operator.getitem`` for nested data + structures such as dictionaries and lists. + + >>> transaction = { + ... "name": "Alice", + ... "purchase": {"items": ["Apple", "Orange"], "costs": [0.50, 1.25]}, + ... "credit card": "5555-1234-1234-1234", + ... } + >>> get_in(["purchase", "items", 0], transaction) + 'Apple' + >>> get_in(["name"], transaction) + 'Alice' + >>> get_in(["purchase", "total"], transaction) + >>> get_in(["purchase", "items", "apple"], transaction) + >>> get_in(["purchase", "items", 10], transaction) + >>> get_in(["purchase", "total"], transaction, 0) + 0 + >>> get_in(["y"], {}, no_default=True) + Traceback (most recent call last): + ... + KeyError: 'y' + + See Also: + itertoolz.get + operator.getitem + """ + try: + return reduce(operator.getitem, keys, coll) + except (KeyError, IndexError, TypeError): + if no_default: + raise + return default + + +def getter(index): + if isinstance(index, list): + if len(index) == 1: + index = index[0] + return lambda x: (x[index],) + elif index: + return operator.itemgetter(*index) + else: + return lambda x: () + else: + return operator.itemgetter(index) + + +def groupby(key, seq): + """Group a collection by a key function + + >>> names = ["Alice", "Bob", "Charlie", "Dan", "Edith", "Frank"] + >>> groupby(len, names) # doctest: +SKIP + {3: ['Bob', 'Dan'], 5: ['Alice', 'Edith', 'Frank'], 7: ['Charlie']} + + >>> iseven = lambda x: x % 2 == 0 + >>> groupby(iseven, [1, 2, 3, 4, 5, 6, 7, 8]) # doctest: +SKIP + {False: [1, 3, 5, 7], True: [2, 4, 6, 8]} + + Non-callable keys imply grouping on a member. + + >>> groupby( + ... "gender", + ... [ + ... {"name": "Alice", "gender": "F"}, + ... {"name": "Bob", "gender": "M"}, + ... {"name": "Charlie", "gender": "M"}, + ... ], + ... ) # doctest:+SKIP + {'F': [{'gender': 'F', 'name': 'Alice'}], + 'M': [{'gender': 'M', 'name': 'Bob'}, + {'gender': 'M', 'name': 'Charlie'}]} + + Not to be confused with ``itertools.groupby`` + + See Also: + countby + """ + if not callable(key): + key = getter(key) + d = collections.defaultdict(lambda: [].append) # type: ignore[var-annotated] + for item in seq: + d[key(item)](item) + rv = {} + for k, v in d.items(): + rv[k] = v.__self__ # type: ignore[var-annotated, attr-defined] + return rv + + +def first(seq): + """The first element in a sequence + + >>> first("ABC") + 'A' + """ + return next(iter(seq)) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/unification/utils.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/unification/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..ab99ad1b4f0d495067cb33b8464c7c80777f7d8d --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/unification/utils.py @@ -0,0 +1,108 @@ +# mypy: allow-untyped-defs +__all__ = ["hashable", "transitive_get", "raises", "reverse_dict", "xfail", "freeze"] + + +def hashable(x): + try: + hash(x) + return True + except TypeError: + return False + + +def transitive_get(key, d): + """Transitive dict.get + >>> d = {1: 2, 2: 3, 3: 4} + >>> d.get(1) + 2 + >>> transitive_get(1, d) + 4 + """ + while hashable(key) and key in d: + key = d[key] + return key + + +def raises(err, lamda): # codespell:ignore lamda + try: + lamda() # codespell:ignore lamda + return False + except err: + return True + + +# Taken from theano/theano/gof/sched.py +# Avoids licensing issues because this was written by Matthew Rocklin +def _toposort(edges): + """Topological sort algorithm by Kahn [1] - O(nodes + vertices) + inputs: + edges - a dict of the form {a: {b, c}} where b and c depend on a + outputs: + L - an ordered list of nodes that satisfy the dependencies of edges + >>> # xdoctest: +SKIP + >>> _toposort({1: (2, 3), 2: (3,)}) + [1, 2, 3] + Closely follows the wikipedia page [2] + [1] Kahn, Arthur B. (1962), "Topological sorting of large networks", + Communications of the ACM + [2] http://en.wikipedia.org/wiki/Toposort#Algorithms + """ + incoming_edges = reverse_dict(edges) + incoming_edges = {k: set(val) for k, val in incoming_edges.items()} + S = {v for v in edges if v not in incoming_edges} + L = [] + + while S: + n = S.pop() + L.append(n) + for m in edges.get(n, ()): + assert n in incoming_edges[m] + incoming_edges[m].remove(n) + if not incoming_edges[m]: + S.add(m) + if any(incoming_edges.get(v) for v in edges): + raise ValueError("Input has cycles") + return L + + +def reverse_dict(d): + """Reverses direction of dependence dict + >>> d = {"a": (1, 2), "b": (2, 3), "c": ()} + >>> reverse_dict(d) # doctest: +SKIP + {1: ('a',), 2: ('a', 'b'), 3: ('b',)} + :note: dict order are not deterministic. As we iterate on the + input dict, it make the output of this function depend on the + dict order. So this function output order should be considered + as undeterministic. + """ + result = {} # type: ignore[var-annotated] + for key in d: + for val in d[key]: + result[val] = result.get(val, ()) + (key,) + return result + + +def xfail(func): + try: + func() + raise Exception("XFailed test passed") # pragma:nocover # noqa: TRY002 + except Exception: + pass + + +def freeze(d): + """Freeze container to hashable form + >>> freeze(1) + 1 + >>> freeze([1, 2]) + (1, 2) + >>> freeze({1: 2}) # doctest: +SKIP + frozenset([(1, 2)]) + """ + if isinstance(d, dict): + return frozenset(map(freeze, d.items())) + if isinstance(d, set): + return frozenset(map(freeze, d)) + if isinstance(d, (tuple, list)): + return tuple(map(freeze, d)) + return d diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/unification/variable.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/unification/variable.py new file mode 100644 index 0000000000000000000000000000000000000000..1b5b51aaf99a5dc9864f5aa22fa9c50571f95797 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/unification/variable.py @@ -0,0 +1,90 @@ +# mypy: allow-untyped-defs +from contextlib import contextmanager + +from .dispatch import dispatch +from .utils import hashable + + +_global_logic_variables = set() # type: ignore[var-annotated] +_glv = _global_logic_variables + + +class Var: + """Logic Variable""" + + _id = 1 + + def __new__(cls, *token): + if len(token) == 0: + token = f"_{Var._id}" # type: ignore[assignment] + Var._id += 1 + elif len(token) == 1: + token = token[0] + + obj = object.__new__(cls) + obj.token = token # type: ignore[attr-defined] + return obj + + def __str__(self): + return "~" + str(self.token) # type: ignore[attr-defined] + + __repr__ = __str__ + + def __eq__(self, other): + return type(self) is type(other) and self.token == other.token # type: ignore[attr-defined] + + def __hash__(self): + return hash((type(self), self.token)) # type: ignore[attr-defined] + + +def var(): + return lambda *args: Var(*args) + + +def vars(): + return lambda n: [var() for i in range(n)] + + +@dispatch(Var) +def isvar(v): + return True + + +isvar + + +@dispatch(object) # type: ignore[no-redef] +def isvar(o): + return _glv and hashable(o) and o in _glv + + +@contextmanager +def variables(*variables): + """ + Context manager for logic variables + + Example: + >>> # xdoctest: +SKIP("undefined vars") + >>> from __future__ import with_statement + >>> with variables(1): + ... print(isvar(1)) + True + >>> print(isvar(1)) + False + >>> # Normal approach + >>> from unification import unify + >>> x = var("x") + >>> unify(x, 1) + {~x: 1} + >>> # Context Manager approach + >>> with variables("x"): + ... print(unify("x", 1)) + {'x': 1} + """ + old_global_logic_variables = _global_logic_variables.copy() + _global_logic_variables.update(set(variables)) + try: + yield + finally: + _global_logic_variables.clear() + _global_logic_variables.update(old_global_logic_variables) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/unify_refinements.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/unify_refinements.py new file mode 100644 index 0000000000000000000000000000000000000000..efafb146179a6c35e0a5ccb9a29893aa3a379a87 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/unify_refinements.py @@ -0,0 +1,124 @@ +# mypy: allow-untyped-defs +from torch.fx.experimental.graph_gradual_typechecker import Refine +from torch.fx.experimental.unification import unify, Var # type: ignore[attr-defined] +from torch.fx.tensor_type import TensorType + + +def infer_symbolic_types_single_pass(traced): + """ + Calls our symbolic inferencer once. + """ + r = Refine(traced) + r.refine() + mgu = unify_eq(r.constraints) + substitute_all_types(traced.graph, mgu) + + +def infer_symbolic_types(traced): + """ + Calls our symbolic inferencer twice. + This is useful when one pass is not enough + to infer all the information such as the case + for braodcasting. + """ + r = Refine(traced) + r.refine() + mgu = unify_eq(r.constraints) + substitute_all_types(traced.graph, mgu) + + r = Refine(traced) + r.refine() + mgu = unify_eq(r.constraints) + substitute_all_types(traced.graph, mgu) + + r.symbolic_relations() + + +def convert_eq(list_of_eq): + """ + Convert equality constraints in the right format + to be used by unification library. + """ + lhs = [] + rhs = [] + for eq in list_of_eq: + lhs.append(eq.lhs) + rhs.append(eq.rhs) + return tuple(lhs), tuple(rhs) + + +def unify_eq(list_of_eq): + """ + Apply unification to a set of + equality constraints + """ + lhs, rhs = convert_eq(list_of_eq) + return unify(lhs, rhs) + + +def substitute_solution_one_type(mapping, t): + """ + Apply the most general unifier to a type + """ + if isinstance(t, Var): + if t in mapping: + return mapping[t] + else: + return t + + elif isinstance(t, TensorType): + new_type = [] + for typ in t.__args__: + if typ in mapping: + new_type.append(mapping[typ]) + else: + new_type.append(typ) + return TensorType(tuple(new_type)) + + elif isinstance(t, list): + new_type = [] + for typ in t: + new_type.append(substitute_solution_one_type(mapping, typ)) + return new_type + + elif isinstance(t, tuple): + new_type = [] + for typ in t: + new_type.append(substitute_solution_one_type(mapping, typ)) + return tuple(new_type) + + else: + return t + + +def substitute_all_types(graph, mapping): + """ + Apply the most general unifier to all types in a graph + till reaching a fixed point. If the input and output graph + are the same, we converge. + """ + flag = True + while flag: + flag = False + for k in mapping: + old_mapping_val = mapping[k] + if mapping[k] in mapping: + new_key = mapping[k] + mapping[k] = mapping[new_key] + if old_mapping_val != mapping[k]: + flag = True + + for n in graph.nodes: + n.type = substitute_solution_one_type(mapping, n.type) + + +def check_for_type_equality(g1, g2): + """ + A check equality to be used in fixed points. + We do not use graph equality but instead type + equality. + """ + for n, m in zip(g1.nodes, g2.nodes): + if n.type != m.type: + return False + return True diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/validator.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/validator.py new file mode 100644 index 0000000000000000000000000000000000000000..56b8b871626af81f23f4e88cc5f57161ed1287ad --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/experimental/validator.py @@ -0,0 +1,874 @@ +# mypy: allow-untyped-defs +import builtins +import functools +import logging +import math +import operator +from collections.abc import Callable +from dataclasses import dataclass +from typing import Any, Optional, Union + +import sympy + +import torch +import torch.fx +import torch.fx.traceback as fx_traceback +from torch._dynamo.exc import TorchDynamoException +from torch._dynamo.utils import dynamo_timed +from torch.fx.node import Argument, Target +from torch.utils._sympy.interp import sympy_interp + + +log = logging.getLogger(__name__) + +try: + import z3 # type: ignore[import] + + # Translation Validation for Dynamo guards + # ======================================== + # + # Checks whether optimizations applied to the collected guards are + # valid. In other words, whether the guard function we actually run + # does not have false positives (unsound). + # + # In order to do so, we build the guards using 2 different information + # attached to each 'SymNode': + # 1. SymPy expressions + # 2. FX nodes + # + # SymPy expressions have implicit optimizations baked within itself, + # which may have a few bugs. On the other hand, we build the FX graph + # manually, with no optimizations enabled. This gives us access to + # the "ground truth". + # + # We then convert into Z3 expressions both the SymPy expressions + # (see [Note: SympyToZ3]) that reach 'ShapeEnv.produce_guards' function + # and the FX nodes (see [Note: PopulateValidator]) that go through + # 'ShapeEnv.evaluate_expr' function. Finally, we run the validation. + # (see [Note: TranslationValidator]) + # Better Z3 to string implementation (for a small fraction of Z3). + # + # Here are the things we clean before showing the Z3 expression: + # - Rename a few ops (e.g. "Distinct" ==> "!=") + # + # - Ignore ToInt and ToReal operations: + # usually they don't really matter + # + # - Transform (ToInt (/ ...)) into (idiv ...): + # this is the pattern for floor division + # + # - Collect a chain of the same operations into one + def z3str(e: z3.ExprRef) -> str: + assert z3.is_expr(e), f"unsupported expression type: {e}" + + def get_args_str(e: z3.ExprRef) -> list[str]: + return [z3str(e.arg(i)) for i in range(e.num_args())] + + # First, we simplify the given expression. + # This is done using rewriting rules, so shouldn't take long. + e = z3.simplify(e) + + # Only support function applications. + # Even Z3 "variables" are, in fact, function applications. + if not z3.is_app(e): + raise ValueError(f"can't print Z3 expression: {e}") + + if z3.is_int_value(e) or z3.is_rational_value(e): + return e.as_string() # type: ignore[attr-defined] + + decl = e.decl() + kind = decl.kind() + op = str(decl) + args = get_args_str(e) + + if kind == z3.Z3_OP_POWER: + op = "pow" + + elif kind in (z3.Z3_OP_ADD, z3.Z3_OP_MUL): + # Collect the arguments of chains of ADD and MUL. + # This is safe, since they are associative. + + def collect_str_args(e): + if not (z3.is_app(e) and e.decl().kind() == kind): + return [z3str(e)] + else: + return [ + x + for i in range(e.num_args()) + for x in collect_str_args(e.arg(i)) + ] + + args = collect_str_args(e) + + elif kind == z3.Z3_OP_NOT: + # Revert some conversions that z3.simplify applies: + # - a != b ==> (Not (== a b)) ==> (!= a b) + # - a < b ==> (Not (<= b a)) ==> (> b a) + # - a > b ==> (Not (<= a b)) ==> (> a b) + + assert e.num_args() == 1 + arg = e.arg(0) + + assert z3.is_app(arg) + argkind = arg.decl().kind() + + logic_inverse = { + z3.Z3_OP_EQ: "!=", + z3.Z3_OP_LE: ">", + z3.Z3_OP_GE: "<", + } + + if argkind in logic_inverse: + op = logic_inverse[argkind] + args = get_args_str(arg) + + elif kind in (z3.Z3_OP_TO_INT, z3.Z3_OP_TO_REAL): + assert e.num_args() == 1 + argstr = z3str(e.arg(0)) + + # Check if it's the floor division pattern. + if argstr.startswith("(/"): + return "(idiv" + argstr[2:] + + # Otherwise, just ignore it. + return argstr + + elif kind == z3.Z3_OP_UNINTERPRETED: + assert e.num_args() == 0 + return str(decl) + + string = op + " " + " ".join(args) + return f"({string.rstrip()})" + + # We need to convert to/from BitVec in order to use z3 bitwise ops. + # We assume that integers are 64 bit. + # If all args are boolean, then use the boolean bitwise op implementation instead, if provided. + def _bitwise_op(bitwise_func, bool_func): + @functools.wraps(bitwise_func) + def wrapper(self, *args): + if bool_func is not None and all( + isinstance(arg, z3.BoolRef) for arg in args + ): + return bool_func(*args) + + wrapped_args = tuple(z3.Int2BV(a, 64) for a in args) + return z3.BV2Int(bitwise_func(*wrapped_args)) + + return wrapper + + # Implementation of Python semantics as Z3 expressions. + # + # Z3 Real-Int theory has operators with semantics that differ that of + # Python. Therefore, in order to get it right, we need to implement + # the (Python) semantics we are relying on in Z3. + @dataclass + class _Z3Ops: + # Validator used for adding assertions as needed. + # e.g. div(a, b) requires b != 0. + validator: "TranslationValidator" + + # The 2 functions below are used for conditionally casting between + # integer and reals. + # + # Returns a real expression from 'x'. + @staticmethod + def to_real(x: z3.ArithRef) -> z3.ArithRef: + return x if x.is_real() else z3.ToReal(x) + + # Returns an integer expression from 'x'. + @staticmethod + def to_int(x: z3.ArithRef) -> z3.ArithRef: + return x if x.is_int() else z3.ToInt(x) + + def sym_sum(self, args: z3.ArithRef) -> z3.ArithRef: + # pyrefly: ignore + return sum(args) + + # Implements Python division semantics. + def div(self, numerator: z3.ArithRef, denominator: z3.ArithRef) -> z3.ArithRef: + self.validator.add_assertion(denominator != 0) # type: ignore[arg-type] + return _Z3Ops.to_real(numerator) / _Z3Ops.to_real(denominator) + + def floor(self, number: z3.ArithRef) -> z3.ArithRef: + # Z3 ToInt function rounds a real number towards negative infinity. + return _Z3Ops.to_int(number) + + # Python semantics for 'FloorDiv' states that before applying the floor + # function, the operands are converted to their common type. + def floordiv( + self, numerator: z3.ArithRef, denominator: z3.ArithRef + ) -> z3.ArithRef: + cast_result_to_real = numerator.is_real() or denominator.is_real() + result = _Z3Ops.to_int(self.div(numerator, denominator)) + # Since the 'result' is already an integer, we just have to check + # whether we should cast it to real. + return _Z3Ops.to_real(result) if cast_result_to_real else result + + def ceil(self, number: z3.ArithRef) -> z3.ArithRef: + return z3.If(self.floor(number) < number, self.floor(number + 1), number) # type: ignore[return-value] + + def trunc(self, number: z3.ArithRef) -> z3.ArithRef: + return z3.If(number >= 0, self.floor(number), self.ceil(number)) # type: ignore[return-value] + + def max(self, a: z3.ArithRef, b: z3.ArithRef) -> z3.ArithRef: + return z3.If(a > b, a, b) # type: ignore[return-value] + + def min(self, a: z3.ArithRef, b: z3.ArithRef) -> z3.ArithRef: + return z3.If(a < b, a, b) # type: ignore[return-value] + + # Python semantics for 'Mod' is defined as: p % q = p - floordiv(p, q) * q + # It should work with both integer and reals. + def mod(self, p: z3.ArithRef, q: z3.ArithRef) -> z3.ArithRef: + return p - self.floordiv(p, q) * q + + def pow(self, base: z3.ArithRef, exp: z3.ArithRef) -> z3.ArithRef: + # Z3 can't handle complex numbers very well. + self.validator.add_assertion(z3.Or(base != 0, exp > 0)) # type: ignore[arg-type] + return base**exp + + def sqrt(self, number: z3.ArithRef) -> z3.ArithRef: + # Square-root: + # 1. Only work with reals + number = _Z3Ops.to_real(number) + # 2. The number should be positive or zero. + # Otherwise, Z3 returns 'unknown'. + self.validator.add_assertion(number >= 0) + return number**0.5 + + def abs(self, number: z3.ArithRef) -> z3.ArithRef: + return z3.Abs(number) + + def round_to_int(self, number: z3.ArithRef) -> z3.ArithRef: + # Pythons builtin 'round' implements the 'round half to even' strategy + # See https://en.wikipedia.org/wiki/Rounding#Rounding_half_to_even + # z3 has an equivalent z3.fpRoundToIntegral(z3.RoundNearestTiesToEven(), ...), but this only applies to + # floating point numbers, which is different from real numbers that we are dealing with here. + # Instead, we implement 'round half to even' in terms of 'round half up' (floor(x + 0.5)) and + # 'round half down' (ceil(x - 0.5)). + # Assuming 'round half up' is the default case, we need to correct ..., -3.5, -1.5, 0.5, 2.5, 4.5, ... + # to round down, i.e. use the 'round half down' strategy + return z3.If( + self.mod(number, z3.IntVal(2)) == 0.5, + self.ceil(number - 0.5), + self.floor(number + 0.5), + ) + + bitwise_and = _bitwise_op(operator.and_, z3.And) + bitwise_or = _bitwise_op(operator.or_, z3.Or) + lshift = _bitwise_op(operator.lshift, None) + rshift = _bitwise_op(operator.rshift, None) + + # Lifts a callable to be used in Z3. + # + # This function replaces the given 'op' by a function that: + # + # 1. Lifts the arguments into Z3 (i.e. make them inhabitants of Z3) + # + # 2. Calls an operation that corresponds to 'op', but works with Z3 + # inhabitants (left as is if it works as is) + def z3op(op: Callable, validator: "TranslationValidator") -> Callable: + # Operations that have booleans as their argument. + # This is needed because the argument of some FX nodes were + # literal integers, instead of booleans. So, whenever this flag + # is set, we also convert ints to booleans. + boolean_ops = {operator.not_} + as_bool = op in boolean_ops + + # Lifts the function into 'z3.ExprRef' domain. + def lift(func): + def wrap(a) -> z3.ExprRef: + if isinstance(a, (z3.ArithRef, z3.BoolRef)): + return a + # Convert it into a Z3 value, if it is some of the supported + # types below. + if isinstance(a, bool) or (as_bool and isinstance(a, int)): + return z3.BoolVal(bool(a)) + if isinstance(a, (int, sympy.Integer)): + return z3.IntVal(int(a)) + if isinstance(a, (float, sympy.Float)): + return z3.RealVal(float(a)) + raise ValueError(f"can't lift type: {type(a)}") + + @functools.wraps(func) + def wrapper(*args): + # Lifts the arguments into a list of Z3 inhabitants. + if len(args) == 1 and isinstance(args[0], (list, tuple)): + wrapped_args = (tuple(wrap(a) for a in args[0]),) + else: + wrapped_args = tuple(wrap(a) for a in args) + # Run the function on the Z3 expressions. + return func(*wrapped_args) + + return wrapper + + ops = _Z3Ops(validator) + replacement_map = { + # Operator module. + operator.not_: lift(z3.Not), + operator.and_: lift(ops.bitwise_and), + operator.or_: lift(ops.bitwise_or), + operator.lshift: lift(ops.lshift), + operator.rshift: lift(ops.rshift), + operator.floordiv: lift(ops.floordiv), + operator.truediv: lift(ops.div), + operator.mod: lift(ops.mod), + operator.abs: lift(ops.abs), + builtins.round: lift(ops.round_to_int), + # Math module. + math.ceil: lift(ops.ceil), + math.floor: lift(ops.floor), + math.trunc: lift(ops.trunc), + # Torch module. + torch.sym_float: lift(ops.to_real), + torch.sym_max: lift(ops.max), + torch.sym_min: lift(ops.min), + torch.sym_sum: lift(ops.sym_sum), + torch.sym_ite: lift(lambda b, t, f: t if b else f), + torch._sym_sqrt: lift(ops.sqrt), # type: ignore[attr-defined] + # Not lifted because we only use this function as a + # marker for adding the expression as validator input. + torch._assert: torch._assert, + } + return replacement_map[op] if op in replacement_map else lift(op) + + # Processes an FX graph, populating the given validator. + # + # [Note: PopulateValidator] + # This class walks through each node in the FX graph, translating + # them into the Z3 world. + # + # Then, whenever it finds an 'torch._assert' call_function operation, + # it adds the Z3 expression corresponding to the argument as validator + # input. + class PopulateValidator(torch.fx.Interpreter): + def __init__(self, graph: torch.fx.Graph, validator: "TranslationValidator"): + # Reference to the translation validator. + self.validator = validator + + # Build the graph module and call `Interpreter` constructor. + module = torch.fx.GraphModule(root={}, graph=graph) + super().__init__(module, garbage_collect_values=True) + + def placeholder( + self, target: Target, args: tuple[Argument, ...], kwargs: dict[str, Any] + ) -> Any: + symbol = fx_traceback.get_current_meta()["symbol"] + return self.validator.z3var(symbol) + + def call_function( + self, target: Target, args: tuple[Argument, ...], kwargs: dict[str, Any] + ) -> Any: + if target is not torch._assert: + # Lift and runs the node target function + return super().call_function(z3op(target, self.validator), args, kwargs) # type: ignore[arg-type] + # Adds the Z3 expression corresponding to the first argument + # as a validator input. + assert len(args) == 1, ( + f"expected 1 argument on assertion. Got: {len(args)} " + ) + self.validator.add_source_expr(args[0]) # type: ignore[arg-type] + + # Translates SymPy expressions into Z3 expressions. + # + # [Note: SympyToZ3] + # At the time of the translation, all free variables present in the + # SymPy expression being translated must be already mapped to a Z3 + # integer variable. + class SympyToZ3: + OPERATOR_HANDLES = {"add", "mul", "eq", "ne", "lt", "gt", "le", "ge"} + + def __init__( + self, + validator: "TranslationValidator", + ) -> None: + self._validator = validator + self._ops = _Z3Ops(self._validator) + + def constant(self, value: Any, dtype: torch.dtype) -> z3.ExprRef: + # TODO: Probably OK to relax this and allow lower precision + if dtype is torch.int64: + return z3.IntVal(int(value)) + if dtype is torch.double: + return z3.RealVal(float(value)) + if dtype is torch.bool: + return z3.BoolVal(bool(value)) + raise ValueError(f"unsupported dtype (SympyToZ3): {dtype}") + + def to_dtype(self, x: z3.ArithRef, dtype: torch.dtype) -> z3.ArithRef: + if dtype == torch.float64: + return z3.ToReal(x) + raise NotImplementedError(f"to_dtype {dtype} NYI") + + def trunc_to_int(self, x: z3.ArithRef, dtype: torch.dtype) -> z3.ArithRef: + return z3.ToInt(x) + + def round_to_int(self, x: z3.ArithRef, dtype: torch.dtype) -> z3.ArithRef: + return self._ops.round_to_int(x) + + def int_truediv( + self, numerator: z3.ArithRef, denominator: z3.ArithRef + ) -> z3.ArithRef: + return self._ops.div(numerator, denominator) + + def truediv( + self, numerator: z3.ArithRef, denominator: z3.ArithRef + ) -> z3.ArithRef: + return self._ops.div(numerator, denominator) + + def floordiv( + self, numerator: z3.ArithRef, denominator: z3.ArithRef + ) -> z3.ArithRef: + return self._ops.floordiv(numerator, denominator) + + def div(self, numerator: z3.ArithRef, denominator: z3.ArithRef) -> z3.ArithRef: + return self._ops.floordiv(numerator, denominator) + + def pow(self, base: z3.ArithRef, exp: z3.ArithRef) -> z3.ArithRef: + return self._ops.pow(base, exp) + + def pow_by_natural(self, base: z3.ArithRef, exp: z3.ArithRef) -> z3.ArithRef: + return self._ops.pow(base, exp) + + def mod(self, p: z3.ArithRef, q: z3.ArithRef) -> z3.ArithRef: + return self._ops.mod(p, q) + + def python_mod(self, p: z3.ArithRef, q: z3.ArithRef) -> z3.ArithRef: + return self._ops.mod(p, q) + + def ceil_to_int(self, x: z3.ArithRef, dtype: torch.dtype) -> z3.ArithRef: + return self._ops.ceil(x) + + def floor_to_int(self, x: z3.ArithRef, dtype: torch.dtype) -> z3.ArithRef: + return self._ops.floor(x) + + def __getattr__(self, name: str) -> Any: + REPLACEMENT = { + "and_": z3.And, + "or_": z3.Or, + "not_": z3.Not, + "bitwise_and": self._ops.bitwise_and, + "bitwise_or": self._ops.bitwise_or, + "lshift": self._ops.lshift, + "rshift": self._ops.rshift, + "floor": self._ops.floor, + "ceil": self._ops.ceil, + "minimum": self._ops.min, + "maximum": self._ops.max, + } + + if name in REPLACEMENT: + return REPLACEMENT[name] + if name in self.OPERATOR_HANDLES: + return getattr(operator, name) + raise AttributeError(f"unhandled operator: {name}") + + def run(self, expr: sympy.Basic) -> z3.ExprRef: + return sympy_interp(self, self._validator.symbols, expr) # type: ignore[arg-type] + + # Dynamo guards translation validator. + # + # [Note: TranslationValidator] + # Verifies whether the guards issued by 'ShapeEnv.produce_guards' are sound. + # That is: whether those (target) guards only yield TRUE whenever the original, + # unoptimized, (source) guards yield TRUE. + # + # More concretely, given 'source' and 'target' guard expressions, we wish to + # check whether the following expression holds: + # + # Not(And(source)) AND And(target) + # + # i.e. whether there is an assignment of the free variables where the opposite + # happens: target is TRUE, but source is FALSE. + class TranslationValidator: + def __init__(self) -> None: + log.debug("new instance") + + # Mapping of SymPy symbols to Z3 variables. + self.symbols: dict[sympy.Symbol, z3.ExprRef] = {} + + # Set of source Z3 expressions. + # They represent the generated guards without any kind of + # simplification or transformation. + self._source_exprs: set[z3.BoolRef] = set() + + # Set of target Z3 expressions. + # They represent the actual checked guards at runtime. They might + # be simplified or transformed versions of the source guards. + self._target_exprs: set[z3.BoolRef] = set() + + # Set of Z3 expressions representing assertions over both the + # source and target expressions. + self._assertions: set[z3.BoolRef] = set() + + # Retrieves the corresponding Z3 variable. + def z3var(self, symbol: sympy.Symbol) -> z3.ExprRef: + assert symbol in self.symbols, f"Z3 variable not found for: {symbol}" + return self.symbols[symbol] + + # Create a variable in Z3 of 'type' for 'symbol', if it doesn't already exists. + def add_var(self, symbol: sympy.Symbol, type: type) -> z3.ExprRef: + if symbol in self.symbols: + return self.symbols[symbol] + + log.debug("new variable: %s (%s)", symbol.name, type.__name__) + + if type is int: + var = z3.Int(symbol.name) + + # If 'symbol' is positive (SymPy assumption), we have to + # convey it to Z3 as well. + if symbol.is_positive: # type: ignore[attr-defined] + self._target_exprs.add(var > 0) + elif type is float: + var = z3.Real(symbol.name) + elif type is bool: + var = z3.Bool(symbol.name) + else: + raise RuntimeError(f"unsupported type for Z3 variable: {type}") + + self.symbols[symbol] = var + return var + + # Checks whether all symbols were already added. + def _check_freesymbols(self, e: sympy.Basic) -> None: + for s in e.free_symbols: + assert isinstance(s, sympy.Symbol) + # Call 'z3var' just to check whether there's already a + # Z3 variable corresponding to 's'. + self.z3var(s) + + def to_z3_boolean_expr(self, e: sympy.Basic) -> z3.BoolRef: + z3expr = SympyToZ3(self).run(e) + assert isinstance(z3expr, z3.BoolRef), ( + f"expected boolean expression. Got: {z3expr}" + ) + return z3expr + + def add_source_expr(self, e: z3.BoolRef) -> None: + if e not in self._source_exprs: + log.debug("add source guard: %s", z3str(e)) + self._source_exprs.add(e) + + def add_target_expr(self, e: "sympy.logic.boolalg.Boolean") -> None: + self._check_freesymbols(e) + z3expr = self.to_z3_boolean_expr(e) + if e not in self._target_exprs: + log.debug("add target guard: %s", z3str(z3expr)) + self._target_exprs.add(z3expr) + + def add_assertion(self, e: Union[z3.BoolRef, sympy.Basic]) -> None: + if isinstance(e, sympy.Basic): + self._check_freesymbols(e) + ref = self.to_z3_boolean_expr(e) + else: + ref = e + assert isinstance(ref, z3.BoolRef) + if ref not in self._assertions: + log.debug("add assertion: %s", z3str(ref)) + self._assertions.add(ref) + + def validate(self) -> None: + with dynamo_timed("TranslationValidator.validate"): + return self._validate() + + def _validate(self) -> None: + if len(self._source_exprs) == 0 or len(self._target_exprs) == 0: + # If there are no source/target expressions, there's nothing we really + # wish to prove. So, we just return. + return None + + # Here, we use "QF_NRA" logic for the solver: + # "Quantifier-free Non-linear Real Arithmetic". + # + # Most of the guards expressions have: + # 1. arithmetic between integer and reals + # 2. no quantifiers + # 3. potentially non-linear. + # + # Although there's also "QF_NIRA" (mixed integer-real arithmetic), + # "QF_NRA" seems to work better on 'dynamo/test_dynamic_shapes.py'. + solver = z3.SolverFor("QF_NRA") + # Set a timeout for finding a solution. + solver.set(timeout=translation_validation_timeout()) + + # Add all the assertions to the solver. + for assertion in self._assertions: + solver.add(assertion) + + # "Is there any case where it's TRUE for the target expressions, + # but FALSE for the source expressions?" + solver.add(z3.Not(z3.And(*self._source_exprs))) + solver.add(*self._target_exprs) + + log.debug("translation validation: start") + r = solver.check() + if r == z3.sat: + # Target expressions are unsound. + # Log the found model and the source expressions that failed. + model = solver.model() + raise ValidationException( + model, + self._assertions, + self._target_exprs, + failed_source_exprs=[ + inp for inp in self._source_exprs if not model.evaluate(inp) + ], + ) + else: + if r == z3.unknown: + # Could not find a solution. It didn't fail, but it also + # didn't succeed. Canceling the validation execution (keyboard + # interrupt) also gets to this branch. + log.warning( + "translation validation: could not validate: got z3.unknown" + ) + else: + # Target expressions are sound. + assert r == z3.unsat + log.debug("translation validation: success") + +except ImportError: + _HAS_Z3 = False + + __all__ = [ + "translation_validation_enabled", + "translation_validation_timeout", + "ValidationException", + "BisectValidationException", + ] + +else: + _HAS_Z3 = True + + __all__ = [ + "z3str", + "z3op", + "PopulateValidator", + "SympyToZ3", + "TranslationValidator", + "translation_validation_enabled", + "translation_validation_timeout", + "ValidationException", + "BisectValidationException", + ] + +from torch.fx.experimental import _config as config + + +def translation_validation_enabled() -> bool: + # Checks every time this function is called, in case the Dynamo + # option is set, but Z3 is not installed. + _assert_z3_installed_if_tv_set() + return _HAS_Z3 and config.translation_validation + + +def translation_validation_timeout() -> int: + return config.translation_validation_timeout + + +def _assert_z3_installed_if_tv_set(): + assert _HAS_Z3 or not config.translation_validation, ( + "translation validation requires Z3 package. Please, either install " + "z3-solver or disable translation validation." + ) + + +class ValidationException(TorchDynamoException): + def __init__(self, model, assertions, target_exprs, failed_source_exprs): + assert _HAS_Z3 + + def symbolstr(sym) -> str: + return f"{sym}: {model[sym]}" + + def joinlines(xs) -> str: + return "\n".join(f" ==> {x}" for x in xs) + + model_str = joinlines(sorted(map(symbolstr, model))) + assertions_str = joinlines(sorted(map(z3str, assertions))) + target_exprs_str = joinlines(sorted(map(z3str, target_exprs))) + failed_source_exprs_str = joinlines(sorted(map(z3str, failed_source_exprs))) + + self.msg = "translation validation failed." + self.details = f"""\ +Model: +{model_str} + +Assertions: +{assertions_str} + +Target Expressions: +{target_exprs_str} + +Failed Source Expressions: +{failed_source_exprs_str}""" + + def __str__(self): + return f"{self.msg}\n\n{self.details}" + + +class BisectValidationException(TorchDynamoException): + def __init__(self, validation_exc, expr, failed_action, traced_node): + self.msg = f"translation validation failed when {failed_action}: {expr}" + self.details = f"""\ +Failure occurred while running node: + {traced_node.format_node()} + +{validation_exc.details}""" + + def __str__(self): + return f"{self.msg}\n\n{self.details}" + + +# Checks when this module is loaded. +_assert_z3_installed_if_tv_set() + + +# Translation validation bisection. +# +# Bisect into the torch._assert nodes recorded in the shape_env FX graph, and raise +# the earliest ValidationException. +# +# As guards are added by ShapeEnv.evaluate_expr calls, some simplification errors +# might be silently happening. This function tries to nail down exactly at which +# point things went wrong from a validation perspective. +def bisect(shape_env): + from torch.fx.experimental.recording import ( + FakeTensorMeta, + replay_shape_env_events, + ShapeEnvEvent, + ) + from torch.fx.experimental.symbolic_shapes import ( + CURRENT_NODE_KEY, + ShapeEnv, + SHAPEENV_EVENT_KEY, + ) + + events = shape_env.events + + # Retrieves the ShapeEnvEvent associated with node. + def get_node_event(node: torch.fx.Node) -> ShapeEnvEvent: + assert SHAPEENV_EVENT_KEY in node.meta + return events[node.meta[SHAPEENV_EVENT_KEY]] + + # Creates a new instance of fake, but updating every symbolic value's ShapeEnv + # reference to the one given as argument. + # + # This is needed so as not to simplify a symbolic expression using a ShapeEnv + # "from the future", where it may have a different set of replacements. + def new_with_shape_env(shape_env: ShapeEnv, fake) -> Any: + if isinstance(fake, int): + return fake + if isinstance(fake, torch.SymInt): + return torch.SymInt(fake.node.with_shape_env(shape_env)) + if isinstance(fake, torch.SymFloat): + return torch.SymFloat(fake.node.with_shape_env(shape_env)) + assert isinstance(fake, FakeTensorMeta) + return FakeTensorMeta( + tuple(new_with_shape_env(shape_env, s) for s in fake.size()), + tuple(new_with_shape_env(shape_env, s) for s in fake.stride()), + new_with_shape_env(shape_env, fake.storage_offset()), + fake.is_nested, + ) + + # Checks whether the given shape_env fails when produce_guards is called. + def check_shapeenv_fails( + shape_env: ShapeEnv, tracked_fakes: Optional[list[Any]] + ) -> Optional[ValidationException]: + assert tracked_fakes is not None + try: + # This produce_guards call is a best-effort replication, since we + # don't populate EqualityConstraint list. Reason: we would also have + # to save OutputGraph.tracked_fakes_id_to_source. + shape_env.produce_guards( + [new_with_shape_env(shape_env, a.fake) for a in tracked_fakes], + [a.source for a in tracked_fakes], + input_contexts=[a.symbolic_context for a in tracked_fakes], + ) + return None + except ValidationException as e: + return e + + # Checks whether the ShapeEnv reconstructed by replaying the events until + # node is created fails when produce_guards is called. + def check_node_fails(node: torch.fx.Node) -> Optional[ValidationException]: + number = node.meta[SHAPEENV_EVENT_KEY] + # Reconstruct shape_env until the event at event_number. + shape_env = replay_shape_env_events(events[: number + 1]) + shape_env.graph.lint() + return check_shapeenv_fails(shape_env, events[number].tracked_fakes) + + last_exception = check_shapeenv_fails( + shape_env, shape_env._snapshot_tracked_fakes() + ) + + if not last_exception: + # We don't actually fail due to a produce_guards call. + # Stop and don't bisect. + log.info("translation validation succeeded: no errors found.") + return + + if not shape_env.should_record_events or config.translation_validation_no_bisect: + # Bisection is off. + # Return the last ValidationException we got. + raise last_exception + + # Cache the raised exception (if any) at each bisection point. + exception = {} + + # Bisection happens on the assertion nodes of the recorded FX graph for + # dynamic shapes. + assert_nodes = [ + node for node in shape_env.graph.nodes if node.target is torch._assert + ] + + # Preparing the indices for binary search. + # The overall invariants are + # - for all i < left, assert_node[i] doesn't fail + # - for all i >= right, assert_node[i] fails + # - `right in exception` always holds + # - `left <= right` always holds + left, mid, right = 0, 0, len(assert_nodes) - 1 + exception[right] = check_node_fails(assert_nodes[right]) + + while left < right: + mid = (left + right) // 2 + + node = assert_nodes[mid] + log.debug("bisecting at %s: %s", mid, get_node_event(node)) + + # Check whether the new shape_env raises a ValidationException or not. + exception[mid] = check_node_fails(node) + + if exception[mid]: + right = mid + else: + left = mid + 1 + + assert left in exception and isinstance(exception[left], ValidationException) + + node = assert_nodes[left] + event = get_node_event(node) + + if event.is_evaluate_expr(): + failed_action = "evaluating" + else: + assert event.is_defer_runtime_assert(), f"unexpected event type: {event}" + failed_action = "adding runtime assert" + + args = event.args + assert args is not None + assert len(args) >= 2, ( + f"bisecting expects {event.name} to have at least 2 positional arguments. " + f"Got: {len(args)}" + ) + assert isinstance(args[1], sympy.Basic), ( + f"bisecting expects {event.name} to have a SymPy expression as its second argument. " + f"Got: {type(args[1])}" + ) + + raise BisectValidationException( + exception[left], + expr=args[1], + failed_action=failed_action, + traced_node=node.meta[CURRENT_NODE_KEY], + ) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/passes/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/passes/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cffbe4fd13de3fe165b13cddda2b8cb1697f6179 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/passes/__pycache__/__init__.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/passes/__pycache__/_tensorify_python_scalars.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/passes/__pycache__/_tensorify_python_scalars.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dea149f66aea38dd2d7629193a3d5f80056c613f Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/passes/__pycache__/_tensorify_python_scalars.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/passes/__pycache__/annotate_getitem_nodes.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/passes/__pycache__/annotate_getitem_nodes.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1037256fc9c3bcdd40a3206e94c8c0ffe1200cea Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/passes/__pycache__/annotate_getitem_nodes.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/passes/__pycache__/fake_tensor_prop.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/passes/__pycache__/fake_tensor_prop.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1f0e3afbe5b687a4c77598dc103e07e932f86694 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/passes/__pycache__/fake_tensor_prop.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/passes/__pycache__/graph_drawer.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/passes/__pycache__/graph_drawer.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..86a6887de531c67cd5d6594f5ad2f7beba281472 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/passes/__pycache__/graph_drawer.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/passes/__pycache__/graph_manipulation.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/passes/__pycache__/graph_manipulation.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b8cec1b03e463b97784a156fc52ee2eeeaccb5f3 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/passes/__pycache__/graph_manipulation.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/passes/__pycache__/graph_transform_observer.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/passes/__pycache__/graph_transform_observer.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..446389e7146a08b02cdc7ee0de1919b303d0da77 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/passes/__pycache__/graph_transform_observer.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/passes/__pycache__/net_min_base.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/passes/__pycache__/net_min_base.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..52c7b487a14e73a2900b2d37cc7dd8b29bb82e55 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/passes/__pycache__/net_min_base.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/passes/__pycache__/operator_support.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/passes/__pycache__/operator_support.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eb037757b649399d46763b5c9b4506ea1e31a52a Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/passes/__pycache__/operator_support.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/passes/__pycache__/param_fetch.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/passes/__pycache__/param_fetch.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6c34037e5f5f27ffbfb4864ef053b4a8817b3ef7 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/passes/__pycache__/param_fetch.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/passes/__pycache__/pass_manager.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/passes/__pycache__/pass_manager.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6fc846f129d67cb07c432cd5e54f2dca6c1f7054 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/passes/__pycache__/pass_manager.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/passes/__pycache__/regional_inductor.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/passes/__pycache__/regional_inductor.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ab2c98790e1eb36b917e5ac4377f22bf01fc4cca Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/passes/__pycache__/regional_inductor.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/passes/__pycache__/reinplace.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/passes/__pycache__/reinplace.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f1ffaa2ccb20b45236b341ade032cb13f9944867 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/passes/__pycache__/reinplace.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/passes/__pycache__/runtime_assert.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/passes/__pycache__/runtime_assert.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2ed1cbb41cd4187d6ae1e62808a6b5a2c7b7671c Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/passes/__pycache__/runtime_assert.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/passes/__pycache__/shape_prop.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/passes/__pycache__/shape_prop.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d5e8870c9f0b992619990ef5f6f6de74803c7ea7 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/passes/__pycache__/shape_prop.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/passes/__pycache__/split_module.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/passes/__pycache__/split_module.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eef41d0f62ecc5b84a0123380fbfd98786399270 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/passes/__pycache__/split_module.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/passes/__pycache__/split_utils.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/passes/__pycache__/split_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f03cd34b62c06654522dbde82c7d13e96856b267 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/passes/__pycache__/split_utils.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/passes/__pycache__/splitter_base.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/passes/__pycache__/splitter_base.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2905d2fb0d7212f3c4de12405ed6034749c75409 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/passes/__pycache__/splitter_base.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/passes/__pycache__/tools_common.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/passes/__pycache__/tools_common.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..58588e8028480b878a3721b0977b118292ce6bb2 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/passes/__pycache__/tools_common.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/passes/backends/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/passes/backends/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/passes/backends/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/passes/backends/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e833ce3a4127cffb841686395b88c4e41645c073 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/passes/backends/__pycache__/__init__.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/passes/backends/__pycache__/cudagraphs.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/passes/backends/__pycache__/cudagraphs.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5f88752c66f2916c99ebd95573a55d67c24b499e Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/passes/backends/__pycache__/cudagraphs.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/passes/backends/cudagraphs.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/passes/backends/cudagraphs.py new file mode 100644 index 0000000000000000000000000000000000000000..97496fbc9b2a2439b687bc09c58bb4031b8fc670 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/passes/backends/cudagraphs.py @@ -0,0 +1,61 @@ +# mypy: allow-untyped-defs +import operator + +import torch +from torch.fx.passes.fake_tensor_prop import FakeTensorProp +from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner +from torch.fx.passes.operator_support import OperatorSupport +from torch.fx.passes.tools_common import CALLABLE_NODE_OPS +from torch.utils import _pytree as pytree + + +class CudaGraphsSupport(OperatorSupport): + # TODO: why is submodules passed here + def is_node_supported(self, submodules, node: torch.fx.Node) -> bool: + if node.op not in CALLABLE_NODE_OPS: + return False + + if node.target is torch.ops.aten.embedding_dense_backward.default: + return False + + if node.target is operator.getitem: + return True + + found_not_cuda = False + + def meta_fk(meta): + return meta["val"] if "val" in meta else meta["fake_result"] + + def find_not_cuda(t): + nonlocal found_not_cuda + if isinstance(t, torch.Tensor) and t.device.type != "cuda": + found_not_cuda = True + + for n in node.all_input_nodes: + pytree.tree_map_(find_not_cuda, meta_fk(n.meta)) + + pytree.tree_map_(find_not_cuda, meta_fk(node.meta)) + + # NB: factory function is accounted for because the result would be + # cpu or cuda + + return not found_not_cuda + + +def partition_cudagraphs(gm, inputs): + """ + Partition an FX graph into sub-GraphModules that can be validly run under + CUDA graphs. For a subgraph to be runnable under CUDA, all of the operations + must involve CUDA tensors only/ + """ + + FakeTensorProp(gm).propagate(*inputs) + supported_ops = CudaGraphsSupport() + # TODO: single node partition may be wrong due to the pessimization + # from copying in and out the data. Check in benchmarks, perhaps + partitioner = CapabilityBasedPartitioner( + gm, supported_ops, allows_single_node_partition=True + ) + partitions = partitioner.propose_partitions() + fused_graph = partitioner.fuse_partitions(partitions) + return fused_graph diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/passes/dialect/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/passes/dialect/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/passes/dialect/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/passes/dialect/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3934e3dc69e574777cb7b656e7d2240ff1766090 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/passes/dialect/__pycache__/__init__.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/passes/dialect/common/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/passes/dialect/common/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/passes/dialect/common/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/passes/dialect/common/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cb82b9b41319ea915fdb3c5f5595f1b2ea814a39 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/passes/dialect/common/__pycache__/__init__.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/passes/dialect/common/__pycache__/cse_pass.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/passes/dialect/common/__pycache__/cse_pass.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..77e95e9069ef1c5a47240b210a349e2ae73711e2 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/passes/dialect/common/__pycache__/cse_pass.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/passes/dialect/common/cse_pass.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/passes/dialect/common/cse_pass.py new file mode 100644 index 0000000000000000000000000000000000000000..e5889375bb07ae0f56917aff9950db67ff3f4bec --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/passes/dialect/common/cse_pass.py @@ -0,0 +1,155 @@ +# mypy: allow-untyped-defs +from typing import Any + +import torch +from torch.fx import Graph, GraphModule, Node +from torch.fx.passes.infra.pass_base import PassBase, PassResult +from torch.utils._pytree import tree_flatten + + +aten = torch.ops.aten + + +# stateful ops are banned from CSE +rand_ops = { + aten.dropout, + aten._fused_dropout, + aten._standard_gamma, + aten.bernoulli, + aten.multinomial, + aten.native_dropout, + aten.normal, + aten.poisson, + aten.binomial, + aten.rrelu, + aten.rand_like, + aten.rand, + aten.randint, + aten.randn, + aten.randperm, +} # noqa: E501,B950 + +inplace_ops = { + aten.add_, + aten.sub_, + aten.mul_, + aten.div_, + aten.pow_, + aten.lerp_, + aten.relu_, + aten.sigmoid_, + aten.tanh_, +} # noqa: E501 + + +@torch.fx._compatibility.compatibility(is_backward_compatible=False) +def get_CSE_banned_ops(): + return rand_ops.union(inplace_ops) + + +@torch.fx._compatibility.compatibility(is_backward_compatible=False) +class CSEPass(PassBase): + def __init__(self, banned_ops=None): + """ + This version of CSE Pass aims to be dialect agnostic, and it's implemented purely based on the connectivity between fx.Node. + + For functional dialects, user would only need to specify the random ops in ban list. + + Warning: CSE Pass cannot be safely applied on a FX graph in non-functional dialects. + If your dialect contains stateful operators, please customized the banned_ops. + + """ + if banned_ops is None: + banned_ops = set() + self.banned_ops = banned_ops + super().__init__() + + def call(self, graph_module: GraphModule) -> PassResult: + """ + Return a new copy of torch.fx.GraphModule with CSE applied to the input graph + + Example usage: + + from torch.fx.experimental.proxy_tensor import make_fx + def f(a): + b = a * a + c = a * a + return b+c + + p = CSEPass() + traced_graph = make_fx(f)(torch.tensor(1)) + print(traced_graph) + result = p(traced_graph) + print(result.graph_module) + """ + + def get_aten_target(node): + if hasattr(node.target, "overloadpacket"): + return node.target.overloadpacket + return node.target + + modified = False + new_graph = Graph() + env: dict[ + Node, Node + ] = {} # map from node in the old graph to node in the new graph + hash_env: dict[ + tuple[torch._ops.OpOverload, int], Node + ] = {} # map from hash to a node in the new graph + token_map: dict[ + tuple[torch._ops.OpOverload, int], dict[str, Any] + ] = {} # map from hash to token + for n in graph_module.graph.nodes: + # The placeholder, output, and get_attr nodes are copied to the new graph without change + # do not CSE away random operations + if ( + n.op == "placeholder" + or n.op == "output" + or n.op == "get_attr" + or get_aten_target(n) in self.banned_ops + ): + new_node = new_graph.node_copy(n, lambda x: env[x]) + env[n] = new_node + else: # n.op == 'call_function', should never see n.op == 'call_module' or 'call_method' + # substitute args and kwargs members to their mapping in env if exists + # specs can be used to reconstruct nested list/dictionaries + def substitute(arg_list): + arg_list, spec = tree_flatten(arg_list) + for i in range(len(arg_list)): + v = arg_list[i] + if isinstance(v, Node) and v in env: + arg_list[i] = env[v] + return tuple(arg_list), spec + + args, args_spec = substitute(n.args) + kwargs, kwargs_spec = substitute(n.kwargs) + + # each token corresponds to a unique node + # nodes with the same token can be substituted + token = { + "target": n.target, + "args": args, + "args_spec": args_spec, + "kwargs": kwargs, + "kwargs_spec": kwargs_spec, + } + + # hash substituted args to a number, do not hash specs because specs are not hashable + hash_arg = hash((args, kwargs)) + hash_val = (n.target, hash_arg) + + # check if a node has a substitute and can be eliminated + hash_val_in_hash_env = hash_val in hash_env + if hash_val_in_hash_env and token_map[hash_val] == token: + modified = True # substitution happens and the graph is modified + env[n] = hash_env[hash_val] + continue + + new_node = new_graph.node_copy(n, lambda x: env[x]) + env[n] = new_node + if not hash_val_in_hash_env: + hash_env[hash_val] = new_node + token_map[hash_val] = token + + csed_gm = GraphModule(graph_module, new_graph) + return PassResult(csed_gm, modified) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/passes/infra/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/passes/infra/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..939157f1302e75e3cf17ec3c1e93d1b8993d67a0 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/passes/infra/__init__.py @@ -0,0 +1 @@ +from . import pass_manager diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/passes/infra/partitioner.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/passes/infra/partitioner.py new file mode 100644 index 0000000000000000000000000000000000000000..7bb536dbba9399d9ae5d5df53966f75f3a2a18a8 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/passes/infra/partitioner.py @@ -0,0 +1,400 @@ +# mypy: allow-untyped-defs +import collections +import itertools +import logging +import operator +from collections.abc import Iterable, Sequence +from typing import Optional + +from torch.fx.graph_module import GraphModule +from torch.fx.node import _get_qualified_name, Node +from torch.fx.passes.operator_support import OperatorSupportBase +from torch.fx.passes.utils.fuser_utils import fuse_by_partitions + + +logger = logging.getLogger(__name__) +logger.setLevel(logging.WARNING) + + +class Partition: + def __init__( + self, + id: Optional[int] = None, + nodes: Optional[Iterable[Node]] = None, + node_orders: Optional[Iterable[int]] = None, + ): + self.id = id + self.nodes: dict[Node, Optional[int]] = {} + if nodes is not None: + if node_orders is None: + self.nodes = dict.fromkeys(nodes, None) + else: + nodes_list = list(nodes) + node_orders_list = list(node_orders) + assert len(nodes_list) == len(node_orders_list), ( + "nodes and node_orders must have the same length" + ) + self.nodes = dict(zip(nodes_list, node_orders_list)) + + def __repr__(self) -> str: + return str(self.nodes) + + def add_node(self, node: Node, node_order: Optional[int] = None): + self.nodes.update({node: node_order}) + + def remove_node(self, node: Node): + del self.nodes[node] + + def size(self): + return len(self.nodes) + + +class _DependencyViewer: + def __init__(self, graph_module: GraphModule): + self.downstreams = collections.defaultdict(set) + + for node in reversed(graph_module.graph.nodes): + for output_node in node.users: + # add output_node and output_node's downstream dependency + self.downstreams[node].add(output_node) + self.downstreams[node].update(self.downstreams[output_node]) + + def downstreams_of(self, node: Node) -> set[Node]: + return self.downstreams[node] + + +class CapabilityBasedPartitioner: + def __init__( + self, + graph_module: GraphModule, + operator_support: OperatorSupportBase, + allows_single_node_partition: bool = False, + non_compute_ops: Optional[Sequence[str]] = None, + allowed_single_node_partition_ops: Optional[Sequence[str]] = None, + ) -> None: + self.graph_module = graph_module + self.operator_support = operator_support + self.allows_single_node_partition = allows_single_node_partition + self.non_compute_ops = non_compute_ops if non_compute_ops is not None else [] + self.allowed_single_node_partition_ops = ( + allowed_single_node_partition_ops + if allowed_single_node_partition_ops is not None + else [] + ) + self.dependency_viewer = _DependencyViewer(graph_module) + + def _is_node_supported(self, node: Node) -> bool: + return self.operator_support.is_node_supported( + dict(self.graph_module.named_modules()), node + ) + + def propose_partitions(self) -> list[Partition]: + # partition_map is a mapping from partition id to a set of partition id's. + # The value set contains all the partition ids that can be reached by doing a + # DFS starting from the partition id in the key. + partition_map: dict[int, set] = collections.defaultdict(set) + + # assumptions: nodes in candidate list is sorted in topological order + assignment: dict[Node, int] = {} # mapping from node to partition_id + partitions_by_id: dict[ + int, Partition + ] = {} # mapping from partition_id to partition + nodes_order: dict[ + Node, int + ] = {} # mapping from nodes to reversed topological order + partitions_order: dict[ + int, int + ] = {} # mapping from partition_id to minimum topo order of nodes in partition + partition_users: dict[ + int, set + ] = {} # mapping from partition_id to partition users + new_partition_id = itertools.count() + + # try to merge partition other_id into partition self_id + # merge only happens if the end graph doesn't contain cyclic dependency + # returns `True` when merge happens, `False` otherwise. + def maybe_merge_partition(self_id: int, other_id: int): + # merged_nodes is the union of nodes in two partition to-be-merged + self_nodes = partitions_by_id[self_id].nodes + other_nodes = partitions_by_id[other_id].nodes + + def dfs_iter_find_cycle(all_user_nodes: set[Node]): + for user_node in all_user_nodes: + visited_partition_ids = set() + + for path_node in self.dependency_viewer.downstreams_of(user_node): + # If any of the nodes in the dfs path of this node are in the merged_nodes + # list then there is a cycle in the graph. + if path_node in self_nodes or path_node in other_nodes: + return True + + # If any of the nodes in the dfs path of this node are in the assignment + # map then we have to make sure that the partitions that these nodes belong + # to do not form a cycle with the current partitions being merged. This means + # iterating through all the nodes in all the parititons that are traversed in + # the dfs path and checking if they are in the merged_nodes list. + if path_node in assignment: + partition_id = assignment[path_node] + # If the partition id has already been visited then we know that it doesn't + # form a cycle with the current partitions being merged. + if partition_id in visited_partition_ids: + continue + p_map = partition_map[partition_id] + if self_id in p_map or other_id in p_map: + return True + + visited_partition_ids.add(partition_id) + + return False + + # find new partition users if merge. + all_user_nodes = partition_users[self_id] | partition_users[other_id] + all_user_nodes.difference_update(other_nodes, self_nodes) + + # check if merge would create cyclic dependency. + if dfs_iter_find_cycle(all_user_nodes): + # return false indicating cyclic dependency found and + # merge is aborted + return self_id, False + + # merge the smaller partition into the larger. + merge_id, removed_id = self_id, other_id + if len(self_nodes) < len(other_nodes): + merge_id, removed_id = removed_id, merge_id + # no cyclic dependency found, move forward with the merge + # updating partition nodes + partitions_by_id[merge_id].nodes.update(partitions_by_id[removed_id].nodes) + # updating assignment map + for node in partitions_by_id[removed_id].nodes: + assignment[node] = merge_id + # delete other partition + del partitions_by_id[removed_id] + + partitions_order[merge_id] = min( + partitions_order[merge_id], partitions_order[removed_id] + ) + del partitions_order[removed_id] + + partition_map[merge_id] = partition_map[merge_id].union( + partition_map[removed_id] + ) + del partition_map[removed_id] + + partition_users[merge_id] = all_user_nodes + del partition_users[removed_id] + + return merge_id, True + + def merge_single_node(node: Node, node_order: Optional[int], id: Optional[int]): + def _update_partition_map(node: Node, id: int): + # Iterate through all the users of this node and update the partition map to indicate + # that there is a path from the partition id of this node to the target partition id. + for user_node in node.users: + target_id = assignment.get(user_node) + if target_id is not None: + partition_map[id].add(target_id) + partition_map[id].update(partition_map[target_id]) + + if node in assignment: + partitions_by_id[assignment[node]].remove_node(node) + + if id is None: + assignment.pop(node) + elif id not in partitions_by_id: + assignment[node] = id + assert node_order is not None + partitions_by_id[id] = Partition( + id=id, nodes=[node], node_orders=[node_order] + ) + partition_users[id] = set(node.users) + _update_partition_map(node, id) + else: + assignment[node] = id + partitions_by_id[id].add_node(node, node_order) + + logger.debug("Proposing partitions...") + + for node_order, node in enumerate(reversed(self.graph_module.graph.nodes)): + # use Dict as an ordered set to ensure deterministic partitioning result, don't care value + merge_candidates: dict[int, None] = {} + + # Note a limited horizontal fusion is enabled: + # when `node` is not supported, the code below attempts to fuse consumer of `node`. + # + # I don't see a need to add a knob to disable horizontal fusion yet, we can short-cut + # the fusion by adding an `else` block here to skip horizontal fusion. + if self._is_node_supported(node) and node not in assignment: + partition_id = next(new_partition_id) + nodes_order[node] = partition_id + partitions_order[partition_id] = partition_id + merge_single_node(node, node_order, partition_id) + merge_candidates[partition_id] = None + + # merge all possible partitions + for partition_id, _ in sorted( + partitions_order.items(), key=operator.itemgetter(1) + ): + merge_candidates[partition_id] = None + + merge_candidates_list = list(merge_candidates.keys()) + if len(merge_candidates_list) > 1: + self_id = merge_candidates_list[0] + for other_id in merge_candidates_list[1:]: + # note: merge partitions if it doesn't create cyclic dependency + # in the graph, otherwise, this is a no-op + self_id, _ = maybe_merge_partition(self_id, other_id) + + # sort partition nodes based on descending node order + for partition in partitions_by_id.values(): + partition.nodes = dict( + sorted( + partition.nodes.items(), key=operator.itemgetter(1), reverse=True + ) + ) + + # post processing to re-assign "getitem" nodes into upstream partition + logger.debug("Reassigning getitem nodes to its producer node's partition...") + nodes_reassignment: dict[Node, int] = {} + for node in self.graph_module.graph.nodes: + is_tuple_output = True + for user in node.users: + if ( + user.op != "call_function" + or _get_qualified_name(user.target) != "_operator.getitem" + ): # type: ignore[arg-type] + is_tuple_output = False + break + + # node has tuple outputs, re-assign all following getitem node into node's partition + if is_tuple_output: + id = assignment.get(node) # type: ignore[arg-type] + for user in node.users: + if assignment.get(user) != id: # type: ignore[arg-type] + nodes_reassignment[user] = id # type: ignore[assignment] + for node, id in nodes_reassignment.items(): + merge_single_node(node, None, id) + + # filter out single node partitions + if not self.allows_single_node_partition: + logger.debug("Filtering out single node partitions...") + default_non_compute_ops = {"torch.ops.aten.view", "_operator.getitem"} + non_compute_ops = default_non_compute_ops.union(set(self.non_compute_ops)) + partitions_to_remove: list[int] = [] + for id, partition in partitions_by_id.items(): + compute_node_count = 0 + for node in partition.nodes: + if node.op == "call_function": + assert callable(node.target) + if _get_qualified_name(node.target) not in non_compute_ops: + compute_node_count += 1 + if ( + _get_qualified_name(node.target) + in self.allowed_single_node_partition_ops + ): + compute_node_count += 1 + if compute_node_count <= 1: + partitions_to_remove.append(id) + for id in partitions_to_remove: + del partitions_by_id[id] + + logger.debug("Partitions proposed:") + for id, partition in partitions_by_id.items(): + logger.debug( + "partition #%s: %s", id, [node.name for node in partition.nodes] + ) + + return [ + partition for partition in partitions_by_id.values() if partition.size() > 0 + ] + + def fuse_partitions( + self, partitions: list[Partition], prefix: str = "fused_" + ) -> GraphModule: + logger.debug("Fusing partitions...") + # fuse_by_partitions expects partitions in List[Dict[Node, None]]: [ {node0 : None}, {node1 : None} ] + return fuse_by_partitions( + self.graph_module, + [partition.nodes for partition in partitions], + prefix=prefix, + ) + + # remove non-compute-ops that sits at the boundary of a partition. + def remove_bookend_non_compute_ops(self, partitions: list[Partition]): + non_compute_ops = set(self.non_compute_ops) + + def is_non_compute_node(node: Node): + return ( + node.op == "call_function" + and _get_qualified_name(node.target) in non_compute_ops # type: ignore[arg-type] + ) + + # cache transparent nodes + transparent_input_nodes: dict[Node, bool] = {} + transparent_output_nodes: dict[Node, bool] = {} + + def is_transparent_input_node( + node: Node, partition: set[Node], removed_nodes: set[Node] + ): + if ( + node.op == "placeholder" + or (node not in partition) + or (node in removed_nodes) + ): + return True + if node in transparent_input_nodes: + return transparent_input_nodes[node] + if is_non_compute_node(node): + for input_n in node.all_input_nodes: + if not is_transparent_input_node(input_n, partition, removed_nodes): + transparent_input_nodes[node] = False + return False + transparent_input_nodes[node] = True + return True + transparent_input_nodes[node] = False + return False + + def is_transparent_output_node( + node: Node, partition: set[Node], removed_nodes: set[Node] + ): + if ( + node.op == "placeholder" + or (node not in partition) + or (node in removed_nodes) + ): + return True + if node in transparent_output_nodes: + return transparent_output_nodes[node] + if is_non_compute_node(node): + for output_n in node.users: + if not is_transparent_output_node( + output_n, partition, removed_nodes + ): + transparent_output_nodes[node] = False + return False + transparent_output_nodes[node] = True + return True + transparent_output_nodes[node] = False + return False + + for partition in partitions: + # Note it's ok to use `set` here, since we are only query if a node + # has been removed. We are NEVER going to iterate on nodes inside + # the set. + remove_node: set[Node] = set() + for node in partition.nodes: + if is_non_compute_node(node) and ( + is_transparent_input_node(node, set(partition.nodes), remove_node) + or is_transparent_output_node( + node, set(partition.nodes), remove_node + ) + ): + remove_node.add(node) + + if len(remove_node) != 0: + for node in remove_node: + partition.nodes.pop(node, None) + + def partition_and_fuse(self, prefix: str = "fused_") -> GraphModule: + partitions = self.propose_partitions() + fused_gm = self.fuse_partitions(partitions, prefix=prefix) + return fused_gm diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/passes/infra/pass_base.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/passes/infra/pass_base.py new file mode 100644 index 0000000000000000000000000000000000000000..32c641031b31f2c49ca76daac6751b356e740213 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/passes/infra/pass_base.py @@ -0,0 +1,79 @@ +# mypy: allow-untyped-defs +import abc +from collections import namedtuple +from typing import Optional + +from torch.fx._compatibility import compatibility +from torch.fx.graph_module import GraphModule + + +__all__ = ["PassResult", "PassBase"] + + +@compatibility(is_backward_compatible=False) +# pyrefly: ignore [invalid-inheritance] +class PassResult(namedtuple("PassResult", ["graph_module", "modified"])): + """ + Result of a pass: + graph_module: The modified graph module + modified: A flag for if the pass has modified the graph module + """ + + __slots__ = () + + def __new__(cls, graph_module, modified): + return super().__new__(cls, graph_module, modified) + + +@compatibility(is_backward_compatible=False) +class PassBase(abc.ABC): + """ + Base interface for implementing passes. + + It is required to implement the `call` function so that we can directly + pass instances of the Pass directly to the PassManager and call them as a + function. + + We can directly pass an instance of a class implementing this interface into + the PassManager's `passes` attribute. + """ + + def __call__(self, graph_module: GraphModule) -> Optional[PassResult]: + """ + Runs the precondition check, the pass itself, and the postcondition check. + """ + + self.requires(graph_module) + res = self.call(graph_module) + self.ensures(graph_module) + return res + + @abc.abstractmethod + def call(self, graph_module: GraphModule) -> Optional[PassResult]: + """ + The pass that is run through the given graph module. To implement a + pass, it is required to implement this function. + + Args: + graph_module: The graph module we will run a pass on + """ + + def requires(self, graph_module: GraphModule) -> None: # noqa: B027 + """ + This function will be called before the pass is run and will check that + the given graph module contains the preconditions needed to run the + pass. It is not required to implement this function. + + Args: + graph_module: The graph module we will run checks on + """ + + def ensures(self, graph_module: GraphModule) -> None: # noqa: B027 + """ + This function will be called after the pass is run and will check that + the given graph module contains the postconditions needed to run the + pass. It is not required to implement this function. + + Args: + graph_module: The graph module we will run checks on + """ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/passes/infra/pass_manager.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/passes/infra/pass_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..87fb6e70037f9a00c46143f87efc5a832a7db3ae --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/passes/infra/pass_manager.py @@ -0,0 +1,310 @@ +# mypy: allow-untyped-defs +import inspect +import logging +from collections.abc import Callable +from functools import wraps +from queue import Queue + +import torch.nn as nn +from torch.fx._compatibility import compatibility +from torch.fx.graph_module import GraphModule +from torch.fx.passes.infra.pass_base import PassResult + + +logger = logging.getLogger(__name__) +logger.setLevel(logging.WARNING) + +__all__ = ["pass_result_wrapper", "this_before_that_pass_constraint", "PassManager"] + + +@compatibility(is_backward_compatible=False) +def pass_result_wrapper(fn: Callable) -> Callable: + """ + Wrapper for passes which currently do not return a PassResult. + This wrapper makes them return a PassResult containing the modified object + and True for the "modified" flag. + + Args: + fn (Callable[Module, Any]) + + Returns: + wrapped_fn (Callable[Module, PassResult]) + """ + if fn is None: + # pyrefly: ignore [bad-return] + return None + + @wraps(fn) + def wrapped_fn(gm): + res = fn(gm) + if res is None: + return PassResult(gm, True) + if isinstance(res, PassResult): + return res + elif isinstance(res, nn.Module): + return PassResult(res, True) + + if not inspect.isfunction(fn): + wrapped_fn.__name__ = type(fn).__name__ + + return wrapped_fn + + +def _validate_pass_schedule_constraint( + constraint: Callable[[Callable, Callable], bool], passes: list[Callable] +) -> None: + for i, a in enumerate(passes): + for j, b in enumerate(passes[i + 1 :]): + if constraint(a, b): + continue + raise RuntimeError( + f"pass schedule constraint violated. Expected {a} before {b}" + f" but found {a} at index {i} and {b} at index{j} in pass" + f" list." + ) + + +def _topological_sort_passes( + passes: list[Callable], constraints: list[Callable] +) -> list[Callable]: + """ + Args + passes: Passes that we are ordering + constraints: Constraints applied on these passes + + Returns + A sorted list of callables and a boolean of if a circular dependency + existed + """ + if len(constraints) == 0: + return passes + + # Construct a graph mapping nodes to a list of their users + graph: dict[Callable, list[Callable]] = {p: [] for p in passes} + indegree_map: dict[Callable, int] = dict.fromkeys(passes, 0) + candidates: Queue = Queue() + for a in passes: + for b in passes: + if a == b: + continue + + for constraint in constraints: + if not constraint(a, b): + graph[b].append(a) + indegree_map[a] += 1 + + if indegree_map[a] == 0: + candidates.put(a) + + visited: dict[Callable, bool] = dict.fromkeys(passes, False) + sorted_passes: list[Callable] = [] + + while not candidates.empty(): + p = candidates.get() + sorted_passes.append(p) + visited[p] = True + + for n in graph[p]: + if not visited[n]: + indegree_map[n] -= 1 + if indegree_map[n] == 0: + candidates.put(n) + + # Check if there are unvisited nodes (aka cycles in the graph) + cycle_passes = list(filter(lambda p: indegree_map[p] != 0, indegree_map.keys())) + if len(cycle_passes) != 0: + error = ( + f"Circular dependency detected within the following passes: {cycle_passes}" + ) + raise RuntimeError(error) + + return sorted_passes + + +@compatibility(is_backward_compatible=False) +def this_before_that_pass_constraint(this: Callable, that: Callable) -> Callable: + """ + Defines a partial order ('depends on' function) where `this` must occur + before `that`. + + For example, the following pass list and constraint list would be invalid. + ``` + passes = [pass_b, pass_a] + + constraints = [this_before_that_pass_constraint(pass_a, pass_b)] + ``` + + Args: + this (Callable): pass which should occur first + that (Callable): pass which should occur later + + Returns: + depends_on (Callable[[Object, Object], bool] + """ + + def depends_on(a: Callable, b: Callable): + return a != that or b != this + + return depends_on + + +@compatibility(is_backward_compatible=False) +class PassManager: + """ + Construct a PassManager. + + Collects passes and constraints. This defines the pass schedule, manages + pass constraints and pass execution. + + Args: + passes (Optional[List[Callable]]): List of passes. A pass is a + callable which modifies an object and returns a PassResult + constraint (Optional[List[Callable]]): List of constraints. A + constraint is a callable which takes two passes (A, B) and returns + True if A depends on B and False otherwise. See implementation of + `this_before_that_pass_constraint` for example. + steps (int): Max number of times we run the passes (default = 1). + run_checks_after_each_pass (bool): Whether to run checks and linting + after each pass + suppress_check_failures (bool): Whether to raise errors when running + checks + """ + + passes: list[Callable[[nn.Module], PassResult]] + constraints: list[Callable[[Callable, Callable], bool]] + _validated: bool = False + steps: int = 1 + + def __init__( + self, + passes=None, + constraints=None, + steps=None, + run_checks_after_each_pass: bool = False, + suppress_check_failures: bool = False, + ): + self.passes = passes or [] + self.constraints = constraints or [] + if steps: + self.steps = steps + + self.run_checks_after_each_pass = run_checks_after_each_pass + self.suppress_check_failures = suppress_check_failures + + def add_pass(self, _pass: Callable): + """ + Adds a pass into the current list of passes. + """ + self.passes.append(_pass) + self._validated = False + + def add_constraint(self, constraint: Callable): + """ + Adds a constraint into the current list of constraints. + """ + self.constraints.append(constraint) + self._validated = False + + def validate_constraints(self): + """ + Validates that current pass schedule defined by `self.passes` is valid + according to all constraints in `self.constraints` + """ + if self._validated: + return + for constraint in self.constraints: + _validate_pass_schedule_constraint(constraint, self.passes) + self._validated = True + + def solve_constraints(self): + """ + Finds a valid traversal order based on the given constraints and orders + the passes based on this order. + + If a circular dependency exists between the constraints and steps = 1, + then we will raise an error because if steps != 1 this means that we + will re-run the passes, allowing for circular dependencies. + """ + self.passes = _topological_sort_passes(self.passes, self.constraints) + self._validated = True + + def add_checks(self, check: Callable) -> None: + """ + Adds a function which takes runs various checks on a given graph module. + This function is run before and after each pass if the + `run_checks_after_each_pass` flag is enabled. + """ + sig = inspect.signature(check) + + if len(list(sig.parameters.values())) != 1: + raise TypeError( + "PassManager check function should only take in one variable, a module" + ) + + setattr(self, "check", check) # noqa: B010 + + def check(self, module: nn.Module) -> None: + pass + + def __call__(self, module: nn.Module) -> PassResult: + """ + Runs a list of passes in the order based on `self.passes` on the given + graph module. Each time a pass is run, checks and linting will be run on + the graph module if `run_checks_after_each_pass` is set. + + If the module is a graph module, we will run the list of passes until + the graph stops changing, or until `steps` number of times. + """ + # Order the passes based on the constraints + if not self._validated: + self.solve_constraints() + + # Check graph invariants + self.check(module) + + # Run the set of passes `steps` number of times or until the graph stops + # changing + overall_modified = False + for _ in range(self.steps): + modified = False + + # Run the set of passes on the graph module + for i, fn in enumerate(self.passes): + fn_name = fn.__name__ if inspect.isfunction(fn) else type(fn).__name__ + logger.debug("Running pass '%s'", fn_name) + + try: + res = fn(module) + + if not isinstance(res, PassResult) and not hasattr( + res, "graph_module" + ): + raise TypeError( + f"The result of the pass {fn_name} should be type PassResult." + + "Please wrap it with pass_result_wrapper()" + ) + module = res.graph_module + modified = modified or res.modified + + if isinstance(module, GraphModule): + logger.debug("Graph after pass '%s': %s", fn_name, module.graph) + module.recompile() + + # Check graph invariants + if self.run_checks_after_each_pass: + self.check(module) + + except Exception as e: + prev_pass_names = [ + p.__name__ if inspect.isfunction(p) else type(p).__name__ + for p in self.passes[:i] + ] + msg = f"An error occurred when running the '{fn_name}' pass after the following passes: {prev_pass_names}" + raise Exception(msg) from e # noqa: TRY002 + + # If the graph no longer changes, then we can stop running these passes + overall_modified = overall_modified or modified + if not modified: + break + + return PassResult(module, overall_modified) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/passes/tests/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/passes/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/passes/tests/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/passes/tests/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f4e1bd496674d6ebc964bbda079c0d685a156909 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/passes/tests/__pycache__/__init__.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/passes/tests/__pycache__/test_pass_manager.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/passes/tests/__pycache__/test_pass_manager.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8e565ac0e4d2ba422756eb800ca05e3c7da90b96 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/passes/tests/__pycache__/test_pass_manager.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/passes/tests/test_pass_manager.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/passes/tests/test_pass_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..157dc4017eda576f10793ef46b78cd97b0f5074b --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/passes/tests/test_pass_manager.py @@ -0,0 +1,56 @@ +import unittest + +from ..pass_manager import ( + inplace_wrapper, + PassManager, + these_before_those_pass_constraint, + this_before_that_pass_constraint, +) + + +class TestPassManager(unittest.TestCase): + def test_pass_manager_builder(self) -> None: + passes = [lambda x: 2 * x for _ in range(10)] + pm = PassManager(passes) + pm.validate() + + def test_this_before_that_pass_constraint(self) -> None: + passes = [lambda x: 2 * x for _ in range(10)] + pm = PassManager(passes) + + # add unfulfillable constraint + pm.add_constraint(this_before_that_pass_constraint(passes[-1], passes[0])) + + self.assertRaises(RuntimeError, pm.validate) + + def test_these_before_those_pass_constraint(self) -> None: + passes = [lambda x: 2 * x for _ in range(10)] + constraint = these_before_those_pass_constraint(passes[-1], passes[0]) + pm = PassManager([inplace_wrapper(p) for p in passes]) + + # add unfulfillable constraint + pm.add_constraint(constraint) + + self.assertRaises(RuntimeError, pm.validate) + + def test_two_pass_managers(self) -> None: + """Make sure we can construct the PassManager twice and not share any + state between them""" + + passes = [lambda x: 2 * x for _ in range(3)] + constraint = these_before_those_pass_constraint(passes[0], passes[1]) + pm1 = PassManager() + for p in passes: + pm1.add_pass(p) + pm1.add_constraint(constraint) + output1 = pm1(1) + self.assertEqual(output1, 2**3) + + passes = [lambda x: 3 * x for _ in range(3)] + constraint = these_before_those_pass_constraint(passes[0], passes[1]) + pm2 = PassManager() + for p in passes: + pm2.add_pass(p) + pm2.add_constraint(constraint) + output2 = pm2(1) + self.assertEqual(output2, 3**3) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/passes/utils/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/passes/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ee5e7e66868a0776609ff7ffff458f6a91ccf98a --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/passes/utils/__init__.py @@ -0,0 +1 @@ +from .common import compare_graphs, HolderModule, lift_subgraph_as_module diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/passes/utils/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/passes/utils/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..08e8bb0001c24cd8df92f41b61aad130ff2743d2 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/passes/utils/__pycache__/__init__.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/passes/utils/__pycache__/common.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/passes/utils/__pycache__/common.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..37954d4815c084f32d55fbc313ffd2f976fedc15 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/passes/utils/__pycache__/common.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/passes/utils/__pycache__/fuser_utils.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/passes/utils/__pycache__/fuser_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..56f37a809bd3a493ee609136f73033ce20925e04 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/passes/utils/__pycache__/fuser_utils.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/passes/utils/__pycache__/matcher_utils.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/passes/utils/__pycache__/matcher_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..04ec6a3c250af7eca8bd34d82cca57bc0cd08125 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/passes/utils/__pycache__/matcher_utils.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/passes/utils/__pycache__/matcher_with_name_node_map_utils.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/passes/utils/__pycache__/matcher_with_name_node_map_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..25471718dad5a511cb8ca0c98fe59ec8b35a3269 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/passes/utils/__pycache__/matcher_with_name_node_map_utils.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/passes/utils/__pycache__/source_matcher_utils.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/passes/utils/__pycache__/source_matcher_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cb3b868ef36e7857615aa2de8af5cf43bc2d7de2 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/passes/utils/__pycache__/source_matcher_utils.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/passes/utils/common.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/passes/utils/common.py new file mode 100644 index 0000000000000000000000000000000000000000..4c97aa4093571604953f12f8ff4711fb401ca9c5 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/passes/utils/common.py @@ -0,0 +1,95 @@ +# mypy: allow-untyped-defs + +from torch.fx._compatibility import compatibility +from torch.fx.graph import Graph +from torch.fx.graph_module import GraphModule +from torch.fx.passes.utils.matcher_utils import SubgraphMatcher +from torch.nn import Module + + +__all__ = ["HolderModule", "lift_subgraph_as_module", "compare_graphs"] + + +@compatibility(is_backward_compatible=False) +class HolderModule(Module): + """ + HolderModule is used to copy all the attributes from original module to submodules + that uses the attributes + """ + + def __init__(self, d): + super().__init__() + for k, v in d.items(): + self.add_module(k, v) + + +@compatibility(is_backward_compatible=False) +def lift_subgraph_as_module( + gm: GraphModule, + subgraph: Graph, + comp_name: str = "", + class_name: str = "GraphModule", +) -> tuple[GraphModule, dict[str, str]]: + """ + Create a GraphModule for subgraph, which copies the necessary attributes from the original parent graph_module. + + Args: + gm (GraphModule): parent graph module + + subgraph (Graph): a valid subgraph that contains copied nodes from the parent graph + + comp_name (str): name for the new component + + class_name (str): name for the submodule + + """ + + # Loop through all module calls (call_module) and param fetches (get_attr) + # in this component, creating HolderModules as necessary to match the path. + # e.g. if in the original module there's a get_attr node fetches "conv.weight". + # We create a HolderModule as root -> add a HolderModule named "conv" -> + # make "weight" a attribute of "conv" HolderModule and point to conv.weight in + # the original module. + submodule = HolderModule({}) + orig_to_split_fqn_mapping: dict[str, str] = {} + for n in subgraph.nodes: + if n.op not in ("call_module", "get_attr"): + continue + + target = n.target + assert isinstance(target, str) + target_name_parts = target.split(".") + curr = submodule + orig_gm = gm + + for name in target_name_parts[:-1]: + if not hasattr(curr, name): + # pyrefly: ignore [missing-attribute] + curr.add_module(name, HolderModule({})) + + curr = getattr(curr, name) + orig_gm = getattr(orig_gm, name) + + leaf_node_name = target_name_parts[-1] + leaf_node = getattr(orig_gm, leaf_node_name) + + orig_to_split_fqn_mapping[target] = f"{comp_name}.{target}" + # Relies on custom __setattr__ magic. + setattr(curr, leaf_node_name, leaf_node) + + return GraphModule(submodule, subgraph, class_name), orig_to_split_fqn_mapping + + +@compatibility(is_backward_compatible=False) +def compare_graphs(left: Graph, right: Graph) -> bool: + """ + Return True if two graphs are identical, i.e they + - have the same number of outputs in the same order + - have the same number of inputs in the same order + - have the same set of nodes, and identical connectivity + """ + + matcher = SubgraphMatcher(left, match_output=True, match_placeholder=True) + matches = matcher.match(right) + + return len(matches) > 0 diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/passes/utils/fuser_utils.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/passes/utils/fuser_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..0571c92f61b765732d34f06ba09080dc11a66b60 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/passes/utils/fuser_utils.py @@ -0,0 +1,294 @@ +import copy +from queue import SimpleQueue +from typing import Optional as _Optional + +import torch.fx +from torch.fx._compatibility import compatibility +from torch.fx.graph import Graph +from torch.fx.graph_module import GraphModule +from torch.fx.node import Node +from torch.fx.passes.tools_common import ( # noqa: F401 + legalize_graph, + NodeList, + NodeSet, + stable_topological_sort, +) +from torch.fx.passes.utils import lift_subgraph_as_module # type: ignore[attr-defined] + + +@compatibility(is_backward_compatible=False) +def topo_sort(nodes: NodeList) -> NodeList: + # sort nodes according to the topological order + indegree_map = dict.fromkeys(nodes, 0) + candidates: SimpleQueue[Node] = SimpleQueue() + + for node in nodes: + for n in node.all_input_nodes: + if n in indegree_map: + indegree_map[node] += 1 + if indegree_map[node] == 0: + candidates.put(node) + + sorted_nodes: NodeList = [] + while not candidates.empty(): + node = candidates.get() + sorted_nodes.append(node) + + for n in node.users: + if n in indegree_map: + indegree_map[n] -= 1 + if indegree_map[n] == 0: + candidates.put(n) + + assert len(nodes) == len(sorted_nodes), ( + "topological sorted nodes doesn't have same length as input nodes" + ) + + return sorted_nodes + + +@compatibility(is_backward_compatible=False) +def validate_partition(partition: NodeList) -> bool: + # verify the partition doesn't form a dependency cycle in the original graph + # returns True for valid partition, False for invalid + + partition_set = set(partition) + + outputs: NodeList = [] + for node in partition_set: + for user_node in node.users: + if user_node not in partition_set: + # external user node, need to expose as an output + outputs.append(user_node) + + # Perform BFS on the partition outputs. + # If it reaches a node within the partition, then it found a cycle. + # This function takes the ownership of `root_nodes` and may modify it. + def bfs_find_cycle(root_nodes: NodeList) -> bool: + # Set used to exclude nodes that have already been visited. + # If a node has been visited, that node and all its children have + # been checked for cycles. + visited: NodeSet = set() + + # Start with `root_nodes` and traverse through (toward child nodes) + # their connected sub-graph. Nodes in `visited` won't be added + # to `queue` again. + queue: NodeList = root_nodes + while queue: + current = queue.pop() + visited.add(current) + if current in partition_set: + # Started from partition's `output` nodes, and reached + # another node in partition. Cycle! + return True + for user_node in current.users: + if user_node in visited: + continue + queue.append(user_node) + # `root_nodes` don't cause cycle. + return False + + # Use all output nodes as roots to traverse + # the graph to check cycles. + if bfs_find_cycle(outputs): + return False + + return True + + +@compatibility(is_backward_compatible=False) +def fuse_as_graphmodule( + gm: GraphModule, + nodes: NodeList, + module_name: str, + partition_lookup_table: _Optional[dict[Node, _Optional[int]]] = None, + *, + always_return_tuple: bool = False, +) -> tuple[GraphModule, tuple[Node, ...], tuple[Node, ...]]: + """ + Fuse nodes in graph_module into a GraphModule. + + Args: + gm (GraphModule): target graph_module + + nodes (List[Node]): list of nodes in `gm` to fuse, where the node must be topologically sorted + + module_name: class name for the fused GraphModule + + partition_lookup_table (Optional[Dict[Node, None]]): optional dict of nodes to speed up lookup + + always_return_tuple (bool): whether to always return a tuple, even if there is only one output + + Returns: + fused_gm (GraphModule): fused graph module, where its node is a copy of `nodes` in `gm` + + original_inputs (Tuple[Node, ...]): input nodes to `nodes` in original `gm` + + original_outputs (Tuple[Node, ...]): consumer nodes of `nodes` in original `gm` + + """ + + # assumption: nodes are already sorted in topo order + + for node in nodes: + assert node.graph.owning_module is gm, ( + f"{node} doesn't belong to passed in graph module {gm._get_name()}" + ) + assert not node._erased, f"{node} has been removed from owning graph" + assert node in gm.graph._find_nodes_lookup_table, ( + f"{node} is not found in graph module {gm._get_name()}" + ) + + # validates partition doesn't introduce dependency circles in the graph + assert validate_partition(nodes), "Invalid partition, found dependency cycles" + + # if no dict of partition nodes is provided, reconstruct it by nodes list to reduce lookup time + if partition_lookup_table is None: + partition_lookup_table = dict.fromkeys(nodes) + + subgraph = Graph() + + node_to_placeholder: dict[ + Node, Node + ] = {} # mapping of nodes from old graph to placeholder in new graph + node_map: dict[Node, Node] = {} # mapping of nodes from old graph to new graph + + # handles inputs through graph.node_copy's arg_transform functions + def remap_inputs(x: Node) -> Node: + if x.op == "get_attr": + # TODO: do we really need copy the get_attr node into the graph? + # do something here + pass + + if x in partition_lookup_table: + # x is inside subgraph, return the copied node + # the node should have been copied already, as we are copying graph in the topological order + return node_map[x] + + if x not in node_to_placeholder: + # x is not in subgraph, create a new placeholder for subgraph + placeholder_node = subgraph.placeholder(x.name, type_expr=x.type) + # copy all meta fields, even if some fields might be irrelevant for the placeholder node + placeholder_node.meta = copy.copy(x.meta) + node_to_placeholder[x] = placeholder_node + + return node_to_placeholder[x] + + # copy nodes in topological order + for node in nodes: + new_node = subgraph.node_copy(node, remap_inputs) + node_map[node] = new_node + + # handles outputs + output_mapping: dict[Node, Node] = {} # mapping from old output to new outputs + + for node in nodes: + for user_node in node.users: + if user_node not in partition_lookup_table: + # external user node, need to expose as an output + output_mapping[node] = node_map[node] + + # outs contain nodes in the new subgraph + outs = tuple(output_mapping.values()) + + if always_return_tuple: + # always return a tuple, even if there is only one output + subgraph.output(outs) + else: + # If there's a single output then return it directly, otherwise return a tuple. + subgraph.output(outs[0] if len(outs) == 1 else outs) + + # lint to ensure correctness + subgraph.lint() # type: ignore[no-untyped-call] + fused_gm: GraphModule + fused_gm, _ = lift_subgraph_as_module( + gm, subgraph, comp_name="", class_name=module_name + ) + + # sub_gm's input nodes in the original module + original_inputs: tuple[Node, ...] = tuple(node_to_placeholder.keys()) + + # sub_gm's outputs node in the original module + original_outputs: tuple[Node, ...] = tuple(output_mapping.keys()) + + return fused_gm, original_inputs, original_outputs + + +@compatibility(is_backward_compatible=False) +def insert_subgm( + gm: GraphModule, + sub_gm: GraphModule, + orig_inputs: tuple[Node, ...], + orig_outputs: tuple[Node, ...], +) -> GraphModule: + # add sub_gm into gm + submodule_name = sub_gm.__class__.__name__ + gm.add_submodule(submodule_name, sub_gm) + + def last_node(target_nodes: tuple[Node, ...]) -> Node | None: + for node in reversed(gm.graph.nodes): + if node in target_nodes: + return node + return None + + last_output_node: Node | None = last_node(orig_outputs) + assert last_output_node is not None + + # Create a call_module node in main graph. + with gm.graph.inserting_after(last_output_node): + module_node = gm.graph.call_module( + submodule_name, args=orig_inputs, kwargs=None + ) + output_node = sub_gm.graph.output_node() + + next_node = module_node.next + with gm.graph.inserting_before(next_node): + if len(orig_outputs) == 1 and not isinstance(output_node.args[0], tuple): + # main_remapping[comp.orig_outputs[0]] = module_node + orig_outputs[0].replace_all_uses_with(module_node, propagate_meta=True) + else: + for i, orig_output in enumerate(orig_outputs): + # Use Proxy to record getitem access. + proxy_out = torch.fx.Proxy(module_node)[i].node # type: ignore[index] + orig_output.replace_all_uses_with(proxy_out, propagate_meta=True) + + module_node.meta["val"] = tuple( + orig_output.meta.get("val", None) for orig_output in orig_outputs + ) + return gm + + +@compatibility(is_backward_compatible=False) +def erase_nodes(gm: GraphModule, nodes: NodeList) -> None: + # erase original nodes in inversed topological order + for node in reversed(nodes): + gm.graph.erase_node(node) + + +@compatibility(is_backward_compatible=False) +def fuse_by_partitions( + gm: GraphModule, + partitions: list[dict[Node, _Optional[int]]], + prefix: str = "fused_", + always_return_tuple: bool = False, +) -> GraphModule: + for partition_id, partition in enumerate(partitions): + sorted_nodes = topo_sort(list(partition)) + + submodule_name = prefix + str(partition_id) + sub_gm, orig_inputs, orig_outputs = fuse_as_graphmodule( + gm, + sorted_nodes, + submodule_name, + partition, + always_return_tuple=always_return_tuple, + ) + + insert_subgm(gm, sub_gm, orig_inputs, orig_outputs) + + erase_nodes(gm, sorted_nodes) + + stable_topological_sort(gm) + gm.graph.lint() + + return gm diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/passes/utils/matcher_utils.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/passes/utils/matcher_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..6f253cb292860de6ec8d3f8418d4e9d5033ca9c5 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/passes/utils/matcher_utils.py @@ -0,0 +1,447 @@ +# mypy: allow-untyped-defs +import copy +import logging +import os +from collections import defaultdict +from dataclasses import dataclass, field +from typing import Any, Union + +import torch +from torch.fx import Graph, Node +from torch.fx._compatibility import compatibility + + +__all__ = ["SubgraphMatcher", "InternalMatch"] + + +# Set`PYTORCH_MATCHER_LOGLEVEL=INFO` to see debug logs +def _init_logger(): + logger = logging.getLogger(__name__) + + level = os.environ.get("PYTORCH_MATCHER_LOGLEVEL", "WARNING").upper() + logger.setLevel(level) + console = logging.StreamHandler() + formatter = logging.Formatter("%(filename)s > %(message)s") + console.setFormatter(formatter) + console.setLevel(level) + # add the handlers to the logger + logger.addHandler(console) + logger.propagate = False + return logger + + +logger = _init_logger() + + +@compatibility(is_backward_compatible=False) +@dataclass +class InternalMatch: + # Nodes from which the match was found + anchors: list[Node] + # Maps nodes in the pattern subgraph to nodes in the larger graph + nodes_map: dict[Node, Node] = field(default_factory=dict) + + # nodes in target graph that are matched placeholder in pattern + placeholder_nodes: list[Node] = field(default_factory=list) + + # nodes in matched subgraph returned by output + returning_nodes: list[Node] = field(default_factory=list) + + # map from a string name to a node in the target graph + # only available if the matcher is `SubgraphMatcherWithNameNodesMap` + name_node_map: dict[str, Node] = field(default_factory=dict) + + def __copy__(self): + return InternalMatch( + anchors=self.anchors, + nodes_map=self.nodes_map.copy(), + placeholder_nodes=self.placeholder_nodes.copy(), + returning_nodes=self.returning_nodes.copy(), + ) + + +@compatibility(is_backward_compatible=False) +class SubgraphMatcher: + def __init__( + self, + pattern: Graph, + match_output: bool = False, + match_placeholder: bool = False, + remove_overlapping_matches: bool = True, + ignore_literals: bool = False, + ) -> None: + """ + Args: + pattern: the targeted matching pattern, represented in fx.Graph. + match_output: If True, output node in the pattern graph will be treated as a part of the targeted pattern. + If False, output node is ignored during match. + match_placeholder: If True, placeholder node in the pattern graph will be treated as a part of + the targeted pattern. If False, placeholder nodes will be used a wildcard. + remove_overlapping_matches: If True, in the case of overlapping matches, only the first match + will be returned. + ignore_literals: If True, will not check if literals are equal and + will instead treat them as wildcards. + """ + + self.pattern = pattern + self.match_output = match_output + self.match_placeholder = match_placeholder + self.remove_overlapping_matches = remove_overlapping_matches + self.ignore_literals = ignore_literals + + if len(pattern.nodes) == 0: + raise ValueError( + "SubgraphMatcher cannot be initialized with an empty pattern" + ) + + for node in pattern.nodes: + if node.op != "output" and not node.is_impure(): + assert len(node.users) > 0, ( + "SubgraphMatcher cannot be initialized with an pattern with dead code" + ) + + # TODO: assert pattern is a connected graph + + self.pattern_placeholder_nodes = [ + n for n in pattern.nodes if n.op == "placeholder" + ] + output_node = next(iter(reversed(pattern.nodes))) + # nodes returned by outputs + self.pattern_returning_nodes: list[Node] = output_node.all_input_nodes + + self.pattern_anchors: list[Node] = [] + if match_output: + self.pattern_anchors = [output_node] + else: + # If a node has output_node as the ONLY user, then this node is a graph sink, + # and should be matched against as an anchor + self.pattern_anchors = [ + n for n in output_node.all_input_nodes if len(n.users) == 1 + ] + + def _match_attributes(self, pn: Node, gn: Node) -> bool: + # Attributes matching is complicated. Right now we only support matching constant tensor + assert isinstance(pn.target, str), f"pn.target {pn.target} must be a string." + assert isinstance(gn.target, str), f"gn.target {gn.target} must be a string." + + pn_value = torch.fx.graph_module._get_attr(pn.graph.owning_module, pn.target) + gn_value = torch.fx.graph_module._get_attr(gn.graph.owning_module, gn.target) + + if type(pn_value) is not type(gn_value): + return False + + # Don't require exact match on tensor values. + if isinstance(pn_value, torch.Tensor): + return isinstance(gn_value, torch.Tensor) + else: + raise RuntimeError(f"Unsupported type {pn_value} when matching attributes") + return False + + def _nodes_are_equal(self, pn: Node, gn: Node, node_name_match: str = "") -> bool: + # if exact match for placeholder is not required, then use placeholder as a wildcard + if not self.match_placeholder and pn.op == "placeholder": + return True + + if node_name_match and node_name_match in gn.name: + return True + + if pn.op == gn.op: + if pn.op == "placeholder" or pn.op == "output": + return True + elif pn.op == "get_attr": + return self._match_attributes(pn, gn) + return pn.target == gn.target + return False + + def _is_contained(self, nodes_map: dict[Node, Node]) -> bool: + # `lookup` represents all the nodes in `original_graph` + # that are part of `pattern` + + # Placeholders can be used by other nodes in the graphs + lookup: dict[Node, Node] = { + gn: pn for pn, gn in nodes_map.items() if pn.op != "placeholder" + } + + for gn, pn in lookup.items(): + # nodes returned by output are allowed to be used in other areas of the graph + if pn in self.pattern_returning_nodes: + continue + + for user in gn.users: + # If this node has users that were not in `lookup`, then it must leak out of the + # pattern subgraph + if user not in lookup: + return False + return True + + def _remove_overlapping_matches( + self, matches: list[InternalMatch] + ) -> list[InternalMatch]: + non_overlapping_matches: list[InternalMatch] = [] + nodes_matched: set[Node] = set() + + for match in matches: + found_overlap = False + for pn, gn in match.nodes_map.items(): + if pn.op not in {"placeholder", "output"} and gn in nodes_matched: + found_overlap = True + break + + if not found_overlap: + non_overlapping_matches.append(match) + for pn, gn in match.nodes_map.items(): + if pn.op not in {"placeholder", "output"}: + nodes_matched.add(gn) + return non_overlapping_matches + + def _match_literals(self, pn: Any, gn: Any, match: InternalMatch) -> bool: + assert not (isinstance(pn, Node) and isinstance(gn, Node)), ( + "pn and gn cannot both be Node" + ) + + if isinstance(pn, Node) and not isinstance(gn, Node): + if pn.op == "placeholder": + # Check if we've already matched these nodes in the current + # traversal + if pn in match.nodes_map: + return match.nodes_map[pn] == gn + + match.nodes_map[pn] = gn + return True + else: + return False + elif not isinstance(pn, Node) and isinstance(gn, Node): + return False + else: + return type(gn) is type(pn) and gn == pn + + def _match_nodes( + self, pn: Node, gn: Node, match: InternalMatch, node_name_match: str = "" + ) -> bool: + logger.info(" matching %s to %s", pn, gn) + + assert isinstance(pn, Node) and isinstance(gn, Node), str( + f"pn and gn must be Node, pn: {pn}, gn: {gn}" + ) + + # Check if we've already matched these nodes in the current + # traversal + if pn in match.nodes_map: + return match.nodes_map[pn] == gn + + # TODO: use a more efficient way to check if gn is matched before: two-way dict + if gn in match.nodes_map.values(): + return False + + if not self._nodes_are_equal(pn, gn, node_name_match): + return False + + # Optimistically mark `pn` as a match for `gn`, and save a local copy of match + saved_match = copy.copy(match) + match.nodes_map[pn] = gn + + # Placeholder is a wildcard and can be matched with any python object + # (including list/tuple) + if pn.op == "placeholder": + return True + + # Recursively traverse upwards to check if `pn` is a true + # match for `gn` + match_found = True + + def _match_args(args1: Union[list, tuple], args2: Union[list, tuple]) -> bool: + if len(args1) != len(args2): + return False + + for a1, a2 in zip(args1, args2): + if isinstance(a1, Node) and isinstance(a2, Node): + matched = self._match_nodes(a1, a2, match) + elif isinstance(a1, (list, tuple)) and isinstance(a2, (list, tuple)): + matched = _match_args(a1, a2) + else: + matched = ( + self._match_literals(a1, a2, match) or self.ignore_literals + ) + + if not matched: + return False + + return True + + # Flatten all args/kwargs into 1 list of args + pn_args, gn_args = None, None + if ( + ( + len(pn.args) != len(gn.args) + or list(pn.kwargs.keys()) != list(gn.kwargs.keys()) + ) + and pn.op == "call_function" + and isinstance(pn.target, torch._ops.OpOverload) + ): + args_schema = pn.target._schema.arguments + + def get_all_arguments(orig_args, orig_kwargs): + all_args = [] + for i, schema in enumerate(args_schema): + if schema.name in orig_kwargs: + all_args.append(orig_kwargs[schema.name]) + elif not schema.kwarg_only and i < len(orig_args): + all_args.append(orig_args[i]) + else: + all_args.append(schema.default_value) + return all_args + + pn_args = get_all_arguments(pn.args, pn.kwargs) + gn_args = get_all_arguments(gn.args, gn.kwargs) + + elif len(pn.args) == len(gn.args) and list(pn.kwargs.keys()) == list( + gn.kwargs.keys() + ): + pn_args = list(pn.args) + gn_args = list(gn.args) + pn_args.extend(list(pn.kwargs.values())) + gn_args.extend(list(gn.kwargs.values())) + else: + match_found = False + + match_found = ( + match_found + and pn_args is not None + and gn_args is not None + and _match_args(pn_args, gn_args) + ) + + if not match_found: + # revert to saved_match before matching with current node + match = copy.copy(saved_match) + return False + + return True + + def match(self, graph: Graph, node_name_match: str = "") -> list[InternalMatch]: + """ + Returns: + The matched subgraphs. + The returned subgraph would be fully self-contained, meaning the nodes (except placeholder + and nodes returned by output) can only be consumed by nodes within the matched subgraph. + + Subgraph pattern matcher is implemented with the backtracking style in the following steps: + + 1. We first identify all the anchor nodes in the pattern graph. The anchor nodes + are the "sinks" (nodes with no user other than the output node) of the pattern graph. + One pattern graph could have multiple anchors if it has multiple return values. + + 2. In the target graph, we identify the potential candidate nodes that can be matched + with each anchor. These anchor-candidate pairs are the starting points for + pairwise per-node matching. + + 3. For each anchor-candidate pair, we simultaneously traverse backwards (DFS) in both + pattern and target graphs. For every pattern nodes along traversal path, we compare it + against the target nodes. In case any comparison failed, the match for this anchor-candidate + pair fails. A match is found when DFS completes traversing the graph. See `self._match_nodes` + for more details. + + 4. In the case of multiple anchors, every anchor will need to find a match using step 3. + In addition, the matches found between anchors need to have a common intersection node + in order for the match to be valid. This is implemented with backtracking. See `backtracking` + for more details. + + Notice: graph traversal must be done in the reverser order because a tensor can have multiple + consumers, but can only have a single producer. Only with reverser order, we can we jointly + traverse the pattern and target graph in a deterministic path. + + Warning: In theory, this backtracking algorithm have an **exponential** time complexity. However, + in practice, it's unlikely to blow up. + + """ + from torch.fx.passes.utils.fuser_utils import validate_partition + + # find candidate nodes to match with pattern anchors + match_candidates: dict[Node, list[Node]] = defaultdict(list) + for pattern_anchor in self.pattern_anchors: + for node in graph.nodes: + if self._nodes_are_equal(pattern_anchor, node, node_name_match): + match_candidates[pattern_anchor].append(node) + match_candidates_list = list(match_candidates.items()) + + logger.info("Initial match_candidates_list: %s\n", match_candidates_list) + + matches: list[InternalMatch] = [] + + def backtracking(anchor_index, match): + if anchor_index == len(match_candidates_list): + match.placeholder_nodes = [ + match.nodes_map[pn] for pn in self.pattern_placeholder_nodes + ] + match.returning_nodes = [ + match.nodes_map[pn] for pn in self.pattern_returning_nodes + ] + matches.append(match) + + logger.info("Found a match: %s\n", match) + return + + pattern_anchor, candidate_nodes = match_candidates_list[anchor_index] + saved_match = copy.copy(match) + + for node in candidate_nodes: + logger.info("Trying to match anchor %s to %s", pattern_anchor, node) + + match_found = self._match_nodes( + pattern_anchor, node, match, node_name_match + ) + if match_found: + # match next anchor + backtracking(anchor_index + 1, match) + else: + logger.info( + "Failed to match anchor %s to %s\n", pattern_anchor, node + ) + + # revert to saved_match before matching with current anchor + match = copy.copy(saved_match) + + match = InternalMatch(anchors=self.pattern_anchors) + if match_candidates_list: + backtracking(0, match) + + # filter out the matches where the subgraph is not fully_contained + before = len(matches) + matches = [match for match in matches if self._is_contained(match.nodes_map)] + after = len(matches) + if before != after: + logger.info( + "Filtered out %s matches because they are not fully contained", + before - after, + ) + + # filter out the matches that form a cycle if the subgraph is fused + valid_matches = [] + for match in matches: + matched_compute_nodes = [ + gn + for pn, gn in match.nodes_map.items() + if pn.op not in {"placeholder", "output"} + ] + if validate_partition(matched_compute_nodes): + valid_matches.append(match) + if len(valid_matches) != len(matches): + logger.info( + "Filtered out %s matches because \ + matched subgraph would form a cycle if fused", + len(matches) - len(valid_matches), + ) + + if self.remove_overlapping_matches: + before = len(valid_matches) + matches = self._remove_overlapping_matches(valid_matches) + after = len(matches) + if before != after: + logger.info( + "Filtered out %s matches because matched subgraphs are overlapping", + before - after, + ) + + logger.info("Matches returned: %s", matches) + + return matches diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/passes/utils/matcher_with_name_node_map_utils.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/passes/utils/matcher_with_name_node_map_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..3114d55b635fcb5d02b8e57faade2474ec021e7f --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/passes/utils/matcher_with_name_node_map_utils.py @@ -0,0 +1,114 @@ +from torch.fx import Graph, GraphModule, Node +from torch.fx._compatibility import compatibility + +from .matcher_utils import InternalMatch, SubgraphMatcher + + +__all__ = ["SubgraphMatcherWithNameNodeMap"] + + +def _split_to_graph_and_name_node_map( + gm: GraphModule, +) -> tuple[GraphModule, dict[str, Node]]: + from torch.fx.graph import _PyTreeInfo + from torch.utils._pytree import tree_flatten, tree_unflatten + + name_node_map = {} + for n in gm.graph.nodes: + if n.op == "output": + assert gm._out_spec is not None + output = tree_unflatten(n.args[0], gm._out_spec) + assert isinstance(output, tuple), ( + "Expecting the pattern graph to return a tuple" + ) + assert len(output) >= 2, ( + "Expecting the pattern graph to have at least two outputs" + ) + *out, name_node_map = output + flattened, out_spec = tree_flatten(out) + assert isinstance(name_node_map, dict), ( + "Expecting the input graph to have a dict output as the last element" + ) + n.args = (flattened,) + orig_pytree_info = gm._graph._codegen.pytree_info # type: ignore[attr-defined] + gm._graph._codegen.pytree_info = _PyTreeInfo( # type: ignore[attr-defined] + orig_pytree_info.orig_args, orig_pytree_info.in_spec, out_spec + ) + gm.recompile() + return gm, name_node_map + + +@compatibility(is_backward_compatible=False) +class SubgraphMatcherWithNameNodeMap(SubgraphMatcher): + """Extends SubgraphMatcher to support querying the matched subgraph nodes through node name, + this requires pattern to have specific format (returning and additional dictionary at the output, + that has node name as key, and the node in the pattern graph as value, see Example for more details) + + Difference with SubgraphMatcher is that it takes a `pattern_gm` GraphModule as input during + initialization since we need to modify the graph (which requires `recompile` the GraphModule) + + Example:: + def pattern(x, weight): + conv = F.conv2d(x, weight) + relu = F.relu(conv) + return relu, {"conv": conv, "relu": relu} + + + def target_graph(x, weight): + conv = F.conv2d(x, weight) + relu = F.relu(conv) + relu *= 2 + return relu + + + pattern_gm = export_for_training(pattern, example_inputs).module() + target_gm = export_for_training(target_graph, example_inputs).module() + matcher = SubgraphMatcherWithNameNodeMap(pattern_gm) + matches = matcher.match(target_gm) + for match in matches: + match.name_node_map["conv"].meta["annotation"] = ... + + """ + + def __init__( + self, + pattern_gm: GraphModule, + match_output: bool = False, + match_placeholder: bool = False, + remove_overlapping_matches: bool = True, + ignore_literals: bool = False, + ) -> None: + pattern_gm, name_node_map = _split_to_graph_and_name_node_map(pattern_gm) + self.name_node_map = name_node_map + super().__init__( + pattern_gm.graph, + match_output, + match_placeholder, + remove_overlapping_matches, + ignore_literals, + ) + + def match(self, graph: Graph, node_name_match: str = "") -> list[InternalMatch]: + """The returned InternalMatch will have name_node_map populated with a map + from node name (str) to the target node, e.g. + {"conv": target_conv_ndoe, "relu": target_relu_node} + + this requires the pattern graph returns an additional + output of node name to node, e.g. instead of: + ``` + def pattern(...): + ... + return relu + ``` + we should do: + ``` + def pattern(...): + ... + return relu, {"conv": conv, "relu": relu} + ``` instead + """ + internal_matches = super().match(graph, node_name_match) + for internal_match in internal_matches: + for k, n in self.name_node_map.items(): + internal_match.name_node_map[k] = internal_match.nodes_map[n] + return internal_matches diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/passes/utils/source_matcher_utils.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/passes/utils/source_matcher_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..82259b8a36ab78ce67ab14411ca4522cc33cd83c --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/passes/utils/source_matcher_utils.py @@ -0,0 +1,163 @@ +import logging +import os +from collections.abc import Callable +from dataclasses import dataclass, field +from typing import Any, Optional + +from torch.fx._compatibility import compatibility +from torch.fx.graph import Graph +from torch.fx.node import Node + + +__all__ = ["get_source_partitions", "check_subgraphs_connected", "SourcePartition"] + + +# Set`PYTORCH_MATCHER_LOGLEVEL=INFO` to see debug logs +def _init_logger() -> logging.Logger: + logger = logging.getLogger(__name__) + + level = os.environ.get("PYTORCH_MATCHER_LOGLEVEL", "WARNING").upper() + logger.setLevel(level) + console = logging.StreamHandler() + formatter = logging.Formatter("%(filename)s > %(message)s") + console.setFormatter(formatter) + console.setLevel(level) + # add the handlers to the logger + logger.addHandler(console) + logger.propagate = False + return logger + + +logger = _init_logger() + + +@compatibility(is_backward_compatible=False) +@dataclass +class SourcePartition: + # Nodes in a particular partition + nodes: list[Node] + + # The source these nodes decomposed from + source: Any + + # Nodes in the graph that are needed as inputs to the partition + # These do not include the params of the partition + input_nodes: list[Node] = field(default_factory=list) + + # Nodes in the partition that are being used by nodes outside of the + # partition + output_nodes: list[Node] = field(default_factory=list) + + # Parameters that are being used + params: list[Node] = field(default_factory=list) + + +@compatibility(is_backward_compatible=False) # type: ignore[misc] +def get_source_partitions( + graph: Graph, + wanted_sources: list[Any], + filter_fn: Optional[Callable[[Node], bool]] = None, +) -> dict[Any, list[SourcePartition]]: + """ + Args: + graph: The graph we want to partition + wanted_sources: List of sources of nodes that were decomposed from this + source. This can be a function (ex. torch.nn.functional.linear) or a + leaf module type (ex. torch.nn.Linear). + + Returns: + Dictionary mapping sources that were given to a list of SourcePartitions + that correspond to the list of nodes that were decomposed from the given + source. + """ + modules: dict[type, dict[str, list[Node]]] = {} + + for node in graph.nodes: + # The metadata source_fn should contain a tuple of a unique name for the + # source, and the source function if the node is decomposed from a + # function, or the type of module if the node is decomposed from a leaf + # module + + # TODO: Bypass "torch_fn" when "source_fn_stack" because now "torch_fn" can + # be different from "source_fn_stack", for example for the add_ node + # decomposed from batch norm. We should remove the check on "source_fn_stack" + # after we fix "torch_fn". T199561090 + if (source_fn_st := node.meta.get("source_fn_stack", None)) is None and ( + torch_fn := node.meta.get("torch_fn", None) + ) is not None: + node_fqn, source_fn = torch_fn + source_fn_name = source_fn.split(".")[1] + if source_fn_name in wanted_sources: + diff_modules = modules.setdefault(source_fn_name, {}) + partition = diff_modules.setdefault(node_fqn, []) + partition.append(node) + + if (source_fn_st := node.meta.get("source_fn_stack", None)) is not None: + source_fn = source_fn_st[-1] + if source_fn[1] in wanted_sources: + diff_modules = modules.setdefault(source_fn[1], {}) + partition = diff_modules.setdefault(source_fn[0], []) + partition.append(node) + + def make_partition(nodes: list[Node], module_type: type) -> SourcePartition: + input_nodes = set() + output_nodes = set() + params = set() + for node in nodes: + for arg in node.args: + if isinstance(arg, Node) and arg not in nodes and arg.op != "get_attr": + input_nodes.add(arg) + + if node.op == "get_attr": + params.add(node) + # get_attr nodes won't be output nodes + continue + + for user in node.users: + if user not in nodes: + output_nodes.add(node) + + return SourcePartition( + nodes, + module_type, + list(input_nodes), + list(output_nodes), + list(params), # type: ignore[arg-type] + ) + + ret: dict[type[Any], list[SourcePartition]] = {} + + if filter_fn: + # for each partition, we apply filter_fn to filter out all partitions that doesn't satisfy the + # filter condition + filtered_modules = {} + for tp, name_to_partition in modules.items(): + filtered_name_to_partition = { + name: partition + for name, partition in name_to_partition.items() + if all(map(filter_fn, partition)) + } + filtered_modules[tp] = filtered_name_to_partition + modules = filtered_modules + + for k, v in modules.items(): + ret[k] = [make_partition(partition, k) for partition in v.values()] + + return ret + + +@compatibility(is_backward_compatible=False) # type: ignore[misc] +def check_subgraphs_connected( + subgraph1: SourcePartition, subgraph2: SourcePartition +) -> bool: + """ + Given two subgraphs A and B (in the form of a list of nodes), checks if + A has nodes connecting to at least one node in B -- aka there exists a node + in B that uses a node in A (not the other way around). + """ + + for node in reversed(subgraph1.nodes): + for user in node.users: + if user in subgraph2.nodes: + return True + return False diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/Activation.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/Activation.h new file mode 100644 index 0000000000000000000000000000000000000000..4ca11a5566b860dbfe7866db7c10b532e9ab0537 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/Activation.h @@ -0,0 +1,25 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once +#include +#include + +namespace at { +struct TensorIteratorBase; +class TensorBase; +} + +namespace at::native { + +void launch_glu_backward_kernel(const TensorIteratorBase& iter, + int64_t gI_stride, int64_t I_stride); + +void launch_log_sigmoid_forward_kernel(TensorIteratorBase& iter); + +void GeluCUDAKernelImpl(TensorIteratorBase& it, GeluType approximate); +void GeluBackwardCUDAKernelImpl(TensorIteratorBase& it, GeluType approximate); + +} // namespace at::native + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/BinaryInternal.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/BinaryInternal.h new file mode 100644 index 0000000000000000000000000000000000000000..1002f53b0793f655db825b5203e500a93af33615 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/BinaryInternal.h @@ -0,0 +1,49 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +// DON'T include this except from Binary*.cu files. It should not leak into +// headers. +#pragma once +#define TORCH_ASSERT_NO_OPERATORS +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +namespace at::native::binary_internal { + +template +struct DivFunctor { + __device__ scalar_t operator()(scalar_t a, scalar_t b) const { + return a / b; + } +}; + +template +struct MulFunctor { + __device__ T operator()(T a, T b) const { + return a * b; + } +}; + +// Workaround for the error: '*' in boolean context, suggest '&&' instead +// [-Werror=int-in-bool-context] +template <> +struct MulFunctor { + __device__ bool operator()(bool a, bool b) const { + return a && b; + } +}; +void div_true_kernel_cuda(TensorIteratorBase& iter); +void div_trunc_kernel_cuda(TensorIteratorBase& iter); +} // namespace at::native::binary_internal + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/CUDALoops.cuh b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/CUDALoops.cuh new file mode 100644 index 0000000000000000000000000000000000000000..37684e127bed6702510109919f38de7a5528f5d3 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/CUDALoops.cuh @@ -0,0 +1,1136 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +// This file provides two functions to help write GPU elementwise kernels: +// +// gpu_kernel(TensorIterator iter, ) +// gpu_kernel_with_scalars(TensorIterator iter, ) +// +// The gpu_kernel_with_scalars generates specializations that support a +// single scalar CPU argument, such as from `cuda_tensor + 5`. The CPU scalar +// is lifted to a kernel parameter instead of copying to device memory. +// This should be used in conjunction with TensorIterator::allow_cpu_scalars_, +// which is the default for TensorIterator::binary_op. Otherwise, all inputs +// and the output must be on the GPU. +// +// For example, to write a reciprocal kernel for GPU float Tensors: +// +// gpu_kernel(iter, []GPU_LAMBDA(float a) { +// return 1.0f / a; +// }); +// +// To write a multiplication kernel for GPU float Tensors where one argument +// may be a CPU scalar: +// +// gpu_kernel_with_scalars(iter, []GPU_LAMBDA(float a, float b) { +// return a * b; +// }); +// +// See BinaryOpsKernel.cu for the complete implementation +// + +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +#ifdef __NVCC__ +#define ASSERT_HOST_DEVICE_LAMBDA(type) \ + static_assert( \ + __nv_is_extended_host_device_lambda_closure_type(type), \ + #type " must be a __host__ __device__ lambda") +#else +#define ASSERT_HOST_DEVICE_LAMBDA(type) +#endif + +namespace at::native { + +#ifdef USE_ROCM +// Custom configuration for vectorized elementwise kernel +// with template instantiation. +namespace vectorized_templated_config { +constexpr int num_threads() { + return 512; +} + +constexpr int elems_per_thread() { + return 32; +} + +constexpr int block_work_size() { + return elems_per_thread() * num_threads(); +} +} // namespace vectorized_templated_config +#endif + +template +constexpr auto sum_of_sizes(args_t args, std::index_sequence) { + if constexpr (sizeof...(Is) == 0) { + return 0; + } else { + return (sizeof(std::tuple_element_t) + ...); + } +} + +#ifdef USE_ROCM +template +constexpr auto elems_per_thread(){ + if constexpr (io_sizes == 1) { + return 16; + } else if constexpr (io_sizes < 4) { + return 8; + } else { + return 4; + } +} +#else +template +constexpr auto elems_per_thread(){ + if constexpr (io_sizes == 1) { + return 16; + } else { + return 8; + } +} +#endif + + +//thread work size of 8 regresses the perf of elementwise kernel on cuda +//this doesn't change ROCm behavior as thread_work_size is already 4 on ROCm +constexpr int elementwise_thread_work_size() {return 4;} +constexpr int elementwise_block_work_size() { + return elementwise_thread_work_size() * num_threads(); +} + +template +constexpr auto io_block_work_size() { + return num_threads() * elems_per_thread(); +} + +#ifdef USE_ROCM +template +constexpr auto input_size(args_t args, std::index_sequence) { + if constexpr (sizeof...(Is) == 0) { + return 0; + } else { + return sizeof(std::tuple_element_t<0, args_t>); + } +} + +template +constexpr auto calc_optimal_vec_size() { + static_assert(vec_size != 0); + static_assert(io_size != 0); + if constexpr (io_size == 1 && vec_size >= 16) { + return 16; + } else if constexpr (io_size <= 2 && vec_size >= 8) { + return 8; + } else if constexpr (io_size <= 4 && vec_size >= 4) { + return 4; + } else if constexpr (vec_size >= 4) { + return 4; + } else if constexpr (vec_size >= 2) { + return 2; + } else { + return 1; + } +} +#endif + +template +constexpr auto calc_io_size(){ + using traits = function_traits; + using args_t = typename traits::ArgsTuple; +#ifdef USE_ROCM + constexpr auto input_size = at::native::input_size(args_t{}, std::make_index_sequence>{}); + constexpr auto output_size = sizeof(typename traits::result_type); + return (input_size > 0) ? ((input_size < output_size) ? input_size : output_size) : output_size; +#else + constexpr auto input_size = at::native::sum_of_sizes(args_t{}, std::make_index_sequence>{}); + constexpr auto output_size = sizeof(typename traits::result_type); + return input_size + output_size; +#endif +} + +#ifndef USE_ROCM +// To save on binary size of libtorch_cuda.so, we split the vectorized_elementwise_kernel +// into two: one for vec_size=8 and one for vec_size=[2, 4], since vec8 is going to be +// used on sm_90 and sm_100 exclusively. +template +C10_LAUNCH_BOUNDS_1(num_threads()) +__global__ void vectorized_elementwise_kernel(int N, func_t f, array_t data) { + if constexpr (vec_size == 8) { +#if __CUDA_ARCH__ == 900 || __CUDA_ARCH__ == 1000 + using traits = function_traits; + constexpr auto io_size = calc_io_size(); + int remaining = N - io_block_work_size() * blockIdx.x; + + if (remaining < io_block_work_size()) { // if this block handles the reminder, + // just do a naive unrolled loop + auto input_calc = TrivialOffsetCalculator(); + auto output_calc = TrivialOffsetCalculator<1>(); + auto loader = memory::LoadWithoutCast(); + auto storer = memory::StoreWithoutCast(); + auto policy = memory::policies::unroll< + array_t, + decltype(input_calc), + decltype(output_calc), + memory::LoadWithoutCast, + memory::StoreWithoutCast, + elems_per_thread()>( + data, remaining, input_calc, output_calc, loader, storer); + elementwise_kernel_helper(f, policy); + } else { // if this block has a full `block_work_size` data to handle, use + // vectorized memory access + elementwise_kernel_helper( + f, memory::policies::vectorized()>(data)); + } +#endif // __CUDA_ARCH__ == 900 || __CUDA_ARCH__ == 1000 + } else { + using traits = function_traits; + constexpr auto io_size = calc_io_size(); + int remaining = N - io_block_work_size() * blockIdx.x; + + if (remaining < io_block_work_size()) { // if this block handles the reminder, + // just do a naive unrolled loop + auto input_calc = TrivialOffsetCalculator(); + auto output_calc = TrivialOffsetCalculator<1>(); + auto loader = memory::LoadWithoutCast(); + auto storer = memory::StoreWithoutCast(); + auto policy = memory::policies::unroll< + array_t, + decltype(input_calc), + decltype(output_calc), + memory::LoadWithoutCast, + memory::StoreWithoutCast, + elems_per_thread()>( + data, remaining, input_calc, output_calc, loader, storer); + elementwise_kernel_helper(f, policy); + } else { // if this block has a full `block_work_size` data to handle, use + // vectorized memory access + elementwise_kernel_helper( + f, memory::policies::vectorized()>(data)); + } + } +} + +#else // USE_ROCM +template +C10_LAUNCH_BOUNDS_1(num_threads()) +__global__ void vectorized_elementwise_kernel(int N, func_t f, array_t data) { + using traits = function_traits; + constexpr auto io_size = calc_io_size(); +#if defined(USE_ROCM) && defined(__gfx942__) + // Similar check in launch_vectorized_kernel() as well. Both should be in sync. + constexpr int tws = 16; +#else + constexpr int tws = elems_per_thread(); +#endif + constexpr int bws = tws * num_threads(); + int remaining = N - bws * blockIdx.x; + + if (remaining < bws) { // if this block handles the reminder, + // just do a naive unrolled loop + auto input_calc = TrivialOffsetCalculator(); + auto output_calc = TrivialOffsetCalculator<1>(); + auto loader = memory::LoadWithoutCast(); + auto storer = memory::StoreWithoutCast(); + auto policy = memory::policies::unroll< + array_t, + decltype(input_calc), + decltype(output_calc), + memory::LoadWithoutCast, + memory::StoreWithoutCast, + tws>( + data, remaining, input_calc, output_calc, loader, storer); + elementwise_kernel_helper(f, policy); + } else { // if this block has a full `block_work_size` data to handle, use + // vectorized memory access + constexpr auto optimal_vec_size = calc_optimal_vec_size(); + elementwise_kernel_helper( + f, memory::policies::vectorized(data)); + } +} +#endif // USE_ROCM + +template < + typename func_t, + typename array_t, + int elems_per_thread, + typename inp_calc_t, + typename out_calc_t, + typename loader_t, + typename storer_t> +C10_LAUNCH_BOUNDS_1(num_threads()) +__global__ void unrolled_elementwise_kernel( + int N, + func_t f, + array_t data, + inp_calc_t ic, + out_calc_t oc, + loader_t l, + storer_t s) { + int remaining = N - elems_per_thread * num_threads() * blockIdx.x; + auto policy = memory::policies:: + unroll( + data, remaining, ic, oc, l, s); + elementwise_kernel_helper(f, policy); +} + +// this function assume trivial 1d and no dynamic casting +template +static inline void launch_vectorized_kernel( + int64_t N, + const func_t& f, + array_t data) { + TORCH_INTERNAL_ASSERT(N > 0 && N <= std::numeric_limits::max()); + using traits = function_traits; + constexpr auto io_size = calc_io_size(); + auto stream = at::cuda::getCurrentCUDAStream(); +#ifdef USE_ROCM + int vec_size = memory::can_vectorize_up_to(data); + c10::DeviceIndex curDevice = -1; + AT_CUDA_CHECK(c10::cuda::GetDevice(&curDevice)); + // Similar check in vectorized_elementwise_kernel() as well. Both should be in sync. + int tws = at::detail::getCUDAHooks().isGPUArch({"gfx942"}, curDevice) ? 16 : elems_per_thread(); +#else + using cpp_type = typename function_traits::result_type; + const uint16_t max_vec_size = memory::can_vectorize_up_to(data); + uint16_t vec_size = 16 / static_cast(sizeof(cpp_type)); + vec_size = std::min(vec_size, max_vec_size); + // Here we purposely omit vec8 for 1-byte data because of a bug in NVCC + // that causes some numerical mismatches with uint8 on sm80 and sm90. + // TODO: Revisit this after CUDA 12.8 update. + cudaDeviceProp* p = at::cuda::getDeviceProperties(stream.device().index()); + const int computeCapability = p->major * 10 + p->minor; + if (computeCapability != 90 && computeCapability != 100) { + vec_size = std::min(vec_size, 4); + } + if constexpr (sizeof(cpp_type) < 2) { + vec_size = std::min(vec_size, 4); + } + int tws = elems_per_thread(); +#endif + int bws = tws * num_threads(); + int64_t grid = (N + bws - 1) / bws; + switch (vec_size) { +#ifdef USE_ROCM + case 16: + vectorized_elementwise_kernel<16, func_t, array_t> + <<>>(N, f, data); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + break; +#endif + case 8: + vectorized_elementwise_kernel<8, func_t, array_t> + <<>>(N, f, data); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + break; + case 4: + vectorized_elementwise_kernel<4, func_t, array_t> + <<>>(N, f, data); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + break; + case 2: + vectorized_elementwise_kernel<2, func_t, array_t> + <<>>(N, f, data); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + break; + case 1: { + auto input_calc = TrivialOffsetCalculator(); + auto output_calc = TrivialOffsetCalculator<1>(); + auto loader = memory::LoadWithoutCast(); + auto storer = memory::StoreWithoutCast(); + int64_t grid_unrolled = (N + elementwise_block_work_size() - 1) / elementwise_block_work_size(); + unrolled_elementwise_kernel + <<>>( + N, f, data, input_calc, output_calc, loader, storer); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + break; + } + default: + TORCH_INTERNAL_ASSERT(false, "Unexpected vectorization size"); + } +} + +#ifdef USE_ROCM +template < + int vec_size, + typename func_t, + typename array_t, + typename inp_calc_t, + typename out_calc_t, + typename loader_t, + typename storer_t, + typename OutputType, + typename... InputTypes> +C10_LAUNCH_BOUNDS_1(vectorized_templated_config::num_threads()) +__global__ void vectorized_templated_elementwise_kernel( + int N, + func_t f, + array_t data, + inp_calc_t inp_calc, + out_calc_t out_calc, + loader_t loader, + storer_t storer) { + int remaining = N - + vectorized_templated_config::block_work_size() * + (gridDim.x - blockIdx.x - 1); + constexpr bool reverted_idx = true; + + if (remaining < + vectorized_templated_config::block_work_size()) { // if this block handles + // the reminder, + // just do a naive unrolled loop + auto policy = memory::policies::unroll_base< + vectorized_templated_config::num_threads(), + array_t, + inp_calc_t, + out_calc_t, + loader_t, + storer_t, + vectorized_templated_config::elems_per_thread()>( + data, remaining, inp_calc, out_calc, loader, storer); + elementwise_kernel_helper(f, policy); + } else { // if this block has a full `block_work_size` data to handle, use + // vectorized memory access + auto policy = memory::policies::vectorized_templated< + vec_size, + array_t, + vectorized_templated_config::elems_per_thread(), + vectorized_templated_config::num_threads(), + OutputType, + InputTypes...>(data); + elementwise_kernel_helper(f, policy); + } +} + +// This function assume trivial 1d and supports template specialization +// to avoid dynamic casting. +// Input vectorization size is based on runtime information, i.e. +// the actual data types of the input and output tensor and cannot +// be determined using the functor type, as in regular non-templated +// vectorized kernels. The caller is in charge of selecting the correct input +// vectorization length. +template < + typename func_t, + typename array_t, + typename inp_calc_t, + typename out_calc_t, + typename loader_t, + typename storer_t, + typename OutputType, + typename... InputTypes> +static inline void launch_vectorized_templated_kernel( + int64_t N, + const func_t& f, + array_t data, + inp_calc_t ic, + out_calc_t oc, + loader_t l, + storer_t s) { + TORCH_INTERNAL_ASSERT(N > 0 && N <= std::numeric_limits::max()); + int64_t grid = (N + vectorized_templated_config::block_work_size() - 1) / + vectorized_templated_config::block_work_size(); + auto stream = at::cuda::getCurrentCUDAStream(); + int vec_size = memory::can_vectorize_up_to(data); + switch (vec_size) { + case 8: + vectorized_templated_elementwise_kernel< + 8, + func_t, + array_t, + inp_calc_t, + out_calc_t, + loader_t, + storer_t, + OutputType, + InputTypes...> + <<>>( + N, f, data, ic, oc, l, s); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + break; + case 4: + vectorized_templated_elementwise_kernel< + 4, + func_t, + array_t, + inp_calc_t, + out_calc_t, + loader_t, + storer_t, + OutputType, + InputTypes...> + <<>>( + N, f, data, ic, oc, l, s); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + break; + case 2: + vectorized_templated_elementwise_kernel< + 2, + func_t, + array_t, + inp_calc_t, + out_calc_t, + loader_t, + storer_t, + OutputType, + InputTypes...> + <<>>( + N, f, data, ic, oc, l, s); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + break; + default: + // vector size 1 is not handled as part of vectorize_templated kernel + TORCH_INTERNAL_ASSERT(false, "Unexpected vectorization size"); + } +} +#endif + +template < + typename func_t, + typename array_t, + typename inp_calc_t, + typename out_calc_t, + typename loader_t, + typename storer_t> +static inline void launch_unrolled_kernel( + int64_t N, + const func_t& f, + array_t data, + inp_calc_t ic, + out_calc_t oc, + loader_t l, + storer_t s) { + TORCH_INTERNAL_ASSERT(N > 0 && N <= std::numeric_limits::max()); + + int64_t grid = (N + elementwise_block_work_size() - 1) / elementwise_block_work_size(); + auto stream = at::cuda::getCurrentCUDAStream(); + unrolled_elementwise_kernel + <<>>(N, f, data, ic, oc, l, s); + C10_CUDA_KERNEL_LAUNCH_CHECK(); +} + +template +C10_LAUNCH_BOUNDS_2(nt, 4) +__global__ void elementwise_kernel(int N, func_t f) { + int tid = threadIdx.x; + int nv = nt * vt; + int idx = nv * blockIdx.x + tid; +#pragma unroll + for (int i = 0; i < vt; i++) { + if (idx < N) { + f(idx); + idx += nt; + } + } +} + +template +static void launch_legacy_kernel(int64_t N, const func_t& f) { + TORCH_INTERNAL_ASSERT(N >= 0 && N <= std::numeric_limits::max()); + if (N == 0) { + return; + } + dim3 block(nt); + dim3 grid((N + block.x * vt - 1) / (block.x * vt)); + auto stream = at::cuda::getCurrentCUDAStream(); + elementwise_kernel<<>>(N, f); + C10_CUDA_KERNEL_LAUNCH_CHECK(); +} + +#ifdef USE_ROCM +template +C10_LAUNCH_BOUNDS_2(nt, 4) +__global__ void elementwise_kernel_manual_unroll(int N, func_t f) { + int tid = threadIdx.x; + constexpr int nv = nt * vt; + int idx = nv * blockIdx.x + tid; + if ((idx + nt*(vt-1)) < N) { + f(idx, true); + } else { +#pragma unroll + for (int i = 0; i < vt; i++) { + if (idx < N) { + f(idx, false); + idx += nt; + } + } + } +} + +template +static void launch_legacy_kernel_manual_unroll(int64_t N, const func_t& f) { + TORCH_INTERNAL_ASSERT(N >= 0 && N <= std::numeric_limits::max()); + if (N == 0) { + return; + } + dim3 block(nt); + dim3 grid((N + block.x * vt - 1) / (block.x * vt)); + auto stream = at::cuda::getCurrentCUDAStream(); + elementwise_kernel_manual_unroll<<>>(N, f); + C10_CUDA_KERNEL_LAUNCH_CHECK(); +} +#endif + +template +C10_HOST_DEVICE typename traits::result_type invoke_impl( + const func_t& f, + char* const C10_RESTRICT data[], + const index_t strides[], + int i, + std::index_sequence) { + (void)strides; + (void)i; + return f(c10::load::type>( + data[INDEX] + i * strides[INDEX])...); +} + +template < + typename func_t, + typename index_t, + typename traits = function_traits> +C10_HOST_DEVICE typename traits::result_type invoke( + const func_t& f, + char* const C10_RESTRICT data[], + const index_t strides[], + int i) { + using Indices = std::make_index_sequence; + return invoke_impl(f, data, strides, i, Indices{}); +} + +template +C10_HOST_DEVICE typename traits::result_type invoke_impl( + const func_t& f, + char* const C10_RESTRICT data[], + const index_t strides[], + const ScalarType dtypes[], + int i, + std::index_sequence) { + (void)strides; + (void)i; + return f(c10::fetch_and_cast::type>( + dtypes[I], data[I] + i * strides[I])...); +} + +template < + typename func_t, + typename index_t, + typename traits = function_traits> +C10_HOST_DEVICE typename traits::result_type invoke( + const func_t& f, + char* const C10_RESTRICT data[], + const index_t strides[], + const ScalarType dtypes[], + int i) { + using Indices = std::make_index_sequence; + return invoke_impl(f, data, strides, dtypes, i, Indices{}); +} + +template +void gpu_kernel_impl_nocast(TensorIteratorBase& iter, const func_t& f) { + using traits = function_traits; + using arg0_t = typename traits::result_type; + constexpr int ntensors = traits::arity + 1; + + TORCH_INTERNAL_ASSERT(iter.can_use_32bit_indexing()); + TORCH_INTERNAL_ASSERT(iter.ninputs() == traits::arity); + TORCH_INTERNAL_ASSERT(iter.noutputs() == 1); + TORCH_INTERNAL_ASSERT(!needs_dynamic_casting::check(iter)); + + std::array data; + for (int i = 0; i < ntensors; i++) { + data[i] = (char*)iter.data_ptr(i); + } + + int64_t numel = iter.numel(); + + bool contiguous = iter.is_contiguous(); + + if (contiguous) { + return launch_vectorized_kernel(numel, f, data); + } + auto offset_calc = ::make_offset_calculator(iter); +#ifndef USE_ROCM + constexpr int unroll_factor = sizeof(arg0_t) >= 4 ? 2 : 4; + launch_legacy_kernel<128, unroll_factor>(numel, [=] GPU_LAMBDA(int idx) { + auto offsets = offset_calc.get(idx); + arg0_t* out = (arg0_t*)(data[0] + offsets[0]); + *out = invoke(f, &data[1], &offsets[1], 1); + }); +#else + constexpr int unroll_factor = sizeof(arg0_t) >= 4 ? 4 : 8; + constexpr int grp_sz = 128; + launch_legacy_kernel_manual_unroll(numel, [=] GPU_LAMBDA(int idx, bool unrl) { + if (unrl) { + if constexpr (unroll_factor == 4) { + auto offsets0 = offset_calc.get(idx); + auto offsets1 = offset_calc.get(idx+grp_sz); + auto offsets2 = offset_calc.get(idx+grp_sz*2); + auto offsets3 = offset_calc.get(idx+grp_sz*3); + arg0_t* out0 = (arg0_t*)(data[0] + offsets0[0]); + arg0_t* out1 = (arg0_t*)(data[0] + offsets1[0]); + arg0_t* out2 = (arg0_t*)(data[0] + offsets2[0]); + arg0_t* out3 = (arg0_t*)(data[0] + offsets3[0]); + auto tmp0 = invoke(f, &data[1], &offsets0[1], 1); + auto tmp1 = invoke(f, &data[1], &offsets1[1], 1); + auto tmp2 = invoke(f, &data[1], &offsets2[1], 1); + auto tmp3 = invoke(f, &data[1], &offsets3[1], 1); + *out0 = tmp0; + *out1 = tmp1; + *out2 = tmp2; + *out3 = tmp3; + } else { + auto offsets0 = offset_calc.get(idx); + auto offsets1 = offset_calc.get(idx+grp_sz); + auto offsets2 = offset_calc.get(idx+grp_sz*2); + auto offsets3 = offset_calc.get(idx+grp_sz*3); + auto offsets4 = offset_calc.get(idx+grp_sz*4); + auto offsets5 = offset_calc.get(idx+grp_sz*5); + auto offsets6 = offset_calc.get(idx+grp_sz*6); + auto offsets7 = offset_calc.get(idx+grp_sz*7); + arg0_t* out0 = (arg0_t*)(data[0] + offsets0[0]); + arg0_t* out1 = (arg0_t*)(data[0] + offsets1[0]); + arg0_t* out2 = (arg0_t*)(data[0] + offsets2[0]); + arg0_t* out3 = (arg0_t*)(data[0] + offsets3[0]); + arg0_t* out4 = (arg0_t*)(data[0] + offsets4[0]); + arg0_t* out5 = (arg0_t*)(data[0] + offsets5[0]); + arg0_t* out6 = (arg0_t*)(data[0] + offsets6[0]); + arg0_t* out7 = (arg0_t*)(data[0] + offsets7[0]); + auto tmp0 = invoke(f, &data[1], &offsets0[1], 1); + auto tmp1 = invoke(f, &data[1], &offsets1[1], 1); + auto tmp2 = invoke(f, &data[1], &offsets2[1], 1); + auto tmp3 = invoke(f, &data[1], &offsets3[1], 1); + auto tmp4 = invoke(f, &data[1], &offsets4[1], 1); + auto tmp5 = invoke(f, &data[1], &offsets5[1], 1); + auto tmp6 = invoke(f, &data[1], &offsets6[1], 1); + auto tmp7 = invoke(f, &data[1], &offsets7[1], 1); + *out0 = tmp0; + *out1 = tmp1; + *out2 = tmp2; + *out3 = tmp3; + *out4 = tmp4; + *out5 = tmp5; + *out6 = tmp6; + *out7 = tmp7; + } + } else { + auto offsets = offset_calc.get(idx); + arg0_t* out = (arg0_t*)(data[0] + offsets[0]); + *out = invoke(f, &data[1], &offsets[1], 1); + } + }); +#endif +} + +#ifdef USE_ROCM +namespace { +template < + typename TupleLike, + typename FirstParamTy, + typename SecondParamTy, + size_t arity, + size_t arg_num = 0> +struct check_binary_functor_types_for_specialization { + constexpr static inline bool check() { + if constexpr (arity != 2) + return false; + if constexpr (arg_num == 0) { + using SelectedType = std::tuple_element_t; + if constexpr (std::is_same_v) + return check_binary_functor_types_for_specialization< + TupleLike, + FirstParamTy, + SecondParamTy, + arity, + arg_num + 1>::check(); + } else if constexpr (arg_num == 1) { + using SelectedType2 = std::tuple_element_t; + if constexpr (std::is_same_v) + return check_binary_functor_types_for_specialization< + TupleLike, + FirstParamTy, + SecondParamTy, + arity, + arg_num + 1>::check(); + } + return false; + } +}; + +// Bottom case: if we got this far, assume correct type matching except +// when there are no arguments (arity == 0). +template < + typename TupleLike, + typename FirstParamTy, + typename SecondParamTy, + size_t arity> +struct check_binary_functor_types_for_specialization< + TupleLike, + FirstParamTy, + SecondParamTy, + arity, + arity> { + constexpr static inline bool check() { + if constexpr (arity != 0) + return true; + return false; + } +}; + +template +struct check_binary_functor_types_for_specialization< + TupleLike, + FirstParamTy, + SecondParamTy, + 0, + 0> { + constexpr static inline bool check() { + return false; + } +}; + +// The following is a list of type specializations for vectorized_templated +// elementwise kernel. The three types refer to runtime types of the output +// tensor, first tensor argument, and the second tensor argument used for a +// binary functor. +constexpr std::array rt_binary_specializations = { + std::array( + {c10::CppTypeToScalarType::value, + c10::CppTypeToScalarType::value, + c10::CppTypeToScalarType::value}), + std::array( + {c10::CppTypeToScalarType::value, + c10::CppTypeToScalarType::value, + c10::CppTypeToScalarType::value}), + std::array( + {c10::CppTypeToScalarType::value, + c10::CppTypeToScalarType::value, + c10::CppTypeToScalarType::value}), + std::array( + {c10::CppTypeToScalarType::value, + c10::CppTypeToScalarType::value, + c10::CppTypeToScalarType::value}), + std::array( + {c10::CppTypeToScalarType::value, + c10::CppTypeToScalarType::value, + c10::CppTypeToScalarType::value}), + std::array( + {c10::CppTypeToScalarType::value, + c10::CppTypeToScalarType::value, + c10::CppTypeToScalarType::value})}; + +bool check_binary_rt_types_for_specialization(TensorIteratorBase& iter) { + if (iter.ninputs() != 2) + return false; + for (auto spec : rt_binary_specializations) + if (iter.dtype(0) == spec[0] && iter.input_dtype(0) == spec[1] && + iter.input_dtype(1) == spec[2]) + return true; + return false; +} + +template +struct type_specialized_kernel_launcher { + template < + typename func_t, + typename array_t, + typename inp_calc_t, + typename out_calc_t, + typename loader_t, + typename storer_t> + static void apply( + ScalarType ret_t, + ScalarType arg0_t, + ScalarType arg1_t, + int64_t numel, + func_t f, + array_t data, + inp_calc_t input_offset_calculator, + out_calc_t output_offset_calculator, + loader_t loader, + storer_t storer) { + constexpr ScalarType sret_t = rt_binary_specializations[arg_index][0]; + constexpr ScalarType sarg0_t = rt_binary_specializations[arg_index][1]; + constexpr ScalarType sarg1_t = rt_binary_specializations[arg_index][2]; + if (ret_t == sret_t && arg0_t == sarg0_t && arg1_t == sarg1_t) { + using cret_t = c10::impl::ScalarTypeToCPPTypeT; + using carg0_t = c10::impl::ScalarTypeToCPPTypeT; + using carg1_t = c10::impl::ScalarTypeToCPPTypeT; + launch_vectorized_templated_kernel< + func_t, + array_t, + inp_calc_t, + out_calc_t, + loader_t, + storer_t, + cret_t, + carg0_t, + carg1_t>( + numel, + f, + data, + input_offset_calculator, + output_offset_calculator, + loader, + storer); + } + } +}; + +template +struct type_specialized_broadcast_kernel_launcher { + template < + typename func_t, + typename array_t, + typename dtypes_t, + typename calc_t> + static void apply( + int64_t numel, + func_t f, + array_t data, + dtypes_t dtypes, + calc_t offset_calc) { + using traits = function_traits; + using ret_t = typename traits::result_type; + using arg0_t = typename traits::template arg<0>::type; + using arg1_t = typename traits::template arg<1>::type; + if (dtypes[0] == rt_binary_specializations[arg_index][0] && + dtypes[1] == rt_binary_specializations[arg_index][1] && + dtypes[2] == rt_binary_specializations[arg_index][2]) { + using ret_cpp_t = c10::impl::ScalarTypeToCPPTypeT; + using arg0_cpp_t = c10::impl::ScalarTypeToCPPTypeT; + using arg1_cpp_t = c10::impl::ScalarTypeToCPPTypeT; + constexpr int grp_sz = 128; + launch_legacy_kernel_manual_unroll(numel, [=] GPU_LAMBDA(int idx, bool unrl) { + if (unrl) { + auto offsets0 = offset_calc.get(idx); + auto offsets1 = offset_calc.get(idx + grp_sz); + auto offsets2 = offset_calc.get(idx + grp_sz * 2); + auto offsets3 = offset_calc.get(idx + grp_sz * 3); + void* out0 = data[0] + offsets0[0]; + void* out1 = data[0] + offsets1[0]; + void* out2 = data[0] + offsets2[0]; + void* out3 = data[0] + offsets3[0]; + auto u = c10::load(data[1] + offsets0[1]); + auto v = c10::load(data[2] + offsets0[2]); + ret_t result0 = f(c10::convert(u), c10::convert(v)); + auto u1 = c10::load(data[1] + offsets1[1]); + auto v1 = c10::load(data[2]+ offsets1[2]); + ret_t result1 = f(c10::convert(u1), c10::convert(v1)); + auto u2 = c10::load(data[1] + offsets2[1]); + auto v2 = c10::load(data[2] + offsets2[2]); + ret_t result2 = f(c10::convert(u2), c10::convert(v2)); + auto u3 = c10::load(data[1] + offsets3[1]); + auto v3 = c10::load(data[2] + offsets3[2]); + ret_t result3 = f(c10::convert(u3), c10::convert(v3)); + *(ret_cpp_t*)out0 = c10::convert(result0); + *(ret_cpp_t*)out1 = c10::convert(result1); + *(ret_cpp_t*)out2 = c10::convert(result2); + *(ret_cpp_t*)out3 = c10::convert(result3); + } else { + auto offsets = offset_calc.get(idx); + void* out = data[0] + offsets[0]; + auto u = c10::load(data[1] + offsets[1]); + auto v = c10::load(data[2] + offsets[2]); + ret_t result = f(c10::convert(u), c10::convert(v)); + *(ret_cpp_t*)out = c10::convert(result); + } + }); + } + } +}; + +} // namespace +#endif + +template +void gpu_kernel_impl(TensorIteratorBase& iter, const func_t& f) { + if (!needs_dynamic_casting::check(iter)) { + return gpu_kernel_impl_nocast(iter, f); + } + using traits = function_traits; + using arg0_t = typename traits::result_type; + constexpr int ntensors = traits::arity + 1; + + TORCH_INTERNAL_ASSERT(iter.can_use_32bit_indexing()); + TORCH_INTERNAL_ASSERT(iter.ninputs() == traits::arity); + TORCH_INTERNAL_ASSERT(iter.noutputs() == 1); + + std::array data; + for (int i = 0; i < ntensors; i++) { + data[i] = (char*)iter.data_ptr(i); + } + + int64_t numel = iter.numel(); + + bool contiguous = iter.is_contiguous(); + + if (contiguous) { +#ifdef USE_ROCM + // Attempt to call specialized vectorized elementwise kernel + // that enables interleaving. + if (check_binary_rt_types_for_specialization(iter) && + memory::can_vectorize_up_to(data) > 1) { + // constexpr to reduce the amount of kernels generated for + // vectorized templated elementwise and limit which functors are actually + // applied to the load and store at compile time. + using func_tuple = typename traits::ArgsTuple; + if constexpr ( + std::is_same_v && traits::arity == 2 && + check_binary_functor_types_for_specialization< + func_tuple, + float, + float, + traits::arity, + /*arg_num=*/0>::check()) { + // If we got here, we know we are in one of the specialized cases. We + // need to translate the runtime type to a statically known type. This + // is effectively hoisting to the host the switch over runtime type in + // the kernel in fetch_and_cast. Loader, storer, offset calculators are + // only needed for the reminder loop. + auto input_offset_calculator = TrivialOffsetCalculator(); + auto output_offset_calculator = TrivialOffsetCalculator<1>(); + auto loader = memory::LoadWithCast(iter); + auto storer = memory::StoreWithCast<1>(iter); + memory::detail::static_unroll< + type_specialized_kernel_launcher, + rt_binary_specializations.size()>:: + with_args( + iter.dtype(0), + iter.input_dtype(0), + iter.input_dtype(1), + numel, + f, + data, + input_offset_calculator, + output_offset_calculator, + loader, + storer); + return; + } + } + std::array dtypes; + auto inner_strides = iter.get_inner_strides(); + std::array strides; + for (int i = 0; i < ntensors; i++) { + dtypes[i] = iter.dtype(i); + strides[i] = inner_strides[i]; + } + constexpr int grp_sz = 128; + launch_legacy_kernel_manual_unroll(numel, [=] GPU_LAMBDA(int idx, bool unrl) { + if (unrl) { + void* out0 = data[0] + strides[0] * idx; + void* out1 = data[0] + strides[0] * (idx + grp_sz); + void* out2 = data[0] + strides[0] * (idx + grp_sz * 2); + void* out3 = data[0] + strides[0] * (idx + grp_sz * 3); + arg0_t result0 = invoke(f, &data[1], &strides[1], &dtypes[1], idx); + arg0_t result1 = invoke(f, &data[1], &strides[1], &dtypes[1], (idx + grp_sz)); + arg0_t result2 = invoke(f, &data[1], &strides[1], &dtypes[1], (idx + grp_sz * 2)); + arg0_t result3 = invoke(f, &data[1], &strides[1], &dtypes[1], (idx + grp_sz * 3)); + c10::cast_and_store(dtypes[0], out0, result0); + c10::cast_and_store(dtypes[0], out1, result1); + c10::cast_and_store(dtypes[0], out2, result2); + c10::cast_and_store(dtypes[0], out3, result3); + } else { + void* out = data[0] + strides[0] * idx; + arg0_t result = invoke(f, &data[1], &strides[1], &dtypes[1], idx); + c10::cast_and_store(dtypes[0], out, result); + } + }); +#else + auto loader = memory::LoadWithCast(iter); + auto storer = memory::StoreWithCast<1>(iter); + auto input_offset_calculator = TrivialOffsetCalculator(); + auto output_offset_calculator = TrivialOffsetCalculator<1>(); + launch_unrolled_kernel( + numel, + f, + data, + input_offset_calculator, + output_offset_calculator, + loader, + storer); +#endif + } else { + std::array dtypes; + for (int i = 0; i < ntensors; i++) { + dtypes[i] = iter.dtype(i); + } + auto offset_calc = ::make_offset_calculator(iter); +#ifdef USE_ROCM + if (check_binary_rt_types_for_specialization(iter)) { + // constexpr to reduce the amount of kernels generated for + // broadcast elementwise with mexed dtypes and limit which functors are actually + // applied to the load and store at compile time. + using func_tuple = typename traits::ArgsTuple; + if constexpr ( + std::is_same_v && traits::arity == 2 && + check_binary_functor_types_for_specialization< + func_tuple, + float, + float, + traits::arity, + /*arg_num=*/0>::check()) { + memory::detail::static_unroll< + type_specialized_broadcast_kernel_launcher, + rt_binary_specializations.size()>::with_args( + numel, + f, + data, + dtypes, + offset_calc + ); + return; + } + } + + constexpr int grp_sz = 128; + launch_legacy_kernel_manual_unroll(numel, [=] GPU_LAMBDA(int idx, bool unrl) { + if (unrl) { + auto offsets0 = offset_calc.get(idx); + auto offsets1 = offset_calc.get(idx + grp_sz); + auto offsets2 = offset_calc.get(idx + grp_sz * 2); + auto offsets3 = offset_calc.get(idx + grp_sz * 3); + void* out0 = data[0] + offsets0[0]; + void* out1 = data[0] + offsets1[0]; + void* out2 = data[0] + offsets2[0]; + void* out3 = data[0] + offsets3[0]; + arg0_t result0 = invoke(f, &data[1], &offsets0[1], &dtypes[1], 1); + arg0_t result1 = invoke(f, &data[1], &offsets1[1], &dtypes[1], 1); + arg0_t result2 = invoke(f, &data[1], &offsets2[1], &dtypes[1], 1); + arg0_t result3 = invoke(f, &data[1], &offsets3[1], &dtypes[1], 1); + c10::cast_and_store(dtypes[0], out0, result0); + c10::cast_and_store(dtypes[0], out1, result1); + c10::cast_and_store(dtypes[0], out2, result2); + c10::cast_and_store(dtypes[0], out3, result3); + } else { + auto offsets = offset_calc.get(idx); + void* out = data[0] + offsets[0]; + arg0_t result = invoke(f, &data[1], &offsets[1], &dtypes[1], 1); + c10::cast_and_store(dtypes[0], out, result); + } + }); +#else + launch_legacy_kernel<128, 4>(numel, [=] GPU_LAMBDA(int idx) { + auto offsets = offset_calc.get(idx); + void* out = data[0] + offsets[0]; + arg0_t result = invoke(f, &data[1], &offsets[1], &dtypes[1], 1); + c10::cast_and_store(dtypes[0], out, result); + }); +#endif + } +} + +} // namespace at::native + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/ForeachMinMaxFunctors.cuh b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/ForeachMinMaxFunctors.cuh new file mode 100644 index 0000000000000000000000000000000000000000..302ed64991ab18b46d7e7bfe35cbd8482fc1c941 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/ForeachMinMaxFunctors.cuh @@ -0,0 +1,27 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include + +namespace at::native { + +// std:: does not have clamp functors +template +struct minimum { + __device__ T operator()(const T& a, const T& b) const { + return (_isnan(a) || a < b) ? a : b; + } +}; + +template +struct maximum { + __device__ T operator()(const T& a, const T& b) const { + return (_isnan(a) || a > b) ? a : b; + } +}; + +} // namespace at::native + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/GridSampler.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/GridSampler.h new file mode 100644 index 0000000000000000000000000000000000000000..3730e552317e77d4d9a110c6d330fa31fb2d4b47 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/GridSampler.h @@ -0,0 +1,36 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once +#include +#include + +namespace at { +class TensorBase; +} + +namespace at::native { + +void launch_grid_sampler_2d_forward_kernel( + const TensorBase &output, const TensorBase &input, const TensorBase &grid, + int64_t interpolation_mode, int64_t padding_mode, bool align_corners); + +void launch_grid_sampler_3d_forward_kernel( + const TensorBase &output, const TensorBase &input, const TensorBase &grid, + int64_t interpolation_mode, int64_t padding_mode, bool align_corners); + +void launch_grid_sampler_2d_backward_kernel( + const TensorBase &grad_input, const TensorBase &grad_grid, + const TensorBase &grad_output, const TensorBase &input, + const TensorBase &grid, int64_t interpolation_mode, int64_t padding_mode, + bool align_corners, std::array output_mask); + +void launch_grid_sampler_3d_backward_kernel( + const TensorBase &grad_input, const TensorBase &grad_grid, + const TensorBase &grad_output, const TensorBase &input, + const TensorBase &grid, int64_t interpolation_mode, int64_t padding_mode, + bool align_corners, std::array output_mask); + +} // namespace at::native + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/Loops.cuh b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/Loops.cuh new file mode 100644 index 0000000000000000000000000000000000000000..a856720d0a9d478b9e41dc8b065d7e03843a6411 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/Loops.cuh @@ -0,0 +1,338 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#include + + + +namespace at::native { + +template +static OffsetCalculator make_input_offset_calculator(const TensorIteratorBase& iter) { + // array size can not be 0, this happens when N == 0 + constexpr int array_size = std::max(N, 1); + TORCH_INTERNAL_ASSERT(N == iter.ntensors() - iter.noutputs()); + std::array strides; + int64_t element_sizes[array_size]; + for (int i = 0; i < N; i++) { + strides[i] = iter.strides(i + iter.noutputs()).data(); + element_sizes[i] = iter.element_size(i + iter.noutputs()); + } + return OffsetCalculator(iter.ndim(), iter.shape().data(), strides.data(), element_sizes); +} + +template +static OffsetCalculator make_output_offset_calculator(const TensorIteratorBase& iter) { + TORCH_INTERNAL_ASSERT(num_outputs == iter.noutputs()); + std::array strides; + int64_t element_sizes[num_outputs]; + for (int i = 0; i < num_outputs; i++) { + strides[i] = iter.strides(i).data(); + element_sizes[i] = iter.element_size(i); + } + return OffsetCalculator(iter.ndim(), iter.shape().data(), strides.data(), element_sizes); +} + +template +__device__ inline void elementwise_kernel_helper(func_t f, policy_t policy) { + using traits = function_traits; + using return_t = typename traits::result_type; + using args_t = typename traits::ArgsTuple; + constexpr int elems_per_thread = policy_t::tws; + + int idx = blockIdx.x; + if constexpr (reverted_idx) + idx = gridDim.x - blockIdx.x - 1; + + return_t results[elems_per_thread]; + args_t args[elems_per_thread]; + + // load + policy.load(args, idx); + + // compute + #pragma unroll + for (int i = 0; i < elems_per_thread; i++) { + if (policy.check_inbounds(i)) { +#if defined(__HIP__) + results[i] = c10::guts::apply(f, args[i]); +#else + results[i] = std::apply(f, args[i]); +#endif + } + } + + // store + policy.store(results, idx); +} + +} // namespace at::native + +#include + +namespace at:: native { + +template +void gpu_kernel_nocast(TensorIteratorBase& iter, const func_t& f) { + + for (int arg = 0; arg < iter.ntensors(); arg++) { + TORCH_INTERNAL_ASSERT( + iter.device(arg).is_cuda(), + "argument ", arg, ": expected a CUDA device but found ", iter.device(arg)); + } + + if (iter.numel() == 0) { + return; + } + + if (!iter.can_use_32bit_indexing()) { + for (auto& sub_iter : iter.with_32bit_indexing()) { + gpu_kernel_nocast(sub_iter, f); + } + return; + } + + gpu_kernel_impl_nocast(iter, f); +} + +template +void gpu_kernel(TensorIteratorBase& iter, const func_t& f) { + + for (int arg = 0; arg < iter.ntensors(); arg++) { + TORCH_INTERNAL_ASSERT( + iter.device(arg).is_cuda(), + "argument ", arg, ": expected a CUDA device but found ", iter.device(arg)); + } + + if (iter.numel() == 0) { + return; + } + + if (!iter.can_use_32bit_indexing()) { + for (auto& sub_iter : iter.with_32bit_indexing()) { + gpu_kernel(sub_iter, f); + } + return; + } + + gpu_kernel_impl(iter, f); +} + +template +struct AUnaryFunctor { + using traits = function_traits; + using opmath_arg1_t = typename traits::template arg<0>::type; + __device__ return_t operator()(arg2_t b) const { + return f(a, b); + } + // NB: scalar is stored in higher precision! + AUnaryFunctor(func_t f_, opmath_arg1_t a_): f(f_), a(a_) {} + private: + func_t f; + opmath_arg1_t a; +}; + +template +struct BUnaryFunctor { + using traits = function_traits; + using opmath_arg2_t = typename traits::template arg<1>::type; + __device__ return_t operator()(arg1_t a) const { + return f(a, b); + } + // NB: scalar is stored in higher precision! + BUnaryFunctor(func_t f_, opmath_arg2_t b_): f(f_), b(b_) {} + private: + func_t f; + opmath_arg2_t b; +}; + +// Though seemingly noop, this inserts casts from arg1_t to func_t's type +// (which may be higher precision), as well as casts to return_t +template +struct BinaryFunctor { + __device__ return_t operator()(arg1_t a, arg2_t b) const { + return f(a, b); + } + BinaryFunctor(func_t f_): f(f_) {} + private: + func_t f; +}; + +// Unlike gpu_kernel_with_scalars, this allows you to pass a func_t which +// accepts inputs at higher precision (typically opmath_t), but then +// ensure that we load from memory at the correct precision (scalar_t) +// to avoid expensive loads. For the whole sordid story see +// https://dev-discuss.pytorch.org/t/cuda-loops-case-study-code-generation-vs-templates/302 +template +void opmath_gpu_kernel_with_scalars(TensorIteratorBase& iter, const func_t& f) { + TORCH_INTERNAL_ASSERT(iter.ntensors() == 3); + + using traits = function_traits; + using opmath_arg1_t = typename traits::template arg<0>::type; + using opmath_arg2_t = typename traits::template arg<1>::type; + static_assert( + traits::arity == 2, + "gpu_kernel_with_scalars only supports two input arguments"); + + if (iter.is_cpu_scalar(1)) { + AUnaryFunctor af(f, iter.scalar_value(1)); + iter.remove_operand(1); + // TODO: When all kernels that use gpu_kernel_with_scalars are + // ported to structured, this device guard can be deleted. This + // works around incorrect device guard generation for pre-structured + // kernels device guards, but structured kernels do it right and + // we can assume the device is already set correctly + const OptionalDeviceGuard device_guard(iter.device(1)); + gpu_kernel(iter, af); + } else if (iter.is_cpu_scalar(2)) { + BUnaryFunctor bf(f, iter.scalar_value(2)); + iter.remove_operand(2); + gpu_kernel(iter, bf); + } else { + gpu_kernel(iter, BinaryFunctor(f)); + } +} + +template +void opmath_symmetric_gpu_kernel_with_scalars(TensorIteratorBase& iter, const func_t& f) { + // Use symmetric property of the functor to reduce number of kernels, + // requires f(a, b) == f(b, a) + TORCH_INTERNAL_ASSERT(iter.ntensors() == 3); + + using traits = function_traits; + using opmath_arg_t = typename traits::template arg<0>::type; + static_assert( + traits::arity == 2, + "gpu_kernel_with_scalars only supports two input arguments"); + static_assert(std::is_same_v::type>, + "f is not symmetric"); + + OptionalDeviceGuard device_guard; + opmath_arg_t scalar_val{}; + + if (iter.is_cpu_scalar(1)) { + scalar_val = iter.scalar_value(1); + iter.remove_operand(1); + + // TODO: When all kernels that use gpu_kernel_with_scalars are + // ported to structured, this device guard can be deleted. This + // works around incorrect device guard generation for pre-structured + // kernels device guards, but structured kernels do it right and + // we can assume the device is already set correctly + device_guard.reset_device(iter.device(1)); + } else if (iter.is_cpu_scalar(2)) { + scalar_val = iter.scalar_value(2); + iter.remove_operand(2); + } + + if (iter.ninputs() == 2) { + gpu_kernel(iter, BinaryFunctor(f)); + } else { + AUnaryFunctor unary_f(f, scalar_val); + gpu_kernel(iter, unary_f); + } +} + +// Legacy variant that assumes that func_t has the correct types +// that we expect to load from memory +template +void gpu_kernel_with_scalars(TensorIteratorBase& iter, const func_t& f) { + using traits = function_traits; + static_assert( + traits::arity == 2, + "gpu_kernel_with_scalars only supports two input arguments"); + using arg1_t = typename traits::template arg<0>::type; + using arg2_t = typename traits::template arg<1>::type; + using return_t = typename traits::result_type; + opmath_gpu_kernel_with_scalars(iter, f); +} + +namespace { // functions for `gpu_kernel_multiple_outputs`. + +// check the return type is `thrust::tuple`, not `std::tuple`. +template struct is_tuple: std::false_type {}; + +template struct is_tuple>: std::true_type {}; + +template +C10_LAUNCH_BOUNDS_1(num_threads()) +__global__ void unrolled_elementwise_kernel_for_multi_outputs(int N, func_t f, array_t data, inp_calc_t ic, out_calc_t oc) { + int remaining = N - block_work_size() * blockIdx.x; + elementwise_kernel_helper(f, memory::policies::multi_outputs_unroll(data, remaining, ic, oc)); +} + +template +static inline void launch_unrolled_kernel_for_multi_outputs(int64_t N, const func_t& f, array_t data, inp_calc_t ic, out_calc_t oc) { + TORCH_INTERNAL_ASSERT(N > 0 && N <= std::numeric_limits::max()); + int64_t grid = (N + block_work_size() - 1) / block_work_size(); + auto stream = at::cuda::getCurrentCUDAStream(); + unrolled_elementwise_kernel_for_multi_outputs<<>>(N, f, data, ic, oc); + C10_CUDA_KERNEL_LAUNCH_CHECK(); +} + +template +void gpu_kernel_multiple_outputs_impl(TensorIteratorBase& iter, const func_t& f) { + using traits = function_traits; + using output_t = typename traits::result_type; + static_assert(is_tuple::value, "f's return type must be `thrust::tuple`"); + constexpr int num_outputs = thrust::tuple_size::value; + constexpr int num_inputs = traits::arity; + constexpr int ntensors = num_outputs + num_inputs; + + TORCH_INTERNAL_ASSERT(iter.can_use_32bit_indexing()); + TORCH_INTERNAL_ASSERT(iter.ntensors() == ntensors); + + std::array data; + for (int i = 0; i < ntensors; i++) { + data[i] = (char*)iter.data_ptr(i); + } + + int64_t numel = iter.numel(); + + if (iter.is_contiguous()) { + auto input_calc = TrivialOffsetCalculator(); + auto output_calc = TrivialOffsetCalculator(); + launch_unrolled_kernel_for_multi_outputs(numel, f, data, input_calc, output_calc); + } else { + auto input_calc = make_input_offset_calculator(iter); + auto output_calc = make_output_offset_calculator(iter); + launch_unrolled_kernel_for_multi_outputs(numel, f, data, input_calc, output_calc); + } +} +} // namespace + +template +void gpu_kernel_multiple_outputs(TensorIteratorBase& iter, const func_t& f) { + ASSERT_HOST_DEVICE_LAMBDA(func_t); + + for (int arg = 0; arg < iter.ntensors(); arg++) { + TORCH_INTERNAL_ASSERT(iter.device(arg).is_cuda()); + } + + if (iter.numel() == 0) { + return; + } + + if (!iter.can_use_32bit_indexing()) { + for (auto& sub_iter : iter.with_32bit_indexing()) { + gpu_kernel_multiple_outputs(sub_iter, f); + } + return; + } + + gpu_kernel_multiple_outputs_impl(iter, f); +} + +} //namespace at::native + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/Pow.cuh b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/Pow.cuh new file mode 100644 index 0000000000000000000000000000000000000000..8ee3cd13337f9abb3ddf2d8a516ebaff9e458b47 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/Pow.cuh @@ -0,0 +1,63 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once +#include +#include + +namespace at::native { + +namespace { + + +// SFINAE doesn't work well with NVCC under Windows for math functions like pow and sqrt. +// So we need to define the functions with the explicit function signatures. +// As for pow, the following signatures are defined as the device function: +// pow(float, int) +// pow(double, int) +// pow(float, float) +// pow(double, double) +#if defined(_MSC_VER) || defined(_LIBCPP_VERSION) +// Functions for pow +// pow for at::Half +static inline __host__ __device__ at::Half pow_(at::Half base, at::Half exp) { + return static_cast(std::pow(static_cast(base), static_cast(exp))); +} +// pow for at::BFloat16 +static inline __host__ __device__ at::BFloat16 pow_(at::BFloat16 base, at::BFloat16 exp) { + return static_cast(std::pow(static_cast(base), static_cast(exp))); +} +// pow (floating, floating/int) +template +static inline __host__ __device__ typename std::enable_if_t && (std::is_same_v || std::is_same_v), Base_type> + pow_(Base_type base, Exp_type exp) { + return std::pow(base, exp); +} +// pow (Otherwise) +template +static inline __host__ __device__ typename std::enable_if_t && !std::is_same_v, Base_type> + pow_(Base_type base, Exp_type exp) { + return static_cast(std::pow(static_cast(base), static_cast(exp))); +} +#else +template +static inline __host__ __device__ Base_type pow_(Base_type base, Exp_type exp) { + return ::pow(base, exp); +} +#endif + +template +static inline __host__ __device__ std::enable_if_t, T> pow_( + T base, T exp) { + return at::native::powi(base, exp); +} + +template +static inline __host__ __device__ c10::complex pow_(c10::complex base, c10::complex exp) { + return c10_complex_math::pow(base, exp); +} + +} // namespace +} // namespace at::native + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/Randperm.cuh b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/Randperm.cuh new file mode 100644 index 0000000000000000000000000000000000000000..8ac62c82fd0f8d1277f437311483cd0b64d0ad09 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/Randperm.cuh @@ -0,0 +1,63 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#include +#include +#include + +#include +#include +#include + +namespace { + +// See note [Algorithm of randperm] +template +__global__ void randperm_handle_duplicate_keys_kernel(T *keys, scalar_t *data, T mask, int n, at::PhiloxCudaState philox_args) { + int tid = threadIdx.x + blockDim.x * blockIdx.x; + + // find the beginning of islands + if (tid >= n - 1) return; // out of range + if ((keys[tid] & mask) != (keys[tid + 1] & mask)) return; // not in an island + if (tid != 0 && (keys[tid] & mask) == (keys[tid - 1] & mask)) return; // not the beginning of an island + + // find the size of islands + int island_size = 0; + do { island_size++; } + while ((tid + island_size < n) && (keys[tid + island_size] & mask) == (keys[tid] & mask)); + + // do random permutation inside each island. + data += tid; + const auto [seed, offset] = at::cuda::philox::unpack(philox_args); + curandStatePhilox4_32_10_t state; + curand_init(seed, tid, offset, &state); + for (int i = island_size - 1; i > 0; i--) { + unsigned int r = curand(&state) % (i + 1); + if (i != r) { + scalar_t tmp = data[i]; + data[i] = data[r]; + data[r] = tmp; + } + } +} + +// See note [Algorithm of randperm] +template +void randperm_handle_duplicate_keys(T *keys, scalar_t *data, int bits, int64_t n, std::optional &gen_) { + auto gen = at::get_generator_or_default(gen_, at::cuda::detail::getDefaultCUDAGenerator()); + int64_t counter_offset = n; + at::PhiloxCudaState rng_engine_inputs; + { + // See Note [Acquire lock when using random generators] + std::lock_guard lock(gen->mutex_); + rng_engine_inputs = gen->philox_cuda_state(counter_offset); + } + T mask = static_cast((1UL << bits) - 1); + randperm_handle_duplicate_keys_kernel<<<(n + 511) / 512, 512, 0, at::cuda::getCurrentCUDAStream()>>>( + keys, data, mask, n, rng_engine_inputs); + C10_CUDA_KERNEL_LAUNCH_CHECK(); +} + +} + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/Sorting.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/Sorting.h new file mode 100644 index 0000000000000000000000000000000000000000..32697c85aa281e4aa1a4b305ef0a5551dacab1e7 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/Sorting.h @@ -0,0 +1,22 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once +#include + +namespace at { +class TensorBase; +} + +namespace at::native { + +void launch_kthvalue_kernel( + const TensorBase &values, const TensorBase &indices, + const TensorBase &self, int64_t dim, int64_t k); +void launch_median_kernel( + const TensorBase &vals, const TensorBase &inds, + const TensorBase &in, int64_t dim, bool ignore_nan); + +} // namespace at::native + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/SortingRadixSelect.cuh b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/SortingRadixSelect.cuh new file mode 100644 index 0000000000000000000000000000000000000000..1af2600814ffcafe11e24734b2c6a3d4b107aeec --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/SortingRadixSelect.cuh @@ -0,0 +1,432 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#include +#include +#include +#include +#include + +namespace at::native { + +template +struct TopKTypeConfig {}; + +template <> +struct TopKTypeConfig { + typedef uint32_t RadixType; + + // Converts a float to an integer representation with the same + // sorting; i.e., for floats f1, f2: + // if f1 < f2 then convert(f1) < convert(f2) + // We use this to enable radix selection of floating-point values. + // This also gives a relative order for NaNs, but that's ok, as they + // will all be adjacent + // neg inf: signbit=1 exp=ff fraction=0 --> radix = 0 00 ff.. + // pos inf: signbit=0 exp=ff fraction=0 --> radix = 1 ff 00.. + // pos nan: signbit=0 exp=ff fraction>0 --> radix = 1 ff x>0 + // neg nan: signbit=1 exp=ff fraction>0 --> radix = 0 00 x +struct TopKTypeConfig { + typedef uint32_t RadixType; + + static inline __device__ RadixType convert(uint8_t v) { + return v; + } + + static inline __device__ uint8_t deconvert(RadixType v) { + return v; + } +}; + +template <> +struct TopKTypeConfig { + typedef uint32_t RadixType; + + static inline __device__ RadixType convert(int8_t v) { + return 128u + v; + } + + static inline __device__ int8_t deconvert(RadixType v) { + return v - 128; + } +}; + +template <> +struct TopKTypeConfig { + typedef uint32_t RadixType; + + static inline __device__ RadixType convert(int16_t v) { + static_assert(sizeof(short) == 2, ""); + return 32768u + v; + } + + static inline __device__ int16_t deconvert(RadixType v) { + return v - 32768; + } +}; + +template <> +struct TopKTypeConfig { + typedef uint32_t RadixType; + + static inline __device__ RadixType convert(int32_t v) { + static_assert(sizeof(int) == 4, ""); + return 2147483648u + v; + } + + static inline __device__ int32_t deconvert(RadixType v) { + return v - 2147483648u; + } +}; + +template <> +struct TopKTypeConfig { + typedef uint64_t RadixType; + + static inline __device__ RadixType convert(int64_t v) { + static_assert(sizeof(int64_t) == 8, ""); + return 9223372036854775808ull + v; + } + + static inline __device__ int64_t deconvert(RadixType v) { + return v - 9223372036854775808ull; + } +}; + +template <> +struct TopKTypeConfig { + typedef uint64_t RadixType; + + static inline __device__ RadixType convert(double v) { + RadixType x = __double_as_longlong(v); + RadixType mask = -((x >> 63)) | 0x8000000000000000; + return (v == v) ? (x ^ mask) : 0xffffffffffffffff; + } + + static inline __device__ double deconvert(RadixType v) { + RadixType mask = ((v >> 63) - 1) | 0x8000000000000000; + return __longlong_as_double(v ^ mask); + } +}; + +template <> +struct TopKTypeConfig { + typedef uint32_t RadixType; + + static inline __device__ RadixType convert(at::Half v) { +#if defined(__CUDA_ARCH__) || defined(USE_ROCM) + RadixType x = __half_as_ushort(v); + RadixType mask = (x & 0x00008000) ? 0x0000ffff : 0x00008000; + return (v == v) ? (x ^ mask) : 0xffff; +#else + CUDA_KERNEL_ASSERT(false); + return 0u; +#endif + } + + static inline __device__ at::Half deconvert(RadixType v) { +#if defined(__CUDA_ARCH__) || defined(USE_ROCM) + RadixType mask = (v & 0x00008000) ? 0x00008000 : 0x0000ffff; + return __ushort_as_half(v ^ mask); +#else + CUDA_KERNEL_ASSERT(false); + return static_cast(0); +#endif + } +}; + +template <> +struct TopKTypeConfig { + typedef uint32_t RadixType; + + static inline __device__ RadixType convert(at::BFloat16 v) { + RadixType x = v.x; + RadixType mask = (x & 0x00008000) ? 0x0000ffff : 0x00008000; + return (v == v) ? (x ^ mask) : 0xffff; + } + + static inline __device__ at::BFloat16 deconvert(RadixType v) { + RadixType mask = (v & 0x00008000) ? 0x00008000 : 0x0000ffff; + at::BFloat16 r; + r.x = (v ^ mask); + return r; + } +}; + +// This function counts the distribution of all input values in a +// slice we are selecting by radix digit at `radixDigitPos`, but only +// those that pass the filter `((v & desiredMask) == desired)`. +// This produces and broadcasts the seen counts for a single block only. +// `smem` must have at least `RadixSize` elements. +template < + typename scalar_t, + typename bitwise_t, + typename index_t, + typename CountType, + int RadixSize, + int RadixBits> +__device__ void countRadixUsingMask( + CountType counts[RadixSize], + CountType* smem, + bitwise_t desired, + bitwise_t desiredMask, + int radixDigitPos, + index_t sliceSize, + index_t withinSliceStride, + const scalar_t* data) { + // Clear out per-thread counts from a previous round +#pragma unroll + for (int i = 0; i < RadixSize; ++i) { + counts[i] = 0; + } + + if (threadIdx.x < RadixSize) { + smem[threadIdx.x] = 0; + } + __syncthreads(); + + // Scan over all the data. Upon a read, the warp will accumulate + // counts per each digit in the radix using warp voting. +#if !defined(USE_ROCM) + // Must be called outside of loop to ensure all threads participate + unsigned mask = WARP_BALLOT(threadIdx.x < sliceSize); +#endif + for (index_t i = threadIdx.x; i < sliceSize;) { + bitwise_t val = + TopKTypeConfig::convert(doLdg(&data[i * withinSliceStride])); + + bool hasVal = ((val & desiredMask) == desired); + bitwise_t digitInRadix = at::cuda::Bitfield::getBitfield( + val, radixDigitPos, RadixBits); + +#pragma unroll + for (uint32_t j = 0; j < RadixSize; ++j) { + bool vote = hasVal && (digitInRadix == j); +#if defined(USE_ROCM) + counts[j] += __popcll(WARP_BALLOT(vote)); +#else + counts[j] += __popc(WARP_BALLOT(vote, mask)); +#endif + } + i += blockDim.x; +#if !defined(USE_ROCM) + mask = WARP_BALLOT(i < sliceSize, mask); +#endif + } + + // Now, for each warp, sum values + if (at::cuda::getLaneId() == 0) { +#pragma unroll + for (uint32_t i = 0; i < RadixSize; ++i) { + gpuAtomicAddNoReturn(&smem[i], counts[i]); + } + } + + __syncthreads(); + + // For each thread, read in the total counts +#pragma unroll + for (uint32_t i = 0; i < RadixSize; ++i) { + counts[i] = smem[i]; + } + + __syncthreads(); +} + +// Over what radix we are selecting values +constexpr int RADIX_BITS = 2; // digits are base-(2 ^ RADIX_BITS) +constexpr int RADIX_SIZE = 4; // 2 ^ RADIX_BITS +constexpr int RADIX_MASK = (RADIX_SIZE - 1); + +// This finds the unique value `v` that matches the pattern +// ((v & desired) == desiredMask) in our sorted int format +template +__device__ scalar_t findPattern( + scalar_t* smem, + const scalar_t* data, + index_t sliceSize, + index_t withinSliceStride, + bitwise_t desired, + bitwise_t desiredMask) { + if (threadIdx.x < 2) { + smem[threadIdx.x] = static_cast(0); + } + __syncthreads(); + + // All threads participate in the loop, in order to sync on the flag + index_t numIterations = + round_up(sliceSize, static_cast(blockDim.x)); + for (index_t i = threadIdx.x; i < numIterations; i += blockDim.x) { + bool inRange = (i < sliceSize); + scalar_t v = inRange ? doLdg(&data[i * withinSliceStride]) + : static_cast(0); + + if (inRange && + ((TopKTypeConfig::convert(v) & desiredMask) == desired)) { + // There should not be conflicts if we are using findPattern, + // since the result is unique + smem[0] = static_cast(1); + smem[1] = v; // can't use val as the flag, since it could be 0 + } + + __syncthreads(); + + scalar_t found = smem[0]; + scalar_t val = smem[1]; + + __syncthreads(); + + // Check to see if a thread found the value + if (found != static_cast(0)) { + // all threads return this value + return val; + } + } + + // should not get here + CUDA_KERNEL_ASSERT(false); + return static_cast(0); +} + +// Returns the top-Kth element found in the data using radix selection +template +__device__ void radixSelect( + const scalar_t* data, + index_t k, + bool largest, + index_t sliceSize, + index_t withinSliceStride, + int* smem, + scalar_t* topK) { + // Per-thread buckets into which we accumulate digit counts in our + // radix + int counts[RADIX_SIZE]; + + // We only consider elements x such that (x & desiredMask) == desired + // Initially, we consider all elements of the array, so the above + // statement is true regardless of input. + bitwise_t desired = 0; + bitwise_t desiredMask = 0; + + // We are looking for the top kToFind-th element when iterating over + // digits; this count gets reduced by elimination when counting + // successive digits + int kToFind = k; + + // We start at the most significant digit in our radix, scanning + // through to the least significant digit + for (int digitPos = sizeof(scalar_t) * 8 - RADIX_BITS; digitPos >= 0; + digitPos -= RADIX_BITS) { + // Count radix distribution for the current position and reduce + // across all threads + countRadixUsingMask< + scalar_t, + bitwise_t, + index_t, + int, + RADIX_SIZE, + RADIX_BITS>( + counts, + smem, + desired, + desiredMask, + digitPos, + sliceSize, + withinSliceStride, + data); + + auto found_unique = [&](int i, int count) -> bool { + /* All threads have the same value in counts here, so all */ + /* threads will return from the function. */ + if (count == 1 && kToFind == 1) { + /* There is a unique answer. */ + desired = at::cuda::Bitfield::setBitfield( + desired, i, digitPos, RADIX_BITS); + desiredMask = at::cuda::Bitfield::setBitfield( + desiredMask, RADIX_MASK, digitPos, RADIX_BITS); + + /* The answer is now the unique element v such that: */ + /* (v & desiredMask) == desired */ + /* However, we do not yet know what the actual element is. We */ + /* need to perform a search through the data to find the */ + /* element that matches this pattern. */ + *topK = findPattern( + (scalar_t*)smem, + data, + sliceSize, + withinSliceStride, + desired, + desiredMask); + return true; + } + return false; + }; + auto found_non_unique = [&](int i, int count) -> bool { + if (count >= kToFind) { + desired = + at::cuda::Bitfield::setBitfield( + desired, i, digitPos, RADIX_BITS); + desiredMask = at::cuda::Bitfield::setBitfield( + desiredMask, RADIX_MASK, digitPos, RADIX_BITS); + + /* The top-Kth element v must now be one such that: */ + /* (v & desiredMask == desired) */ + /* but we haven't narrowed it down; we must check the next */ + /* least-significant digit */ + return true; + } + kToFind -= count; + return false; // continue the loop + }; + + // All threads participate in the comparisons below to know the + // final result + if (largest) { + // Process in descending order +#pragma unroll + for (int i = RADIX_SIZE - 1; i >= 0; --i) { + int count = counts[i]; + if (found_unique(i, count)) { + return; + } + if (found_non_unique(i, count)) { + break; + } + } + } else { + // Process in ascending order +#pragma unroll + for (int i = 0; i < RADIX_SIZE; ++i) { + int count = counts[i]; + if (found_unique(i, count)) { + return; + } + if (found_non_unique(i, count)) { + break; + } + } + } + } // end digitPos for + + // There is no unique result, but there is a non-unique result + // matching `desired` exactly + *topK = TopKTypeConfig::deconvert(desired); +} +} // namespace at::native + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/TensorModeKernel.cuh b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/TensorModeKernel.cuh new file mode 100644 index 0000000000000000000000000000000000000000..c5be7518ea83713d9cee9699ea39ad9896e81d22 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/TensorModeKernel.cuh @@ -0,0 +1,436 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include +#include +#include + +namespace at::native { + +// Used for a segmented reduction +struct ModeUnsignedBoolPair { + unsigned int val; + bool flag; +}; + +// In the kernel below, we have a common pattern of reducing (unsigned int, +// unsigned int) pairs of data +struct ModeUnsignedPair { + unsigned int val; + unsigned int index; +}; + +// Inclusive Scan via an upsweep/downsweep mechanism. Assumes: +// +// 1. Power2ScanSize is a power of 2. This code still works for collections that +// do not exactly contain a power of 2 number of elements, simply round up to +// the nearest power of 2 and then call. +// +// 2. That there are two-elements per thread, i.e. the size of the smem storage +// is 2 * blockDim.x * sizeof(T). +// +// Consider a (+)-Scan on the following elements: +// +// Upsweep: +// +// 0 1 2 3 4 5 6 7 +// 1 5 9 13 +// 6 22 +// 28 +// +// Downsweep: +// 15 +// 3 10 21 +template +__device__ void inclusivePrefixScan(T* smem, BinaryOp binop) { + // Reduce step ("upsweep") +#pragma unroll + for (int stride = 1; stride < Power2ScanSize; stride <<= 1) { + int index = (threadIdx.x + 1) * stride * 2 - 1; + if (index < Power2ScanSize) { + smem[index] = binop(smem[index], smem[index - stride]); + } + __syncthreads(); + } + + // Post-reduce step ("downsweep") +#pragma unroll + for (int stride = Power2ScanSize / 4; stride > 0; stride >>= 1) { + int index = (threadIdx.x + 1) * stride * 2 - 1; + if ((index + stride) < Power2ScanSize) { + smem[index + stride] = binop(smem[index + stride], smem[index]); + } + __syncthreads(); + } +} + +// Block-wide reduction where each thread locally reduces N +// values before letting a single warp take over - assumes +// threadVals is in registers, not shared memory +// +// If smem is not used again, there is no need to __syncthreads before this +// call. However, if smem will be used, e.g., this function is called in a loop, +// then __syncthreads is needed either before or afterwards to prevent non-0 +// threads overriding smem in the next loop before num-0 thread reads from it. +template +__device__ T reduceBlockWithNThreadLocalReductions( + T* smem, + T threadVals[N], + const unsigned int numVals, + ReduceOp reduceOp, + T init) { + int offset = threadIdx.x * N; + T local = offset < numVals ? threadVals[0] : init; + +#pragma unroll + for (int i = 1; i < N; ++i) { + ++offset; + T next = offset < numVals ? threadVals[i] : init; + local = reduceOp.combine(local, next); + } + + return cuda_utils::BlockReduce(local, reduceOp, init, smem); +} + +template +__device__ inline void swapVars(T& t1, T& t2) { + T tmp = t1; + t1 = t2; + t2 = tmp; +} + +template +__device__ inline void bitonicSwap( + K& kA, + V& vA, + bool& validA, + K& kB, + V& vB, + bool& validB, + bool dir, + const Comparator& comp) { + // Invalid entries always sort to the end + bool swap = (comp(kA, kB) && validA) || !validB; + if (swap == dir) { + swapVars(kA, kB); + swapVars(vA, vB); + swapVars(validA, validB); + } +}; + +template +__device__ inline void bitonicSwapKeys( + K& kA, + bool& validA, + K& kB, + bool& validB, + bool dir, + const Comparator& comp) { + bool swap = (comp(kA, kB) && validA) || !validB; + if (swap == dir) { + swapVars(kA, kB); + swapVars(validA, validB); + } +} + +template < + typename K, + typename IndexType, + int Power2SortSize, + typename Comparator> +__device__ inline void bitonicSortKeys( + K keys[Power2SortSize], + bool valid[Power2SortSize], + const Comparator& comp) { +#if !defined(USE_ROCM) +#pragma unroll +#endif + for (unsigned int size = 2; size < Power2SortSize; size *= 2) { + bool flag = ((threadIdx.x & (size / 2)) != 0); + +#if !defined(USE_ROCM) +#pragma unroll +#endif + for (unsigned int stride = size / 2; stride > 0; stride /= 2) { + __syncthreads(); + + unsigned int pos = 2 * threadIdx.x - (threadIdx.x & (stride - 1)); + bitonicSwapKeys( + keys[pos], + valid[pos], + keys[pos + stride], + valid[pos + stride], + flag, + comp); + } + } + +#if !defined(USE_ROCM) +#pragma unroll +#endif + for (unsigned int stride = Power2SortSize / 2; stride > 0; stride /= 2) { + __syncthreads(); + + unsigned int pos = 2 * threadIdx.x - (threadIdx.x & (stride - 1)); + bitonicSwapKeys( + keys[pos], + valid[pos], + keys[pos + stride], + valid[pos + stride], + false, + comp); + } + + __syncthreads(); +} + +// The mode kernel has the following characteristics: It uses internal shared +// memory buffers of Power2Size, which must be greater than the number of +// elements. Additionally, there is one block for every slice to calculate the +// mode for, and in each block there is one thread for every two elements. +// +// Both sorted and positions are assumed to be contiguous Tensors with the mode +// dimension as the innermost dim, such that we can get the particular slice for +// a Tensor via its linear block dimension * the slice size. +template +__launch_bounds__(1024, 1) +__global__ void compute_mode( + const T* input, + at::cuda::detail::TensorInfo values, + at::cuda::detail::TensorInfo indices, + int64_t sliceSize, + int64_t slices) { + int tidx = threadIdx.x; + int stidx = blockDim.x + threadIdx.x; // Second index this thread responsible for + + // First, we need to calculate the offset into the sorted Tensor that + // represents the start of the slice for this block to calculate the mode for. + // This offset is a combination of the gridIndices, and the number of elements + // in the slice. + unsigned int blockId = getLinearBlockId(); + unsigned int linearOffset = blockId * sliceSize; + + if (blockId >= slices) { + return; + } + + // shmem is a dynamically sized buffer we will use throughout the kernel to + // handle computation efficiently. The size of this shmem must be + // sizeof(T) * Power2Size + (2 * sizeof(unsigned int) * Power2Size) + // + // Initially, the buffer will be organized as follows: + // + // [smem (slice elements) | bmem (valid indices) | ] + extern __shared__ char shmem[]; + + // smem represents a proportion of the shared memory buffer that is used to + // store the elements from the slice: + T* smem = reinterpret_cast(shmem); + + // Each thread loads up to two elements from the Tensor into shared memory + if (tidx < sliceSize) { + smem[tidx] = c10::load(&input[linearOffset + tidx]); + } + if (stidx < sliceSize) { + smem[stidx] = c10::load(&input[linearOffset + stidx]); + } + + // Next, we initialize a boolean region of the buffer, offset by the loaded + // element smem region + bool* bmem = reinterpret_cast(&smem[Power2Size]); + + // The first use of this region stores bmem[i] = i < sliceSize to mark the + // valid components in the smem buffer + bmem[tidx] = tidx < sliceSize; + bmem[stidx] = stidx < sliceSize; + __syncthreads(); // barrier for smem, bmem initialization + + // First, sort the input slice in ascending order. smem contains the input + // elements, and bmem marks the valid indices + bitonicSortKeys( + smem, bmem, [&] GPU_LAMBDA(const auto& a, const auto& b) { + return a < b; + }); + __syncthreads(); // make no assumptions that the sort syncs at end + + // The next step of our algorithm is performing a block-wide comparison of + // neighboring elements. In particular, given an sorted input slice A, we + // produce an output slice B, such that B[i] = 1 if A[i-i] != A[i], otherwise + // 0. + // + // Given the input A = [0, 0, 1, 1, 2, 2, 2, 4, 5, 6, 6, 7, 8] + // B = [1, 0, 1, 0, 1, 0, 0, 1, 1, 1, 0, 1, 1] + // + // In particular, we can think of B[i] true indicating the start of a sequence + // of equal values in the sorted list. Similarly, we will also store the + // negation of B, which we'll call C. In particular, we can think of C[i] = + // true iff A[i-1] == A[i] in our original sorted slice. + // + // C = [0, 1, 0, 1, 0, 1, 1, 0, 0, 0, 1, 0, 0] + + // We overwrite bmem, and treat the rest of shared memory as a buffer of + // (index, flag) pairs where the index represents values from C, and the flag + // represents values from B. + // + // [smem (sorted slice) | ubpmem (index, flag pairs)] + + struct ModeUnsignedBoolPair* ubpmem = + reinterpret_cast(&smem[Power2Size]); + + if (tidx == 0) { + ubpmem[0].flag = true; + ubpmem[0].val = 0; + } + + // Compares elements (0, 1), (2, 3), ... and sets 1, 3, ... + ubpmem[tidx * 2 + 1].flag = + smem[tidx * 2] != smem[tidx * 2 + 1]; // (0, 1), (1, 2), etc. + ubpmem[tidx * 2 + 1].val = !ubpmem[tidx * 2 + 1].flag; + + // Compares elements (1, 2), (3, 4), ... and sets 2, 4, ... + if (((tidx + 1) * 2) < Power2Size) { + ubpmem[(tidx + 1) * 2].flag = + smem[((tidx + 1) * 2) - 1] != smem[(tidx + 1) * 2]; + ubpmem[(tidx + 1) * 2].val = !ubpmem[(tidx + 1) * 2].flag; + } + __syncthreads(); // barrier for ubpmem initialization + + // Next, we perform a segmented prefix sum on the neighboring elements, where + // the presence of a one indicates the start of a segment. In this case B acts + // as the segment start flags, and C is the buffer to be summed: + // + // Input (C) = [0, 1, 0, 1, 0, 1, 1, 0, 0, 0, 1, 0, 0] + // Flag (B) = [1, 0, 1, 0, 1, 0, 0, 1, 1, 1, 0, 1, 1] + // Output (C) = [0, 1, 0, 1, 0, 1, 2, 0, 0, 0, 1, 0, 0] + // + // Afterwards, the (index) components of the ubpmem buffer contain the lengths + // of the segments (minus 1), i.e. the counts of each element in the original + // input. + inclusivePrefixScan( + ubpmem, [=] GPU_LAMBDA(const auto& a, const auto& b) { + ModeUnsignedBoolPair c; + c.val = a.flag ? a.val : a.val + b.val; + c.flag = a.flag | b.flag; + return c; + }); + // assumes scan syncs at the end + + // Next, we reinterpret the ubpmem buffer as pairs of unsigned integers (i.e. + // we treat the boolean flag regions as integers). We initialize these to + // represent indices, and we'll call this buffer I + struct ModeUnsignedPair* uupmem = + reinterpret_cast(ubpmem); + + // At this point, we need to find the maximum element in lengths buffer C. + // This element will represent the count (-1) of the mode. Because of the + // way we have set up the problem, the index where this mode occurs will + // also be the location of the mode value in the sorted array, e.g. + // + // smem = [0, 0, 1, 1, 1, 2] + // C = [0, 1, 0, 1, 2, 0] + // I = [0, 1, 2, 3, 4, 5] + // ^ + // maximum value, also aligned with mode = 1 + // + // We perform a block wide max-reduction of the C buffer, but we also need the + // indices to come along with it, so we utilize the uupmem construction. + // + // At the end we need to return the ModeUnsignedPair containing index = 4, val + // = 2, which represents the max + + // In practice, we will make each thread locally reduce 2 values in its + // registers prior to the global block-wide reduction. Note that instead of + // tidx/stidx, we utilize tidx * 2, tidx * 2 + 1, so each thread deals with + // adjacent elements. This is because the reduce code below relies on thread + // elements to be adjacent. + struct ModeUnsignedPair uup[2]; + uup[0].index = tidx * 2; + uup[0].val = ubpmem[tidx * 2].val; + uup[1].index = tidx * 2 + 1; + uup[1].val = ubpmem[tidx * 2 + 1].val; + __syncthreads(); + + struct ModeUnsignedPair max = {0, 0}; + + struct MaxOp { + inline __device__ ModeUnsignedPair combine(ModeUnsignedPair a, ModeUnsignedPair b) const { + return b.val > a.val ? b : a; + } + + inline __device__ ModeUnsignedPair warp_shfl_down(ModeUnsignedPair acc, int offset) const { + ModeUnsignedPair ret; + ret.index = WARP_SHFL_DOWN(acc.index, offset); + ret.val = WARP_SHFL_DOWN(acc.val, offset); + return ret; + } + } max_op; + + max = reduceBlockWithNThreadLocalReductions<2>( + uupmem, + uup, + sliceSize, + max_op, + max); + + // Store the mode in shared memory for use in finding the mode in the input + // slice + __shared__ T mode; + + // Given the above constraints, the mode is the value at the reduced index in + // the original sorted element buffer + if (tidx == 0) { + mode = smem[max.index]; + } + __syncthreads(); // broadcast mode + + // Finally, we need to find "an" index of the mode in the input + // Tensor. The API does not constrain which index we pick, but here + // we always pick the largest index. We store the index if the value + // is the mode, or 0 otherwise. Then find the maximum value. + // + // Again we reduce 2 elements in the thread's registers prior to the + // block-wide reduction + unsigned mode_index[2] = {0u, 0u}; + if (tidx * 2 < sliceSize) { + const unsigned idx = tidx * 2; + mode_index[0] = c10::load(&input[linearOffset + idx]) == mode ? idx : 0u; + } + if (tidx * 2 + 1 < sliceSize) { + const unsigned idx = tidx * 2 + 1; + mode_index[1] = c10::load(&input[linearOffset + idx]) == mode ? idx : 0u; + } + + struct MaxIndexOp { + inline __device__ unsigned combine(unsigned a, unsigned b) const { + return b > a ? b : a; + } + + inline __device__ unsigned warp_shfl_down(unsigned acc, int offset) const { + return WARP_SHFL_DOWN(acc, offset); + } + } max_index_op; + + int64_t index = reduceBlockWithNThreadLocalReductions<2>( + reinterpret_cast(&shmem[0]), + mode_index, + sliceSize, + max_index_op, + 0u); + + // Finally, we have the mode, and an index where it occurs. We use a single + // thread to place this in the appropriate output position + if (tidx == 0) { + unsigned int outputOffset = + at::cuda::detail::IndexToOffset::get( + blockId, values); + values.data[outputOffset] = mode; + indices.data[outputOffset] = index; + } +} + +} // namespace at::native + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/TensorModeKernel.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/TensorModeKernel.h new file mode 100644 index 0000000000000000000000000000000000000000..7efd68a5aa0316bc02e639f5382d3ed9fe2b4ef3 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/TensorModeKernel.h @@ -0,0 +1,23 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once +#include + +namespace at { +class TensorBase; +} + +namespace at::native { + +void launch_fused_mode_kernel( + const TensorBase &values, const TensorBase &indices, + const TensorBase &self, int64_t slice_size, int64_t slices); + +void launch_apply_mode_kernel( + const TensorBase &values, const TensorBase &indices, + const TensorBase &self, int64_t dim, int64_t ndim); + +} // namespace at::native + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/UniqueCub.cuh b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/UniqueCub.cuh new file mode 100644 index 0000000000000000000000000000000000000000..1fffc057d29b50f533870dfd5b276677ee523b9e --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/UniqueCub.cuh @@ -0,0 +1,17 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#include + +namespace at::native::internal { + +template +std::tuple unique_cuda_template( + const Tensor& self, + const bool consecutive, + const bool return_inverse, + const bool return_counts); + +} // namespace at::native::internal + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/cutlass_common.cuh b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/cutlass_common.cuh new file mode 100644 index 0000000000000000000000000000000000000000..bd2c021b4c96d32bec0321bf5c9cfe52d7f45cbc --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/cutlass_common.cuh @@ -0,0 +1,53 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include + +namespace at::cuda::detail { + +template +struct enable_2x_kernel_for_sm89 : Kernel { + template + CUTLASS_DEVICE static void invoke(Args&&... args) { +#if defined __CUDA_ARCH__ && __CUDA_ARCH__ == 890 + Kernel::invoke(std::forward(args)...); +#endif + } +}; + +template +struct enable_3x_kernel_for_sm9x : Kernel { + template + CUTLASS_DEVICE void operator()(Args&&... args) { +#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 900 && __CUDA_ARCH__ < 1000 + Kernel::operator()(std::forward(args)...); +#endif + } +}; + +template +struct enable_3x_kernel_for_sm10 : Kernel { + template + CUTLASS_DEVICE void operator()(Args&&... args) { +#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 1000 && __CUDA_ARCH__ < 1200 + Kernel::operator()(std::forward(args)...); +#endif + } +}; + +template +struct enable_3x_kernel_for_sm10_or_later : Kernel { + template + CUTLASS_DEVICE void operator()(Args&&... args) { +#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 1000 + Kernel::operator()(std::forward(args)...); +#endif + } +}; + +} // namespace at::cuda::detail + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/fused_adagrad_impl.cuh b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/fused_adagrad_impl.cuh new file mode 100644 index 0000000000000000000000000000000000000000..a02bc34d0584acfc7371892ac7685de51da99286 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/fused_adagrad_impl.cuh @@ -0,0 +1,37 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once +#include + +namespace at::native { + +void _fused_adagrad_cuda_impl_( + at::TensorList params, + at::TensorList grads, + at::TensorList state_sums, + at::TensorList state_steps, + const double lr, + const double lr_decay, + const double weight_decay, + const double eps, + const bool maximize, + const std::optional& grad_scale, + const std::optional& found_inf); + +void _fused_adagrad_cuda_impl_( + at::TensorList params, + at::TensorList grads, + at::TensorList state_sums, + at::TensorList state_steps, + const at::Tensor& lr, + const double lr_decay, + const double weight_decay, + const double eps, + const bool maximize, + const std::optional& grad_scale, + const std::optional& found_inf); + +} // namespace at::native + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/fused_adam_impl.cuh b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/fused_adam_impl.cuh new file mode 100644 index 0000000000000000000000000000000000000000..e7655b881383ca990b0aa3e84aac449cf3596bc7 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/fused_adam_impl.cuh @@ -0,0 +1,41 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once +#include + +namespace at::native { + +void _fused_adam_cuda_impl_( + at::TensorList params, + at::TensorList grads, + at::TensorList exp_avgs, + at::TensorList exp_avg_sqs, + at::TensorList state_steps, + const double lr, + const double beta1, + const double beta2, + const double weight_decay, + const double eps, + const bool maximize, + const std::optional& grad_scale, + const std::optional& found_inf); + +void _fused_adam_cuda_impl_( + at::TensorList params, + at::TensorList grads, + at::TensorList exp_avgs, + at::TensorList exp_avg_sqs, + at::TensorList state_steps, + const at::Tensor& lr, + const double beta1, + const double beta2, + const double weight_decay, + const double eps, + const bool maximize, + const std::optional& grad_scale, + const std::optional& found_inf); + +} // namespace at::native + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/fused_adamw_amsgrad_impl.cuh b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/fused_adamw_amsgrad_impl.cuh new file mode 100644 index 0000000000000000000000000000000000000000..3b4a105428d4a8f11472ab0e9fbcd319744933d0 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/fused_adamw_amsgrad_impl.cuh @@ -0,0 +1,43 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once +#include + +namespace at::native { + +void _fused_adamw_amsgrad_cuda_impl_( + at::TensorList params, + at::TensorList grads, + at::TensorList exp_avgs, + at::TensorList exp_avg_sqs, + at::TensorList max_exp_avg_sqs, + at::TensorList state_steps, + const double lr, + const double beta1, + const double beta2, + const double weight_decay, + const double eps, + const bool maximize, + const std::optional& grad_scale, + const std::optional& found_inf); + +void _fused_adamw_amsgrad_cuda_impl_( + at::TensorList params, + at::TensorList grads, + at::TensorList exp_avgs, + at::TensorList exp_avg_sqs, + at::TensorList max_exp_avg_sqs, + at::TensorList state_steps, + const at::Tensor& lr, + const double beta1, + const double beta2, + const double weight_decay, + const double eps, + const bool maximize, + const std::optional& grad_scale, + const std::optional& found_inf); + +} // namespace at::native + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/im2col.cuh b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/im2col.cuh new file mode 100644 index 0000000000000000000000000000000000000000..78a9f7d0cc4c939376c69b53056b3be69c681c38 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/im2col.cuh @@ -0,0 +1,341 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include +#include + +#include + +namespace at::native { + +using namespace at::cuda::detail; + +// Kernel for fast unfold+copy +// (borrowed from Caffe: +// https://github.com/BVLC/caffe/blob/master/src/caffe/layers/conv_layer.cu) +// CUDA_NUM_THREADS = 1024 + +template +C10_LAUNCH_BOUNDS_1(1024) +__global__ void im2col_kernel( + const int64_t n, + const dt* data_im, + const int64_t height, + const int64_t width, + const int64_t kernel_height, + const int64_t kernel_width, + const int64_t pad_height, + const int64_t pad_width, + const int64_t stride_height, + const int64_t stride_width, + const int64_t dilation_height, + const int64_t dilation_width, + const int64_t height_col, + const int64_t width_col, + dt* data_col) { + CUDA_KERNEL_LOOP_TYPE(index, n, int64_t) { + int64_t w_out = index % width_col; + + int64_t idx = index / width_col; + + int64_t h_out = idx % height_col; + int64_t channel_in = idx / height_col; + int64_t channel_out = channel_in * kernel_height * kernel_width; + int64_t h_in = h_out * stride_height - pad_height; + int64_t w_in = w_out * stride_width - pad_width; + + dt* col = data_col + (channel_out * height_col + h_out) * width_col + w_out; + const dt* im = data_im + (channel_in * height + h_in) * width + w_in; + + for (int64_t i = 0; i < kernel_height; ++i) { + for (int64_t j = 0; j < kernel_width; ++j) { + int64_t h = h_in + i * dilation_height; + int64_t w = w_in + j * dilation_width; + *col = (h >= 0 && w >= 0 && h < height && w < width) + ? im[i * dilation_height * width + j * dilation_width] + : static_cast
(0); + col += height_col * width_col; + } + } + } +} + +template +void im2col( + cudaStream_t stream, + const dt* data_im, + const int64_t channels, + const int64_t height, + const int64_t width, + const int64_t height_col, + const int64_t width_col, + const int64_t kernel_height, + const int64_t kernel_width, + const int64_t pad_height, + const int64_t pad_width, + const int64_t stride_height, + const int64_t stride_width, + const int64_t dilation_height, + const int64_t dilation_width, + dt* data_col) { + // We are going to launch channels * height_col * width_col kernels, each + // kernel responsible for copying a single-channel grid. + int64_t num_kernels = channels * height_col * width_col; + // Launch CUDA_NUM_THREADS = 1024 + im2col_kernel<<>>( + num_kernels, + data_im, + height, + width, + kernel_height, + kernel_width, + pad_height, + pad_width, + stride_height, + stride_width, + dilation_height, + dilation_width, + height_col, + width_col, + data_col); + C10_CUDA_KERNEL_LAUNCH_CHECK(); +} + +template +__forceinline__ __device__ void col2im_device( + const int64_t index, + const dt* data_col, + const int64_t height, + const int64_t width, + const int64_t kernel_h, + const int64_t kernel_w, + const int64_t pad_height, + const int64_t pad_width, + const int64_t stride_height, + const int64_t stride_width, + const int64_t dilation_height, + const int64_t dilation_width, + const int64_t height_col, + const int64_t width_col, + dt* data_im) { + accT val = static_cast(0); + const int64_t w_im = index % width + pad_width; + const int64_t h_im = (index / width) % height + pad_height; + const int64_t c_im = index / (width * height); + int64_t kernel_extent_w = (kernel_w - 1) * dilation_width + 1; + int64_t kernel_extent_h = (kernel_h - 1) * dilation_height + 1; + // compute the start and end of the output + const int64_t w_col_start = (w_im < kernel_extent_w) + ? 0 + : (w_im - kernel_extent_w) / stride_width + 1; + const int64_t w_col_end = ::min(w_im / stride_width + 1, width_col); + const int64_t h_col_start = (h_im < kernel_extent_h) + ? 0 + : (h_im - kernel_extent_h) / stride_height + 1; + const int64_t h_col_end = ::min(h_im / stride_height + 1, height_col); + + // TODO: use LCM of stride and dilation to avoid unnecessary loops + for (int64_t h_col = h_col_start; h_col < h_col_end; h_col += 1) { + for (int64_t w_col = w_col_start; w_col < w_col_end; w_col += 1) { + int64_t h_k = (h_im - h_col * stride_height); + int64_t w_k = (w_im - w_col * stride_width); + if (h_k % dilation_height == 0 && w_k % dilation_width == 0) { + h_k /= dilation_height; + w_k /= dilation_width; + int64_t data_col_index = + (((c_im * kernel_h + h_k) * kernel_w + w_k) * height_col + + h_col) * + width_col + + w_col; + val += data_col[data_col_index]; + } + } + } + data_im[index] = static_cast
(val); +} + +template +C10_LAUNCH_BOUNDS_1(512) +__global__ void col2im_kernel( + const int64_t n, + const dt* data_col, + const int64_t height, + const int64_t width, + const int64_t kernel_h, + const int64_t kernel_w, + const int64_t pad_height, + const int64_t pad_width, + const int64_t stride_height, + const int64_t stride_width, + const int64_t dilation_height, + const int64_t dilation_width, + const int64_t height_col, + const int64_t width_col, + dt* data_im) { + CUDA_KERNEL_LOOP(index, n) { + col2im_device( + index, + data_col, + height, + width, + kernel_h, + kernel_w, + pad_height, + pad_width, + stride_height, + stride_width, + dilation_height, + dilation_width, + height_col, + width_col, + data_im); + } +} + +template +void col2im( + cudaStream_t stream, + const dt* data_col, + const int64_t channels, + const int64_t height, + const int64_t width, + const int64_t height_col, + const int64_t width_col, + const int64_t patch_height, + const int64_t patch_width, + const int64_t pad_height, + const int64_t pad_width, + const int64_t stride_height, + const int64_t stride_width, + const int64_t dilation_height, + const int64_t dilation_width, + dt* data_im) { + int64_t num_kernels = channels * height * width; + // To avoid involving atomic operations, we will launch one kernel per + // bottom dimension, and then in the kernel add up the top dimensions. + // CUDA_NUM_THREADS = 1024 + col2im_kernel + <<>>( + num_kernels, + data_col, + height, + width, + patch_height, + patch_width, + pad_height, + pad_width, + stride_height, + stride_width, + dilation_height, + dilation_width, + height_col, + width_col, + data_im); + C10_CUDA_KERNEL_LAUNCH_CHECK(); +} + +template +C10_LAUNCH_BOUNDS_1(512) +__global__ void col2im_batched_kernel( + const int64_t n, + const dt* data_col, + const int64_t col_batch_stride, + const int64_t nbatch, + const int64_t height, + const int64_t width, + const int64_t kernel_h, + const int64_t kernel_w, + const int64_t pad_height, + const int64_t pad_width, + const int64_t stride_height, + const int64_t stride_width, + const int64_t dilation_height, + const int64_t dilation_width, + const int64_t height_col, + const int64_t width_col, + dt* data_im, + const int64_t im_batch_stride) { + using accT = at::acc_type; + const auto im_numel = n * nbatch; + + CUDA_KERNEL_LOOP_TYPE(index, im_numel, int64_t) { + const auto ibatch = index / n; + const auto slice_index = index % n; + + col2im_device( + slice_index, + data_col + ibatch * col_batch_stride, + height, + width, + kernel_h, + kernel_w, + pad_height, + pad_width, + stride_height, + stride_width, + dilation_height, + dilation_width, + height_col, + width_col, + data_im + ibatch * im_batch_stride); + } +} + +template +void col2im_batched( + cudaStream_t stream, + const dt* data_col, + const int64_t col_batch_stride, + const int64_t nbatch, + const int64_t channels, + const int64_t height, + const int64_t width, + const int64_t height_col, + const int64_t width_col, + const int64_t patch_height, + const int64_t patch_width, + const int64_t pad_height, + const int64_t pad_width, + const int64_t stride_height, + const int64_t stride_width, + const int64_t dilation_height, + const int64_t dilation_width, + dt* data_im, + const int64_t im_batch_stride) { + const int64_t num_kernels = channels * height * width; + const int64_t output_numel = nbatch * num_kernels; + if (output_numel == 0) { + return; // No work to do + } + + // To avoid involving atomic operations, we will launch one kernel per + // bottom dimension, and then in the kernel add up the top dimensions. + // CUDA_NUM_THREADS = 1024 + col2im_batched_kernel<<>>( + num_kernels, + data_col, + col_batch_stride, + nbatch, + height, + width, + patch_height, + patch_width, + pad_height, + pad_width, + stride_height, + stride_width, + dilation_height, + dilation_width, + height_col, + width_col, + data_im, + im_batch_stride); + C10_CUDA_KERNEL_LAUNCH_CHECK(); +} + +} // namespace at::native + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/reduction_template.cuh b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/reduction_template.cuh new file mode 100644 index 0000000000000000000000000000000000000000..c86a786493ad6dbf4fd45dc1f14dab4f94fcb161 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/reduction_template.cuh @@ -0,0 +1,689 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +namespace at::cuda { +//windows doesn't like large string literals, so split in two +const std::string reduction_template_0 = R"ESCAPE( + #define C10_HOST_DEVICE __host__ __device__ + #define C10_DEVICE __device__ + #if defined(__clang__) && defined(__HIP__) + #ifndef __forceinline__ + #define __forceinline__ inline __attribute__((always_inline)) + #endif + // until ROCm support for kernel asserts is restored + #define assert(expr) (static_cast(0)) + #endif + + template + __device__ __forceinline__ T WARP_SHFL_DOWN(T value, unsigned int delta, int width = warpSize, unsigned int mask = 0xffffffff) + { + #if defined(__clang__) && defined(__HIP__) + return __shfl_down(value, delta, width); + #else + return __shfl_down_sync(mask, value, delta, width); + #endif + } + + + #if ${complex} + template + __device__ __forceinline__ std::complex WARP_SHFL_DOWN(std::complex value, unsigned int delta, int width = warpSize, unsigned int mask = 0xffffffff) + { + return std::complex( + #if defined(__clang__) && defined(__HIP__) + __shfl_down(value.real(), delta, width), + __shfl_down(value.imag(), delta, width)); + #else + __shfl_down_sync(mask, value.real(), delta, width), + __shfl_down_sync(mask, value.imag(), delta, width)); + #endif + } + #endif + + // aligned vector generates vectorized load/store on CUDA + template + struct alignas(sizeof(scalar_t) * vec_size) aligned_vector { + scalar_t val[vec_size]; + }; + + + C10_HOST_DEVICE static void reduce_fraction(size_t &numerator, size_t &denominator) { + // get GCD of num and denom using Euclid's algorithm. + // Can replace this with std::gcd if we ever support c++17. + size_t a = denominator; + size_t b = numerator; + while (b != 0) { + a %= b; + // swap(a,b) + size_t tmp = a; + a = b; + b = tmp; + } + + // a is now the GCD + numerator /= a; + denominator /= a; + } + + + + + struct ReduceConfig { + //has to match host-side ReduceConfig in the eager code + static constexpr int BLOCK_X = 0; + static constexpr int BLOCK_Y = 1; + static constexpr int CTA = 2; + + static constexpr int input_vec_size = 4; + int element_size_bytes; + int num_inputs; + int num_outputs; + int step_input = 1; + int step_output = 1; + int ctas_per_output = 1; + int input_mult[3] = {0, 0, 0}; + int output_mult[2] = {0, 0}; + + int block_width; + int block_height; + int num_threads; + + bool vectorize_input = false; + int output_vec_size = 1; + + C10_HOST_DEVICE bool should_block_x_reduce() const { + return input_mult[BLOCK_X] != 0; + } + + C10_HOST_DEVICE bool should_block_y_reduce() const { + return input_mult[BLOCK_Y] != 0; + } + + C10_HOST_DEVICE bool should_global_reduce() const { + return input_mult[CTA] != 0; + } + + C10_DEVICE bool should_store(int output_idx) const { + return output_idx < num_outputs && + (!should_block_x_reduce() || threadIdx.x == 0) && + (!should_block_y_reduce() || threadIdx.y == 0); + } + + C10_DEVICE bool should_reduce_tail() const { + return (!should_block_y_reduce() || threadIdx.y == 0) && + (!should_global_reduce() || blockIdx.y == 0); + } + + C10_HOST_DEVICE int input_idx() const { + int lane = threadIdx.x; + int warp = threadIdx.y; + int cta2 = blockIdx.y; + return (lane * input_mult[BLOCK_X] + + warp * input_mult[BLOCK_Y] + + cta2 * input_mult[CTA]); + } + + template + C10_HOST_DEVICE int output_idx() const { + int lane = threadIdx.x; + int warp = threadIdx.y; + int cta1 = blockIdx.x; + return (lane * output_mult[BLOCK_X] + + warp * output_mult[BLOCK_Y] + + cta1 * step_output) * output_vec_size; + } + + C10_DEVICE int shared_memory_offset(int offset) const { + return threadIdx.x + (threadIdx.y + offset) * blockDim.x; + } + + C10_DEVICE int staging_memory_offset(int cta2) const { + int offset = cta2 + blockIdx.x * gridDim.y; + if (!should_block_x_reduce()) { + offset = threadIdx.x + offset * blockDim.x; + } + return offset; + } + + + }; + + +//TODO this will need to be different for more generic reduction functions +namespace reducer { + + using scalar_t = ${scalar_type}; + using arg_t = ${reduction_accum_type}; + using out_scalar_t = ${result_type}; + + + inline __device__ ${functor} + + inline __device__ out_scalar_t project(arg_t arg) { + return (out_scalar_t) arg; + } + + inline __device__ arg_t warp_shfl_down(arg_t arg, int offset) { + return WARP_SHFL_DOWN(arg, offset); + } + + inline __device__ arg_t translate_idx(arg_t acc, int64_t /*idx*/) { + return acc; + } + + // wrap a normal reduction that ignores the index + inline __device__ arg_t reduce(arg_t acc, arg_t val, int64_t idx) { + return combine(acc, val); + } +} + + +struct ReduceJitOp { + using scalar_t = ${scalar_type}; + using arg_t = ${reduction_accum_type}; + using out_scalar_t = ${result_type}; + + using InputCalculator = OffsetCalculator<1>; + using OutputCalculator = OffsetCalculator<2>; + +// static constexpr bool can_accumulate_in_output = +// std::is_convertible_v +// && std::is_convertible_v; + + static constexpr int input_vec_size = ReduceConfig::input_vec_size; + + arg_t ident; + ReduceConfig config; + InputCalculator input_calc; + OutputCalculator output_calc; + const void* src; + const char* dst[2]; //it accepts at most two destinations + // acc_buf used for accumulation among sub Tensor Iterator when accumulation on + // output is not permissible + void* acc_buf; + // cta_buf used for accumulation between blocks during global reduction + void* cta_buf; + int* semaphores; + int64_t base_idx; + bool accumulate; + bool final_output; + int noutputs; + + + C10_DEVICE void run() const { + extern __shared__ char shared_memory[]; + uint32_t output_idx = config.output_idx<${output_vec_size}>(); + uint32_t input_idx = config.input_idx(); + auto base_offsets1 = output_calc.get(output_idx)[1]; + + using arg_vec_t = Array; + arg_vec_t value; + + if (output_idx < config.num_outputs && input_idx < config.num_inputs) { + const scalar_t* input_slice = (const scalar_t*)((const char*)src + base_offsets1); + + value = thread_reduce<${output_vec_size}>(input_slice); + } + + if (config.should_block_y_reduce()) { + value = block_y_reduce<${output_vec_size}>(value, shared_memory); + } + if (config.should_block_x_reduce()) { + value = block_x_reduce<${output_vec_size}>(value, shared_memory); + } + + using out_ptr_vec_t = Array; + using offset_vec_t = Array; + offset_vec_t base_offsets; + out_ptr_vec_t out; + + #pragma unroll + for (int i = 0; i < ${output_vec_size}; i++) { + base_offsets[i] = output_calc.get(output_idx + i)[0]; + out[i] = (out_scalar_t*)((char*)dst[0] + base_offsets[i]); + } + + arg_vec_t* acc = nullptr; + if (acc_buf != nullptr) { + size_t numerator = sizeof(arg_t); + size_t denominator = sizeof(out_scalar_t); + reduce_fraction(numerator, denominator); + acc = (arg_vec_t*)((char*)acc_buf + (base_offsets[0] * numerator / denominator)); + } + + if (config.should_global_reduce()) { + value = global_reduce<${output_vec_size}>(value, acc, shared_memory); + } else if (config.should_store(output_idx)) { + if (accumulate) { + #pragma unroll + for (int i = 0; i < ${output_vec_size}; i++) { + value[i] = reducer::translate_idx(value[i], base_idx); + } + } + + if (acc == nullptr) { + if (accumulate) { + value = accumulate_in_output<${output_vec_size}>(out, value); + } + if (final_output) { + set_results_to_output<${output_vec_size}>(value, base_offsets); + } else { + #pragma unroll + for (int i = 0; i < ${output_vec_size}; i++) { + *(out[i]) = get_accumulated_output(out[i], value[i]); + } + } + } else { + if (accumulate) { + #pragma unroll + for (int i = 0; i < ${output_vec_size}; i++) { + value[i] = reducer::combine((*acc)[i], value[i]); + } + } + if (final_output) { + set_results_to_output<${output_vec_size}>(value, base_offsets); + } else { + *acc = value; + } + } + } + } + + template + C10_DEVICE Array thread_reduce(const scalar_t* data) const { + if (config.vectorize_input) { + assert(output_vec_size == 1); + // reduce at the header of input_slice where memory is not aligned, + // so that thread_reduce will have an aligned memory to work on. + return {input_vectorized_thread_reduce_impl(data)}; + } else { + uint32_t element_stride = input_calc.strides_[0][0] / sizeof(scalar_t); + bool is_contiguous = (input_calc.dims == 1 && element_stride == 1); + if (is_contiguous) { + return thread_reduce_impl(data, [](uint32_t idx) { return idx; }); + } else if (input_calc.dims == 1) { + return thread_reduce_impl(data, [&](uint32_t idx) { return idx * element_stride; }); + } else { + return thread_reduce_impl(data, [&](uint32_t idx) { return input_calc.get(idx)[0] / sizeof(scalar_t); }); + } + } + } + + C10_DEVICE arg_t input_vectorized_thread_reduce_impl(const scalar_t* data) const { + uint32_t end = config.num_inputs; + + // Handle the head of input slice where data is not aligned + arg_t value = ident; + constexpr int align_bytes = alignof(aligned_vector); + constexpr int align_elements = align_bytes / sizeof(scalar_t); + int shift = ((int64_t)data) % align_bytes / sizeof(scalar_t); + if (shift > 0) { + data -= shift; + end += shift; + if(threadIdx.x >= shift && threadIdx.x < align_elements && config.should_reduce_tail()){ + value = reducer::reduce(value, data[threadIdx.x], threadIdx.x - shift); + } + end -= align_elements; + data += align_elements; + shift = align_elements - shift; + } + + // Do the vectorized reduction + using load_t = aligned_vector; + + uint32_t idx = config.input_idx(); + const uint32_t stride = config.step_input; + + // Multiple accumulators to remove dependency between unrolled loops. + arg_t value_list[input_vec_size]; + value_list[0] = value; + + #pragma unroll + for (int i = 1; i < input_vec_size; i++) { + value_list[i] = ident; + } + + scalar_t values[input_vec_size]; + + load_t *values_vector = reinterpret_cast(&values[0]); + + while (idx * input_vec_size + input_vec_size - 1 < end) { + *values_vector = reinterpret_cast(data)[idx]; + #pragma unroll + for (uint32_t i = 0; i < input_vec_size; i++) { + value_list[i] = reducer::reduce(value_list[i], values[i], shift + idx * input_vec_size + i); + } + idx += stride; + } + + // tail + uint32_t tail_start = end - end % input_vec_size; + if (config.should_reduce_tail()) { + int idx = tail_start + threadIdx.x; + if (idx < end) { + value_list[0] = reducer::reduce(value_list[0], data[idx], idx + shift); + } + } + + // combine accumulators + #pragma unroll + for (int i = 1; i < input_vec_size; i++) { + value_list[0] = reducer::combine(value_list[0], value_list[i]); + } + return value_list[0]; + } + + template + C10_DEVICE Array thread_reduce_impl(const scalar_t* data_, offset_calc_t calc) const { + uint32_t idx = config.input_idx(); + const uint32_t end = config.num_inputs; + const uint32_t stride = config.step_input; + const int vt0=${vt0}; + + using arg_vec_t = Array; + using load_t = aligned_vector; + const load_t* data = reinterpret_cast(data_); + + // Multiple accumulators to remove dependency between unrolled loops. + arg_vec_t value_list[vt0]; + + #pragma unroll + for (int i = 0; i < vt0; i++) { + #pragma unroll + for (int j = 0; j < output_vec_size; j++) { + value_list[i][j] = ident; + } + } + + load_t values[vt0]; + + while (idx + (vt0 - 1) * stride < end) { + #pragma unroll + for (uint32_t i = 0; i < vt0; i++) { + values[i] = data[calc(idx + i * stride) / output_vec_size]; + } + #pragma unroll + for (uint32_t i = 0; i < vt0; i++) { + #pragma unroll + for (uint32_t j = 0; j < output_vec_size; j++) { + value_list[i][j] = reducer::reduce(value_list[i][j], values[i].val[j], idx + i * stride); + } + } + idx += stride * vt0; + } + + // tail + int idx_ = idx; + #pragma unroll + for (uint32_t i = 0; i < vt0; i++) { + if (idx >= end) { + break; + } + values[i] = data[calc(idx) / output_vec_size]; + idx += stride; + } + idx = idx_; + #pragma unroll + for (uint32_t i = 0; i < vt0; i++) { + if (idx >= end) { + break; + } + #pragma unroll + for (uint32_t j = 0; j < output_vec_size; j++) { + value_list[i][j] = reducer::reduce(value_list[i][j], values[i].val[j], idx); + } + idx += stride; + } + + // combine accumulators + #pragma unroll + for (int i = 1; i < vt0; i++) { + #pragma unroll + for (uint32_t j = 0; j < output_vec_size; j++) { + value_list[0][j] = reducer::combine(value_list[0][j], value_list[i][j]); + } + } + return value_list[0]; + } + template + C10_DEVICE Array block_x_reduce(Array value, char* shared_memory) const { + using args_vec_t = Array; + int dim_x = blockDim.x; + args_vec_t* shared = (args_vec_t*)shared_memory; + if (dim_x > warpSize) { + int address_base = threadIdx.x + threadIdx.y*blockDim.x; + shared[address_base] = value; + for (int offset = dim_x/2; offset >= warpSize; offset >>= 1) { + __syncthreads(); + if (threadIdx.x < offset && threadIdx.x + offset < blockDim.x) { + args_vec_t other = shared[address_base + offset]; + #pragma unroll + for (int i = 0; i < output_vec_size; i++) { + value[i] = reducer::combine(value[i], other[i]); + } + shared[address_base] = value; + } + } + dim_x = warpSize; + } + + __syncthreads(); + + #if defined(USE_ROCM) || defined(FBCODE_CAFFE2) + for (int offset = 1; offset < dim_x; offset <<= 1) { + #else + for (int offset = dim_x >> 1; offset > 0; offset >>= 1) { + #endif + #pragma unroll + for (int i = 0; i < output_vec_size; i++) { + arg_t other = reducer::warp_shfl_down(value[i], offset); + value[i] = reducer::combine(value[i], other); + } + } + return value; + } + + template + C10_DEVICE Array block_y_reduce(Array value, char* shared_memory) const { + using args_vec_t = Array; + args_vec_t* shared = (args_vec_t*)shared_memory; + shared[config.shared_memory_offset(0)] = value; + for (int offset = blockDim.y / 2; offset > 0; offset >>= 1) { + __syncthreads(); + if (threadIdx.y < offset && threadIdx.y + offset < blockDim.y) { + args_vec_t other = shared[config.shared_memory_offset(offset)]; + #pragma unroll + for (int i = 0; i < output_vec_size; i++) { + value[i] = reducer::combine(value[i], other[i]); + } + shared[config.shared_memory_offset(0)] = value; + } + } + return value; + } + )ESCAPE"; + + const std::string reduction_template_1 = R"ESCAPE( + + C10_DEVICE bool mark_block_finished() const { + __shared__ bool is_last_block_done_shared; + + __syncthreads(); + if (threadIdx.x == 0 && threadIdx.y == 0) { + int prev_blocks_finished = atomicAdd(&semaphores[blockIdx.x], 1); + is_last_block_done_shared = (prev_blocks_finished == gridDim.y - 1); + } + + __syncthreads(); + + return is_last_block_done_shared; + } + + template + C10_DEVICE Array accumulate_in_output( + Array out, + Array value + ) const { + Array ret; + #pragma unroll + for (int i = 0; i < output_vec_size; i++) { + ret[i] = reducer::combine(*(out[i]), value[i]); + } + return ret; + } + + + C10_DEVICE out_scalar_t get_accumulated_output( + out_scalar_t* out, arg_t value + ) const { + assert(!final_output); + return (out_scalar_t)value; + } + + template + C10_DEVICE void set_results(const T x, const uint32_t base_offset) const { + assert(noutputs == 1); + auto res = (out_scalar_t*)((char*)dst[0] + base_offset); + *res = x; + } + +//TODO - multi-output reduction - we won't be able to use thrust::pair +//just explicitly specify typed output reads/writes +//Currently implemented for max of two outputs +// template +// C10_DEVICE void set_results(const thrust::pair x, const index_t base_offset) const { +// if (noutputs >= 1) { +// auto res0 = (T1*)((char*)dst[0] + base_offset); +// *res0 = x.first; +// } +// if (noutputs >= 2) { +// // base offset is computed assuming element size being sizeof(T1), so we need to make a +// // correction to obtain the correct base offset +// auto res1 = (T2*) ((char *) dst[1] + base_offset / sizeof(T1) * sizeof(T2)); +// *res1 = x.second; +// } +// } + + template + C10_DEVICE void set_results_to_output(Array value, Array base_offset) const { + assert(final_output); + #pragma unroll + for (int i = 0; i < output_vec_size; i++) { + set_results(reducer::project(value[i]), base_offset[i]); + } + } + + template + C10_DEVICE Array global_reduce(Array value, Array *acc, char* shared_memory) const { + using arg_vec_t = Array; + using out_ptr_vec_t = Array; + using offset_vec_t = Array; + + arg_vec_t* reduce_buffer = (arg_vec_t*)cta_buf; + uint32_t output_idx = config.output_idx(); + offset_vec_t base_offsets; + out_ptr_vec_t out; + + #pragma unroll + for (int i = 0; i < output_vec_size; i++) { + base_offsets[i] = output_calc.get(output_idx + i)[0]; + out[i] = (out_scalar_t*)((char*)dst[0] + base_offsets[i]); + } + + bool should_store = config.should_store(output_idx); + if (should_store) { + uint32_t offset = config.staging_memory_offset(blockIdx.y); + reduce_buffer[offset] = value; + } + + __threadfence(); // make sure writes are globally visible + __syncthreads(); // if multiple warps in this block wrote to staging, make sure they're all done + bool is_last_block_done = mark_block_finished(); + + if (is_last_block_done) { + __threadfence(); //complete acquire pattern + value = ident; + if (config.should_block_x_reduce()) { + uint32_t input_offset = threadIdx.x + threadIdx.y * blockDim.x; + uint32_t step = blockDim.x * blockDim.y; + for (; input_offset < config.ctas_per_output; input_offset += step) { + uint32_t idx = config.staging_memory_offset(input_offset); + arg_vec_t next = reduce_buffer[idx]; + #pragma unroll + for (int i = 0; i < output_vec_size; i++) { + value[i] = reducer::combine(value[i], next[i]); + } + } + } else { + uint32_t input_offset = threadIdx.y; + uint32_t step = blockDim.y; + for (; input_offset < config.ctas_per_output; input_offset += step) { + uint32_t idx = config.staging_memory_offset(input_offset); + arg_vec_t next = reduce_buffer[idx]; + #pragma unroll + for (int i = 0; i < output_vec_size; i++) { + value[i] = reducer::combine(value[i], next[i]); + } + } + } + value = block_y_reduce(value, shared_memory); + if (config.should_block_x_reduce()) { + value = block_x_reduce(value, shared_memory); + } + if (should_store) { + if (accumulate) { + #pragma unroll + for (int i = 0; i < output_vec_size; i++) { + value[i] = reducer::translate_idx(value[i], base_idx); + } + } + + if (acc == nullptr) { + if (accumulate) { + value = accumulate_in_output(out, value); + } + if (final_output) { + set_results_to_output(value, base_offsets); + } else { + #pragma unroll + for (int i = 0; i < output_vec_size; i++) { + *(out[i]) = get_accumulated_output(out[i], value[i]); + } + } + } else { + if (accumulate) { + #pragma unroll + for (int i = 0; i < output_vec_size; i++) { + value[i] = reducer::combine((*acc)[i], value[i]); + } + } + if (final_output) { + set_results_to_output(value, base_offsets); + } else { + *acc = value; + } + } + } + } + + return value; + } +}; + +extern "C" +__launch_bounds__(${max_threads_lb}, 4) +__global__ void reduction_${name}_kernel(ReduceJitOp r){ + r.run(); +} +)ESCAPE"; + +const std::string reduction_template = reduction_template_0 + reduction_template_1; + + +const std::string &get_reduction_template() { + return reduction_template; +} + +} // namespace at::cuda + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/ops/_amp_update_scale_native.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/ops/_amp_update_scale_native.h new file mode 100644 index 0000000000000000000000000000000000000000..1f4121095527699fc19e2ad629c3e0d1a4e4bcfd --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/ops/_amp_update_scale_native.h @@ -0,0 +1,29 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +// @generated by torchgen/gen.py from NativeFunction.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + +namespace at { +namespace native { +TORCH_API ::std::tuple _amp_update_scale(const at::Tensor & self, const at::Tensor & growth_tracker, const at::Tensor & found_inf, double scale_growth_factor, double scale_backoff_factor, int64_t growth_interval); +TORCH_API at::Tensor & _amp_update_scale_out(const at::Tensor & self, at::Tensor & growth_tracker, const at::Tensor & found_inf, double scale_growth_factor, double scale_backoff_factor, int64_t growth_interval, at::Tensor & out); +TORCH_API at::Tensor & _amp_update_scale_cpu_(at::Tensor & self, at::Tensor & growth_tracker, const at::Tensor & found_inf, double scale_growth_factor, double scale_backoff_factor, int64_t growth_interval); +TORCH_API at::Tensor & _amp_update_scale_cuda_(at::Tensor & self, at::Tensor & growth_tracker, const at::Tensor & found_inf, double scale_growth_factor, double scale_backoff_factor, int64_t growth_interval); +} // namespace native +} // namespace at + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/ops/_embedding_bag_per_sample_weights_backward_cuda_dispatch.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/ops/_embedding_bag_per_sample_weights_backward_cuda_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..dd78590e835d0548ad9df3030e827c5a6428dfdb --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/ops/_embedding_bag_per_sample_weights_backward_cuda_dispatch.h @@ -0,0 +1,28 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace cuda { + +TORCH_API at::Tensor _embedding_bag_per_sample_weights_backward(const at::Tensor & grad, const at::Tensor & weight, const at::Tensor & indices, const at::Tensor & offsets, const at::Tensor & offset2bag, int64_t mode, int64_t padding_idx=-1); + +} // namespace cuda +} // namespace at + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/ops/_lazy_clone_ops.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/ops/_lazy_clone_ops.h new file mode 100644 index 0000000000000000000000000000000000000000..6f53940c472c16e43eeccaa8f5b93957a1d77d3d --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/ops/_lazy_clone_ops.h @@ -0,0 +1,34 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +// @generated by torchgen/gen.py from Operator.h + +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { +namespace _ops { + + +struct TORCH_API _lazy_clone { + using schema = at::Tensor (const at::Tensor &); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + static constexpr const char* name = "aten::_lazy_clone"; + static constexpr const char* overload_name = ""; + static constexpr const char* schema_str = "_lazy_clone(Tensor self) -> Tensor"; + static at::Tensor call(const at::Tensor & self); + static at::Tensor redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self); +}; + +}} // namespace at::_ops + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/ops/_unsafe_masked_index_native.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/ops/_unsafe_masked_index_native.h new file mode 100644 index 0000000000000000000000000000000000000000..7c7d8b5970ca591e5ea630599c939ca62274e151 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/ops/_unsafe_masked_index_native.h @@ -0,0 +1,26 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +// @generated by torchgen/gen.py from NativeFunction.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + +namespace at { +namespace native { +TORCH_API at::Tensor _unsafe_masked_index(const at::Tensor & self, const at::Tensor & mask, const c10::List<::std::optional> & indices, const at::Scalar & fill); +} // namespace native +} // namespace at + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/ops/le_meta_dispatch.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/ops/le_meta_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..abf361c208c9a56c339c3bb8f6a3b6f1bd23e3eb --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/ops/le_meta_dispatch.h @@ -0,0 +1,35 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace meta { + +TORCH_API at::Tensor le(const at::Tensor & self, const at::Scalar & other); +TORCH_API at::Tensor & le_out(at::Tensor & out, const at::Tensor & self, const at::Scalar & other); +TORCH_API at::Tensor & le_outf(const at::Tensor & self, const at::Scalar & other, at::Tensor & out); +TORCH_API at::Tensor & le_(at::Tensor & self, const at::Scalar & other); +TORCH_API at::Tensor le(const at::Tensor & self, const at::Tensor & other); +TORCH_API at::Tensor & le_out(at::Tensor & out, const at::Tensor & self, const at::Tensor & other); +TORCH_API at::Tensor & le_outf(const at::Tensor & self, const at::Tensor & other, at::Tensor & out); +TORCH_API at::Tensor & le_(at::Tensor & self, const at::Tensor & other); + +} // namespace meta +} // namespace at + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/ops/linalg_householder_product.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/ops/linalg_householder_product.h new file mode 100644 index 0000000000000000000000000000000000000000..2ffe202c65fe4b2bc0b5ddbcf610fda9d6f5d76f --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/ops/linalg_householder_product.h @@ -0,0 +1,45 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +// @generated by torchgen/gen.py from Function.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + + +#include + +namespace at { + + +// aten::linalg_householder_product(Tensor input, Tensor tau) -> Tensor +inline at::Tensor linalg_householder_product(const at::Tensor & input, const at::Tensor & tau) { + return at::_ops::linalg_householder_product::call(input, tau); +} + +// aten::linalg_householder_product.out(Tensor input, Tensor tau, *, Tensor(a!) out) -> Tensor(a!) +inline at::Tensor & linalg_householder_product_out(at::Tensor & out, const at::Tensor & input, const at::Tensor & tau) { + return at::_ops::linalg_householder_product_out::call(input, tau, out); +} +// aten::linalg_householder_product.out(Tensor input, Tensor tau, *, Tensor(a!) out) -> Tensor(a!) +inline at::Tensor & linalg_householder_product_outf(const at::Tensor & input, const at::Tensor & tau, at::Tensor & out) { + return at::_ops::linalg_householder_product_out::call(input, tau, out); +} + +} + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/ops/linalg_qr.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/ops/linalg_qr.h new file mode 100644 index 0000000000000000000000000000000000000000..60721caa16c157f4f6a02ca6ae043cb3f63015ca --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/ops/linalg_qr.h @@ -0,0 +1,45 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +// @generated by torchgen/gen.py from Function.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + + +#include + +namespace at { + + +// aten::linalg_qr(Tensor A, str mode='reduced') -> (Tensor Q, Tensor R) +inline ::std::tuple linalg_qr(const at::Tensor & A, c10::string_view mode="reduced") { + return at::_ops::linalg_qr::call(A, mode); +} + +// aten::linalg_qr.out(Tensor A, str mode='reduced', *, Tensor(a!) Q, Tensor(b!) R) -> (Tensor(a!) Q, Tensor(b!) R) +inline ::std::tuple linalg_qr_out(at::Tensor & Q, at::Tensor & R, const at::Tensor & A, c10::string_view mode="reduced") { + return at::_ops::linalg_qr_out::call(A, mode, Q, R); +} +// aten::linalg_qr.out(Tensor A, str mode='reduced', *, Tensor(a!) Q, Tensor(b!) R) -> (Tensor(a!) Q, Tensor(b!) R) +inline ::std::tuple linalg_qr_outf(const at::Tensor & A, c10::string_view mode, at::Tensor & Q, at::Tensor & R) { + return at::_ops::linalg_qr_out::call(A, mode, Q, R); +} + +} + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/ops/max_pool1d_with_indices.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/ops/max_pool1d_with_indices.h new file mode 100644 index 0000000000000000000000000000000000000000..88dc799be41b347b42a4bc8ba49d610649af74d0 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/ops/max_pool1d_with_indices.h @@ -0,0 +1,36 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +// @generated by torchgen/gen.py from Function.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + + +#include + +namespace at { + + +// aten::max_pool1d_with_indices(Tensor self, int[1] kernel_size, int[1] stride=[], int[1] padding=0, int[1] dilation=1, bool ceil_mode=False) -> (Tensor, Tensor) +inline ::std::tuple max_pool1d_with_indices(const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride={}, at::IntArrayRef padding=0, at::IntArrayRef dilation=1, bool ceil_mode=false) { + return at::_ops::max_pool1d_with_indices::call(self, kernel_size, stride, padding, dilation, ceil_mode); +} + +} + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/ops/nextafter_meta_dispatch.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/ops/nextafter_meta_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..91bbdf872ab3ff8d0b084fecff11e2ae7b5bc56c --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/ops/nextafter_meta_dispatch.h @@ -0,0 +1,31 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace meta { + +TORCH_API at::Tensor nextafter(const at::Tensor & self, const at::Tensor & other); +TORCH_API at::Tensor & nextafter_out(at::Tensor & out, const at::Tensor & self, const at::Tensor & other); +TORCH_API at::Tensor & nextafter_outf(const at::Tensor & self, const at::Tensor & other, at::Tensor & out); +TORCH_API at::Tensor & nextafter_(at::Tensor & self, const at::Tensor & other); + +} // namespace meta +} // namespace at + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/ops/reciprocal_cuda_dispatch.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/ops/reciprocal_cuda_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..52af20b1c6299f7f7326b9c8de5325a2fbafe183 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/ops/reciprocal_cuda_dispatch.h @@ -0,0 +1,31 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace cuda { + +TORCH_API at::Tensor reciprocal(const at::Tensor & self); +TORCH_API at::Tensor & reciprocal_out(at::Tensor & out, const at::Tensor & self); +TORCH_API at::Tensor & reciprocal_outf(const at::Tensor & self, at::Tensor & out); +TORCH_API at::Tensor & reciprocal_(at::Tensor & self); + +} // namespace cuda +} // namespace at + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/ops/softplus_backward_cpu_dispatch.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/ops/softplus_backward_cpu_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..5d55aadc9cd745fe44c7365c7f225e8acdc06d9e --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/ops/softplus_backward_cpu_dispatch.h @@ -0,0 +1,30 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace cpu { + +TORCH_API at::Tensor softplus_backward(const at::Tensor & grad_output, const at::Tensor & self, const at::Scalar & beta, const at::Scalar & threshold); +TORCH_API at::Tensor & softplus_backward_out(at::Tensor & grad_input, const at::Tensor & grad_output, const at::Tensor & self, const at::Scalar & beta, const at::Scalar & threshold); +TORCH_API at::Tensor & softplus_backward_outf(const at::Tensor & grad_output, const at::Tensor & self, const at::Scalar & beta, const at::Scalar & threshold, at::Tensor & grad_input); + +} // namespace cpu +} // namespace at + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/ops/special_modified_bessel_i0.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/ops/special_modified_bessel_i0.h new file mode 100644 index 0000000000000000000000000000000000000000..e0f2c6d9372b10b5b405f593989ef140c40c6219 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/ops/special_modified_bessel_i0.h @@ -0,0 +1,45 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +// @generated by torchgen/gen.py from Function.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + + +#include + +namespace at { + + +// aten::special_modified_bessel_i0(Tensor self) -> Tensor +inline at::Tensor special_modified_bessel_i0(const at::Tensor & self) { + return at::_ops::special_modified_bessel_i0::call(self); +} + +// aten::special_modified_bessel_i0.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) +inline at::Tensor & special_modified_bessel_i0_out(at::Tensor & out, const at::Tensor & self) { + return at::_ops::special_modified_bessel_i0_out::call(self, out); +} +// aten::special_modified_bessel_i0.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) +inline at::Tensor & special_modified_bessel_i0_outf(const at::Tensor & self, at::Tensor & out) { + return at::_ops::special_modified_bessel_i0_out::call(self, out); +} + +} + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/google/protobuf/util/delimited_message_util.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/google/protobuf/util/delimited_message_util.h new file mode 100644 index 0000000000000000000000000000000000000000..a9838c1ee0b51dd9aef4faedb03e0cb17fa6f390 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/google/protobuf/util/delimited_message_util.h @@ -0,0 +1,113 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Adapted from the patch of kenton@google.com (Kenton Varda) +// See https://github.com/protocolbuffers/protobuf/pull/710 for details. + +#ifndef GOOGLE_PROTOBUF_UTIL_DELIMITED_MESSAGE_UTIL_H__ +#define GOOGLE_PROTOBUF_UTIL_DELIMITED_MESSAGE_UTIL_H__ + + +#include + +#include +#include +#include + +#include + +namespace google { +namespace protobuf { +namespace util { + +// Write a single size-delimited message from the given stream. Delimited +// format allows a single file or stream to contain multiple messages, +// whereas normally writing multiple non-delimited messages to the same +// stream would cause them to be merged. A delimited message is a varint +// encoding the message size followed by a message of exactly that size. +// +// Note that if you want to *read* a delimited message from a file descriptor +// or istream, you will need to construct an io::FileInputStream or +// io::OstreamInputStream (implementations of io::ZeroCopyStream) and use the +// utility function ParseDelimitedFromZeroCopyStream(). You must then +// continue to use the same ZeroCopyInputStream to read all further data from +// the stream until EOF. This is because these ZeroCopyInputStream +// implementations are buffered: they read a big chunk of data at a time, +// then parse it. As a result, they may read past the end of the delimited +// message. There is no way for them to push the extra data back into the +// underlying source, so instead you must keep using the same stream object. +bool PROTOBUF_EXPORT SerializeDelimitedToFileDescriptor( + const MessageLite& message, int file_descriptor); + +bool PROTOBUF_EXPORT SerializeDelimitedToOstream(const MessageLite& message, + std::ostream* output); + +// Read a single size-delimited message from the given stream. Delimited +// format allows a single file or stream to contain multiple messages, +// whereas normally parsing consumes the entire input. A delimited message +// is a varint encoding the message size followed by a message of exactly +// that size. +// +// If |clean_eof| is not NULL, then it will be set to indicate whether the +// stream ended cleanly. That is, if the stream ends without this method +// having read any data at all from it, then *clean_eof will be set true, +// otherwise it will be set false. Note that these methods return false +// on EOF, but they also return false on other errors, so |clean_eof| is +// needed to distinguish a clean end from errors. +bool PROTOBUF_EXPORT ParseDelimitedFromZeroCopyStream( + MessageLite* message, io::ZeroCopyInputStream* input, bool* clean_eof); + +bool PROTOBUF_EXPORT ParseDelimitedFromCodedStream(MessageLite* message, + io::CodedInputStream* input, + bool* clean_eof); + +// Write a single size-delimited message from the given stream. Delimited +// format allows a single file or stream to contain multiple messages, +// whereas normally writing multiple non-delimited messages to the same +// stream would cause them to be merged. A delimited message is a varint +// encoding the message size followed by a message of exactly that size. +bool PROTOBUF_EXPORT SerializeDelimitedToZeroCopyStream( + const MessageLite& message, io::ZeroCopyOutputStream* output); + +bool PROTOBUF_EXPORT SerializeDelimitedToCodedStream( + const MessageLite& message, io::CodedOutputStream* output); + +} // namespace util +} // namespace protobuf +} // namespace google + +#include + +#endif // GOOGLE_PROTOBUF_UTIL_DELIMITED_MESSAGE_UTIL_H__ + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/google/protobuf/util/field_comparator.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/google/protobuf/util/field_comparator.h new file mode 100644 index 0000000000000000000000000000000000000000..c86fe78795a809d5e48377e10ff35b2eb3431709 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/google/protobuf/util/field_comparator.h @@ -0,0 +1,265 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Defines classes for field comparison. + +#ifndef GOOGLE_PROTOBUF_UTIL_FIELD_COMPARATOR_H__ +#define GOOGLE_PROTOBUF_UTIL_FIELD_COMPARATOR_H__ + +#include +#include +#include + +#include + +#include + +namespace google { +namespace protobuf { + +class Message; +class EnumValueDescriptor; +class FieldDescriptor; + +namespace util { + +class FieldContext; +class MessageDifferencer; + +// Base class specifying the interface for comparing protocol buffer fields. +// Regular users should consider using or subclassing DefaultFieldComparator +// rather than this interface. +// Currently, this does not support comparing unknown fields. +class PROTOBUF_EXPORT FieldComparator { + public: + FieldComparator(); + virtual ~FieldComparator(); + + enum ComparisonResult { + SAME, // Compared fields are equal. In case of comparing submessages, + // user should not recursively compare their contents. + DIFFERENT, // Compared fields are different. In case of comparing + // submessages, user should not recursively compare their + // contents. + RECURSE, // Compared submessages need to be compared recursively. + // FieldComparator does not specify the semantics of recursive + // comparison. This value should not be returned for simple + // values. + }; + + // Compares the values of a field in two protocol buffer messages. + // Returns SAME or DIFFERENT for simple values, and SAME, DIFFERENT or RECURSE + // for submessages. Returning RECURSE for fields not being submessages is + // illegal. + // In case the given FieldDescriptor points to a repeated field, the indices + // need to be valid. Otherwise they should be ignored. + // + // FieldContext contains information about the specific instances of the + // fields being compared, versus FieldDescriptor which only contains general + // type information about the fields. + virtual ComparisonResult Compare(const Message& message_1, + const Message& message_2, + const FieldDescriptor* field, int index_1, + int index_2, + const util::FieldContext* field_context) = 0; + + private: + GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(FieldComparator); +}; + +// Basic implementation of FieldComparator. Supports three modes of floating +// point value comparison: exact, approximate using MathUtil::AlmostEqual +// method, and arbitrarily precise using MathUtil::WithinFractionOrMargin. +class PROTOBUF_EXPORT DefaultFieldComparator : public FieldComparator { + public: + enum FloatComparison { + EXACT, // Floats and doubles are compared exactly. + APPROXIMATE, // Floats and doubles are compared using the + // MathUtil::AlmostEqual method or + // MathUtil::WithinFractionOrMargin method. + // TODO(ksroka): Introduce third value to differentiate uses of AlmostEqual + // and WithinFractionOrMargin. + }; + + // Creates new comparator with float comparison set to EXACT. + DefaultFieldComparator(); + + ~DefaultFieldComparator() override; + + ComparisonResult Compare(const Message& message_1, const Message& message_2, + const FieldDescriptor* field, int index_1, + int index_2, + const util::FieldContext* field_context) override; + + void set_float_comparison(FloatComparison float_comparison) { + float_comparison_ = float_comparison; + } + + FloatComparison float_comparison() const { return float_comparison_; } + + // Set whether the FieldComparator shall treat floats or doubles that are both + // NaN as equal (treat_nan_as_equal = true) or as different + // (treat_nan_as_equal = false). Default is treating NaNs always as different. + void set_treat_nan_as_equal(bool treat_nan_as_equal) { + treat_nan_as_equal_ = treat_nan_as_equal; + } + + bool treat_nan_as_equal() const { return treat_nan_as_equal_; } + + // Sets the fraction and margin for the float comparison of a given field. + // Uses MathUtil::WithinFractionOrMargin to compare the values. + // + // REQUIRES: field->cpp_type == FieldDescriptor::CPPTYPE_DOUBLE or + // field->cpp_type == FieldDescriptor::CPPTYPE_FLOAT + // REQUIRES: float_comparison_ == APPROXIMATE + void SetFractionAndMargin(const FieldDescriptor* field, double fraction, + double margin); + + // Sets the fraction and margin for the float comparison of all float and + // double fields, unless a field has been given a specific setting via + // SetFractionAndMargin() above. + // Uses MathUtil::WithinFractionOrMargin to compare the values. + // + // REQUIRES: float_comparison_ == APPROXIMATE + void SetDefaultFractionAndMargin(double fraction, double margin); + + protected: + // Compare using the provided message_differencer. For example, a subclass can + // use this method to compare some field in a certain way using the same + // message_differencer instance and the field context. + bool Compare(MessageDifferencer* differencer, const Message& message1, + const Message& message2, + const util::FieldContext* field_context); + + private: + // Defines the tolerance for floating point comparison (fraction and margin). + struct Tolerance { + double fraction; + double margin; + Tolerance() : fraction(0.0), margin(0.0) {} + Tolerance(double f, double m) : fraction(f), margin(m) {} + }; + + // Defines the map to store the tolerances for floating point comparison. + typedef std::map ToleranceMap; + + // The following methods get executed when CompareFields is called for the + // basic types (instead of submessages). They return true on success. One + // can use ResultFromBoolean() to convert that boolean to a ComparisonResult + // value. + bool CompareBool(const FieldDescriptor& /* unused */, bool value_1, + bool value_2) { + return value_1 == value_2; + } + + // Uses CompareDoubleOrFloat, a helper function used by both CompareDouble and + // CompareFloat. + bool CompareDouble(const FieldDescriptor& field, double value_1, + double value_2); + + bool CompareEnum(const FieldDescriptor& field, + const EnumValueDescriptor* value_1, + const EnumValueDescriptor* value_2); + + // Uses CompareDoubleOrFloat, a helper function used by both CompareDouble and + // CompareFloat. + bool CompareFloat(const FieldDescriptor& field, float value_1, float value_2); + + bool CompareInt32(const FieldDescriptor& /* unused */, int32 value_1, + int32 value_2) { + return value_1 == value_2; + } + + bool CompareInt64(const FieldDescriptor& /* unused */, int64 value_1, + int64 value_2) { + return value_1 == value_2; + } + + bool CompareString(const FieldDescriptor& /* unused */, + const std::string& value_1, const std::string& value_2) { + return value_1 == value_2; + } + + bool CompareUInt32(const FieldDescriptor& /* unused */, uint32 value_1, + uint32 value_2) { + return value_1 == value_2; + } + + bool CompareUInt64(const FieldDescriptor& /* unused */, uint64 value_1, + uint64 value_2) { + return value_1 == value_2; + } + + // This function is used by CompareDouble and CompareFloat to avoid code + // duplication. There are no checks done against types of the values passed, + // but it's likely to fail if passed non-numeric arguments. + template + bool CompareDoubleOrFloat(const FieldDescriptor& field, T value_1, T value_2); + + // Returns FieldComparator::SAME if boolean_result is true and + // FieldComparator::DIFFERENT otherwise. + ComparisonResult ResultFromBoolean(bool boolean_result) const; + + FloatComparison float_comparison_; + + // If true, floats and doubles that are both NaN are considered to be + // equal. Otherwise, two floats or doubles that are NaN are considered to be + // different. + bool treat_nan_as_equal_; + + // True iff default_tolerance_ has been explicitly set. + // + // If false, then the default tolerance for floats and doubles is that which + // is used by MathUtil::AlmostEquals(). + bool has_default_tolerance_; + + // Default float/double tolerance. Only meaningful if + // has_default_tolerance_ == true. + Tolerance default_tolerance_; + + // Field-specific float/double tolerances, which override any default for + // those particular fields. + ToleranceMap map_tolerance_; + + GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(DefaultFieldComparator); +}; + +} // namespace util +} // namespace protobuf +} // namespace google + +#include + +#endif // GOOGLE_PROTOBUF_UTIL_FIELD_COMPARATOR_H__ + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/google/protobuf/util/json_util.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/google/protobuf/util/json_util.h new file mode 100644 index 0000000000000000000000000000000000000000..7a8026248d63929a6f1d43fb312af94faa8de155 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/google/protobuf/util/json_util.h @@ -0,0 +1,208 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Utility functions to convert between protobuf binary format and proto3 JSON +// format. +#ifndef GOOGLE_PROTOBUF_UTIL_JSON_UTIL_H__ +#define GOOGLE_PROTOBUF_UTIL_JSON_UTIL_H__ + +#include +#include +#include +#include + +#include + +namespace google { +namespace protobuf { +namespace io { +class ZeroCopyInputStream; +class ZeroCopyOutputStream; +} // namespace io +namespace util { + +struct JsonParseOptions { + // Whether to ignore unknown JSON fields during parsing + bool ignore_unknown_fields; + + // If true, when a lowercase enum value fails to parse, try convert it to + // UPPER_CASE and see if it matches a valid enum. + // WARNING: This option exists only to preserve legacy behavior. Avoid using + // this option. If your enum needs to support different casing, consider using + // allow_alias instead. + bool case_insensitive_enum_parsing; + + JsonParseOptions() + : ignore_unknown_fields(false), + case_insensitive_enum_parsing(false) {} +}; + +struct JsonPrintOptions { + // Whether to add spaces, line breaks and indentation to make the JSON output + // easy to read. + bool add_whitespace; + // Whether to always print primitive fields. By default proto3 primitive + // fields with default values will be omitted in JSON output. For example, an + // int32 field set to 0 will be omitted. Set this flag to true will override + // the default behavior and print primitive fields regardless of their values. + bool always_print_primitive_fields; + // Whether to always print enums as ints. By default they are rendered as + // strings. + bool always_print_enums_as_ints; + // Whether to preserve proto field names + bool preserve_proto_field_names; + + JsonPrintOptions() + : add_whitespace(false), + always_print_primitive_fields(false), + always_print_enums_as_ints(false), + preserve_proto_field_names(false) {} +}; + +// DEPRECATED. Use JsonPrintOptions instead. +typedef JsonPrintOptions JsonOptions; + +// Converts from protobuf message to JSON and appends it to |output|. This is a +// simple wrapper of BinaryToJsonString(). It will use the DescriptorPool of the +// passed-in message to resolve Any types. +PROTOBUF_EXPORT util::Status MessageToJsonString(const Message& message, + std::string* output, + const JsonOptions& options); + +inline util::Status MessageToJsonString(const Message& message, + std::string* output) { + return MessageToJsonString(message, output, JsonOptions()); +} + +// Converts from JSON to protobuf message. This is a simple wrapper of +// JsonStringToBinary(). It will use the DescriptorPool of the passed-in +// message to resolve Any types. +PROTOBUF_EXPORT util::Status JsonStringToMessage( + StringPiece input, Message* message, const JsonParseOptions& options); + +inline util::Status JsonStringToMessage(StringPiece input, + Message* message) { + return JsonStringToMessage(input, message, JsonParseOptions()); +} + +// Converts protobuf binary data to JSON. +// The conversion will fail if: +// 1. TypeResolver fails to resolve a type. +// 2. input is not valid protobuf wire format, or conflicts with the type +// information returned by TypeResolver. +// Note that unknown fields will be discarded silently. +PROTOBUF_EXPORT util::Status BinaryToJsonStream( + TypeResolver* resolver, const std::string& type_url, + io::ZeroCopyInputStream* binary_input, + io::ZeroCopyOutputStream* json_output, const JsonPrintOptions& options); + +inline util::Status BinaryToJsonStream( + TypeResolver* resolver, const std::string& type_url, + io::ZeroCopyInputStream* binary_input, + io::ZeroCopyOutputStream* json_output) { + return BinaryToJsonStream(resolver, type_url, binary_input, json_output, + JsonPrintOptions()); +} + +PROTOBUF_EXPORT util::Status BinaryToJsonString( + TypeResolver* resolver, const std::string& type_url, + const std::string& binary_input, std::string* json_output, + const JsonPrintOptions& options); + +inline util::Status BinaryToJsonString(TypeResolver* resolver, + const std::string& type_url, + const std::string& binary_input, + std::string* json_output) { + return BinaryToJsonString(resolver, type_url, binary_input, json_output, + JsonPrintOptions()); +} + +// Converts JSON data to protobuf binary format. +// The conversion will fail if: +// 1. TypeResolver fails to resolve a type. +// 2. input is not valid JSON format, or conflicts with the type +// information returned by TypeResolver. +PROTOBUF_EXPORT util::Status JsonToBinaryStream( + TypeResolver* resolver, const std::string& type_url, + io::ZeroCopyInputStream* json_input, + io::ZeroCopyOutputStream* binary_output, const JsonParseOptions& options); + +inline util::Status JsonToBinaryStream( + TypeResolver* resolver, const std::string& type_url, + io::ZeroCopyInputStream* json_input, + io::ZeroCopyOutputStream* binary_output) { + return JsonToBinaryStream(resolver, type_url, json_input, binary_output, + JsonParseOptions()); +} + +PROTOBUF_EXPORT util::Status JsonToBinaryString( + TypeResolver* resolver, const std::string& type_url, + StringPiece json_input, std::string* binary_output, + const JsonParseOptions& options); + +inline util::Status JsonToBinaryString(TypeResolver* resolver, + const std::string& type_url, + StringPiece json_input, + std::string* binary_output) { + return JsonToBinaryString(resolver, type_url, json_input, binary_output, + JsonParseOptions()); +} + +namespace internal { +// Internal helper class. Put in the header so we can write unit-tests for it. +class PROTOBUF_EXPORT ZeroCopyStreamByteSink : public strings::ByteSink { + public: + explicit ZeroCopyStreamByteSink(io::ZeroCopyOutputStream* stream) + : stream_(stream), buffer_(NULL), buffer_size_(0) {} + ~ZeroCopyStreamByteSink(); + + void Append(const char* bytes, size_t len) override; + + private: + io::ZeroCopyOutputStream* stream_; + void* buffer_; + int buffer_size_; + + GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(ZeroCopyStreamByteSink); +}; +} // namespace internal + +} // namespace util +} // namespace protobuf +} // namespace google + +#include + +#endif // GOOGLE_PROTOBUF_UTIL_JSON_UTIL_H__ + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/google/protobuf/util/message_differencer.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/google/protobuf/util/message_differencer.h new file mode 100644 index 0000000000000000000000000000000000000000..afef8c8b276991f39e1817bca505d852aca7f58b --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/google/protobuf/util/message_differencer.h @@ -0,0 +1,941 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Author: jschorr@google.com (Joseph Schorr) +// Based on original Protocol Buffers design by +// Sanjay Ghemawat, Jeff Dean, and others. +// +// This file defines static methods and classes for comparing Protocol +// Messages. +// +// Aug. 2008: Added Unknown Fields Comparison for messages. +// Aug. 2009: Added different options to compare repeated fields. +// Apr. 2010: Moved field comparison to FieldComparator. + +#ifndef GOOGLE_PROTOBUF_UTIL_MESSAGE_DIFFERENCER_H__ +#define GOOGLE_PROTOBUF_UTIL_MESSAGE_DIFFERENCER_H__ + +#include +#include +#include +#include +#include + +#include // FieldDescriptor +#include // Message +#include +#include + +// Always include as last one, otherwise it can break compilation +#include + +namespace google { +namespace protobuf { + +class DynamicMessageFactory; +class FieldDescriptor; + +namespace io { +class ZeroCopyOutputStream; +class Printer; +} // namespace io + +namespace util { + +class DefaultFieldComparator; +class FieldContext; // declared below MessageDifferencer + +// Defines a collection of field descriptors. +// In case of internal google codebase we are using absl::FixedArray instead +// of vector. It significantly speeds up proto comparison (by ~30%) by +// reducing the number of malloc/free operations +typedef std::vector FieldDescriptorArray; + +// A basic differencer that can be used to determine +// the differences between two specified Protocol Messages. If any differences +// are found, the Compare method will return false, and any differencer reporter +// specified via ReportDifferencesTo will have its reporting methods called (see +// below for implementation of the report). Based off of the original +// ProtocolDifferencer implementation in //net/proto/protocol-differencer.h +// (Thanks Todd!). +// +// MessageDifferencer REQUIRES that compared messages be the same type, defined +// as messages that share the same descriptor. If not, the behavior of this +// class is undefined. +// +// People disagree on what MessageDifferencer should do when asked to compare +// messages with different descriptors. Some people think it should always +// return false. Others expect it to try to look for similar fields and +// compare them anyway -- especially if the descriptors happen to be identical. +// If we chose either of these behaviors, some set of people would find it +// surprising, and could end up writing code expecting the other behavior +// without realizing their error. Therefore, we forbid that usage. +// +// This class is implemented based on the proto2 reflection. The performance +// should be good enough for normal usages. However, for places where the +// performance is extremely sensitive, there are several alternatives: +// - Comparing serialized string +// Downside: false negatives (there are messages that are the same but their +// serialized strings are different). +// - Equals code generator by compiler plugin (net/proto2/contrib/equals_plugin) +// Downside: more generated code; maintenance overhead for the additional rule +// (must be in sync with the original proto_library). +// +// Note on handling of google.protobuf.Any: MessageDifferencer automatically +// unpacks Any::value into a Message and compares its individual fields. +// Messages encoded in a repeated Any cannot be compared using TreatAsMap. +// +// Note on thread-safety: MessageDifferencer is *not* thread-safe. You need to +// guard it with a lock to use the same MessageDifferencer instance from +// multiple threads. Note that it's fine to call static comparison methods +// (like MessageDifferencer::Equals) concurrently, but it's not recommended for +// performance critical code as it leads to extra allocations. +class PROTOBUF_EXPORT MessageDifferencer { + public: + // Determines whether the supplied messages are equal. Equality is defined as + // all fields within the two messages being set to the same value. Primitive + // fields and strings are compared by value while embedded messages/groups + // are compared as if via a recursive call. Use Compare() with IgnoreField() + // if some fields should be ignored in the comparison. Use Compare() with + // TreatAsSet() if there are repeated fields where ordering does not matter. + // + // This method REQUIRES that the two messages have the same + // Descriptor (message1.GetDescriptor() == message2.GetDescriptor()). + static bool Equals(const Message& message1, const Message& message2); + + // Determines whether the supplied messages are equivalent. Equivalency is + // defined as all fields within the two messages having the same value. This + // differs from the Equals method above in that fields with default values + // are considered set to said value automatically. For details on how default + // values are defined for each field type, see: + // https://developers.google.com/protocol-buffers/docs/proto?csw=1#optional. + // Also, Equivalent() ignores unknown fields. Use IgnoreField() and Compare() + // if some fields should be ignored in the comparison. + // + // This method REQUIRES that the two messages have the same + // Descriptor (message1.GetDescriptor() == message2.GetDescriptor()). + static bool Equivalent(const Message& message1, const Message& message2); + + // Determines whether the supplied messages are approximately equal. + // Approximate equality is defined as all fields within the two messages + // being approximately equal. Primitive (non-float) fields and strings are + // compared by value, floats are compared using MathUtil::AlmostEquals() and + // embedded messages/groups are compared as if via a recursive call. Use + // IgnoreField() and Compare() if some fields should be ignored in the + // comparison. + // + // This method REQUIRES that the two messages have the same + // Descriptor (message1.GetDescriptor() == message2.GetDescriptor()). + static bool ApproximatelyEquals(const Message& message1, + const Message& message2); + + // Determines whether the supplied messages are approximately equivalent. + // Approximate equivalency is defined as all fields within the two messages + // being approximately equivalent. As in + // MessageDifferencer::ApproximatelyEquals, primitive (non-float) fields and + // strings are compared by value, floats are compared using + // MathUtil::AlmostEquals() and embedded messages/groups are compared as if + // via a recursive call. However, fields with default values are considered + // set to said value, as per MessageDiffencer::Equivalent. Use IgnoreField() + // and Compare() if some fields should be ignored in the comparison. + // + // This method REQUIRES that the two messages have the same + // Descriptor (message1.GetDescriptor() == message2.GetDescriptor()). + static bool ApproximatelyEquivalent(const Message& message1, + const Message& message2); + + // Identifies an individual field in a message instance. Used for field_path, + // below. + struct SpecificField { + // For known fields, "field" is filled in and "unknown_field_number" is -1. + // For unknown fields, "field" is NULL, "unknown_field_number" is the field + // number, and "unknown_field_type" is its type. + const FieldDescriptor* field; + int unknown_field_number; + UnknownField::Type unknown_field_type; + + // If this a repeated field, "index" is the index within it. For unknown + // fields, this is the index of the field among all unknown fields of the + // same field number and type. + int index; + + // If "field" is a repeated field which is being treated as a map or + // a set (see TreatAsMap() and TreatAsSet(), below), new_index indicates + // the index the position to which the element has moved. If the element + // has not moved, "new_index" will have the same value as "index". + int new_index; + + // For unknown fields, these are the pointers to the UnknownFieldSet + // containing the unknown fields. In certain cases (e.g. proto1's + // MessageSet, or nested groups of unknown fields), these may differ from + // the messages' internal UnknownFieldSets. + const UnknownFieldSet* unknown_field_set1; + const UnknownFieldSet* unknown_field_set2; + + // For unknown fields, these are the index of the field within the + // UnknownFieldSets. One or the other will be -1 when + // reporting an addition or deletion. + int unknown_field_index1; + int unknown_field_index2; + + SpecificField() + : field(NULL), + unknown_field_number(-1), + index(-1), + new_index(-1), + unknown_field_set1(NULL), + unknown_field_set2(NULL), + unknown_field_index1(-1), + unknown_field_index2(-1) {} + }; + + // Abstract base class from which all MessageDifferencer + // reporters derive. The five Report* methods below will be called when + // a field has been added, deleted, modified, moved, or matched. The third + // argument is a vector of FieldDescriptor pointers which describes the chain + // of fields that was taken to find the current field. For example, for a + // field found in an embedded message, the vector will contain two + // FieldDescriptors. The first will be the field of the embedded message + // itself and the second will be the actual field in the embedded message + // that was added/deleted/modified. + // Fields will be reported in PostTraversalOrder. + // For example, given following proto, if both baz and quux are changed. + // foo { + // bar { + // baz: 1 + // quux: 2 + // } + // } + // ReportModified will be invoked with following order: + // 1. foo.bar.baz or foo.bar.quux + // 2. foo.bar.quux or foo.bar.baz + // 2. foo.bar + // 3. foo + class PROTOBUF_EXPORT Reporter { + public: + Reporter(); + virtual ~Reporter(); + + // Reports that a field has been added into Message2. + virtual void ReportAdded(const Message& message1, const Message& message2, + const std::vector& field_path) = 0; + + // Reports that a field has been deleted from Message1. + virtual void ReportDeleted( + const Message& message1, const Message& message2, + const std::vector& field_path) = 0; + + // Reports that the value of a field has been modified. + virtual void ReportModified( + const Message& message1, const Message& message2, + const std::vector& field_path) = 0; + + // Reports that a repeated field has been moved to another location. This + // only applies when using TreatAsSet or TreatAsMap() -- see below. Also + // note that for any given field, ReportModified and ReportMoved are + // mutually exclusive. If a field has been both moved and modified, then + // only ReportModified will be called. + virtual void ReportMoved( + const Message& /* message1 */, const Message& /* message2 */, + const std::vector& /* field_path */) {} + + // Reports that two fields match. Useful for doing side-by-side diffs. + // This function is mutually exclusive with ReportModified and ReportMoved. + // Note that you must call set_report_matches(true) before calling Compare + // to make use of this function. + virtual void ReportMatched( + const Message& /* message1 */, const Message& /* message2 */, + const std::vector& /* field_path */) {} + + // Reports that two fields would have been compared, but the + // comparison has been skipped because the field was marked as + // 'ignored' using IgnoreField(). This function is mutually + // exclusive with all the other Report() functions. + // + // The contract of ReportIgnored is slightly different than the + // other Report() functions, in that |field_path.back().index| is + // always equal to -1, even if the last field is repeated. This is + // because while the other Report() functions indicate where in a + // repeated field the action (Addition, Deletion, etc...) + // happened, when a repeated field is 'ignored', the differencer + // simply calls ReportIgnored on the repeated field as a whole and + // moves on without looking at its individual elements. + // + // Furthermore, ReportIgnored() does not indicate whether the + // fields were in fact equal or not, as Compare() does not inspect + // these fields at all. It is up to the Reporter to decide whether + // the fields are equal or not (perhaps with a second call to + // Compare()), if it cares. + virtual void ReportIgnored( + const Message& /* message1 */, const Message& /* message2 */, + const std::vector& /* field_path */) {} + + // Report that an unknown field is ignored. (see comment above). + // Note this is a different function since the last SpecificField in field + // path has a null field. This could break existing Reporter. + virtual void ReportUnknownFieldIgnored( + const Message& /* message1 */, const Message& /* message2 */, + const std::vector& /* field_path */) {} + + private: + GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(Reporter); + }; + + // MapKeyComparator is used to determine if two elements have the same key + // when comparing elements of a repeated field as a map. + class PROTOBUF_EXPORT MapKeyComparator { + public: + MapKeyComparator(); + virtual ~MapKeyComparator(); + + virtual bool IsMatch( + const Message& /* message1 */, const Message& /* message2 */, + const std::vector& /* parent_fields */) const { + GOOGLE_CHECK(false) << "IsMatch() is not implemented."; + return false; + } + + private: + GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(MapKeyComparator); + }; + + // Abstract base class from which all IgnoreCriteria derive. + // By adding IgnoreCriteria more complex ignore logic can be implemented. + // IgnoreCriteria are registed with AddIgnoreCriteria. For each compared + // field IsIgnored is called on each added IgnoreCriteria until one returns + // true or all return false. + // IsIgnored is called for fields where at least one side has a value. + class PROTOBUF_EXPORT IgnoreCriteria { + public: + IgnoreCriteria(); + virtual ~IgnoreCriteria(); + + // Returns true if the field should be ignored. + virtual bool IsIgnored( + const Message& /* message1 */, const Message& /* message2 */, + const FieldDescriptor* /* field */, + const std::vector& /* parent_fields */) = 0; + + // Returns true if the unknown field should be ignored. + // Note: This will be called for unknown fields as well in which case + // field.field will be null. + virtual bool IsUnknownFieldIgnored( + const Message& /* message1 */, const Message& /* message2 */, + const SpecificField& /* field */, + const std::vector& /* parent_fields */) { + return false; + } + }; + + // To add a Reporter, construct default here, then use ReportDifferencesTo or + // ReportDifferencesToString. + explicit MessageDifferencer(); + + ~MessageDifferencer(); + + enum MessageFieldComparison { + EQUAL, // Fields must be present in both messages + // for the messages to be considered the same. + EQUIVALENT, // Fields with default values are considered set + // for comparison purposes even if not explicitly + // set in the messages themselves. Unknown fields + // are ignored. + }; + + enum Scope { + FULL, // All fields of both messages are considered in the comparison. + PARTIAL // Only fields present in the first message are considered; fields + // set only in the second message will be skipped during + // comparison. + }; + + // DEPRECATED. Use FieldComparator::FloatComparison instead. + enum FloatComparison { + EXACT, // Floats and doubles are compared exactly. + APPROXIMATE // Floats and doubles are compared using the + // MathUtil::AlmostEquals method. + }; + + enum RepeatedFieldComparison { + AS_LIST, // Repeated fields are compared in order. Differing values at + // the same index are reported using ReportModified(). If the + // repeated fields have different numbers of elements, the + // unpaired elements are reported using ReportAdded() or + // ReportDeleted(). + AS_SET, // Treat all the repeated fields as sets. + // See TreatAsSet(), as below. + AS_SMART_LIST, // Similar to AS_SET, but preserve the order and find the + // longest matching sequence from the first matching + // element. To use an optimal solution, call + // SetMatchIndicesForSmartListCallback() to pass it in. + AS_SMART_SET, // Similar to AS_SET, but match elements with fewest diffs. + }; + + // The elements of the given repeated field will be treated as a set for + // diffing purposes, so different orderings of the same elements will be + // considered equal. Elements which are present on both sides of the + // comparison but which have changed position will be reported with + // ReportMoved(). Elements which only exist on one side or the other are + // reported with ReportAdded() and ReportDeleted() regardless of their + // positions. ReportModified() is never used for this repeated field. If + // the only differences between the compared messages is that some fields + // have been moved, then the comparison returns true. + // + // Note that despite the name of this method, this is really + // comparison as multisets: if one side of the comparison has a duplicate + // in the repeated field but the other side doesn't, this will count as + // a mismatch. + // + // If the scope of comparison is set to PARTIAL, then in addition to what's + // above, extra values added to repeated fields of the second message will + // not cause the comparison to fail. + // + // Note that set comparison is currently O(k * n^2) (where n is the total + // number of elements, and k is the average size of each element). In theory + // it could be made O(n * k) with a more complex hashing implementation. Feel + // free to contribute one if the current implementation is too slow for you. + // If partial matching is also enabled, the time complexity will be O(k * n^2 + // + n^3) in which n^3 is the time complexity of the maximum matching + // algorithm. + // + // REQUIRES: field->is_repeated() and field not registered with TreatAsList + void TreatAsSet(const FieldDescriptor* field); + void TreatAsSmartSet(const FieldDescriptor* field); + + // The elements of the given repeated field will be treated as a list for + // diffing purposes, so different orderings of the same elements will NOT be + // considered equal. + // + // REQUIRED: field->is_repeated() and field not registered with TreatAsSet + void TreatAsList(const FieldDescriptor* field); + // Note that the complexity is similar to treating as SET. + void TreatAsSmartList(const FieldDescriptor* field); + + // The elements of the given repeated field will be treated as a map for + // diffing purposes, with |key| being the map key. Thus, elements with the + // same key will be compared even if they do not appear at the same index. + // Differences are reported similarly to TreatAsSet(), except that + // ReportModified() is used to report elements with the same key but + // different values. Note that if an element is both moved and modified, + // only ReportModified() will be called. As with TreatAsSet, if the only + // differences between the compared messages is that some fields have been + // moved, then the comparison returns true. See TreatAsSet for notes on + // performance. + // + // REQUIRES: field->is_repeated() + // REQUIRES: field->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE + // REQUIRES: key->containing_type() == field->message_type() + void TreatAsMap(const FieldDescriptor* field, const FieldDescriptor* key); + // Same as TreatAsMap except that this method will use multiple fields as + // the key in comparison. All specified fields in 'key_fields' should be + // present in the compared elements. Two elements will be treated as having + // the same key iff they have the same value for every specified field. There + // are two steps in the comparison process. The first one is key matching. + // Every element from one message will be compared to every element from + // the other message. Only fields in 'key_fields' are compared in this step + // to decide if two elements have the same key. The second step is value + // comparison. Those pairs of elements with the same key (with equal value + // for every field in 'key_fields') will be compared in this step. + // Time complexity of the first step is O(s * m * n ^ 2) where s is the + // average size of the fields specified in 'key_fields', m is the number of + // fields in 'key_fields' and n is the number of elements. If partial + // matching is enabled, an extra O(n^3) will be incured by the maximum + // matching algorithm. The second step is O(k * n) where k is the average + // size of each element. + void TreatAsMapWithMultipleFieldsAsKey( + const FieldDescriptor* field, + const std::vector& key_fields); + // Same as TreatAsMapWithMultipleFieldsAsKey, except that each of the field + // do not necessarily need to be a direct subfield. Each element in + // key_field_paths indicate a path from the message being compared, listing + // successive subfield to reach the key field. + // + // REQUIRES: + // for key_field_path in key_field_paths: + // key_field_path[0]->containing_type() == field->message_type() + // for i in [0, key_field_path.size() - 1): + // key_field_path[i+1]->containing_type() == + // key_field_path[i]->message_type() + // key_field_path[i]->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE + // !key_field_path[i]->is_repeated() + void TreatAsMapWithMultipleFieldPathsAsKey( + const FieldDescriptor* field, + const std::vector >& key_field_paths); + + // Uses a custom MapKeyComparator to determine if two elements have the same + // key when comparing a repeated field as a map. + // The caller is responsible to delete the key_comparator. + // This method varies from TreatAsMapWithMultipleFieldsAsKey only in the + // first key matching step. Rather than comparing some specified fields, it + // will invoke the IsMatch method of the given 'key_comparator' to decide if + // two elements have the same key. + void TreatAsMapUsingKeyComparator(const FieldDescriptor* field, + const MapKeyComparator* key_comparator); + + // Initiates and returns a new instance of MultipleFieldsMapKeyComparator. + MapKeyComparator* CreateMultipleFieldsMapKeyComparator( + const std::vector >& key_field_paths); + + // Add a custom ignore criteria that is evaluated in addition to the + // ignored fields added with IgnoreField. + // Takes ownership of ignore_criteria. + void AddIgnoreCriteria(IgnoreCriteria* ignore_criteria); + + // Indicates that any field with the given descriptor should be + // ignored for the purposes of comparing two messages. This applies + // to fields nested in the message structure as well as top level + // ones. When the MessageDifferencer encounters an ignored field, + // ReportIgnored is called on the reporter, if one is specified. + // + // The only place where the field's 'ignored' status is not applied is when + // it is being used as a key in a field passed to TreatAsMap or is one of + // the fields passed to TreatAsMapWithMultipleFieldsAsKey. + // In this case it is compared in key matching but after that it's ignored + // in value comparison. + void IgnoreField(const FieldDescriptor* field); + + // Sets the field comparator used to determine differences between protocol + // buffer fields. By default it's set to a DefaultFieldComparator instance. + // MessageDifferencer doesn't take ownership over the passed object. + // Note that this method must be called before Compare for the comparator to + // be used. + void set_field_comparator(FieldComparator* comparator); + + // DEPRECATED. Pass a DefaultFieldComparator instance instead. + // Sets the fraction and margin for the float comparison of a given field. + // Uses MathUtil::WithinFractionOrMargin to compare the values. + // NOTE: this method does nothing if differencer's field comparator has been + // set to a custom object. + // + // REQUIRES: field->cpp_type == FieldDescriptor::CPPTYPE_DOUBLE or + // field->cpp_type == FieldDescriptor::CPPTYPE_FLOAT + // REQUIRES: float_comparison_ == APPROXIMATE + void SetFractionAndMargin(const FieldDescriptor* field, double fraction, + double margin); + + // Sets the type of comparison (as defined in the MessageFieldComparison + // enumeration above) that is used by this differencer when determining how + // to compare fields in messages. + void set_message_field_comparison(MessageFieldComparison comparison); + + // Tells the differencer whether or not to report matches. This method must + // be called before Compare. The default for a new differencer is false. + void set_report_matches(bool report_matches) { + report_matches_ = report_matches; + } + + // Tells the differencer whether or not to report moves (in a set or map + // repeated field). This method must be called before Compare. The default for + // a new differencer is true. + void set_report_moves(bool report_moves) { report_moves_ = report_moves; } + + // Tells the differencer whether or not to report ignored values. This method + // must be called before Compare. The default for a new differencer is true. + void set_report_ignores(bool report_ignores) { + report_ignores_ = report_ignores; + } + + // Sets the scope of the comparison (as defined in the Scope enumeration + // above) that is used by this differencer when determining which fields to + // compare between the messages. + void set_scope(Scope scope); + + // Returns the current scope used by this differencer. + Scope scope(); + + // DEPRECATED. Pass a DefaultFieldComparator instance instead. + // Sets the type of comparison (as defined in the FloatComparison enumeration + // above) that is used by this differencer when comparing float (and double) + // fields in messages. + // NOTE: this method does nothing if differencer's field comparator has been + // set to a custom object. + void set_float_comparison(FloatComparison comparison); + + // Sets the type of comparison for repeated field (as defined in the + // RepeatedFieldComparison enumeration above) that is used by this + // differencer when compare repeated fields in messages. + void set_repeated_field_comparison(RepeatedFieldComparison comparison); + + // Returns the current repeated field comparison used by this differencer. + RepeatedFieldComparison repeated_field_comparison(); + + // Compares the two specified messages, returning true if they are the same, + // false otherwise. If this method returns false, any changes between the + // two messages will be reported if a Reporter was specified via + // ReportDifferencesTo (see also ReportDifferencesToString). + // + // This method REQUIRES that the two messages have the same + // Descriptor (message1.GetDescriptor() == message2.GetDescriptor()). + bool Compare(const Message& message1, const Message& message2); + + // Same as above, except comparing only the list of fields specified by the + // two vectors of FieldDescriptors. + bool CompareWithFields( + const Message& message1, const Message& message2, + const std::vector& message1_fields, + const std::vector& message2_fields); + + // Automatically creates a reporter that will output the differences + // found (if any) to the specified output string pointer. Note that this + // method must be called before Compare. + void ReportDifferencesToString(std::string* output); + + // Tells the MessageDifferencer to report differences via the specified + // reporter. Note that this method must be called before Compare for + // the reporter to be used. It is the responsibility of the caller to delete + // this object. + // If the provided pointer equals NULL, the MessageDifferencer stops reporting + // differences to any previously set reporters or output strings. + void ReportDifferencesTo(Reporter* reporter); + + // An implementation of the MessageDifferencer Reporter that outputs + // any differences found in human-readable form to the supplied + // ZeroCopyOutputStream or Printer. If a printer is used, the delimiter + // *must* be '$'. + // + // WARNING: this reporter does not necessarily flush its output until it is + // destroyed. As a result, it is not safe to assume the output is valid or + // complete until after you destroy the reporter. For example, if you use a + // StreamReporter to write to a StringOutputStream, the target string may + // contain uninitialized data until the reporter is destroyed. + class PROTOBUF_EXPORT StreamReporter : public Reporter { + public: + explicit StreamReporter(io::ZeroCopyOutputStream* output); + explicit StreamReporter(io::Printer* printer); // delimiter '$' + ~StreamReporter() override; + + // When set to true, the stream reporter will also output aggregates nodes + // (i.e. messages and groups) whose subfields have been modified. When + // false, will only report the individual subfields. Defaults to false. + void set_report_modified_aggregates(bool report) { + report_modified_aggregates_ = report; + } + + // The following are implementations of the methods described above. + + void ReportAdded(const Message& message1, const Message& message2, + const std::vector& field_path) override; + + void ReportDeleted(const Message& message1, const Message& message2, + const std::vector& field_path) override; + + void ReportModified(const Message& message1, const Message& message2, + const std::vector& field_path) override; + + void ReportMoved(const Message& message1, const Message& message2, + const std::vector& field_path) override; + + void ReportMatched(const Message& message1, const Message& message2, + const std::vector& field_path) override; + + void ReportIgnored(const Message& message1, const Message& message2, + const std::vector& field_path) override; + + void ReportUnknownFieldIgnored( + const Message& message1, const Message& message2, + const std::vector& field_path) override; + + protected: + // Prints the specified path of fields to the buffer. message is used to + // print map keys. + virtual void PrintPath(const std::vector& field_path, + bool left_side, const Message& message); + + // Prints the specified path of fields to the buffer. + virtual void PrintPath(const std::vector& field_path, + bool left_side); + + // Prints the value of fields to the buffer. left_side is true if the + // given message is from the left side of the comparison, false if it + // was the right. This is relevant only to decide whether to follow + // unknown_field_index1 or unknown_field_index2 when an unknown field + // is encountered in field_path. + virtual void PrintValue(const Message& message, + const std::vector& field_path, + bool left_side); + + // Prints the specified path of unknown fields to the buffer. + virtual void PrintUnknownFieldValue(const UnknownField* unknown_field); + + // Just print a string + void Print(const std::string& str); + + private: + io::Printer* printer_; + bool delete_printer_; + bool report_modified_aggregates_; + + GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(StreamReporter); + }; + + private: + friend class DefaultFieldComparator; + + // A MapKeyComparator to be used in TreatAsMapUsingKeyComparator. + // Implementation of this class needs to do field value comparison which + // relies on some private methods of MessageDifferencer. That's why this + // class is declared as a nested class of MessageDifferencer. + class MultipleFieldsMapKeyComparator; + + // A MapKeyComparator for use with map_entries. + class PROTOBUF_EXPORT MapEntryKeyComparator : public MapKeyComparator { + public: + explicit MapEntryKeyComparator(MessageDifferencer* message_differencer); + bool IsMatch( + const Message& message1, const Message& message2, + const std::vector& parent_fields) const override; + + private: + MessageDifferencer* message_differencer_; + }; + + // Returns true if field1's number() is less than field2's. + static bool FieldBefore(const FieldDescriptor* field1, + const FieldDescriptor* field2); + + // Retrieve all the set fields, including extensions. + FieldDescriptorArray RetrieveFields(const Message& message, + bool base_message); + + // Combine the two lists of fields into the combined_fields output vector. + // All fields present in both lists will always be included in the combined + // list. Fields only present in one of the lists will only appear in the + // combined list if the corresponding fields_scope option is set to FULL. + FieldDescriptorArray CombineFields(const FieldDescriptorArray& fields1, + Scope fields1_scope, + const FieldDescriptorArray& fields2, + Scope fields2_scope); + + // Internal version of the Compare method which performs the actual + // comparison. The parent_fields vector is a vector containing field + // descriptors of all fields accessed to get to this comparison operation + // (i.e. if the current message is an embedded message, the parent_fields + // vector will contain the field that has this embedded message). + bool Compare(const Message& message1, const Message& message2, + std::vector* parent_fields); + + // Compares all the unknown fields in two messages. + bool CompareUnknownFields(const Message& message1, const Message& message2, + const UnknownFieldSet&, const UnknownFieldSet&, + std::vector* parent_fields); + + // Compares the specified messages for the requested field lists. The field + // lists are modified depending on comparison settings, and then passed to + // CompareWithFieldsInternal. + bool CompareRequestedFieldsUsingSettings( + const Message& message1, const Message& message2, + const FieldDescriptorArray& message1_fields, + const FieldDescriptorArray& message2_fields, + std::vector* parent_fields); + + // Compares the specified messages with the specified field lists. + bool CompareWithFieldsInternal(const Message& message1, + const Message& message2, + const FieldDescriptorArray& message1_fields, + const FieldDescriptorArray& message2_fields, + std::vector* parent_fields); + + // Compares the repeated fields, and report the error. + bool CompareRepeatedField(const Message& message1, const Message& message2, + const FieldDescriptor* field, + std::vector* parent_fields); + + // Shorthand for CompareFieldValueUsingParentFields with NULL parent_fields. + bool CompareFieldValue(const Message& message1, const Message& message2, + const FieldDescriptor* field, int index1, int index2); + + // Compares the specified field on the two messages, returning + // true if they are the same, false otherwise. For repeated fields, + // this method only compares the value in the specified index. This method + // uses Compare functions to recurse into submessages. + // The parent_fields vector is used in calls to a Reporter instance calls. + // It can be NULL, in which case the MessageDifferencer will create new + // list of parent messages if it needs to recursively compare the given field. + // To avoid confusing users you should not set it to NULL unless you modified + // Reporter to handle the change of parent_fields correctly. + bool CompareFieldValueUsingParentFields( + const Message& message1, const Message& message2, + const FieldDescriptor* field, int index1, int index2, + std::vector* parent_fields); + + // Compares the specified field on the two messages, returning comparison + // result, as returned by appropriate FieldComparator. + FieldComparator::ComparisonResult GetFieldComparisonResult( + const Message& message1, const Message& message2, + const FieldDescriptor* field, int index1, int index2, + const FieldContext* field_context); + + // Check if the two elements in the repeated field are match to each other. + // if the key_comprator is NULL, this function returns true when the two + // elements are equal. + bool IsMatch(const FieldDescriptor* repeated_field, + const MapKeyComparator* key_comparator, const Message* message1, + const Message* message2, + const std::vector& parent_fields, + Reporter* reporter, int index1, int index2); + + // Returns true when this repeated field has been configured to be treated + // as a Set / SmartSet / SmartList. + bool IsTreatedAsSet(const FieldDescriptor* field); + bool IsTreatedAsSmartSet(const FieldDescriptor* field); + + bool IsTreatedAsSmartList(const FieldDescriptor* field); + // When treating as SMART_LIST, it uses MatchIndicesPostProcessorForSmartList + // by default to find the longest matching sequence from the first matching + // element. The callback takes two vectors showing the matching indices from + // the other vector, where -1 means an unmatch. + void SetMatchIndicesForSmartListCallback( + std::function*, std::vector*)> callback); + + // Returns true when this repeated field is to be compared as a subset, ie. + // has been configured to be treated as a set or map and scope is set to + // PARTIAL. + bool IsTreatedAsSubset(const FieldDescriptor* field); + + // Returns true if this field is to be ignored when this + // MessageDifferencer compares messages. + bool IsIgnored(const Message& message1, const Message& message2, + const FieldDescriptor* field, + const std::vector& parent_fields); + + // Returns true if this unknown field is to be ignored when this + // MessageDifferencer compares messages. + bool IsUnknownFieldIgnored(const Message& message1, const Message& message2, + const SpecificField& field, + const std::vector& parent_fields); + + // Returns MapKeyComparator* when this field has been configured to be treated + // as a map or its is_map() return true. If not, returns NULL. + const MapKeyComparator* GetMapKeyComparator( + const FieldDescriptor* field) const; + + // Attempts to match indices of a repeated field, so that the contained values + // match. Clears output vectors and sets their values to indices of paired + // messages, ie. if message1[0] matches message2[1], then match_list1[0] == 1 + // and match_list2[1] == 0. The unmatched indices are indicated by -1. + // Assumes the repeated field is not treated as a simple list. + // This method returns false if the match failed. However, it doesn't mean + // that the comparison succeeds when this method returns true (you need to + // double-check in this case). + bool MatchRepeatedFieldIndices( + const Message& message1, const Message& message2, + const FieldDescriptor* repeated_field, + const MapKeyComparator* key_comparator, + const std::vector& parent_fields, + std::vector* match_list1, std::vector* match_list2); + + // If "any" is of type google.protobuf.Any, extract its payload using + // DynamicMessageFactory and store in "data". + bool UnpackAny(const Message& any, std::unique_ptr* data); + + // Checks if index is equal to new_index in all the specific fields. + static bool CheckPathChanged(const std::vector& parent_fields); + + // CHECKs that the given repeated field can be compared according to + // new_comparison. + void CheckRepeatedFieldComparisons( + const FieldDescriptor* field, + const RepeatedFieldComparison& new_comparison); + + // Defines a map between field descriptors and their MapKeyComparators. + // Used for repeated fields when they are configured as TreatAsMap. + typedef std::map + FieldKeyComparatorMap; + + // Defines a set to store field descriptors. Used for repeated fields when + // they are configured as TreatAsSet. + typedef std::set FieldSet; + typedef std::map FieldMap; + + Reporter* reporter_; + DefaultFieldComparator default_field_comparator_; + FieldComparator* field_comparator_; + MessageFieldComparison message_field_comparison_; + Scope scope_; + RepeatedFieldComparison repeated_field_comparison_; + + FieldMap repeated_field_comparisons_; + // Keeps track of MapKeyComparators that are created within + // MessageDifferencer. These MapKeyComparators should be deleted + // before MessageDifferencer is destroyed. + // When TreatAsMap or TreatAsMapWithMultipleFieldsAsKey is called, we don't + // store the supplied FieldDescriptors directly. Instead, a new + // MapKeyComparator is created for comparison purpose. + std::vector owned_key_comparators_; + FieldKeyComparatorMap map_field_key_comparator_; + MapEntryKeyComparator map_entry_key_comparator_; + std::vector ignore_criteria_; + // Reused multiple times in RetrieveFields to avoid extra allocations + std::vector tmp_message_fields_; + + FieldSet ignored_fields_; + + bool report_matches_; + bool report_moves_; + bool report_ignores_; + + std::string* output_string_; + + // Callback to post-process the matched indices to support SMART_LIST. + std::function*, std::vector*)> + match_indices_for_smart_list_callback_; + + std::unique_ptr dynamic_message_factory_; + GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(MessageDifferencer); +}; + +// This class provides extra information to the FieldComparator::Compare +// function. +class PROTOBUF_EXPORT FieldContext { + public: + explicit FieldContext( + std::vector* parent_fields) + : parent_fields_(parent_fields) {} + + std::vector* parent_fields() const { + return parent_fields_; + } + + private: + std::vector* parent_fields_; +}; + +} // namespace util +} // namespace protobuf +} // namespace google + +#include + +#endif // GOOGLE_PROTOBUF_UTIL_MESSAGE_DIFFERENCER_H__ + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/google/protobuf/util/time_util.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/google/protobuf/util/time_util.h new file mode 100644 index 0000000000000000000000000000000000000000..5c7d4d521ee785eaedfb29d1b20165355b0b2f2e --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/google/protobuf/util/time_util.h @@ -0,0 +1,317 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Defines utilities for the Timestamp and Duration well known types. + +#ifndef GOOGLE_PROTOBUF_UTIL_TIME_UTIL_H__ +#define GOOGLE_PROTOBUF_UTIL_TIME_UTIL_H__ + +#include +#include +#include +#ifdef _MSC_VER +#ifdef _XBOX_ONE +struct timeval { + int64 tv_sec; /* seconds */ + int64 tv_usec; /* and microseconds */ +}; +#else +#include +#endif // _XBOX_ONE +#else +#include +#endif + +#include +#include + +#include + +namespace google { +namespace protobuf { +namespace util { + +// Utility functions for Timestamp and Duration. +class PROTOBUF_EXPORT TimeUtil { + typedef google::protobuf::Timestamp Timestamp; + typedef google::protobuf::Duration Duration; + + public: + // The min/max Timestamp/Duration values we support. + // + // For "0001-01-01T00:00:00Z". + static const int64 kTimestampMinSeconds = -62135596800LL; + // For "9999-12-31T23:59:59.999999999Z". + static const int64 kTimestampMaxSeconds = 253402300799LL; + static const int64 kDurationMinSeconds = -315576000000LL; + static const int64 kDurationMaxSeconds = 315576000000LL; + + // Converts Timestamp to/from RFC 3339 date string format. + // Generated output will always be Z-normalized and uses 3, 6 or 9 + // fractional digits as required to represent the exact time. When + // parsing, any fractional digits (or none) and any offset are + // accepted as long as they fit into nano-seconds precision. + // Note that Timestamp can only represent time from + // 0001-01-01T00:00:00Z to 9999-12-31T23:59:59.999999999Z. Converting + // a Timestamp outside of this range is undefined behavior. + // See https://www.ietf.org/rfc/rfc3339.txt + // + // Example of generated format: + // "1972-01-01T10:00:20.021Z" + // + // Example of accepted format: + // "1972-01-01T10:00:20.021-05:00" + static std::string ToString(const Timestamp& timestamp); + static bool FromString(const std::string& value, Timestamp* timestamp); + + // Converts Duration to/from string format. The string format will contains + // 3, 6, or 9 fractional digits depending on the precision required to + // represent the exact Duration value. For example: + // "1s", "1.010s", "1.000000100s", "-3.100s" + // The range that can be represented by Duration is from -315,576,000,000 + // to +315,576,000,000 inclusive (in seconds). + static std::string ToString(const Duration& duration); + static bool FromString(const std::string& value, Duration* timestamp); + +#ifdef GetCurrentTime +#undef GetCurrentTime // Visual Studio has macro GetCurrentTime +#endif + // Gets the current UTC time. + static Timestamp GetCurrentTime(); + // Returns the Time representing "1970-01-01 00:00:00". + static Timestamp GetEpoch(); + + // Converts between Duration and integer types. The behavior is undefined if + // the input value is not in the valid range of Duration. + static Duration NanosecondsToDuration(int64 nanos); + static Duration MicrosecondsToDuration(int64 micros); + static Duration MillisecondsToDuration(int64 millis); + static Duration SecondsToDuration(int64 seconds); + static Duration MinutesToDuration(int64 minutes); + static Duration HoursToDuration(int64 hours); + // Result will be truncated towards zero. For example, "-1.5s" will be + // truncated to "-1s", and "1.5s" to "1s" when converting to seconds. + // It's undefined behavior if the input duration is not valid or the result + // exceeds the range of int64. A duration is not valid if it's not in the + // valid range of Duration, or have an invalid nanos value (i.e., larger + // than 999999999, less than -999999999, or have a different sign from the + // seconds part). + static int64 DurationToNanoseconds(const Duration& duration); + static int64 DurationToMicroseconds(const Duration& duration); + static int64 DurationToMilliseconds(const Duration& duration); + static int64 DurationToSeconds(const Duration& duration); + static int64 DurationToMinutes(const Duration& duration); + static int64 DurationToHours(const Duration& duration); + // Creates Timestamp from integer types. The integer value indicates the + // time elapsed from Epoch time. The behavior is undefined if the input + // value is not in the valid range of Timestamp. + static Timestamp NanosecondsToTimestamp(int64 nanos); + static Timestamp MicrosecondsToTimestamp(int64 micros); + static Timestamp MillisecondsToTimestamp(int64 millis); + static Timestamp SecondsToTimestamp(int64 seconds); + // Result will be truncated down to the nearest integer value. For example, + // with "1969-12-31T23:59:59.9Z", TimestampToMilliseconds() returns -100 + // and TimestampToSeconds() returns -1. It's undefined behavior if the input + // Timestamp is not valid (i.e., its seconds part or nanos part does not fall + // in the valid range) or the return value doesn't fit into int64. + static int64 TimestampToNanoseconds(const Timestamp& timestamp); + static int64 TimestampToMicroseconds(const Timestamp& timestamp); + static int64 TimestampToMilliseconds(const Timestamp& timestamp); + static int64 TimestampToSeconds(const Timestamp& timestamp); + + // Conversion to/from other time/date types. Note that these types may + // have a different precision and time range from Timestamp/Duration. + // When converting to a lower precision type, the value will be truncated + // to the nearest value that can be represented. If the value is + // out of the range of the result type, the return value is undefined. + // + // Conversion to/from time_t + static Timestamp TimeTToTimestamp(time_t value); + static time_t TimestampToTimeT(const Timestamp& value); + + // Conversion to/from timeval + static Timestamp TimevalToTimestamp(const timeval& value); + static timeval TimestampToTimeval(const Timestamp& value); + static Duration TimevalToDuration(const timeval& value); + static timeval DurationToTimeval(const Duration& value); +}; + +} // namespace util +} // namespace protobuf +} // namespace google + +namespace google { +namespace protobuf { +// Overloaded operators for Duration. +// +// Assignment operators. +PROTOBUF_EXPORT Duration& operator+=(Duration& d1, + const Duration& d2); // NOLINT +PROTOBUF_EXPORT Duration& operator-=(Duration& d1, + const Duration& d2); // NOLINT +PROTOBUF_EXPORT Duration& operator*=(Duration& d, int64 r); // NOLINT +PROTOBUF_EXPORT Duration& operator*=(Duration& d, double r); // NOLINT +PROTOBUF_EXPORT Duration& operator/=(Duration& d, int64 r); // NOLINT +PROTOBUF_EXPORT Duration& operator/=(Duration& d, double r); // NOLINT +// Overload for other integer types. +template +Duration& operator*=(Duration& d, T r) { // NOLINT + int64 x = r; + return d *= x; +} +template +Duration& operator/=(Duration& d, T r) { // NOLINT + int64 x = r; + return d /= x; +} +PROTOBUF_EXPORT Duration& operator%=(Duration& d1, + const Duration& d2); // NOLINT +// Relational operators. +inline bool operator<(const Duration& d1, const Duration& d2) { + if (d1.seconds() == d2.seconds()) { + return d1.nanos() < d2.nanos(); + } + return d1.seconds() < d2.seconds(); +} +inline bool operator>(const Duration& d1, const Duration& d2) { + return d2 < d1; +} +inline bool operator>=(const Duration& d1, const Duration& d2) { + return !(d1 < d2); +} +inline bool operator<=(const Duration& d1, const Duration& d2) { + return !(d2 < d1); +} +inline bool operator==(const Duration& d1, const Duration& d2) { + return d1.seconds() == d2.seconds() && d1.nanos() == d2.nanos(); +} +inline bool operator!=(const Duration& d1, const Duration& d2) { + return !(d1 == d2); +} +// Additive operators +inline Duration operator-(const Duration& d) { + Duration result; + result.set_seconds(-d.seconds()); + result.set_nanos(-d.nanos()); + return result; +} +inline Duration operator+(const Duration& d1, const Duration& d2) { + Duration result = d1; + return result += d2; +} +inline Duration operator-(const Duration& d1, const Duration& d2) { + Duration result = d1; + return result -= d2; +} +// Multiplicative operators +template +inline Duration operator*(Duration d, T r) { + return d *= r; +} +template +inline Duration operator*(T r, Duration d) { + return d *= r; +} +template +inline Duration operator/(Duration d, T r) { + return d /= r; +} +PROTOBUF_EXPORT int64 operator/(const Duration& d1, const Duration& d2); + +inline Duration operator%(const Duration& d1, const Duration& d2) { + Duration result = d1; + return result %= d2; +} + +inline std::ostream& operator<<(std::ostream& out, const Duration& d) { + out << ::PROTOBUF_NAMESPACE_ID::util::TimeUtil::ToString(d); + return out; +} + +// Overloaded operators for Timestamp +// +// Assignment operators. +PROTOBUF_EXPORT Timestamp& operator+=(Timestamp& t, + const Duration& d); // NOLINT +PROTOBUF_EXPORT Timestamp& operator-=(Timestamp& t, + const Duration& d); // NOLINT +// Relational operators. +inline bool operator<(const Timestamp& t1, const Timestamp& t2) { + if (t1.seconds() == t2.seconds()) { + return t1.nanos() < t2.nanos(); + } + return t1.seconds() < t2.seconds(); +} +inline bool operator>(const Timestamp& t1, const Timestamp& t2) { + return t2 < t1; +} +inline bool operator>=(const Timestamp& t1, const Timestamp& t2) { + return !(t1 < t2); +} +inline bool operator<=(const Timestamp& t1, const Timestamp& t2) { + return !(t2 < t1); +} +inline bool operator==(const Timestamp& t1, const Timestamp& t2) { + return t1.seconds() == t2.seconds() && t1.nanos() == t2.nanos(); +} +inline bool operator!=(const Timestamp& t1, const Timestamp& t2) { + return !(t1 == t2); +} +// Additive operators. +inline Timestamp operator+(const Timestamp& t, const Duration& d) { + Timestamp result = t; + return result += d; +} +inline Timestamp operator+(const Duration& d, const Timestamp& t) { + Timestamp result = t; + return result += d; +} +inline Timestamp operator-(const Timestamp& t, const Duration& d) { + Timestamp result = t; + return result -= d; +} +PROTOBUF_EXPORT Duration operator-(const Timestamp& t1, const Timestamp& t2); + +inline std::ostream& operator<<(std::ostream& out, const Timestamp& t) { + out << ::PROTOBUF_NAMESPACE_ID::util::TimeUtil::ToString(t); + return out; +} + +} // namespace protobuf +} // namespace google + +#include + +#endif // GOOGLE_PROTOBUF_UTIL_TIME_UTIL_H__ + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/google/protobuf/util/type_resolver.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/google/protobuf/util/type_resolver.h new file mode 100644 index 0000000000000000000000000000000000000000..70f14b58a8776e0b70936176edcb6e2ac54665cd --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/google/protobuf/util/type_resolver.h @@ -0,0 +1,80 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Defines a TypeResolver for the Any message. + +#ifndef GOOGLE_PROTOBUF_UTIL_TYPE_RESOLVER_H__ +#define GOOGLE_PROTOBUF_UTIL_TYPE_RESOLVER_H__ + +#include + +#include +#include +#include + +#include + +namespace google { +namespace protobuf { +class DescriptorPool; +namespace util { + +// Abstract interface for a type resolver. +// +// Implementations of this interface must be thread-safe. +class PROTOBUF_EXPORT TypeResolver { + public: + TypeResolver() {} + virtual ~TypeResolver() {} + + // Resolves a type url for a message type. + virtual util::Status ResolveMessageType( + const std::string& type_url, google::protobuf::Type* message_type) = 0; + + // Resolves a type url for an enum type. + virtual util::Status ResolveEnumType(const std::string& type_url, + google::protobuf::Enum* enum_type) = 0; + + private: + GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(TypeResolver); +}; + +} // namespace util +} // namespace protobuf +} // namespace google + +#include + +#endif // GOOGLE_PROTOBUF_UTIL_TYPE_RESOLVER_H__ + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/google/protobuf/util/type_resolver_util.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/google/protobuf/util/type_resolver_util.h new file mode 100644 index 0000000000000000000000000000000000000000..207a84cce41e8b8182a836e797150b8641006ef0 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/google/protobuf/util/type_resolver_util.h @@ -0,0 +1,62 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Defines utilities for the TypeResolver. + +#ifndef GOOGLE_PROTOBUF_UTIL_TYPE_RESOLVER_UTIL_H__ +#define GOOGLE_PROTOBUF_UTIL_TYPE_RESOLVER_UTIL_H__ + +#include + +namespace google { +namespace protobuf { +class DescriptorPool; +namespace util { +class TypeResolver; + +#include + +// Creates a TypeResolver that serves type information in the given descriptor +// pool. Caller takes ownership of the returned TypeResolver. +PROTOBUF_EXPORT TypeResolver* NewTypeResolverForDescriptorPool( + const std::string& url_prefix, const DescriptorPool* pool); + +} // namespace util +} // namespace protobuf +} // namespace google + +#include + +#endif // GOOGLE_PROTOBUF_UTIL_TYPE_RESOLVER_UTIL_H__ + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/masked/maskedtensor/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/masked/maskedtensor/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ab38a732799ece06ac286f281819957b9b1aa640 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/masked/maskedtensor/__pycache__/__init__.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/masked/maskedtensor/__pycache__/_ops_refs.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/masked/maskedtensor/__pycache__/_ops_refs.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..907eff0e7ed59228cf41b8ca667f527c91729911 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/masked/maskedtensor/__pycache__/_ops_refs.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/masked/maskedtensor/__pycache__/binary.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/masked/maskedtensor/__pycache__/binary.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..beb95a2cfa8e827aba203e2b70161b1d9bb7b4d0 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/masked/maskedtensor/__pycache__/binary.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/masked/maskedtensor/__pycache__/creation.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/masked/maskedtensor/__pycache__/creation.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..103688e4fdc47f8f1223ee38d28efd3732201281 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/masked/maskedtensor/__pycache__/creation.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/masked/maskedtensor/__pycache__/passthrough.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/masked/maskedtensor/__pycache__/passthrough.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d60630ad41af938c97d23cfdbbbd0b4b41d763f7 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/masked/maskedtensor/__pycache__/passthrough.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/masked/maskedtensor/__pycache__/reductions.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/masked/maskedtensor/__pycache__/reductions.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b3f08f8f2170d786799c22f5d565711b4f621451 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/masked/maskedtensor/__pycache__/reductions.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/masked/maskedtensor/__pycache__/unary.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/masked/maskedtensor/__pycache__/unary.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e3e22ed2d1a60114c7c344cbcf91b7f6d98f41b1 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/masked/maskedtensor/__pycache__/unary.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/backends/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/backends/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c9d8a3267b4783723d0e4711d48c7d6c335b646c Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/backends/__pycache__/__init__.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/backends/__pycache__/thnn.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/backends/__pycache__/thnn.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8e3b73453c593b19bff4fdc18d5f9cd961a54fd6 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/backends/__pycache__/thnn.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/parallel/__pycache__/comm.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/parallel/__pycache__/comm.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1727008e8c7be30f901e673f05e553e80deecf59 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/parallel/__pycache__/comm.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/utils/_expanded_weights/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/utils/_expanded_weights/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..97042a9670300032a22b8285552fa9c74ebb872c Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/utils/_expanded_weights/__pycache__/__init__.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/utils/_expanded_weights/__pycache__/conv_expanded_weights.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/utils/_expanded_weights/__pycache__/conv_expanded_weights.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..477d8f955340998c8698cae98bdf41cdd6a693d0 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/utils/_expanded_weights/__pycache__/conv_expanded_weights.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/utils/_expanded_weights/__pycache__/conv_utils.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/utils/_expanded_weights/__pycache__/conv_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..275f6aef4c5df86fe4e6f14834559bb07403ea60 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/utils/_expanded_weights/__pycache__/conv_utils.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/utils/_expanded_weights/__pycache__/embedding_expanded_weights.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/utils/_expanded_weights/__pycache__/embedding_expanded_weights.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..94a1f2975a514389d1cc5aa67f6390dabddbea1e Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/utils/_expanded_weights/__pycache__/embedding_expanded_weights.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/utils/_expanded_weights/__pycache__/expanded_weights_impl.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/utils/_expanded_weights/__pycache__/expanded_weights_impl.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b841e48e6a88793f66685c72b51dd84cdce46071 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/utils/_expanded_weights/__pycache__/expanded_weights_impl.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/utils/_expanded_weights/__pycache__/group_norm_expanded_weights.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/utils/_expanded_weights/__pycache__/group_norm_expanded_weights.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c1a94975d5283bb5ffd0a986d29a9df3447e1981 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/utils/_expanded_weights/__pycache__/group_norm_expanded_weights.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/utils/_expanded_weights/__pycache__/instance_norm_expanded_weights.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/utils/_expanded_weights/__pycache__/instance_norm_expanded_weights.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..454f7b1dd7deab03f49b053951aa6c150a916f14 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/utils/_expanded_weights/__pycache__/instance_norm_expanded_weights.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/utils/_expanded_weights/__pycache__/layer_norm_expanded_weights.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/utils/_expanded_weights/__pycache__/layer_norm_expanded_weights.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..581a4f1c504e9108dcc842011d9f08c6665c5dd7 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/utils/_expanded_weights/__pycache__/layer_norm_expanded_weights.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/utils/_expanded_weights/__pycache__/linear_expanded_weights.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/utils/_expanded_weights/__pycache__/linear_expanded_weights.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1a8918c5b0bfb1587bf05151a1a45d596181b702 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/nn/utils/_expanded_weights/__pycache__/linear_expanded_weights.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..28e3b74efedad4f27703a7dc5c8703dfca11f6d6 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/__pycache__/__init__.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/__pycache__/_comparison.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/__pycache__/_comparison.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6affcab86bc27e12c4463decc033bc431f902772 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/__pycache__/_comparison.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/__pycache__/_creation.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/__pycache__/_creation.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9b1bb1918597c7092091727d1934f0cdb30bc55e Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/__pycache__/_creation.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/__pycache__/_utils.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/__pycache__/_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..05717d795cebca9f0672130807a5a57cd2616541 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/__pycache__/_utils.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/autocast_test_lists.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/autocast_test_lists.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c35803762e12ecddf426be43eb66bbba4587ec7a Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/autocast_test_lists.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/autograd_function_db.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/autograd_function_db.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9835778939aca81e6432c6d84b26aa5e7b07e0ca Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/autograd_function_db.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/common_dist_composable.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/common_dist_composable.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b6209a1fd6036ba6d371a497760465d8920e9b2a Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/common_dist_composable.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/common_distributed.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/common_distributed.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..af5370ac8592a4f1562603fb2ccca9b4a47e2351 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/common_distributed.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/common_dtype.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/common_dtype.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7c383d5846059df0ab7d44a8bca04146322d5b03 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/common_dtype.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/common_jit.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/common_jit.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2001033d7457f09996ffe50987f1e89351ef827d Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/common_jit.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/common_mkldnn.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/common_mkldnn.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2da57f7e7cb624187221c3904e0cb3ad5749e26b Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/common_mkldnn.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/common_optimizers.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/common_optimizers.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8af436859022dd92f265da3541e09836deddbac7 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/common_optimizers.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/common_quantized.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/common_quantized.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dc03968eb10e9051790f3fa504eb62638652e1d2 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/common_quantized.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/common_subclass.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/common_subclass.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0d58b03e1a5f101f6649a4a7aaa7a17e4b812c54 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/common_subclass.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/composite_compliance.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/composite_compliance.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..81f7a86fd359169cd9e3b2c3d1fc6d786f2492e9 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/composite_compliance.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/custom_tensor.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/custom_tensor.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..59358dfb8bc7b24e454e6d8245113724adcc3208 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/custom_tensor.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/dynamo_test_failures.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/dynamo_test_failures.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..efc55c0544cec84ba1b05956dcea928eae25d58c Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/dynamo_test_failures.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/fake_config_module.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/fake_config_module.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..16906dd35c9bb0721357a70b77c7418ff75b9dfd Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/fake_config_module.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/fake_config_module2.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/fake_config_module2.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7ac484c3c1a5cee15205c16eb7d9769cb8b9f142 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/fake_config_module2.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/hypothesis_utils.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/hypothesis_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3e655b8a72999088bbf759f407ed750d918465c5 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/hypothesis_utils.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/inductor_utils.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/inductor_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c6db9e7db2361775e1394ce5f866817000bef003 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/inductor_utils.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/jit_metaprogramming_utils.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/jit_metaprogramming_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..45d6f8539448ae76d62eb8197eab3d187303726a Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/jit_metaprogramming_utils.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/jit_utils.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/jit_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..81ec52174e3f0a4409294d4648584564882f4fc1 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/jit_utils.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/logging_tensor.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/logging_tensor.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2a8bb61caf7bf32d7bbe216443e5e1734d1ef987 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/logging_tensor.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/logging_utils.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/logging_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ef1157809c18e0db987bf3652ee84efa007a2735 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/logging_utils.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/subclasses.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/subclasses.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4dba883aa171517aa3cd36b2a30f861055ef007c Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/subclasses.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/torchbind_impls.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/torchbind_impls.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..44163ca02de46d115a0ee81bd5ffd60f00a8f4cf Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/torchbind_impls.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/autograd_function_db.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/autograd_function_db.py new file mode 100644 index 0000000000000000000000000000000000000000..46abb4bb758dde5752d974f5459ccd77ac9c0f74 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/autograd_function_db.py @@ -0,0 +1,633 @@ +# mypy: ignore-errors + +import torch +from functools import partial +from torch.testing import make_tensor +from torch.testing._internal.opinfo.core import ( + OpInfo, + SampleInput, +) +from torch.testing._internal.common_dtype import all_types_and +import numpy as np + +# Note: [autograd.Function db] +# +# This is a collection of autograd.Function test cases written as OpInfos +# so they can easily be consumed by OpInfo-based tests to check if a subsystem +# supports autograd.Function. +# +# Axes: +# - saves {output, input, intermediate, non-tensor} +# - {inputs, output} x {single tensor, tensors, arbitrary objects} +# - Uses {mark_dirty, mark_non_differentiable, once_differentiable} + + +def to_numpy(tensor): + return tensor.cpu().numpy() + + +class NumpyCube(torch.autograd.Function): + @staticmethod + def forward(input): + input_np = to_numpy(input) + dinput = torch.tensor(3 * input_np ** 2, device=input.device) + return torch.tensor(input_np ** 3, device=input.device), dinput + + @staticmethod + def setup_context(ctx, inputs, output): + ctx.save_for_backward(inputs[0], output[1]) + ctx.save_for_forward(inputs[0], output[1]) + + @staticmethod + def backward(ctx, grad_output, grad_saved): + input, dinput = ctx.saved_tensors + return NumpyMul.apply(grad_output, dinput) + 6 * NumpyMul.apply(grad_saved, input) + + @staticmethod + def vmap(info, in_dims, input): + result = NumpyCube.apply(input) + return result, (in_dims[0], in_dims[0]) + + @staticmethod + def jvp(ctx, input_tangent): + input, dinput = ctx.saved_tensors + return NumpyMul.apply(input_tangent, dinput), 6 * NumpyMul.apply(input_tangent, input) + + +class CubeGenVmap(torch.autograd.Function): + generate_vmap_rule = True + + @staticmethod + def forward(x): + return x ** 3, 3 * x ** 2 + + @staticmethod + def setup_context(ctx, inputs, outputs): + ctx.save_for_backward(inputs[0], outputs[1]) + ctx.save_for_forward(inputs[0], outputs[1]) + + @staticmethod + def backward(ctx, grad_output, grad_saved): + _input, dinput = ctx.saved_tensors + result = grad_output * dinput + 6 * dinput + return result + + @staticmethod + def jvp(ctx, input_tangent): + input, dinput = ctx.saved_tensors + return MulGenVmap.apply(input_tangent, dinput), 6 * NumpyMul.apply(input_tangent, input) + + +def sample_inputs_numpy_cube(opinfo, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + yield SampleInput(make_arg(1, low=0.8, high=2), args=()) + + +class NumpyCubeNotComposable(torch.autograd.Function): + @staticmethod + def forward(input): + input_np = to_numpy(input) + return torch.tensor(input_np ** 3, device=input.device), input_np + + @staticmethod + def setup_context(ctx, inputs, output): + _, input_np = output + ctx.input_np = input_np + ctx.device = inputs[0].device + + @staticmethod + @torch.autograd.function.once_differentiable + def backward(ctx, grad_output, grad_saved): + result_np = 3 * (ctx.input_np ** 2) + return torch.tensor(result_np, device=ctx.device) + + +class NumpyMul(torch.autograd.Function): + @staticmethod + def forward(x, y): + return torch.tensor(to_numpy(x) * to_numpy(y), device=x.device) + + @staticmethod + def setup_context(ctx, inputs, output): + ctx.save_for_backward(*inputs) + ctx.save_for_forward(*inputs) + + @staticmethod + def backward(ctx, grad_output): + x, y = ctx.saved_tensors + gx = None + if ctx.needs_input_grad[0]: + gx = NumpyMul.apply(grad_output, y) + gy = None + if ctx.needs_input_grad[1]: + gy = NumpyMul.apply(grad_output, x) + return gx, gy + + @staticmethod + def vmap(info, in_dims, x, y): + x_bdim, y_bdim = in_dims + x = x.movedim(x_bdim, -1) if x_bdim is not None else x.unsqueeze(-1) + y = y.movedim(y_bdim, -1) if y_bdim is not None else y.unsqueeze(-1) + result = NumpyMul.apply(x, y) + result = result.movedim(-1, 0) + return result, 0 + + @staticmethod + def jvp(ctx, x_tangent, y_tangent): + x, y = ctx.saved_tensors + return x_tangent * y + y_tangent * x + +def sample_inputs_numpy_mul(opinfo, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + # Broadcasting + yield SampleInput(make_arg(4, low=0.9, high=2), args=(make_arg(3, 4, low=0.9, high=2),)) + +def sample_inputs_numpy_mul_scalar(opinfo, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + yield SampleInput(make_arg(4, low=0.9, high=2), args=(), kwargs={"scalar": 3.14}) + +class MulGenVmap(torch.autograd.Function): + generate_vmap_rule = True + + @staticmethod + def forward(x, y): + return x * y + + @staticmethod + def setup_context(ctx, inputs, outputs): + ctx.save_for_backward(*inputs) + ctx.save_for_forward(*inputs) + + @staticmethod + def backward(ctx, grad_output): + x, y = ctx.saved_tensors + gx = None + if ctx.needs_input_grad[0]: + gx = MulGenVmap.apply(grad_output, y) + gy = None + if ctx.needs_input_grad[1]: + gy = MulGenVmap.apply(grad_output, x) + return gx, gy + + @staticmethod + def jvp(ctx, x_tangent, y_tangent): + x, y = ctx.saved_tensors + return x_tangent * y + y_tangent * x + + +class NumpyExp_(torch.autograd.Function): + @staticmethod + def forward(x): + x_np = to_numpy(x) + np.exp(x_np, x_np) + return x + + @staticmethod + def setup_context(ctx, inputs, output): + x, = inputs + ctx.mark_dirty(x) + ctx.save_for_backward(output) + ctx.save_for_forward(output) + + @staticmethod + def backward(ctx, grad_output): + output, = ctx.saved_tensors + return NumpyMul.apply(grad_output, output) + + @staticmethod + def vmap(info, in_dims, x): + NumpyExp_.apply(x) + return x, in_dims[0] + + @staticmethod + def jvp(ctx, x_tangent): + # Doesn't call numpy operations because I didn't want to write NumpyMul_ + output, = ctx.saved_tensors + x_tangent.mul_(output) + return x_tangent + +class NumpySort(torch.autograd.Function): + @staticmethod + def forward(x, dim): + device = x.device + x = to_numpy(x) + ind = np.argsort(x, axis=dim) + ind_inv = np.argsort(ind, axis=dim) + return ( + torch.tensor(x, device=device), + torch.tensor(ind, device=device), + torch.tensor(ind_inv, device=device), + ) + + @staticmethod + def setup_context(ctx, inputs, output): + _x, dim = inputs + _, ind, ind_inv = output + ctx.mark_non_differentiable(ind, ind_inv) + ctx.save_for_backward(ind, ind_inv) + ctx.save_for_forward(ind, ind_inv) + ctx.dim = dim + + @staticmethod + def backward(ctx, grad_output, _0, _1): + ind, ind_inv = ctx.saved_tensors + return NumpyTake.apply(grad_output, ind_inv, ind, ctx.dim), None + + @staticmethod + def vmap(info, in_dims, x, dim): + x_bdim, _ = in_dims + x = x.movedim(x_bdim, 0) + # wrap dim + dim = dim if dim >= 0 else dim + x.dim() - 1 + return NumpySort.apply(x, dim + 1), (0, 0, 0) + + @staticmethod + def jvp(ctx, x_tangent, _): + ind, ind_inv = ctx.saved_tensors + return NumpyTake.apply(x_tangent, ind, ind_inv, ctx.dim), None, None + +class SortGenVmap(torch.autograd.Function): + generate_vmap_rule = True + + @staticmethod + def forward(x, dim): + ind = torch.argsort(x, dim=dim) + ind_inv = torch.argsort(ind, axis=dim) + result = torch.take_along_dim(x, ind, dim=dim) + return result, ind, ind_inv + + @staticmethod + def setup_context(ctx, inputs, outputs): + x, dim = inputs + _, ind, ind_inv = outputs + ctx.mark_non_differentiable(ind, ind_inv) + ctx.save_for_backward(ind, ind_inv) + ctx.save_for_forward(ind, ind_inv) + ctx.dim = dim + + @staticmethod + def backward(ctx, grad_output, _0, _1): + ind, ind_inv = ctx.saved_tensors + return TakeGenVmap.apply(grad_output, ind_inv, ind, ctx.dim), None + + @staticmethod + def jvp(ctx, x_tangent, _): + ind, ind_inv = ctx.saved_tensors + return TakeGenVmap.apply(x_tangent, ind, ind_inv, ctx.dim), None, None + + +def sample_inputs_numpy_sort(opinfo, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + yield SampleInput(make_arg(3, 5), args=(1,)) + + +def sample_inputs_numpy_take(opinfo, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + tensor = make_arg(3, 5) + dim = 1 + _, ind, ind_inv = NumpySort.apply(tensor, 1) + yield SampleInput(tensor, args=(ind, ind_inv, dim)) + + +class NumpyTake(torch.autograd.Function): + @staticmethod + def forward(x, ind, ind_inv, dim): + device = x.device + x = to_numpy(x) + ind = to_numpy(ind) + return torch.tensor(np.take_along_axis(x, ind, dim), device=device) + + @staticmethod + def setup_context(ctx, inputs, output): + _x, ind, ind_inv, dim = inputs + ctx.save_for_backward(ind, ind_inv) + ctx.save_for_forward(ind, ind_inv) + ctx.dim = dim + + @staticmethod + def backward(ctx, grad_output): + ind, ind_inv = ctx.saved_tensors + result = NumpyTake.apply(grad_output, ind_inv, ind, ctx.dim) + return result, None, None, None + + @staticmethod + def vmap(info, in_dims, x, ind, ind_inv, dim): + x_bdim, ind_bdim, ind_inv_bdim, _ = in_dims + + # wrap dim + logical_dim = x.dim() if x_bdim is None else x_bdim - 1 + dim = dim if dim >= 0 else dim + logical_dim + + def expand_bdim(x, x_bdim): + if x_bdim is None: + return x.expand(info.batch_size, *x.shape) + return x.movedim(x_bdim, 0) + + x = expand_bdim(x, x_bdim) + ind = expand_bdim(ind, ind_bdim) + ind_inv = expand_bdim(ind_inv, ind_inv_bdim) + + return NumpyTake.apply(x, ind, ind_inv, dim + 1), 0 + + @staticmethod + def jvp(ctx, x_tangent, ind_tangent, ind_inv_tangent, _): + assert ind_tangent is None + assert ind_inv_tangent is None + ind, ind_inv = ctx.saved_tensors + return NumpyTake.apply(x_tangent, ind, ind_inv, ctx.dim) + +class TakeGenVmap(torch.autograd.Function): + generate_vmap_rule = True + + @staticmethod + def forward(x, ind, ind_inv, dim): + return torch.take_along_dim(x, ind, dim) + + @staticmethod + def setup_context(ctx, inputs, outputs): + _x, ind, ind_inv, dim = inputs + ctx.save_for_backward(ind, ind_inv) + ctx.save_for_forward(ind, ind_inv) + ctx.dim = dim + + @staticmethod + def backward(ctx, grad_output): + ind, ind_inv = ctx.saved_tensors + result = TakeGenVmap.apply(grad_output, ind_inv, ind, ctx.dim) + return result, None, None, None + + @staticmethod + def jvp(ctx, x_tangent, ind_tangent, ind_inv_tangent, _): + ind, ind_inv = ctx.saved_tensors + return TakeGenVmap.apply(x_tangent, ind, ind_inv, ctx.dim) + +class Select(torch.autograd.Function): + @staticmethod + def forward(x, idx): + return x[idx] + + @staticmethod + def setup_context(ctx, inputs, output): + x, idx = inputs + ctx.x_shape = x.shape + ctx.idx = idx + + @staticmethod + def backward(ctx, grad_output): + result = grad_output.new_zeros(ctx.x_shape) + result[ctx.idx] = grad_output + return result, None + + @staticmethod + def vmap(info, in_dims, x, idx): + x_bdim, _ = in_dims + x = x.movedim(x_bdim, 1) + return Select.apply(x, idx), 0 + + @staticmethod + def jvp(ctx, x_tangent, _): + return Select.apply(x_tangent, ctx.idx) + +class SelectGenVmap(torch.autograd.Function): + generate_vmap_rule = True + + @staticmethod + def forward(x, idx): + return x[idx] + + @staticmethod + def setup_context(ctx, inputs, outputs): + x, idx = inputs + ctx.x_shape = x.shape + ctx.idx = idx + + @staticmethod + def backward(ctx, grad_output): + result = grad_output.new_zeros(ctx.x_shape) + result[ctx.idx] = grad_output + return result, None + + @staticmethod + def jvp(ctx, x_tangent, _): + return SelectGenVmap.apply(x_tangent, ctx.idx) + + +def sample_inputs_select(opinfo, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + yield SampleInput(make_arg(3, 5), args=(2,)) + +class ScaleGradGenVmap(torch.autograd.Function): + generate_vmap_rule = True + scale = 3.14 + + @staticmethod + def forward(x): + return x.clone() + + @staticmethod + def setup_context(ctx, inputs, outputs): + pass + + @staticmethod + def backward(ctx, grad_output): + return grad_output * ScaleGradGenVmap.scale + + @staticmethod + def jvp(ctx, x_tangent): + return x_tangent * ScaleGradGenVmap.scale + +class ZeroGradientsGenVmap(torch.autograd.Function): + generate_vmap_rule = True + + @staticmethod + def forward(x, y): + return x.clone(), y.clone() + + @staticmethod + def setup_context(ctx, inputs, outputs): + pass + + @staticmethod + def backward(ctx, gx, gy): + # Intentionally returning torch.zeros instead of zeros_like or new_zeros. + # Also intentionally not None. + return ( + # Intentionally too-large gradient + torch.zeros(3, 4, *gx.shape, dtype=gx.dtype, device=gx.device), + torch.zeros(gy.shape, dtype=gy.dtype, device=gy.device), + ) + + @staticmethod + def jvp(ctx, gx, gy): + # Intentionally returning torch.zeros instead of zeros_like or new_zeros. + # Also intentionally not None. + return ( + torch.zeros(gx.shape, dtype=gx.dtype, device=gx.device), + torch.zeros(gy.shape, dtype=gy.dtype, device=gy.device), + ) + + +def sample_inputs_forward_default_args(opinfo, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + yield SampleInput(make_arg(3, 5)) + + +class ForwardHasDefaultArgs(torch.autograd.Function): + @staticmethod + def forward(x, idx=(2,)): + return x[idx] + + @staticmethod + def setup_context(ctx, inputs, output): + x, idx = inputs + ctx.x_shape = x.shape + ctx.idx = idx + + @staticmethod + def backward(ctx, grad_output): + result = grad_output.new_zeros(ctx.x_shape) + result[ctx.idx] = grad_output + return result, None + + @staticmethod + def vmap(info, in_dims, x, idx): + x_bdim, _ = in_dims + x = x.movedim(x_bdim, 1) + return ForwardHasDefaultArgs.apply(x, idx), 0 + + @staticmethod + def jvp(ctx, x_tangent, _): + return ForwardHasDefaultArgs.apply(x_tangent, ctx.idx) + + +autograd_function_db = [ + OpInfo( + 'NumpyCubeAutogradFunction', + op=NumpyCube.apply, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + sample_inputs_func=sample_inputs_numpy_cube, + dtypes=all_types_and(torch.bool, torch.half), + supports_out=False, + ), + OpInfo( + 'NumpyExpMarkDirtyAutogradFunction', + op=lambda x: NumpyExp_.apply(x.clone()), + inplace_variant=NumpyExp_.apply, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + sample_inputs_func=sample_inputs_numpy_cube, + dtypes=all_types_and(torch.bool, torch.half), + supports_out=False, + ), + OpInfo( + 'NumpyMulAutogradFunction', + op=NumpyMul.apply, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + sample_inputs_func=sample_inputs_numpy_mul, + dtypes=all_types_and(torch.bool, torch.half), + supports_out=False, + ), + OpInfo( + 'NumpyCubeNotComposableAutogradFunction', + op=lambda x: NumpyCubeNotComposable.apply(x)[0], + supports_forward_ad=False, + supports_fwgrad_bwgrad=False, + sample_inputs_func=sample_inputs_numpy_cube, + dtypes=all_types_and(torch.bool, torch.half), + supports_out=False, + ), + OpInfo( + 'NumpySortAutogradFunction', + op=NumpySort.apply, + supports_forward_ad=False, + supports_fwgrad_bwgrad=False, + sample_inputs_func=sample_inputs_numpy_sort, + dtypes=all_types_and(torch.bool, torch.half), + supports_out=False, + gradcheck_wrapper=lambda y, ind: y, + ), + OpInfo( + 'NumpyTakeAutogradFunction', + op=NumpyTake.apply, + supports_forward_ad=False, + supports_fwgrad_bwgrad=False, + sample_inputs_func=sample_inputs_numpy_take, + dtypes=all_types_and(torch.bool, torch.half), + supports_out=False, + ), + OpInfo( + 'SelectAutogradFunction', + op=Select.apply, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + sample_inputs_func=sample_inputs_select, + dtypes=all_types_and(torch.bool, torch.half), + supports_out=False, + ), + OpInfo( + 'CubeGenVmapAutogradFunction', + op=CubeGenVmap.apply, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + sample_inputs_func=sample_inputs_numpy_cube, + dtypes=all_types_and(torch.bool, torch.half), + supports_out=False, + ), + OpInfo( + 'MulGenVmapAutogradFunction', + op=MulGenVmap.apply, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + sample_inputs_func=sample_inputs_numpy_mul, + dtypes=all_types_and(torch.bool, torch.half), + supports_out=False, + ), + OpInfo( + 'SortGenVmapAutogradFunction', + op=SortGenVmap.apply, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + sample_inputs_func=sample_inputs_numpy_sort, + dtypes=all_types_and(torch.bool, torch.half), + supports_out=False, + gradcheck_wrapper=lambda y, ind: y, + ), + OpInfo( + 'SelectGenVmapAutogradFunction', + op=SelectGenVmap.apply, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + sample_inputs_func=sample_inputs_select, + dtypes=all_types_and(torch.bool, torch.half), + supports_out=False, + ), + OpInfo( + 'ScaleGradGenVmapAutogradFunction', + op=ScaleGradGenVmap.apply, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + sample_inputs_func=sample_inputs_numpy_cube, + dtypes=all_types_and(torch.bool, torch.half), + supports_out=False, + ), + OpInfo( + 'ZeroGradientsGenVmapAutogradFunction', + op=ZeroGradientsGenVmap.apply, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + sample_inputs_func=sample_inputs_numpy_mul, + dtypes=all_types_and(torch.bool, torch.half), + supports_out=False, + ), + OpInfo( + 'ForwardHasDefaultArgsAutogradFunction', + op=ForwardHasDefaultArgs.apply, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + sample_inputs_func=sample_inputs_forward_default_args, + dtypes=all_types_and(torch.bool, torch.half), + supports_out=False, + ), +] diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/codegen/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/codegen/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1e3572cfc4c6a0ddc3d8fa2e1b056415204acdfa --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/codegen/__init__.py @@ -0,0 +1 @@ +# mypy: ignore-errors diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/common_device_type.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/common_device_type.py new file mode 100644 index 0000000000000000000000000000000000000000..9acc6f0f7567627c30411ed4ddf61ba2022418bb --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/common_device_type.py @@ -0,0 +1,2038 @@ +# mypy: ignore-errors + +import copy +import gc +import inspect +import os +import runpy +import sys +import threading +import unittest +from collections import namedtuple +from collections.abc import Callable, Iterable, Sequence +from enum import Enum +from functools import partial, wraps +from typing import Any, ClassVar, Optional, TypeVar, Union +from typing_extensions import ParamSpec + +import torch +from torch._inductor.utils import GPU_TYPES +from torch.testing._internal.common_cuda import ( + _get_torch_cuda_version, + _get_torch_rocm_version, + TEST_CUSPARSE_GENERIC, + TEST_HIPSPARSE_GENERIC, +) +from torch.testing._internal.common_dtype import get_all_dtypes +from torch.testing._internal.common_utils import ( + _TestParametrizer, + clear_tracked_input, + compose_parametrize_fns, + dtype_name, + get_tracked_input, + IS_FBCODE, + IS_MACOS, + is_privateuse1_backend_available, + IS_REMOTE_GPU, + IS_S390X, + IS_SANDCASTLE, + IS_WINDOWS, + NATIVE_DEVICES, + PRINT_REPRO_ON_FAILURE, + skipCUDANonDefaultStreamIf, + skipIfTorchDynamo, + TEST_HPU, + TEST_MKL, + TEST_MPS, + TEST_WITH_ASAN, + TEST_WITH_MIOPEN_SUGGEST_NHWC, + TEST_WITH_MTIA, + TEST_WITH_ROCM, + TEST_WITH_TORCHINDUCTOR, + TEST_WITH_TSAN, + TEST_WITH_UBSAN, + TEST_XPU, + TestCase, +) + + +_T = TypeVar("_T") +_P = ParamSpec("_P") + +try: + import psutil # type: ignore[import] + + HAS_PSUTIL = True +except ModuleNotFoundError: + HAS_PSUTIL = False + psutil = None + +# Note [Writing Test Templates] +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# +# This note was written shortly after the PyTorch 1.9 release. +# If you notice it's out-of-date or think it could be improved then please +# file an issue. +# +# PyTorch has its own framework for instantiating test templates. That is, for +# taking test classes that look similar to unittest or pytest +# compatible test classes and optionally doing the following: +# +# - instantiating a version of the test class for each available device type +# (often the CPU, CUDA, and META device types) +# - further instantiating a version of each test that's always specialized +# on the test class's device type, and optionally specialized further +# on datatypes or operators +# +# This functionality is similar to pytest's parametrize functionality +# (see https://docs.pytest.org/en/6.2.x/parametrize.html), but with considerable +# additional logic that specializes the instantiated test classes for their +# device types (see CPUTestBase and CUDATestBase below), supports a variety +# of composable decorators that allow for test filtering and setting +# tolerances, and allows tests parametrized by operators to instantiate +# only the subset of device type x dtype that operator supports. +# +# This framework was built to make it easier to write tests that run on +# multiple device types, multiple datatypes (dtypes), and for multiple +# operators. It's also useful for controlling which tests are run. For example, +# only tests that use a CUDA device can be run on platforms with CUDA. +# Let's dive in with an example to get an idea for how it works: +# +# -------------------------------------------------------- +# A template class (looks like a regular unittest TestCase) +# class TestClassFoo(TestCase): +# +# # A template test that can be specialized with a device +# # NOTE: this test case is not runnable by unittest or pytest because it +# # accepts an extra positional argument, "device", that they do not understand +# def test_bar(self, device): +# pass +# +# # Function that instantiates a template class and its tests +# instantiate_device_type_tests(TestCommon, globals()) +# -------------------------------------------------------- +# +# In the above code example we see a template class and a single test template +# that can be instantiated with a device. The function +# instantiate_device_type_tests(), called at file scope, instantiates +# new test classes, one per available device type, and new tests in those +# classes from these templates. It actually does this by removing +# the class TestClassFoo and replacing it with classes like TestClassFooCPU +# and TestClassFooCUDA, instantiated test classes that inherit from CPUTestBase +# and CUDATestBase respectively. Additional device types, like XLA, +# (see https://github.com/pytorch/xla) can further extend the set of +# instantiated test classes to create classes like TestClassFooXLA. +# +# The test template, test_bar(), is also instantiated. In this case the template +# is only specialized on a device, so (depending on the available device +# types) it might become test_bar_cpu() in TestClassFooCPU and test_bar_cuda() +# in TestClassFooCUDA. We can think of the instantiated test classes as +# looking like this: +# +# -------------------------------------------------------- +# # An instantiated test class for the CPU device type +# class TestClassFooCPU(CPUTestBase): +# +# # An instantiated test that calls the template with the string representation +# # of a device from the test class's device type +# def test_bar_cpu(self): +# test_bar(self, 'cpu') +# +# # An instantiated test class for the CUDA device type +# class TestClassFooCUDA(CUDATestBase): +# +# # An instantiated test that calls the template with the string representation +# # of a device from the test class's device type +# def test_bar_cuda(self): +# test_bar(self, 'cuda:0') +# -------------------------------------------------------- +# +# These instantiated test classes ARE discoverable and runnable by both +# unittest and pytest. One thing that may be confusing, however, is that +# attempting to run "test_bar" will not work, despite it appearing in the +# original template code. This is because "test_bar" is no longer discoverable +# after instantiate_device_type_tests() runs, as the above snippet shows. +# Instead "test_bar_cpu" and "test_bar_cuda" may be run directly, or both +# can be run with the option "-k test_bar". +# +# Removing the template class and adding the instantiated classes requires +# passing "globals()" to instantiate_device_type_tests(), because it +# edits the file's Python objects. +# +# As mentioned, tests can be additionally parametrized on dtypes or +# operators. Datatype parametrization uses the @dtypes decorator and +# require a test template like this: +# +# -------------------------------------------------------- +# # A template test that can be specialized with a device and a datatype (dtype) +# @dtypes(torch.float32, torch.int64) +# def test_car(self, device, dtype) +# pass +# -------------------------------------------------------- +# +# If the CPU and CUDA device types are available this test would be +# instantiated as 4 tests that cover the cross-product of the two dtypes +# and two device types: +# +# - test_car_cpu_float32 +# - test_car_cpu_int64 +# - test_car_cuda_float32 +# - test_car_cuda_int64 +# +# The dtype is passed as a torch.dtype object. +# +# Tests parametrized on operators (actually on OpInfos, more on that in a +# moment...) use the @ops decorator and require a test template like this: +# -------------------------------------------------------- +# # A template test that can be specialized with a device, dtype, and OpInfo +# @ops(op_db) +# def test_car(self, device, dtype, op) +# pass +# -------------------------------------------------------- +# +# See the documentation for the @ops decorator below for additional details +# on how to use it and see the note [OpInfos] in +# common_methods_invocations.py for more details on OpInfos. +# +# A test parametrized over the entire "op_db", which contains hundreds of +# OpInfos, will likely have hundreds or thousands of instantiations. The +# test will be instantiated on the cross-product of device types, operators, +# and the dtypes the operator supports on that device type. The instantiated +# tests will have names like: +# +# - test_car_add_cpu_float32 +# - test_car_sub_cuda_int64 +# +# The first instantiated test calls the original test_car() with the OpInfo +# for torch.add as its "op" argument, the string 'cpu' for its "device" argument, +# and the dtype torch.float32 for is "dtype" argument. The second instantiated +# test calls the test_car() with the OpInfo for torch.sub, a CUDA device string +# like 'cuda:0' or 'cuda:1' for its "device" argument, and the dtype +# torch.int64 for its "dtype argument." +# +# In addition to parametrizing over device, dtype, and ops via OpInfos, the +# @parametrize decorator is supported for arbitrary parametrizations: +# -------------------------------------------------------- +# # A template test that can be specialized with a device, dtype, and value for x +# @parametrize("x", range(5)) +# def test_car(self, device, dtype, x) +# pass +# -------------------------------------------------------- +# +# See the documentation for @parametrize in common_utils.py for additional details +# on this. Note that the instantiate_device_type_tests() function will handle +# such parametrizations; there is no need to additionally call +# instantiate_parametrized_tests(). +# +# Clever test filtering can be very useful when working with parametrized +# tests. "-k test_car" would run every instantiated variant of the test_car() +# test template, and "-k test_car_add" runs every variant instantiated with +# torch.add. +# +# It is important to use the passed device and dtype as appropriate. Use +# helper functions like make_tensor() that require explicitly specifying +# the device and dtype so they're not forgotten. +# +# Test templates can use a variety of composable decorators to specify +# additional options and requirements, some are listed here: +# +# - @deviceCountAtLeast() +# Passes a list of strings representing all available devices of +# the test class's device type as the test template's "device" argument. +# If there are fewer devices than the value passed to the decorator +# the test is skipped. +# - @dtypes() +# In addition to accepting multiple dtypes, the @dtypes decorator +# can accept a sequence of tuple pairs of dtypes. The test template +# will be called with each tuple for its "dtype" argument. +# - @onlyNativeDeviceTypes +# Skips the test if the device is not a native device type (currently CPU, CUDA, Meta) +# - @onlyCPU +# Skips the test if the device is not a CPU device +# - @onlyCUDA +# Skips the test if the device is not a CUDA device +# - @onlyMPS +# Skips the test if the device is not a MPS device +# - @skipCPUIfNoLapack +# Skips the test if the device is a CPU device and LAPACK is not installed +# - @skipCPUIfNoMkl +# Skips the test if the device is a CPU device and MKL is not installed +# - @skipCUDAIfNoMagma +# Skips the test if the device is a CUDA device and MAGMA is not installed +# - @skipCUDAIfRocm +# Skips the test if the device is a CUDA device and ROCm is being used + + +# Note [Adding a Device Type] +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# +# To add a device type: +# +# (1) Create a new "TestBase" extending DeviceTypeTestBase. +# See CPUTestBase and CUDATestBase below. +# (2) Define the "device_type" attribute of the base to be the +# appropriate string. +# (3) Add logic to this file that appends your base class to +# device_type_test_bases when your device type is available. +# (4) (Optional) Write setUpClass/tearDownClass class methods that +# instantiate dependencies (see MAGMA in CUDATestBase). +# (5) (Optional) Override the "instantiate_test" method for total +# control over how your class creates tests. +# +# setUpClass is called AFTER tests have been created and BEFORE and ONLY IF +# they are run. This makes it useful for initializing devices and dependencies. + + +def _dtype_test_suffix(dtypes): + """Returns the test suffix for a dtype, sequence of dtypes, or None.""" + if isinstance(dtypes, (list, tuple)): + if len(dtypes) == 0: + return "" + return "_" + "_".join(dtype_name(d) for d in dtypes) + elif dtypes: + return f"_{dtype_name(dtypes)}" + else: + return "" + + +def _update_param_kwargs(param_kwargs, name, value): + """Adds a kwarg with the specified name and value to the param_kwargs dict.""" + # Make name plural (e.g. devices / dtypes) if the value is composite. + plural_name = f"{name}s" + + # Clear out old entries of the arg if any. + if name in param_kwargs: + del param_kwargs[name] + if plural_name in param_kwargs: + del param_kwargs[plural_name] + + if isinstance(value, (list, tuple)): + param_kwargs[plural_name] = value + elif value is not None: + param_kwargs[name] = value + + # Leave param_kwargs as-is when value is None. + + +class DeviceTypeTestBase(TestCase): + device_type: str = "generic_device_type" + + # Flag to disable test suite early due to unrecoverable error such as CUDA error. + _stop_test_suite = False + + # Precision is a thread-local setting since it may be overridden per test + _tls = threading.local() + _tls.precision = TestCase._precision + _tls.rel_tol = TestCase._rel_tol + + @property + def precision(self): + return self._tls.precision + + @precision.setter + def precision(self, prec): + self._tls.precision = prec + + @property + def rel_tol(self): + return self._tls.rel_tol + + @rel_tol.setter + def rel_tol(self, prec): + self._tls.rel_tol = prec + + # Returns a string representing the device that single device tests should use. + # Note: single device tests use this device exclusively. + @classmethod + def get_primary_device(cls): + return cls.device_type + + @classmethod + def _init_and_get_primary_device(cls): + try: + return cls.get_primary_device() + except Exception: + # For CUDATestBase, XPUTestBase, XLATestBase, and possibly others, the primary device won't be available + # until setUpClass() sets it. Call that manually here if needed. + if hasattr(cls, "setUpClass"): + cls.setUpClass() + return cls.get_primary_device() + + # Returns a list of strings representing all available devices of this + # device type. The primary device must be the first string in the list + # and the list must contain no duplicates. + # Note: UNSTABLE API. Will be replaced once PyTorch has a device generic + # mechanism of acquiring all available devices. + @classmethod + def get_all_devices(cls): + return [cls.get_primary_device()] + + # Returns the dtypes the test has requested. + # Prefers device-specific dtype specifications over generic ones. + @classmethod + def _get_dtypes(cls, test): + if not hasattr(test, "dtypes"): + return None + + default_dtypes = test.dtypes.get("all") + msg = f"@dtypes is mandatory when using @dtypesIf however '{test.__name__}' didn't specify it" + assert default_dtypes is not None, msg + + return test.dtypes.get(cls.device_type, default_dtypes) + + def _get_precision_override(self, test, dtype): + if not hasattr(test, "precision_overrides"): + return self.precision + return test.precision_overrides.get(dtype, self.precision) + + def _get_tolerance_override(self, test, dtype): + if not hasattr(test, "tolerance_overrides"): + return self.precision, self.rel_tol + return test.tolerance_overrides.get(dtype, tol(self.precision, self.rel_tol)) + + def _apply_precision_override_for_test(self, test, param_kwargs): + dtype = param_kwargs.get("dtype") + dtype = param_kwargs.get("dtypes", dtype) + if dtype: + self.precision = self._get_precision_override(test, dtype) + self.precision, self.rel_tol = self._get_tolerance_override(test, dtype) + + # Creates device-specific tests. + @classmethod + def instantiate_test(cls, name, test, *, generic_cls=None): + def instantiate_test_helper( + cls, name, *, test, param_kwargs=None, decorator_fn=lambda _: [] + ): + # Add the device param kwarg if the test needs device or devices. + param_kwargs = {} if param_kwargs is None else param_kwargs + test_sig_params = inspect.signature(test).parameters + if "device" in test_sig_params or "devices" in test_sig_params: + device_arg: str = cls._init_and_get_primary_device() + if hasattr(test, "num_required_devices"): + device_arg = cls.get_all_devices() + _update_param_kwargs(param_kwargs, "device", device_arg) + + # Apply decorators based on param kwargs. + for decorator in decorator_fn(param_kwargs): + test = decorator(test) + + # Constructs the test + @wraps(test) + def instantiated_test(self, param_kwargs=param_kwargs): + # Sets precision and runs test + # Note: precision is reset after the test is run + guard_precision = self.precision + guard_rel_tol = self.rel_tol + try: + self._apply_precision_override_for_test(test, param_kwargs) + result = test(self, **param_kwargs) + except RuntimeError as rte: + # check if rte should stop entire test suite. + self._stop_test_suite = self._should_stop_test_suite() + # Check if test has been decorated with `@expectedFailure` + # Using `__unittest_expecting_failure__` attribute, see + # https://github.com/python/cpython/blob/ffa505b580464/Lib/unittest/case.py#L164 + # In that case, make it fail with "unexpected success" by suppressing exception + if ( + getattr(test, "__unittest_expecting_failure__", False) + and self._stop_test_suite + ): + import sys + + print( + "Suppressing fatal exception to trigger unexpected success", + file=sys.stderr, + ) + return + # raise the runtime error as is for the test suite to record. + raise rte + finally: + self.precision = guard_precision + self.rel_tol = guard_rel_tol + + return result + + assert not hasattr(cls, name), f"Redefinition of test {name}" + setattr(cls, name, instantiated_test) + + def default_parametrize_fn(test, generic_cls, device_cls): + # By default, no parametrization is needed. + yield (test, "", {}, lambda _: []) + + # Parametrization decorators set the parametrize_fn attribute on the test. + parametrize_fn = getattr(test, "parametrize_fn", default_parametrize_fn) + + # If one of the @dtypes* decorators is present, also parametrize over the dtypes set by it. + dtypes = cls._get_dtypes(test) + if dtypes is not None: + + def dtype_parametrize_fn(test, generic_cls, device_cls, dtypes=dtypes): + for dtype in dtypes: + param_kwargs: dict[str, Any] = {} + _update_param_kwargs(param_kwargs, "dtype", dtype) + + # Note that an empty test suffix is set here so that the dtype can be appended + # later after the device. + yield (test, "", param_kwargs, lambda _: []) + + parametrize_fn = compose_parametrize_fns( + dtype_parametrize_fn, parametrize_fn + ) + + # Instantiate the parametrized tests. + for ( + test, # noqa: B020 + test_suffix, + param_kwargs, + decorator_fn, + ) in parametrize_fn(test, generic_cls, cls): + test_suffix = "" if test_suffix == "" else "_" + test_suffix + cls_device_type = ( + cls.device_type + if cls.device_type != "privateuse1" + else torch._C._get_privateuse1_backend_name() + ) + device_suffix = "_" + cls_device_type + + # Note: device and dtype suffix placement + # Special handling here to place dtype(s) after device according to test name convention. + dtype_kwarg = None + if "dtype" in param_kwargs or "dtypes" in param_kwargs: + dtype_kwarg = ( + param_kwargs["dtypes"] + if "dtypes" in param_kwargs + else param_kwargs["dtype"] + ) + test_name = ( + f"{name}{test_suffix}{device_suffix}{_dtype_test_suffix(dtype_kwarg)}" + ) + + instantiate_test_helper( + cls=cls, + name=test_name, + test=test, + param_kwargs=param_kwargs, + decorator_fn=decorator_fn, + ) + + def run(self, result=None): + super().run(result=result) + # Early terminate test if _stop_test_suite is set. + if self._stop_test_suite: + result.stop() + + +class CPUTestBase(DeviceTypeTestBase): + device_type = "cpu" + + # No critical error should stop CPU test suite + def _should_stop_test_suite(self): + return False + + +class CUDATestBase(DeviceTypeTestBase): + device_type = "cuda" + _do_cuda_memory_leak_check = True + _do_cuda_non_default_stream = True + primary_device: ClassVar[str] + cudnn_version: ClassVar[Any] + no_magma: ClassVar[bool] + no_cudnn: ClassVar[bool] + + def has_cudnn(self): + return not self.no_cudnn + + @classmethod + def get_primary_device(cls): + return cls.primary_device + + @classmethod + def get_all_devices(cls): + primary_device_idx = int(cls.get_primary_device().split(":")[1]) + num_devices = torch.cuda.device_count() + + prim_device = cls.get_primary_device() + cuda_str = "cuda:{0}" + non_primary_devices = [ + cuda_str.format(idx) + for idx in range(num_devices) + if idx != primary_device_idx + ] + return [prim_device] + non_primary_devices + + @classmethod + def setUpClass(cls): + # has_magma shows up after cuda is initialized + t = torch.ones(1).cuda() + cls.no_magma = not torch.cuda.has_magma + + # Determines if cuDNN is available and its version + cls.no_cudnn = not torch.backends.cudnn.is_acceptable(t) + cls.cudnn_version = None if cls.no_cudnn else torch.backends.cudnn.version() + + # Acquires the current device as the primary (test) device + cls.primary_device = f"cuda:{torch.cuda.current_device()}" + + +# See Note [Lazy Tensor tests in device agnostic testing] +lazy_ts_backend_init = False + + +class LazyTestBase(DeviceTypeTestBase): + device_type = "lazy" + + def _should_stop_test_suite(self): + return False + + @classmethod + def setUpClass(cls): + import torch._lazy + import torch._lazy.metrics + import torch._lazy.ts_backend + + global lazy_ts_backend_init + if not lazy_ts_backend_init: + # Need to connect the TS backend to lazy key before running tests + torch._lazy.ts_backend.init() + lazy_ts_backend_init = True + + +class MPSTestBase(DeviceTypeTestBase): + device_type = "mps" + primary_device: ClassVar[str] + + @classmethod + def get_primary_device(cls): + return cls.primary_device + + @classmethod + def get_all_devices(cls): + # currently only one device is supported on MPS backend + prim_device = cls.get_primary_device() + return [prim_device] + + @classmethod + def setUpClass(cls): + cls.primary_device = "mps:0" + + def _should_stop_test_suite(self): + return False + + +class XPUTestBase(DeviceTypeTestBase): + device_type = "xpu" + primary_device: ClassVar[str] + + @classmethod + def get_primary_device(cls): + return cls.primary_device + + @classmethod + def get_all_devices(cls): + # currently only one device is supported on MPS backend + primary_device_idx = int(cls.get_primary_device().split(":")[1]) + num_devices = torch.xpu.device_count() + + prim_device = cls.get_primary_device() + xpu_str = "xpu:{0}" + non_primary_devices = [ + xpu_str.format(idx) + for idx in range(num_devices) + if idx != primary_device_idx + ] + return [prim_device] + non_primary_devices + + @classmethod + def setUpClass(cls): + cls.primary_device = f"xpu:{torch.xpu.current_device()}" + + def _should_stop_test_suite(self): + return False + + +class HPUTestBase(DeviceTypeTestBase): + device_type = "hpu" + primary_device: ClassVar[str] + + @classmethod + def get_primary_device(cls): + return cls.primary_device + + @classmethod + def setUpClass(cls): + cls.primary_device = "hpu:0" + + +class PrivateUse1TestBase(DeviceTypeTestBase): + primary_device: ClassVar[str] + device_mod = None + device_type = "privateuse1" + + @classmethod + def get_primary_device(cls): + return cls.primary_device + + @classmethod + def get_all_devices(cls): + primary_device_idx = int(cls.get_primary_device().split(":")[1]) + num_devices = cls.device_mod.device_count() + prim_device = cls.get_primary_device() + device_str = f"{cls.device_type}:{{0}}" + non_primary_devices = [ + device_str.format(idx) + for idx in range(num_devices) + if idx != primary_device_idx + ] + return [prim_device] + non_primary_devices + + @classmethod + def setUpClass(cls): + cls.device_type = torch._C._get_privateuse1_backend_name() + cls.device_mod = getattr(torch, cls.device_type, None) + assert ( + cls.device_mod is not None + ), f"""torch has no module of `{cls.device_type}`, you should register + a module by `torch._register_device_module`.""" + cls.primary_device = f"{cls.device_type}:{cls.device_mod.current_device()}" + + +# Adds available device-type-specific test base classes +def get_device_type_test_bases(): + # set type to List[Any] due to mypy list-of-union issue: + # https://github.com/python/mypy/issues/3351 + test_bases: list[Any] = [] + + if IS_SANDCASTLE or IS_FBCODE: + if IS_REMOTE_GPU: + # Skip if sanitizer is enabled or we're on MTIA machines + if ( + not TEST_WITH_ASAN + and not TEST_WITH_TSAN + and not TEST_WITH_UBSAN + and not TEST_WITH_MTIA + ): + test_bases.append(CUDATestBase) + else: + test_bases.append(CPUTestBase) + else: + test_bases.append(CPUTestBase) + if torch.cuda.is_available(): + test_bases.append(CUDATestBase) + + if is_privateuse1_backend_available(): + test_bases.append(PrivateUse1TestBase) + # Disable MPS testing in generic device testing temporarily while we're + # ramping up support. + # elif torch.backends.mps.is_available(): + # test_bases.append(MPSTestBase) + + return test_bases + + +device_type_test_bases = get_device_type_test_bases() + + +def filter_desired_device_types(device_type_test_bases, except_for=None, only_for=None): + # device type cannot appear in both except_for and only_for + intersect = set(except_for if except_for else []) & set( + only_for if only_for else [] + ) + assert not intersect, ( + f"device ({intersect}) appeared in both except_for and only_for" + ) + + # Replace your privateuse1 backend name with 'privateuse1' + if is_privateuse1_backend_available(): + privateuse1_backend_name = torch._C._get_privateuse1_backend_name() + + def func_replace(x: str): + return x.replace(privateuse1_backend_name, "privateuse1") + + except_for = ( + ([func_replace(x) for x in except_for] if except_for is not None else None) + if not isinstance(except_for, str) + else func_replace(except_for) + ) + only_for = ( + ([func_replace(x) for x in only_for] if only_for is not None else None) + if not isinstance(only_for, str) + else func_replace(only_for) + ) + + if except_for: + device_type_test_bases = filter( + lambda x: x.device_type not in except_for, device_type_test_bases + ) + if only_for: + device_type_test_bases = filter( + lambda x: x.device_type in only_for, device_type_test_bases + ) + + return list(device_type_test_bases) + + +# Note [How to extend DeviceTypeTestBase to add new test device] +# The following logic optionally allows downstream projects like pytorch/xla to +# add more test devices. +# Instructions: +# - Add a python file (e.g. pytorch/xla/test/pytorch_test_base.py) in downstream project. +# - Inside the file, one should inherit from `DeviceTypeTestBase` class and define +# a new DeviceTypeTest class (e.g. `XLATestBase`) with proper implementation of +# `instantiate_test` method. +# - DO NOT import common_device_type inside the file. +# `runpy.run_path` with `globals()` already properly setup the context so that +# `DeviceTypeTestBase` is already available. +# - Set a top-level variable `TEST_CLASS` equal to your new class. +# E.g. TEST_CLASS = XLATensorBase +# - To run tests with new device type, set `TORCH_TEST_DEVICE` env variable to path +# to this file. Multiple paths can be separated by `:`. +# See pytorch/xla/test/pytorch_test_base.py for a more detailed example. +_TORCH_TEST_DEVICES = os.environ.get("TORCH_TEST_DEVICES", None) +if _TORCH_TEST_DEVICES: + for path in _TORCH_TEST_DEVICES.split(":"): + # runpy (a stdlib module) lacks annotations + mod = runpy.run_path(path, init_globals=globals()) # type: ignore[func-returns-value] + device_type_test_bases.append(mod["TEST_CLASS"]) + + +PYTORCH_CUDA_MEMCHECK = os.getenv("PYTORCH_CUDA_MEMCHECK", "0") == "1" + +PYTORCH_TESTING_DEVICE_ONLY_FOR_KEY = "PYTORCH_TESTING_DEVICE_ONLY_FOR" +PYTORCH_TESTING_DEVICE_EXCEPT_FOR_KEY = "PYTORCH_TESTING_DEVICE_EXCEPT_FOR" +PYTORCH_TESTING_DEVICE_FOR_CUSTOM_KEY = "PYTORCH_TESTING_DEVICE_FOR_CUSTOM" + + +def get_desired_device_type_test_bases( + except_for=None, only_for=None, include_lazy=False, allow_mps=False, allow_xpu=False +): + # allow callers to specifically opt tests into being tested on MPS, similar to `include_lazy` + test_bases = device_type_test_bases.copy() + if allow_mps and TEST_MPS and MPSTestBase not in test_bases: + test_bases.append(MPSTestBase) + if allow_xpu and TEST_XPU and XPUTestBase not in test_bases: + test_bases.append(XPUTestBase) + if TEST_HPU and HPUTestBase not in test_bases: + test_bases.append(HPUTestBase) + # Filter out the device types based on user inputs + desired_device_type_test_bases = filter_desired_device_types( + test_bases, except_for, only_for + ) + if include_lazy: + # Note [Lazy Tensor tests in device agnostic testing] + # Right now, test_view_ops.py runs with LazyTensor. + # We don't want to opt every device-agnostic test into using the lazy device, + # because many of them will fail. + # So instead, the only way to opt a specific device-agnostic test file into + # lazy tensor testing is with include_lazy=True + if IS_FBCODE: + print( + "TorchScript backend not yet supported in FBCODE/OVRSOURCE builds", + file=sys.stderr, + ) + else: + desired_device_type_test_bases.append(LazyTestBase) + + def split_if_not_empty(x: str): + return x.split(",") if x else [] + + # run some cuda testcases on other devices if available + # Usage: + # export PYTORCH_TESTING_DEVICE_FOR_CUSTOM=privateuse1 + env_custom_only_for = split_if_not_empty( + os.getenv(PYTORCH_TESTING_DEVICE_FOR_CUSTOM_KEY, "") + ) + if env_custom_only_for: + desired_device_type_test_bases += filter( + lambda x: x.device_type in env_custom_only_for, test_bases + ) + desired_device_type_test_bases = list(set(desired_device_type_test_bases)) + + # Filter out the device types based on environment variables if available + # Usage: + # export PYTORCH_TESTING_DEVICE_ONLY_FOR=cuda,cpu + # export PYTORCH_TESTING_DEVICE_EXCEPT_FOR=xla + env_only_for = split_if_not_empty( + os.getenv(PYTORCH_TESTING_DEVICE_ONLY_FOR_KEY, "") + ) + env_except_for = split_if_not_empty( + os.getenv(PYTORCH_TESTING_DEVICE_EXCEPT_FOR_KEY, "") + ) + + return filter_desired_device_types( + desired_device_type_test_bases, env_except_for, env_only_for + ) + + +# Adds 'instantiated' device-specific test cases to the given scope. +# The tests in these test cases are derived from the generic tests in +# generic_test_class. This function should be used instead of +# instantiate_parametrized_tests() if the test class contains +# device-specific tests (NB: this supports additional @parametrize usage). +# +# See note "Writing Test Templates" +# TODO: remove "allow_xpu" option after Interl GPU support all test case instantiate by this function. +def instantiate_device_type_tests( + generic_test_class, + scope, + except_for=None, + only_for=None, + include_lazy=False, + allow_mps=False, + allow_xpu=False, +): + # Removes the generic test class from its enclosing scope so its tests + # are not discoverable. + del scope[generic_test_class.__name__] + + generic_members = set(generic_test_class.__dict__.keys()) + generic_tests = [x for x in generic_members if x.startswith("test")] + + # Creates device-specific test cases + for base in get_desired_device_type_test_bases( + except_for, only_for, include_lazy, allow_mps, allow_xpu + ): + class_name = generic_test_class.__name__ + base.device_type.upper() + + # type set to Any and suppressed due to unsupported runtime class: + # https://github.com/python/mypy/wiki/Unsupported-Python-Features + device_type_test_class: Any = type(class_name, (base, generic_test_class), {}) + + # Arrange for setUpClass and tearDownClass methods defined both in the test template + # class and in the generic base to be called. This allows device-parameterized test + # classes to support setup and teardown. + # NB: This should be done before instantiate_test() is called as that invokes setup. + @classmethod + def _setUpClass(cls): + # This should always be called, whether or not the test class invokes + # super().setUpClass(), to set the primary device. + base.setUpClass() + # We want to call the @classmethod defined in the generic base, but pass + # it the device-specific class object (cls), hence the __func__ call. + generic_test_class.setUpClass.__func__(cls) + + @classmethod + def _tearDownClass(cls): + # We want to call the @classmethod defined in the generic base, but pass + # it the device-specific class object (cls), hence the __func__ call. + generic_test_class.tearDownClass.__func__(cls) + base.tearDownClass() + + device_type_test_class.setUpClass = _setUpClass + device_type_test_class.tearDownClass = _tearDownClass + + for name in generic_members: + if name in generic_tests: # Instantiates test member + test = getattr(generic_test_class, name) + # XLA-compat shim (XLA's instantiate_test takes doesn't take generic_cls) + sig = inspect.signature(device_type_test_class.instantiate_test) + if len(sig.parameters) == 3: + # Instantiates the device-specific tests + device_type_test_class.instantiate_test( + name, copy.deepcopy(test), generic_cls=generic_test_class + ) + else: + device_type_test_class.instantiate_test(name, copy.deepcopy(test)) + # Ports non-test member. Setup / teardown have already been handled above + elif name not in device_type_test_class.__dict__: + nontest = getattr(generic_test_class, name) + setattr(device_type_test_class, name, nontest) + + # Mimics defining the instantiated class in the caller's file + # by setting its module to the given class's and adding + # the module to the given scope. + # This lets the instantiated class be discovered by unittest. + device_type_test_class.__module__ = generic_test_class.__module__ + scope[class_name] = device_type_test_class + + # Delete the generic form of the test functions (e.g. TestFoo.test_bar()) so they're + # not discoverable. This mutates the original class (TestFoo), which was removed from + # scope above. At this point, device-specific tests (e.g. TestFooCUDA.test_bar_cuda) + # have already been created and the generic forms are no longer needed. + for name in generic_tests: + delattr(generic_test_class, name) + + +# Category of dtypes to run an OpInfo-based test for +# Example use: @ops(dtype=OpDTypes.supported) +# +# There are 7 categories: +# - supported: Every dtype supported by the operator. Use for exhaustive +# testing of all dtypes. +# - unsupported: Run tests on dtypes not supported by the operator. e.g. for +# testing the operator raises an error and doesn't crash. +# - supported_backward: Every dtype supported by the operator's backward pass. +# - unsupported_backward: Run tests on dtypes not supported by the operator's backward pass. +# - any_one: Runs a test for one dtype the operator supports. Prioritizes dtypes the +# operator supports in both forward and backward. +# - none: Useful for tests that are not dtype-specific. No dtype will be passed to the test +# when this is selected. +# - any_common_cpu_cuda_one: Pick a dtype that supports both CPU and CUDA. +class OpDTypes(Enum): + supported = 0 # Test all supported dtypes (default) + unsupported = 1 # Test only unsupported dtypes + supported_backward = 2 # Test all supported backward dtypes + unsupported_backward = 3 # Test only unsupported backward dtypes + any_one = 4 # Test precisely one supported dtype + none = 5 # Instantiate no dtype variants (no dtype kwarg needed) + any_common_cpu_cuda_one = ( + 6 # Test precisely one supported dtype that is common to both cuda and cpu + ) + + +# Arbitrary order +ANY_DTYPE_ORDER = ( + torch.float32, + torch.float64, + torch.complex64, + torch.complex128, + torch.float16, + torch.bfloat16, + torch.long, + torch.int32, + torch.int16, + torch.int8, + torch.uint8, + torch.bool, + torch.float8_e4m3fn, + torch.float8_e5m2, +) + + +def _serialize_sample(sample_input): + # NB: For OpInfos, SampleInput.summary() prints in a cleaner way. + if getattr(sample_input, "summary", None) is not None: + return sample_input.summary() + return str(sample_input) + + +# Decorator that defines the OpInfos a test template should be instantiated for. +# +# Example usage: +# +# @ops(unary_ufuncs) +# def test_numerics(self, device, dtype, op): +# +# +# This will instantiate variants of test_numerics for each given OpInfo, +# on each device the OpInfo's operator supports, and for every dtype supported by +# that operator. There are a few caveats to the dtype rule, explained below. +# +# The @ops decorator can accept two +# additional arguments, "dtypes" and "allowed_dtypes". If "dtypes" is specified +# then the test variants are instantiated for those dtypes, regardless of +# what the operator supports. If given "allowed_dtypes" then test variants +# are instantiated only for the intersection of allowed_dtypes and the dtypes +# they would otherwise be instantiated with. That is, allowed_dtypes composes +# with the options listed above and below. +# +# The "dtypes" argument can also accept additional values (see OpDTypes above): +# OpDTypes.supported - the test is instantiated for all dtypes the operator +# supports +# OpDTypes.unsupported - the test is instantiated for all dtypes the operator +# doesn't support +# OpDTypes.supported_backward - the test is instantiated for all dtypes the +# operator's gradient formula supports +# OpDTypes.unsupported_backward - the test is instantiated for all dtypes the +# operator's gradient formula doesn't support +# OpDTypes.any_one - the test is instantiated for one dtype the +# operator supports. The dtype supports forward and backward if possible. +# OpDTypes.none - the test is instantiated without any dtype. The test signature +# should not include a dtype kwarg in this case. +# OpDTypes.any_common_cpu_cuda_one - the test is instantiated for a dtype +# that supports both CPU and CUDA. +# +# These options allow tests to have considerable control over the dtypes +# they're instantiated for. + + +class ops(_TestParametrizer): + def __init__( + self, + op_list, + *, + dtypes: Union[OpDTypes, Sequence[torch.dtype]] = OpDTypes.supported, + allowed_dtypes: Optional[Sequence[torch.dtype]] = None, + skip_if_dynamo=True, + ): + self.op_list = list(op_list) + self.opinfo_dtypes = dtypes + self.allowed_dtypes = ( + set(allowed_dtypes) if allowed_dtypes is not None else None + ) + self.skip_if_dynamo = skip_if_dynamo + + def _parametrize_test(self, test, generic_cls, device_cls): + """Parameterizes the given test function across each op and its associated dtypes.""" + if device_cls is None: + raise RuntimeError( + "The @ops decorator is only intended to be used in a device-specific " + "context; use it with instantiate_device_type_tests() instead of " + "instantiate_parametrized_tests()" + ) + + op = check_exhausted_iterator = object() + for op in self.op_list: + # Determine the set of dtypes to use. + dtypes: Union[set[torch.dtype], set[None]] + if isinstance(self.opinfo_dtypes, Sequence): + dtypes = set(self.opinfo_dtypes) + elif self.opinfo_dtypes == OpDTypes.unsupported_backward: + dtypes = set(get_all_dtypes()).difference( + op.supported_backward_dtypes(device_cls.device_type) + ) + elif self.opinfo_dtypes == OpDTypes.supported_backward: + dtypes = op.supported_backward_dtypes(device_cls.device_type) + elif self.opinfo_dtypes == OpDTypes.unsupported: + dtypes = set(get_all_dtypes()).difference( + op.supported_dtypes(device_cls.device_type) + ) + elif self.opinfo_dtypes == OpDTypes.supported: + dtypes = set(op.supported_dtypes(device_cls.device_type)) + elif self.opinfo_dtypes == OpDTypes.any_one: + # Tries to pick a dtype that supports both forward or backward + supported = op.supported_dtypes(device_cls.device_type) + supported_backward = op.supported_backward_dtypes( + device_cls.device_type + ) + supported_both = supported.intersection(supported_backward) + dtype_set = supported_both if len(supported_both) > 0 else supported + for dtype in ANY_DTYPE_ORDER: + if dtype in dtype_set: + dtypes = {dtype} + break + else: + dtypes = {} + elif self.opinfo_dtypes == OpDTypes.any_common_cpu_cuda_one: + # Tries to pick a dtype that supports both CPU and CUDA + supported = set(op.dtypes).intersection(op.dtypesIfCUDA) + if supported: + dtypes = { + next(dtype for dtype in ANY_DTYPE_ORDER if dtype in supported) + } + else: + dtypes = {} + + elif self.opinfo_dtypes == OpDTypes.none: + dtypes = {None} + else: + raise RuntimeError(f"Unknown OpDType: {self.opinfo_dtypes}") + + if self.allowed_dtypes is not None: + dtypes = dtypes.intersection(self.allowed_dtypes) + + # Construct the test name; device / dtype parts are handled outside. + # See [Note: device and dtype suffix placement] + test_name = op.formatted_name + + # Filter sample skips / xfails to only those that apply to the OpInfo. + # These are defined on the test function via decorators. + sample_skips_and_xfails = getattr(test, "sample_skips_and_xfails", None) + if sample_skips_and_xfails is not None: + sample_skips_and_xfails = [ + rule + for rule in sample_skips_and_xfails + if rule.op_match_fn(device_cls.device_type, op) + ] + + for dtype in dtypes: + # Construct parameter kwargs to pass to the test. + param_kwargs = {"op": op} + _update_param_kwargs(param_kwargs, "dtype", dtype) + + # NOTE: test_wrapper exists because we don't want to apply + # op-specific decorators to the original test. + # Test-specific decorators are applied to the original test, + # however. + try: + + @wraps(test) + def test_wrapper(*args, **kwargs): + try: + return test(*args, **kwargs) + except unittest.SkipTest as e: + raise e + except Exception as e: + tracked_input = get_tracked_input() + if PRINT_REPRO_ON_FAILURE and tracked_input is not None: + e_tracked = Exception( # noqa: TRY002 + f"{str(e)}\n\nCaused by {tracked_input.type_desc} " + f"at index {tracked_input.index}: " + f"{_serialize_sample(tracked_input.val)}" + ) + e_tracked._tracked_input = tracked_input # type: ignore[attr] + raise e_tracked from e + raise e + finally: + clear_tracked_input() + + if self.skip_if_dynamo and not TEST_WITH_TORCHINDUCTOR: + test_wrapper = skipIfTorchDynamo( + "Policy: we don't run OpInfo tests w/ Dynamo" + )(test_wrapper) + + # Initialize info for the last input seen. This is useful for tracking + # down which inputs caused a test failure. Note that TrackedInputIter is + # responsible for managing this. + test.tracked_input = None + + decorator_fn = partial( + op.get_decorators, + generic_cls.__name__, + test.__name__, + device_cls.device_type, + dtype, + ) + + if sample_skips_and_xfails is not None: + test_wrapper.sample_skips_and_xfails = sample_skips_and_xfails + + yield (test_wrapper, test_name, param_kwargs, decorator_fn) + except Exception as ex: + # Provides an error message for debugging before rethrowing the exception + print(f"Failed to instantiate {test_name} for op {op.name}!") + raise ex + if op is check_exhausted_iterator: + raise ValueError( + "An empty op_list was passed to @ops. " + "Note that this may result from reuse of a generator." + ) + + +# Decorator that skips a test if the given condition is true. +# Notes: +# (1) Skip conditions stack. +# (2) Skip conditions can be bools or strings. If a string the +# test base must have defined the corresponding attribute to be False +# for the test to run. If you want to use a string argument you should +# probably define a new decorator instead (see below). +# (3) Prefer the existing decorators to defining the 'device_type' kwarg. +class skipIf: + def __init__(self, dep, reason, device_type=None): + self.dep = dep + self.reason = reason + self.device_type = device_type + + def __call__(self, fn): + @wraps(fn) + def dep_fn(slf, *args, **kwargs): + if ( + self.device_type is None + or self.device_type == slf.device_type + or ( + isinstance(self.device_type, Iterable) + and slf.device_type in self.device_type + ) + ): + if (isinstance(self.dep, str) and getattr(slf, self.dep, True)) or ( + isinstance(self.dep, bool) and self.dep + ): + raise unittest.SkipTest(self.reason) + + return fn(slf, *args, **kwargs) + + return dep_fn + + +# Skips a test on CPU if the condition is true. +class skipCPUIf(skipIf): + def __init__(self, dep, reason): + super().__init__(dep, reason, device_type="cpu") + + +# Skips a test on CUDA if the condition is true. +class skipCUDAIf(skipIf): + def __init__(self, dep, reason): + super().__init__(dep, reason, device_type="cuda") + + +# Skips a test on XPU if the condition is true. +class skipXPUIf(skipIf): + def __init__(self, dep, reason): + super().__init__(dep, reason, device_type="xpu") + + +# Skips a test on XPU or CUDA if the condition is true. +class skipGPUIf(skipIf): + def __init__(self, dep, reason): + super().__init__(dep, reason, device_type=GPU_TYPES) + + +# Skips a test on Lazy if the condition is true. +class skipLazyIf(skipIf): + def __init__(self, dep, reason): + super().__init__(dep, reason, device_type="lazy") + + +# Skips a test on Meta if the condition is true. +class skipMetaIf(skipIf): + def __init__(self, dep, reason): + super().__init__(dep, reason, device_type="meta") + + +# Skips a test on MPS if the condition is true. +class skipMPSIf(skipIf): + def __init__(self, dep, reason): + super().__init__(dep, reason, device_type="mps") + + +class skipHPUIf(skipIf): + def __init__(self, dep, reason): + super().__init__(dep, reason, device_type="hpu") + + +# Skips a test on XLA if the condition is true. +class skipXLAIf(skipIf): + def __init__(self, dep, reason): + super().__init__(dep, reason, device_type="xla") + + +class skipPRIVATEUSE1If(skipIf): + def __init__(self, dep, reason): + device_type = torch._C._get_privateuse1_backend_name() + super().__init__(dep, reason, device_type=device_type) + + +def _has_sufficient_memory(device, size): + device_ = torch.device(device) + device_type = device_.type + if device_type in ["cuda", "xpu"]: + acc = torch.accelerator.current_accelerator() + # Case 1: no accelerator found + if not acc: + return False + # Case 2: accelerator found but not matching device type + if acc.type != device_type: + return True + # Case 3: accelerator found and matching device type but not available + if not torch.accelerator.is_available(): + return False + # Case 4: accelerator found and matching device type and available + gc.collect() + torch.accelerator.empty_cache() + + if device_.index is None: + device_ = torch.device(device_type, 0) + + if device_type == "cuda": + return ( + torch.cuda.memory.mem_get_info(device_)[0] + * torch.cuda.memory.get_per_process_memory_fraction(device_) + ) >= size + + if device_type == "xpu": + return torch.xpu.memory.mem_get_info(device_)[0] >= size + + if device_type == "xla": + raise unittest.SkipTest("TODO: Memory availability checks for XLA?") + + if device_type != "cpu": + raise unittest.SkipTest("Unknown device type") + + # CPU + if not HAS_PSUTIL: + raise unittest.SkipTest("Need psutil to determine if memory is sufficient") + + # The sanitizers have significant memory overheads + if TEST_WITH_ASAN or TEST_WITH_TSAN or TEST_WITH_UBSAN: + effective_size = size * 10 + else: + effective_size = size + + # don't try using all RAM on s390x, leave some for service processes + if IS_S390X: + effective_size = effective_size * 2 + + if psutil.virtual_memory().available < effective_size: + gc.collect() + return psutil.virtual_memory().available >= effective_size + + +def largeTensorTest(size, device=None, inductor=TEST_WITH_TORCHINDUCTOR): + """Skip test if the device has insufficient memory to run the test + + size may be a number of bytes, a string of the form "N GB", or a callable + + If the test is a device generic test, available memory on the primary device will be checked. + It can also be overridden by the optional `device=` argument. + In other tests, the `device=` argument needs to be specified. + """ + if isinstance(size, str): + assert size.endswith(("GB", "gb")), "only bytes or GB supported" + size = 1024**3 * int(size[:-2]) + + def inner(fn): + @wraps(fn) + def dep_fn(self, *args, **kwargs): + size_bytes: int = size(self, *args, **kwargs) if callable(size) else size + _device = device + if _device is None: + if hasattr(self, "get_primary_device"): + _device = self.get_primary_device() + else: + _device = self.device + + # If this is running with GPU cpp_wrapper, the autotuning step will generate + # an additional array of the same size as the input. + if inductor and torch._inductor.config.cpp_wrapper and _device != "cpu": + size_bytes *= 2 + if not _has_sufficient_memory(_device, size_bytes): + raise unittest.SkipTest(f"Insufficient {_device} memory") + + return fn(self, *args, **kwargs) + + return dep_fn + + return inner + + +class expectedFailure: + def __init__(self, device_type, dtype=None): + self.device_type = device_type + self.dtype = dtype + + def __call__(self, fn): + @wraps(fn) + def efail_fn(slf, *args, **kwargs): + if ( + not hasattr(slf, "device_type") + and hasattr(slf, "device") + and isinstance(slf.device, str) + ): + target_device_type = slf.device + else: + target_device_type = slf.device_type + + target_dtype = kwargs.get("dtype", getattr(slf, "dtype", None)) + device_matches = ( + self.device_type is None or self.device_type == target_device_type + ) + dtype_matches = self.dtype is None or self.dtype == target_dtype + + if device_matches and dtype_matches: + try: + fn(slf, *args, **kwargs) + except Exception: + return + else: + slf.fail("expected test to fail, but it passed") + + return fn(slf, *args, **kwargs) + + return efail_fn + + +class onlyOn: + def __init__(self, device_type: Union[str, list]): + self.device_type = device_type + + def __call__(self, fn): + @wraps(fn) + def only_fn(slf, *args, **kwargs): + if slf.device_type not in self.device_type: + reason = f"Only runs on {self.device_type}" + raise unittest.SkipTest(reason) + + return fn(slf, *args, **kwargs) + + return only_fn + + +# Decorator that provides all available devices of the device type to the test +# as a list of strings instead of providing a single device string. +# Skips the test if the number of available devices of the variant's device +# type is less than the 'num_required_devices' arg. +class deviceCountAtLeast: + def __init__(self, num_required_devices): + self.num_required_devices = num_required_devices + + def __call__(self, fn): + assert not hasattr(fn, "num_required_devices"), ( + f"deviceCountAtLeast redefinition for {fn.__name__}" + ) + fn.num_required_devices = self.num_required_devices + + @wraps(fn) + def multi_fn(slf, devices, *args, **kwargs): + if len(devices) < self.num_required_devices: + reason = f"fewer than {self.num_required_devices} devices detected" + raise unittest.SkipTest(reason) + + return fn(slf, devices, *args, **kwargs) + + return multi_fn + + +# Only runs the test on the native device type (currently CPU, CUDA, Meta and PRIVATEUSE1) +def onlyNativeDeviceTypes(fn: Callable[_P, _T]) -> Callable[_P, _T]: + @wraps(fn) + def only_fn(self, *args: _P.args, **kwargs: _P.kwargs) -> _T: + if self.device_type not in NATIVE_DEVICES: + reason = f"onlyNativeDeviceTypes: doesn't run on {self.device_type}" + raise unittest.SkipTest(reason) + + return fn(self, *args, **kwargs) + + return only_fn + + +# Only runs the test on the native device types and devices specified in the devices list +def onlyNativeDeviceTypesAnd(devices=None): + def decorator(fn): + @wraps(fn) + def only_fn(self, *args, **kwargs): + if ( + self.device_type not in NATIVE_DEVICES + and self.device_type not in devices + ): + reason = f"onlyNativeDeviceTypesAnd {devices} : doesn't run on {self.device_type}" + raise unittest.SkipTest(reason) + + return fn(self, *args, **kwargs) + + return only_fn + + return decorator + + +# Specifies per-dtype precision overrides. +# Ex. +# +# @precisionOverride({torch.half : 1e-2, torch.float : 1e-4}) +# @dtypes(torch.half, torch.float, torch.double) +# def test_X(self, device, dtype): +# ... +# +# When the test is instantiated its class's precision will be set to the +# corresponding override, if it exists. +# self.precision can be accessed directly, and it also controls the behavior of +# functions like self.assertEqual(). +# +# Note that self.precision is a scalar value, so if you require multiple +# precisions (or are working with multiple dtypes) they should be specified +# explicitly and computed using self.precision (e.g. +# self.precision *2, max(1, self.precision)). +class precisionOverride: + def __init__(self, d): + assert isinstance(d, dict), ( + "precisionOverride not given a dtype : precision dict!" + ) + for dtype in d: + assert isinstance(dtype, torch.dtype), ( + f"precisionOverride given unknown dtype {dtype}" + ) + + self.d = d + + def __call__(self, fn): + fn.precision_overrides = self.d + return fn + + +# Specifies per-dtype tolerance overrides tol(atol, rtol). It has priority over +# precisionOverride. +# Ex. +# +# @toleranceOverride({torch.float : tol(atol=1e-2, rtol=1e-3}, +# torch.double : tol{atol=1e-4, rtol = 0}) +# @dtypes(torch.half, torch.float, torch.double) +# def test_X(self, device, dtype): +# ... +# +# When the test is instantiated its class's tolerance will be set to the +# corresponding override, if it exists. +# self.rtol and self.precision can be accessed directly, and they also control +# the behavior of functions like self.assertEqual(). +# +# The above example sets atol = 1e-2 and rtol = 1e-3 for torch.float and +# atol = 1e-4 and rtol = 0 for torch.double. +tol = namedtuple("tol", ["atol", "rtol"]) + + +class toleranceOverride: + def __init__(self, d): + assert isinstance(d, dict), "toleranceOverride not given a dtype : tol dict!" + for dtype, prec in d.items(): + assert isinstance(dtype, torch.dtype), ( + f"toleranceOverride given unknown dtype {dtype}" + ) + assert isinstance(prec, tol), ( + "toleranceOverride not given a dtype : tol dict!" + ) + + self.d = d + + def __call__(self, fn): + fn.tolerance_overrides = self.d + return fn + + +# Decorator that instantiates a variant of the test for each given dtype. +# Notes: +# (1) Tests that accept the dtype argument MUST use this decorator. +# (2) Can be overridden for CPU or CUDA, respectively, using dtypesIfCPU +# or dtypesIfCUDA. +# (3) Can accept an iterable of dtypes or an iterable of tuples +# of dtypes. +# Examples: +# @dtypes(torch.float32, torch.float64) +# @dtypes((torch.long, torch.float32), (torch.int, torch.float64)) +class dtypes: + def __init__(self, *args, device_type="all"): + if len(args) > 0 and isinstance(args[0], (list, tuple)): + for arg in args: + assert isinstance(arg, (list, tuple)), ( + "When one dtype variant is a tuple or list, " + "all dtype variants must be. " + f"Received non-list non-tuple dtype {str(arg)}" + ) + assert all(isinstance(dtype, torch.dtype) for dtype in arg), ( + f"Unknown dtype in {str(arg)}" + ) + else: + assert all(isinstance(arg, torch.dtype) for arg in args), ( + f"Unknown dtype in {str(args)}" + ) + + self.args = args + self.device_type = device_type + + def __call__(self, fn): + d = getattr(fn, "dtypes", {}) + assert self.device_type not in d, f"dtypes redefinition for {self.device_type}" + d[self.device_type] = self.args + fn.dtypes = d + return fn + + +# Overrides specified dtypes on the CPU. +class dtypesIfCPU(dtypes): + def __init__(self, *args): + super().__init__(*args, device_type="cpu") + + +# Overrides specified dtypes on CUDA. +class dtypesIfCUDA(dtypes): + def __init__(self, *args): + super().__init__(*args, device_type="cuda") + + +# Overrides specified dtypes on Intel GPU. +class dtypesIfXPU(dtypes): + def __init__(self, *args): + super().__init__(*args, device_type="xpu") + + +class dtypesIfMPS(dtypes): + def __init__(self, *args): + super().__init__(*args, device_type="mps") + + +class dtypesIfHPU(dtypes): + def __init__(self, *args): + super().__init__(*args, device_type="hpu") + + +class dtypesIfPRIVATEUSE1(dtypes): + def __init__(self, *args): + super().__init__(*args, device_type=torch._C._get_privateuse1_backend_name()) + + +def onlyCPU(fn): + return onlyOn("cpu")(fn) + + +def onlyCUDA(fn): + return onlyOn("cuda")(fn) + + +def onlyMPS(fn): + return onlyOn("mps")(fn) + + +def onlyXPU(fn): + return onlyOn("xpu")(fn) + + +def onlyHPU(fn): + return onlyOn("hpu")(fn) + + +def onlyPRIVATEUSE1(fn): + device_type = torch._C._get_privateuse1_backend_name() + device_mod = getattr(torch, device_type, None) + if device_mod is None: + reason = f"Skip as torch has no module of {device_type}" + return unittest.skip(reason)(fn) + return onlyOn(device_type)(fn) + + +def onlyCUDAAndPRIVATEUSE1(fn): + @wraps(fn) + def only_fn(self, *args, **kwargs): + if self.device_type not in ("cuda", torch._C._get_privateuse1_backend_name()): + reason = f"onlyCUDAAndPRIVATEUSE1: doesn't run on {self.device_type}" + raise unittest.SkipTest(reason) + + return fn(self, *args, **kwargs) + + return only_fn + + +def disablecuDNN(fn): + @wraps(fn) + def disable_cudnn(self, *args, **kwargs): + if self.device_type == "cuda" and self.has_cudnn(): + with torch.backends.cudnn.flags(enabled=False): + return fn(self, *args, **kwargs) + return fn(self, *args, **kwargs) + + return disable_cudnn + + +def disableMkldnn(fn): + @wraps(fn) + def disable_mkldnn(self, *args, **kwargs): + if torch.backends.mkldnn.is_available(): + with torch.backends.mkldnn.flags(enabled=False): + return fn(self, *args, **kwargs) + return fn(self, *args, **kwargs) + + return disable_mkldnn + + +def expectedFailureCPU(fn): + return expectedFailure("cpu")(fn) + + +def expectedFailureCUDA(fn): + return expectedFailure("cuda")(fn) + + +def expectedFailureXPU(fn): + return expectedFailure("xpu")(fn) + + +def expectedFailureMeta(fn): + return skipIfTorchDynamo()(expectedFailure("meta")(fn)) + + +def expectedFailureXLA(fn): + return expectedFailure("xla")(fn) + + +def expectedFailureHPU(fn): + return expectedFailure("hpu")(fn) + + +def expectedFailureMPS(fn): + return expectedFailure("mps")(fn) + + +def expectedFailureMPSComplex(fn): + return expectedFailure("mps", torch.complex64)(fn) + + +def expectedFailureMPSPre15(fn): + import platform + + version = float(".".join(platform.mac_ver()[0].split(".")[:2]) or -1) + if not version or version < 1.0: # cpu or other unsupported device + return fn + if version < 15.0: + return expectedFailure("mps")(fn) + return fn + + +def expectedFailureMPSPre14(fn): + import platform + + version = float(".".join(platform.mac_ver()[0].split(".")[:2]) or -1) + if not version or version < 1.0: # cpu or other unsupported device + return fn + if version < 14.0: + return expectedFailure("mps")(fn) + return fn + + +# Skips a test on CPU if LAPACK is not available. +def skipCPUIfNoLapack(fn): + return skipCPUIf(not torch._C.has_lapack, "PyTorch compiled without Lapack")(fn) + + +# Skips a test on CPU if FFT is not available. +def skipCPUIfNoFFT(fn): + return skipCPUIf(not torch._C.has_spectral, "PyTorch is built without FFT support")( + fn + ) + + +# Skips a test on CPU if MKL is not available. +def skipCPUIfNoMkl(fn): + return skipCPUIf(not TEST_MKL, "PyTorch is built without MKL support")(fn) + + +# Skips a test on CPU if MKL Sparse is not available (it's not linked on Windows). +def skipCPUIfNoMklSparse(fn): + return skipCPUIf( + IS_WINDOWS or not TEST_MKL, "PyTorch is built without MKL support" + )(fn) + + +# Skips a test on CPU if mkldnn is not available. +def skipCPUIfNoMkldnn(fn): + return skipCPUIf( + not torch.backends.mkldnn.is_available(), + "PyTorch is built without mkldnn support", + )(fn) + + +# Skips a test on CUDA if MAGMA is not available. +def skipCUDAIfNoMagma(fn): + return skipCUDAIf("no_magma", "no MAGMA library detected")( + skipCUDANonDefaultStreamIf(True)(fn) + ) + + +def has_cusolver(): + return not TEST_WITH_ROCM + + +def has_hipsolver(): + rocm_version = _get_torch_rocm_version() + # hipSOLVER is disabled on ROCM < 5.3 + return rocm_version >= (5, 3) + + +# Skips a test on CUDA/ROCM if cuSOLVER/hipSOLVER is not available +def skipCUDAIfNoCusolver(fn): + return skipCUDAIf( + not has_cusolver() and not has_hipsolver(), "cuSOLVER not available" + )(fn) + + +# Skips a test if both cuSOLVER and MAGMA are not available +def skipCUDAIfNoMagmaAndNoCusolver(fn): + if has_cusolver(): + return fn + else: + # cuSolver is disabled on cuda < 10.1.243, tests depend on MAGMA + return skipCUDAIfNoMagma(fn) + + +# Skips a test if both cuSOLVER/hipSOLVER and MAGMA are not available +def skipCUDAIfNoMagmaAndNoLinalgsolver(fn): + if has_cusolver() or has_hipsolver(): + return fn + else: + # cuSolver is disabled on cuda < 10.1.243, tests depend on MAGMA + return skipCUDAIfNoMagma(fn) + + +# Skips a test on CUDA when using ROCm. +def skipCUDAIfRocm(func=None, *, msg="test doesn't currently work on the ROCm stack"): + def dec_fn(fn): + reason = f"skipCUDAIfRocm: {msg}" + return skipCUDAIf(TEST_WITH_ROCM, reason=reason)(fn) + + if func: + return dec_fn(func) + return dec_fn + + +# Skips a test on CUDA when not using ROCm. +def skipCUDAIfNotRocm(fn): + return skipCUDAIf( + not TEST_WITH_ROCM, "test doesn't currently work on the CUDA stack" + )(fn) + + +# Skips a test on CUDA if ROCm is unavailable or its version is lower than requested. +def skipCUDAIfRocmVersionLessThan(version=None): + def dec_fn(fn): + @wraps(fn) + def wrap_fn(self, *args, **kwargs): + if self.device_type == "cuda": + if not TEST_WITH_ROCM: + reason = "ROCm not available" + raise unittest.SkipTest(reason) + rocm_version_tuple = _get_torch_rocm_version() + if ( + rocm_version_tuple is None + or version is None + or rocm_version_tuple < tuple(version) + ): + reason = ( + f"ROCm {rocm_version_tuple} is available but {version} required" + ) + raise unittest.SkipTest(reason) + + return fn(self, *args, **kwargs) + + return wrap_fn + + return dec_fn + + +# Skips a test on CUDA when using ROCm. +def skipCUDAIfNotMiopenSuggestNHWC(fn): + return skipCUDAIf( + not TEST_WITH_MIOPEN_SUGGEST_NHWC, + "test doesn't currently work without MIOpen NHWC activation", + )(fn) + + +# Skips a test for specified CUDA versions, given in the form of a list of [major, minor]s. +def skipCUDAVersionIn(versions: Optional[list[tuple[int, int]]] = None): + def dec_fn(fn): + @wraps(fn) + def wrap_fn(self, *args, **kwargs): + version = _get_torch_cuda_version() + if version == (0, 0): # cpu or rocm + return fn(self, *args, **kwargs) + if version in (versions or []): + reason = f"test skipped for CUDA version {version}" + raise unittest.SkipTest(reason) + return fn(self, *args, **kwargs) + + return wrap_fn + + return dec_fn + + +# Skips a test for CUDA versions less than specified, given in the form of [major, minor]. +def skipCUDAIfVersionLessThan(versions: Optional[tuple[int, int]] = None): + def dec_fn(fn): + @wraps(fn) + def wrap_fn(self, *args, **kwargs): + version = _get_torch_cuda_version() + if version == (0, 0): # cpu or rocm + return fn(self, *args, **kwargs) + if version < versions: + reason = f"test skipped for CUDA versions < {version}" + raise unittest.SkipTest(reason) + return fn(self, *args, **kwargs) + + return wrap_fn + + return dec_fn + + +# Skips a test on CUDA if cuDNN is unavailable or its version is lower than requested. +def skipCUDAIfCudnnVersionLessThan(version=0): + def dec_fn(fn): + @wraps(fn) + def wrap_fn(self, *args, **kwargs): + if self.device_type == "cuda": + if self.no_cudnn: + reason = "cuDNN not available" + raise unittest.SkipTest(reason) + if self.cudnn_version is None or self.cudnn_version < version: + reason = f"cuDNN version {self.cudnn_version} is available but {version} required" + raise unittest.SkipTest(reason) + + return fn(self, *args, **kwargs) + + return wrap_fn + + return dec_fn + + +# Skips a test on CUDA if cuSparse generic API is not available +def skipCUDAIfNoCusparseGeneric(fn): + return skipCUDAIf(not TEST_CUSPARSE_GENERIC, "cuSparse Generic API not available")( + fn + ) + + +def skipCUDAIfNoHipsparseGeneric(fn): + return skipCUDAIf( + not TEST_HIPSPARSE_GENERIC, "hipSparse Generic API not available" + )(fn) + + +def skipCUDAIfNoSparseGeneric(fn): + return skipCUDAIf( + not (TEST_CUSPARSE_GENERIC or TEST_HIPSPARSE_GENERIC), + "Sparse Generic API not available", + )(fn) + + +def skipCUDAIfNoCudnn(fn): + return skipCUDAIfCudnnVersionLessThan(0)(fn) + + +def skipCUDAIfMiopen(fn): + return skipCUDAIf(torch.version.hip is not None, "Marked as skipped for MIOpen")(fn) + + +def skipCUDAIfNoMiopen(fn): + return skipCUDAIf(torch.version.hip is None, "MIOpen is not available")( + skipCUDAIfNoCudnn(fn) + ) + + +def skipLazy(fn): + return skipLazyIf(True, "test doesn't work with lazy tensors")(fn) + + +def skipMeta(fn): + return skipMetaIf(True, "test doesn't work with meta tensors")(fn) + + +def skipXLA(fn): + return skipXLAIf(True, "Marked as skipped for XLA")(fn) + + +def skipMPS(fn): + return skipMPSIf(True, "test doesn't work on MPS backend")(fn) + + +def skipHPU(fn): + return skipHPUIf(True, "test doesn't work on HPU backend")(fn) + + +def skipXPU(fn): + return skipXPUIf(True, "test doesn't work on XPU backend")(fn) + + +def skipPRIVATEUSE1(fn): + return skipPRIVATEUSE1If(True, "test doesn't work on privateuse1 backend")(fn) + + +# TODO: the "all" in the name isn't true anymore for quite some time as we have also have for example XLA and MPS now. +# This should probably enumerate all available device type test base classes. +def get_all_device_types() -> list[str]: + return ["cpu"] if not torch.cuda.is_available() else ["cpu", "cuda"] + + +# skip since currently flex attention requires at least `avx2` support on CPU. +IS_FLEX_ATTENTION_CPU_PLATFORM_SUPPORTED = ( + not torch.xpu.is_available() + and not torch.cuda.is_available() + and not IS_MACOS + and torch.cpu._is_avx2_supported() + and os.getenv("ATEN_CPU_CAPABILITY") != "default" +) +IS_FLEX_ATTENTION_XPU_PLATFORM_SUPPORTED = ( + torch.xpu.is_available() and torch.utils._triton.has_triton() +) +flex_attention_supported_platform = unittest.skipUnless( + IS_FLEX_ATTENTION_XPU_PLATFORM_SUPPORTED + or IS_FLEX_ATTENTION_CPU_PLATFORM_SUPPORTED + or ( + torch.cuda.is_available() + and torch.utils._triton.has_triton() + and torch.cuda.get_device_capability() >= (8, 0) + ), + "Requires CUDA and Triton, Intel GPU and triton, or CPU with avx2 and later", +) +if torch.version.hip and "gfx94" in torch.cuda.get_device_properties(0).gcnArchName: + e4m3_type = torch.float8_e4m3fnuz + e5m2_type = torch.float8_e5m2fnuz + E4M3_MAX_POS = torch.finfo(torch.float8_e4m3fnuz).max + E5M2_MAX_POS = torch.finfo(torch.float8_e5m2fnuz).max +else: + e4m3_type = torch.float8_e4m3fn + e5m2_type = torch.float8_e5m2 + E4M3_MAX_POS = torch.finfo(torch.float8_e4m3fn).max + E5M2_MAX_POS = torch.finfo(torch.float8_e5m2).max diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/common_methods_invocations.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/common_methods_invocations.py new file mode 100644 index 0000000000000000000000000000000000000000..dac77fe9aa731ade2f96a87abb16af21699e2a6a --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/common_methods_invocations.py @@ -0,0 +1,25236 @@ +# mypy: ignore-errors + +from functools import wraps, partial +from itertools import product, chain, islice +import itertools +import functools +import copy +import operator +import random +import unittest +import math +import enum + +import torch +import numpy as np +import numpy.typing as npt +from torch import inf, nan + +from typing import Any, Union +from collections.abc import Sequence +from torch.testing import make_tensor +from torch.testing._internal.common_dtype import ( + _dispatch_dtypes, floating_types, floating_types_and, complex_types, floating_and_complex_types, + floating_and_complex_types_and, all_types_and_complex_and, all_types_and, all_types_and_complex, integral_types_and, + empty_types, complex_types_and, integral_types, custom_types, all_types_complex_float8_and, float8_types, +) +from torch.testing._internal.common_device_type import \ + (onlyCPU, onlyCUDA, onlyNativeDeviceTypes, disablecuDNN, skipCUDAIfNoMagma, skipCUDAIfNoMagmaAndNoCusolver, + skipCUDAIfNoCusolver, skipCPUIfNoLapack, skipCPUIfNoFFT, skipCUDAIf, precisionOverride, + skipCPUIfNoMklSparse, + toleranceOverride, tol, skipXPU) +from torch.testing._internal.common_cuda import ( + PLATFORM_SUPPORTS_FLASH_ATTENTION, PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, + SM53OrLater, SM80OrLater, SM89OrLater, with_tf32_off, TEST_CUDNN, +) +from torch.testing._internal.common_quantized import ( + _bfloat16_to_float4_e2m1fn_x2, +) +from torch.testing._internal.common_utils import ( + make_fullrank_matrices_with_distinct_singular_values, + TEST_WITH_ROCM, IS_FBCODE, IS_WINDOWS, IS_MACOS, IS_S390X, TEST_SCIPY, + torch_to_numpy_dtype_dict, numpy_to_torch_dtype, TEST_WITH_ASAN, + GRADCHECK_NONDET_TOL, slowTest, TEST_WITH_SLOW, + TEST_WITH_TORCHINDUCTOR, MACOS_VERSION, +) +from torch.testing._utils import wrapper_set_seed + +import torch._refs as refs # noqa: F401 +import torch._refs.nn.functional +import torch._refs.special +import torch._refs.linalg +import torch._prims as prims # noqa: F401 +from torch.utils import _pytree as pytree + + +from torch._vendor.packaging import version + +from torch.testing._internal.opinfo.core import ( # noqa: F401 + L, + M, + S, + XS, + _NOTHING, + _getattr_qual, + DecorateInfo, + SampleInput, + ErrorInput, + AliasInfo, + NumericsFilter, + OpInfo, + _generate_reduction_inputs, + _generate_reduction_kwargs, + sample_inputs_reduction, + ReductionOpInfo, + reference_inputs_elementwise_binary, + make_error_inputs_elementwise_binary, + generate_elementwise_binary_tensors, + generate_elementwise_binary_arbitrarily_strided_tensors, + generate_elementwise_binary_small_value_tensors, + generate_elementwise_binary_large_value_tensors, + generate_elementwise_binary_extremal_value_tensors, + generate_elementwise_binary_broadcasting_tensors, + generate_elementwise_binary_with_scalar_samples, + generate_elementwise_binary_with_scalar_and_type_promotion_samples, + generate_elementwise_binary_noncontiguous_tensors, + sample_inputs_elementwise_binary, + BinaryUfuncInfo, + sample_inputs_elementwise_unary, + generate_elementwise_unary_tensors, + generate_elementwise_unary_small_value_tensors, + generate_elementwise_unary_large_value_tensors, + generate_elementwise_unary_extremal_value_tensors, + reference_inputs_elementwise_unary, + UnaryUfuncInfo, + sample_inputs_spectral_ops, + SpectralFuncType, + SpectralFuncInfo, + ShapeFuncInfo, + sample_inputs_foreach, + ForeachFuncInfo, + gradcheck_wrapper_hermitian_input, + gradcheck_wrapper_ctc_loss, + gradcheck_wrapper_triangular_input, + gradcheck_wrapper_triangular_input_real_positive_diagonal, + gradcheck_wrapper_masked_operation, + gradcheck_wrapper_masked_pointwise_operation, + clone_sample, +) +from torch.testing._internal.opinfo.refs import ( # NOQA: F401 + _find_referenced_opinfo, + _inherit_constructor_args, + PythonRefInfo, + ReductionPythonRefInfo, + ElementwiseUnaryPythonRefInfo, + ElementwiseBinaryPythonRefInfo, +) +from torch.testing._internal.opinfo.utils import ( + np_unary_ufunc_integer_promotion_wrapper, + reference_reduction_numpy, + prod_numpy +) +from torch.testing._internal import opinfo +from torch.testing._internal.opinfo.definitions.linalg import ( + sample_inputs_linalg_cholesky, + sample_inputs_linalg_cholesky_inverse, + sample_inputs_cross, + sample_inputs_linalg_qr_geqrf, + sample_inputs_linalg_invertible, + sample_inputs_lu_solve, + sample_inputs_legacy_solve, + sample_inputs_svd, + sample_inputs_linalg_det_logdet_slogdet, + sample_inputs_linalg_lu, + sample_inputs_diagonal_diag_embed, + error_inputs_diagonal_diag_embed, +) +from torch.testing._internal.opinfo.definitions.special import ( + sample_inputs_i0_i1, + sample_inputs_polygamma, + reference_polygamma, +) +from torch.testing._internal.opinfo.definitions._masked import ( + sample_inputs_softmax_variant, +) +from torch.testing._internal.opinfo.definitions.sparse import ( + error_inputs_sparse_like_fns, + sample_inputs_sparse_like_fns, + error_inputs_sparse_mul, + sample_inputs_sparse_mul, + error_inputs_sparse_reduction_sum, + sample_inputs_sparse_reduction_sum +) + +if TEST_SCIPY: + from scipy import stats + import scipy.spatial + import scipy.special + + +def round_up(x: int, y: int) -> int: + return ((x + y - 1) // y) * y + + +# test if a tensor is close to an integer +def close_to_int(x, eps=0.1): + if x.is_complex(): + y = torch.abs(torch.view_as_complex(torch.frac(torch.view_as_real(x)))) + else: + y = torch.abs(torch.frac(x)) + return (y < eps) | (y > (1 - eps)) + + +def sample_inputs_slice(op_info, device, dtype, requires_grad, **kwargs): + + make_input = partial(make_tensor, device=device, dtype=dtype, + low=None, high=None, requires_grad=requires_grad) + + yield SampleInput(make_input(3), 0) + + yield SampleInput(make_input(20, 30, 40), dim=1, start=1, end=-2) + + yield SampleInput(make_input(20, 30, 40), dim=1, start=1, end=-2, step=3) + + yield SampleInput(make_input(20, 30, 40), dim=0, start=-10, end=-2, step=2) + + +def sample_inputs_tensor_split(op_info, device, dtype, requires_grad, **kwargs): + make_input = partial(make_tensor, device=device, dtype=dtype, + low=None, high=None, requires_grad=requires_grad) + + args_cases = ( + # Cases with tensor indices. + (torch.tensor([1, 2, 3]),), + (torch.tensor(1),), + (torch.tensor([1, 2, 3]), 1), + (torch.tensor([1, 4, 2, 5, 3, 6])[::2], 1), + # Cases with list of indices. + ((2, 4),), + ((2, 4), 1), + ((2, 4), -1), + # Cases with integer section. + (3,), + (3, 1), + (3, -1), + ) + + for args in args_cases: + yield SampleInput(make_input((S, S, S)), args=args) + + +def sample_inputs_hsplit(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, dtype=dtype, device=device, + low=None, high=None, requires_grad=requires_grad) + yield SampleInput(make_arg(6), 2) + yield SampleInput(make_arg(S, S, S), [1, 2, 3]) + +def sample_inputs_vsplit(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, dtype=dtype, device=device, + low=None, high=None, requires_grad=requires_grad) + yield SampleInput(make_arg(6, S), 2) + yield SampleInput(make_arg(S, S, S), [1, 2, 3]) + +def sample_inputs_dsplit(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, dtype=dtype, device=device, + low=None, high=None, requires_grad=requires_grad) + yield SampleInput(make_arg(S, S, S), [1, 2, 3]) + yield SampleInput(make_arg(S, S, 6), 2) + +def error_inputs_hsplit(op_info, device, **kwargs): + make_arg = partial(make_tensor, dtype=torch.float32, device=device) + err_msg1 = ("torch.hsplit requires a tensor with at least 1 dimension, " + "but got a tensor with 0 dimensions!") + yield ErrorInput(SampleInput(make_arg(()), 0), error_regex=err_msg1) + + err_msg2 = (f"torch.hsplit attempted to split along dimension 1, " + f"but the size of the dimension {S} " + f"is not divisible by the split_size 0!") + yield ErrorInput(SampleInput(make_arg((S, S, S)), 0), error_regex=err_msg2) + + # Incorrect type for indices_or_section argument + err_msg3 = ("received an invalid combination of arguments.") + yield ErrorInput( + SampleInput(make_arg((S, S, S)), "abc"), + error_type=TypeError, error_regex=err_msg3) + +def error_inputs_vsplit(op_info, device, **kwargs): + make_arg = partial(make_tensor, dtype=torch.float32, device=device) + err_msg1 = ("torch.vsplit requires a tensor with at least 2 dimension, " + "but got a tensor with 1 dimensions!") + yield ErrorInput(SampleInput(make_arg(S), 0), error_regex=err_msg1) + + err_msg2 = (f"torch.vsplit attempted to split along dimension 0, " + f"but the size of the dimension {S} " + f"is not divisible by the split_size 0!") + yield ErrorInput(SampleInput(make_arg(S, S, S), 0), + error_regex=err_msg2) + + # Incorrect type for indices_or_section argument + err_msg3 = ("received an invalid combination of arguments.") + yield ErrorInput(SampleInput(make_arg(S, S, S), "abc"), + error_type=TypeError, error_regex=err_msg3) + +def error_inputs_dsplit(op_info, device, **kwargs): + make_arg = partial(make_tensor, dtype=torch.float32, device=device) + err_msg1 = ("torch.dsplit requires a tensor with at least 3 dimension, " + "but got a tensor with 1 dimensions!") + yield ErrorInput(SampleInput(make_arg(S), 0), error_regex=err_msg1) + + err_msg2 = (f"torch.dsplit attempted to split along dimension 2, " + f"but the size of the dimension {S} " + f"is not divisible by the split_size 0!") + yield ErrorInput(SampleInput(make_arg(S, S, S), 0), error_regex=err_msg2) + + +def sample_inputs_as_strided(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + # input shape, output shape, output stride, output storage offset + test_cases = ( + ((1,), (1,), (1,), 0), + ((3, 3), (2, 2), (1, 2), 0), + ((3, 3), (2, 2), (1, 2), 1), + ((16,), (2, 2, 2, 2), (1, 1, 1, 1), 0), + ((16,), (2, 1, 1, 2), (1, 7, 7, 1), 0), + ) + + for input_shape, output_shape, stride, storage_offset in test_cases: + input_t = make_arg(input_shape) + kwargs = dict(storage_offset=storage_offset) + yield SampleInput(input_t, args=(output_shape, stride), kwargs=kwargs) + +def sample_inputs_as_strided_partial_views(op_info, device, dtype, requires_grad, **kwargs): + def make_arg(): + base = make_tensor((20,), device=device, dtype=dtype) + return base[5:15].requires_grad_(requires_grad) + + # as_strided on offset, partial views + yield SampleInput(make_arg(), (2, 2), (1, 2)) + yield SampleInput(make_arg(), (2, 2), (1, 2), storage_offset=0) + yield SampleInput(make_arg(), (2, 2), (1, 2), storage_offset=10) + +def sample_inputs_as_strided_scatter(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + # input shape, output shape, output stride, output storage offset + test_cases = [ + ((1,), (), (), 0), + ((1,), (1,), (1,), 0), + ((3, 3), (2, 2), (1, 2), 0), + ((3, 3), (2, 2), (1, 2), 1), + ((3, 3), (2, 2), (2, 1), 0), + # Scatter to larger dimensions + ((16,), (2, 2, 2, 2), (8, 4, 2, 1), 0), + # Scatter to larger dimensions with strides inverted + ((16,), (2, 1, 1, 2), (1, 2, 4, 8), 0), + ] + + for input_shape, output_shape, stride, storage_offset in test_cases: + input_t = make_arg(input_shape) + input_src = make_arg(output_shape) + yield SampleInput(input_t, input_src, output_shape, stride, storage_offset=storage_offset) + + +def error_inputs_as_strided_scatter(op_info, device, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=torch.float32, requires_grad=False) + + # Create a small tensor and try to scatter it out of bounds + input_t = make_arg([4, 4]) + input_src = make_arg([2, 2]) + yield ErrorInput( + SampleInput(input_t, input_src, [2, 2], [200, 200], storage_offset=0), + error_regex="itemsize 4 requiring a storage size of 1604 are out of bounds for storage of size 64" + ) + + +def sample_inputs_combinations(op_info, device, dtype, requires_grad, **kwargs): + inputs = ( + (0,), + (0, 1), + (0, 1, 2, 3), + ) + + rvals = [1, 2, 4] + + products = product(inputs, rvals, [False, True]) + + for input_data, r, with_replacement in products: + input_t = torch.tensor(input_data, device=device, dtype=dtype, requires_grad=requires_grad) + yield SampleInput(input_t, r=r, with_replacement=with_replacement) + +def sample_inputs_cartesian_prod(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(torch.tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + # constructs 1-D tensors with varying number of elements + a = make_arg((0,)) + b = make_arg((0, 1)) + c = make_arg((0, 1, 2, 3)) + + # sample with only 1 tensor + yield SampleInput(a) + + # sample with 2 tensors + yield SampleInput(a, b) + + # sample with 3 tensors + yield SampleInput(a, b, c) + +def sample_inputs_cosine_similarity(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + # Ordered as input_shape, dict of dim and eps + cases: tuple[tuple, dict] = ( # type: ignore[assignment] + ((S, S), {'dim': 1}), + ((S, 2), {'dim': -1}), + ((S,), {'dim': 0, 'eps': 0.5}), + ((), {'dim': 0}), + ((S, S, M), {'dim': 2}), + ((S, S), {}) + ) + + for input_shape, kwargs in cases: + yield SampleInput(make_arg(input_shape), args=(make_arg(input_shape),), kwargs=kwargs) + # Test for Broadcasting + yield SampleInput(make_arg((1, 2, 3)), args=(make_arg((2, 1, 3)),), kwargs={'dim': -1}) + yield SampleInput(make_arg((1, 2, 3)), args=(make_arg((2, 1, 3)),), kwargs={'dim': -2}) + yield SampleInput(make_arg((2, 3)), args=(make_arg((2, 1, 3)),), kwargs={'dim': -1}) + + +def sample_inputs_item(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=False) + + cases = ( + (), + (()), + (1), + ((1,)), + ) + + for shape in cases: + yield SampleInput(make_arg(shape)) + +def error_inputs_item(op, device, **kwargs): + make_arg = partial(make_tensor, dtype=torch.float32, device=device, requires_grad=False) + + cases = ( + (M), + ((S,)), + (S, S), + (S, M, L), + ) + + for shape in cases: + yield ErrorInput( + SampleInput(make_arg(shape)), error_type=RuntimeError, + error_regex="elements cannot be converted to Scalar") + + +def sample_inputs_batch_norm(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + make_arg_without_requires_grad = partial(make_tensor, device=device, dtype=dtype, requires_grad=False) + + # Ordered as: input shape, kwargs for training, momentum, eps + cases: tuple[tuple[int, ...], dict] = ( + ((S, S, S), {'training': True, 'momentum': 0.5, 'eps': 0.6}), + ((3, 2, 4), {'training': False, 'momentum': -1.2}), + ((3, 1), {'training': True, 'momentum': 0.0}), + ((0,), {'training': True}), + ((0,), {'training': False}), + ((3, 2, 3, 4), {'training': True, 'momentum': -1.0, 'eps': 0.5}), + ((3, 2, 3, 4), {'training': False, 'momentum': -1.0, 'eps': 0.5}), + ((2, 1), {}), + ) + + for input_shape, kwargs in cases: + # args: running mean, running var, weight and bias should necessarily be of shape: (channels,) + channels = input_shape[1] if len(input_shape) > 1 else 0 + weight = make_arg(channels) if channels > 0 else None + bias = make_arg(channels) if channels > 0 else None + running_mean = make_arg_without_requires_grad(channels, low=0) + running_var = make_arg_without_requires_grad(channels, low=0) + + yield SampleInput( + make_arg(input_shape), + args=( + running_mean, + running_var, + weight, + bias + ), + kwargs=kwargs + ) + + # Checking for permutations of weights and biases as `None` + is_training = [True, False, False] + + for training in is_training: + yield SampleInput( + make_arg(input_shape), + args=( + running_mean, + running_var, + make_arg(channels), + make_arg(channels) + ), + kwargs={'training': training} + ) + + # Test case for no optional kwargs + # running_mean and running_var are required in evaluation mode (training: False) but not in training mode + yield SampleInput(make_arg((1, 2, 3)), args=(None, None, None, None), kwargs={'training': True}) + +def sample_inputs_softmax_backward_data(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial( + make_tensor, device=device, dtype=dtype, requires_grad=requires_grad + ) + cases = [ + ((S,), 0), + ((S, S), 0), + ((S, M, S), -1), + ] + input_dtypes = [dtype] + if dtype == torch.float and device == 'cuda': + input_dtypes += [torch.float16] + + for (shape, dim), input_dtype in product(cases, input_dtypes): + input = make_arg(shape) + output = torch.nn.functional.softmax(input, dim=dim, dtype=input_dtype) + yield SampleInput(make_arg(shape), output, dim, input_dtype) + +def sample_inputs_native_batch_norm(op_info, device, dtype, requires_grad, **kwargs): + samples = sample_inputs_batch_norm(op_info, device, dtype, requires_grad, **kwargs) + for sample in samples: + # torch.native_batch_norm does not support 0 numel tensors + # IndexError: Dimension out of range (expected to be in range of [-1, 0], but got 1) + if sample.input.numel() == 0: + continue + args = sample.args + training = sample.kwargs.get('training', True) + momentum = sample.kwargs.get('momentum', 0.5) + eps = sample.kwargs.get('eps', 1e-5) + yield SampleInput(sample.input, args=(args[2], args[3], args[0], args[1], training, momentum, eps)) + + +def sample_inputs__native_batch_norm_legit(op_info, device, dtype, requires_grad, **kwargs): + samples = sample_inputs_batch_norm(op_info, device, dtype, requires_grad, **kwargs) + for sample in samples: + # torch.native_batch_norm does not support 0 numel tensors + # IndexError: Dimension out of range (expected to be in range of [-1, 0], but got 1) + if sample.input.numel() == 0: + continue + args = sample.args + training = sample.kwargs.get('training', True) + momentum = sample.kwargs.get('momentum', 0.5) + eps = sample.kwargs.get('eps', 1e-5) + if args[0] is not None and args[1] is not None: + yield SampleInput(sample.input, args=(args[2], args[3], args[0], args[1], training, momentum, eps)) + else: + yield SampleInput(sample.input, args=(args[2], args[3], training, momentum, eps)) + +def sample_inputs__batch_norm_with_update(op_info, device, dtype, requires_grad, **kwargs): + samples = sample_inputs_batch_norm(op_info, device, dtype, requires_grad, **kwargs) + for sample in samples: + # torch.native_batch_norm does not support 0 numel tensors + # IndexError: Dimension out of range (expected to be in range of [-1, 0], but got 1) + if sample.input.numel() == 0: + continue + args = sample.args + momentum = sample.kwargs.get('momentum', 0.5) + eps = sample.kwargs.get('eps', 1e-5) + if any(args[i] is None for i in range(4)): + continue + yield SampleInput(sample.input, args=(args[2], args[3], args[0], args[1], momentum, eps)) + +def sample_inputs_nn_activation_relu(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + cases = ( + (()), + ((S, )), + ((S, S)), + ((S, M, S)) + ) + + for shape in cases: + yield SampleInput(make_arg(shape)) + +def sample_inputs_prelu(op_info, device, dtype, requires_grad, **kwargs): + op_kwargs = op_info.sample_kwargs(device, dtype, None)[0] + yield from sample_inputs_elementwise_unary(op_info, device, dtype, requires_grad, + op_kwargs=op_kwargs) + + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + cases = ( + (()), + ((S, )), + ((S, S)), + ((S, M, S)) + ) + + for shape in cases: + for weight in [-1., 0., 0.8, 1.]: + weight_tensor = torch.tensor(weight, device=device, dtype=dtype, requires_grad=requires_grad) + yield SampleInput(make_arg(shape), args=(weight_tensor,)) + + channel_size = shape[1] if len(shape) >= 2 else 1 + yield SampleInput(make_arg(shape), args=(make_arg((channel_size,)),)) + + weight_tensor = torch.tensor(1., device=device, dtype=dtype, requires_grad=requires_grad) + + yield SampleInput(make_arg((S, S)), kwargs=dict(weight=weight_tensor,)) + yield SampleInput(make_arg((S, S)), kwargs=dict(weight=make_arg((S,)),)) + +def reference_inputs_prelu(op, device, dtype, requires_grad, **kwargs): + yield from sample_inputs_prelu(op, device, dtype, requires_grad, **kwargs) + yield from reference_inputs_elementwise_unary(op, device, dtype, requires_grad, **kwargs) + +def sample_kwargs_prelu_scalar_weight(device, dtype, input): + weight = torch.rand((), device=device, dtype=dtype) + # NumPy does not support bfloat16, so we default to float32 (only for NumPy) in that case + if dtype == torch.bfloat16: + weight_cpu = weight.to(dtype=torch.float32, device="cpu") + else: + weight_cpu = weight.cpu() + np_weight = weight_cpu.numpy() + return ({'weight': weight}, {'weight': np_weight}) + +def error_inputs_prelu(op, device): + # Weight has numel != 1, but self.ndim is zero-dim tensor + inp = make_tensor((), device=device, dtype=torch.float32) + weight = make_tensor((2,), device=device, dtype=torch.float32) + yield ErrorInput(SampleInput(inp, kwargs={'weight': weight}), + error_regex="Not allow zero-dim input tensor.") + + # Weight has numel != 1, but numel does not match channel size + inp = make_tensor((2, 8, 3), device=device, dtype=torch.float32) + weight = make_tensor((9,), device=device, dtype=torch.float32) + yield ErrorInput(SampleInput(inp, kwargs={'weight': weight}), + error_regex="Mismatch of parameter numbers and input channel size.") + + # Weight is neither a scalar nor 1-D tensor + inp = make_tensor((2, 8, 3), device=device, dtype=torch.float32) + weight = make_tensor((2, 4), device=device, dtype=torch.float32) + yield ErrorInput(SampleInput(inp, kwargs={'weight': weight}), + error_regex="prelu: Expected `weight` to be a scalar or 1D tensor, but got: ndim = 2") + + # src and index tensors must have the same # of dimensions +def sample_inputs_norm(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + # ord = inf is tested in inputs_norm_inf as it fails on some tests + cases = [ + ((S, S), (2,), '2'), + ((S, S), (0,), '0'), + ((S, S), (0.5,), '0_5'), + ((S, S), (1,), '1'), + ((S, S), (3,), '3'), + ((S, S), (-1,), 'neg_1'), + ((S, S), (-2,), 'neg_2'), + ((S, S), (-0.5,), 'neg_0_5'), + ((S, S), (-1.5,), 'neg_1_5'), + ] + + cases_nonzero_input = ( + ((S, S, S), (1.5,), '1_5_default'), + ((S, S, S), (1.5, 1), '1_5_dim'), + ((S, S, S), (1.5, -1), '1_5_neg_dim'), + ((S, S, S), (1.5, 1, True), 'keepdim_1_5_dim'), + ((S, S, S), (1.5, -1, True), 'keepdim_1_5_neg_dim'), + ) + + cases_posdim = ( + ((S, S), (-2, 1,), 'neg_2_dim'), + ((S, S), (-1, 1,), 'neg_1_dim'), + ((S, S), (0, 1,), '0_dim'), + ((S, S), (1, 1,), '1_dim'), + ((S, S), (2, 1,), '2_dim'), + ((S, S), (3, 1,), '3_dim'), + ((S, S, S), (2, 1), '2_dim'), + ((S, S, S), (3, 1), '3_dim'), + ((S, S, S), (2, 1, True), 'keepdim_2_dim'), + ((S, S, S), (3, 1, True), 'keepdim_3_dim'), + ((), (2, 0), '2_dim_scalar'), + ((), (3, 0), '3_dim_scalar'), + ((), (2, 0, True), 'keepdim_2_dim_scalar'), + ((), (3, 0, True), 'keepdim_3_dim_scalar'), + ) + + cases_negdim = ((shape, args[:1] + (-args[1],) + args[2:], name.replace("_dim", "_neg_dim")) + for shape, args, name in cases_posdim) + + for shape, args, name in itertools.chain(cases, cases_posdim, cases_negdim): + yield SampleInput(make_arg(shape), args=args, name=name) + + for shape, args, name in cases_nonzero_input: + yield SampleInput(make_arg(shape, exclude_zero=True), args=args, name=name) + + +def sample_inputs_norm_fro(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + cases = ( + ((S, S), (), 'default'), + ((S, S), ('fro',), 'fro_default'), + ((S, S), ('fro', [0, 1],), 'fro'), + ) + + for shape, args, name in cases: + yield SampleInput(make_arg(shape), args=args, name=name) + + +def sample_inputs_norm_nuc(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + cases = ( + ((S, S), ('nuc',), 'nuc'), + ((S, S, S), ('nuc', [1, 2]), 'nuc_batched'), + ) + + for shape, args, name in cases: + yield SampleInput(make_arg(shape), args=args, name=name) + + +def sample_inputs_norm_inf(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + cases = ( + ((S, S), (-inf,), '-inf'), + ((S, S), (inf,), 'inf'), + ((S, S), (inf, 1,), 'inf_2_dim'), + ((S, S), (inf, -1,), 'inf_2_neg_dim'), + ) + + for shape, args, name in cases: + yield SampleInput(make_arg(shape), args=args, name=name) + + +def sample_inputs_equal(op, device, dtype, requires_grad, **kwargs): + make_arg = partial( + make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + shapes = ( + ((), ()), + ((S,), ()), + ((), (S,)), + ((S, 1), (S,)), + ((M, S), ()), + ((S, S), (S, S)) + ) + + for shape_lhs, shape_rhs in shapes: + lhs = make_arg(shape_lhs) + rhs = make_arg(shape_rhs) + broadcasts_input = shape_lhs != torch.broadcast_shapes(shape_lhs, shape_rhs) + + yield SampleInput(lhs, args=(rhs,), broadcasts_input=broadcasts_input) + if shape_lhs == shape_rhs: + yield SampleInput(lhs, args=(lhs.clone().detach_(),)) + + +def sample_inputs_jiterator(op, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + shapes = ( + ((), ()), + ((S,), ()), + ((S, 1), (S,)), + ((M, S), ()), + ((S, M, S), (M, S)), + ((S, M, S), (S, M, S)), + ((M, 1, S), (M, S)), + ((M, 1, S), (1, M, S)), + ((0, 1, 3), (0, 10, 3)) + ) + + num_inputs = kwargs.get('num_inputs') + sample_kwargs = kwargs.get('sample_kwargs', {}) + + for shape_lhs, shape_rhs in shapes: + lhs = make_arg(shape_lhs) + args = [make_arg(shape_rhs) for _ in range(num_inputs - 1)] + broadcasts_input = (shape_lhs != torch.broadcast_shapes(shape_lhs, shape_rhs)) + + yield SampleInput(lhs, args=tuple(args), kwargs=sample_kwargs, broadcasts_input=broadcasts_input) + +def sample_inputs_broadcast_shapes(op, device, dtype, requires_grad, **kwargs): + shapes = ( + ((), ()), + ((S,), ()), + ((S, 1), (S,)), + ((S, 1), S), + ((M, S), ()), + ((S, M, S), (M, S)), + ((S, M, S), (S, M, S)), + ((M, 1, S), (M, S)), + ((M, 1, S), (1, M, S)), + ((0, 1, 3), (0, 10, 3)) + ) + + for shape in shapes: + inp, *arg0 = shape + yield SampleInput(inp, args=tuple(arg0)) + +def sample_inputs_add_sub(op, device, dtype, requires_grad, **kwargs): + yield from sample_inputs_elementwise_binary(op, device, dtype, requires_grad, **kwargs) + + # Adds alpha kwarg cases + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + lhs = make_arg((S, S), **op.lhs_make_tensor_kwargs) + rhs = make_arg((S, S), **op.rhs_make_tensor_kwargs) + if dtype is not torch.bool: + yield SampleInput(lhs, args=(rhs,), kwargs={'alpha': 2}) + else: + yield SampleInput(lhs, args=(rhs,), kwargs={'alpha': True}) + neg_alpha = -3.125 if (dtype.is_floating_point or dtype.is_complex) else -3 + lhs = make_arg((S, S), **op.lhs_make_tensor_kwargs) + rhs = make_arg((S, S), **op.rhs_make_tensor_kwargs) + if dtype is not torch.bool: + yield SampleInput(lhs, args=(rhs,), kwargs={'alpha': neg_alpha}) + else: + yield SampleInput(lhs, args=(rhs,), kwargs={'alpha': False}) + +def error_inputs_arange(op, device, **kwargs): + yield ErrorInput(SampleInput(0, args=(3, 0)), error_type=RuntimeError, error_regex='step must be nonzero') + yield ErrorInput(SampleInput(0, args=(-3, 2)), error_type=RuntimeError, + error_regex='upper bound and lower bound inconsistent with step sign') + yield ErrorInput(SampleInput(0, args=(3, -2)), error_type=RuntimeError, + error_regex='upper bound and lower bound inconsistent with step sign') + yield ErrorInput(SampleInput(1549556900, args=(1549556828, 1989724)), error_type=RuntimeError, + error_regex='upper bound and lower bound inconsistent with step sign') + yield ErrorInput(SampleInput(0, args=(float('inf'), 2)), error_type=RuntimeError, error_regex='unsupported range') + yield ErrorInput(SampleInput(float('-inf'), args=(1, 2)), error_type=RuntimeError, error_regex='unsupported range') + +def sample_inputs_arange(op, device, dtype, requires_grad, **kwargs): + int_samples = ( + # positive direction + (-1, 2, 2), + # negative direction + (2, -3, -1), + # start == end + (1, 1, 1), + (1, 1, -1), + # divides evenly + (0, -8, -4), + (1, 5, 2), + # bool + (False, True, True), + # default step + (0, 1, None), + # default start + (None, 3, None), + ) + + def to_float(start, end, step): + start = start + 0.1 if start is not None else None + end = end + 0.1 + step = float(step) if step is not None else None + return start, end, step + + float_samples = ( + # includes endpoint + (0., -8. - 1e-6, -4.), + (1., 5. + 1e-6, 2.), + (0., -8., -4.), + (1., 5., 2.), + *(to_float(start, end, step) for (start, end, step) in int_samples), + ) + + large_samples = ( + (0, 10000, None), + ) + + samples = int_samples + float_samples + if dtype not in (torch.int8, torch.uint8): + samples += large_samples + + for start, end, step in samples: + if start is None: + assert step is None + # Pass end as positional arg + yield SampleInput(end, kwargs={"dtype": dtype, "device": device}) + # (Similar to) calling torch.arange(end=3) + yield SampleInput(0, kwargs={"end": end, "dtype": dtype, "device": device}) + elif step is None: + yield SampleInput(start, args=(end,), kwargs={"dtype": dtype, "device": device}) + else: + yield SampleInput(start, args=(end, step), kwargs={"dtype": dtype, "device": device}) + + yield SampleInput(2) + yield SampleInput(1, args=(3, 1)) + +def sample_inputs_randn(op, device, dtype, requires_grad, **kwargs): + shapes = ( + (M,), + (S, S) + ) + + for shape in shapes: + yield SampleInput(input=shape, kwargs=dict(dtype=dtype, device=device, requires_grad=requires_grad)) + +def sample_inputs_normal(op, device, dtype, requires_grad, **kwargs): + + make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=False) + samples = ( + ((S, S), 0, 5), + ((S, S, S), -2, 0.5), + ) + for shape, mean, std in samples: + yield SampleInput(make_arg(shape), args=(mean, std)) + +def error_inputs_normal(op, device, **kwargs): + t = torch.zeros([10], device=device) + invalid_std = -1 + yield ErrorInput( + SampleInput(t, args=(0, invalid_std)), + error_type=RuntimeError, + error_regex=fr"normal expects std >= 0.0, but found std {invalid_std}", + ) + +def sample_inputs_cauchy(op, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=False) + samples = ( + ((M,), 0, 0.5), + ((S, S), 0, 1), + ((S, S, S), -2, 1), + ) + for shape, median, gamma in samples: + yield SampleInput(make_arg(shape), args=(median, gamma)) + + +def error_inputs_cauchy(op, device, **kwargs): + t = torch.zeros([10], device=device) + invalid_scale = 0 + yield ErrorInput( + SampleInput(t, args=(0, invalid_scale,)), + error_type=RuntimeError, + error_regex=fr"cauchy_ expects sigma > 0.0, but found sigma={invalid_scale}", + ) + + +def sample_inputs_exponential(op, device, dtype, requires_grad, **kwargs): + + make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=False) + samples = ( + ((M,), 0.5), + ((S, S), 1), + ((S, S, S), 1.5), + ) + for shape, rate in samples: + yield SampleInput(make_arg(shape), args=(rate,)) + + +def error_inputs_exponential(op, device, **kwargs): + t = torch.zeros([10], device=device) + invalid_rate = 0 + yield ErrorInput( + SampleInput(t, args=(invalid_rate,)), + error_type=RuntimeError, + error_regex=fr"exponential_ expects lambda > 0.0, but found lambda={invalid_rate}", + ) + + +def sample_inputs_geometric(op, device, dtype, requires_grad, **kwargs): + + make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=False) + samples = ( + ((M,), 0.2), + ((S, S), 0.5), + ((S, S, S), 0.8), + ) + for shape, rate in samples: + yield SampleInput(make_arg(shape), args=(rate,)) + + +def error_inputs_geometric(op, device, **kwargs): + t = torch.zeros([10], device=device) + neg_prob = -1 + yield ErrorInput( + SampleInput(t, args=(neg_prob,)), + error_type=RuntimeError, + error_regex=fr"geometric_ expects p to be in \(0, 1\), but got p={neg_prob}", + ) + + +def sample_inputs_log_normal(op, device, dtype, requires_grad, **kwargs): + + make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=False) + samples = ( + ((M,), 0, 0.25), + ((S, S), 0.5, 1), + ((S, S, S), 0, 0.5), + ) + for shape, mean, std in samples: + yield SampleInput(make_arg(shape), args=(mean, std)) + + +def error_inputs_log_normal(op, device, **kwargs): + t = torch.zeros([10], device=device) + invalid_std = 0 + yield ErrorInput( + SampleInput(t, args=(0, invalid_std)), + error_type=RuntimeError, + error_regex=fr"log_normal_ expects std > 0.0, but found std={invalid_std}", + ) + + +def sample_inputs_uniform(op, device, dtype, requires_grad, **kwargs): + + make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=False) + samples = ( + ((M,), -100, 100), + ((S, S), 0, 1), + ((S, S, S), 1, 2), + ) + for shape, hi, lo in samples: + yield SampleInput(make_arg(shape), args=(hi, lo)) + +def sample_inputs_ones_zeros(op, device, dtype, requires_grad, **kwargs): + # this is a bit messy, as we want the args to be tuples + # so if we pass size as a tuple, we have a tuple containing a tuple + sizes = ( + (M,), + (S, S), + ) + for size in sizes: + yield SampleInput(size, kwargs={'dtype': dtype, 'device': device}) + +def sample_inputs_full(op, device, dtype, requires_grad, **kwargs): + def get_val(dtype): + return make_tensor([], dtype=dtype, device="cpu").item() + + sizes = ( + (M,), + (S, S), + ) + fill_values = [get_val(dtype), get_val(torch.int)] + + for size, fill_value in product(sizes, fill_values): + yield SampleInput(size, fill_value, dtype=dtype, device=device) + + +def error_inputs_uniform(op, device, **kwargs): + t = torch.zeros([10], device=device) + yield ErrorInput( + SampleInput(t, args=(3, -1)), + error_type=RuntimeError, + error_regex=r"uniform_ expects to return a \[from, to\) range, but found from=3 > to=-1", + ) + + +def error_inputs_linspace(op, device, **kwargs): + yield ErrorInput(SampleInput(0, args=(3, -1)), error_type=RuntimeError, error_regex='number of steps must be non-negative') + yield ErrorInput( + SampleInput(0, args=(3, 1.)), + error_type=TypeError, + error_regex="received an invalid combination of arguments - got \\(int, int, float", + ) + yield ErrorInput( + SampleInput(torch.tensor([1, 1], device=device), args=(torch.tensor([3, 3], device=device), 1)), + error_type=RuntimeError, + error_regex="only supports 0-dimensional start and end tensors" + ) + + +def sample_inputs_linspace(op, device, dtype, requires_grad, **kwargs): + ends = (-3, 0, 1, 4, 50) + starts = (-2., 0, 4.3, 50) + nsteps = (0, 1, 50) + # Extra case to replicate off-by-one issue on CUDA + cases = list(product(starts, ends, nsteps)) + [(0, 7, 50)] + for start, end, nstep in cases: + if dtype == torch.uint8 and (end < 0 or start < 0): + continue + yield SampleInput(start, args=(end, nstep), kwargs={"dtype": dtype, "device": device}) + + yield SampleInput(1, args=(3, 1)) + + +def sample_inputs_linspace_tensor_overload(op, device, dtype, requires_grad, **kwargs): + ends = (-3, 0, 1, 4, 50) + starts = (-2., 0, 4.3, 50) + nsteps = (0, 1, 50) + is_start_end_tensors = ((True, True), (True, False), (False, True)) + make_arg = partial(torch.tensor, device=device, requires_grad=False) + + # Extra case to replicate off-by-one issue on CUDA + cases = list(product(starts, ends, nsteps, is_start_end_tensors)) + [(0, 7, 50, (True, True))] + for start, end, nstep, (is_start_tensor, is_end_tensor) in cases: + if dtype == torch.uint8 and (end < 0 or start < 0): + continue + + tensor_options = {"dtype": dtype, "device": device} + if is_start_tensor: + start = make_arg(start, dtype=torch.float32 if isinstance(start, float) else torch.int64) + if is_end_tensor: + end = make_arg(end, dtype=torch.float32 if isinstance(end, float) else torch.int64) + + yield SampleInput(start, args=(end, nstep), kwargs=tensor_options) + + yield SampleInput(1, args=(3, 1)) + + +def sample_inputs_logspace(op, device, dtype, requires_grad, **kwargs): + ends = (-3, 0, 1.2, 2, 4) + starts = (-2., 0, 1, 2, 4.3) + nsteps = (0, 1, 2, 4) + bases = (2., 1.1) if dtype in (torch.int8, torch.uint8) else (None, 2., 3., 1.1, 5.) + for start, end, nstep, base in product(starts, ends, nsteps, bases): + if dtype == torch.uint8 and end < 0 or start < 0: + continue + if nstep == 1 and isinstance(start, float) and not (dtype.is_complex or dtype.is_floating_point): + # https://github.com/pytorch/pytorch/issues/82242 + continue + if base is None: + yield SampleInput(start, args=(end, nstep), kwargs={"dtype": dtype, "device": device}) + else: + yield SampleInput(start, args=(end, nstep, base), kwargs={"dtype": dtype, "device": device}) + + yield SampleInput(1, args=(3, 1, 2.)) + + +def sample_inputs_logspace_tensor_overload(op, device, dtype, requires_grad, **kwargs): + ends = (-3, 0, 1.2, 2, 4) + starts = (-2., 0, 1, 2, 4.3) + nsteps = (0, 1, 2, 4) + bases = (2., 1.1) if dtype in (torch.int8, torch.uint8) else (None, 2., 3., 1.1, 5.) + is_start_end_tensors = ((True, True), (True, False), (False, True)) + make_arg = partial(torch.tensor, device=device) + for start, end, nstep, base, (is_start_tensor, is_end_tensor) in product(starts, ends, nsteps, bases, is_start_end_tensors): + if dtype == torch.uint8 and end < 0 or start < 0: + continue + if nstep == 1 and isinstance(start, float) and not (dtype.is_complex or dtype.is_floating_point): + # https://github.com/pytorch/pytorch/issues/82242 + continue + + tensor_options = {"dtype": dtype, "device": device} + + if (is_start_tensor): + start = make_arg(start, dtype=torch.float32 if isinstance(start, float) else torch.int64) + if (is_end_tensor): + end = make_arg(end, dtype=torch.float32 if isinstance(end, float) else torch.int64) + + if base is None: + yield SampleInput(start, args=(end, nstep), kwargs=tensor_options) + else: + yield SampleInput(start, args=(end, nstep, base), kwargs=tensor_options) + + yield SampleInput(1, args=(3, 1, 2.)) + + +def sample_inputs_isclose(op, device, dtype, requires_grad, **kwargs): + yield from sample_inputs_elementwise_binary(op, device, dtype, requires_grad, **kwargs) + + # Creates additional inputs to test the rtol, atol, and equal_nan params + rtols = [0., 1e-7] + atols = [0., 1e-7] + equal_nans = [False, True] + + products = product(rtols, atols, equal_nans) + + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + for rtol, atol, equal_nan in products: + lhs = make_arg((S, S), **op.lhs_make_tensor_kwargs) + rhs = make_arg((S, S), **op.rhs_make_tensor_kwargs) + + yield SampleInput(lhs, args=(rhs,), + kwargs=dict(rtol=rtol, atol=atol, equal_nan=equal_nan)) + + +def error_inputs_isclose(op, device, **kwargs): + make_float_arg = partial(make_tensor, device=device, dtype=torch.float, requires_grad=False) + + yield ErrorInput(SampleInput(make_float_arg(()), args=(make_float_arg(()),), kwargs={'rtol': -0.4}), + error_type=RuntimeError, + error_regex='rtol must be greater than or equal to zero') + + yield ErrorInput(SampleInput(make_float_arg(()), args=(make_float_arg(()),), kwargs={'atol': -0.4}), + error_type=RuntimeError, + error_regex='atol must be greater than or equal to zero') + + +def sample_inputs_t(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + yield SampleInput(make_arg((1, 2))) + yield SampleInput(make_arg((2,))) + yield SampleInput(make_arg(())) + + +def sample_inputs_mm(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + def make_arg_conj(size): + return make_arg(size).conj().requires_grad_(requires_grad) + + first_shape, second_shape = (S, M), (M, S) + + yield SampleInput(make_arg(first_shape), args=(make_arg(second_shape),)) + + if dtype.is_complex: + yield SampleInput(make_arg(first_shape), args=(make_arg_conj(second_shape),)) + + # Matmul of empty matrices + yield SampleInput(make_arg((0, S)), args=(make_arg(S, M),)) + yield SampleInput(make_arg((S, 0)), args=(make_arg(0, M),)) + + +def sample_inputs_addmm(op_info, device, dtype, requires_grad, **kwargs): + alpha_val = kwargs.get('alpha', 2 + 3j if dtype.is_complex else 0.6 if dtype.is_floating_point else 2) + beta_val = kwargs.get('beta', 1 + 2j if dtype.is_complex else 0.2 if dtype.is_floating_point else 3) + tests_list = [ + ((2, 3), (2, 2), (2, 3), False), + ((3, 3), (3, 3), (3, 3), False), + ] + tests_with_lhs_broadcasting = [ + ((1,), (2, 2), (2, 3), True), + ((), (2, 2), (2, 3), True), + ] + test_cases = tests_list + tests_with_lhs_broadcasting # type: ignore[operator] + + kwargs = dict(alpha=alpha_val, beta=beta_val) + make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) + for shape_a, shape_b, shape_c, broadcasts_input in test_cases: + yield SampleInput( + make_arg(shape_a), + make_arg(shape_b), + make_arg(shape_c), + **kwargs, + ).with_metadata(broadcasts_input=broadcasts_input) + + if dtype.is_complex: + shape = (3, 3) + yield SampleInput( + make_arg(shape), + make_arg(shape, requires_grad=False).mH.requires_grad_(requires_grad), + make_arg(shape), + **kwargs, + ) + yield SampleInput( + make_arg(shape), + make_arg(shape), + make_arg(shape, requires_grad=False).mH.requires_grad_(requires_grad), + **kwargs, + ) + # addmm of empty matrices + if dtype.is_floating_point: + yield SampleInput(make_arg(S, M), make_arg(S, 0), make_arg(0, M), **kwargs) + # empty matmul with broadcastable input + yield SampleInput(make_arg(M), make_arg(S, 0), make_arg(0, M), **kwargs).with_metadata(broadcasts_input=True) + +def sample_inputs_sparse_sampled_addmm(op_info, device, dtype, requires_grad, **kwargs): + alpha = 2 + 3j if dtype.is_complex else 0.6 + beta = 1 + 2j if dtype.is_complex else 0.2 + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + # sparse.sampled_addmm performs: alpha * (A @ B) * sparse_ones_like(C) + beta * C + for m, n, k in itertools.product([0, 5], repeat=3): + yield SampleInput( + torch.eye(m, n, device=device, dtype=dtype) + .to_sparse_csr() + .requires_grad_(requires_grad), + make_arg((m, k)), + make_arg((k, n)), + alpha=alpha, + beta=beta, + ) + +def sample_inputs_sparse_mm_reduce(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + reductions = ["sum", "mean", "amax", "amin"] + for m, k, reduce in product([5, 7], [3, 11], reductions): + yield SampleInput( + torch.eye(m, m) + .to(device=device, dtype=dtype) + .to_sparse_csr() + .requires_grad_(requires_grad), + make_arg((m, k)), + reduce, + ) + + +def sample_inputs_mv(self, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, dtype=dtype, device=device, low=None, high=None, requires_grad=requires_grad) + yield SampleInput(make_arg(S, M), make_arg(M)) + +def sample_inputs_bmm(self, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, dtype=dtype, device=device, low=None, high=None, requires_grad=requires_grad) + yield SampleInput(make_arg(M, S, M), make_arg(M, M, S)) + +def sample_inputs_dot_vdot(self, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + def make_arg_conj(size): + return make_arg(size).conj().requires_grad_(requires_grad) + + yield SampleInput(make_arg((S, )), make_arg((S, ))) + if dtype.is_complex: + # dot/vdot for (conj(input), conj(arg_tensor)) and (conj(input), arg_tensor) + # is tested in test_conj_view (which tests operations with only conjugated input tensor + # -- not conjugated arg tensors) + yield SampleInput(make_arg((S, )), make_arg_conj((S, ))) + + +def error_inputs_dot_vdot(op_info, device, is_ref=False, **kwargs): + make_input = partial(make_tensor, device=device, dtype=torch.float32) + + yield ErrorInput(SampleInput(make_input(1), args=(make_input(3, dtype=torch.float16),)), + error_regex='dot : expected both vectors to have same dtype') + yield ErrorInput(SampleInput(make_input(1, 1), args=(make_input(3),)), + error_regex='1D tensors expected') + yield ErrorInput(SampleInput(make_input(9), args=(make_input(3),)), + error_regex='inconsistent tensor size') + if device != "cpu" and not is_ref: + yield ErrorInput(SampleInput(make_input(3), args=(make_input(3, device="cpu"),)), + error_regex='Expected all tensors to be on the same device') + + +def sample_inputs_addmv(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) + + test_cases = (((S,), (S, M), (M,), 1, 1, False), + ((S,), (S, M), (M,), 0.2, 0.6, False), + ) + + test_cases_with_broadcast = (((1,), (S, M), (M,), 1, 1, True), + ((1,), (S, M), (M,), 0.2, 0.6, True), + ((), (S, M), (M,), 1, 1, True), + ((), (S, M), (M,), 0.2, 0.6, True), + ) + + cases = test_cases + test_cases_with_broadcast + + # addmv performs: beta * M + alpha * (mat @ vec) + for size, mat, vec, beta, alpha, broadcasts_input in cases: + yield SampleInput(make_arg(size), args=(make_arg(mat), make_arg(vec)), + kwargs=dict(beta=beta, alpha=alpha), broadcasts_input=broadcasts_input) + +def sample_inputs_addbmm(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + # input_shape, batch1_shape, batch2_shape, beta_val, alpha_val, is_broadcasting + test_cases = [((S, M), (S, S, S), (S, S, M), 1, 1, False), + ((1,), (S, S, S), (S, S, M), 1, 1, True), + ((S, M), (S, S, S), (S, S, M), 0.6, 0.2, False), + ((1,), (S, S, S), (S, S, M), 0.6, 0.2, True), + ((), (S, S, S), (S, S, M), 1, 1, True), + ((), (S, S, S), (S, S, M), 0.6, 0.2, True), + ] + + for input_shape, batch1_shape, batch2_shape, beta, alpha, is_broadcasting in test_cases: + if dtype.is_complex: + beta_complex, alpha_complex = beta * (1 + 2j), alpha * (2 + 3j) + yield SampleInput(make_arg(input_shape), args=(make_arg(batch1_shape), make_arg(batch2_shape)), + kwargs=dict(beta=beta_complex, alpha=alpha_complex), broadcasts_input=is_broadcasting) + yield SampleInput(make_arg(input_shape), args=(make_arg(batch1_shape), make_arg(batch2_shape)), + kwargs=dict(beta=beta, alpha=alpha), broadcasts_input=is_broadcasting) + +def sample_inputs_addcmul_addcdiv(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + test_cases = [(((S, S), (S, S), (S, S)), False), + (((S, S), (S, 1), (1, S)), False), + (((1,), (S, S, 1), (1, S)), True), + (((), (), ()), False), + (((S, S), (), ()), True), + (((), (S, S, 1), (1, S)), True) + ] + + for input_args, broadcasts_input in test_cases: + # addcdiv should accept inputs with zero value + # Currently, it throws ZeroDivisionError when the denominator is zero + # TODO: exclude_zeros can be removed after https://github.com/pytorch/pytorch/issues/73638 is fixed + args = tuple(make_arg(arg, exclude_zero=True) if isinstance(arg, tuple) else arg + for arg in input_args) + yield SampleInput(*args).with_metadata(broadcasts_input=broadcasts_input) + + # addcdiv should accept inputs with zero value + # Currently, it throws ZeroDivisionError when the denominator is zero + # TODO: exclude_zeros can be removed after https://github.com/pytorch/pytorch/issues/73638 is fixed + args = tuple(make_arg(arg, exclude_zero=True) if isinstance(arg, tuple) else arg + for arg in input_args) + yield SampleInput( + *args, value=3.14 if dtype.is_floating_point or dtype.is_complex else 3 + ).with_metadata(broadcasts_input=broadcasts_input) + +def reference_inputs_addcmul_addcdiv(op_info, device, dtype, requires_grad, **kwargs): + yield from sample_inputs_addcmul_addcdiv( + op_info, device, dtype, requires_grad, **kwargs) + + # type promotion cases + supported_dtypes = op_info.supported_dtypes(device) + make_arg = partial(make_tensor, device=device, requires_grad=requires_grad) + + types = ( + (torch.float64, torch.complex128), + (torch.bfloat16, torch.float32), + ) + + values = ( + None, + True, False, + 3.14, 3, + 1.0, 1, + 0.0, 0, + -3.14, -3, + 3.14 + 2.71j, + ) + + for (type2, type3), value in product(types, values): + if (type2 not in supported_dtypes or + type3 not in supported_dtypes): + continue + + # RuntimeError: value cannot be converted without overflow + if (type(value) is complex and + type2 is not torch.complex128): + continue + + arg1 = make_arg([5, 5], dtype=dtype) + arg2 = make_arg([5, 5], dtype=type2) + arg3 = make_arg([1, 5], dtype=type3) + + # TypeError: addcdiv(): argument 'value' must be Number, not NoneType + if value is not None: + yield SampleInput(arg1, args=(arg2, arg3), kwargs=dict(value=value)) + else: + yield SampleInput(arg1, args=(arg2, arg3)) + +def sample_inputs_baddbmm(op_info, device, dtype, requires_grad, **kwargs): + test_cases = [((S, S, M), (S, S, S), (S, S, M), 1, 1, False), + ((1,), (S, S, S), (S, S, M), 1, 1, True), + ((S, S, M), (S, S, S), (S, S, M), 0.6, 0.2, False), + ((1,), (S, S, S), (S, S, M), 0.6, 0.2, True), + ((), (S, S, S), (S, S, M), 1, 1, True), + ((), (S, S, S), (S, S, M), 0.6, 0.2, True), + ] + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad, low=None, high=None) + for (input_shape, batch1_shape, batch2_shape, alpha, beta, broadcasts_input) in test_cases: + yield SampleInput( + make_arg(input_shape), + make_arg(batch1_shape), + make_arg(batch2_shape), + beta=beta, + alpha=alpha + ).with_metadata(broadcasts_input=broadcasts_input) + + if dtype.is_complex: + yield SampleInput( + make_arg(input_shape), + make_arg(batch1_shape), + make_arg(batch2_shape), + beta=beta * (1 + 2j), + alpha=alpha * (2 + 3j), + ).with_metadata(broadcasts_input=broadcasts_input) + + if dtype.is_complex: + shapes = [(S, S, S), (S, M, S), (S, S, M)] + args = tuple(make_arg(s) for s in shapes) + yield SampleInput( + args[0].transpose_(-1, 1), + args[1].transpose(-1, 1).conj().requires_grad_(requires_grad), + args[2].transpose(-1, 1).conj().requires_grad_(requires_grad), + beta=beta * (1 + 2j), + alpha=alpha * (2 + 3j), + ) + +# TODO: add reduction kwargs +def sample_inputs_multilabel_soft_margin_loss(op_info, device, dtype, requires_grad, **kwargs): + _make_tensor = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + shapes = ( + (S,), + (S, S), + ) + + for shape in shapes: + # Produce one with weight and one without. + yield SampleInput(_make_tensor(shape), args=(_make_tensor(shape, requires_grad=False),), kwargs={}) + yield SampleInput(_make_tensor(shape), args=(_make_tensor(shape, requires_grad=False),), + kwargs={'weight': _make_tensor(shape, requires_grad=False)}) + +def sample_inputs_addr(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial( + make_tensor, device=device, dtype=dtype, requires_grad=requires_grad, low=None, high=None + ) + yield SampleInput(make_arg(S, M), make_arg(S), make_arg(M)) + + yield SampleInput(make_arg(), make_arg(S), make_arg(M)).with_metadata(broadcasts_input=True) + + if dtype.is_complex: + alpha, beta = 0.1 + 0.3j, 0.4 + 0.6j + elif dtype.is_floating_point: + alpha, beta = 0.2, 0.6 + else: + alpha, beta = 2, 3 + + yield SampleInput(make_arg(S, M), make_arg(S), make_arg(M), beta=beta, alpha=alpha) + + yield SampleInput( + make_arg(), + make_arg(S), + make_arg(M), + beta=beta, + alpha=alpha, + ).with_metadata(broadcasts_input=True) + + # These samples fail gradcheck + if dtype.is_floating_point and not requires_grad: + tensor_options = dict(device=device, dtype=dtype, requires_grad=requires_grad) + yield SampleInput( + torch.tensor([[math.nan]], **tensor_options), + torch.tensor([0.0], **tensor_options), + torch.tensor([0.0], **tensor_options), + beta=0.0, + alpha=0.0, + ).with_metadata(broadcasts_input=True) + + yield SampleInput( + torch.tensor([[0.0]], **tensor_options), + torch.tensor([math.nan], **tensor_options), + torch.tensor([math.nan], **tensor_options), + beta=0.0, + alpha=0.0, + ).with_metadata(broadcasts_input=True) + +def sample_inputs_zero_(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + cases = ((), (S, S, S), (S,)) + + for shape in cases: + yield SampleInput(make_arg(shape)) + +def sample_inputs_multi_margin_loss(op_info, device, dtype, requires_grad, **kwargs): + _make_tensor = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + make_target = partial(_make_tensor, dtype=torch.long, requires_grad=False) + make_weight = partial(_make_tensor, requires_grad=False) + + inputs = ( + ((), make_target([], low=0, high=1), {}), + ((S,), make_target([], low=0, high=S), {"p": 1}), + ((S,), make_target([1], low=0, high=S), {"p": 2}), + ((S, M), make_target([S], low=0, high=M), {"margin": 1.0}), + ((S, M), make_target([S], low=0, high=M), {"margin": -3.14}), + ((M, S), make_target([M], low=0, high=S), {"weight": None}), + ((M, S), make_target([M], low=0, high=S), {"weight": make_weight([S], low=-10., high=10.)}), + ((M, S), make_target([M], low=0, high=S), {"reduction": "none"}), + ((M, S), make_target([M], low=0, high=S), {"reduction": "mean"}), + ((M, S), make_target([M], low=0, high=S), {"reduction": "sum"}), + ) + + for input_shape, target, kwargs in inputs: + yield SampleInput(_make_tensor(input_shape), args=(target,), kwargs=kwargs) + + +def reference_inputs_multi_margin_loss(op_info, device, dtype, requires_grad, **kwargs): + yield from sample_inputs_multi_margin_loss(op_info, device, dtype, requires_grad, **kwargs) + _make_tensor = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + make_target = partial(_make_tensor, dtype=torch.long, requires_grad=False) + make_weight = partial(_make_tensor, requires_grad=False) + + inputs = ( + ((), make_target([], low=0, high=1)), + ((S,), make_target([], low=0, high=S)), + ((S,), make_target([1], low=0, high=S)), + ((M, S), make_target([M], low=0, high=S)), + ) + ps = (1, 2) + margins = (0, 7, -3.14) + weights = (False, True) + reductions = (None, "none", "mean", "sum") + + for (input_shape, target), p, margin, weight, reduction in product(inputs, ps, margins, weights, reductions): + input = _make_tensor(input_shape) + weight_shape = [input.size(-1)] if input.ndim > 0 else [1] + weight = make_weight(weight_shape, low=-10., high=10.) if weight else None + kwargs = {"p": p, "margin": margin, "weight": weight} + if reduction is not None: + kwargs["reduction"] = reduction + yield SampleInput(input, args=(target,), kwargs=kwargs) + + +def error_inputs_multi_margin_loss(op, device, **kwargs): + make_input = partial(make_tensor, device=device, dtype=torch.float32) + # invalid reduction + yield ErrorInput(SampleInput(make_input(5, 4), args=(make_input(5,),), kwargs={'reduction': 'abc'}), + error_type=ValueError, error_regex='abc is not a valid value for reduction') + # invalid input + yield ErrorInput(SampleInput(make_input(5, 0), args=(make_input(5,),), kwargs={}), + error_type=RuntimeError, + error_regex=r'Expected non-empty vector or matrix with optional 0-dim batch size, but got: \[5, 0\]') + yield ErrorInput(SampleInput(make_input(0,), args=(make_input(5,),), kwargs={}), + error_type=RuntimeError, + error_regex=r'Expected non-empty vector or matrix with optional 0-dim batch size, but got: \[0\]') + # invalid target + yield ErrorInput(SampleInput(make_input(5, 4), args=(make_input(5, 4),), kwargs={}), + error_type=RuntimeError, error_regex=r'inconsistent target size, expected 5 but got \[5, 4\]') + # invalid target dtype + yield ErrorInput(SampleInput(make_input(5, 4), args=(make_input(5,),), kwargs={}), + error_type=RuntimeError, error_regex='expected scalar type Long but found Float') + # invalid weight + yield ErrorInput(SampleInput(make_input(5, 4), args=(make_input(5,),), kwargs={'weight': make_input(())}), + error_type=ValueError, error_regex='weight must be one-dimensional') + yield ErrorInput(SampleInput(make_input(5, 4), args=(make_input(5,),), kwargs={'weight': make_input(5, 4)}), + error_type=ValueError, error_regex='weight must be one-dimensional') + yield ErrorInput(SampleInput(make_input(5, 4), args=(make_input(5,),), kwargs={'weight': make_input(5,)}), + error_type=RuntimeError, error_regex=r'inconsistent weight size, expected 4 but got \[5\]') + # invalid p + yield ErrorInput(SampleInput(make_input(5, 4), args=(make_input(5,),), kwargs={'p': 3}), + error_type=ValueError, error_regex='only p == 1 and p == 2 supported') + + +def sample_inputs_logsumexp(self, device, dtype, requires_grad, **kwargs): + inputs = ( + ((), (0,), True), + ((S, S), (1,), True), + ((S, S), (1,), False), + ((S, S), (-2,), False), + ((S, S), (0, 1), False), + ) + # Test large inputs to check numerical stability + lows = (None, 1e3, 1e6) if dtype in (torch.float32, torch.float64, torch.complex64, torch.complex128) else (None,) + for low in lows: + high = low * 2 if low is not None else None + for shape, dim, keepdim in inputs: + t = make_tensor(shape, dtype=dtype, device=device, + low=low, high=high, + requires_grad=requires_grad) + yield SampleInput(t, dim, keepdim) + +def reference_inputs_logsumexp(op, device, dtype, requires_grad, **kwargs): + yield from sample_inputs_logsumexp(op, device, dtype, requires_grad, **kwargs) + + # https://github.com/pytorch/pytorch/issues/91843 + t = torch.tensor([20, 30, 100], dtype=dtype, device=device, requires_grad=requires_grad) + yield SampleInput(t, 0, False) + + t = torch.tensor((), dtype=dtype, device=device, requires_grad=requires_grad) + yield SampleInput(t, 0, False) + + # tests masking + # https://github.com/pytorch/pytorch/pull/91860#pullrequestreview-1241344073 + t = torch.tensor(float("inf")) + yield SampleInput(t, 0, True) + +def sample_inputs_like_fns(self, device, dtype, requires_grad, **kwargs): + inputs = [ + ((), {}), + ((S, S), {}), + ((0, S, 0), {}), + ((S,), {'dtype': dtype, 'device': device}), + # Hard-code some dtypes/devices. We want to test cases where the + # (dtype, device) is different from the input's (dtype, device) + ((S,), {'dtype': torch.double if device != 'mps:0' else torch.float}), + ((S,), {'device': 'cpu'}), + ((S,), {'dtype': torch.double, 'device': 'cpu'}), + ] + if torch.cuda.is_available(): + inputs.append(((S,), {'device': 'cuda'})) + + for shape, kwargs in inputs: + t = make_tensor(shape, dtype=dtype, device=device, + low=None, high=None, + requires_grad=requires_grad) + yield SampleInput(t, **kwargs) + +def reference_inputs_like_fns(op, device, dtype, requires_grad, **kwargs): + yield from sample_inputs_like_fns(op, device, dtype, requires_grad, **kwargs) + + # shape + cases = ( + (), (0,), (1, 0), (1, 1, 4, 5), (5, 3, 0, 1), (1, 4, 3, 1, 1) + ) + + make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) + for shape in cases: + yield SampleInput(make_arg(shape)) + yield SampleInput(make_arg(shape).transpose(0, -1)) + yield SampleInput(make_arg(shape, noncontiguous=True)) + yield SampleInput(make_arg(shape, noncontiguous=True).transpose(0, -1)) + +def sample_inputs_multilabel_margin_loss(op_info, device, dtype, requires_grad, **kwargs): + _make_tensor = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + make_target = partial(_make_tensor, dtype=torch.long, requires_grad=False) + + inputs = ( + ([], make_target([], low=0, high=1), {}), + ([S], make_target([S], low=0, high=S), {}), + ([M, S], make_target([M, S], low=0, high=S), {}), + ([M, S], make_target([M, S], low=0, high=S), {"reduction": "none"}), + ([M, S], make_target([M, S], low=0, high=S), {"reduction": "mean"}), + ([M, S], make_target([M, S], low=0, high=S), {"reduction": "sum"}), + ) + + for shape, target, kwargs in inputs: + yield SampleInput(_make_tensor(shape), args=(target,), kwargs=kwargs) + + +def reference_inputs_multilabel_margin_loss(op_info, device, dtype, requires_grad, **kwargs): + yield from sample_inputs_multilabel_margin_loss(op_info, device, dtype, requires_grad, **kwargs) + _make_tensor = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + make_target = partial(_make_tensor, dtype=torch.long, requires_grad=False) + make_target_tensor = partial(torch.tensor, device=device, dtype=torch.long, requires_grad=False) + + inputs = ( + # random tests including -1 target labels + ([], make_target([], low=-1, high=1)), + ([S], make_target([S], low=-1, high=S)), + ([M, S], make_target([M, S], low=-1, high=S)), + # repeated target labels and -1 (labels after the first -1 are ignored) + ([], make_target_tensor(-1)), + ([7], make_target_tensor([2, 0, 6, -1, 4, -1, 6])), + ([4, 5], make_target_tensor([[4, -1, 0, -1, 2], [0, 0, 4, 1, 4], [-1, 3, -1, 1, 0], [4, 3, 2, 1, 0]])), + ) + reductions = (None, "none", "mean", "sum") + + for (shape, target), reduction in product(inputs, reductions): + kwargs = {} + if reduction is not None: + kwargs["reduction"] = reduction + yield SampleInput(_make_tensor(shape), args=(target,), kwargs=kwargs) + + +def error_inputs_multilabel_margin_loss(op, device, **kwargs): + make_input = partial(make_tensor, device=device, dtype=torch.float32) + # invalid reduction + yield ErrorInput(SampleInput(make_input(5, 4), args=(make_input(5, 4),), kwargs={'reduction': 'abc'}), + error_type=ValueError, error_regex='abc is not a valid value for reduction') + # invalid input + yield ErrorInput(SampleInput(make_input(5, 0), args=(make_input(5, 4),), kwargs={}), + error_type=RuntimeError, + error_regex=r'Expected non-empty vector or matrix with optional 0-dim batch size, but got: \[5, 0\]') + yield ErrorInput(SampleInput(make_input(0,), args=(make_input(0,),), kwargs={}), + error_type=RuntimeError, + error_regex=r'Expected non-empty vector or matrix with optional 0-dim batch size, but got: \[0\]') + # invalid target + yield ErrorInput(SampleInput(make_input(5, 4), args=(make_input(4,),), kwargs={}), + error_type=RuntimeError, + error_regex=r'inconsistent target size: \[4\] for input of size: \[5, 4\]') + yield ErrorInput(SampleInput(make_input(5, 4), args=(make_input((),),), kwargs={}), + error_type=RuntimeError, + error_regex=r'inconsistent target size: \[\] for input of size: \[5, 4\]') + + +def get_independent_tensor(tensor): + return tensor.clone().requires_grad_(tensor.requires_grad) + +def sample_inputs_randint(self, device, dtype, requires_grad, **kwargs): + low = 2 + high = 10 + + for sample in sample_inputs_like_fns(self, device, dtype, requires_grad, **kwargs): + sample.kwargs.setdefault('device', device) + # With high + yield SampleInput(high, sample.input.shape, *sample.args, **sample.kwargs) + # With low and high + yield SampleInput(low, high, sample.input.shape, *sample.args, **sample.kwargs) + +def sample_inputs_randint_like(self, device, dtype, requires_grad, **kwargs): + low = 2 + high = 10 + + for sample in sample_inputs_like_fns(self, device, dtype, requires_grad, **kwargs): + # With high + yield SampleInput( + sample.input, + high, + *sample.args, + **sample.kwargs) + # With low and high + yield SampleInput( + get_independent_tensor(sample.input), + low, + high, + *sample.args, + **sample.kwargs) + +def sample_inputs_margin_ranking_loss(op_info, device, dtype, requires_grad, **kwargs): + _make_tensor = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + shapes = ( + (), + (S,), + (S, S), + (S, S, S), + ) + + margins = (0., 1.) + reductions = ('sum', 'mean', 'none') + + for shape in shapes: + for margin, reduction in product(margins, reductions): + kwargs = {'margin': margin, 'reduction': reduction} + yield SampleInput(_make_tensor(shape), + args=(_make_tensor(shape, requires_grad=False), + _make_tensor(shape, requires_grad=False)), + kwargs=kwargs) + +def reference_inputs_margin_ranking_loss(op, device, dtype, requires_grad, **kwargs): + yield from sample_inputs_margin_ranking_loss(op, device, dtype, requires_grad, **kwargs) + make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + for reduction in ('sum', 'mean', 'none'): + if dtype.is_floating_point: # only supports ints and floats + # NaN propagation + inp1 = make_input((10, )) + inp1[2] = float('nan') + inp2 = make_input((10, )) + inp2[4] = float('nan') + target = make_input((10, )) + inp2[9] = float('nan') + yield SampleInput(inp1, args=(inp2, target), kwargs={'reduction': reduction}) + + # Inf handling + inp1 = make_input((10, )) + inp2[1] = float('inf') + inp2 = make_input((10, )) + inp2[4] = float('inf') + target = make_input((10, )) + inp2[7] = float('inf') + yield SampleInput(inp1, args=(inp2, target), kwargs={'reduction': reduction}) + + # Broadcasting + inp1 = make_input((5, 2)) + inp2 = make_input((5, 1)) + target = make_input((1, 2)) + yield SampleInput(inp1, args=(inp2, target), kwargs={'reduction': reduction}) + +def error_inputs_margin_ranking_loss(op, device, **kwargs): + make_input = partial(make_tensor, device=device, dtype=torch.float32) + # invalid reduction value. + yield ErrorInput(SampleInput(make_input(5, 4), args=(make_input(5, 4), make_input(5, 4),), kwargs={'reduction': 'abc'}), + error_type=ValueError, error_regex='is not a valid value') + # invalid input shapes + yield ErrorInput(SampleInput(make_input(5, 4), args=(make_input(5, 4), make_input(5,),)), + error_regex='margin_ranking_loss : All input tensors should') + +def sample_inputs_new_fns(self, device, dtype, requires_grad, *, is_strided=False, **kwargs): + # input_shape, output_shape, strides, kwargs + # lengths of output_shape and strides must be equal + inputs = [ + ((), (), (), {}), + ((S, S), (2, 0), (3, 4), {}), + ((0, S, 0), (3, 2, 2), (1, 2, 3), {}), + ((S,), (2, 3), (7, 8), {'dtype': dtype, 'device': device}), + # Hard-code some dtypes/devices. We want to test cases where the + # (dtype, device) is different from the input's (dtype, device) + ((S,), (10,), (S,), {'dtype': torch.double if device != 'mps:0' else torch.float}), + ((S,), (1, 1, 12), (S, L, M), {'device': 'cpu'}), + ((S,), (2, 2, 2), (L, M, S), {'dtype': torch.double, 'device': 'cpu'}), + ] + if torch.cuda.is_available(): + inputs.append(((S,), (7, 2), (3, 4), {'device': 'cuda'})) + + for input_shape, output_shape, strides, kwargs in inputs: + t = make_tensor(input_shape, dtype=dtype, device=device, + low=None, high=None, + requires_grad=requires_grad) + if is_strided: + yield SampleInput(t, output_shape, strides, **kwargs) + else: + yield SampleInput(t, output_shape, **kwargs) + +def sample_inputs_empty_strided(op, device, dtype, requires_grad=False, **kwargs): + + inputs = [ + ((), (), {'dtype': dtype, 'device': device}), + ((S,), (4,), {'dtype': dtype, 'device': device}), + ((S, S), (2, 1), {'dtype': dtype, 'device': device}), + ((S, S, S), (2, 0, 1), {'dtype': dtype, 'device': device}), + ] + + for shape, strides, kwargs in inputs: + yield SampleInput(shape, strides, requires_grad=requires_grad, **kwargs) + +def sample_inputs_empty(op, device, dtype, requires_grad, **kwargs): + # shape + cases = ( + (), (0,), (1,), (1, 3, 5), (5, 3, 1), (1, 0, 5, 1), + ) + + for case in cases: + yield SampleInput(case, device=device, dtype=dtype, requires_grad=requires_grad) + +def sample_inputs_empty_permuted(op, device, dtype, requires_grad, **kwargs): + # shape + cases = ( + (), (0,), (1,), (1, 3, 5), (5, 3, 1), (1, 0, 5, 1), + ) + + for case in cases: + for layout in itertools.permutations(range(len(case))): + yield SampleInput(case, layout, device=device, dtype=dtype, requires_grad=requires_grad) + +def error_inputs_empty_permuted(op_info, device, **kwargs): + yield ErrorInput( + SampleInput((2,), args=((0, 1),)), + error_type=RuntimeError, + error_regex="Number of dimensions in size does not match the length of the physical_layout" + ) + yield ErrorInput( + SampleInput((2,), args=((3,),)), + error_type=RuntimeError, + error_regex="Dimension out of range" + ) + yield ErrorInput( + SampleInput((2, 3), args=((0, 0),)), + error_type=RuntimeError, + error_regex="Duplicate dim not allowed" + ) + +def sample_inputs_scalar_tensor(op, device, dtype, requires_grad, **kwargs): + # Not including a scalar tensor in vals because meta tests start failing due to + # lack of meta support for _local_scalar_dense + # torch.tensor(2, device=device) + vals = (-5, 0, 1) + + for item in vals: + yield SampleInput(item, device=device, dtype=dtype, requires_grad=requires_grad) + +def sample_inputs_eye(op, device, dtype, requires_grad, **kwargs): + # only ints >= 0 are allowed for both arguments, unless m is omitted + sizes = (None, 0, 1, 2, 3, 4, 7, L, M, S) + + for n, m in product(sizes, sizes): + if n is None: + continue + + # TODO: no layout + _kwargs = {'device': device, 'dtype': dtype, 'requires_grad': requires_grad} + if m is None: + yield SampleInput(n, args=(), kwargs=_kwargs) + else: + yield SampleInput(n, args=(m,), kwargs=_kwargs) + +def error_inputs_eye(op_info, device, **kwargs): + # TODO: no layout + _kwargs = {'device': device, 'dtype': torch.float32} + + yield ErrorInput( + SampleInput(-1, args=(), kwargs=_kwargs), + error_regex="n must be greater or equal to 0, got -1" + ) + + yield ErrorInput( + SampleInput(-7, args=(42,), kwargs=_kwargs), + error_regex="n must be greater or equal to 0, got -7" + ) + + yield ErrorInput( + SampleInput(0, args=(-3,), kwargs=_kwargs), + error_regex="m must be greater or equal to 0, got -3" + ) + + +def sample_inputs_new_full(self, device, dtype, requires_grad, **kwargs): + def get_val(dtype): + return make_tensor([], dtype=dtype, device="cpu").item() + + for sample in sample_inputs_new_fns(self, device, dtype, requires_grad, **kwargs): + # The scalar we are passing to new_full must be the same dtype + # as the one of the resulting tensor + use_dtype = sample.kwargs.get('dtype', dtype) + yield SampleInput( + sample.input, *sample.args, get_val(use_dtype), **sample.kwargs) + +def sample_inputs_full_like(self, device, dtype, requires_grad, **kwargs): + def get_val(dtype): + return make_tensor([], dtype=dtype, device="cpu").item() + + double_dtype = torch.double if device != "mps:0" else torch.float + inputs = [ + ((), get_val(dtype), {}), + ((S, S), get_val(dtype), {}), + ((0, S, 0), get_val(dtype), {}), + ((S,), get_val(dtype), {'dtype': dtype, 'device': device}), + # Hard-code some dtypes/devices. We want to test cases where the + # (dtype, device) is different from the input's (dtype, device) + ((S,), get_val(double_dtype), {'dtype': double_dtype}), + ((S,), get_val(dtype), {'device': 'cpu'}), + ((S,), get_val(double_dtype), {'dtype': double_dtype, 'device': 'cpu'}), + ] + if torch.cuda.is_available(): + inputs.append(((S,), get_val(dtype), {'device': 'cuda'})) + + if torch.mps.is_available() and dtype not in [torch.float64, torch.complex128, torch.uint32, torch.uint16]: + inputs.append(((S,), get_val(dtype), {'device': 'mps'})) + + if not dtype.is_signed: + # For unsigned dtypes, negative values are converted. + inputs.append(((S,), -get_val(dtype), {})) + + for shape, fill_value, kwargs in inputs: + t = make_tensor(shape, dtype=dtype, device=device, + low=None, high=None, + requires_grad=requires_grad) + yield SampleInput(t, fill_value, **kwargs) + +def sample_inputs_multinomial(self, device, dtype, requires_grad, **kwargs): + cases = [ + ([3], 3, {}), + ([10], 3, {}), + ([3, 10], 3, {}), + ([3], 3, dict(replacement=False)), + ([3], 3, dict(replacement=True)), + ([3, 4], 4, dict(replacement=True)), + ([3, 4], 4, dict(replacement=False)), + ] + + for shape, num_samples, kwargs in cases: + t = make_tensor(shape, dtype=dtype, device=device, + low=0, high=None, + requires_grad=requires_grad) + yield SampleInput(t, num_samples, **kwargs) + +def sample_inputs_normal_common(self, device, dtype, requires_grad, cases, **kwargs): + def get_value_or_make_tensor(value_or_shape): + if isinstance(value_or_shape, list): + return make_tensor(value_or_shape, dtype=dtype, device=device, + low=0, high=None, + requires_grad=requires_grad) + return value_or_shape + + for value_or_mean_shape, value_or_std_shape, kwargs in cases: + mean = get_value_or_make_tensor(value_or_mean_shape) + std = get_value_or_make_tensor(value_or_std_shape) + yield SampleInput(mean, std, **kwargs) + +def sample_inputs_normal_tensor_first(self, device, dtype, requires_grad, **kwargs): + # value_or_size, value_or_size, kwargs + cases = [ + ([], [], {}), + ([3], [3], {}), + ([3, 4, 2], [3, 4, 2], {}), + ([2, 3], 1.1, {}), + ([1, 2, 3], [5, 2, 3], {}), # broadcasting + ] + + return sample_inputs_normal_common(self, device, dtype, requires_grad, cases, **kwargs) + +def sample_inputs_normal_tensor_second(self, device, dtype, requires_grad, **kwargs): + yield SampleInput(1.6, 0.3, [2, 3], dtype=dtype, device=device) + yield SampleInput(1.6, 0.3, [2, 2, 2], dtype=dtype, layout=torch.strided, device=device) + yield SampleInput(2.7, make_tensor([4, 3], dtype=dtype, device=device, low=0, high=None, requires_grad=requires_grad)) + +def sample_inputs_bernoulli(self, device, dtype, requires_grad, **kwargs): + shapes = [ + [3], + [], + [0, 3], + [2, 3, 4], + ] + + for shape in shapes: + t = make_tensor(shape, dtype=dtype, device=device, + low=0, high=1, + requires_grad=requires_grad) + yield SampleInput(t) + +def error_inputs_bernoulli(op_info, device, **kwargs): + # more than one element of the written-to tensor refers to a single memory location + x = torch.rand((1,), device=device).expand((6,)) + err_msg = 'unsupported operation' + yield ErrorInput(SampleInput(torch.rand_like(x), kwargs={'out': x}), + error_regex=err_msg) + +def sample_inputs_logcumsumexp(self, device, dtype, requires_grad, **kwargs): + inputs = ( + ((S, S, S), 0), + ((S, S, S), 1), + ((), 0), + ) + + for large_number in (True, False): + for shape, dim in inputs: + t = make_tensor(shape, dtype=dtype, device=device, + low=None, high=None, + requires_grad=requires_grad) + + if large_number and t.dim() > 0: + t[0] = 10000 + yield SampleInput(t, dim) + +def sample_inputs_trace(self, device, dtype, requires_grad, **kwargs): + yield SampleInput( + make_tensor((S, S), dtype=dtype, device=device, + low=None, high=None, + requires_grad=requires_grad)) + + +def error_inputs_trace(op, device): + yield ErrorInput(SampleInput(make_tensor((3, 4, 5), dtype=torch.float32, device=device)), error_regex="expected a matrix") + + +def sample_inputs_renorm(self, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) + cases = (((S, S, S), (2, 1, 0.5)), + ((S, S, S), (2, -1, 0.5)), + ((S, S, S), (1, 2, 3)), + ((S, S, S), (float('inf'), 2, 0.5)), + ) + + for shape, args in cases: + yield SampleInput(make_arg(shape), args=args) + + +def sample_inputs_transpose_swapdims(self, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) + + cases = (((1, 2, 3), (-1, -2)), + ((1, 2, 3), (-1, 2)), + ((1, 2, 3), (1, -2)), + ((1, 2, 3), (1, 2)), + ((), (0, 0)), + ((1, ), (0, 0)), + ((M, M), (0, 1)), + ((S, S, S), (2, 0)), ) + + for shape, args in cases: + yield SampleInput(make_arg(shape), args=args) + +def _numpy_ref_transpose(a, dim0, dim1): + if a.ndim <= 1: + return a + + return np.swapaxes(a, dim0, dim1) + +def sample_inputs_adjoint(self, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) + + shapes = ((1, 2, 3), (M, M), (S, S, S), (S, M, S), (M, S, M, S)) + return (SampleInput(make_arg(shape)) for shape in shapes) + +def sample_inputs_T(self, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) + + shapes = ((M, M), (M, L)) + return (SampleInput(make_arg(shape)) for shape in shapes) + +def error_inputs_T(self, device, has_ndims_error=False): + make_arg = partial(make_tensor, device=device, dtype=torch.float32) + + # Deprecated behavior in regular PyTorch, but throws an error in primTorch: + # https://github.com/pytorch/pytorch/issues/86968 + if has_ndims_error: + # ndims == 1 + yield ErrorInput(SampleInput(make_arg(M)), + error_regex=(r'The use of `x\.T` on tensors of dimension other than 0 or 2 ' + r'to reverse their shape is not supported\.')) + + # ndims > 2 + yield ErrorInput(SampleInput(make_arg(M, S, L)), + error_regex=(r'The use of `x\.T` on tensors of dimension other than 0 or 2 ' + r'to reverse their shape is not supported\.')) + + +def sample_inputs_singular_matrix_factors(op_info, device, dtype, requires_grad=False): + """ + This function produces two tensors of shape (*, m, k) and (*, n, k) with k <= min(m, n). + Their matrix product could be used to generate tensor of shape (*, m, n) of rank k. + """ + + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + batches = [(), (2,)] + size = [3, 4] + for batch, m, n in product(batches, size, size): + k = 2 + a = make_arg((*batch, m, k)) + b = make_arg((*batch, n, k)) + yield a, b + + +def sample_inputs_svd_lowrank(op_info, device, dtype, requires_grad=False, **kwargs): + # Function that's well defined on the outputs for complex inputs + def fn(usv): + U, S, V = usv + return U @ V.mH, S + + for (a, b) in sample_inputs_singular_matrix_factors(op_info, device, dtype, requires_grad): + *batch, m, k = a.shape + n = b.shape[-2] + + # NOTE: since svd_lowrank relies on non rank-revealing SVD, + # it inherits the problem of unstable behavior with repeated + # singular values including zeros. + # Since we want to avoid (repeated) zeros as singular values, + # we can only use k for q. + # This issues could be resolved with using a rank-revealing SVD + # which does not include "zero" singular values. + yield SampleInput(a, b, q=k, M=None).with_metadata(output_process_fn_grad=fn) + + for (a, b) in sample_inputs_singular_matrix_factors(op_info, device, dtype, requires_grad): + *batch, m, k = a.shape + n = b.shape[-2] + M = make_tensor((*batch, m, n), dtype=dtype, device=device, requires_grad=requires_grad) + yield SampleInput(a, b, q=k, M=M).with_metadata(output_process_fn_grad=fn) + +def chunk_iter(iterable, size): + it = iter(iterable) + while True: + chunk = tuple(islice(it, size)) + if not chunk: + break + yield chunk + +def sample_inputs_pca_lowrank(op_info, device, dtype, requires_grad=False, **kwargs): + # we reuse samples from svd_lowrank which come in group of two with + # kwarg['M'] = None and with kwarg['M'] = + samples = sample_inputs_svd_lowrank(op_info, device, dtype, requires_grad, **kwargs) + for s1, s2 in chunk_iter(samples, 2): + del s1.kwargs['M'] + del s2.kwargs['M'] + s1.kwargs['center'] = False + s2.kwargs['center'] = True + yield s1 + yield s2 + +def np_sinc_with_fp16_as_fp32(x): + # Wraps numpy's sinc function so that fp16 values are promoted to fp32 + # before sinc is invoked. Context: numpy's sinc returns NaN when evaluated + # at 0 for fp16. + if x.dtype == np.float16: + return np.sinc(x.astype(np.float32)) + else: + return np.sinc(x) + +def sample_inputs_broadcast_to(op_info, device, dtype, requires_grad, **kwargs): + test_cases = ( + ((S, 1, 1), (S, S, S)), + ((S, 1, S), (S, S, S)), + ((S, 1), (S, S, S)), + ((1,), (S, S, S)), + ((1, S), (1, 1, S)), + ((), ()), + ((), (1, 3, 2)), + ) + + return ( + SampleInput( + make_tensor(size, dtype=dtype, device=device, low=None, high=None, requires_grad=requires_grad), + shape, + ) for size, shape in test_cases) + +def sample_inputs_broadcast_tensors(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) + test_cases: tuple[tuple] = (((3,), (1, 2, 1), (1, 1), (5, 1, 1),),) + + for shape, *other_shapes in test_cases: + yield SampleInput(make_arg(shape), args=tuple(make_arg(s) for s in other_shapes)) + +def reference_inputs_broadcast_tensors(op, device, dtype, requires_grad, **kwargs): + yield from sample_inputs_broadcast_tensors(op, device, dtype, requires_grad, **kwargs) + + m = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) + n = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad, noncontiguous=True) + + cases = ( + ((), (1, 1), (1, 1, 7, 1), (3, 1, 1)), + ((3, 5, 6), (1, 3, 5, 6), (1, 1, 1, 1, 6), (8, 3, 5, 6)) + ) + + for a, b, c, d in cases: + yield SampleInput(m(a), args=(m(b), m(c), m(d))) + yield SampleInput(n(a), args=(n(b), n(c), n(d))) + +def sample_inputs_block_diag(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) + test_cases: tuple[tuple] = ( + ((1, S), (2, S), (3, S),), + ((S, 1), (S, 2), (S, 3),), + ((1,), (2,), (3,),), + ((2, S), (S,)) + ) + + for shape, *other_shapes in test_cases: + yield SampleInput(make_arg(shape), args=tuple(make_arg(s) for s in other_shapes)) + # We also want to test mixed complex-non-complex inputs to block_diag + if dtype == torch.complex32 or dtype == torch.complex64: + non_complex_dtype = torch.float32 if dtype == torch.complex32 else torch.float64 + make_arg_non_complex = partial(make_tensor, dtype=non_complex_dtype, device=device, requires_grad=requires_grad) + yield SampleInput(make_arg_non_complex(shape), args=tuple(make_arg(s) for s in other_shapes)) + +def sample_inputs_cdist(op_info, device, dtype, requires_grad, **kwargs): + small_S = 2 + test_cases = ( + ((S, S, 2), (S, S + 1, 2)), + ((S, S), (S, S)), + ((S, S, S), (S, S, S)), + ((3, 5), (3, 5)), + ((2, 3, 5), (2, 3, 5)), + ((1, 2, 3), (1, 2, 3)), + ((1, 1), (S, 1)), + ((0, 5), (4, 5)), + ((4, 5), (0, 5)), + ((0, 4, 5), (3, 5)), + ((4, 5), (0, 3, 5)), + ((0, 4, 5), (1, 3, 5)), + ((1, 4, 5), (0, 3, 5)), + # Using S here would make this one test take 9s + ((small_S, small_S, small_S + 1, 2), (small_S, small_S, small_S + 2, 2)), + ((small_S, 1, 1, small_S), (1, small_S, small_S)), + ((1, 1, small_S), (small_S, 1, small_S, small_S)), + ) + + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + for cm in ['use_mm_for_euclid_dist', 'donot_use_mm_for_euclid_dist']: + # FIXME add an override for JIT and revert 0. back to 0 + # since it's accepted by eager + for p in [0., 1., 2., 3., 0.5, 1.5, 2.5, float("inf")]: + for t1_size, t2_size in test_cases: + # The args should never be non-contiguous as this is not supported in the backward + yield SampleInput(make_arg(t1_size), make_arg(t2_size), p, cm) + +def _fill_np(a, value): + a = a.copy() + a.fill(value) + return a + +def _fill_sample_kwargs(device, dtype, input): + if dtype is torch.bool: + value = True + else: + value = 3 + + return ({'value': value}, {'value': value}) + +def sample_inputs_comparison_ops(op, device, dtype, requires_grad, **kwargs): + yield from sample_inputs_elementwise_binary(op, device, dtype, requires_grad, **kwargs) + + # Adds a sample input where both tensors have the same values + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + lhs = make_arg((S, S)) + yield SampleInput(lhs, args=(lhs.clone(),)) + +def sample_inputs_stack(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + # shape x number of tensors + cases = ( + ((3, 4), 1), + ((1, 2, 1, 4), 3), + ((0, 1, 0), 2),) + + for shape, num_tensors in cases: + tensors = [make_arg(shape) for _ in range(num_tensors)] + for dim in range(-1, len(shape) - 1): + yield SampleInput(tensors, args=(dim,)) + + +def sample_inputs_chunk_cat(op_info, device, dtype, requires_grad, **kwargs): + # 1. If input tensors have different ndims, dim should be non-negative and be less than the ndims of every input tensors. + # If all input tensors have the same ndims, we support both negative and non-negative dim. + # 2. For wrapped_dim, all tensors should have the same size for 0,...,wrapped_dim-1 dimensions. + # No requirements for (wrapped_dim, ...)-th dimension. + # 3. Expect positive num_chunks + # 4. Expect non-empty input tensor list and each input tensor should have at least 1 element + # 5. Non-contiguous input tensors are allowed. + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + same_ndim_cases = ( + ( + [ + torch.Size([1, 2, 3]), + torch.Size([1, 2, 3]), + ], -1, 5 + ), + ( + [ + torch.Size([1, 2, 129]), + torch.Size([1, 2, 297]), + ], -1, 5 + ), + ( + [ + torch.Size([1, 2, 3]), + torch.Size([1, 2, 3]), + ], 1, 5 + ), + ( + [ + torch.Size([3, 3, 2, 1]), + torch.Size([1, 4, 2, 2]), + torch.Size([2, 1, 3, 3]), + ], 0, 2 + ), + ) + for sizes, dim, num_chunks in same_ndim_cases: + tensors = [make_arg(size) for size in sizes] + yield SampleInput(tensors, args=(dim, num_chunks)) + + different_ndim_case = [ + torch.Size([2, 3, 3]), + torch.Size([2, 3, 1, 2]), + torch.Size([2, 3]), + torch.Size([2, 3, 2]), + torch.Size([2, 3, 271]), + ] + max_dim, num_chunks = 2, 3 + for dim in range(max_dim): + tensors = [] + for size in different_ndim_case: + tensors.append(make_arg(size)) + yield SampleInput(tensors, args=(dim, num_chunks)) + + # non-contiguous + for dim in range(max_dim): + tensors = [] + for size in different_ndim_case: + # make the last 2 dims column-major (i.e. non-contiguous) + t = make_arg(size).transpose(-2, -1).contiguous().transpose(-2, -1) + tensors.append(t) + yield SampleInput(tensors, args=(dim, num_chunks)) + +def error_inputs_chunk_cat(op_info, device, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=torch.float32) + + # input tensors have different ndims but dim is negative + sizes, dim, num_chunks = [torch.Size([2, 3]), torch.Size([4,])], -1, 3 + tensors = [make_arg(size) for size in sizes] + yield ErrorInput( + SampleInput(tensors, args=(dim, num_chunks)), + error_regex='_chunk_cat expects non-negative dim when input tensors have different ndims', + ) + + # input tensors have different ndims but dim >= ndim of some input tensors + sizes, dim, num_chunks = [torch.Size([2, 3]), torch.Size([4,])], 1, 3 + tensors = [make_arg(size) for size in sizes] + yield ErrorInput( + SampleInput(tensors, args=(dim, num_chunks)), + error_regex='_chunk_cat expects dim < ndim for all input tensors', + ) + + # some tensors have different sizes for 0, ..., dim-1 dimensions. + sizes, dim, num_chunks = [torch.Size([2, 3, 4]), torch.Size([4, 3])], 1, 3 + tensors = [make_arg(size) for size in sizes] + yield ErrorInput( + SampleInput(tensors, args=(dim, num_chunks)), + error_regex='_chunk_cat expects same sizes of 0,...,dim-1 dimensions for all tensors', + ) + + # negative num_chunks + sizes, dim, num_chunks = [torch.Size([2,]), torch.Size([3,])], 0, -1 + tensors = [make_arg(size) for size in sizes] + yield ErrorInput( + SampleInput(tensors, args=(dim, num_chunks)), + error_regex='_chunk_cat expects positive num_chunks', + ) + + # zero as num_chunks + sizes, dim, num_chunks = [torch.Size([2,]), torch.Size([3,])], 0, 0 + tensors = [make_arg(size) for size in sizes] + yield ErrorInput( + SampleInput(tensors, args=(dim, num_chunks)), + error_regex='_chunk_cat expects positive num_chunks', + ) + + # empty input tensor list + dim, num_chunks = 0, 1 + yield ErrorInput( + SampleInput([], args=(dim, num_chunks)), + error_regex='_chunk_cat expects a non-empty input tensor list', + ) + + # empty input tensor with 0 elements + sizes, dim, num_chunks = [torch.Size([0,]), torch.Size([3,])], 0, 1 + tensors = [make_arg(size) for size in sizes] + yield ErrorInput( + SampleInput(tensors, args=(dim, num_chunks)), + error_regex='_chunk_cat expects non-empty tensor', + ) + + +def sample_inputs_cat_concat(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + cases: tuple[tuple, tuple, dict] = ( # type: ignore[assignment] + ((S, S), (S, S), {'dim': -1}), + ((S, S), (S, S), {'dim': 1}), + ((M, S), (S, S), {'dim': 0}), # different shapes + ((1, 2, 3), (1, 2, 3), {'dim': -2}), + ((0,), (0,), {'dim': 0}), # empty tensor + ((0,), (S, S), {'dim': 1}), # empty tensor with unempty and dim=1 (special case for legacy_cat_wrap_dim) + ((0, S), (S, S), {'dim': 0}), + ((1,), (1,), {}) # dim not passed, fallback to default + ) + + for input_shape1, input_shape2, kwargs in cases: + yield SampleInput([make_arg(input_shape1), make_arg(input_shape2)], kwargs=kwargs) + + # from coat_lite_mini + yield SampleInput([make_arg((2, 2, 2, 2), memory_format=torch.channels_last)], args=(1,),) + +def error_inputs_cat(op_info, device, **kwargs): + + make_arg = partial(make_tensor, device=device, dtype=torch.float32) + + # error inputs for more than one element of the written-to tensor refer to a single memory location + yield ErrorInput(SampleInput([make_arg((S, S)), make_arg((S, S))], + kwargs={'out': make_arg((1, S)).expand((2 * S, S))}), + error_regex='unsupported operation') + + # error inputs for empty tensors + yield ErrorInput(SampleInput([], kwargs={'dim': 1}), + error_regex='non-empty list of Tensors', error_type=ValueError) + + # error inputs for different sizes + yield ErrorInput(SampleInput([make_arg((S, S, L, L)), make_arg((S, 0, L - 1, L))], kwargs={'dim': 1}), + error_regex='Sizes of tensors must match except in dimension') + yield ErrorInput(SampleInput([make_arg((S, 0, L - 1, L)), make_arg((S, S, L, L))], kwargs={'dim': 1}), + error_regex='Sizes of tensors must match except in dimension') + + # error inputs for different dimensions + yield ErrorInput(SampleInput([make_arg((S - 1, 0)), make_arg((S, 0, L - 1, L))], kwargs={'dim': 1}), + error_regex='Tensors must have same number of dimensions') + yield ErrorInput(SampleInput([make_arg((S, 0, L - 1, L)), make_arg((S - 1, 0))], kwargs={'dim': 1}), + error_regex='Tensors must have same number of dimensions') + + # error inputs for same memory locations + x = torch.zeros((0), device=device) + y = torch.randn((4, 6), device=device) + + err_msg = "the written-to tensor refer to a single memory location" + + yield ErrorInput(SampleInput((x, y), kwargs={'dim': 0, 'out': x}), + error_regex=err_msg) + yield ErrorInput(SampleInput((x, y), kwargs={'dim': 0, 'out': y}), + error_regex=err_msg) + + z = torch.zeros((4, 6), device=device) + yield ErrorInput(SampleInput((y, z), kwargs={'out': z[:2, :]}), + error_regex=err_msg) + + # error inputs for different devices + if torch.device(device).type == 'cuda': + x_cuda = make_tensor((3, 3), device=device, dtype=torch.float32) + y_cpu = make_tensor((3, 3), device='cpu', dtype=torch.float32) + yield ErrorInput(SampleInput((x_cuda, y_cpu)), + error_regex='Expected all tensors to be on the same device') + + # error inputs for different input sizes for more than 2 tensors + yield ErrorInput(SampleInput([make_arg((L, 1)), make_arg((L, 1, 1)), make_arg((L, 1, 1))]), + error_regex='Tensors must have same number of dimensions') + + yield ErrorInput(SampleInput([make_arg((S, 1, M)), make_arg((S, 1, 1)), make_arg((S, M, 1))], + kwargs={'dim': 1}), + error_regex='Sizes of tensors must match') + + # error inputs for None input + yield ErrorInput(SampleInput((make_arg((S, 1, 1)), None)), error_type=TypeError, + error_regex='got None') + + # error inputs for zero-dimensional tensors + yield ErrorInput(SampleInput([make_arg(()), make_arg(())]), + error_regex='zero-dimensional.*cannot be concatenated') + + # error inputs for different dtype of out tensors + d = make_tensor((2, 3), device=device, dtype=torch.double if not device.startswith("mps") else torch.float16) + x = make_tensor((2, 3), device=device, dtype=torch.float32) + yield ErrorInput(SampleInput(x, kwargs={'out': d}), error_type=TypeError, + error_regex='invalid combination of arguments') + +def reference_inputs_cat(op, device, dtype, requires_grad, **kwargs): + yield from sample_inputs_cat_concat(op, device, dtype, requires_grad, **kwargs) + + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + # Noncontiguous type promoting tensors + a = make_arg((3, 4, 2)) + b = make_arg((3, 2, 2), noncontiguous=True, dtype=torch.double) + c = make_arg((3, 3, 2), dtype=torch.float16).permute(1, 0, 2) + + yield SampleInput((a, b, c), kwargs={'dim': 1}) + + # Special 1D tensor with dim length of 0 case + a = make_arg((0,)) + b = make_arg((3, 2, 2)) + + yield SampleInput((a, b, a)) + yield SampleInput((a, a, a)) + +def _elementwise_type_promo_np(*args, type_promotion_kind): + def _maybe_torch(x): + if isinstance(x, np.ndarray): + return torch.from_numpy(x) + return x + + flattened = pytree.arg_tree_leaves(*args) + transformed = tuple(_maybe_torch(a) for a in flattened) + result_dtype, _ = prims.utils.elementwise_dtypes( + *transformed, + type_promotion_kind=type_promotion_kind) + return torch_to_numpy_dtype_dict[result_dtype] + +def _cat_np(input_seq, dim=0): + inputs = tuple(a for a in input_seq if not (a.ndim == 1 and a.size == 0)) + + if len(inputs) == 0: + np_dtype = _elementwise_type_promo_np( + input_seq, + type_promotion_kind=prims.utils.ELEMENTWISE_TYPE_PROMOTION_KIND.NO_OPMATH) + return np.empty(0, dtype=np_dtype) + + return np.concatenate(inputs, axis=dim) + +def _floor_divide_np(a, b): + dtype = _elementwise_type_promo_np( + a, + b, + type_promotion_kind=prims.utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT) + if isinstance(a, np.ndarray): + a = a.astype(dtype) + if isinstance(b, np.ndarray): + b = b.astype(dtype) + return np.floor_divide(a, b) + +def sample_inputs_hstack_dstack_vstack(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) + tensor_shapes = ( + # First Tensor being 1-D is special + # case for hstack + ((S,), (S,), (S,)), + ((S, S), (S, S), (S, S)), + ) + for s1, s2, s3 in tensor_shapes: + tensors = (make_arg(s1,), make_arg(s2,), make_arg(s3)) + yield SampleInput(tensors) + +def error_inputs_hstack_dstack_vstack(op, device): + make_arg = partial(make_tensor, dtype=torch.int32, device=device, requires_grad=False) + tensor_shapes = ( + ((S,), (S, S, S, S), (S,)), + ) + for s1, s2, s3 in tensor_shapes: + tensors = (make_arg(s1,), make_arg(s2,), make_arg(s3)) + # Different dimension tensor + yield ErrorInput(SampleInput(tensors), error_regex="Tensors must have same number of dimensions") + + # empty tensor list + yield ErrorInput(SampleInput(()), error_regex="expects a non-empty TensorList") + +def sample_inputs_unbind(op_info, device, dtype, requires_grad, **kwargs): + # Note: we don't do any tests where we unbind along 0-length dims + # because in that case unbind returns and empty tuple, and that breaks + # some assumptions in some backward tests in test_ops.py + shape_dims = (((S,), 0), + ((S, S), 0), + ((S, S), 1), + ((S, S), -1), + ((S, 0, S), 0), + ((S, S, S), 1), + ) + for shape, dim in shape_dims: + yield SampleInput(make_tensor(shape, dtype=dtype, device=device, + requires_grad=requires_grad), + args=(dim,)) + +def error_inputs_unbind(op_info, device): + make_arg = partial(make_tensor, dtype=torch.int32, device=device, requires_grad=False) + yield ErrorInput(SampleInput(make_arg(()), args=(0,)), error_type=IndexError, + error_regex="Dimension specified as 0 but tensor has no dimensions") + yield ErrorInput(SampleInput(make_arg((2,)), args=(2,)), error_type=IndexError, + error_regex="Dimension out of range") + +def reference_unbind(t, dim): + """A numpy implementation of torch.unbind""" + return tuple(s.squeeze(dim) for s in np.split(t, t.shape[dim], dim)) + +def sample_inputs_gather(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad, low=None, high=None) + yield SampleInput( + make_arg((M, S)), + 0, + gather_variable((S, S), 1, M, True, device=device)) + yield SampleInput( + make_arg((M, S)), + 0, + gather_variable((S, S), 1, M, True, device=device).to(torch.int32)) + yield SampleInput( + make_arg((M, S)), + 1, + gather_variable((M, S // 2), 0, S, True, device=device)) + # Empty index tensor case, see: https://github.com/pytorch/pytorch/pull/65006 + yield SampleInput( + make_arg((S,)), + 0, + torch.tensor([], dtype=torch.uint8, device=device)) + yield SampleInput( + make_arg((S,)), + 0, + torch.tensor([[], []], dtype=torch.uint8, device=device)) + # 0D tensor case + yield SampleInput( + make_arg(()), + 0, + torch.tensor([0], dtype=torch.int64, device=device)) + yield SampleInput( + make_arg(()), + 0, + torch.tensor(0, dtype=torch.int64, device=device)) + +def _fill_indices(idx, dim, dim_size, elems_per_row, m, n, o): + for i in range(1 if dim == 0 else m): + for j in range(1 if dim == 1 else n): + for k in range(1 if dim == 2 else o): + ii = [i, j, k] + ii[dim] = slice(0, idx.size(dim) + 1) + idx[tuple(ii)] = torch.randperm(dim_size)[0:elems_per_row] + +def error_inputs_gather(op_info, device, **kwargs): + # src is [1, 2] + # [3, 4] + src = torch.tensor(((1, 2), (3, 4)), device=device, dtype=torch.float32) + + # idx is [0, 0] + # [1, 0] + idx = torch.tensor(((0, 0), (1, 0)), device=device, dtype=torch.long) + + # Index should be smaller than self except on dimension 1 + bad_src = make_tensor((1, 1), device=device, dtype=torch.float32) + yield ErrorInput(SampleInput(bad_src, args=(1, idx,)), + error_regex="Size does not match at dimension 0") + + # TODO: FIXME + # out.dtype must match src.dtype + # Creates new src & idx since SampleInputs can't share tensors + src = torch.tensor(((1, 2), (3, 4)), device=device, dtype=torch.float32) + idx = torch.tensor(((0, 0), (1, 0)), device=device, dtype=torch.long) + out = torch.empty((2, 2), device=device, dtype=torch.float64) + yield ErrorInput(SampleInput(src, args=(1, idx), kwargs={'out': out}), + error_regex="Expected out tensor to have dtype") + + # src and index tensors must have the same # of dimensions + # idx too few dimensions + src = torch.tensor(((1, 2), (3, 4)), device=device, dtype=torch.float32) + idx = torch.tensor((0, 0), device=device, dtype=torch.long) + yield ErrorInput(SampleInput(src, args=(1, idx)), + error_regex="Index tensor must have the same number of dimensions") + + # src too few dimensions + src = torch.tensor((1, 2), device=device, dtype=torch.float32) + idx = torch.tensor(((0, 0), (1, 0)), device=device, dtype=torch.long) + yield ErrorInput(SampleInput(src, args=(0, idx)), + error_regex="Index tensor must have the same number of dimensions") + + # index out of bounds + # NOTE: this ErrorInput is guarded because bounds checking does not occur on CUDA devices + if torch.device(device).type == 'cpu': + src = torch.tensor(((1, 2), (3, 4)), device=device, dtype=torch.float32) + idx = torch.tensor(((0, 23), (1, 0)), device=device, dtype=torch.long) + yield ErrorInput(SampleInput(src, args=(1, idx,)), + error_regex="index 23 is out of bounds for dimension") + + x = torch.rand((1,), device=device).expand((3,)) + src = torch.rand((6,), device=device) + ind = torch.tensor([2, 1, 0], device=device, dtype=torch.int64) + + yield ErrorInput(SampleInput(src, args=(0, ind,), kwargs=dict(out=x)), + error_type=RuntimeError, + error_regex='unsupported operation') + + yield ErrorInput(SampleInput(src, args=(0, ind,), kwargs=dict(out=src)), + error_type=RuntimeError, + error_regex='unsupported operation') + + yield ErrorInput(SampleInput(ind.clone(), args=(0, ind[1:],), kwargs=dict(out=ind[:1])), + error_type=RuntimeError, + error_regex='unsupported operation') + +def error_inputs_take(op_info, device, **kwargs): + x = torch.rand((1,), device=device).expand((3,)) + src = torch.rand((6,), device=device) + ind = torch.tensor([2, 1, 0], device=device, dtype=torch.int64) + + yield ErrorInput(SampleInput(src, args=(ind,), kwargs=dict(out=x)), + error_type=RuntimeError, + error_regex='unsupported operation') + + yield ErrorInput(SampleInput(src, args=(ind,), kwargs=dict(out=src)), + error_type=RuntimeError, + error_regex='unsupported operation') + + yield ErrorInput(SampleInput(ind.clone(), args=(ind[1:],), kwargs=dict(out=ind[:-1])), + error_type=RuntimeError, + error_regex='unsupported operation') + +# Error inputs for scatter +def error_inputs_scatter_and_scatter_add(op_info, device, **kwargs): + # Error when self.dtype != src.dtype (and src is not a scalar) + src = make_tensor((2, 5), device=device, dtype=torch.float32) + idx = torch.tensor(((0, 1), (1, 2)), device=device, dtype=torch.long) + dst = torch.zeros((3, 5), device=device, dtype=torch.double) + yield ErrorInput(SampleInput(dst, args=(0, idx, src)), + error_regex="Expected self.dtype to be equal to src.dtype") + + # Index and destination must have the same number of dimensions + src = make_tensor((2, 5), device=device, dtype=torch.float32) + idx = torch.tensor(((0, 1), (1, 2)), device=device, dtype=torch.long) + dst = torch.zeros((3, 5, 3), device=device, dtype=torch.float32) + yield ErrorInput(SampleInput(dst, args=(0, idx, src)), + error_regex="Index tensor must have the same number of dimensions as self tensor") + + # Index and src must have the same number of dimensions when src is not a scalar + src = make_tensor((2, 5, 2), device=device, dtype=torch.float32) + idx = torch.tensor(((34, 1), (1, 2)), device=device, dtype=torch.long) + dst = torch.zeros((3, 5), device=device, dtype=torch.float32) + yield ErrorInput(SampleInput(dst, args=(0, idx, src)), + error_regex="Index tensor must have the same number of dimensions as src tensor") + + # Index out of bounds + # NOTE: this ErrorInput is guarded because bounds checking does not occur on CUDA devices + if torch.device(device).type == 'cpu': + src = make_tensor((2, 5), device=device, dtype=torch.float32) + idx = torch.tensor(((34, 1), (1, 2)), device=device, dtype=torch.long) + dst = torch.zeros((3, 5), device=device, dtype=torch.float32) + yield ErrorInput(SampleInput(dst, args=(0, idx, src)), + error_regex="index 34 is out of bounds for dimension 0 with size 3") + +def error_inputs_renorm(op_info, device, **kwargs): + zero_d = torch.randn((), device=device) + yield ErrorInput(SampleInput(zero_d, args=(0.5, 0, 1.0)), error_type=RuntimeError, + error_regex="needs at least 2 dimensions, got 0 dimensions") + + +def error_inputs_ormqr(op_info, device, **kwargs): + zero_d = torch.randn((), device=device) + yield ErrorInput(SampleInput(zero_d, args=(zero_d, zero_d)), error_type=RuntimeError, + error_regex="input must have at least 2 dimensions") + + # https://github.com/pytorch/pytorch/issues/85218 + tensor_0 = torch.full((5, 0,), 1, device=device) + tensor_1 = torch.full((5,), 1, device=device) + tensor_2 = torch.full((5, 5,), 1, device=device) + bool_3 = True + bool_4 = True + yield ErrorInput(SampleInput(tensor_0, args=(tensor_1, tensor_2, bool_3, bool_4)), error_type=RuntimeError, + error_regex=r"tau.shape\[-1\] must be equal to min\(other.shape\[-2\], input.shape\[-1\]\)") + + +def error_inputs_diag(op_info, device, **kwargs): + zero_d = torch.randn((), device=device) + yield ErrorInput(SampleInput(zero_d, args=(0,)), error_type=RuntimeError, + error_regex="1D or 2D") + zero_d = torch.randn(1, 1, 1, device=device) + yield ErrorInput(SampleInput(zero_d, args=(0,)), error_type=RuntimeError, + error_regex="1D or 2D") + +def error_inputs_embedding(op_info, device, **kwargs): + indices = torch.rand(2, 2, device=device).long() + weights = [ + torch.tensor(1.0, device=device), + torch.tensor(1.0, device=device).reshape(1, 1, 1), + ] + + for weight in weights: + yield ErrorInput(SampleInput(weight, args=(indices,)), error_type=RuntimeError, + error_regex="'weight' must be 2-D") + + +def error_inputs_t(op_info, device, **kwargs): + yield ErrorInput( + SampleInput(torch.randn(2, 3, 4, 5, device=device)), + error_regex="expects a tensor with <= 2", + ) + + +def error_inputs_multinomial(op_info, device, **kwargs): + x = torch.empty(1, 2, 3, dtype=torch.double, device=device) + yield ErrorInput(SampleInput(x, args=(2,)), + error_regex="prob_dist must be 1 or 2 dim") + + x = torch.empty(1, 2, dtype=torch.long, device=device) + yield ErrorInput(SampleInput(x, args=(2,)), + error_regex="multinomial only supports floating-point dtypes for input") + + x = torch.empty(1, 2, dtype=torch.double, device=device) + y = torch.empty(1, 2, dtype=torch.double, device=device) + yield ErrorInput(SampleInput(x, args=(2,), kwargs=dict(out=y)), + error_regex="multinomial expects Long tensor out") + + x = torch.empty(2, dtype=torch.double, device=device) + yield ErrorInput(SampleInput(x, args=(0,)), + error_regex="cannot sample n_sample <= 0 samples") + + x = torch.empty(2, dtype=torch.double, device=device) + yield ErrorInput(SampleInput(x, args=(-1,)), + error_regex="cannot sample n_sample <= 0 samples") + + x = torch.empty(2, dtype=torch.double, device=device) + yield ErrorInput(SampleInput(x, args=(3, False,)), + error_regex="cannot sample n_sample > prob_dist") + + x = torch.empty(16777217, dtype=torch.double, device=device) + yield ErrorInput(SampleInput(x, args=(3,)), + error_regex="number of categories cannot exceed") + + inputs = ((1., -1., 1.), (1., inf, 1.), (1., -inf, 1.), (1., 1., nan)) + + err_msg1 = "probability tensor contains either `inf`, `nan` or element < 0" + err_msg2 = "invalid multinomial distribution" + + rep_arg = (False, True) if torch.device(device).type == 'cpu' else (False,) + + if torch.device(device).type == 'cpu': + for rep in rep_arg: + kwargs = {'num_samples': 2, 'replacement': rep} + + for shape in inputs: + # error case when input tensor contains `inf`, `nan` or negative element + yield ErrorInput(SampleInput(torch.tensor(shape), kwargs=kwargs), + error_regex=err_msg1 if rep is False else err_msg2) + + # error case for the invalid multinomial distribution (sum of probabilities <= 0), 1-D input + x = torch.zeros(3, device=device) + yield ErrorInput(SampleInput(x, kwargs=kwargs), + error_regex=err_msg2) + + # error case for the invalid multinomial distribution (sum of probabilities <= 0), 2-D input + x = torch.zeros(3, 3, device=device) + yield ErrorInput(SampleInput(x, kwargs=kwargs), + error_regex=err_msg2) + + # error case for the invalid multinomial distribution + x[1, :] = 1 + yield ErrorInput(SampleInput(x, kwargs=kwargs), + error_regex=err_msg2) + +def error_inputs_gradient(op_info, device, **kwargs): + for dtype in [torch.long, torch.float32, torch.complex64]: + t = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], device=device, dtype=dtype) + + dim = (1, 0) + spacing = [0.1] + yield ErrorInput(SampleInput(t, kwargs=dict(spacing=spacing, dim=dim, edge_order=1)), + error_type=RuntimeError, + error_regex='torch.gradient expected spacing to be unspecified, a scalar ') + + yield ErrorInput(SampleInput(t, kwargs=dict(edge_order=3)), + error_type=RuntimeError, + error_regex='torch.gradient only supports edge_order=1 and edge_order=2.') + + dim = (1, 1) + spacing = 0.1 + yield ErrorInput(SampleInput(t, kwargs=dict(spacing=spacing, dim=dim, edge_order=1)), + error_type=RuntimeError, + error_regex='dim 1 appears multiple times in the list of dims') + + dim = (0, 1) + coordinates = [torch.tensor([1, 2, 4], device='cpu'), torch.tensor([1, 2, 4], device='meta')] + yield ErrorInput(SampleInput(t, kwargs=dict(spacing=coordinates, dim=dim, edge_order=1)), + error_type=RuntimeError, + error_regex='torch.gradient expected each tensor to be on the same device,') + + yield ErrorInput(SampleInput(t, kwargs=dict(dim=3)), + error_type=IndexError, error_regex='') + + t = torch.tensor([[1], [2], [3]]) + yield ErrorInput(SampleInput(t, kwargs=dict(edge_order=1)), + error_type=RuntimeError, + error_regex='torch.gradient expected each dimension size to be at least') + + t = torch.tensor([[1, 2], [3, 4]]) + yield ErrorInput(SampleInput(t, kwargs=dict(edge_order=2)), + error_type=RuntimeError, + error_regex='torch.gradient expected each dimension size to be at least') + +def sample_inputs_rrelu(op_info, device, dtype, requires_grad, **kwargs): + yield from sample_inputs_elementwise_unary( + op_info, device, dtype, requires_grad, op_kwargs=dict(lower=0., upper=1., training=True)) + + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + yield SampleInput(make_arg(S)) + yield SampleInput(make_arg(S), training=False) + +def error_inputs_rrelu(op_info, device, **kwargs): + input = make_tensor((S, S), device=device, dtype=torch.float32) + yield ErrorInput(SampleInput(input, kwargs={'lower': 0.3, 'upper': 0.1}), + error_regex='Lower bound should be less than or equal to the upper bound') + +def error_inputs_masked_select(op_info, device, **kwargs): + x = torch.rand((1,), device=device).expand((3,)) + y = torch.rand((6,), device=device) + mask = torch.tensor([True, False, True, True, False, False], device=device) + + yield ErrorInput(SampleInput(y, args=(mask,), kwargs=dict(out=x)), + error_type=RuntimeError, + error_regex='unsupported operation') + + yield ErrorInput(SampleInput(y, args=(mask,), kwargs=dict(out=y)), + error_type=RuntimeError, + error_regex='unsupported operation') + + yield ErrorInput(SampleInput(mask.clone(), args=(mask,), kwargs=dict(out=mask)), + error_type=RuntimeError, + error_regex='unsupported operation') + +def error_inputs_median(op_info, device, **kwargs): + x = torch.tensor([[[[[[[[[[[[[[[[[[[[[[[[[nan], + [nan]]]]]]]]]]]]]]]]]]]]]]]]], device=device) + if device == 'cuda': + yield ErrorInput(SampleInput(x, kwargs=dict(dim=(-1))), + error_type=RuntimeError, + error_regex='CUDA Tensors cannot have more than 25 dimensions') + else: + return + + +def error_inputs_index_select(op_info, device, **kwargs): + x = torch.rand((1, 6), device=device).expand((2, 6)) + y = torch.rand((3, 6), device=device) + ind = torch.tensor([0, 1], dtype=torch.int64, device=device) + + yield ErrorInput(SampleInput(y, args=(1, ind,), kwargs=dict(out=x)), + error_type=RuntimeError, + error_regex='unsupported operation') + +def error_inputs_index_add(op_info, device, **kwargs): + result = torch.tensor([[1., 2.], [4., 5.], [7., 8.]]) + source = torch.tensor([2., 4.]) + + yield ErrorInput(SampleInput(result, args=(0, torch.tensor([0, 2]), source)), + error_type=RuntimeError, + error_regex=r'source tensor shape must match self tensor shape, ' + r'excluding the specified dimension. Got self.shape = \[3, 2\] source.shape = \[2\]') + +def error_inputs_logcumsumexp(op_info, device, **kwargs): + dim = 3 + srcs = [torch.randn(5, 2, device=device), torch.randn(0, 2, device=device)] + for src in srcs: + yield ErrorInput(SampleInput(src, args=(dim,)), + error_type=IndexError, + error_regex='Dimension out of range') + +def sample_inputs_take_along_dim(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad, low=None, high=None) + yield SampleInput( + make_arg((S, S)), gather_variable((S, S), 1, S, True, device=device), 0) + + # `indices` broadcast + yield SampleInput( + make_arg((S, S)), gather_variable((1, S // 2), 0, S, True, device=device), 1) + + # `self` broadcast + yield SampleInput( + make_arg((1, S)), gather_variable((S, S // 2), 0, S, True, device=device), 1) + + # without `dim` arg + yield SampleInput( + make_arg((S, S)), gather_variable((S, S // 2), 0, S, True, device=device)) + + # Negative indices sample — guarded against python_ref + if not kwargs.get('is_python_ref', False): + neg_idx = gather_variable((S, S), 1, S, True, device=device) - S + yield SampleInput( + make_arg((S, S)), + neg_idx, + 1) + + +def error_inputs_aminmax_amax_amin(op_info, device, is_ref=False, **kwargs): + + # Error Inputs for zero-dim tensors, when 'dim' arg is not provided. + shape = (S, 0, S) + err_msg_amax_amin = "reduction" + err_msg_aminmax = "cannot compute aminmax over an empty dimension as the operation has no identity" + if op_info.name in ['amax', 'amin', '_refs.amax', '_refs.amin']: + yield ErrorInput(SampleInput(torch.rand(shape, device=device)), error_regex=err_msg_amax_amin) + elif op_info.name == 'aminmax': + yield ErrorInput(SampleInput(torch.rand(shape, device=device)), error_regex=err_msg_aminmax) + + # Error Inputs for tensors with more than 64 dimension + sizes = [1] * 65 + err_msg1 = "only tensors with up to 64 dims are supported" + yield ErrorInput(SampleInput(torch.randn(sizes, device=device), kwargs={'dim': -1}), + error_regex=err_msg1) + yield ErrorInput(SampleInput(torch.randn(sizes, device=device), kwargs={'dim': 64}), + error_regex=err_msg1) + + # Error Inputs for repeated 'dim' + if op_info.name in ['amax', 'amin', '_refs.amax', '_refs.amin']: + dims = [(0, 0), (0, -4)] + err_msg2 = "in the list of dims" + x = torch.randn(S, S, S, S, device=device) + for dim in dims: + yield ErrorInput(SampleInput(x, kwargs={'dim': dim}), error_regex=err_msg2) + + # Error Input for illegal dtype + input5 = torch.randn(L, L, dtype=torch.float32, device=device) + max_values = torch.empty(L, dtype=torch.float32, device=device) + min_values = torch.empty(L, dtype=torch.double, device=device) + illegal_values = torch.empty(L, dtype=torch.int, device=device) + + # Unlike regular PyTorch, amax and amin refs don't require input and out + # dtypes to match exactly: + # https://github.com/pytorch/pytorch/pull/87765#pullrequestreview-1162023824 + if is_ref: + err_msg_amax_amin2 = ("Attempting to cast from torch.float32 to out tensor with dtype " + "torch.int32, but this can't be cast because it is not safe!") + else: + err_msg_amax_amin2 = ("Expected the dtype for input and out to match, but got Float " + "for input's dtype and Int for out's dtype.") + err_msg_aminmax2 = "Expected out tensor to have dtype float, but got double instead" + + if op_info.name in ['amax', 'amin', '_refs.amax', '_refs.amin']: + yield ErrorInput(SampleInput(input5, kwargs={'dim': 0, 'out': illegal_values}), + error_regex=err_msg_amax_amin2) + elif op_info.name == 'aminmax': + yield ErrorInput(SampleInput(input5, kwargs={'dim': 0, 'out': (max_values, min_values)}), + error_regex=err_msg_aminmax2) + + # Error Inputs for functions to raise an error on specified zero'd dimension as reduction dim + err_msg3 = "reduction" + # FIXME: eager and ref impl throw different types of errors + error_type = IndexError if 'refs' not in op_info.name else RuntimeError + yield ErrorInput(SampleInput(torch.rand(shape, device=device), kwargs={'dim': 1}), + error_type=error_type, error_regex=err_msg3) + +def sample_inputs_aminmax(op_info, device, dtype, requires_grad, **kwargs): + test_cases: tuple[tuple, dict] = ( # type: ignore[assignment] + ((S, S, S), {}), + ((S, S, S), {'dim': 1}), + ((S, S, S), {'dim': 1, 'keepdim': True}), + ((), {'dim': 0}), + ((), {}), + ((), {'dim': 0, 'keepdim': True}), + ((S, 0, S), {'dim': 0}), + ) + + for shape, kwargs in test_cases: + yield SampleInput( + make_tensor(shape, dtype=dtype, device=device, requires_grad=requires_grad), + **kwargs) + +def error_inputs_diff(op_info, device, **kwargs): + t = torch.rand((1, 3), device=device) + n = -1 + yield ErrorInput(SampleInput(t, args=(n, ), kwargs=kwargs), + error_type=RuntimeError, + error_regex=f'order must be non-negative but got {n}') + +def sample_inputs_diff(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) + + test_cases = ( + ((1,), 0, None, None), + ((S,), 0, None, None), + ((S, 1), 0, None, None), + ((S, 1), 1, None, None), + ((S, S), 0, None, None), + ((S, S), 1, None, None), + ((S, S), 0, (1, S), (2, S)), + ((S, S), 0, None, (2, S)), + ((XS, XS, XS), 1, None, None), + ((XS, XS, XS), 2, None, None), + ((XS, XS, XS), 1, (XS, 1, XS), (XS, 1, XS)), + ((XS, XS, XS), 2, (XS, XS, 1), (XS, XS, 1)), + ((XS, XS, XS), 2, (XS, XS, XS), (XS, XS, XS)),) + + for size, dim, size_prepend, size_append in test_cases: + prepend_size = 0 if (size_prepend is None) else size_prepend[dim] + append_size = 0 if (size_append is None) else size_append[dim] + dim_size = size[dim] + prepend_size + append_size + for n in range(dim_size): + input_tensor = make_arg(size) + prepend = make_arg(size_prepend) if size_prepend else None + append = make_arg(size_append) if size_append else None + yield SampleInput(input_tensor, n, dim, prepend, append) + + # add some samples with n > dim_size + yield SampleInput(make_arg((XS, XS, XS)), S + 1, 1) + yield SampleInput(make_arg((XS, XS, XS)), S * 3 + 2, 2, make_arg((XS, XS, XS)), make_arg((XS, XS, XS))) + +def sample_inputs_histogram(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) + + sizes = ((), (S,), (S, S), (S, S, S), (S, 1, S), (S, 0, S)) + + for size, bin_ct, weighted, density in product(sizes, range(1, 5), [False, True], [False, True]): + input_tensor = make_arg(size) + weight_tensor = make_arg(size) if weighted else None + + yield SampleInput(input_tensor, bin_ct, + weight=weight_tensor, density=density) + + bins_tensor = make_arg((bin_ct + 1,)) + sorted_bins, _bins_indices = torch.sort(bins_tensor) + yield SampleInput(input_tensor, sorted_bins, + weight=weight_tensor, density=density) + +def sample_inputs_histogramdd(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) + + sizes = ((S, S), (S, S, S), (S, 1, S), (S, 0, S)) + bin_ct_patterns = ((1, 1, 1, 1, 1), (2, 3, 2, 3, 2), (3, 2, 3, 2, 3)) + + for size, bin_ct_pattern, weighted, density in product(sizes, bin_ct_patterns, [False, True], [False, True]): + input_tensor = make_arg(size) + bin_ct = bin_ct_pattern[:size[-1]] + weight_tensor = make_arg(size[:-1]) if weighted else None + + yield SampleInput(input_tensor, bin_ct, + weight=weight_tensor, density=density) + + bins_tensor = [make_arg(ct + 1) for ct in bin_ct] + yield SampleInput(input_tensor, bins_tensor, + weight=weight_tensor, density=density) + +def error_inputs_histogramdd(opinfo, device, **kwargs): + invalid_bins = [1, 1, 1, 1, 1] + make_arg = partial(make_tensor, dtype=torch.float, device=device, requires_grad=False) + msg = "histogramdd: The size of bins must be equal to the innermost dimension of the input." + yield ErrorInput(SampleInput(make_arg(5, 6), invalid_bins), error_regex=msg) + +def sample_inputs_histc(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) + + sizes = ((), (S,), (S, S), (S, S, S), (S, 1, S), (S, 0, S)) + + for size, min, max in product(sizes, [0, -10], [0, 10]): + # construct sample input omitting bins arg + yield SampleInput(make_arg(size), min=min, max=max) + + # construct sample inputs with a few different bins values + for bins in [1, 3, 10]: + yield SampleInput(make_arg(size), bins=bins, min=min, max=max) + +def sample_inputs_bincount(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) + + for size, weighted in product((S, M), [False, True]): + input_tensor = torch.randint(0, size, (size,), dtype=dtype, device=device) + weight_tensor = make_arg((size,)) if weighted else None + + max_val = int(input_tensor.max().item()) + + for minlength in [0, max_val // 2, max_val, 2 * max_val]: + yield SampleInput( + input_tensor, weights=weight_tensor, minlength=minlength) + +def sample_inputs_bucketize(op_info, device, dtype, requires_grad, reference_inputs_mode=False, **kwargs): + make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) + + sizes = (((), S), ((S,), S), ((S, S), S), ((S, S, S), S), ((S, 1, S), S), ((S, 0, S), S)) + + if reference_inputs_mode: + sizes += (((256,), 128), ((128,), 256), ((32, 32), 11), ((32, 4, 32), 33)) + + for (input_shape, nb), out_int32, right in product(sizes, [False, True], [False, True]): + input_tensor = make_arg(input_shape) + boundaries = make_arg(nb).msort() + + yield SampleInput(input_tensor, boundaries, + out_int32=out_int32, right=right) + +reference_inputs_bucketize = partial(sample_inputs_bucketize, reference_inputs_mode=True) + +def error_inputs_bucketize(opinfo, device, **kwargs): + make_arg = partial(make_tensor, dtype=torch.float, device=device, requires_grad=False) + yield ErrorInput(SampleInput(make_arg((S, S, S)), make_arg((S, S))), + error_regex="boundaries tensor must be 1 dimension") + +def sample_inputs_searchsorted(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) + + # (unsorted tensor size, (input sizes,), is_scalar) + sizes = ( + ((0,), ((0,),), False), + ((M,), ((), (M,), (M, M)), False), + ((0, 0), ((0, 0),), False), + ((M, M), ((M, M),), False), + ((0, 0, 0), ((0, 0, 0),), False), + ((M, M, M), ((M, M, M),), False), + ((L,), ((),), True), + ) + + for (size, input_sizes, is_scalar), noncontiguous, out_int32, right in product( + sizes, [False, True], [False, True], [False, True] + ): + unsorted_tensor = make_arg(size, noncontiguous=noncontiguous) + for input_size in input_sizes: + input = make_arg(input_size, noncontiguous=noncontiguous) + if is_scalar: + input = input.item() + if np.prod(size) == 0: + boundary_tensor = unsorted_tensor + sorter = make_tensor(size, dtype=torch.int64, device=device, noncontiguous=noncontiguous) + else: + boundary_tensor, sorter = torch.sort(unsorted_tensor) + side = "right" if right else "left" + + yield SampleInput(boundary_tensor, input, out_int32=out_int32, right=right) + yield SampleInput(boundary_tensor, input, out_int32=out_int32, side=side) + + yield SampleInput(unsorted_tensor, input, out_int32=out_int32, right=right, sorter=sorter) + yield SampleInput(unsorted_tensor, input, out_int32=out_int32, side=side, sorter=sorter) + +def sample_inputs_gradient(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad, low=None, high=None) + test_cases_float = ( + ((S,), None, None, 1), + ((S,), 2., None, 1), + ((S, S), None, None, 2), + ((S, S), [2.0, 2.1], None, 1), + ((S, S), [2.0, 2.1], (0, 1), 1), + ((4, 4, 4), [2., 1.], (0, 1), 2), + ) + for size, spacing, dim, edge_order in test_cases_float: + t = make_arg(size) + yield SampleInput(t, dim=dim, spacing=spacing, edge_order=edge_order) + + test_cases_tensor = ( + ((3, 3, 3), ((1.1, 2.0, 3.5), (4.0, 2, 6.0)), (0, -1), 1), + ((3, 3, 3), ((1.0, 3.0, 2.0), (8.0, 6.0, 1.0)), (0, 1), 2), + ) + for size, coordinates, dim, edge_order in test_cases_tensor: + t = make_arg(size) + coordinates_tensor_list = [] + for coords in coordinates: + # `coords` will always contain floating point values and Python 3.10 does not support this + # implicit conversion to an integer using `__int__` + # TODO: this can be simplified after https://github.com/pytorch/pytorch/issues/69316 is fixed + a = torch.tensor(coords, device=device) + coordinates_tensor_list.append(a.to(dtype)) + yield SampleInput(t, dim=dim, spacing=coordinates_tensor_list, edge_order=edge_order) + +def sample_inputs_getitem(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) + test_args = [ + ([1, 2],), + (slice(0, 3),), + ((slice(0, 3), 1),), + (([0, 2, 3], [1, 3, 3], [0, 0, 2]),), + (([0, 0, 3], [1, 1, 3], [0, 0, 2]),), + ((slice(None), slice(None), [0, 3]),), + ((slice(None), [0, 3], slice(None)),), + (([0, 3], slice(None), slice(None)),), + (([0, 3], [1, 2], slice(None)),), + (([0, 3], ),), + (([0, 3], slice(None)),), + (([0, 3], Ellipsis),), + (([0, 2, 3], [1, 3, 3], torch.LongTensor([0, 0, 2])),), + (index_variable(2, S, device=device),), + (mask_not_all_zeros((S,)),), + ] + + for args in test_args: + yield SampleInput(make_arg((S, S, S)), args=args) + + yield SampleInput(make_arg((S, S, S, S)), args=((slice(None), [0, 1], slice(None), [0, 1]),)) + +def sample_inputs_index_put(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) + + for accumulate in [False, True]: + # Test with indices arg + yield SampleInput( + make_arg((S, S,)), + # As defined in the docs, if accumulate is false, duplicate indices are not supported + (index_variable(2 if accumulate else 1, S, device=device),), + make_arg((2 if accumulate else 1, S)), + accumulate=accumulate) + + # Test with mask arg + mask = torch.zeros(S, dtype=torch.bool) if accumulate else mask_not_all_zeros((S,)) + yield SampleInput( + make_arg((S, S)), (mask, ), make_arg((S,)), accumulate=accumulate) + +def sample_inputs_sort(op_info, device, dtype, requires_grad, **kwargs): + def small_3d_unique(): + res = torch.randperm(S * S * S, dtype=torch.int64, device=device).view(S, S, S) + res = res.to(dtype).requires_grad_(requires_grad) + return res + + def large_1d_unique(): + res = torch.randperm(L * L * L, dtype=torch.int64, device=device) + res = res.to(dtype).requires_grad_(requires_grad) + return res + + # Test case for large tensor. + yield SampleInput(large_1d_unique()) + + # Test cases for small 3d tensors. + # Imitates legacy tests from test/test_torch.py + dims = range(-3, 3) + flag = [True, False] + for dim, descending, stable in product(dims, flag, flag): + # default schema without stable sort + if not (dtype == torch.bool and torch.device(device).type == 'cuda'): + # bool and cuda requires stable sort for stable results, at least + # for the return index + yield SampleInput(small_3d_unique(), dim, descending) + # schema with stable sort, no CUDA support yet + if torch.device(device).type == 'cpu': + yield SampleInput( + small_3d_unique(), dim=dim, descending=descending, stable=stable) + + # Test cases for scalar tensor + tensor_opt = dict(dtype=dtype, device=device, requires_grad=requires_grad) + yield SampleInput(torch.tensor(1, **tensor_opt)) + yield SampleInput(torch.tensor(1, **tensor_opt), 0) + yield SampleInput(torch.tensor(1, **tensor_opt), 0, True) + + # Test cases for empty tensor + yield SampleInput(torch.tensor((), **tensor_opt)) + yield SampleInput(torch.tensor((), **tensor_opt), 0) + yield SampleInput(torch.tensor((), **tensor_opt), 0, True) + + # Test cases for stable sort + yield SampleInput(small_3d_unique(), stable=True) + yield SampleInput(small_3d_unique(), dim=0, stable=True) + yield SampleInput(small_3d_unique(), dim=0, descending=True, stable=True) + +def sample_inputs_threshold(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) + sizes = ((), (S,), (S, S), (S, S, S)) + for x_size in sizes: + # threshold and values args must be numbers + yield SampleInput(make_arg(x_size), make_arg(()).item(), make_arg(()).item()) + +def sample_inputs_unique(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + sizes = ((), (S,), (S, S), (S, S, S), (S, 1, S), (S, 0, S)) + + for shape, sorted, return_inverse, return_counts, dim in \ + product(sizes, [False, True], [False, True], [False, True], [None, -2, -1, 0, 1, 2]): + # torch.unique cannot be called if the input tensor has a zero dimension which isn't the selected dim + if 0 in shape and shape.index(0) is not dim: + continue + + # skip invalid dim args + if dim is not None and (dim < -len(shape) or dim >= len(shape)): + continue + + kwargs = dict(sorted=sorted, return_inverse=return_inverse, return_counts=return_counts, dim=dim) + + # construct a test case with only one distinct value + input_t = torch.zeros(shape, dtype=dtype, device=device, requires_grad=requires_grad) + yield SampleInput(input_t, **kwargs) + + # construct a test case with mixed 0s and 1s + input_t = make_arg(shape, dtype=torch.bool, requires_grad=False)\ + .to(dtype).requires_grad_(requires_grad) + yield SampleInput(input_t, **kwargs) + + # construct a test case with many different values + yield SampleInput(make_arg(shape), **kwargs) + +def sample_inputs_unique_consecutive(*args, **kwargs): + for sample_input in sample_inputs_unique(*args, **kwargs): + if not sample_input.kwargs["sorted"]: + sample_input.kwargs.pop("sorted") + yield sample_input + +def sample_inputs_adaptive_avg_pool1d(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + # Ordered as (input shape, output size) + cases = ( + ((0, 8, 8), (5,)), + ((3, 8, 8), 5), + ((3, 8, 8), 1) + ) + + for input_shape, output_size in cases: + # Batched + yield SampleInput(make_arg(input_shape), args=(output_size,)) + # Unbatched + yield SampleInput(make_arg(input_shape[1:]), args=(output_size,)) + + +def error_inputs_adaptive_avg_pool1d(opinfo, device, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=torch.float32) + + # error inputs for empty output + yield ErrorInput(SampleInput(make_arg((1, 2, 3)), output_size=()), + error_regex="'output_size' should contain one int") + + # error inputs for output_size lesser than 0 + yield ErrorInput(SampleInput(make_arg((1, 1, 1)), output_size=(-1,)), + error_regex="elements of output_size must be greater than or equal to 0") + + +def sample_inputs_adaptive_avg_pool2d(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + # Ordered as (input shape, output size) + cases = ( + ((1, 8, 8, 8), (5, 7)), + ((2, 8, 8, 8), (None, 7)), + ((1, 8, 4, 3), (5, None)), + ((1, 8, 4, 3), (None, None)), + ((1, 8, 4, 3), (5)), + ) + + for input_shape, output_size in cases: + # Batched + yield SampleInput(make_arg(input_shape), args=(output_size,)) + # Unbatched + yield SampleInput(make_arg(input_shape[1:]), args=(output_size,)) + + +def error_inputs_adaptive_avg_pool2d(opinfo, device, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=torch.float32) + + # error inputs for incorrect input dimension + yield ErrorInput(SampleInput(make_arg((2, 2)), output_size=(2, 2)), + error_type=ValueError, error_regex="Input dimension should be at least 3") + + # error inputs for empty output + yield ErrorInput(SampleInput(make_arg((1, 2, 3, 4)), output_size=()), + error_regex="output_size must be 2") + + # error inputs for output_size lesser than 0 + yield ErrorInput(SampleInput(make_arg((1, 1, 1, 1)), output_size=(-1, 0)), + error_regex="elements of output_size must be greater than or equal to 0") + + +def sample_inputs_adaptive_avg_pool3d(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + # Ordered as (input shape, output size) + cases = ( + ((0, 8, 8, 8, 8), (5, 7, 4)), + ((1, 8, 4, 3, 7), (None, None, None)), + ((1, 8, 4, 3, 7), (1, 1, 1)), + ((3, 3, 8, 8, 6), (5, 7, None)), + ((1, 3, 8, 8, 6), (5, None, 2)), + ((3, 3, 8, 8, 6), (None, 3, 2)), + ) + + for input_shape, output_size in cases: + # Batched + yield SampleInput(make_arg(input_shape), args=(output_size,)) + # Unbatched + yield SampleInput(make_arg(input_shape[1:]), args=(output_size,)) + + +def error_inputs_adaptive_avg_pool3d(opinfo, device, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=torch.float32) + + # error inputs for incorrect input dimension + yield ErrorInput(SampleInput(make_arg((2, 2, 2)), output_size=(2, 2, 2)), + error_type=ValueError, error_regex="Input dimension should be at least 4") + + # error inputs for empty output + yield ErrorInput(SampleInput(make_arg((1, 2, 3, 4)), output_size=()), + error_regex="output_size must be 3") + + # error inputs for output_size lesser than 0 + yield ErrorInput(SampleInput(make_arg((1, 1, 1, 1, 1)), output_size=(-1, 0, 2)), + error_regex="elements of output_size must be greater than or equal to 0") + + +def sample_inputs_adaptive_max_pool1d(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + # Ordered as (input shape, output size) + cases = ( + # ((0, 8, 8), (5,)), + # 0 batch size doesn't work, cannot reshape tensor of 0 elements into shape [0, 8, -1] + ((3, 4, 4), 3), + ((3, 4, 4), 1) + ) + + for shapes, return_idx in product(cases, (True, False)): + # Batched + yield SampleInput(make_arg(shapes[0]), args=(shapes[1], return_idx)) + # Unbatched + yield SampleInput(make_arg(shapes[0][1:]), args=(shapes[1], return_idx)) + + +def error_inputs_adaptive_max_pool1d(opinfo, device, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=torch.float32) + + # error inputs for empty output + yield ErrorInput(SampleInput(make_arg((1, 2, 3)), output_size=()), + error_regex="'output_size' should contain one int") + + # error inputs for output_size lesser than 0 + yield ErrorInput(SampleInput(make_arg((1, 1, 1)), output_size=(-1,)), + error_regex="Trying to create tensor with negative dimension") + +def sample_inputs_adaptive_max_pool2d(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + # Ordered as (input shape, output size) + cases = ( + # ((0, 8, 8, 8), (5, 7)), + # 0 batch size doesn't work, cannot reshape tensor of 0 elements into shape [0, 8, -1] + ((1, 4, 4, 4), (2, 3)), + ((2, 4, 4, 4), (None, 3)), + ((2, 4, 4, 4), (1, 1)), + ((1, 4, 4, 3), (3, None)), + ((1, 4, 4, 3), (None, None)), + ((1, 4, 4, 3), (3)), + ) + + for shapes, return_idx in product(cases, (True, False)): + # Batched + yield SampleInput(make_arg(shapes[0]), args=(shapes[1], return_idx)) + # Unbatched + yield SampleInput(make_arg(shapes[0][1:]), args=(shapes[1], return_idx)) + +def error_inputs_adaptive_max_pool2d(opinfo, device, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=torch.float32) + + # error inputs for incorrect input dimension + yield ErrorInput(SampleInput(make_arg((2, 2)), output_size=(2, 2)), + error_type=ValueError, error_regex="Input dimension should be at least 3") + + # error inputs for empty output + yield ErrorInput(SampleInput(make_arg((1, 2, 3, 4)), output_size=()), + error_regex="internal error") + + # error inputs for output_size lesser than 0 + yield ErrorInput(SampleInput(make_arg((1, 1, 1, 1)), output_size=(-1, 0)), + error_regex="Trying to create tensor with negative dimension") + + +def sample_inputs_adaptive_max_pool3d(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + # Ordered as (input shape, output size) + cases = ( + # ((0, 8, 8, 8, 8), (5, 7, 4)), + # 0 batch size doesn't work, cannot reshape tensor of 0 elements into shape [0, 8, -1] + ((1, 4, 4, 3, 5), (None, None, None)), + ((1, 4, 4, 3, 5), (1, 1, 1)), + ((3, 3, 4, 4, 6), (2, 3, None)), + ((1, 3, 4, 4, 6), (3, None, 2)), + ((3, 3, 4, 4, 6), (None, 3, 2)), + ) + + for shapes, return_idx in product(cases, (True, False)): + # Batched + yield SampleInput(make_arg(shapes[0]), args=(shapes[1], return_idx)) + # Unbatched + yield SampleInput(make_arg(shapes[0][1:]), args=(shapes[1], return_idx)) + +def error_inputs_adaptive_max_pool3d(opinfo, device, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=torch.float32) + + # error inputs for incorrect input dimension + yield ErrorInput(SampleInput(make_arg((2, 2, 2)), output_size=(2, 2, 2)), + error_type=ValueError, error_regex="Input dimension should be at least 4") + + # error inputs for empty output + yield ErrorInput(SampleInput(make_arg((1, 2, 3, 4)), output_size=()), + error_regex="internal error") + + # error inputs for output_size lesser than 0 + yield ErrorInput(SampleInput(make_arg((1, 1, 1, 1, 1)), output_size=(-1, 0, 2)), + error_regex="Trying to create tensor with negative dimension") + + +class _TestParamsMaxPoolBase: + + def __init__(self) -> None: + self.kwargs = { + 'kernel_size': [3], + 'stride': [2, None], + 'ceil_mode': [True, False], + 'padding': [0, 1], + 'dilation': [1], + 'return_indices': [True, False] + } + + self.shapes = [ + [1, 2, None], # batch + [2], # channels + [3, 6] # signal + ] + + def _gen_shape(self): + for shape in product(*self.shapes): + # shape[0] is None indicates missing batch dimension + if shape[0] is None: + shape = shape[1:] + + yield shape, torch.contiguous_format + # only 2d (N, C, H, W) rank 4 tensors support channels_last memory format + if len(self.shapes) == 4 and len(shape) == 4: + yield shape, torch.channels_last + + def _gen_kwargs(self): + keys = self.kwargs.keys() + for values in product(*self.kwargs.values()): + yield dict(zip(keys, values, strict=True)) + + def gen_input_params(self): + yield from product(self._gen_shape(), self._gen_kwargs()) + +class _TestParamsMaxPool1d(_TestParamsMaxPoolBase): + + def __init__(self) -> None: + super().__init__() + self.kwargs['kernel_size'] += [(3,)] + self.kwargs['stride'] += [(2,)] + self.kwargs['padding'] += [(1,)] + self.kwargs['dilation'] += [(1,)] + +class _TestParamsMaxPool2d(_TestParamsMaxPoolBase): + + def __init__(self) -> None: + super().__init__() + self.kwargs['kernel_size'] += [(3, 2)] + self.kwargs['stride'] += [(2, 1)] + self.kwargs['padding'] += [(1, 1)] + self.kwargs['dilation'] += [(1, 2)] + + self.shapes.append([6]) + +class _TestParamsMaxPool3d(_TestParamsMaxPoolBase): + + def __init__(self) -> None: + super().__init__() + self.kwargs['kernel_size'] += [(3, 2, 3)] + self.kwargs['stride'] += [(2, 1, 2)] + self.kwargs['dilation'] += [(1, 2, 1)] + + self.shapes.append([6]) + self.shapes.append([5]) + +def sample_inputs_max_pool(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=False) + + params_generator_type_dict = { + 'nn.functional.max_pool1d': _TestParamsMaxPool1d, + 'nn.functional.max_pool2d': _TestParamsMaxPool2d, + 'nn.functional.max_pool3d': _TestParamsMaxPool3d, + 'max_pool2d_with_indices_backward': _TestParamsMaxPool2d, + } + + params_generator = params_generator_type_dict[op_info.name]() + for (shape, memory_format), kwargs in params_generator.gen_input_params(): + arg = make_arg(shape).to(memory_format=memory_format).requires_grad_(requires_grad) + yield SampleInput(arg, kwargs=kwargs) + +def max_pool2d_backward(*args, kernel_size=(), stride=(), padding=(0,), dilation=(1,), ceil_mode=False, **kwargs): + out, indices = torch.nn.functional.max_pool2d_with_indices( + *args, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, ceil_mode=ceil_mode, return_indices=True) + grad_out = torch.ones_like(out) + if stride is None: + stride = kernel_size + out_b = torch.ops.aten.max_pool2d_with_indices_backward.default( + grad_out, *args, kernel_size, stride, padding, dilation, ceil_mode, indices) + return out_b + +def error_inputs_max_pool1d(op_info, device, **kwargs): + # Toggle requires_grad because `max_pool1d` has different path + # based on whether `requires_grad` is set or not. + for requires_grad in (True, False): + make_arg = partial(make_tensor, device=device, dtype=torch.float, requires_grad=requires_grad) + # error inputs when pad is negative + x = make_arg((0, 1, 49)) + yield ErrorInput(SampleInput(x, kwargs={'kernel_size': 2, 'stride': 50, 'padding': -1, 'return_indices': True}), + error_regex='pad must be non-negative') + + # error inputs when pad > kernel_size / 2 + yield ErrorInput(SampleInput(x, kwargs={'kernel_size': 2, 'stride': 50, 'padding': 4, 'return_indices': True}), + error_regex='pad should be at most half of effective kernel size') + + # error inputs when pad > ((kernel_size - 1) * dilation + 1) / 2, when dilation is not default + yield ErrorInput(SampleInput(x, + kwargs={'kernel_size': 3, 'dilation': 2, 'stride': 1, 'padding': 3, 'return_indices': True}), + error_regex='pad should be at most half of effective kernel size') + + # error inputs for input tensor + error_msg = r'Expected 2D or 3D \(batch mode\) tensor with optional 0 dim batch size for input' + yield ErrorInput(SampleInput(make_arg((), requires_grad=requires_grad), kwargs={'kernel_size': 1}), + error_regex=error_msg) + + # error inputs for empty input + yield ErrorInput(SampleInput(torch.tensor([], device=device, requires_grad=requires_grad), + kwargs={'kernel_size': 1}), + error_regex=error_msg) + + # error: unbatched input with 0 sized non-batch dims. + yield ErrorInput(SampleInput(make_arg((0, 10), requires_grad=requires_grad), + kwargs={'kernel_size': 1}), + error_regex=error_msg) + + # error: batched input with 0 sized non-batch dims. + yield ErrorInput(SampleInput(make_arg((1, 10, 0), requires_grad=requires_grad), + kwargs={'kernel_size': 1}), + error_regex=error_msg) + + # error inputs for empty input with stride=0 + error_msg = 'stride must be greater than zero, but got 0' + yield ErrorInput(SampleInput(make_arg((3, 3, 3)), kwargs={'kernel_size': 1, 'stride': 0}), + error_regex=error_msg) + + # error inputs for empty input with dilation=0 + error_msg = 'dilation must be greater than zero, but got 0' + yield ErrorInput(SampleInput(make_arg((3, 3, 3)), + kwargs={'kernel_size': 1, 'stride': 1, 'padding': 0, 'dilation': 0}), + error_regex=error_msg) + + # error inputs for invalid output size + error_msg = 'Invalid computed output size: -2' + yield ErrorInput(SampleInput(make_arg((2, 2, 2)), + kwargs={'kernel_size': 5, 'stride': 1, 'padding': 0, 'dilation': 1}), + error_regex=error_msg) + + # error inputs when kernel_size=0 + error_msg = 'kernel_size must be greater than zero' + yield ErrorInput(SampleInput(x, kwargs={'kernel_size': 0}), + error_regex=error_msg) + + # error inputs for strides > 0 + error_msg = 'stride must be greater than zero' + yield ErrorInput(SampleInput(x, kwargs={'kernel_size': 2, 'stride': 0}), + error_regex=error_msg) + + +def error_inputs_max_pool2d(op_info, device, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=torch.float, requires_grad=False) + # error inputs when pad is negative + x = make_arg((0, 1, 49)) + yield ErrorInput(SampleInput(x, kwargs={'kernel_size': 2, 'stride': 50, 'padding': -1, 'return_indices': True}), + error_regex='pad must be non-negative') + # 2-dimensional kernel + yield ErrorInput(SampleInput(x, kwargs={'kernel_size': (3, 2), 'stride': 50, 'padding': -1, 'return_indices': True}), + error_regex='pad must be non-negative') + + # error inputs when pad > kernel_size / 2 (kernel_size : int) + yield ErrorInput(SampleInput(x, kwargs={'kernel_size': 2, 'stride': 50, 'padding': 4, 'return_indices': True}), + error_regex='pad should be at most half of effective kernel size') + + # error inputs when pad > kernel_size / 2 (kernel_size : tuple) + yield ErrorInput(SampleInput(x, kwargs={'kernel_size': (3, 2), 'stride': 50, 'padding': 4, 'return_indices': True}), + error_regex='pad should be at most half of effective kernel size') + + # error: unbatched input with 0 sized non-batch dims. + err_msg = r'Expected 3D or 4D \(batch mode\) tensor with optional 0 dim batch size for input' + yield ErrorInput(SampleInput(make_arg((1, 0, 10)), + kwargs={'kernel_size': 1}), + error_regex=err_msg) + + # error: batched input with 0 sized non-batch dims. + yield ErrorInput(SampleInput(make_arg((2, 1, 10, 0)), + kwargs={'kernel_size': 1}), + error_regex=err_msg) + + # error: inputs when kernel size too large for input + yield ErrorInput(SampleInput(make_arg((1, 1, 4)), + kwargs={'kernel_size': 2}), + error_regex='Output size is too small') + + +def error_inputs_max_pool3d(op_info, device, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=torch.float, requires_grad=False) + # error inputs when pad is negative + x = make_arg((0, 1, 49, 50)) + yield ErrorInput(SampleInput(x, kwargs={'kernel_size': 2, 'stride': 50, 'padding': -1, 'return_indices': True}), + error_regex='pad must be non-negative') + # 3-dimensional kernel + yield ErrorInput(SampleInput(x, kwargs={'kernel_size': (3, 2, 2), 'stride': 50, + 'padding': -1, 'return_indices': True}), + error_regex='pad must be non-negative') + + # error inputs when pad > kernel_size / 2 (kernel_size: int) + yield ErrorInput(SampleInput(x, kwargs={'kernel_size': 2, 'stride': 50, 'padding': 4, 'return_indices': True}), + error_regex='pad should be at most half of effective kernel size') + + # error inputs when pad > kernel_size / 2 (kernel_size: tuple) + yield ErrorInput(SampleInput(x, kwargs={'kernel_size': (3, 2, 2), 'stride': 50, + 'padding': 4, 'return_indices': True}), + error_regex='pad should be at most half of effective kernel size') + + # error: unbatched input with 0 sized non-batch dims. + err_msg = r'Expected input\'s non-batch dimensions to have positive length' + yield ErrorInput(SampleInput(make_arg((0, 1, 2, 10)), + kwargs={'kernel_size': 1}), + error_regex=err_msg) + + # error: batched inputs with 0 sized non-batch dims. + yield ErrorInput(SampleInput(make_arg((2, 1, 0, 1, 2)), + kwargs={'kernel_size': 1}), + error_regex=err_msg) + + # error: inputs when kernel size too large for input + yield ErrorInput(SampleInput(make_arg((1, 1, 1, 4, 4)), + kwargs={'kernel_size': 2}), + error_regex='Output size is too small') + + + +def sample_inputs_normalize(self, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, low=-1, high=1, device=device, dtype=dtype, requires_grad=requires_grad) + + cases: tuple[tuple[int, ...], dict] = ( + ((2, 1, 4, 5), {'p': 1., 'dim': 2}), + ((2, 3, 4, 5), {'p': 2., 'dim': 1}), + ((1, 2, 4, 5), {'p': 0.5, 'dim': 0}), + ((1, 3, 4, 5), {'p': -1., 'dim': 1}), + ((1, 3, 4, 5), {'p': 0., 'dim': -1}), + ((), {'p': 1.2, 'dim': 0}), + ((2, 3, 4, 5), {}), + ((2, 3, 4, 5), {'eps': 1e-4})) + + for input_shape, kwargs in cases: + yield SampleInput(make_arg(input_shape), kwargs=kwargs) + + +def complex_conv(fn, input_size, weight, grad_output, stride, padding, dilation, groups): + # conv(W, x, b) = conv(Wr, xr, br) - conv(Wi, xi, 0) + i(conv(Wi, xr, bi) + conv(Wr, xi, 0)) + # a = conv(Wr, xr, br), + # b = conv(Wi, xi, 0), + # c = conv(Wr + Wi, xr + xi, br + bi) + # conv(W, x, b) = a - b + i(c - a - b) + + grad_output_ = torch.view_as_real(grad_output) + grad_output_r = grad_output_[..., 0] + grad_output_i = grad_output_[..., 1] + + weight_ = torch.view_as_real(weight) + weight_r = weight_[..., 0] + weight_i = weight_[..., 1] + + a = fn(input_size, weight_r, grad_output_r, stride, padding, dilation, groups) + b = fn(input_size, weight_i, grad_output_i, stride, padding, dilation, groups) + c = fn(input_size, weight_r + weight_i, grad_output_r + grad_output_i, stride, padding, dilation, groups) + + return (a - b) + 1j * (c - a - b) + + +def conv_transpose_ref(input, weight, bias, stride=1, padding=0, + output_padding=0, dilation=1, groups=1, + fn=None): + # Derivative of `conv` is `conv_transpose`. + # To verify the correctness of `conv_transpose`, + # we rely `torch.nn.grad` implementation (which is tested in test_nn.py) + # for floating dtypes. + + assert fn is not None + + grad_fn_map = {torch.nn.functional.conv_transpose1d: torch.nn.grad.conv1d_input, + torch.nn.functional.conv_transpose2d: torch.nn.grad.conv2d_input, + torch.nn.functional.conv_transpose3d: torch.nn.grad.conv3d_input} + batched_dim_map = {torch.nn.functional.conv_transpose1d: 3, + torch.nn.functional.conv_transpose2d: 4, + torch.nn.functional.conv_transpose3d: 5} + + # Input for `ref` is ndarray. + input, weight = torch.from_numpy(input), torch.from_numpy(weight) + + is_batched = len(input.shape) == batched_dim_map[fn] + if not is_batched: + input = input.unsqueeze(0) + + if bias is not None: + bias = torch.from_numpy(bias) + unsqueeze_dims = input.ndim - 2 + for _ in range(unsqueeze_dims): + bias = bias.unsqueeze(1) + + grad_output = input + # Get the input shape for grad_fn. + conv_transpose_output = fn(grad_output.to('meta'), weight.to('meta'), None, + stride=stride, padding=padding, output_padding=output_padding, + groups=groups, dilation=dilation) + input_size = conv_transpose_output.shape + + grad_fn = grad_fn_map[fn] + if weight.dtype.is_complex: + out = complex_conv(grad_fn, input_size, weight, grad_output, stride, padding, dilation, groups) + else: # Floating + out = grad_fn(input_size, weight, grad_output, stride, padding, dilation, groups) + + if bias is not None: + out = out + bias + + return out.squeeze(0) if not is_batched else out + + +def sample_inputs_conv_transpose1d(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + # Ordered as shapes for input, weight, bias + # and a dict of values of (stride, padding, output_padding, groups, dilation) + cases: tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], dict] = ( + ((1, 3, 4), (3, 3, 3), (3,), + {'stride': (2,), 'padding': 2, 'output_padding': (1,), 'groups': 1}), + ((2, 2, 4), (2, 2, 4), (4,), + {'stride': (3,), 'padding': (1,), 'output_padding': (2,), 'groups': 2, 'dilation': (4,)}), + ((1, 1, 4), (1, 1, 4), (1,), + {'stride': 2, 'padding': 1, 'output_padding': 1, 'groups': 1, 'dilation': (2,)}), + ((1, 1, 4), (1, 2, 3), None, + {'stride': 2, 'padding': 1, 'output_padding': 1, 'groups': 1}), + ((1, 4, 5), (4, 8, 3), None, + {}) + ) + + for input_shape, weight, bias, kwargs in cases: + # Batched + yield SampleInput(make_arg(input_shape), args=( + make_arg(weight), + make_arg(bias) if bias is not None else bias + ), kwargs=kwargs) + # Unbatched + yield SampleInput(make_arg(input_shape[1:]), args=( + make_arg(weight), + make_arg(bias) if bias is not None else bias + ), kwargs=kwargs) + + +def sample_inputs_conv_transpose2d(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + # Ordered as shapes for input, weight, bias + # and a dict of values of (stride, padding, output_padding, groups, dilation) + cases: tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], dict] = ( + ((1, 3, 4, 4), (3, 3, 3, 3), (3,), + {'stride': (2, 2), 'padding': 2, 'output_padding': (1, 1), 'groups': 1}), + ((2, 2, 4, 4), (2, 2, 4, 5), (4,), + {'stride': (3, 2), 'padding': (1, 2), 'output_padding': (2, 3), 'groups': 2, 'dilation': (4, 4)}), + ((1, 1, 4, 5), (1, 1, 4, 3), (1,), + {'stride': 2, 'padding': 1, 'output_padding': 1, 'groups': 1, 'dilation': (2, 3)}), + ((1, 1, 4, 3), (1, 2, 3, 4), None, + {'stride': 2, 'padding': 1, 'output_padding': 1, 'groups': 1}), + ((2, 4, 4, 4), (4, 1, 3, 3), None, {'groups': 4}), + ((1, 2, 5, 5), (2, 4, 3, 3), None, {}) + ) + + for input_shape, weight, bias, kwargs in cases: + # Batched + yield SampleInput(make_arg(input_shape), args=( + make_arg(weight), + make_arg(bias) if bias is not None else bias + ), kwargs=kwargs) + # Unbatched + yield SampleInput(make_arg(input_shape[1:]), args=( + make_arg(weight), + make_arg(bias) if bias is not None else bias + ), kwargs=kwargs) + +def sample_inputs_conv_transpose3d(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + # Ordered as shapes for input, weight, bias + # and a dict of values of (stride, padding, output_padding, groups, dilation) + cases: tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], dict] = ( + ((1, 3, 4, 4, 4), (3, 3, 3, 3, 3), (3,), + {'stride': (2, 2, 2), 'padding': 2, 'output_padding': (1, 1, 1), 'groups': 1}), + ((2, 2, 4, 4, 4), (2, 2, 4, 5, 6), (4,), + {'stride': (3, 2, 1), 'padding': (1, 2, 3), 'output_padding': (2, 3, 1), 'groups': 2, 'dilation': (4, 4, 4)}), + ((1, 1, 4, 5, 2), (1, 1, 4, 3, 1), (1,), + {'stride': 2, 'padding': 1, 'output_padding': 1, 'groups': 1, 'dilation': (2, 3, 2)}), + ((1, 1, 4, 3, 4), (1, 2, 3, 4, 5), None, + {'stride': 2, 'padding': 1, 'output_padding': 1, 'groups': 1}), + ((1, 4, 5, 5, 5), (4, 8, 3, 3, 3), None, + {}) + ) + + for input_shape, weight, bias, kwargs in cases: + # Batched + yield SampleInput(make_arg(input_shape), args=( + make_arg(weight), + make_arg(bias) if bias is not None else bias + ), kwargs=kwargs) + # Unbatched + yield SampleInput(make_arg(input_shape[1:]), args=( + make_arg(weight), + make_arg(bias) if bias is not None else bias + ), kwargs=kwargs) + + +def sample_inputs_conv1d(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + # Ordered as shapes for input, weight, bias, + # and a dict of values of (stride, padding, dilation, groups) + cases: tuple = ( + ((1, 3, 4), (3, 3, 3), (3,), {'stride': (2,), 'padding': 2, 'groups': 1}), + ((2, 4, 8), (2, 2, 3), (2,), {'stride': 3, 'padding': 1, 'groups': 2, 'dilation': 2}), + ((1, 4, 5), (1, 4, 3), None, {'stride': (2,), 'padding': 'valid'}), + ((2, 2, 4), (2, 1, 4), (2,), {'stride': (1,), 'padding': 'same', 'groups': 2, 'dilation': (2,)}), + # With defaults + ((1, 4, 5), (3, 4, 3), None, {}), + ) + + for input_shape, weight, bias, kwargs in cases: + # Batched + yield SampleInput(make_arg(input_shape), args=( + make_arg(weight), + make_arg(bias) if bias is not None else bias + ), kwargs=kwargs) + # Unbatched + yield SampleInput(make_arg(input_shape[1:]), args=( + make_arg(weight), + make_arg(bias) if bias is not None else bias + ), kwargs=kwargs) + + +def error_inputs_conv1d(opinfo, device, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=torch.float64) + make_int_arg = partial(make_tensor, device=device, dtype=torch.int64) + make_complex_arg = partial(make_tensor, device=device, dtype=torch.complex128) + + # error inputs for different dtypes of input tensor and bias + yield ErrorInput( + SampleInput(make_int_arg((1, 1, 4)), args=(make_int_arg((1, 1, 2)), make_arg((1,)))), + error_regex="should be the same") + + # error inputs for different dtypes of input tensor and bias + yield ErrorInput( + SampleInput(make_arg((1, 1, 4)), args=(make_arg((1, 1, 2)), make_complex_arg((1,)))), + error_regex="should be the same") + + # error inputs for negative strides + yield ErrorInput( + SampleInput(make_arg((1, 1, 4)), args=(make_arg((1, 2, 2)), make_arg((1,))), + kwargs={'stride': (-1,)}), error_regex="non-positive stride is not supported") + + # error inputs for negative padding + yield ErrorInput( + SampleInput(make_arg((1, 1, 4)), args=(make_arg((1, 2, 2)), make_arg((1,))), + kwargs={'padding': (-1,)}), error_regex="negative padding is not supported") + + # error inputs for negative dilation + yield ErrorInput( + SampleInput(make_arg((1, 1, 4)), args=(make_arg((1, 1, 2)), make_arg((1,))), + kwargs={'dilation': (-1,)}), error_regex="dilation should be greater than zero") + + # FIXME: https://github.com/pytorch/pytorch/issues/85656 + # error inputs for bias shape not equal to the output channels + # yield ErrorInput(SampleInput(make_arg((1, 1, 4)), args=(make_arg((1, 1, 3)), make_arg((2,)))), + # error_regex="expected bias to be 1-dimensional with 1 elements") + + # error inputs for input.ndim != weight.ndim + yield ErrorInput(SampleInput(make_arg((1, 1, 4)), args=(make_arg((1, 2)), make_arg((1,)))), + error_regex="weight should have at least three dimensions") + + # error inputs for the weight[0] are less than the number of groups + yield ErrorInput( + SampleInput(make_arg((2, 2, 4)), args=(make_arg((2, 2, 2)), make_arg((2,))), + kwargs={'padding': 'same', 'groups': 3}), error_regex="expected weight to be at least 3 at dimension 0") + + # error inputs for the weight[0] are less than the number of groups + yield ErrorInput( + SampleInput(make_arg((2, 2, 4)), args=(make_arg((2, 2, 2)), make_arg((2,))), + kwargs={'groups': 3}), error_regex="expected weight to be at least 3 at dimension 0") + + # error inputs for invalid groups + yield ErrorInput( + SampleInput(make_arg((2, 2, 4)), args=(make_arg((2, 2, 2)), make_arg((2,))), + kwargs={'padding': 'same', 'groups': -1}), error_regex="non-positive groups is not supported") + + # error inputs for invalid groups + yield ErrorInput( + SampleInput(make_arg((2, 2, 4)), args=(make_arg((2, 2, 2)), make_arg((2,))), + kwargs={'padding': 'same', 'groups': 0}), error_regex="non-positive groups is not supported") + + +def error_inputs_conv2d(opinfo, device, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=torch.float64) + make_int_arg = partial(make_tensor, device=device, dtype=torch.int64) + make_complex_arg = partial(make_tensor, device=device, dtype=torch.complex128) + + # error inputs for different dtypes of input tensor and bias + yield ErrorInput( + SampleInput(make_int_arg((2, 4, 4)), args=(make_int_arg((3, 2, 3, 3)), make_arg((3,)))), + error_regex="should be the same") + + # error inputs for different dtypes of input tensor and bias + yield ErrorInput( + SampleInput(make_arg((2, 4, 4)), args=(make_arg((3, 2, 3, 3)), make_complex_arg((3,)))), + error_regex="should be the same") + + # error inputs for negative strides + yield ErrorInput( + SampleInput(make_arg((1, 1, 4, 4)), args=(make_arg((1, 2, 2, 3)), make_arg((1,))), + kwargs={'stride': (-1,)}), error_regex="non-positive stride is not supported") + + # error inputs for negative padding + yield ErrorInput( + SampleInput(make_arg((1, 1, 4, 3)), args=(make_arg((1, 2, 2, 4)), make_arg((1,))), + kwargs={'padding': (-1,)}), error_regex="negative padding is not supported") + + # error inputs for negative dilation + yield ErrorInput( + SampleInput(make_arg((1, 1, 4, 2)), args=(make_arg((1, 1, 2, 5)), make_arg((1,))), + kwargs={'dilation': (-1,)}), error_regex="dilation should be greater than zero") + + # FIXME: https://github.com/pytorch/pytorch/issues/85656 + # error inputs for bias shape not equal to the output channels + # yield ErrorInput(SampleInput(make_arg((1, 1, 4, 4)), args=(make_arg((1, 1, 3, 2)), make_arg((2,)))), + # error_regex="expected bias to be 1-dimensional with 1 elements") + + # error inputs for input.ndim != weight.ndim + yield ErrorInput( + SampleInput(make_arg((1, 1, 4, 3)), args=(make_arg((1, 2, 2)), make_arg((1,))), + kwargs={'padding': 'same'}), error_regex="Expected 3-dimensional input for 3-dimensional weight") + + # error inputs for the weight[0] are less than the number of groups + yield ErrorInput( + SampleInput(make_arg((2, 2, 4, 3)), args=(make_arg((2, 2, 1, 3)), make_arg((2,))), + kwargs={'groups': 3}), error_regex="expected weight to be at least 3 at dimension 0") + + # error inputs for groups the weight[0] are less than the number of groups + yield ErrorInput( + SampleInput(make_arg((2, 2, 4, 3)), args=(make_arg((2, 2, 1, 3)), make_arg((2,))), + kwargs={'padding': 'same', 'groups': 3}), error_regex="expected weight to be at least 3 at dimension 0") + + # error inputs for invalid groups + yield ErrorInput( + SampleInput(make_arg((2, 2, 4, 5)), args=(make_arg((2, 2, 1, 4)), make_arg((2,))), + kwargs={'padding': 'same', 'groups': -1}), error_regex="non-positive groups is not supported") + + # error inputs for invalid groups + yield ErrorInput( + SampleInput(make_arg((2, 2, 4, 3)), args=(make_arg((2, 2, 4, 3)), make_arg((2,))), + kwargs={'padding': 'same', 'groups': 0}), error_regex="non-positive groups is not supported") + + +def sample_inputs_conv2d(op_info, device, dtype, requires_grad, jit_fail_sample=False, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + # Ordered as shapes for input, weight, bias + # and a dict of values of (stride, padding, groups, dilation) + cases: tuple = ( + ((1, 3, 4, 4), (3, 3, 3, 3), (3,), + {'stride': (2, 2), 'padding': 2, 'groups': 1}), + ((2, 4, 8, 8), (2, 2, 3, 3), (2,), + {'stride': (3, 2), 'padding': (2, 1), 'groups': 2, 'dilation': (4, 4)}), + ((1, 4, 5, 5), (1, 4, 2, 3), (1,), + {'stride': 2, 'padding': 1, 'groups': 1, 'dilation': (2, 3)}), + ((1, 4, 5, 5), (1, 4, 2, 3), (1,), + {'stride': 2, 'padding': 1, 'groups': 1, 'dilation': (2, 3)}), + ((1, 2, 4, 3), (4, 2, 3, 4), None, + {'stride': 2, 'padding': 1, 'groups': 1}), + ((1, 4, 5, 5), (1, 4, 2, 3), (1,), + {'stride': 2, 'padding': "valid"}), + ((1, 4, 5, 5), (1, 4, 2, 3), (1,), + {'stride': 1, 'padding': "same", 'dilation': 3}), + # Below are the group related samples from common_nn.py + ((2, 4, 6, 6), (4, 1, 3, 3), (4,), {'groups': 4}), + ((2, 4, 6, 6), (8, 1, 3, 3), (8,), {'groups': 4}), + ((2, 4, 6, 6), (8, 1, 3, 3), None, {'groups': 4}), + ((2, 4, 6, 6), (4, 1, 3, 3), (4,), {'groups': 4, 'stride': (3, 2)}), + ((2, 4, 6, 6), (4, 1, 3, 3), (4,), {'groups': 4, 'padding': (1, 1)}), + ((2, 4, 5, 5), (4, 1, 2, 2), (4,), {'groups': 4, 'dilation': (2, 2)}), + ((2, 4, 6, 5), (6, 2, 3, 2), (6,), {'groups': 2}), + # With defaults + ((1, 4, 5, 5), (3, 4, 3, 3), None, {}), + ) + + for input_shape, weight, bias, kwargs in cases: + # Batched + yield SampleInput(make_arg(input_shape), args=( + make_arg(weight), + make_arg(bias) if bias is not None else bias + ), kwargs=kwargs) + # Unbatched + yield SampleInput(make_arg(input_shape[1:]), args=( + make_arg(weight), + make_arg(bias) if bias is not None else bias + ), kwargs=kwargs) + + +def sample_inputs_conv3d(opinfo, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + # Ordered as shapes for input, weight, bias + # and dict of values of (stride, padding, dilation, groups) + cases: tuple = ( + ((1, 1, 4, 4, 4), (1, 1, 1, 1, 1), (1,), {'padding': 'same'}), + ((1, 1, 4, 4, 4), (1, 1, 4, 4, 4), (1,), {'stride': (2, 2, 2)}), + ((1, 1, 5, 5, 5), (1, 1, 3, 3, 3), (1,), {'dilation': 2}), + ((1, 1, 1, 1, 10), (1, 1, 1, 1, 4), None, {'padding': 'valid'}), + ((1, 1, 10, 11, 12), (1, 1, 1, 2, 5), None, {'padding': 'same'}), + ((1, 1, 10, 11, 12), (1, 1, 1, 2, 5), None, {'padding': 'same', 'dilation': 2}), + ((1, 1, 10, 11, 12), (1, 1, 4, 4, 4), None, {'padding': 'same', 'dilation': 3}), + ((1, 1, 1, 1, 10), (1, 1, 1, 1, 4), None, {'padding': 'valid'}), + ((3, 9, 3, 1, 9), (3, 3, 3, 1, 9), (3,), {'groups': 3}), + ((3, 9, 3, 1, 9), (3, 3, 3, 1, 9), (3,), {'stride': (2, 2, 2), 'dilation': 1, 'groups': 3}), + ) + + for input_shape, weight, bias, kwargs in cases: + # Batched + yield SampleInput(make_arg(input_shape), args=( + make_arg(weight), + make_arg(bias) if bias is not None else bias + ), kwargs=kwargs) + # Unbatched + yield SampleInput(make_arg(input_shape[1:]), args=( + make_arg(weight), + make_arg(bias) if bias is not None else bias + ), kwargs=kwargs) + + +def error_inputs_conv3d(opinfo, device, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=torch.float64) + make_int_arg = partial(make_tensor, device=device, dtype=torch.int64) + make_complex_arg = partial(make_tensor, device=device, dtype=torch.complex128) + + # error inputs for different dtypes of input tensor and bias + yield ErrorInput( + SampleInput(make_int_arg((1, 1, 4, 4, 4)), args=(make_int_arg((1, 1, 2, 2, 2)), make_arg((1,)))), + error_regex="should be the same") + + # error inputs for different dtypes of input tensor and bias + yield ErrorInput( + SampleInput(make_arg((1, 1, 4, 4, 4)), args=(make_arg((1, 1, 2, 2, 2)), make_complex_arg((1,)))), + error_regex="should be the same") + + # error inputs for negative strides + yield ErrorInput( + SampleInput(make_arg((1, 1, 4, 4, 4)), args=(make_arg((1, 1, 2, 2, 2)), make_arg((1,))), + kwargs={'stride': (-1,)}), error_regex="non-positive stride is not supported") + + # error inputs for negative padding + yield ErrorInput( + SampleInput(make_arg((1, 1, 4, 4, 4)), args=(make_arg((1, 1, 2, 2, 2)), make_arg((1,))), + kwargs={'padding': (-1,)}), error_regex="negative padding is not supported") + + # error inputs for negative dilation + yield ErrorInput( + SampleInput(make_arg((1, 1, 4, 4, 4)), args=(make_arg((1, 1, 2, 2, 2)), make_arg((1,))), + kwargs={'dilation': (-1,)}), error_regex="dilation should be greater than zero") + + # FIXME: https://github.com/pytorch/pytorch/issues/85656 + # error inputs for bias shape not equal to the output channels + # yield ErrorInput(SampleInput(make_arg((1, 1, 4, 4, 4)), args=(make_arg((1, 1, 3, 3, 3)), make_arg((2,)))), + # error_regex="expected bias to be 1-dimensional with 1 elements") + + # error inputs for input.ndim != weight.ndim + yield ErrorInput( + SampleInput(make_arg((1, 1, 3, 4, 5)), args=(make_arg((1, 1, 4, 3)), make_arg((1,))), + kwargs={'padding': 'same'}), error_regex="Expected 4-dimensional input for 4-dimensional weight") + + # error inputs for the weight[0] are less than the number of groups + yield ErrorInput( + SampleInput(make_arg((2, 2, 3, 4, 5)), args=(make_arg((2, 2, 4, 3, 3)), + make_arg((2,))), kwargs={'groups': 3}), + error_regex="expected weight to be at least 3 at dimension 0") + + # error inputs for the weight[0] are less than the number of groups + yield ErrorInput( + SampleInput(make_arg((2, 2, 3, 4, 5)), args=(make_arg((2, 2, 4, 3, 3)), + make_arg((2,))), kwargs={'padding': 'same', 'groups': 3}), + error_regex="expected weight to be at least 3 at dimension 0") + + # error inputs for invalid groups + yield ErrorInput( + SampleInput(make_arg((2, 2, 3, 4, 5)), args=(make_arg((2, 2, 4, 3, 3)), + make_arg((2,))), kwargs={'padding': 'same', 'groups': 0}), + error_regex="non-positive groups is not supported") + + # error inputs for padding='same' not supported by strided convolutions + yield ErrorInput( + SampleInput(make_arg((18, 27, 9, 1, 9)), args=(make_arg((9, 9, 9, 1, 9)), + make_arg((9,))), kwargs={'stride': 2, 'padding': 'same', 'groups': 3}), + error_regex="padding='same' is not supported for strided convolutions") + + +def sample_inputs_group_norm(opinfo, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + # Ordered as input shape, num groups, and kwargs for eps + cases: tuple[tuple[int, ...], int, float] = ( + ((1, 6, 3), 2, {'eps' : 0.5}), + ((2, 6, 3), 2, {'eps' : -0.5}), + ((1, 3), 1, {'eps' : 1e-5}), + ((0, 2), 1, {'eps' : 1e-5}), + ((S, S, S), 1, {'eps' : 0.5}), + ) + + # num_channels is inferred to be input.shape[1] dimension + for input_shape, num_groups, kwargs in cases: + # Shape of weight and bias should be the same as num_channels + channels = input_shape[1] if len(input_shape) > 1 else 0 + weight_tensor = make_arg(channels) + bias_tensor = make_arg(channels) + + # Checking for permutations of weights and biases as `None` + weights = [weight_tensor, None] + biases = [bias_tensor, None] + for weight, bias in itertools.product(weights, biases): + kwargs = { + 'weight': weight, + 'bias': bias, + **kwargs + } + yield SampleInput(make_arg(input_shape), num_groups, **kwargs) + + # Without any optional args + yield SampleInput(make_arg((1, 2)), args=(1,)) + +def reference_inputs_group_norm(op_info, device, dtype, requires_grad, **kwargs): + yield from sample_inputs_group_norm( + op_info, device, dtype, requires_grad, **kwargs) + + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + # Ordered as input shape, num groups, and kwargs for eps + cases: tuple[tuple[int, ...], int, float] = ( + ((20, 6, 10, 10), 3, {'eps' : 1e-5}), + # equivalent with InstanceNorm + # GroupNorm(C, num_groups=C) == InstanceNorm(num_features=C) + ((20, 6, 10, 10), 6, {'eps' : 1e-5}), + # equivalent with LayerNorm + # GroupNorm(C, num_groups=1, affine=False) == LayerNorm(normalized_shape=[C, H, W], elementwise_affine=False) + ((20, 6, 10, 10), 1, {'eps' : 1e-5}), + ) + + # num_channels is inferred to be input.shape[1] dimension + for input_shape, num_groups, kwargs in cases: + # Shape of weight and bias should be the same as num_channels + channels = input_shape[1] if len(input_shape) > 1 else 0 + input_tensor = make_arg(input_shape) + weight_tensor = make_arg(channels) + bias_tensor = make_arg(channels) + + # Checking for permutations of weights and biases as `None` + weights = [weight_tensor, None] + biases = [bias_tensor, None] + for weight, bias in itertools.product(weights, biases): + kwargs = { + 'weight': weight, + 'bias': bias, + **kwargs + } + yield SampleInput(input_tensor, num_groups, **kwargs) + + +def sample_inputs_instance_norm(opinfo, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + make_arg_without_requires_grad = partial(make_tensor, device=device, dtype=dtype, requires_grad=False) + + # Ordered as: input shape, kwargs for momentum, eps + cases: tuple[tuple[int, ...], dict] = ( + ((S, S, S), {'momentum': 0.5, 'eps': 0.6}), + ((S, S, S), {'momentum': 0.5, 'eps': 0.6, 'use_input_stats': True}), + ((3, 2, 4), {'momentum': -1.2}), + ((3, 2, 4), {'momentum': 0.0}), + ((3, 2, 3, 4), {'momentum': -1.0, 'eps': 0.5}), + ((3, 2, 3, 4), {'momentum': -1.0, 'eps': 0.5}), + ) + + for input_shape, kwargs in cases: + # args: running mean, running var, weight and bias should necessarily be of shape: (channels,) + channels = input_shape[1] + weight = make_arg(channels) + bias = make_arg(channels) + running_mean = make_arg_without_requires_grad(channels, low=0) + running_var = make_arg_without_requires_grad(channels, low=0) + new_kwargs = { + 'running_mean': running_mean, + 'running_var': running_var, + 'weight': weight, + 'bias': bias, + **kwargs + } + + yield SampleInput( + make_arg(input_shape), + args=(), + kwargs=new_kwargs + ) + + # Checking for permutations of weights and biases as `None` + # instance_norm assumes that if there's a bias, there's a weight + weights = [channels, None] + biases = [None, None] + + for weight_channels, bias_channels in zip(weights, biases, strict=True): + running_mean = make_arg_without_requires_grad(channels, low=0) + running_var = make_arg_without_requires_grad(channels, low=0) + yield SampleInput( + make_arg(input_shape), + args=(), + kwargs={ + 'running_mean': running_mean, + 'running_var': running_var, + 'weight': make_arg(weight_channels) if weight_channels is not None else None, + 'bias': make_arg(bias_channels) if bias_channels is not None else None + } + ) + + # Test case for no optional kwargs + yield SampleInput(make_arg((1, 2, 3)), kwargs={}) + +def sample_inputs_safe_softmax(opinfo, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=False) + + def make_bool_mask(*shape): + return torch.randint(0, 2, shape, device=device, dtype=torch.bool) + + def mask_two_rows(rows, cols): + mask_two_rows = torch.ones((rows, cols), dtype=torch.bool, device=device) + mask_two_rows[rows - 1] = False + mask_two_rows[rows - 3] = False + return mask_two_rows + + def convert_to_float_mask(mask: torch.Tensor) -> torch.Tensor: + return torch.where(~mask, float('-inf'), 0.0) + + def with_requires_grad(tensor): + return tensor.requires_grad_(requires_grad) + + def generate_input_from_mask(mask_shape, dim): + mask = make_bool_mask(*mask_shape) + input_tensor = make_arg(mask_shape) + masked_input = input_tensor + convert_to_float_mask(mask) + return SampleInput(with_requires_grad(masked_input), kwargs={'dim': dim}) + + samples = [ + # Basic 3D tensor with mask + generate_input_from_mask((2, 3, 4), dim=1), + # 2D tensor with mask, testing different dim + generate_input_from_mask((5, 5), dim=0), + # 4D tensor, testing with a different dim + generate_input_from_mask((2, 3, 4, 5), dim=2), + # Edge case: 1D tensor + generate_input_from_mask((10,), dim=0), + # Edge case: tensor with one dimension of size 1 + generate_input_from_mask((1, 5, 5), dim=1), + # Testing with all elements masked + SampleInput( + with_requires_grad( + make_arg((3, 3)) + + convert_to_float_mask( + torch.zeros((3, 3), dtype=torch.bool, device=device) + ) + ), + kwargs={"dim": 1}, + ), + # Testing with no elements masked + SampleInput( + with_requires_grad( + make_arg((3, 3)) + + convert_to_float_mask( + torch.ones((3, 3), dtype=torch.bool, device=device) + ) + ), + kwargs={"dim": 1}, + ), + # Testing with two rows masked + SampleInput( + with_requires_grad( + make_arg((6, 3)) + convert_to_float_mask(mask_two_rows(6, 3)) + ), + kwargs={"dim": 1}, + ), + ] + yield from samples + +def sample_inputs_layer_norm(opinfo, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + # Ordered as input shape, normalized_shape and a kwarg dict for eps + cases: tuple[tuple[int, ...], tuple[int, ...], dict] = ( + ((1, 2, 3), (1, 2, 3), {'eps': 0.5}), + ((2, 2, 3), (2, 3), {'eps': -0.5}), + ((1,), (1,), {}), + ((1, 2), (2,), {}), + ((0, 1), (1,), {}), + ) + + for input_shape, normalized_shape, kwargs in cases: + # Shape of weight and bias should be the same as normalized_shape + weight = make_arg(normalized_shape) + bias = make_arg(normalized_shape) + yield SampleInput( + make_arg(input_shape), + args=(normalized_shape, weight, bias), + kwargs=kwargs + ) + # Without any optional args + yield SampleInput(make_arg((1, 2)), args=((2,),)) + + # TODO: @krshrimali, once to_numpy method in SampleInput class is modified to take None inputs, + # enable these inputs; see https://github.com/pytorch/pytorch/pull/63276#discussion_r691950400 + + # With weight and a `None` bias + # yield SampleInput(make_arg((1, 2)), args=((2,), make_arg((2,)), None)) + + # With `None` weight and bias (tests failing for this, see the link above) + # yield SampleInput(make_arg((1, 2)), args=((2,), None, make_arg((2,)))) + + +def sample_inputs_native_layer_norm(opinfo, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + # Ordered as input shape, normalized_shape, eps + cases: tuple[tuple[int, ...], tuple[int, ...], float] = ( + ((1, 2, 3), (1, 2, 3), 0.5), + ((2, 2, 3), (2, 3), -0.5), + ((1,), (1,), 1e-5), + ((1, 2), (2,), 1e-5), + ((0, 1), (1,), 1e-5), + ) + + for input_shape, normalized_shape, eps in cases: + # Shape of weight and bias should be the same as normalized_shape + weight = make_arg(normalized_shape) + bias = make_arg(normalized_shape) + yield SampleInput( + make_arg(input_shape), + args=(normalized_shape, weight, bias, eps), + ) + yield SampleInput( + make_arg(input_shape), + args=(normalized_shape, None, bias, eps), + ) + yield SampleInput( + make_arg(input_shape), + args=(normalized_shape, weight, None, eps), + ) + yield SampleInput( + make_arg(input_shape), + args=(normalized_shape, None, None, eps), + ) + +def sample_inputs_rms_norm(opinfo, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad, high=1000) + + # Ordered as input shape, normalized_shape and a kwarg dict for eps + cases: tuple[tuple[int, ...], tuple[int, ...], dict] = ( + ((1, 2, 3), (1, 2, 3), {'eps': 0.5}), + ((2, 2, 3), (2, 3), {'eps': -0.5}), + ((1,), (1,), {}), + ((1, 2), (2,), {}), + ((0, 1), (1,), {}), + ) + + for input_shape, normalized_shape, kwargs in cases: + # Shape of weight and bias should be the same as normalized_shape + weight = make_arg(normalized_shape) + yield SampleInput( + make_arg(input_shape), + args=(normalized_shape, weight), + kwargs=kwargs + ) + # Without any optional args + yield SampleInput(make_arg((1, 2)), args=((2,),)) + +def error_inputs_group_norm(opinfo, device, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=torch.float32, requires_grad=False) + + # check that input has minimum number of dimensions + err_msg1 = "Expected at least 2 dimensions for input tensor but received" + s1 = SampleInput(make_arg(1), args=(1,)) + yield ErrorInput(s1, error_regex=err_msg1) + + # check that the channels dimension is compatible with number of groups + err_msg2 = "Expected number of channels in input to be divisible by num_groups, but got input of shape" + s2 = SampleInput(make_arg((2, 7, 4)), args=(2,)) + yield ErrorInput(s2, error_regex=err_msg2) + +def error_inputs_native_layer_norm(opinfo, device, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=torch.float32, requires_grad=False) + input_shape = (1, 2, 3) + + err_msg1 = "Expected normalized_shape to be at least 1-dimensional" + s1 = SampleInput( + make_arg(input_shape), args=((), None, None, 1e-5) + ) + yield ErrorInput(s1, error_regex=err_msg1) + + normalized_shape = (1, 2, 3) + weight = make_arg((1, 2)) + err_msg2 = "Expected weight to be of same shape as normalized_shape" + s2 = SampleInput( + make_arg(input_shape), args=(normalized_shape, weight, None, 1e-5) + ) + yield ErrorInput(s2, error_regex=err_msg2) + + bias = make_arg((1, 2)) + err_msg3 = "Expected bias to be of same shape as normalized_shape" + s3 = SampleInput( + make_arg(input_shape), args=(normalized_shape, None, bias, 1e-5) + ) + yield ErrorInput(s3, error_regex=err_msg3) + + err_msg4 = "Given normalized_shape=" + s4 = SampleInput( + make_arg((2, 2, 3)), args=((2, 2), None, None, 1e-5) + ) + yield ErrorInput(s4, error_regex=err_msg4) + +def error_inputs_rms_norm(opinfo, device, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=torch.float32, requires_grad=False) + input_shape = (1, 2, 3) + + err_msg1 = "Expected normalized_shape to be at least 1-dimensional" + s1 = SampleInput( + make_arg(input_shape), args=((), None, 1e-5) + ) + yield ErrorInput(s1, error_regex=err_msg1) + + normalized_shape = (1, 2, 3) + weight = make_arg((1, 2)) + err_msg2 = "Expected weight to be of same shape as normalized_shape" + s2 = SampleInput( + make_arg(input_shape), args=(normalized_shape, weight, 1e-5) + ) + yield ErrorInput(s2, error_regex=err_msg2) + + + err_msg4 = "Given normalized_shape=" + s4 = SampleInput( + make_arg((2, 2, 3)), args=((2, 2), None, 1e-5) + ) + yield ErrorInput(s4, error_regex=err_msg4) + + +def sample_inputs_local_response_norm(opinfo, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + # Ordered as input shape, size and a kwarg dict for alpha, beta, and k + cases: tuple[tuple[int, ...], tuple[int, ...], dict] = ( + ((1, 6, 3), 2, {'alpha': 3e-05, 'beta': 0.5, 'k': 1.25}), + ((1, 6, 3), 2, {'beta': 0.5, 'k': 1.25}), + ((1, 6, 3), 2, {'alpha': 3e-05, 'k': 1.25}), + ((1, 6, 3), 2, {'alpha': 3e-05, 'beta': 0.5}), + ((1, 6, 3), 2, {'alpha': 3e-05}), + ((1, 6, 3), 2, {'beta': 0.5}), + ((1, 6, 3), 2, {'k': 1.25}), + ((1, 6, 3), 2, {}), + ((2, 6, 3), 2, {'alpha': 3e-05, 'beta': 0.5, 'k': 1.25}), + ((1, 1, 2), 1, {'alpha': 3e-05, 'beta': 0.5, 'k': 1.25}), + ((0, 1, 2), 1, {'alpha': 3e-05, 'beta': 0.5, 'k': 1.25}), + ) + + for input_shape, size, kwargs in cases: + yield SampleInput(make_arg(input_shape), args=(size,), kwargs=kwargs) + +def sample_inputs_hardswish(self, device, dtype, requires_grad, **kwargs): + N = 5 + # make sure we are testing -3 -> 3 range. default is -10 -> 10 so maybe unnecessary ? + make_arg = partial(make_tensor, device=device, dtype=dtype, + requires_grad=requires_grad, low=-5, high=5) + return (SampleInput(make_arg((N * 2, N * 2))) for _ in range(1, N)) + +def sample_inputs_linear(self, device, dtype, requires_grad, **kwargs): + features_options = [[3, 4], [8, 8]] + batch_options: list[list[int]] = [ + [], # no batch + [0], + [8], + [2, 3], + ] + create_tensor = partial(make_tensor, device=device, dtype=dtype, + requires_grad=requires_grad, low=-2, high=2) + + for has_bias, (in_feat, out_feat), batch_shape in \ + itertools.product([True, False], features_options, batch_options): + input_tensor = create_tensor(batch_shape + [in_feat]) + weight = create_tensor([out_feat, in_feat]) + if not has_bias: + yield SampleInput(input_tensor, weight) + continue + + bias = create_tensor([out_feat]) + yield SampleInput(input_tensor, weight, bias) + + # 5D tensor, used to crash on MPS, see https://github.com/pytorch/pytorch/issues/114942 + yield SampleInput(create_tensor(2, 1, 2, 1, 2), create_tensor(4, 2)) + yield SampleInput(create_tensor(2, 1, 2, 1, 2), create_tensor(4, 2), create_tensor(4)) + +def sample_inputs_bilinear(self, device, dtype, requires_grad, **kwargs): + features_options = [[3, 4, 5], [8, 8, 8]] + batch_options: list[list[int]] = [ + [], # no batch + [0], + [8], + [2, 3], + ] + create_tensor = partial(make_tensor, device=device, dtype=dtype, + requires_grad=requires_grad, low=-2, high=2) + + for has_bias, (in_feat1, in_feat2, out_feat), batch_shape in \ + itertools.product([True, False], features_options, batch_options): + input_tensor1 = create_tensor(batch_shape + [in_feat1]) + input_tensor2 = create_tensor(batch_shape + [in_feat2]) + weight = create_tensor([out_feat, in_feat1, in_feat2]) + if not has_bias: + yield SampleInput(input_tensor1, input_tensor2, weight) + continue + bias = create_tensor([out_feat]) + yield SampleInput(input_tensor1, input_tensor2, weight, bias) + +def sample_inputs_glu(self, device, dtype, requires_grad, **kwargs): + features_options = [[2], [2, 4], [8, 8], [3, 6, 8], [1, 4, 6, 7]] + batch_options: list[list[int]] = [ + [], # no batch + [0], + [8], + [2, 3], + ] + create_tensor = partial(make_tensor, device=device, dtype=dtype, + requires_grad=requires_grad, low=-2, high=2) + + for features, batch_shape in itertools.product(features_options, batch_options): + ndim = len(features) + len(batch_shape) + for dim in range(ndim): + input_tensor = create_tensor(batch_shape + features) + dim_size = input_tensor.size(dim) + if dim_size > 0 and dim_size % 2 == 0: + yield SampleInput(input_tensor, dim) + +def sample_inputs_interpolate(mode, self, device, dtype, requires_grad, **kwargs): + N, C = 2, 3 + D = 4 + S = 3 + L = 5 + + align_corners_options: tuple[Any, ...] = (None,) + if mode in ('linear', 'bilinear', 'bicubic', 'trilinear'): + align_corners_options = (True, False, None) + ranks_for_mode = { + 'nearest': [1, 2, 3], + 'nearest-exact': [1, 2, 3], + 'linear': [1], + 'bilinear': [2], + 'bicubic': [2], + 'trilinear': [3], + 'area': [1, 2, 3] + } + + def shape(size, rank, with_batch_channel=True): + if with_batch_channel: + return tuple([N, C] + ([size] * rank)) + return tuple([size] * rank) + + def uneven_shape(size, rank, with_batch_channel=True): + rc = list(shape(size, rank, with_batch_channel)) + rc[-1] += 1 + if rank > 2: + rc[-2] -= 1 + return tuple(rc) + + if mode in ('bilinear', 'bicubic') and dtype == torch.uint8: + make_arg = partial( + make_tensor, + device=device, + dtype=dtype, + requires_grad=requires_grad, + # we pick more realistic upper bound 256 instead of default 10 for uint8 dtype + high=256 if dtype == torch.uint8 else None, + ) + # provide few samples for a more close to typical image processing usage + rank = 2 + for memory_format in [torch.contiguous_format, torch.channels_last]: + yield SampleInput( + make_arg(shape(270, rank), memory_format=memory_format), + shape(130, rank, False), + scale_factor=None, + mode=mode, + align_corners=False, + ) + + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + for align_corners in align_corners_options: + for rank in ranks_for_mode[mode]: + yield SampleInput( + make_arg(shape(D, rank)), + shape(S, rank, False), + scale_factor=None, + mode=mode, + align_corners=align_corners, + ) + yield SampleInput( + make_arg(shape(D, rank)), + shape(L, rank, False), + scale_factor=None, + mode=mode, + align_corners=align_corners, + ) + if rank > 1 and dtype.is_floating_point: + yield SampleInput( + make_arg(uneven_shape(D, rank)), + uneven_shape(S, rank, False), + scale_factor=None, + mode=mode, + align_corners=align_corners, + ) + yield SampleInput( + make_arg(uneven_shape(D, rank)), + uneven_shape(L, rank, False), + scale_factor=None, + mode=mode, + align_corners=align_corners, + ) + for recompute_scale_factor in [False, True]: + for scale_factor in [1.7, 0.6]: + yield SampleInput( + make_arg(shape(D, rank)), + size=None, + scale_factor=scale_factor, + mode=mode, + align_corners=align_corners, + recompute_scale_factor=recompute_scale_factor, + ) + +def reference_inputs_interpolate(mode, self, device, dtype, requires_grad, **kwargs): + yield from sample_inputs_interpolate(mode, self, device, dtype, requires_grad, **kwargs) + + if mode in ('bilinear', 'bicubic'): + make_arg = partial( + make_tensor, + device=device, + dtype=dtype, + requires_grad=requires_grad, + # we pick more realistic upper bound 256 instead of default 10 for uint8 dtype + high=256 if dtype == torch.uint8 else None, + ) + # provide few samples for more typical image processing usage + for memory_format in [torch.contiguous_format, torch.channels_last]: + for aa in [True, False]: + yield SampleInput( + make_arg((2, 3, 345, 456), memory_format=memory_format), + (270, 270), + scale_factor=None, + mode=mode, + align_corners=False, + antialias=aa, + ) + +def sample_inputs_upsample(mode, self, device, dtype, requires_grad, **kwargs): + N, C = 2, 3 + D = 4 + S = 3 + L = 5 + + ranks_for_mode = { + 'nearest': [1, 2, 3], + 'bilinear': [2], + } + + def shape(size, rank, with_batch_channel=True): + if with_batch_channel: + return torch.Size([N, C] + ([size] * rank)) + return torch.Size([size] * rank) + + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + for rank in ranks_for_mode[mode]: + yield SampleInput(make_arg(shape(D, rank)), size=shape(S, rank, False)) + yield SampleInput(make_arg(shape(D, rank)), size=shape(L, rank, False)) + yield SampleInput(make_arg(shape(D, rank)), scale_factor=1.7) + yield SampleInput(make_arg(shape(D, rank)), scale_factor=0.6) + +def reference_inputs_upsample(mode, self, device, dtype, requires_grad, **kwargs): + yield from sample_inputs_upsample(mode, self, device, dtype, requires_grad, **kwargs) + + if mode == 'bilinear': + make_arg = partial( + make_tensor, + device=device, + dtype=dtype, + requires_grad=requires_grad, + # we pick more realistic upper bound 256 instead of default 10 for uint8 dtype + high=256 if dtype == torch.uint8 else None, + ) + # provide a single sample for more typical image processing usage + for memory_format in [torch.contiguous_format, torch.channels_last]: + yield SampleInput( + make_arg((2, 3, 345, 456), memory_format=memory_format), + (270, 270), + ) + +def sample_inputs_upsample_aa(mode, self, device, dtype, requires_grad, **kwargs): + N = 6 + C = 3 + H = 10 + W = 20 + S = 3 + L = 5 + + input_tensor = make_tensor(torch.Size([N, C, H, W]), device=device, dtype=dtype, requires_grad=requires_grad) + + yield SampleInput(input_tensor, output_size=torch.Size([S, S]), align_corners=False, scale_factors=None) + yield SampleInput(input_tensor, output_size=torch.Size([L, L]), align_corners=False, scale_factors=None) + yield SampleInput(input_tensor, output_size=None, align_corners=False, scale_factors=[1.7, 0.9]) + yield SampleInput(input_tensor, output_size=None, align_corners=True, scale_factors=[0.8, 1.0]) + + yield SampleInput(input_tensor, output_size=torch.Size([S, S]), align_corners=False, scales_h=None, scales_w=None) + yield SampleInput(input_tensor, output_size=torch.Size([S, S]), align_corners=False, scales_h=1.7, scales_w=0.9) + yield SampleInput(input_tensor, output_size=torch.Size([S, S]), align_corners=True, scales_h=1.7, scales_w=0.9) + +def sample_inputs_gelu(self, device, dtype, requires_grad, **kwargs): + N = 5 + for _ in range(1, N): + for approximate in ['none', 'tanh']: + yield SampleInput( + make_tensor((N * 2, N * 2), device=device, dtype=dtype, + requires_grad=requires_grad, low=-3, high=3), + approximate=approximate) + + +def error_inputs_gelu(op, device, **kwargs): + # Tests that gelu errors out when passed an approximation we don't know. + yield ErrorInput(SampleInput(make_tensor((), dtype=torch.float, device=device), kwargs={"approximate": "asdf"}), + error_regex="approximate argument must be either") + + +def sample_inputs_max_min_reduction_with_dim(op_info, device, dtype, requires_grad, **kwargs): + args_for_reduction_with_dim = ( + ((S, S, S), (1,),), + ((S, S, S), (1, True, ),), + ((), (0,),), + ((), (0, True,),), + ) + return ((SampleInput(make_tensor(input_tensor, dtype=dtype, device=device, + low=None, high=None, + requires_grad=requires_grad), + *args)) + for input_tensor, args in args_for_reduction_with_dim) + +def sample_inputs_max_min_reduction_no_dim(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad, low=None, high=None) + yield SampleInput(make_arg((S, S, S))) + yield SampleInput(make_arg(())) + +def _generate_nan_reduction_inputs(device, dtype, requires_grad, **kwargs): + yield from _generate_reduction_inputs(device, dtype, requires_grad) + # NaN only exists for floating point numbers + if dtype.is_complex or dtype.is_floating_point: + yield torch.tensor([2, torch.nan, -1], device=device, dtype=dtype, requires_grad=requires_grad) + yield torch.tensor([[torch.nan, 2], [0, 1]], device=device, dtype=dtype, requires_grad=requires_grad) + +def sample_inputs_nan_reduction(supports_multiple_dims): + # Generates sample inputs for reduction ops that contain the input tensor + # and dim and keepdim kwargs. If a reduction op needs to test additional + # args/kwargs then create a separate sample_inputs function + def fn(op_info, device, dtype, requires_grad, **kwargs): + for t in _generate_nan_reduction_inputs(device, dtype, requires_grad): + # Add case without dim and keepdim kwargs + yield SampleInput(t.clone().requires_grad_(requires_grad)) + for kwargs in _generate_reduction_kwargs(t.ndim, supports_multiple_dims): + yield SampleInput(t.clone().requires_grad_(requires_grad), **kwargs) + + return fn + +def sample_inputs_reduction_quantile(op_info, device, dtype, requires_grad, **kwargs): + test_quantiles = (0.5, make_tensor((2,), dtype=dtype, device=device, low=0, high=1, requires_grad=requires_grad)) + test_interpolations = ['linear', 'midpoint'] + + for quantiles in test_quantiles: + for t in _generate_reduction_inputs(device, dtype, requires_grad): + # Add case without dim and keepdim kwargs + input = t.clone().requires_grad_(requires_grad) + yield SampleInput(input, quantiles) + for kwargs in _generate_reduction_kwargs(t.ndim, supports_multiple_dims=False): + # Interpolation kwarg for now is only supported when providing both dim and keepdim + kwargs.setdefault('dim', 0) + kwargs.setdefault('keepdim', False) + for interpolation in test_interpolations: + kwargs['interpolation'] = interpolation + input = t.clone().requires_grad_(requires_grad) + yield SampleInput(input, quantiles, **kwargs) + +def sample_inputs_reduction_count_nonzero(*args, **kwargs): + """Sample inputs for count_nonzero""" + # count_nonzero does not support keepdim yet + for sample in sample_inputs_reduction(*args, **kwargs): + sample.kwargs.pop('keepdim', None) + yield sample + +def sample_inputs_leaky_relu(op_info, device, dtype, requires_grad, **kwargs): + N = 10 + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + return (SampleInput(make_arg((N, N))) for _ in range(1, N)) + +def sample_inputs_fractional_max_pool2d(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + # Order: input_shape, kernel_size + cases = (((1, 3, 9, 9), 3), + ((1, 3, 9, 9), (4, 4)), + ((1, 3, 9, 9), (6, 6)), + ((2, 3, 9, 9), (3, 3)), + ((1, 1, 4, 4), (2, 2)), + ((1, 2, 6, 6), (4, 4))) + + for input_shape, kernel_size in cases: + for return_indices in [False, True]: + # test case passing a single output size + yield SampleInput( + make_arg(input_shape), + kernel_size, + output_size=2, + return_indices=return_indices, + ) + + # test case passing a tuple output size + yield SampleInput( + make_arg(input_shape), + kernel_size, + output_size=(2, 3), + return_indices=return_indices, + ) + + # test case passing an output ratio + yield SampleInput( + make_arg(input_shape), + kernel_size, + output_ratio=(0.5, 0.5), + return_indices=return_indices, + ) + + yield SampleInput( + make_arg((1, 1, 16, 16)), + (1, 1), + output_ratio=(0.5, 0.5), + return_indices=True, + _random_samples=make_tensor((1, 1, 2), device=device, dtype=dtype, requires_grad=False), + ) + +def sample_inputs_fractional_max_pool3d(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + # Order: input_shape, kernel_size + cases = (((2, 3, 5, 5, 5), (2, 2, 2)), + ((1, 2, 6, 5, 4), 2), + ((1, 2, 5, 6, 5), (2, 3, 2)), + ((1, 2, 6, 6, 6), (2, 3, 2)), + ((1, 1, 7, 6, 7), (2, 3, 4)), + ((1, 1, 4, 5, 4), (2, 2, 1)), + ((1, 1, 8, 7, 6), (4, 3, 2)), + ((0, 1, 4, 5, 4), (2, 2, 1))) + + for input_shape, kernel_size in cases: + for return_indices in [False, True]: + # test case passing a single output size + yield SampleInput( + make_arg(input_shape), + kernel_size, + output_size=2, + return_indices=return_indices, + ) + + # test case passing a tuple output size + yield SampleInput( + make_arg(input_shape), + kernel_size, + output_size=(2, 3, 2), + return_indices=return_indices, + ) + + # test case passing an output ratio + yield SampleInput( + make_arg(input_shape), + kernel_size, + output_ratio=(0.5, 0.5, 0.5), + return_indices=return_indices, + ) + +def sample_inputs_avgpool2d(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + # Order: input_shape, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override + cases = (((1, 3, 9, 9), 3, 1, 1, True, False, 2), + ((1, 3, 9, 9), (4, 4), (2, 3), 1, True, False, 2), + ((1, 3, 9, 9), (6, 6), (3, 3), (2, 3), True, True, 2), + ((2, 3, 9, 9), (3, 3), (1, 1), (1, ), True, False, 2), + ((1, 1, 4, 4), (2, 2), (), (0, ), False, True, -2), + ((1, 2, 6, 6), (4, 4), (2, 2), (2, ), True, True, None)) + + for input_shape, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override in cases: + yield SampleInput(make_arg(input_shape), + args=(kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override)) + # Case with just input_shape and kernel_size + yield SampleInput(make_arg((1, 3, 9, 9)), args=((3, 3))) + +def sample_inputs_avgpool1d(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + # Order: input_shape, kernel_size, kwargs + cases: list[tuple[tuple[int, ...], Union[int, tuple[int, ...]], dict]] = [ + ((2, 3, 9), (3,), {}), + ((1, 3, 9), 3, dict(stride=1, padding=1, ceil_mode=True, count_include_pad=False)), + ((1, 3, 9), (6,), dict(stride=(3,), padding=(2,), ceil_mode=True, count_include_pad=True)), + ((2, 3, 9), (3,), dict(stride=(1,), padding=(1,), ceil_mode=False, count_include_pad=True)), + ((0, 3, 9), (6,), dict(stride=(3,), padding=(2,), ceil_mode=False, count_include_pad=True)), + ((1, 2, 9), (7,), dict(stride=(3,), padding=(2,), ceil_mode=False)), + ((1, 2, 9), (7,), dict(stride=(3,), padding=(3,), ceil_mode=True)), + ((1, 2, 9), (7,), dict(stride=(3,), ceil_mode=False)), + ((1, 2, 9), (7,), dict(stride=(3,), ceil_mode=True)), + ] + + for input_shape, kernel_size, kwargs in cases: + yield SampleInput(make_arg(input_shape), args=(kernel_size,), kwargs=kwargs) + +def sample_inputs_avgpool3d(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + # Order: input_shape, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override + cases: list[tuple[tuple[int, ...], Union[int, tuple[int, ...]], dict]] = [ + ((2, 3, 3, 4, 4), (2, 2, 2), {}), + ((1, 2, 4, 4, 4), 2, dict(stride=1, padding=1, ceil_mode=True, + count_include_pad=False, divisor_override=2)), + ((1, 2, 5, 5, 5), (2, 3, 4), dict(stride=(1, 2, 2), padding=(0, 1, 2), ceil_mode=True, + count_include_pad=True, divisor_override=2)), + ((1, 2, 5, 5, 5), (2, 3, 4), dict(stride=(1, 2, 2), padding=(0, 1, 2), ceil_mode=False)), + ((1, 1, 7, 5, 7), (6, 3, 4), dict(stride=(2, 3, 2), padding=(3, 1, 0), ceil_mode=False, + count_include_pad=False, divisor_override=2)), + ((1, 1, 4, 5, 4), (2, 2, 3), dict(stride=(2, 2, 1), padding=0, ceil_mode=False, + count_include_pad=True, divisor_override=-2)), + ((1, 1, 6, 5, 6), (4, 5, 6), dict(stride=(2, 3, 2), padding=2, ceil_mode=True, + count_include_pad=True, divisor_override=None)), + ((0, 1, 4, 5, 4), (2, 3, 1), dict(stride=(2, 1, 2), padding=0, ceil_mode=False, + count_include_pad=True, divisor_override=None)), + ] + + for input_shape, kernel_size, kwargs in cases: + yield SampleInput(make_arg(input_shape), args=(kernel_size,), kwargs=kwargs) + +def error_inputs_avg_pool1d(op_info, device, **kwargs): + # error inputs when pad is negative + x = torch.rand([0, 1, 49], dtype=torch.float32) + yield ErrorInput(SampleInput(x, kwargs={'kernel_size': 2, 'stride': 50, 'padding': -1}), + error_regex='pad must be non-negative') + + # error inputs when pad > kernel_size / 2 + yield ErrorInput(SampleInput(x, kwargs={'kernel_size': 2, 'stride': 50, 'padding': 4}), + error_regex='pad should be at most half of effective kernel size') + +def error_inputs_avg_pool2d(op_info, device, **kwargs): + # error inputs when pad is negative + x = torch.rand([0, 1, 49], dtype=torch.float32) + yield ErrorInput(SampleInput(x, kwargs={'kernel_size': 2, 'stride': 50, 'padding': -1}), + error_regex='pad must be non-negative') + # 2-dimensional kernel + yield ErrorInput(SampleInput(x, kwargs={'kernel_size': (3, 2), 'stride': 50, 'padding': -1}), + error_regex='pad must be non-negative') + + # error inputs when pad > kernel_size / 2 + yield ErrorInput(SampleInput(x, kwargs={'kernel_size': 2, 'stride': 50, 'padding': 4}), + error_regex='pad should be at most half of effective kernel size') + # 2-dimensional kernel + yield ErrorInput(SampleInput(x, kwargs={'kernel_size': (3, 2), 'stride': 50, 'padding': 4}), + error_regex='pad should be at most half of effective kernel size') + + # error inputs for zero divisor + x = torch.zeros(3, 3, 3) + yield ErrorInput(SampleInput(x, kwargs={'kernel_size': (2, 2), 'divisor_override': 0}), + error_regex='divisor must be not zero') + +def error_inputs_avg_pool3d(op_info, device, **kwargs): + # error inputs when pad is negative + x = torch.rand([0, 1, 49, 50], dtype=torch.float32) + yield ErrorInput(SampleInput(x, kwargs={'kernel_size': 2, 'stride': 50, 'padding': -1}), + error_regex='pad must be non-negative') + # 3-dimensional kernel + yield ErrorInput(SampleInput(x, kwargs={'kernel_size': (3, 2, 2), 'stride': 50, 'padding': -1}), + error_regex='pad must be non-negative') + + # error inputs when pad > kernel_size / 2 + yield ErrorInput(SampleInput(x, kwargs={'kernel_size': 2, 'stride': 50, 'padding': 4}), + error_regex='pad should be at most half of effective kernel size') + # 3-dimensional kernel + yield ErrorInput(SampleInput(x, kwargs={'kernel_size': (3, 2, 2), 'stride': 50, 'padding': 4}), + error_regex='pad should be at most half of effective kernel size') + + # error inputs for zero divisor + x = torch.zeros(3, 3, 3, 3) + yield ErrorInput(SampleInput(x, kwargs={'kernel_size': (2, 2, 2), 'divisor_override': 0}), + error_regex='divisor must be not zero') + + # error inputs for invalid input dimension + x = torch.rand([0, 1, 49], dtype=torch.float32) + yield ErrorInput(SampleInput(x, kwargs={'kernel_size': 2, 'stride': 50, 'padding': 0}), + error_regex='non-empty 4D or 5D') + + +def sample_inputs_to(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + # test_multiple_devices_to_cuda would fail if we use a different device than given + devices = [device] + if torch.device(device).type == 'cpu': + devices = [torch.device('cpu'), torch.device('cuda:0')] if torch.cuda.is_available() else devices + memory_formats = [torch.preserve_format, torch.channels_last] + + # TODO: can't switch `to.device` overload to use positional arguments + # https://github.com/pytorch/pytorch/issues/84265 + # to.device overload + for device, nb, cp, mem_f in product(devices, [True, False], [True, False], memory_formats): + kwargs = { + "memory_format": mem_f, + } + yield SampleInput(make_arg((S, S, S, S)), args=(device, torch.float64, nb, cp), kwargs=kwargs) + + # to.dtype overload + for nb, cp, mem_f in product([True, False], [True, False], memory_formats): + kwargs = { + "memory_format": mem_f, + } + yield SampleInput(make_arg((S, S, S, S)), args=(torch.float64, nb, cp), kwargs=kwargs) + + # to.other overload + for device, nb, cp, mem_f in product(devices, [True, False], [True, False], memory_formats): + kwargs = { + "memory_format": mem_f, + } + other = make_arg((S, S, S, S), dtype=torch.float64, device=device) + yield SampleInput(make_arg((S, S, S, S)), args=(other, nb, cp), kwargs=kwargs) + + +def sample_inputs_topk(op_info, device, dtype, requires_grad, **kwargs): + def get_tensor_input(size): + return make_tensor(size, dtype=dtype, device=device, requires_grad=requires_grad) + + yield SampleInput(get_tensor_input((S, M, S)), 3) + yield SampleInput(get_tensor_input((S, M, S)), 3, 1) + yield SampleInput(get_tensor_input((S, M, S)), 3, -2) + yield SampleInput(get_tensor_input((S, M, S)), 3, 1, True) + yield SampleInput(get_tensor_input((S, M, S)), 3, -2, True) + yield SampleInput(get_tensor_input((S, M, S)), 3, 1, True, True) + yield SampleInput(get_tensor_input((S, M, S)), 3, -2, True, True) + + yield SampleInput(get_tensor_input(()), 1) + yield SampleInput(get_tensor_input(()), 1, 0) + yield SampleInput(get_tensor_input(()), 1, -1) + yield SampleInput(get_tensor_input(()), 1, 0, True) + yield SampleInput(get_tensor_input(()), 1, -1, True) + yield SampleInput(get_tensor_input(()), 1, 0, True, True) + yield SampleInput(get_tensor_input(()), 1, -1, True, True) + +def sample_inputs_outer(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + yield SampleInput(make_arg(S), make_arg(M)) + +def sample_inputs_dist(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + sizes = ((S, S, S), (S,), (S, 1, S), (), (S, S)) + ps = (2, 4) + + for size_x, size_y, p in product(sizes, sizes, ps): + yield SampleInput(make_arg(size_x), args=(make_arg(size_y), p)) + +# Missing to test the nondeterminism of the operation +# https://github.com/pytorch/pytorch/issues/53352 +def sample_inputs_index(op_info, device, dtype, requires_grad, reference=False, **kwargs): + # target.index_add(dim, idx, source, *, alpha=1) + add = "index_add" in op_info.name + # target.index_copy(dim, idx, source) + copy = "index_copy" in op_info.name + # target.index_fill(dim, idx, value) + fill = "index_fill" in op_info.name + + # Extended reference inputs. We generate that exercise atomic adds / writing + # several times to one location + if reference: + make_arg = partial(torch.ones, device=device, dtype=dtype, requires_grad=requires_grad) + make_idx = partial(torch.zeros, device=device, dtype=torch.int64) + else: + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + # idx They need to be different for copy and add to be deterministic + if copy or add: + make_idx = partial(torch.randperm, device=device, dtype=torch.int64) + else: + def make_idx(n): + return make_tensor((n,), device=device, dtype=torch.int64, low=0, high=n) + + shapes = [(), (1,), (S, S)] + # extra parameter for add + if add: + if dtype == torch.bool: + alphas = (True, False) + else: + alphas = (-1, 0, 2) + else: + alphas = (None,) + + if fill: + # A weird number to catch errors. + # The former one tests `index_fill.int_Scalar`, and the latter one tests `index_fill.int_Tensor`. + values = (make_arg((1,)).item(), make_arg(())) + else: + values = (None,) + + for shape, alpha, value in product(shapes, alphas, values): + t = make_arg(shape) + args = [] + + # dim. We handle the scalar case + dim = -1 if t.ndim == 2 else 0 + args.append(dim) + + idx = make_idx(t.shape[dim] if t.ndim != 0 else 1) + args.append(idx) + + # source + if copy or add: + args.append(make_arg(shape)) + elif fill: + args.append(value) + + args = tuple(args) + kwargs = {} if alpha is None else {"alpha": alpha} + + yield SampleInput(t, args=args, kwargs=kwargs) + +def sample_inputs_index_reduce(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + def make_idx(n, m): + return make_tensor((n,), device=device, dtype=torch.int64, low=0, high=m) + + shapes = [((), ()), ((1,), (1,)), ((S, S), (S, M)), ((S, S, S), (S, M, S))] + include_selfs = (True, False) + reduce = op_info.variant_test_name + assert reduce in ('prod', 'mean', 'amin', 'amax') + + for shape, include_self in product(shapes, include_selfs): + self_shape, src_shape = shape + # dim. We handle the scalar case + dim = 1 if len(self_shape) >= 2 else 0 + idx = make_idx(src_shape[dim] if len(src_shape) != 0 else 1, + self_shape[dim] if len(self_shape) != 0 else 1) + args = (dim, idx, make_arg(src_shape), reduce) + yield SampleInput(make_arg(self_shape), + args=args, + kwargs={'include_self' : include_self}) + + # Sample inputs to test edge cases for backward + if requires_grad and reduce == 'prod': + # Check that gradients are propagated correctly for prod when zeros in self/src are reduced + # This sample tests gradients for the following cases + # (a) 1 zero reduced (from source (self[0, 1]), from self (self[0, 0])) + # (b) 2 zeros reduced (1 from src and 1 from self (self[1, 0], self[1, 1]) + # (c) no zeros reduced (self[2, 1], self[2, 2]) + # (d) 2 zeros reduced (both from src) is tested in test/test_autograd.py + # test_scatter_index_reduce_prod_gradgrad_error as this case is not supported for gradgrad + input = torch.tensor([[0, 13], [0, 0], [15, 19]], dtype=dtype, device=device, requires_grad=requires_grad) + src = torch.tensor([[2, 0], [0, 0], [2, 3], [2, 2]], dtype=dtype, device=device, requires_grad=requires_grad) + idx = torch.tensor([0, 1, 2, 0], dtype=torch.long, device=device) + + yield SampleInput(input, + args=(0, idx, src, reduce), + kwargs={'include_self': True}) + +def sample_inputs__unsafe_masked_index(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + def make_idx(n, m, dim, d): + view_shape = [1] * dim + view_shape[d] = n + return make_tensor((n,), device=device, dtype=torch.int64, low=0, high=m).view(view_shape) + + cases = [ + ((S, S), S, M), + ((S, S), M, S), + ((S, S, S), S, M), + ] + + fill_value = make_tensor([], dtype=dtype, device="cpu").item() + + for c in cases: + self_shape, high, idx_size = c + dim = len(self_shape) + indices = [make_idx(idx_size, high, dim, d) for d in range(dim)] + masks = [torch.logical_and(idx >= 0, idx < self_shape[i]) for i, idx in enumerate(indices) if idx is not None] + mask = functools.reduce(torch.logical_and, masks) + yield SampleInput(make_arg(self_shape), mask, indices, fill_value) + + masks = [torch.logical_and(idx >= 1, idx < self_shape[i] - 1) for i, idx in enumerate(indices) if idx is not None] + mask = functools.reduce(torch.logical_and, masks) + yield SampleInput(make_arg(self_shape), mask, indices, fill_value) + +def sample_inputs__unsafe_masked_index_put_accumulate(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + def make_idx(n, m, dim, d): + view_shape = [1] * dim + view_shape[d] = n + return make_tensor((n,), device=device, dtype=torch.int64, low=0, high=m).view(view_shape) + + cases = [ + ((S, S), S, (M, M)), + ((S, S), M, (S, S + 1)), + ((S, S, S), S, (M, M - 1, M + 1)), + ] + + for c in cases: + self_shape, high, idx_sizes = c + dim = len(self_shape) + indices = [make_idx(idx_sizes[d], high, dim, d) for d in range(dim)] + masks = [torch.logical_and(idx >= 0, idx < self_shape[i]) for i, idx in enumerate(indices) if idx is not None] + mask = functools.reduce(torch.logical_and, masks) + values = make_arg(idx_sizes) + yield SampleInput(make_arg(self_shape), mask, indices, values) + + masks = [torch.logical_and(idx >= 1, idx < self_shape[i] - 1) for i, idx in enumerate(indices) if idx is not None] + mask = functools.reduce(torch.logical_and, masks) + yield SampleInput(make_arg(self_shape), mask, indices, values) + + +def sample_inputs_mode(op_info, device, dtype, requires_grad, **kwargs): + args = ( + ((S, S, S), (),), + ((S, S, S), (1, ),), + ((S, S, S), (1, True, ),), + ((), (),), + ((), (0,),), + ((), (0, True,),), + # Non-fused mode kernel on CUDA + ((3000,), ()), + ) + make_arg = partial(make_tensor, dtype=dtype, device=device, + requires_grad=requires_grad, low=None, high=None) + return (SampleInput(make_arg(input_tensor), *args) + for input_tensor, args in args) + +# Missing to test the nondeterminism of the operation +# https://github.com/pytorch/pytorch/issues/53352 +def sample_inputs_put(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) + make_idx = partial(make_tensor, low=0, dtype=torch.int64, device=device, requires_grad=False) + + S = 3 + + # Generic inputs + idx = torch.randperm(S * S, device=device, dtype=torch.int64)[:S] + idx_list = [idx, -idx - 1] + for idx, acc in product(idx_list, (True, False)): + yield SampleInput(input=make_arg((S, S)), + args=(idx.clone(), + make_arg((S,)), + acc)) + + # Scalar cases + scalar_sizes = [(), (1,)] + tgt_gen = (make_arg(size) for size in scalar_sizes) + idx_gen = (make_idx(size, high=1) for size in scalar_sizes) + src_gen = (make_arg(size) for size in scalar_sizes) + for tgt, idx, src, acc in product(tgt_gen, idx_gen, src_gen, (True, False)): + yield SampleInput(input=tgt.clone().requires_grad_(requires_grad), + args=(idx.clone(), + src.clone().requires_grad_(requires_grad), + acc)) + + # Empty cases + tgt_sizes = [(0,), (), (1,), (3, 2)] + tgt_gen = (make_arg(size) for size in tgt_sizes) + idx = make_idx((0,), high=1) + src = make_arg((0,)) + for tgt, acc in product(tgt_gen, (True, False)): + yield SampleInput(input=tgt.clone().requires_grad_(requires_grad), + args=(idx.clone(), + src.clone().requires_grad_(requires_grad), + acc)) + +def sample_inputs_take(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) + make_idx = partial(make_tensor, low=0, dtype=torch.int64, device=device, requires_grad=False) + + S = 3 + + # Generic inputs: take S elements out of S * S + index = make_idx((S,), high=(S * S)) + for idx in (index, -index - 1): + yield SampleInput(input=make_arg((S, S)), args=(idx,)) + + # Scalar cases + scalar_sizes = [(), (1,)] + src_gen = (make_arg(size) for size in scalar_sizes) + idx_gen = (make_idx(size, high=1) for size in scalar_sizes) + for src, idx in product(src_gen, idx_gen): + yield SampleInput(input=src.clone().requires_grad_(requires_grad), + args=(idx.clone(),)) + + # Empty cases + src_sizes = [(0,), (), (1,), (3, 2)] + src_gen = (make_arg(size) for size in src_sizes) + + idx = make_idx((0,), high=1) + for src in src_gen: + yield SampleInput(input=src.clone().requires_grad_(requires_grad), + args=(idx.clone(),)) + +def sample_movedim_moveaxis(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, dtype=dtype, device=device, low=None, high=None, requires_grad=requires_grad) + yield SampleInput(make_arg((4, 3, 2, 1)), [0, 1, 2, 3], [3, 2, 1, 0]) + yield SampleInput(make_arg((4, 3, 2, 1)), [0, -1, -2, -3], [-3, -2, -1, -0]) + +def reference_movedim_moveaxis(op_info, device, dtype, requires_grad, **kwargs): + yield from sample_movedim_moveaxis(op_info, device, dtype, requires_grad, **kwargs) + + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + # shape, source, destination + args = ( + # empty inputs + ((), (), ()), + # int inputs, negative + ((3, 5, 7, 2), -2, 1), + # swap bounds + ((3, 5, 7, 2), (-1, 0), (0, -1)), + # non-sequential, negative + ((2, 3, 4, 5, 6), (3, -3, 4), (1, 0, -1)), + # idempotence, negative + ((2, 3, 4, 5, 6), (-3, 4, 3, 1), (-3, 4, 3, 1)), + # reverse, sequential, positive + ((6, 2, 3, 5, 4), (4, 3, 2, 1, 0), (0, 1, 2, 3, 4)), + # reverse, non-sequential + ((6, 2, 3, 5, 4), (-3, -2, -4, -5, -1), (2, 1, 3, 4, 0)), + # reverse, sequential, negative + ((6, 2, 3, 5, 4), (4, -2, 2, -4, -5), (-5, 1, 2, -2, -1)), + ) + + for shape, source, destination in args: + yield SampleInput(make_arg(shape), args=(source, destination)) + +def error_movedim_moveaxis(op_info, device, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=torch.float32) + + # source length < destination length + yield ErrorInput( + SampleInput(make_arg(2, 3, 4, 5, 6), args=((3, -3), (1, 0, -1))), + error_regex=(r"movedim: Invalid source or destination dims: source " + r"\(\[3, -3\] dims\) should contain the same number of " + r"dims as destination \(\[1, 0, -1\] dims\)"), + ) + + # source length > destination length + yield ErrorInput( + SampleInput(make_arg(2, 3, 4, 5, 6), args=((3, -3, 4), (1, 0))), + error_regex=(r"movedim: Invalid source or destination dims: source " + r"\(\[3, -3, 4\] dims\) should contain the same number of " + r"dims as destination \(\[1, 0\] dims\)"), + ) + + # repeated source dim, with negative indices + yield ErrorInput( + SampleInput(make_arg(2, 3, 4, 5, 6), args=((0, 4, -5), (1, 0, 2))), + error_regex=r"movedim: repeated dim in `source` \(\[0, 4, -5\]\)", + ) + + # repeated destination dim, with negative indices + yield ErrorInput( + SampleInput(make_arg(2, 3, 4, 5, 6), args=((1, 0, 2), (0, 4, -5))), + error_regex=r"movedim: repeated dim in `destination` \(\[0, 4, -5\]\)", + ) + + # repeated dim (both), with negative indices + yield ErrorInput( + SampleInput(make_arg(2, 3, 4, 5, 6), args=((1, 0, -4), (0, 4, -5))), + error_regex=r"movedim: repeated dim in `source` \(\[1, 0, -4\]\)", + ) + + # out of bounds source inputs, with negative indices + yield ErrorInput( + SampleInput(make_arg(2, 3, 4, 5, 6), args=((0, 1, -6), (1, 4, 2))), + error_regex=r"Dimension out of range \(expected to be in range of \[-5, 4\], but got -6\)", + error_type=IndexError, + ) + + # out of bounds destination inputs, with negative indices + yield ErrorInput( + SampleInput(make_arg(2, 3, 4, 5, 6), args=((1, 4, 2), (0, 1, -6))), + error_regex=r"Dimension out of range \(expected to be in range of \[-5, 4\], but got -6\)", + error_type=IndexError, + ) + + # out of bounds source input, int + yield ErrorInput( + SampleInput(make_arg(2, 3, 4, 5, 6), args=(-6, 1)), + error_regex=r"Dimension out of range \(expected to be in range of \[-5, 4\], but got -6\)", + error_type=IndexError, + ) + + # out of bounds destination input, int + yield ErrorInput( + SampleInput(make_arg(2, 3, 4, 5, 6), args=(3, -6)), + error_regex=r"Dimension out of range \(expected to be in range of \[-5, 4\], but got -6\)", + error_type=IndexError, + ) + +def sample_repeat_tile(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) + rep_dims = ((), (0, ), (1, ), (0, 2), (1, 1), (2, 3), (2, 3, 2), (0, 2, 3), (2, 1, 1, 1),) + shapes = ((), (0,), (2,), (3, 0), (3, 2), (3, 0, 1)) + + if requires_grad: + # Tests for variant_consistency_jit, grad, gradgrad + # are slower. Use smaller bags of `rep_dims` and `shapes` + # in this case. + rep_dims = ((), (0, ), (0, 2), (1, 1), (2, 3), (1, 3, 2), (3, 1, 1)) # type: ignore[assignment] + shapes = ((), (0,), (2,), (3, 2)) # type: ignore[assignment] + + is_repeat_op = op_info.name in ['repeat', '_refs.repeat'] + for rep_dim, shape in product(rep_dims, shapes): + # `torch.repeat` errors for `len(rep_dims) < t.dim()`, + # so we filter such combinations. + if is_repeat_op and len(rep_dim) < len(shape): + continue + yield SampleInput(make_arg(shape), rep_dim) + + +def sample_inputs_narrow_narrow_copy(op_info, device, dtype, requires_grad, *, is_narrow, **kwargs): + shapes_and_args = ( + ((S, S, S), 1, 2, 2), + ((S, S, S), -1, 2, 2), + ((S, S, S), 1, 0, 0), + ((S, S, S), -1, 0, 0), + ((S, S, S), 2, 1, 2), + ) + + for shape, dim, start, length in shapes_and_args: + tensor = make_tensor(shape, dtype=dtype, device=device, low=None, high=None, + requires_grad=requires_grad) + yield SampleInput(tensor, dim, start, length) + # narrow also accepts the start argument being a Tensor + if is_narrow: + yield SampleInput(tensor, dim, torch.tensor(start), length) + +def reference_inputs_narrow_narrow_copy(op_info, device, dtype, requires_grad, *, is_narrow, **kwargs): + yield from sample_inputs_narrow_narrow_copy(op_info, device, dtype, requires_grad, is_narrow=is_narrow, **kwargs) + + shapes_and_args = ( + # 1-dim + ((M,), 0, 0, 0), # 0 elems from the left + ((M,), -1, -1, 0), # 0 elems from the right + ((M,), 0, 5, 3), # 3 elems from the left + ((M,), 0, -5, 2), # 2 elems from the right + ((M,), -1, 0, M), # M elems from the left + ((M,), 0, -M, M), # M elems from the right + + # 2-dim + ((M, S), 1, 0, 0), # dim 1, 0 elems from the left + ((S, M), -2, -1, 0), # dim 0, 0 elems from the right + ((L, S), 1, 2, 3), # dim 1, 3 elems from the left + ((L, S), -1, 3, 2), # dim 1, 2 elems from the left + ((M, L), 0, 0, M), # dim 0, M elems from the left + ((M, L), -1, -L, L), # dim 1, L elems from the right + + # 3-dim + ((L, M, S), 2, 0, 0), # dim 2, 0 elems from the left + ((M, S, L), -1, -1, 0), # dim 2, 0 elems from the right + ((S, L, M), 2, 0, M), # dim 2, M elems from the left + ((L, S, M), -1, -M, M), # dim 2, M elems from the right + ((S, L, M), 1, 0, 0), # dim 1, 0 elems from the left + ((S, L, M), 0, 2, 1), # dim 0, 1 elem from the left + ((M, S, M), -1, -5, 4), # dim 2, 4 elems from the right + ) + + for shape, dim, start, length in shapes_and_args: + tensor = make_tensor(shape, dtype=dtype, device=device, low=None, high=None, + requires_grad=requires_grad) + yield SampleInput(tensor, dim, start, length) + # narrow also accepts the start argument being a Tensor + if is_narrow: + yield SampleInput(tensor, dim, torch.tensor(start), length) + +def error_inputs_narrow_narrow_copy(op_info, device, *, is_narrow, is_ref): + make_arg = partial(make_tensor, device=device, dtype=torch.float32) + + # 0-dim + yield ErrorInput(SampleInput(make_arg(()), 0, 0, 1), + error_type=RuntimeError, + error_regex=r"narrow\(\) cannot be applied to a 0-dim tensor\.") + + # out of bounds dim + if not is_narrow and not is_ref and torch.device(device).type == 'cpu': + # narrow_copy_dense_cpu_out + yield ErrorInput(SampleInput(make_arg((M, S, L)), 3, 0, 0), + error_type=RuntimeError, + error_regex=r"Expected dim < static_cast\(self_sizes.size\(\)\) to be true, but got false\.") + else: + yield ErrorInput(SampleInput(make_arg((M, S, L)), 3, 0, 0), + error_type=IndexError, + error_regex=r"Dimension out of range \(expected to be in range of \[-3, 2\], but got 3\)") + # out of bounds dim (negative) + yield ErrorInput(SampleInput(make_arg((L, S, M)), -4, 0, 0), + error_type=IndexError, + error_regex=r"Dimension out of range \(expected to be in range of \[-3, 2\], but got -4\)") + + # out of bounds start + yield ErrorInput(SampleInput(make_arg((L, M, S)), 1, M + 1, 0), + error_type=IndexError, + error_regex=r"start out of range \(expected to be in range of \[-10, 10\], but got 11\)") + # out of bounds start (negative) + yield ErrorInput(SampleInput(make_arg((L, M, S)), 1, -M - 1, 0), + error_type=IndexError, + error_regex=r"start out of range \(expected to be in range of \[-10, 10\], but got -11\)") + + # out of bounds length + yield ErrorInput(SampleInput(make_arg((S, L, M)), 2, 0, M + 1), + error_type=RuntimeError, + error_regex=r"start \(0\) \+ length \(11\) exceeds dimension size \(10\)\.") + # out of bounds length (negative) + if not is_narrow and not is_ref and torch.device(device).type == 'cpu': + # narrow_copy_dense_cpu_out + yield ErrorInput(SampleInput(make_arg((M,)), 0, 0, -1), + error_type=RuntimeError, + error_regex=r"start \(0\) \+ length \(-1\) exceeds dimension size \(10\)\.") + else: + yield ErrorInput(SampleInput(make_arg((M,)), 0, 0, -1), + error_type=RuntimeError, + error_regex=r"narrow\(\): length must be non-negative\.") + + # Test Tensor overload that was added for XLA. Start must be an 0-dim + # integral Tensor. narrow_copy doesn't have this overload. + # https://github.com/pytorch/pytorch/issues/31558 + if is_narrow: + # *1-dim* integral Tensor + yield ErrorInput(SampleInput(make_arg((L, M, S)), 1, make_arg(S, dtype=torch.int), 2), + error_type=RuntimeError, + error_regex=r"start must be an 0-dim integral Tensor\.") + + # 0-dim *bool* Tensor (bools are not allowed) + yield ErrorInput(SampleInput(make_arg((L, M, S)), -3, make_arg((), dtype=torch.bool), 3), + error_type=RuntimeError, + error_regex=r"start must be an 0-dim integral Tensor\.") + + +def sample_trapezoid(op_info, device, dtype, requires_grad, **kwargs): + y_shape_x_shape_and_kwargs = [ + ((2, 3), (2, 3), {}), + ((2, 3), (2, 3), {'dim': 1}), + ((6,), (6,), {}), + ((6,), None, {}), + # When 'trapezoid' is called with an empty input, it does not produce an output with requires_grad + # See Issue #{61619} + # ((6,0), (6,0), {}), + ((2, 3), (1, 3), {}), + ((3, 3), (3, 3), {}), + ((3, 3), (3, 3), {'dim': -2}), + ((5,), None, {'dx': 2.0}), + ((2, 2), None, {'dx': 3.0}) + ] + make_arg = partial(make_tensor, dtype=dtype, device=device, low=None, high=None, + requires_grad=requires_grad) + for y_shape, x_shape, kwarg in y_shape_x_shape_and_kwargs: + y_tensor = make_arg(y_shape) + if x_shape is not None: + x_tensor = make_arg(x_shape) + yield SampleInput(y_tensor, x_tensor, **kwarg) + else: + yield SampleInput(y_tensor, **kwarg) + +def sample_cumulative_trapezoid(op_info, device, dtype, requires_grad, **kwargs): + + y_shape_x_shape_and_kwargs = [ + ((2, 3), (2, 3), {}), + ((2, 3), (2, 3), {'dim': 1}), + ((6,), (6,), {}), + ((6,), None, {}), + # When 'cumulative_trapezoid' is called with an empty input, it does not produce an output with requires_grad + # See Issue #{61619} + # ((6,0), (6,0), {}), + ((2, 3), (1, 3), {}), + ((3, 3), (3, 3), {}), + ((3, 3), (3, 3), {'dim': -2}), + ((5,), None, {'dx': 2.0}), + ((2, 2), None, {'dx': 3.0}) + ] + make_arg = partial(make_tensor, device=device, dtype=dtype, + requires_grad=requires_grad, low=None, high=None) + for y_shape, x_shape, kwarg in y_shape_x_shape_and_kwargs: + y_tensor = make_arg(y_shape) + if x_shape is not None: + x_tensor = make_arg(x_shape) + yield SampleInput(y_tensor, x_tensor, **kwarg) + else: + yield SampleInput(y_tensor, **kwarg) + +def sample_unsqueeze(op_info, device, dtype, requires_grad, **kwargs): + shapes_and_axes = [ + ((3, 4, 5), 0), + ((3, 4, 5), 1), + ((3, 4, 5), 3), + ((3, 4, 5), -1), + ((3, 4, 5), -3), + ((), 0), + ((), -1), + ((1,), 0), + ((1,), -1), + ] + + for shape, axis in shapes_and_axes: + tensor = make_tensor(shape, dtype=dtype, device=device, low=None, high=None, + requires_grad=requires_grad) + yield SampleInput(tensor, axis) + + +def sample_inputs_nn_unfold(op_info, device, dtype, requires_grad, **kwargs): + shapes = ((0, 1, 5, 5), (2, 3, 5, 5)) + kernel_sizes = (2, (2, 2), (2, 3)) + dilations = (1, 2, (1, 2)) + paddings = (0, 1, (1, 2)) + strides = (1, 2, (1, 2)) + + cases = product(shapes, kernel_sizes, dilations, paddings, strides) + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + for shape, kernel_size, dilation, padding, stride in cases: + tensor = make_arg(shape) + yield SampleInput(tensor, kernel_size, dilation, padding, stride) + + # With default args + yield SampleInput(make_arg((1, 1, 5, 5)), (3, 3)) + + +def sample_inputs_squeeze(op_info, device, dtype, requires_grad, **kwargs): + shapes_and_args = ( + ((S, 1, S, 1), ()), + ((1, 1, 1, 1), ()), + ((1, 1, 1, 1), (0,)), + ((S, 1, S, 1), (1,)), + ((S, 1, S, 1), (-1,)), + ((S, 1, S, 1), (2,)), + ((S, 1, S, 1), (-2,)), + ((), (0, )), + ) + + for shape, args in shapes_and_args: + tensor = make_tensor(shape, dtype=dtype, device=device, low=None, high=None, + requires_grad=requires_grad) + + yield SampleInput(tensor, args=args) + + +def sample_inputs_squeeze_multiple(op_info, device, dtype, requires_grad, **kwargs): + shapes_and_args = ( + ((1, 1, 1, 1), ()), + ((S, 1, S, 1), (1,)), + ((S, 1, S, 1), (-1,)), + ((S, 1, S, 1), (1, 3)), + ((S, 1, S, 1), (1, 2,)), + ((), (0,)), + ) + + for shape, dims in shapes_and_args: + tensor = make_tensor(shape, dtype=dtype, device=device, low=None, high=None, + requires_grad=requires_grad) + + yield SampleInput(tensor, dims) + + +def _squeeze_ref(x, axis=None): + # NumPy doesn't allow squeezing scalars + if x.ndim == 0: + return x + + if isinstance(axis, Sequence): + # Numpy doesn't allow specifying non-singular dimensions + axis = tuple(a for a in axis if x.shape[a] == 1) + + if isinstance(axis, int) and x.shape[axis] != 1: + return x + + return np.squeeze(x, axis) + +def sample_inputs_nn_pad(op_info, device, dtype, requires_grad, mode, **kwargs): + assert mode in ('constant', 'reflect', 'replicate', 'circular') + if mode in ['reflect', 'replicate']: + cases: tuple = ( # ignore + ((1, 3), (1, 2)), + ((1, 3), (0, 1)), + ((0, 3, 3), (1, 2)), + ((0, 3, 3), (0, 1)), + ((1, 3, 3), (1, 2)), + ((1, 3, 3), (0, 1)), + ((1, 3, 3), (0, 2, 0, 1)), + ((0, 3, 3, 3), (0, 2, 0, 1)), + ((3, 3, 5, 5), (0, 2, 0, 1)), + ((3, 3, 5, 5), (1, 1, 1, 1, 1, 1)), + ((1, 3, 3, 3, 3), (1, 1, 1, 1, 1, 1)), + ((1, 3, 4, 4), (-1, 1, -2, 1)), + ) + elif mode == 'constant': + cases = ( + ((1, 3), (1, 2)), + ((1, 3), (0, 1)), + ((1, 3), (0, 2, 0, 1)), + ((5, 3), (-1, -2, 1, 1)), + ((0, 3, 3), (1, 2)), + ((0, 3, 3), (0, 1)), + ((0, 3, 3), (0, 2, 0, 1)), + ((0, 3, 3), (1, 1, 1, 1, 1, 1)), + ((1, 3, 3), (1, 2)), + ((1, 3, 3), (0, 1)), + ((1, 3, 3), (0, 2, 0, 1)), + ((1, 3, 3), (1, 1, 1, 1, 1, 1)), + ((0, 3, 3, 3), (1, 2)), + ((0, 3, 3, 3), (0, 1)), + ((0, 3, 3, 3), (0, 2, 0, 1)), + ((0, 3, 3, 3), (1, 1, 1, 1, 1, 1)), + ((3, 3, 5, 5), (1, 2)), + ((3, 3, 5, 5), (0, 1)), + ((3, 3, 5, 5), (0, 2, 0, 1)), + ((3, 3, 5, 5), (1, 1, 1, 1, 1, 1)), + ((1, 3, 3, 3, 3), (1, 2)), + ((1, 3, 3, 3, 3), (0, 1)), + ((1, 3, 3, 3, 3), (0, 2, 0, 1)), + ((1, 3, 3, 3, 3), (1, 1, 1, 1, 1, 1)), + ((1, 3, 4, 4), (-1, 1, -2, 1)), + ) + else: # mode == 'circular' + if dtype == torch.bool: + # test_dtypes fails on ASAN with for the case ab + # runtime error: load of value 190, which is not a valid value for type 'bool' + # Reference: https://github.com/pytorch/pytorch/pull/62814#issuecomment-894156562 + # Reference Issue: https://github.com/pytorch/pytorch/issues/63034 + cases = ( + ((2, 3, 3), (1, 2)), + ((1, 3, 3), (1, 2)), + ) + else: + cases = ( + ((0, 3, 3), (1, 2)), + ((0, 3, 3), (0, 1)), + ((1, 3, 3), (1, 2)), + ((1, 3, 3), (0, 1)), + ((0, 3, 3, 3), (0, 2, 0, 1)), + ((3, 3, 5, 5), (0, 2, 0, 1)), + ((1, 3, 3, 3, 3), (1, 1, 1, 1, 1, 1)), + ((1, 3, 4, 4), (-1, 1, -2, 1)), + ) + + make_inp = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + if mode == 'constant': + # Default args + yield SampleInput(make_inp((1, 3, 3)), args=((2, 2),)) + + if mode in ['reflect', 'replicate', 'circular']: + for shape, pad in cases: + yield SampleInput(make_inp(shape), args=(pad, mode)) + else: # mode == 'constant' + for pad_value in (1., 2.): + for shape, pad in cases: + yield SampleInput(make_inp(shape), args=(pad, mode, pad_value)) + +def sample_inputs_nn_pad_replicate_negative(op_info, device, dtype, requires_grad, **kwargs): + cases: tuple = ( + ((5, 3, 4, 4), (-4, 5, 0, 0)), + ((6, 2, 4, 4), (0, 0, 2, -4)), + ((5, 6, 4, 4), (5, -4, -4, 3)), + ((4, 2, 5, 5), (-2, -1, 4, 6)), + ((2, 6, 5, 5), (8, -1, -1, -3)), + ((8, 1, 5, 5), (-2, -1, -1, -3)), + ) + make_inp = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + for shape, pad in cases: + yield SampleInput(make_inp(shape), args=(pad, 'replicate')) + +def sample_inputs_constant_pad_nd(op_info, device, dtype, *args, **kwargs): + # Inherit sample inputs from nn.pad, but transform them to fit + # constant_pad_nd's interface + nn_samples = sample_inputs_nn_pad(op_info, device, dtype, *args, + mode='constant', **kwargs) + + # NOTE: primTorch is more strict about the type of the fill value argument + # So we must cast it to the correct dtype + from torch._prims_common import dtype_to_type + scalar_type = dtype_to_type(dtype) + + def drop_mode_argument(input, pad, mode=None, value=None): + if value is None: + return SampleInput(input, args=(pad,)) + else: + return SampleInput(input, args=(pad, scalar_type(value))) + + for sample in nn_samples: + yield drop_mode_argument(sample.input, *sample.args, **sample.kwargs) + +def sample_inputs_repeat_interleave(op_info, device, dtype, requires_grad, **kwargs): + make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + yield SampleInput(make_input(()), repeats=2) + yield SampleInput(make_input((2, 3, 4)), repeats=2) + yield SampleInput(make_input((2, 3, 4)), repeats=2, dim=1) + yield SampleInput(make_input((2, 3, 4)), repeats=torch.arange(3, device=device), dim=1) + yield SampleInput(make_input((4, 1)), repeats=torch.arange(4, device=device), dim=0, output_size=6) + + +def sample_inputs_stft(op_info, device, dtype, requires_grad, **kwargs): + def mt(shape, **kwargs): + return make_tensor(shape, device=device, dtype=dtype, + requires_grad=requires_grad, **kwargs) + + yield SampleInput(mt(100), n_fft=10, return_complex=True) + yield SampleInput(mt(100), n_fft=10, return_complex=False) + if dtype.is_complex: + yield SampleInput(mt(100), n_fft=10) + + for center in [False, True]: + yield SampleInput(mt(10), n_fft=7, center=center, return_complex=True) + yield SampleInput(mt((10, 100)), n_fft=16, hop_length=4, + center=center, return_complex=True) + + window = mt(16, low=.5, high=2.0) + yield SampleInput( + mt((2, 100)), kwargs=dict(n_fft=16, window=window, return_complex=True, center=center)) + yield SampleInput( + mt((3, 100)), kwargs=dict(n_fft=16, window=window, return_complex=True, center=center)) + if not dtype.is_complex: + yield SampleInput( + mt((10, 100)), n_fft=16, window=window, onesided=False, + return_complex=True) + + +def sample_inputs_istft(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + def mt(shape, **kwargs): + real_shape = shape if dtype.is_complex else shape + (2,) + return make_arg(real_shape, **kwargs) + + yield SampleInput(mt((10, 2)), kwargs=dict(n_fft=10)) + yield SampleInput(mt((6, 3)), kwargs=dict(n_fft=6, onesided=False)) + yield SampleInput(mt((6, 4)), kwargs=dict(n_fft=10, onesided=True)) + + for center in [False, True]: + yield SampleInput(mt((10, 10, 6)), kwargs=dict(n_fft=10, center=center)) + yield SampleInput(mt((1, 9, 10)), kwargs=dict(n_fft=16, hop_length=4, center=center)) + + window = make_arg(10, low=.5, high=2.0) + yield SampleInput(mt((10, 10, 6)), kwargs=dict( + n_fft=10, window=window, center=center, return_complex=dtype.is_complex)) + yield SampleInput(mt((10, 10, 10)), kwargs=dict( + n_fft=10, window=window[:8], win_length=8, center=center, return_complex=True)) + + real_window = window if not dtype.is_complex else window.real + yield SampleInput(mt((10, 5, 6)), kwargs=dict(n_fft=8, window=real_window[:8], center=center)) + +def sample_inputs_ormqr(op_info, device, dtype, requires_grad, **kwargs): + # create a helper function wrapping `make_tensor` + make_input = partial(make_tensor, dtype=dtype, device=device, low=-1, high=1) + + batches = [(), (0, ), (2, ), (2, 1)] + ns = [5, 2, 0] + tf = [True, False] + for batch, (m, n), left, transpose in product(batches, product(ns, ns), tf, tf): + input = make_input((*batch, m, n)) + reflectors, tau = torch.geqrf(input) + reflectors.requires_grad_(requires_grad) + tau.requires_grad_(requires_grad) + other_matrix_shape = (m, n) if left else (n, m) + other = make_input((*batch, *other_matrix_shape), requires_grad=requires_grad) + yield SampleInput(reflectors, tau, other, left=left, transpose=transpose) + + +def sample_inputs_cholesky_solve(op_info, device, dtype, requires_grad=False, **kwargs): + cholesky_inverse_samples = sample_inputs_linalg_cholesky_inverse( + op_info, device, dtype, requires_grad=False + ) + + for sample in cholesky_inverse_samples: + psd_matrix = sample.input + sample.input = make_tensor(psd_matrix.shape, dtype=dtype, device=device, requires_grad=requires_grad, low=None, high=None) + sample.args = (psd_matrix.requires_grad_(requires_grad),) + yield sample + + +def sample_inputs_lu(op_info, device, dtype, requires_grad=False, **kwargs): + make_arg = partial(make_fullrank_matrices_with_distinct_singular_values, + dtype=dtype, device=device, requires_grad=requires_grad) + + # not needed once OpInfo tests support Iterables + batch_shapes = ((), (3,), (3, 3)) + for batch_shape, get_infos, size_delta in product(batch_shapes, (True, False), (-2, -1, 0, +1, +2)): + shape = batch_shape + (S + size_delta, S) + input = make_arg(*shape) + yield SampleInput(input, args=(True, get_infos)) + + +def sample_inputs_lu_unpack(op_info, device, dtype, requires_grad=False, **kwargs): + def out_fn(output): + return output[1], output[2] + + for lu_sample in sample_inputs_linalg_lu(op_info, device, dtype, requires_grad, **kwargs): + lu_data, pivots = torch.linalg.lu_factor(lu_sample.input) + lu_data.requires_grad_(requires_grad) + yield SampleInput(lu_data, pivots).with_metadata(output_process_fn_grad=out_fn) + + +def sample_inputs_roll(op_info, device, dtype, requires_grad=False, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + args = ((0, 0), (1, 2), (0, 2), (2, 0), (-1, 0), (10000, 1), (2,), ((1, 2, -1), (0, 1, 2))) + + for arg in args: + yield SampleInput(make_arg((0, 0, 0)), args=arg) + yield SampleInput(make_arg((S, S, S)), args=arg) + + # Scalar tensor + yield SampleInput(make_arg(()), args=(10, )) + +def error_inputs_roll(op_info, device, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=torch.float32) + err_msg1 = "`shifts` required" + s1 = SampleInput(make_arg((S,)), ()) + yield ErrorInput(s1, error_regex=err_msg1) + + err_msg2 = ("shifts and dimensions must align") + s2 = SampleInput(make_arg((S, S)), (2, 1), 0) + yield ErrorInput(s2, error_regex=err_msg2) + + err_msg3 = ("out of range") + s3 = SampleInput(make_arg((S, )), 0, 2) + yield ErrorInput(s3, error_regex=err_msg3, error_type=IndexError) + + err_msg4 = ("Dimension specified as 0") + s4 = SampleInput(make_arg(()), 0, 0) + yield ErrorInput(s4, error_regex=err_msg4, error_type=IndexError) + +def sample_inputs_rot90(op_info, device, dtype, requires_grad=False, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + args = itertools.product(range(-5, 6), [(0, 1), (1, 2), (1, -1)]) + + yield SampleInput(make_arg((S, S, S))) + for arg in args: + yield SampleInput(make_arg((S, S, S)), args=arg) + + +def error_inputs_rot90(op_info, device, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=torch.float32) + err_msg1 = "expected total rotation dims" + s1 = SampleInput(make_arg((S, S)), dims=(0,)) + yield ErrorInput(s1, error_regex=err_msg1) + + err_msg2 = "expected total dims >= 2" + s2 = SampleInput(make_arg((S,))) + yield ErrorInput(s2, error_regex=err_msg2) + + err_msg3 = "expected rotation dims to be different" + s3 = SampleInput(make_arg((S, S)), dims=(1, 1)) + yield ErrorInput(s3, error_regex=err_msg3) + + +def sample_inputs_std_var(op_info, device, dtype, requires_grad, **kwargs): + tensor_nd = partial(make_tensor, (S, S, S), device=device, dtype=dtype, + requires_grad=requires_grad) + tensor_1d = partial(make_tensor, (S,), device=device, dtype=dtype, + requires_grad=requires_grad) + + yield SampleInput(tensor_nd()) + yield SampleInput(tensor_nd(), dim=1) + yield SampleInput(tensor_nd(), dim=1, unbiased=True, keepdim=True) + yield SampleInput(tensor_1d(), dim=0, unbiased=True, keepdim=True) + yield SampleInput(tensor_1d(), dim=0, unbiased=False, keepdim=False) + + yield SampleInput(tensor_nd(), dim=(1,), correction=1.3) + yield SampleInput(tensor_nd(), dim=(1,), correction=S // 2) + yield SampleInput(tensor_nd(), dim=None, correction=0, keepdim=True) + yield SampleInput(tensor_nd(), dim=None, correction=None) + yield SampleInput(tensor_nd(), dim=None, correction=-1) + yield SampleInput(tensor_nd(), dim=None, correction=-5) + yield SampleInput(tensor_nd(), correction=0.5, keepdim=True) + yield SampleInput(tensor_nd(), correction=0, keepdim=True) + yield SampleInput(make_tensor(3, 4, 5, device=device, dtype=dtype, requires_grad=requires_grad), dim=-3) + + +def sample_inputs_std_var_unbiased(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, + requires_grad=requires_grad) + + # Test var_mean(Tensor self, bool unbiased=True) -> (Tensor, Tensor) + yield SampleInput(make_arg((S, S)), True) + yield SampleInput(make_arg((S,)), False) + + +def _generate_correlation_inputs(device, dtype, requires_grad, **kwargs): + shapes = [(2,), (1, 2), (3, 2), (2, 3)] + for shape in shapes: + yield make_tensor(shape, dtype=dtype, device=device, requires_grad=requires_grad) + + +def sample_inputs_corrcoef(op_info, device, dtype, requires_grad, **kwargs): + return (SampleInput(t) for t in _generate_correlation_inputs(device, dtype, requires_grad)) + +def sample_inputs_copysign(op_info, device, dtype, requires_grad, **kwargs): + yield from sample_inputs_elementwise_binary(op_info, device, dtype, requires_grad, **kwargs) + if dtype.is_floating_point: + yield SampleInput(make_tensor(5, dtype=dtype, device=device, requires_grad=requires_grad), -3.14) + + +def sample_inputs_cov(op_info, device, dtype, requires_grad, **kwargs): + for t in _generate_correlation_inputs(device, dtype, requires_grad): + yield SampleInput(t) + num_observations = t.numel() if t.ndimension() < 2 else t.size(1) + fweights = make_tensor((num_observations,), dtype=torch.int, device=device, low=1, high=10) + aweights = make_tensor((num_observations,), dtype=torch.float, device=device, low=0, high=1, requires_grad=requires_grad) + for correction, fw, aw in product(range(num_observations), [None, fweights], [None, aweights]): + yield SampleInput(t.clone().requires_grad_(requires_grad), + correction=correction, fweights=fw, aweights=aw) + + +def error_inputs_cov(op_info, device, **kwargs): + a = torch.rand(S, device=device) + yield ErrorInput( + SampleInput(torch.rand(S, S, S, device=device)), + error_regex="expected input to have two or fewer dimensions") + yield ErrorInput( + SampleInput(a, fweights=torch.rand(S, S, device=device)), + error_regex="expected fweights to have one or fewer dimensions") + yield ErrorInput( + SampleInput(a, aweights=torch.rand(S, S, device=device)), + error_regex="expected aweights to have one or fewer dimensions") + yield ErrorInput( + SampleInput(a, fweights=torch.rand(S, device=device)), + error_regex="expected fweights to have integral dtype") + yield ErrorInput( + SampleInput(a, aweights=torch.tensor([1, 1], device=device)), + error_regex="expected aweights to have floating point dtype") + yield ErrorInput( + SampleInput(a, fweights=torch.tensor([1], device=device)), + error_regex="expected fweights to have the same numel") + yield ErrorInput( + SampleInput(a, aweights=torch.rand(1, device=device)), + error_regex="expected aweights to have the same numel") + yield ErrorInput( + SampleInput(a, fweights=torch.tensor([-1, -2, -3, -4 , -5], device=device)), + error_regex="fweights cannot be negative") + yield ErrorInput( + SampleInput(a, aweights=torch.tensor([-1., -2., -3., -4., -5.], device=device)), + error_regex="aweights cannot be negative") + + +def sample_inputs_permute(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + cases = [((1, 2, 3, 4), (0, 2, 3, 1)), + ((1, 2, 3, 4), (0, -2, -1, 1)), + ((), ()), + ((1, 2, 3, 4), (2, 1, 3, 0))] + + for shape, args in cases: + yield SampleInput(make_arg(shape), args=(args,)) + +def reference_inputs_permute(op, device, dtype, requires_grad, **kwargs): + yield from sample_inputs_permute(op, device, dtype, requires_grad, **kwargs) + + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + cases = ( + ((), ()), + ((1,), (0,)), + ((2, 2), (1, 0)), + ((2, 2), (0, 1)), + ((2, 0, 1), (0, 2, 1)), + ((3, 4, 2), (2, 1, 0)), + ((3, 4, 2), (1, 0, 2)), + ((3, 4, 2), (0, 1, 2)), + ) + + # Adds tricky permutations and permutations with noncontiguity + for shape, permutation in cases: + for p in itertools.permutations(permutation): + a = make_arg(shape).permute(p) + yield SampleInput(a, args=(permutation,)) + + a = make_arg(shape, noncontiguous=True).permute(p) + yield SampleInput(a, args=(permutation,)) + +def error_inputs_softshrink(op, device, **kwargs): + yield ErrorInput(SampleInput(make_tensor((1,), dtype=torch.float, device=device), kwargs={"lambd": -0.5}), + error_regex=r"lambda must be in range \[0,.*input dtype.*found -0\.5") + +def sample_inputs_softshrink(op_info, device, dtype, requires_grad=False, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + # The additional sample is to check additional values of lambd beyond the default + # value (what is already checked by sample_inputs_elementwise_unary) + for lbda in (0., 0.5): + yield SampleInput(make_arg(S, S), kwargs={"lambd": lbda}) + + yield from sample_inputs_elementwise_unary(op_info, device, dtype, requires_grad) + +def sample_inputs_hardshrink(op_info, device, dtype, requires_grad=False, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + # The additional sample is to check additional values of lambd beyond the default + # value (what is already checked by sample_inputs_elementwise_unary) + # Note that unlike softshrink, lambd is allowed to be negative for hardshrink + for lbda in (-0.5, 0., 0.5): + yield SampleInput(make_arg(S, S), kwargs={"lambd": lbda}) + + yield from sample_inputs_elementwise_unary(op_info, device, dtype, requires_grad) + + +def sample_inputs_hardtanh(op_info, device, dtype, requires_grad=False, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + # The additional sample is to check additional values of min_val and max_val beyond the default + # value (what is already checked by sample_inputs_elementwise_unary) + for max_val, min_val in ((0.5, -0.5), (0., 0.)): + yield SampleInput(make_arg(S, S), kwargs={"min_val": min_val, "max_val": max_val}) + + yield from sample_inputs_elementwise_unary(op_info, device, dtype, requires_grad) + +def error_inputs_hardtanh(op_info, device, **kwargs): + # Tests that hardtanh errors out when passed min_val > max_val. + yield ErrorInput(SampleInput(make_tensor((1,), dtype=torch.float, device=device), kwargs={"min_val": 0.5, "max_val": -0.5}), + error_type=ValueError, error_regex="min_val cannot be greater than max_val") + +def sample_inputs_einsum(op_info, device, dtype, requires_grad=False, **kwargs): + def c(t): + return t.clone().requires_grad_(requires_grad) + + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + x = make_arg((3,)) + y = make_arg((4,)) + A = make_arg((2, 3,)) + B = make_arg((1, 3,)) + C = make_arg((1, 2, 3,)) + D = make_arg((1, 3, 4,)) + E = make_arg((4, 4,)) + H = make_arg((3, 3,)) + I = make_arg((1, 3, 1,)) + + # Vector operations + yield SampleInput([c(x)], 'i->') # sum + yield SampleInput([c(x), c(y)], 'i,j->ij') # outer + + # Matrix operations + yield SampleInput([c(A)], "ij->i") # col sum + yield SampleInput([c(A), c(B)], "ij,kj->ik") # matmul + yield SampleInput([c(A), c(E)], "ij,Ab->ijAb") # matrix outer product + + # Tensor operations + yield SampleInput([c(C), c(D)], "aij,ajk->aik") # batch matmul + yield SampleInput([c(D), c(E)], "aij,jk->aik") # tensor matrix contraction + yield SampleInput([c(C), c(B)], "ijk,ik->j") # non contiguous + + # Test diagonals + yield SampleInput([c(I)], 'iji->j') # non-contiguous trace + + # Test ellipsis + yield SampleInput([c(H)], "i...->...") + yield SampleInput([c(C), c(x)], '...ik, ...j -> ij') + + +def sample_inputs_flip(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) + sizes = ((S, M, S), (S, 0, M)) + all_dims = ((0, 1, 2), (0,), (0, 2), (-1,), ()) + + for size, dims in product(sizes, all_dims): + yield SampleInput(make_arg(size), kwargs={"dims": dims}) + +def sample_inputs_fliplr_flipud(op_info, device, dtype, requires_grad, **kwargs): + shapes = [ + (S, M, S), + (S, 0, M), + ] + make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) + return (SampleInput(make_arg(shape, low=None, high=None)) for shape in shapes) + +def error_inputs_fliplr(op, device, **kwargs): + yield ErrorInput(SampleInput(make_tensor((1,), dtype=torch.float, device=device)), + error_regex="Input must be >= 2-d.") + +def error_inputs_flipud(op, device, **kwargs): + yield ErrorInput(SampleInput(make_tensor((), dtype=torch.float, device=device)), + error_regex="Input must be >= 1-d.") + +def sample_inputs_clamp(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, dtype=dtype, device=device, low=None, high=None, requires_grad=requires_grad) + make_integral_arg = partial(make_tensor, dtype=torch.int32, device=device, low=None, high=None, requires_grad=False) + shape = (S, M, S) + + yield SampleInput(make_arg(shape), args=(make_arg(shape), make_arg(shape))) + yield SampleInput(make_arg(shape), args=(make_arg(shape[1:]), make_arg(shape[1:]))) + yield SampleInput(make_arg(shape), args=(make_arg((S, 1, S)),)) + yield SampleInput(make_arg(shape), args=(None, make_arg(shape))) + yield SampleInput(make_arg(shape), args=(make_arg(shape), None)) + # test type promotion + yield SampleInput(make_arg(shape), args=(make_integral_arg(shape), None)) + yield SampleInput(make_arg(shape), args=(make_arg(shape), make_integral_arg(shape))) + +def reference_inputs_elementwise_ternary(op, device, dtype, requires_grad, *, sample_inputs_func, supports_scalars=False, **kwargs): + yield from sample_inputs_func(op, device, dtype, requires_grad, **kwargs) + + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + make_scalar_tensor = partial(make_tensor, (), device='cpu', dtype=dtype, requires_grad=requires_grad) + supported_dtypes = op.supported_dtypes(device) + + # broadcasting and oncontiguous cases + cases = ( + ((4, 4), (4, 4), (4, 4)), + ((4, 4), (1, 4, 4), (4, 4)), + ((4, 4), (1, 4, 4), (4, 1, 4)), + ((4, 4, 1), (1, 4, 4), (4, 4)), + ((4, 1), (1, 4, 4), (1, 4)), + ((4, 4), (), (4, 4)), + ((4, 4), (), ()), + ((), (4, 4), (1, 4, 4)), + ) + + for a, b, c in cases: + yield SampleInput(make_arg(a), args=(make_arg(b), make_arg(c))) + yield SampleInput(make_arg(a, noncontiguous=True), + args=(make_arg(b).transpose(0, -1), make_arg(c, noncontiguous=True).transpose(0, -1))) + + # scalar cases + if supports_scalars: + cases = [ + ((), 1, 2,), + ((), 1., 2), + ((4, 4), 1., 2,), + ((3, 4), make_scalar_tensor(), make_scalar_tensor()), + ] + + if torch.complex64 in supported_dtypes: + cases.extend([ + ((3, 1, 4), complex(1, 2), 3.), + ]) + + for a, b, c in cases: + yield SampleInput(make_arg(a), args=(b, c)) + + # type promotion cases + # int x float + if torch.float in supported_dtypes and torch.long in supported_dtypes: + a = make_arg((), dtype=torch.long) + b = make_arg((1, 4), dtype=torch.float) + c = make_arg((3, 4)) + + cases = ( + (a, b, c), + (c, a, b), + ) + + for a, b, c in cases: + yield SampleInput(a, args=(b, c)) + + # NaN propagation + if dtype.is_floating_point or dtype.is_complex: + nan = float('nan') if dtype.is_floating_point else complex(float('nan'), float('nan')) + + a = make_arg((12,)) + a[4] = nan + a[7] = nan + b = make_arg((12,)) + b[1] = nan + b[7] = nan + c = make_arg((12,)) + c[9] = nan + + yield SampleInput(a, args=(b, c)) + + +def _clamp_min_numpy(a, min=None): + return np.maximum(a, min) + + +def _clamp_max_numpy(a, max=None): + return np.minimum(a, max) + + +def _clamp_numpy(a, min=None, max=None): + if min is None: + return np.minimum(a, max) + if max is None: + return np.maximum(a, min) + + return np.minimum(max, np.maximum(a, min)) + + +def sample_inputs_cumprod(op_info, device, dtype, requires_grad, **kwargs): + def make_arg(shape): + # shrink values to be in the interval [-1, +1] for better precision in gradgradcheck + return make_tensor(shape, dtype=dtype, device=device, low=-1, high=+1, requires_grad=requires_grad) + + def prod_zeros(dim_select): + assert len(dim_select) == 2 + result = make_arg(3 * (S,)) + result.narrow(dim_select[0], 0, 1).narrow(dim_select[1], 1, 1).zero_() + result.narrow(dim_select[0], 2, 1).narrow(dim_select[1], 3, 1).zero_() + result.narrow(dim_select[0], 4, 1).narrow(dim_select[1], 3, 1).zero_() + return result + + for dim in range(3): + yield SampleInput(make_arg((S, S, S)), args=(dim,)) + # Scalar tensors and empty tensor + for size in [(), (1,), (0,)]: + yield SampleInput(make_arg(size), args=(0,)) + + yield SampleInput(prod_zeros([0, 1]), args=(1,)) + yield SampleInput(prod_zeros([0, 2]), args=(1,)) + yield SampleInput(prod_zeros([1, 2]), args=(1,)) + + # test dtype kwarg + yield SampleInput(prod_zeros([1, 2]), args=(1,), kwargs={'dtype': dtype}) + +def sample_inputs_view_as_complex(op_info, device, dtype, requires_grad, **kwargs): + yield SampleInput(make_tensor((S, 2), dtype=dtype, device=device, requires_grad=requires_grad)) + +def sample_inputs_view_as_real(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + sizes = ((S, S), ()) + return (SampleInput(make_arg(size)) for size in sizes) + +def error_inputs_complex(op_info, device, is_ref=False, **kwargs): + make_arg = partial(make_tensor, dtype=torch.float32, device=device) + other_dtype = torch.float16 if device.startswith("mps") else torch.float64 + other_dtype_name = "Half" if device.startswith("mps") else "Double" + + if is_ref: + error_float = "Expected both inputs to be Half, Float or Double tensors but got torch.float32 and torch.int32" + error_dtype = "Expected object of scalar type torch.float32 but got scalar type torch.float64 for second argument" + error_out = "Expected out tensor to have dtype torch.complex128 but got torch.complex64 instead" + else: + error_float = "Expected both inputs to be Half, Float or Double tensors but got Float and Int" + error_dtype = f"Expected object of scalar type Float but got scalar type {other_dtype_name} for second argument" + error_out = f"Expected object of scalar type Complex{other_dtype_name} but got scalar type ComplexFloat for argument 'out'" + + yield ErrorInput(SampleInput(make_arg(M, S), make_arg(M, S, dtype=torch.int)), + error_type=RuntimeError, error_regex=error_float) + + yield ErrorInput(SampleInput(make_arg(M, S), make_arg(M, S, dtype=other_dtype)), + error_type=RuntimeError, error_regex=error_dtype) + + yield ErrorInput(SampleInput(make_arg(M, S, dtype=other_dtype), make_arg(M, S, dtype=other_dtype), + out=make_arg(M, S, dtype=torch.complex64)), + error_type=RuntimeError, error_regex=error_out) + +def sample_inputs_logaddexp(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + shape = (S, S) + yield SampleInput(make_arg(shape), make_arg(shape)) + +def sample_inputs_prod(op_info, device, dtype, requires_grad, **kwargs): + def make_arg(shape): + # shrink values to be in the interval [-1, +1] for better precision in gradgradcheck + return make_tensor(shape, dtype=dtype, device=device, low=-1, high=+1, requires_grad=requires_grad) + + def prod_single_zero(): + result = make_arg(2 * (S,)) + result[0, 1] = 0 + return result + + for sample in sample_inputs_cumprod(op_info, device, dtype, requires_grad): + # only Tensor, ignore other inputs + yield SampleInput(sample.input.clone().requires_grad_(requires_grad)) + yield sample + + # Generates samples with keepdim = True + for sample in sample_inputs_cumprod(op_info, device, dtype, requires_grad): + sample.kwargs['keepdim'] = True + yield sample + + yield SampleInput(prod_single_zero()) + yield SampleInput(make_arg((3, 3, 3)), args=(1,)) + yield SampleInput(make_arg((3, 3, 3)), args=(1,), kwargs={'keepdim': True}) + + yield SampleInput(make_arg((3, 0)), args=(1,)) + yield SampleInput(make_arg((3, 0)), args=(1,), kwargs={'keepdim': True}) + yield SampleInput(torch.tensor([2., 3, 0, 0], dtype=dtype, device=device, requires_grad=requires_grad)) + + # test zero scalar tensor + zero = make_arg(()) + zero.zero_() + yield SampleInput(zero.clone().requires_grad_(requires_grad)) + yield SampleInput(zero.clone().requires_grad_(requires_grad), args=(0,)) + yield SampleInput(zero.clone().requires_grad_(requires_grad), + args=(0,), + kwargs={'keepdim': True}) + +def error_inputs_neg(op_info, device, **kwargs): + si = SampleInput(torch.tensor((False, True), device=device)) + msg = ("Negation, the `\\-` operator, on a bool tensor is not supported." + " If you are trying to invert a mask, use the `\\~` or" + " `logical_not\\(\\)` operator instead.") + yield ErrorInput(si, error_regex=msg) + +def sample_inputs_diag(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad, low=None, high=None) + yield SampleInput(make_arg(M)) + + tensors = ( + make_arg((M, M)), + make_arg((3, 5)), + make_arg((5, 3)), + ) + + args = ((), (2,), (-2,), (1,), (2,)) + + for tensor, arg in product(tensors, args): + yield SampleInput(tensor.clone().requires_grad_(requires_grad), *arg) + +def reference_inputs_diagonal_diag_embed(op_info, device, dtype, requires_grad, **kwargs): + yield from sample_inputs_diagonal_diag_embed( + op_info, device, dtype, requires_grad, **kwargs) + + make_arg = partial( + make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + shapes1d = ((0,), (1,)) + shapes2d = ((L, M),) + shapes3d = ((L, M, S),) + + kwargs1d = {} + + kwargs2d = ( + # dim1 > dim2 is allowed + dict(dim1=1, dim2=0), + # negative dims are allowed + dict(dim1=-2, dim2=-1), + # one dim negative and the other nonnegative is allowed + dict(dim1=-1, dim2=0), + # out of bounds offset should return an empty tensor in diagonal and + # offset the diagonal in diag_embed + dict(offset=100), + ) + + kwargs3d = kwargs2d + ( + # make sure we can use non-sequential dims + dict(offset=-1, dim1=0, dim2=2), + ) + + samples1d = product(shapes1d, kwargs1d) + samples2d = product(shapes2d, kwargs2d) + samples3d = product(shapes3d, kwargs3d) + + for shape, kwargs in chain(samples1d, samples2d, samples3d): + if 'diagonal' in op_info.name: + # these are error inputs for diagonal + if shape in ((0,), (1,)): + continue + yield SampleInput(input=make_arg(shape), kwargs=kwargs) + + +def sample_inputs_diagonal_scatter(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) + + # Shapes for 2D Tensors + shapes_2d = ((M, M), (3, 5), (5, 3)) + + # Shapes for 3D Tensors + shapes_3d = ((M, M, M),) + + args_2d = ((), (2,), (-2,), (1,)) + args_3d = ((1, 1, 2), (2, 0, 1), (-2, 0, 1)) + + for input_shape, arg in chain(product(shapes_2d, args_2d), product(shapes_3d, args_3d)): + input_ = make_arg(input_shape) + # We can programmatically figure out the right shape for src: + # It should be the same size as input.diagonal(other_args...) + if not isinstance(arg, tuple): + arg_tuple = (arg,) + else: + arg_tuple = arg + src_shape = input_.diagonal(*arg_tuple).size() + src = make_arg(src_shape) + yield SampleInput(input_, args=(src, *arg_tuple)) + + +def sample_inputs_to_sparse(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + yield SampleInput(make_arg((S, S))).with_metadata(output_process_fn_grad=lambda x: x.to_dense()) + yield SampleInput(make_arg((S, S)), 1).with_metadata(output_process_fn_grad=lambda x: x.to_dense()) + +def sample_inputs_cross_entropy(op_info, device, dtype, requires_grad, **kwargs): + batch_size, num_classes = shape = (2, 3) + reductions = ("mean", "sum", "none") + + input_shape_and_kwargs: list[tuple[tuple[int, ...], dict[str, Any]]] = [ + (shape, {}), + ((*shape, 1), {}), + ((*shape, 1, 2), {}), + ((*shape, 1, 2, 3), {}), + *[(shape, dict(reduction=reduction)) for reduction in reductions], + *[ + ( + shape, + dict( + weight=make_tensor((num_classes,), device=device, dtype=dtype), + reduction=reduction, + ), + ) + for reduction in reductions + ], + (shape, dict(ignore_index=1)), + ] + + for (input_shape, kwargs), probabilities_target in itertools.product(input_shape_and_kwargs, (False, True)): + input = make_tensor(input_shape, device=device, dtype=dtype, requires_grad=requires_grad) + + if probabilities_target: + # ignore_index is not supported for probabilities target + if "ignore_index" in kwargs: + continue + + target = make_tensor( + input_shape, + low=0, + high=1, + device=device, + dtype=dtype, + requires_grad=requires_grad, + ) + else: + target = make_tensor( + (batch_size, *input_shape[2:]), + low=0, + high=num_classes, + device=device, + dtype=torch.long, + ) + + if "ignore_index" in kwargs and torch.all(target == kwargs["ignore_index"]): + # make sure at least one item in target is not ignored + target[0] = random.sample(sorted(set(range(num_classes)) - {kwargs["ignore_index"]}), 1)[0] + + yield SampleInput(input, target, **kwargs) + + +def sample_inputs_logit(op_info, device, dtype, requires_grad, **kwargs): + low, high = op_info.domain + + # Note: Operator is very sensitive at points near the + # start and end of domain and leads to NaN for float16 + # if domain_eps is 1e-5. + if dtype.is_floating_point or dtype.is_complex: + domain_eps = op_info._domain_eps if dtype != torch.float16 else 3e-2 + + low = low + domain_eps + high = high - domain_eps + + make_arg = partial(make_tensor, dtype=dtype, device=device, low=low, high=high, requires_grad=requires_grad) + + yield SampleInput(make_arg((S, S, S))) + yield SampleInput(make_arg((S, S, S)), 0.2) + yield SampleInput(make_arg(())) + yield SampleInput(make_arg(()), 0.2) + +def sample_inputs_isin(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + # isin has two paths based on the size of elements and test_elements. + # if elements.numel() < 10 * pow(test_elements.numel(), 0.145): + yield SampleInput(make_arg((L,)), args=(make_arg((S,)),)) + # else: + yield SampleInput(make_arg((S,)), args=(make_arg((L,)),)) + +def sample_inputs_masked_scatter(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + yield SampleInput(make_arg((S, S)), args=(torch.randn(S, S, device=device) > 0, make_arg((S, S)))) + yield SampleInput(make_arg((S, S)), args=(torch.randn((S,), device=device) > 0, make_arg((S, S)))) + yield SampleInput(make_arg((S, S)), args=(bernoulli_scalar().to(device), make_arg((S, S)))) + yield SampleInput(make_arg((S,)), + args=(torch.randn(S, S, device=device) > 0, make_arg((S, S))), + broadcasts_input=True) + +def error_inputs_masked_scatter(op_info, device, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=torch.float) + for mask_dtype in [torch.float, torch.uint8]: + yield ErrorInput(SampleInput(make_arg(1, 3), args=(torch.ones(1, 3, device=device, dtype=mask_dtype), + make_arg(3, 4))), + error_regex=r"masked_scatter_ only supports boolean masks") + +def sample_inputs_masked_fill(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + yield SampleInput(make_arg((S, S)), args=(torch.randn(S, S, device=device) > 0, 10)) + yield SampleInput(make_arg((S, S)), args=(torch.randn(S, S, device=device) > 0, make_arg(()))) + yield SampleInput(make_arg((S, S)), args=(torch.randn(S, device=device) > 0, 10)) + yield SampleInput(make_arg(()), args=(torch.randn((), device=device) > 0, 10)) + yield SampleInput(make_arg(()), args=(torch.randn((), device=device) > 0, make_arg(()))) + yield SampleInput(make_arg((S, S)), args=(torch.randn((), device=device) > 0, 10)) + + yield SampleInput(make_arg((S,)), + args=(torch.randn(S, S, device=device) > 0, make_arg(())), + broadcasts_input=True) + yield SampleInput(make_arg((S,)), + args=(torch.randn(S, S, device=device) > 0, 10), + broadcasts_input=True) + + if torch.device(device).type == 'cuda': + # `self` and `mask` on CUDA but `value` is a CPU scalar tensor. + yield SampleInput(make_arg((S, S)), + args=(torch.randn(S, S, device=device) > 0, + make_tensor((), device="cpu", dtype=dtype))) + +def error_inputs_masked_fill(op_info, device, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=torch.float, requires_grad=False) + # `value` is not a 0-D tensor. + yield ErrorInput(SampleInput(make_arg((2, 2)), args=(make_arg(()) > 0, make_arg((1,)))), + error_regex="only supports a 0-dimensional value tensor, but got tensor with 1 dimension") + # downcasting complex value (scalar overload) + yield ErrorInput(SampleInput(make_arg((2, 2)), args=(make_arg(()) > 0, 1j)), + error_regex=r"value cannot be converted to type .* without overflow") + # downcasting complex value (tensor overload) + yield ErrorInput(SampleInput(torch.ones(2, dtype=torch.long, device=device), + args=(make_arg(()) > 0, torch.tensor(1j, device=device))), + error_regex=r"value cannot be converted to type .* without overflow") + + if torch.device(device).type == 'cuda': + # `self` and `mask` on CPU but `value` is a CUDA scalar tensor. + yield ErrorInput(SampleInput(torch.randn((S, S), device='cpu'), + args=(torch.randn(S, S, device='cpu') > 0, + torch.randn((), device='cuda'))), + error_regex=r"to be on same device") + + +def sample_inputs_masked_select(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial( + make_tensor, device=device, dtype=dtype, requires_grad=requires_grad, low=None, high=None) + + yield SampleInput(make_arg((M, M)), torch.randn(M, M, device=device) > 0) + + yield SampleInput(make_arg((M, M)), torch.randn((M,), device=device) > 0) + yield SampleInput(make_arg((M,)), torch.randn((M, M), device=device) > 0) + + yield SampleInput(make_arg((M, 1, M)), torch.randn((M, M), device=device) > 0) + + yield SampleInput(make_arg(()), torch.tensor(1, device=device, dtype=torch.bool)) + + yield SampleInput(make_arg((M, M)), torch.tensor(1, device=device, dtype=torch.bool)) + + yield SampleInput(make_arg(()), torch.randn((M, M), device=device) > 0) + +def sample_inputs_matrix_exp(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + yield SampleInput(make_arg((S, S))) + yield SampleInput(make_arg((S, S, S))) + +def sample_inputs_matmul(op_info, device, dtype, requires_grad, is_rmatmul=False, **kwargs): + make_arg = partial(make_tensor, dtype=dtype, device=device, low=None, + high=None, requires_grad=requires_grad) + test_cases = (((L,), (L,)), + ((S, M), (M,)), + ((M,), (M, S)), + ((S, M), (M, S)), + ((S, 0), (0, M)), + ((S, S, M), (M,)), + ((S, S, M), (M, S)), + ((S, S, 0), (0, S)), + ((M,), (S, M, S)), + ((S, M), (S, M, S)), + ((0, 0), (S, 0, 0)), + ((S, S, M, M), (S, S, M, S)), + ((S, S, M, M), (M,)), + ((M,), (S, S, M, S)), + ((S, S, S), (1, S, S)) + ) + for lhs_shape, rhs_shape in test_cases: + lhs = make_arg(lhs_shape) + rhs = make_arg(rhs_shape) + if not is_rmatmul: + yield SampleInput(lhs, rhs) + else: + yield SampleInput(rhs, lhs) + + +def sample_inputs_meshgrid(op_info: OpInfo, device: torch.device, dtype: torch.dtype, + requires_grad: bool, + *, variant: str, **kwargs) -> list[SampleInput]: + if variant == 'variadic': + def make_inputs( + tensors: list[torch.Tensor]) -> tuple[Union[torch.Tensor, + list[torch.Tensor]], + tuple[torch.Tensor, ...]]: + return tensors + elif variant == 'list': + def make_inputs( + tensors: list[torch.Tensor]) -> tuple[Union[torch.Tensor, + list[torch.Tensor]], + tuple[torch.Tensor, ...]]: + return [tensors] + else: + raise ValueError( + 'Unsupported variant, must be one of {"variadic", "list"}. ' + f'Got "{variant}".') + + SCALAR = torch.Size([]) + VECTOR = torch.Size([3]) + test_cases: list[list[torch.Size]] = [ + [SCALAR], + [VECTOR], + [VECTOR, SCALAR], + [VECTOR, SCALAR, VECTOR], + [VECTOR, SCALAR, VECTOR, SCALAR], + ] + + for shapes, indexing in itertools.product(test_cases, {'xy', 'ij'}): + args = make_inputs( + [make_tensor(shape, dtype=dtype, device=device, requires_grad=requires_grad) + for shape in shapes]) + yield SampleInput(*args, indexing=indexing) + + +def sample_inputs_mvlgamma(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + tensor_shapes = ((S, S), ()) + ns = (1, 2, 3, 4, 5) + + # Since the accepted lower bound for input + # to mvlgamma depends on `p` argument, + # the following function computes the lower bound + # which we pass to `make_tensor`. + def compute_min_val(p): + return (p - 1.) / 2 + + for shape, n in product(tensor_shapes, ns): + min_val = compute_min_val(n) + if not dtype.is_floating_point: + # Round-up minimum value for integral dtypes + min_val += 1 + else: + min_val += 2 * torch.finfo(dtype).eps + yield SampleInput(make_arg(shape, low=min_val), args=(n,)) + + +# Since `mvlgamma` has multiple entries, +# there are multiple common skips for the additional +# entries. Following function is a helper to that end. +def skips_mvlgamma(skip_redundant=False): + skips = ( + # outside domain values are hard error for mvlgamma op. + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_float_domains'), + DecorateInfo(unittest.expectedFailure, 'TestUnaryUfuncs', + 'test_reference_numerics_extremal'), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_large', + dtypes=(torch.float16, torch.int8)), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_small', + dtypes=(torch.int8,)), + ) + if skip_redundant: + # Redundant tests + skips = skips + ( # type: ignore[assignment] + DecorateInfo(unittest.skip("Skipped!"), 'TestFwdGradients'), + DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients'), + DecorateInfo(unittest.skip("Skipped!"), 'TestJit'), + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon'), + ) + return skips + + +# To test reference numerics against multiple values of argument `p`, +# we make multiple OpInfo entries with each entry corresponding to different value of p. +# We run the op tests from test_ops.py only for `p=1` to avoid redundancy in testing. +def make_mvlgamma_opinfo(variant_test_name, domain, skips, sample_kwargs): + return UnaryUfuncInfo('mvlgamma', + ref=reference_mvlgamma if TEST_SCIPY else None, + aliases=('special.multigammaln',), + variant_test_name=variant_test_name, + domain=domain, + decorators=(precisionOverride({torch.float16: 5e-2}),), + dtypes=all_types_and(torch.half, torch.bfloat16), + dtypesIfCUDA=all_types_and(torch.float16, torch.bfloat16), + sample_inputs_func=sample_inputs_mvlgamma, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + promotes_int_to_float=True, + skips=skips, + sample_kwargs=sample_kwargs) + + +def sample_inputs_cumulative_ops(op_info, device, dtype, requires_grad, supports_dtype_kwargs=True, **kwargs): + def _make_tensor_helper(shape, low=None, high=None): + return make_tensor(shape, dtype=dtype, device=device, low=low, high=high, requires_grad=requires_grad) + + yield SampleInput(_make_tensor_helper((S, S, S)), 0) + yield SampleInput(_make_tensor_helper((S, S, S)), 1) + yield SampleInput(_make_tensor_helper(()), 0) + + if supports_dtype_kwargs: + # NOTE: if `dtype` is not same as input, then inplace variants fail with + # `provided dtype must match the dtype of self tensor in cumsum` + yield SampleInput(_make_tensor_helper((S, S, S)), 1, dtype=dtype) + + +def sample_inputs_unfold(op_info, device, dtype, requires_grad, **kwargs): + test_cases = ( + ((), (0, 1, 1)), + ((S, S, S, S), (0, 3, 1)), + ((S, S, S, S), (1, 3, 1)), + ((S, S, S, S), (2, 3, 1)), + ((S, S, S, S), (3, 3, 1)), + ((S, S, S, S), (0, 3, 2)), + ((S, S, S, S), (1, 3, 2)), + ((S, S, S, S), (2, 3, 2)), + ((S, S, S, S), (3, 3, 2)), + ((S, S, S, S), (0, 4, 1)), + ((S, S, S, S), (1, 4, 1)), + ((S, S, S, S), (2, 4, 1)), + ((S, S, S, S), (3, 4, 1)), + ((M,), (0, 3, 1)), + ((M,), (0, 3, 2)), + ((M,), (0, 3, 3)), + ((1000,), (0, 3, 11)), + ((1000,), (0, 2, 27)), + ((10, 10), (0, 1, 2)), + ((10, 10), (1, 2, 3)), + ((10, 10), (1, 2, 2)), + ((S, S, S), (2, 3, 2)), + ) + + for shape, arguments in test_cases: + yield SampleInput(make_tensor(shape, dtype=dtype, device=device, + low=None, high=None, + requires_grad=requires_grad), + *arguments) + +def sample_inputs_split(op_info, device, dtype, requires_grad, *, list_args=False, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + if list_args: + cases = ( + ((S, S, S), (torch.Size([int(S / 3), S - int(S / 3) * 2, int(S / 3)]),)), + ((S, S, S), (torch.Size([int(S / 2), S - int(S / 2) * 2, int(S / 2)]), 2),), + ((S, S, S), (torch.Size([int(S / 2), S - int(S / 2) * 2, int(S / 2)]), -2),) + ) + else: + cases = ( # type: ignore[assignment] + ((S, S, S), (2,)), + ((S, S, S), (S, 1)), + ) + + for shape, args in cases: + yield SampleInput(make_arg(shape), args=args) + + +def sample_inputs_split_with_sizes(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + cases = (((S, S, S), (torch.Size([int(S / 3), S - int(S / 3) * 2, int(S / 3)]),)), + ((S, S, S), (torch.Size([int(S / 3), S - int(S / 3), 0]),)), + ((S, S, S), (torch.Size([int(S / 3), S - int(S / 3) * 2, int(S / 3)]), 2)), + ((S, S, S), (torch.Size([int(S / 3), S - int(S / 3) * 2, int(S / 3)]), -2)), + ) + + for shape, args in cases: + yield SampleInput(make_arg(shape), args=args) + + +def sample_inputs_msort(op_info, device, dtype, requires_grad, **kwargs): + def apply_grad(t): + if dtype in floating_types_and(torch.float16, torch.bfloat16): + t.requires_grad_(requires_grad) + + def large_1d_unique(dtype, device): + res = torch.randperm(L * L * L, dtype=torch.int64, device=device) + res = res.to(dtype) + apply_grad(res) + return res + + # Test case for large tensor. + yield SampleInput(large_1d_unique(dtype, device)) + + yield SampleInput(make_tensor((S, M, S), dtype=dtype, device=device, + low=None, high=None, + requires_grad=requires_grad)) + +def sample_inputs_lerp(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) + + # no broadcast + yield SampleInput(make_arg((S, S)), make_arg((S, S)), 0.4) + # broadcast rhs + yield SampleInput(make_arg((S, S)), make_arg((S,)), 0.4) + # scalar tensor + yield SampleInput(make_arg(()), make_arg(()), 0.4) + # broadcast rhs scalar-tensor + yield SampleInput(make_arg((S, S)), make_arg(()), 0.4) + # broadcast rhs with weight tensor + yield SampleInput(make_arg((S, S)), make_arg((S,)), make_arg((S, S))) + # broadcast rhs and weight tensor + yield SampleInput(make_arg((S, S)), make_arg((S, 1)), make_arg((S,))) + # broadcast lhs + yield SampleInput(make_arg((S,)), make_arg((S, S)), 0.4).with_metadata(broadcasts_input=True) + # scalar broadcast_lhs + yield SampleInput(make_arg(()), make_arg((S, S)), 0.4).with_metadata(broadcasts_input=True) + # broadcast all + yield SampleInput(make_arg((S, 1)), make_arg((S, S)), 0.4).with_metadata(broadcasts_input=True) + # tensor broadcast all + yield SampleInput(make_arg((S, 1)), make_arg((S, S)), make_arg((S, 1))).with_metadata( + broadcasts_input=True) + # no broadcast with weight tensor + yield SampleInput(make_arg((S, S)), make_arg((S, S)), make_arg((S, S))) + # broadcast lhs with weight tensor + yield SampleInput(make_arg((S,)), make_arg((S, S)), make_arg((S, S))).with_metadata( + broadcasts_input=True) + # broadcast lhs and weight tensor + yield SampleInput(make_arg((S,)), make_arg((S, S, S)), make_arg((S, S))).with_metadata( + broadcasts_input=True) + # broadcast lhs and weight tensor variant + yield SampleInput(make_arg((S, S)), make_arg((S, S, S)), make_arg((S,))).with_metadata( + broadcasts_input=True) + + if dtype.is_complex: + # no broadcast + yield SampleInput(make_arg((S, S)), make_arg((S, S)), 0.4j) + yield SampleInput(make_arg((S, S)), make_arg((S, S)), 1.2 + 0.1j) + # broadcast rhs + yield SampleInput(make_arg((S, S)), make_arg((S,)), 0.4j) + yield SampleInput(make_arg((S, S)), make_arg((S, S)), 5.4 + 9j) + # scalar tensor + yield SampleInput(make_arg(()), make_arg(()), 0.4j) + yield SampleInput(make_arg(()), make_arg(()), 6.1 + 0.004j) + # broadcast rhs scalar-tensor + yield SampleInput(make_arg((S, S)), make_arg(()), 0.4j) + yield SampleInput(make_arg((S, S)), make_arg(()), 1 + 2j) + +def sample_inputs_tensordot(self, device, dtype, requires_grad, **kwargs): + cases = ( + ((2, 2, 2), (2, 2, 2), (2)), + ((2, 2, 1), (2, 1, 2), ([0, 1], [2, 0])), + ((1, 1, 1), (2, 1, 2), ([0, 1], [2, 0])), + ) + for first_shape, second_shape, dims in cases: + yield SampleInput(make_tensor(first_shape, dtype=dtype, device=device, + requires_grad=requires_grad, low=-1, high=+2), + make_tensor(second_shape, dtype=dtype, device=device, + requires_grad=requires_grad, low=-1, high=+2), + dims=dims) + +def sample_inputs_kron(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial( + make_tensor, dtype=dtype, device=device, requires_grad=requires_grad, low=None, high=None) + test_cases = ( + ((S, S), (M, L)), + ) + + for input_shape, other_shape in test_cases: + input = make_arg(input_shape) + other = make_arg(other_shape) + yield SampleInput(input, other) + +def sample_inputs_inner(self, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) + yield SampleInput(make_arg(S), make_arg(S)) + yield SampleInput(make_arg(), make_arg(S, S)) + +def sample_inputs_scatter(op_info, device, dtype, requires_grad, **kwargs): + def _tensor(shape, dtype=dtype, low=None, high=None): + return make_tensor(shape, dtype=dtype, device=device, low=low, high=high, requires_grad=requires_grad) + + def _gather(shape, index_dim, max_indices): + return gather_variable(shape, index_dim, max_indices, device=device) + + zero = torch.tensor(0, dtype=torch.long, device=device) + test_cases = ( + (_tensor((M, S)), (0, _gather((S, S), 1, M), _tensor((S, S)))), + (_tensor((M, S)), (0, _gather((S, S), 1, M).to(torch.int32), _tensor((S, S)))), + (_tensor((M, S)), (1, _gather((S, S), 0, S), _tensor((S, S)))), + (_tensor((M, S)), (-1, _gather((S, S), 0, S), _tensor((S, S)))), + (_tensor((M, S)), (0, _gather((M, S // 2), 1, M), _tensor((M, S // 2)))), + (_tensor((M, S)), (1, _gather((M, S // 2), 0, S), _tensor((M, S // 2)))), + (_tensor((M, S)), (-1, _gather((M, S // 2), 0, S), _tensor((M, S // 2)))), + (_tensor(()), (0, zero.detach().clone(), _tensor(()))), + (_tensor(()), (0, zero.detach().clone(), 2.5)), + ) + + for tensor, args in test_cases: + yield SampleInput(tensor, *args) + + if not requires_grad: + yield SampleInput(tensor.detach().clone(), *args, reduce='add') + + if dtype.is_floating_point: + yield SampleInput(tensor.detach().clone(), *args, reduce='multiply') + +def sample_inputs_scatter_add(op_info, device, dtype, requires_grad, **kwargs): + def _tensor(shape, dtype=dtype, low=None, high=None): + return make_tensor(shape, dtype=dtype, device=device, low=low, high=high, requires_grad=requires_grad) + + def _gather(shape, index_dim, max_indices): + return gather_variable(shape, index_dim, max_indices, device=device) + + zero = torch.tensor(0, dtype=torch.long, device=device) + yield SampleInput(_tensor((M, S)), 0, _gather((S, S), 1, M), _tensor((S, S))) + yield SampleInput(_tensor((M, S)), 1, _gather((S, S), 0, S), _tensor((S, S))) + yield SampleInput(_tensor((M, S)), -1, _gather((S, S), 0, S), _tensor((S, S))) + yield SampleInput(_tensor((M, S)), 0, _gather((M, S // 2), 1, M), _tensor((M, S // 2))) + yield SampleInput(_tensor((M, S)), 1, _gather((M, S // 2), 0, S), _tensor((M, S // 2))) + yield SampleInput(_tensor((M, S)), -1, _gather((M, S // 2), 0, S), _tensor((M, S // 2))) + yield SampleInput(_tensor(()), 0, zero.detach().clone(), _tensor(())) + +def sample_inputs_scatter_reduce(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) + gather = partial(gather_variable, device=device) + + zero = torch.tensor(0, dtype=torch.long, device=device) + test_cases = ( + ((M, S), 0, gather((S, S), 1, M), (S, S)), + ((M, S), 1, gather((S, S), 0, S), (S, S)), + ((M, S), -1, gather((S, S), 0, S), (S, S)), + ((M, S), 0, gather((M, S // 2), 1, M), (M, S // 2)), + ((M, S), 1, gather((M, S // 2), 0, S), (M, S // 2)), + ((M, S), -1, gather((M, S // 2), 0, S), (M, S // 2)), + ((), 0, zero.detach().clone(), ()), + ) + + reduce = op_info.variant_test_name + for (inp_shape, dim, index, src_shape), include_self in product(test_cases, [False, True, False]): + yield SampleInput(make_arg(inp_shape), + args=(dim, index, make_arg(src_shape), reduce), + kwargs={'include_self': include_self}) + + + # Sample inputs to test edge cases for backward + # Check that gradients are propagated correctly for prod when zeros in self/src are reduced + if requires_grad and reduce == 'prod': + # This sample tests gradients for the following cases + # (a) 1 zero reduced (from src (self[0, 1], self[1, 1]), from self (self[0, 0], self[2, 0])) + # (b) 2 zeros reduced (1 from src and 1 from self (self[1, 0]) + # (c) no zeros reduced (self([2, 1])) + # (d) 2 zeros reduced (both from src) is tested in test/test_autograd.py + # test_scatter_index_reduce_prod_gradgrad_error as this case is not supported for gradgrad + input = torch.tensor([[0, 13], [0, 17], [0, 19]], dtype=dtype, device=device, requires_grad=requires_grad) + src = torch.tensor([[0, 1, 2, 3], [0, 4, 0, 1], [2, 3, 5, 6]], dtype=dtype, device=device, requires_grad=requires_grad) + idx = torch.tensor([[1, 1, 0, 0], [0, 0, 1, 1], [0, 0, 0, 1]], dtype=torch.long, device=device) + + yield SampleInput(input, + args=(1, idx, src, reduce), + kwargs={'include_self': True}) + +def sample_inputs_segment_reduce(op_info, device, dtype, requires_grad, *, mode='lengths', **kwargs): + def _tensor(shape, dtype=dtype, low=None, high=None): + return make_tensor(shape, dtype=dtype, device=device, low=low, high=high, requires_grad=requires_grad) + + test_cases = ( + # inp_shape, dim, lengths, unsafe + ((S,), 0, [0, 1, 2, 2], False), + ((S,), 0, [0, 1, 2, 2], True), + ((S,), 0, [2, 0, 3, 0], False), + ((S, S), 0, [0, 1, 2, 2], False), + # test when lengths do not sum to dim size + ((M, S, S), 0, [1, 2, 0, 6, 0], True), + # test for higher dimensions + ((S, S), 1, [[0, 1, 2, 2] for _ in range(S)], False), + ((S, S), 1, [[2, 0, 3, 0], [0, 1, 2, 2], [3, 0, 2, 0], [1, 1, 1, 2], [0, 1, 2, 2]], False), + ((S, S, S), 1, [[0, 1, 2, 2] for _ in range(S)], False), + ((S, S, S), 1, [[2, 0, 3, 0], [0, 1, 2, 2], [3, 0, 2, 0], [1, 1, 1, 2], [0, 1, 2, 2]], False), + ) + + reductions = ["max", "mean", "min", "sum", "prod"] + for args, reduce, initial in product(test_cases, reductions, [1, 2]): + inp_shape, dim, lengths, unsafe = args + lengths_t = torch.tensor(lengths, dtype=torch.long, device=device) + sample_input_kwargs = {'axis': dim, 'unsafe': unsafe, 'initial': initial} + if mode == 'lengths': + sample_input_kwargs['lengths'] = lengths_t + elif mode == 'offsets': + zeros_shape = list(lengths_t.shape) + zeros_shape[dim] = 1 + offsets_t = torch.cat((lengths_t.new_zeros(zeros_shape), lengths_t), dim).cumsum_(dim) + sample_input_kwargs['offsets'] = offsets_t + else: + raise RuntimeError(f"mode most be one of 'offsets' or 'lengths' got '{mode}'.") + yield SampleInput(_tensor(inp_shape), + args=(reduce,), + kwargs=sample_input_kwargs) + + +def sample_inputs_ravel(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, dtype=dtype, device=device, + low=None, high=None, requires_grad=requires_grad) + yield SampleInput(make_arg((S, S, S))) + yield SampleInput(make_arg(())) + yield SampleInput(make_arg((S, S, S), noncontiguous=True)) + +def sample_inputs_unravel_index(op_info, device, dtype, requires_grad, **kwargs): + yield SampleInput( + torch.tensor( + [[3, 8, 13], [0, 5, 10]], + device=device, + dtype=dtype), + (4, 5)) + yield SampleInput( + torch.tensor([[3, 8, 13], [0, 5, 10]], device=device, dtype=dtype), + (4, 2**30)) + yield SampleInput( + torch.tensor([[3, 8, 13], [0, 5, 10]], device=device, dtype=dtype), + (2**30, 4)) + yield SampleInput( + torch.tensor(2, device=device, dtype=dtype), + (2, 2)) + max_val = 2**(8 * dtype.itemsize - (1 if dtype.is_signed else 0)) - 1 + yield SampleInput( + torch.tensor(max_val - 1, device=device, dtype=dtype), + (1, max_val)) + yield SampleInput( + torch.tensor([22, 41, 37], device=device, dtype=dtype), + (7, 6)) + yield SampleInput( + torch.tensor(min(1621, max_val), device=device, dtype=dtype), + (6, 7, 8, 9)) + yield SampleInput( + torch.tensor([], device=device, dtype=dtype), + (10, 3, 5)) + yield SampleInput( + torch.tensor( + [[1, 0, 1, 2, 3, 4], [1, 6, 1, 3, 2, 0]], + device=device, + dtype=dtype), + (5, 8)) + yield SampleInput( + torch.tensor( + [[1, 0, 1, 2, 3, 4], [1, 6, 1, 3, 2, 0], [1, 3, 1, 0, 9, 5]], + device=device, + dtype=dtype), + (5, 8, 10)) + yield SampleInput( + torch.tensor(0, device=device, dtype=dtype), + ()) + + a = np.array([[2, 4, 5, 6], [7, 8, 1, 15]]) + b = np.array([[3, 2, 7, 6], [10, 12, 8, 9]]) + _, i1, i2 = np.intersect1d(a, b, assume_unique=True, return_indices=True) + yield SampleInput(torch.tensor(i1, device=device, dtype=dtype), a.shape) + yield SampleInput(torch.tensor(i2, device=device, dtype=dtype), b.shape) + + a = np.array([[2, 4, 5, 6, 6], [4, 7, 8, 7, 2]]) + b = np.array([[3, 2, 7, 7], [10, 12, 8, 7]]) + _, i1, i2 = np.intersect1d(a, b, return_indices=True) + yield SampleInput(torch.tensor(i1, device=device, dtype=dtype), a.shape) + yield SampleInput(torch.tensor(i2, device=device, dtype=dtype), b.shape) + + +def sample_inputs_tril_triu(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) + cases = (((M, M), ()), + ((M, M), (2,),), + ((M, S), ()), + ((M, S), (-1,)), + ((M, M), (2,),), + ((S, M, S), ()), + ((S, M, S), (2,)), + ((3, 3, S, S), ()),) + + for shape, args in cases: + yield SampleInput(make_arg(shape), args=args) + +def error_inputs_tril_triu(opinfo, device, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=torch.float32) + + # error inputs for input.ndim <= 2 + yield ErrorInput(SampleInput(make_arg((4,))), error_regex="input tensor must have at least 2 dimensions") + +def sample_inputs_trilu_indices(op_info, device, dtype, requires_grad, **kwargs): + # (row, col, offset) + args_list = ((0, 0), + (20, 0), + (0, 20), + (20, 21, 0), + (20, 21, 7), + (20, 21, -7), + # Large test cases below are deliberately commented out to speed up CI + # tests and to avoid OOM error. When modifying implementations of + # tril_indices and triu_indices, please enable these tests and make sure + # they pass. + # (2, 68435455, 3), + # (5000, 5000), + # (5000, 5000, 1234), + # (5000, 5000, -1233), + ) + for args in args_list: + yield SampleInput(args[0], args=args[1:], kwargs={"dtype": dtype, "device": device}) + +def sample_inputs_clone_contiguous(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) + + yield SampleInput(make_arg((S, M, S))) + yield SampleInput(make_arg(())) + +def reference_inputs_clone_contiguous(op, device, dtype, requires_grad, **kwargs): + # NOTE: the default memory format for clone is torch.preserve_format, for contiguous it's torch.contiguous_format + # This exploits that default to test torch.preserve_format for clone, without causing an error when testing contiguous + yield from sample_inputs_clone_contiguous(op, device, dtype, requires_grad, **kwargs) + + shapes = ( + (3, 5, 6), + (1, 1, 3, 5, 6), + (1, 1, 3, 5, 6, 1, 1), + (1, 0, 3, 5, 0, 2), + (1, 0, 3, 5, 0, 0, 1, 1, 2), + (), + ) + + make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) + for shape in shapes: + yield SampleInput(make_arg(shape)) + yield SampleInput(make_arg(shape).transpose(0, -1)) + yield SampleInput(make_arg(shape, noncontiguous=True)) + yield SampleInput(make_arg(shape, noncontiguous=True).transpose(0, -1)) + + yield SampleInput(make_arg(shape), kwargs={'memory_format': torch.contiguous_format}) + yield SampleInput(make_arg(shape).transpose(0, -1), kwargs={'memory_format': torch.contiguous_format}) + yield SampleInput(make_arg(shape, noncontiguous=True), kwargs={'memory_format': torch.contiguous_format}) + yield SampleInput(make_arg(shape, noncontiguous=True).transpose(0, -1), kwargs={'memory_format': torch.contiguous_format}) + + # shape, strides, offset + strided_cases = ( + ((5, 6, 2), (1, 1, 7), 2), + ((5, 5, 4), (1, 1, 7), 2), + ((5, 5, 2), (4, 5, 7), 3), + ((5, 5, 2), (5, 5, 7), 3), + ((5, 5, 2), (5, 5, 5), 3), + ((9, 5, 2), (0, 1, 7), 3), + ) + + for shape, strides, offset in strided_cases: + yield SampleInput(make_arg(500,).as_strided(shape, strides, offset)) + yield SampleInput(make_arg(500,).as_strided(shape, strides, offset), kwargs={'memory_format': torch.contiguous_format}) + + # channels last 2D + yield SampleInput(make_arg((2, 2, 2, 2)), kwargs={'memory_format': torch.channels_last}) + a = make_arg((2, 2, 2, 2)).permute(0, 3, 1, 2) + yield SampleInput(a, kwargs={'memory_format': torch.channels_last}) + + # channels last 3D + yield SampleInput(make_arg((2, 2, 2, 2, 2)), kwargs={'memory_format': torch.channels_last_3d}) + a = make_arg((2, 2, 2, 2, 2)).permute(0, 4, 1, 2, 3) + yield SampleInput(a, kwargs={'memory_format': torch.channels_last_3d}) + + +def sample_inputs_sum_to_size(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) + + # list of tuples (shape, shape) defining the shapes of the input and output tensors + sample_shapes = [ + ((), ()), + ((S,), (1,)), + ((S, S), (1, 1)), + ((S, S), (1, S)), + ((S, S), (S, S)), + ((S, S, S), (S, 1, S)), + ] + + for input_shape, output_shape in sample_shapes: + yield SampleInput(make_arg(input_shape), args=(output_shape,)) + if output_shape == (): + continue + yield SampleInput(make_arg(input_shape), args=(list(output_shape),)) + yield SampleInput(make_arg(input_shape), args=(*output_shape,)) + + +def error_inputs_sum_to_size(op_info, device, **kwargs): + shape = (M, S, M) + err_msg = "is not expandable to size" + si = SampleInput(make_tensor(shape, device=device, dtype=torch.float32), args=(M, M)) + yield ErrorInput(si, error_regex=err_msg) + + shape = (M + 1, S, S, M) + err_msg = "is not expandable to size" + si = SampleInput(make_tensor(shape, device=device, dtype=torch.float32), args=(M + 1, 1)) + yield ErrorInput(si, error_regex=err_msg) + + +def sample_inputs_resize_ops(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, dtype=dtype, device=device) + cases = (((S, S, S), (S * S, S)), + ((), ()), + ((), (1, 1, 1)), + ) + + for shape, args_or_shape in cases: + # Update `args` based on operator + if op_info.name == 'resize_': + # resize_ takes shape/tuple of ints, + args = (args_or_shape, ) + elif op_info.name == 'resize_as_': + # resize_as_ takes another tensor + args = (make_arg(shape, requires_grad=False), ) # type:ignore[assignment] + else: + raise ValueError("sample_inputs_resize_ops is being used with incorrect operator") + + yield SampleInput(make_arg(shape, requires_grad=requires_grad), args=args) + +def sample_inputs_view_reshape(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) + + cases = ( + # a, b, is_tensor_supported + ((S, S, S), (S * S, S), True), + ((S * S, S), (S, S, S), True), + ((S * S, S), (S, -1, S), False), # neg index + ((S * S * 2, S), (S, -1), False), # neg index + ((S,), (S,), True), + ((), (), False), # empty + ((), (1,), True), + ) + + for a, b, is_tensor_supported in cases: + # skip unsupported cases + if kwargs.get("tensor_arg") and not is_tensor_supported: + continue + + # convert to tensor + if kwargs.get("tensor_arg"): + b = make_arg(b, requires_grad=False) + + yield SampleInput(make_arg(a), args=(b,)) + +def reference_inputs_view_reshape(op, device, dtype, requires_grad, **kwargs): + yield from sample_inputs_view_reshape(op, device, dtype, requires_grad, **kwargs) + + cases = ( + # a, b, is_tensor_supported + ((125,), (25, 5), True), + ((25, 25), (1, 5, 5, 1, 5, 1, 5, 1), True), + ((16, 32), (2, 4, 1, 4, 4, 1, 4), True), + ((16, 12), (12, 16), True), + ((1, 16, 12), (12, 16), True), + ((1, 5, 1, 5), (25, 1), True), + ((2, 4, 2), (4, 4), True), + ((1, 4), (1, 1, 2, 1, 2), True), + ((3, 5, 7), (7, 5, 3), True), + ((1,), (), False), # empty + ((5, 0, 2, 3), (5, 0, 2, 3), True), + ((2, 1, 0, 3, 1), (5, 0), True), + ((1,), (), False), # empty + ((4, 5, 6), (4, 5, 6, 1, 1, 1), True), + ((), (1, 1, 1, 1), False), # empty + ) + + irreversible_cases = ( + ((), (-1,), False), # neg index, empty + ((4, 7, 9, 1, 1), (1, 4, 3, -1, 1), False), # neg index + ) + + make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) + for a, b, is_tensor_supported in cases: + # skip unsupported cases + if kwargs.get("tensor_arg") and not is_tensor_supported: + continue + + if kwargs.get("tensor_arg"): + # convert to tensor + yield SampleInput(make_arg(a), args=(make_arg(b, requires_grad=False),)) + yield SampleInput(make_arg(b), args=(make_arg(a, requires_grad=False),)) + else: + yield SampleInput(make_arg(a), args=(b,)) + yield SampleInput(make_arg(b), args=(a,)) + + for a, b, is_tensor_supported in irreversible_cases: + # skip unsupported cases + if kwargs.get("tensor_arg") and not is_tensor_supported: + continue + + # convert to tensor + if kwargs.get("tensor_arg"): + b = make_arg(b, requires_grad=False) + + yield SampleInput(make_arg(a), args=(b,)) + +def error_inputs_view_reshape(op, device, **kwargs): + + cases = ( + # a, b, is_tensor_supported + # Reshape to different numel + ((2,), (), False), # empty + ((1, 3, 0), (), False), # empty + ((4, 3), (4, 2), True), + ((1, 3, 5), (5, 2, 2), True), + # No valid inference + ((1, 3, 5), (5, -1, 2), False), # neg index + # Two inferred shapes + ((1, 3, 5), (5, -1, -1), False), # neg index + ((1), (0, -1), False), # neg index + ((0, 5), (0, -1), False), # neg index + ) + + make_arg = partial(make_tensor, dtype=torch.float32, device=device, requires_grad=False) + for a, b, is_tensor_supported in cases: + # skip unsupported cases + if kwargs.get("tensor_arg") and not is_tensor_supported: + continue + + if b == (5, -1, -1): + error_regex = "only one dimension can be inferred" + elif a == (0, 5): + error_regex = (r"cannot reshape tensor of 0 elements into shape " + r"\[0, -1\] because the unspecified dimension size " + r"-1 can be any value and is ambiguous") + else: + # to avoid having issues with a regex + shape = ', '.join(map(str, b)) + size = a if type(a) is int else functools.reduce(operator.mul, a, 1) + error_regex = rf"shape '\[{shape}\]' is invalid for input of size {size}" + + # convert to tensor + if kwargs.get("tensor_arg"): + b = make_arg(b, requires_grad=False) + + yield ErrorInput(SampleInput(make_arg(a), args=(b,)), error_type=Exception, + error_regex=error_regex) + + +def sample_inputs_atleast1d2d3d(op_info, device, dtype, requires_grad, **kwargs): + shapes = ((S, S, S, S), (S, S, S), (S, S), (S, ), (),) + make_tensor_partial = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) + for shape in shapes: + yield SampleInput(make_tensor_partial(shape)) + yield SampleInput([make_tensor_partial(shape) for shape in shapes]) + +def sample_inputs_column_stack(op_info, device, dtype, requires_grad, **kwargs): + cases: tuple[tuple, tuple] = ( # type: ignore[assignment] + ((S, 2, 1), (S, 3, 1)), + ((S), (S, 5)), ((), (1, S)) + ) + make_tensor_partial = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) + for shape1, shape2 in cases: + yield SampleInput([make_tensor_partial(shape1), make_tensor_partial(shape2)]) + +def sample_inputs_flatten(op_info, device, dtype, requires_grad, **kwargs): + shapes = ((S, S, S), (S, S), (S, ), (),) + make_tensor_partial = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) + for shape in shapes: + yield SampleInput(make_tensor_partial(shape)) + if len(shape) > 1: + yield SampleInput(make_tensor_partial(shape), start_dim=1, end_dim=-1) + +def reference_inputs_flatten(op, device, dtype, requires_grad, **kwargs): + yield from sample_inputs_flatten(op, device, dtype, requires_grad, **kwargs) + + # shape x start_dim x end_dim + cases = ( + ((5, 4, 0, 1, 3, 7), 1, 3), + ((5, 4, 0, 1, 3, 7), 4, 5), + ((5, 4, 1, 1, 3, 7), 2, 3), + ((), 0, -1), + ((1,), 0, -1), + ((3, 7, 5), 1, 2), + ((4, 5), 1, 1), + ((1, 5, 5, 1, 5, 1, 5, 1), 0, 2), + ((1, 5, 5, 1, 5, 1, 5, 1), 3, -1), + ((1, 5, 5, 1, 5, 7, 5, 1), -2, -1), + ((2, 4, 2), 0, 1), + ((4, 2, 2), 1, 2), + ((0, 3, 4, 5), 1, 3), + ) + + make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) + for shape, start, end in cases: + yield SampleInput(make_arg(shape), args=(start, end,)) + yield SampleInput(make_arg(shape, noncontiguous=True).transpose(0, -1), args=(start, end,)) + yield SampleInput(make_arg(shape).transpose(0, -1), args=(start, end,)) + +def sample_inputs_unflatten(op_info, device, dtype, requires_grad, **kwargs): + # in_shape, dim, sizes + args = (((8,), 0, (8,)), + ((8,), 0, (4, 2)), + ((8,), -1, (2, 2, 2)), + ((8,), -1, (-1, 2)), + ((3, 6, 2), 1, (2, 3)), + ((3, 6, 2), -2, (2, 3)), + ((3, 6, 2), -2, (-1, 3)), + ((3, 2, 12), 2, (3, 2, 2)), + ((4, 0), 0, (2, 2)), + ((4, 0), 1, (2, 0, 0, 0)), + ) + make_tensor_partial = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) + for in_shape, dim, sizes in args: + yield SampleInput(make_tensor_partial(in_shape), args=(dim, sizes)) + + +def sample_inputs_select(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) + + cases = (((S, S, S), (1, 2)), + ((S, S, S), (-1, 2)), + ((S, S, S), (-1, -1)), + ((S, S, S), (1, -1)), + ((S, S), (-1, 2)), + ((S,), (0, 2)) + ) + + for shape, args in cases: + yield SampleInput(make_arg(shape), args=args) + + +def sample_inputs_select_scatter(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) + + cases = (((S, S, S), (S, S), (1, 2)), + ((S, S, S), (S, S), (-1, 2)), + ((S, S, S), (S, S), (-1, -1)), + ((S, S, S), (S, S), (1, -1)), + ((S,), (), (0, 2)) + ) + + for input_shape, src_shape, args in cases: + input_ = make_arg(input_shape) + src = make_arg(src_shape) + yield SampleInput(input_, args=(src, *args)) + + +def sample_inputs_slice_scatter(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) + + cases = (((L, L, L), (L, L, L,), (0, 0, L, 1)), + ((L, L, L), (L // 2, L, L,), (0, L // 2, L, 1)), + ((L, L, L), (L // 4, L, L,), (0, L // 2, L, 2)), + ((L, L, L), (L, L, L,), (1, 0, L, 1)), + ((L, L, L), (L, L // 2, L,), (1, L // 2, L, 1)), + ((L, L, L), (L, L // 4, L,), (1, L // 2, L, 2)), + ((L, L, L), (L, L, L,), (2, 0, L, 1)), + ((L, L, L), (L, L, L // 2,), (2, L // 2, L, 1)), + ((L, L, L), (L, L, L // 4,), (2, L // 2, L, 2)), + ) + + for input_shape, src_shape, args in cases: + input_ = make_arg(input_shape) + src = make_arg(src_shape) + yield SampleInput(input_, args=(src, *args)) + +def sample_inputs_expand(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) + + cases = (((S, 1, 1), (S, S, S)), + ((S, 1, S), (S, S, S)), + ((S, 1, S), (-1, S, -1)), + ((S, 1, S), (-1, S, S)), + ((S, 1), (S, S, S)), + ((1,), (S, S, S)), + ((1, S), (1, 1, S)), + ((), ()), + ((), (1, 3, 2)), + ) + + for case in cases: + shape, args = case + yield SampleInput(make_arg(shape), args=(args,)) + +def sample_inputs_conversion(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) + + shapes = ((), + (2, 3)) + memory_format_options = [None, torch.contiguous_format] + + for shape, memory_format in itertools.product(shapes, memory_format_options): + yield SampleInput(make_arg(shape), + kwargs={'memory_format': memory_format} if memory_format else {}) + yield SampleInput(make_arg((2, 3, 2, 3)), kwargs={'memory_format': torch.channels_last}) + +def sample_inputs_byte(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, dtype=dtype, device=device, low=0, high=255, requires_grad=requires_grad) + + shapes = ((), + (2, 3)) + memory_format_options = [None, torch.contiguous_format] + + for shape, memory_format in itertools.product(shapes, memory_format_options): + yield SampleInput(make_arg(shape), + kwargs={'memory_format': memory_format} if memory_format else {}) + yield SampleInput(make_arg((2, 3, 2, 3)), kwargs={'memory_format': torch.channels_last}) + +def sample_inputs_expand_as(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, dtype=dtype, device=device) + + cases = (((S, 1, 1), (S, S, S)), + ((), ()), + ((), (1, 1)), + ) + + for shape, shape_other in cases: + yield SampleInput(make_arg(shape, requires_grad=requires_grad), + args=(make_arg(shape_other, requires_grad=False),)) + + +def sample_inputs_where(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) + + def make_bool_mask(shape): + # Make sure at least one element is nonzero, + # except for empty tensor + mask_t = make_tensor(shape, dtype=torch.bool, device=device, requires_grad=False) + + if mask_t.numel() == 0: + return mask_t + elif mask_t.numel() == 1: + mask_t.fill_(True) + return mask_t + + if mask_t.sum() == 0: + def random_index(shape): + return tuple(random.randrange(0, max_idx) for max_idx in shape) + + mask_t[random_index(mask_t.shape)] = True + return mask_t + + return mask_t + + cases = (((M, M), (M, M), (M, M), False), + ((M, 1, M), (M, M), (M, M, 1), True), + ((), (), (), False), + ((M, 1, M), (), (M, M, 1), True), + ((), (M, M), (), True), + ((), (2), (1, 1), True), + ) + + for shape, mask_shape, other_shape, broadcasts_input in cases: + yield SampleInput(make_arg(shape), + args=(make_bool_mask(mask_shape), make_arg(other_shape)), + broadcasts_input=broadcasts_input) + +# TODO: add reference inputs for where(condition) signature +def reference_inputs_where(op, device, dtype, requires_grad, **kwargs): + yield from sample_inputs_where(op, device, dtype, requires_grad, **kwargs) + + make_cond = partial(make_tensor, dtype=torch.bool, device=device, requires_grad=requires_grad) + make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) + + # noncontiguous + c = make_cond((10, 3), noncontiguous=True) + a = make_arg((10, 1), noncontiguous=True) + b = make_arg((3, 10, 3)).transpose(0, -1) + + # NOTE that the OpInfo for where takes samples of the form a, cond, b + yield SampleInput(a, args=(c, b)) + + # MPS does not support float64, which causes issues in the following tests + if torch.device(device).type == "mps": + return + + # type promoting + # FIXME(rec): shouldn't other_dtype be used two lines below? + other_dtype = torch.double if dtype is not torch.double else torch.long # noqa: F841 + c = make_cond((10, 3), noncontiguous=True) + a = make_arg((10, 1), dtype=torch.long) + b = make_arg((10, 1)) + + yield SampleInput(a, args=(c, b)) + + # two python scalars + c = make_cond((10, 3), noncontiguous=True) + a = make_arg((1,)).item() + b = make_arg((1,)).item() + + yield SampleInput(a, args=(c, b)) + + # NaN propagation + if dtype.is_floating_point or dtype.is_complex: + if dtype.is_floating_point: + nan = float('nan') + else: + # dtype.is_complex + nan = complex(float('nan'), float('nan')) + c = make_cond((1, 10, 3)) + a = make_arg((10, 3), noncontiguous=True) + a[2, 1] = nan + b = make_arg((1, 3)) + b[0, 2] = nan + + yield SampleInput(a, args=(c, b)) + + # Python scalars type promotion + for scalar in (0, 0.0, 2j, False): + yield SampleInput(scalar, args=(c, b)) + yield SampleInput(a, args=(c, scalar)) + + +def error_inputs_where(op_info, device, **kwargs): + shape = (S,) + err_msg = "Expected all tensors to be on the same device" + for devices in product(('cpu', device), repeat=3): + if len(set(devices)) == 2: + si = SampleInput(make_tensor(shape, device=devices[0], dtype=torch.float32), + args=(make_tensor(shape, dtype=torch.bool, device=devices[1]), + make_tensor(shape, device=devices[2], dtype=torch.float32))) + yield ErrorInput(si, error_regex=err_msg) + +def sample_inputs_nonzero(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) + + sizes = ((), (S,), (S, S), (S, S, S), (S, 1, S), (S, 0, S)) + + inputs = [] + for shape in sizes: + # construct input without any non-zero elements + zeros = torch.zeros(shape, dtype=dtype, device=device, requires_grad=requires_grad) + inputs.append(zeros) + + # construct input with mixed zero and non-zero elements + mixed = make_arg(shape).requires_grad_(False) + mask_t = make_tensor(shape, dtype=torch.bool, device=device, requires_grad=False) + mixed[mask_t] = 0 + inputs.append(mixed) + + for input_t, as_tuple in product(inputs, [False, True]): + yield SampleInput(input_t.clone().requires_grad_(requires_grad), + kwargs=dict(as_tuple=as_tuple)) + +def sample_inputs_nonzero_static(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) + + sizes = ((), (S,), (S, S), (S, S, S), (S, 1, S), (S, 0, S)) + + inputs = [] + for shape in sizes: + # construct input without any non-zero elements + zeros = torch.zeros(shape, dtype=dtype, device=device, requires_grad=requires_grad) + inputs.append(zeros) + + # construct input with mixed zero and non-zero elements + mixed = make_arg(shape).requires_grad_(False) + mask_t = make_tensor(shape, dtype=torch.bool, device=device, requires_grad=False) + mixed[mask_t] = 0 + inputs.append(mixed) + + nonzero_sizes = [0, 1, XS, S, M] + + for input_t, nonzero_size in product(inputs, nonzero_sizes): + yield SampleInput(input_t.clone().requires_grad_(requires_grad), + kwargs=dict(size=nonzero_size)) + +def sample_inputs_chunk(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) + + cases = (((S, S, S), (2,)), + ((S, S, S), (S, 1)), + ((S, S, S), (S, -1))) + + for case in cases: + shape, args = case + yield SampleInput(make_arg(shape), args=args) + +def reference_inputs_chunk(op, device, dtype, requires_grad, **kwargs): + yield from sample_inputs_chunk(op, device, dtype, requires_grad, **kwargs) + + make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) + + # shape x chunks x dim + cases = ( + ((13, 9, 11), 17, -1), + ((13, 9, 11), 11, -1), + ((13,), 12, -1), + ((15,), 12, -1), + ((15,), 7, 0), + ((15,), 9, 0), + ((3, 7), 9, 1), + ((3, 7), 9, 0), + ((3, 7), 2, 0), + ((3, 7), 3, 0), + ((3, 7), 1, 0), + ((3, 7), 1, 1), + ((4, 4), 2, 0), + ) + + for shape, chunks, dim in cases: + yield SampleInput(make_arg(shape), args=(chunks, dim)) + +def sample_inputs_kthvalue(op_info, device, dtype, requires_grad, **kwargs): + def _tensor(shape, dtype=dtype, low=None, high=None): + return make_tensor(shape, dtype=dtype, device=device, low=low, high=high, requires_grad=requires_grad) + + test_cases = [ + ((S, S, S), (2,)), + ((S, S, S), (2, 1,)), + ((S, S, S), (2, -1,)), + ((S, S, S), (2, 1, True,)), + ((S, S, S), (2, -1, True,)), + ((S,), (2, 0,)), + ((S,), (2, 0, True,)), + ((), (1,)), + ((), (1, 0,)), + ((), (1, 0, True)), + ] + + yield from (SampleInput(_tensor(tensor), *args) for tensor, args in test_cases) + +def error_inputs_kthvalue(op_info, device, **kwargs): + # tests overlapping output fails + t = make_tensor(10, dtype=torch.float32, device=device) + indices = torch.empty((), device=device, dtype=torch.long) + yield ErrorInput(SampleInput(t, 5, out=(t, indices)), + error_regex="unsupported operation") + + k_out_of_range_err = "selected number k out of range for dimension" + yield ErrorInput(SampleInput(torch.randn(2, 2, device=device), 3, 0), + error_regex=k_out_of_range_err) + yield ErrorInput(SampleInput(torch.randn(2, 2, device=device), 3), + error_regex=k_out_of_range_err) + yield ErrorInput(SampleInput(torch.tensor(2, device=device), 3), + error_regex=k_out_of_range_err) + +def sample_inputs_dropout(op_info, device, dtype, requires_grad, *, + train=None, valid_input_dim=None, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + if valid_input_dim: + cases = ((S,) * i for i in valid_input_dim) + else: + cases = ((S, S), (S,), ()) + p_vals = [0.0, 0.5, 1.0] + # This is to handle special case for feature_alpha_dropout which has different + # supported dtypes depending on `train` parameter + training_vals = [train] if train is not None else [True, False] + + for case, p, training in product(cases, p_vals, training_vals): + yield SampleInput(make_arg(case), p=p, training=training) + yield SampleInput(make_arg(case)) + +def sample_inputs_dropout_backward(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + make_mask = partial(make_tensor, device=device, dtype=torch.bool, requires_grad=False) + + cases = ((S, S, S, S), (S,), ()) + scale_vals = [0.0, 1.0, 2.0] + + for case, scale in product(cases, scale_vals): + yield SampleInput(make_arg(case), make_mask(case), scale) + +def sample_inputs_embedding_bag(op_info, device, dtype, requires_grad, **kwargs): + def make_input(shape): + return make_tensor(shape, device=device, dtype=dtype, requires_grad=requires_grad) + + def make_long_input(shape, *, low, high, noncontiguous=False): + return make_tensor(shape, device=device, dtype=torch.long, low=low, high=high, + noncontiguous=noncontiguous) + + def make_per_sample_weight(flag, idx): + # a tensor of float / double weights, or None + # to indicate all weights should be taken to be 1 + if flag: + return make_input(idx.shape) + return None + + offsets = torch.tensor([0, 3], device=device, dtype=torch.long) + for generate_per_sample_weight in (True, False): + for mode in ('sum', 'mean', 'max'): + # per_sample_weights is only supported for mode='sum' (got mode='****') + if generate_per_sample_weight and mode in ('mean', 'max'): + continue + + # 1-D index tensor + idx = make_long_input((S,), low=0, high=M) + per_sample_weights = make_per_sample_weight(generate_per_sample_weight, idx) + yield SampleInput(make_input((M, S)), args=(idx,), + kwargs={'offsets': offsets, 'mode': mode, + 'per_sample_weights': per_sample_weights}) + + idx = make_long_input((S,), low=0, high=M, noncontiguous=True) + per_sample_weights = make_per_sample_weight(generate_per_sample_weight, idx) + yield SampleInput(make_input((M, S)), args=(idx,), + kwargs={'offsets': offsets, 'mode': mode, + 'per_sample_weights': per_sample_weights}) + + # bag with zero length + idx = make_long_input((S,), low=0, high=M, noncontiguous=True) + per_sample_weights = make_per_sample_weight(generate_per_sample_weight, idx) + yield SampleInput(make_input((M, S)), args=(idx,), + kwargs={'offsets': torch.tensor([0, 0, 3], device=device, dtype=torch.long), + 'mode': mode, + 'per_sample_weights': per_sample_weights}) + + # 2-D index tensor + idx = make_long_input((S, S), low=0, high=M) + per_sample_weights = make_per_sample_weight(generate_per_sample_weight, idx) + yield SampleInput(make_input((M, S)), args=(idx,), + kwargs={'mode': mode, 'per_sample_weights': per_sample_weights}) + + idx = make_long_input((S, S), low=0, high=M, noncontiguous=True) + per_sample_weights = make_per_sample_weight(generate_per_sample_weight, idx) + yield SampleInput(make_input((M, S)), args=(idx,), + kwargs={'mode': mode, 'per_sample_weights': per_sample_weights}) + + # The gradient vector at `padding_idx` is not updated. + # Negative padding_idx + idx = make_long_input((6,), low=0, high=S) + idx[0] = 4 + idx[4] = 4 + per_sample_weights = make_per_sample_weight(generate_per_sample_weight, idx) + yield SampleInput(make_input((S, S)), args=(idx,), + kwargs={'padding_idx': -1, 'offsets': offsets, + 'mode': mode, 'per_sample_weights': per_sample_weights},) + + idx = make_long_input((3, 3), low=0, high=S) + # Positive padding_idx + idx[0, 0] = 2 + idx[1, 1] = 2 + per_sample_weights = make_per_sample_weight(generate_per_sample_weight, idx) + yield SampleInput(make_input((S, S)), args=(idx,), + kwargs={'padding_idx': 2, 'mode': mode, + 'per_sample_weights': per_sample_weights},) + + idx = make_long_input((6, ), low=0, high=S) + weights = make_input((S, S)) + offsets_ = torch.tensor([0, 3, 6], device=device, dtype=torch.long) + per_sample_weights = make_per_sample_weight(generate_per_sample_weight, idx) + yield SampleInput(weights, args=(idx,), + kwargs={'mode': mode, 'offsets': offsets_, 'include_last_offset': True},) + + if not requires_grad: + # Following inputs return different gradient from the numerical gradient. + # This is expected and relevant tests are present in `test_nn.py`. + + # Due to inplace renorming of weight, the numerical gradient doesn't match the + # analytical gradient. + idx = make_long_input((2, 2), low=0, high=S) + weights = make_input((S, S)) * 2 + per_sample_weights = make_per_sample_weight(generate_per_sample_weight, idx) + yield SampleInput(weights, args=(idx,), + kwargs={'max_norm': 1., 'mode': mode, + 'per_sample_weights': per_sample_weights},) + + idx = make_long_input((6, ), low=0, high=S) + weights = make_input((S, S)) * 2 + per_sample_weights = make_per_sample_weight(generate_per_sample_weight, idx) + yield SampleInput(weights, args=(idx,), + kwargs={'max_norm': 1., 'norm_type': 1.0, + 'mode': mode, 'offsets': offsets, + 'per_sample_weights': per_sample_weights},) + + if mode != 'max': + # Scale the gradient based on the inverse frequency of a particular index. + # Note : smax mode does not support sparse weights + idx = make_long_input((2, 2), low=0, high=S) + idx[0, 0] = 1 + idx[0, 1] = 1 + weights = make_input((S, S)) + per_sample_weights = make_per_sample_weight(generate_per_sample_weight, idx) + yield SampleInput(weights, args=(idx,), + kwargs={'scale_grad_by_freq': True, 'mode': mode, + 'per_sample_weights': per_sample_weights},) + + # gradcheck not implemented for sparse tensors. + # Note : max mode does not support sparse weights + idx = make_long_input((6, ), low=0, high=S) + weights = make_input((S, S)) + per_sample_weights = make_per_sample_weight(generate_per_sample_weight, idx) + yield SampleInput(weights, args=(idx,), + kwargs={'sparse': True, 'offsets': offsets, + 'mode': mode, 'per_sample_weights': per_sample_weights}) + + idx = make_long_input((6, ), low=0, high=S) + idx[0] = 1 # freq more than 1 + idx[1] = 1 # freq more than 1 + idx[3] = 0 # padding_idx + weights = make_input((S, S)) * 2 + per_sample_weights = make_per_sample_weight(generate_per_sample_weight, idx) + yield SampleInput(weights, args=(idx,), + kwargs={'sparse': True, 'scale_grad_by_freq': True, 'padding_idx': 0, + 'max_norm': 1., 'offsets': offsets, + 'mode': mode, 'per_sample_weights': per_sample_weights}) + + +def sample_inputs_embedding(op_info, device, dtype, requires_grad, **kwargs): + def make_input(shape): + return make_tensor(shape, device=device, dtype=dtype, requires_grad=requires_grad) + + def make_long_input(shape, *, low, high): + return make_tensor(shape, device=device, dtype=torch.long, low=low, high=high) + + # 0-D index tensor + idx = make_long_input((), low=0, high=M) + yield SampleInput(make_input((M, S)), args=(idx,),) + + # 1-D index tensor + idx = make_long_input((S,), low=0, high=M) + yield SampleInput(make_input((M, S)), args=(idx,),) + + # 2-D index tensor + idx = make_long_input((S, S), low=0, high=M) + yield SampleInput(make_input((M, S)), args=(idx,),) + + if not requires_grad: + # Following inputs return different gradient from the numerical gradient. + # This is expected and relevant tests are present in `test_nn.py`. + + # The gradient vector at `padding_idx` is not updated. + idx = make_long_input((2, 2), low=0, high=S) + idx[0, 0] = 2 + idx[1, 1] = 2 + yield SampleInput(make_input((S, S)), args=(idx,), kwargs={'padding_idx': 2},) + + idx = make_long_input((2, 2), low=0, high=S) + idx[0, 0] = 4 + idx[1, 1] = 4 + yield SampleInput(make_input((S, S)), args=(idx,), kwargs={'padding_idx': -1},) + + # Due to inplace renorming of weight, the numerical gradient doesn't match the + # analytical gradient. + idx = make_long_input((2, 2), low=0, high=S) + weights = make_input((S, S)) * 2 + yield SampleInput(weights, args=(idx,), kwargs={'max_norm': 1.},) + + idx = make_long_input((2, 2), low=0, high=S) + weights = make_input((S, S)) * 2 + yield SampleInput(weights, args=(idx,), kwargs={'max_norm': 1., 'norm_type': 1.0},) + + # Scale the gradient based on the inverse frequency of a particular index. + idx = make_long_input((2, 2), low=0, high=S) + idx[0, 0] = 1 + idx[0, 1] = 1 + weights = make_input((S, S)) + yield SampleInput(weights, args=(idx,), kwargs={'scale_grad_by_freq': True},) + + # gradcheck not implemented for sparse tensors. + idx = make_long_input((2, 2), low=0, high=S) + weights = make_input((S, S)) + yield SampleInput(weights, args=(idx,), kwargs={'sparse': True}) + + idx = make_long_input((3, 3), low=0, high=S) + idx[0, 0] = 1 # freq more than 1 + idx[0, 1] = 1 # freq more than 1 + idx[1, 0] = 0 # padding_idx + weights = make_input((S, S)) * 2 + yield SampleInput(weights, args=(idx,), + kwargs={'sparse': True, 'scale_grad_by_freq': True, + 'padding_idx': 0, 'max_norm': 1.}) + + +def sample_inputs_one_hot(op_info, device, dtype, requires_grad, **kwargs): + def make_input(shape, *, low, high): + return make_tensor(shape, device=device, dtype=dtype, low=low, high=high, requires_grad=requires_grad) + + shapes = ((), (S,), (L, M, S)) + num_classess = (-1, 10) + + return ( + SampleInput( + make_input( + shape, + low=0, + high=10 if num_classes == -1 else num_classes // 2, + ), + kwargs=dict(num_classes=num_classes), + ) + for shape, num_classes in itertools.product(shapes, num_classess) + ) + + +def sample_inputs_loss(op_info, device, dtype, requires_grad, **kwargs): + rhs_requires_grad = kwargs.get('rhs_requires_grad', requires_grad) + _make_tensor = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + # Although most losses also support the reduce and size_average combination instead of reduce, the former is + # deprecated since 0.4.1 and thus is not tested + shapes_and_kwargs = ( + ((), None), + ((S,), dict(reduction="mean")), + ((S,), dict(reduction="sum")), + ((S,), dict(reduction="none")), + ((S, S), None), + ((S, S, S), None), + ) + + for shape, kwargs in shapes_and_kwargs: + yield SampleInput(_make_tensor(shape), + args=(_make_tensor(shape, requires_grad=rhs_requires_grad),), + kwargs=kwargs) + +def sample_inputs_grid_sample(op_info, device, dtype, requires_grad, **kwargs): + # We get better tests if we change the range of the values to something like [-2,2] + # because for grid (second tensor argument) the "useful" range is [-1,1] and this way + # you get a better combination of out-of-range and in-range test cases + _make_tensor = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad, + low=-2, high=2) + + batch_size = 2 + num_channels = 3 + modes = ("bilinear", "nearest") + align_cornerss = (False, True) + padding_modes = ("zeros", "border", "reflection") + + for dim in (2, 3): + + modes_ = (*modes, "bicubic") if dim == 2 else modes + + for mode, padding_mode, align_corners in itertools.product(modes_, padding_modes, align_cornerss): + yield SampleInput( + _make_tensor((batch_size, num_channels, *[S] * dim)), + _make_tensor((batch_size, *[S] * dim, dim)), + mode=mode, + padding_mode=padding_mode, + align_corners=align_corners, + ) + +def reference_inputs_grid_sample(op_info, device, dtype, requires_grad, **kwargs): + + batch_size = 2 + num_channels = 3 + height = 345 + width = 456 + modes = ("bilinear", "nearest", "bicubic") + align_cornerss = (False, True) + padding_modes = ('zeros', 'border', 'reflection') + + # Create an affine transformation matrix + a = torch.deg2rad(torch.tensor(45.0)) + ca, sa = torch.cos(a), torch.sin(a) # rotation angles + s1, s2 = 1.23, 1.34 # scales + + theta = torch.tensor([[ + [ca / s1, sa, 0.0], + [-sa, ca / s2, 0.0], + ]], dtype=dtype, device=device) + theta = theta.expand(batch_size, 2, 3).contiguous() + + x = torch.arange(batch_size * num_channels * height * width, device=device) + x = x.reshape(batch_size, num_channels, height, width).to(torch.uint8) + x = x.to(dtype=dtype) + x.requires_grad_(requires_grad) + + for mode, padding_mode, align_corners in itertools.product(modes, padding_modes, align_cornerss): + grid = torch.nn.functional.affine_grid( + theta, size=(batch_size, num_channels, height, width), align_corners=align_corners + ) + yield SampleInput( + x, + grid, + mode, + padding_mode, + align_corners, + ) + +def sample_inputs_grid_sampler_2d(op_info, device, dtype, requires_grad, **kwargs): + # We get better tests if we change the range of the values to something like [-2,2] + # because for grid (second tensor argument) the "useful" range is [-1,1] and this way + # you get a better combination of out-of-range and in-range test cases + _make_tensor = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad, + low=-2, high=2) + + batch_size = 2 + num_channels = 3 + modes = (0, 1, 2) + align_cornerss = (False, True) + padding_modes = (0, 1, 2) + + for mode, padding_mode, align_corners in itertools.product(modes, padding_modes, align_cornerss): + yield SampleInput( + _make_tensor((batch_size, num_channels, S, L)), + _make_tensor((batch_size, M + 3, M, 2)), + mode, + padding_mode, + align_corners, + ) + +def sample_inputs_grid_sampler_3d(op_info, device, dtype, requires_grad, **kwargs): + _make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad, + low=-1, high=1) + # Test both out-of-range and in-range grid values + _make_grid = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad, + low=-4, high=4) + + modes = (0,) + padding_modes = (0, 1, 2) + align_cornerss = (False, True) + shape_pairs = [ + # [input_shape, grid_shape] + [(1, 1, 2, 2, 2), (1, 1, 1, 1, 3)], + [(2, 3, S, L, L), (2, M + 2, M + 1, M, 3)], + [(L, L + 1, L + 2, L + 3, L + 4), (L, M + 2, M + 1, M, 3)], + [(M, M + 1, M + 2, M + 3, M + 4), (M, L + 3, L + 2, L + 1, 3)], + [(L, M + 1, M + 2, M + 3, M + 4), (L, L + 3, L + 2, L + 1, 3)], + ] + + params_prod = itertools.product(modes, padding_modes, align_cornerss, shape_pairs) + + for mode, padding_mode, align_corners, (input_shape, grid_shape) in params_prod: + yield SampleInput( + _make_input(input_shape), + _make_grid(grid_shape), + mode, + padding_mode, + align_corners, + ) + +def sample_inputs_cosine_embedding_loss(op_info, device, dtype, requires_grad, **kwargs): + make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + def make_target(shape): + shape = () if len(shape) == 1 else (shape[0], ) + t = torch.randint(0, 2, shape, device=device, dtype=torch.long) + # Label with -1 or 1 + t = t * 2 - 1 + target = t.to(dtype=dtype).detach_().requires_grad_(requires_grad) + return target + + shapes = ((S, S), (S,)) + reductions = ('none', 'mean', 'sum') + for s, r in product(shapes, reductions): + yield SampleInput( + make_input(s), + args=(make_input(s), make_target(s)), + kwargs=dict(reduction=r, margin=random.uniform(-1, 1)) + ) + +def sample_inputs_ctc_loss(op_info, device, dtype, requires_grad, **kwargs): + input_length = 50 + batch = 16 + num_char = 20 + target_length = 30 + + def make_log_probs(s): + t = make_tensor(s, device=device, dtype=dtype) + log_probs = t.log_softmax(2).to(device=device, dtype=dtype).detach().requires_grad_(requires_grad=requires_grad) + return log_probs + + reductions = ('none', 'mean', 'sum') + zero_inf = (True, False) + lengths_type = (list, torch.Tensor) + for r, z, lt in product(reductions, zero_inf, lengths_type): + log_probs = make_log_probs((input_length, batch, num_char)) + targets = torch.randint(1, num_char, (batch, target_length), dtype=torch.long, device=device) + input_lengths = torch.full((batch, ), input_length, dtype=torch.long, device=device) + target_lengths = torch.randint(10, target_length, (batch, ), dtype=torch.long, device=device) + + # Dont generate int[] types if reduction = "Mean" since this results in non composite compliant calls + # to ctc_loss.IntList since a tensor needs to be created from the target lengths. + # Creating such a tensor requires the use of pointers to copy data from int[] -> torch.Tensor + # e.g. via std::copy. Similarly symbolic/real tracing with fx will also not work + if lt is list and r in ["none", "sum"]: + input_lengths = input_lengths.tolist() + target_lengths = target_lengths.tolist() + + yield SampleInput(log_probs, args=(targets, input_lengths, target_lengths,), + kwargs=dict(reduction=r, zero_infinity=z)) + + +def sample_inputs_nll_loss(op_info, device, dtype, requires_grad, **kwargs): + shape = (2, 3) + num_classes = shape[1] + make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + # FIXME: Derivative wrt. weight not implemented + make_weight = partial(make_tensor, num_classes, device=device, dtype=dtype, requires_grad=False) + + def make_target(shape, zeros=False): + s = (shape[0], *shape[2:]) if len(shape) > 1 else () + if zeros: + return torch.zeros(s, device=device, dtype=torch.long) + else: + return make_tensor(s, + low=0, + high=shape[1] if len(shape) > 1 else shape[0], + device=device, + dtype=torch.long) + + + def gen_shape_kwargs(): + # Batched, non-batched and 2d + shapes = (shape, (num_classes,), shape + (2, 2)) + reductions = ('none', 'mean', 'sum') + for reduction, s in product(reductions, shapes): + yield make_input(s), make_target(s), dict(reduction=reduction) + yield make_input(s), make_target(s), dict(weight=make_weight(), reduction=reduction) + yield make_input(s), make_target(s), dict(weight=make_weight(low=0), reduction=reduction) + if dtype.is_floating_point or dtype.is_complex: + yield make_input(s), make_target(s), dict(weight=make_weight(high=0), reduction=reduction) + t = make_target(s) + ignore = num_classes // 2 + # If "mean", nll returns NaN, so it's not differentiable at those points + if t.eq(ignore).all() and reduction == "mean": + t.fill_(0) + yield make_input(s), t, dict(ignore_index=num_classes // 2, reduction=reduction) + yield make_input(s), t, dict(ignore_index=num_classes // 2, reduction=reduction, weight=make_weight()) + # Test ignoring all the targets + # If "mean", nll returns NaN, so it's not differentiable at those points + if reduction != "mean": + yield make_input(s), make_target(s, zeros=True), dict(ignore_index=0, reduction=reduction) + + for input, target, kwargs in gen_shape_kwargs(): + yield SampleInput(input, args=(target,), kwargs=kwargs) + + target = torch.tensor([-1, 2], device=device, dtype=torch.long) + yield SampleInput(make_input(shape), args=(target,), kwargs={'ignore_index': -1}) + + +def sample_inputs_binary_cross_entropy_with_logits( + op_info, device, dtype, requires_grad, **kwargs +): + make = partial(make_tensor, device=device, dtype=dtype) + make_prob = partial(make, low=0, high=1) + reductions = ("mean", "sum", "none") + + def make_weight_shape_kwargs(): + kwargs = [] + for shape in ((1,), (1, S), (S), (S, S)): + kwargs.extend([((S, S), dict(reduction=reduction, weight=make(shape))) for reduction in reductions]) + return kwargs + + shapes_and_kwargs = [ + *[(shape, None) for shape in ((), (1,), (S,), (S, S), (S, S, S))], + *[((S, S), dict(reduction=reduction)) for reduction in reductions], + *make_weight_shape_kwargs(), + *[((S, S), dict(reduction=reduction, pos_weight=make((S,), low=0))) for reduction in reductions], + *[((S, S), dict(reduction=reduction, weight=make((S, S)), pos_weight=make((S,), low=0))) for reduction in reductions], + ] + + for shape, kwargs in shapes_and_kwargs: + yield SampleInput( + make(shape, requires_grad=requires_grad), + args=(make_prob(shape, requires_grad=requires_grad),), + kwargs=kwargs, + ) + +def sample_inputs_argwhere(op_info, device, dtype, requires_grad, **kwargs): + yield SampleInput(torch.tensor([1, 0, 2, 0], dtype=dtype, device=device, requires_grad=requires_grad)) + mask = torch.tensor([[0, 1, 0, 1, 0], + [1, 1, 1, 1, 0], + [0, 0, 0, 1, 0], + [1, 0, 1, 1, 0], + [1, 0, 0, 1, 0]], dtype=torch.bool, device=device) + t = make_tensor((S, S), dtype=dtype, device=device, requires_grad=requires_grad) + t[mask] = 0 + yield SampleInput(t) + + t = make_tensor((S, S), dtype=dtype, device=device, requires_grad=requires_grad, noncontiguous=True) + t[mask] = 0 + yield SampleInput(t) + + t = make_tensor((S, 0), dtype=dtype, device=device, requires_grad=requires_grad) + yield SampleInput(t) + + yield SampleInput(torch.zeros((S,), dtype=dtype, device=device, requires_grad=requires_grad)) + yield SampleInput(make_tensor((), dtype=dtype, device=device, requires_grad=requires_grad)) + +def _generate_sample_shape_reduction(): + shapes = ((S,), (S, S), (S, S, S)) + reductions = ('none', 'mean', 'sum') + yield from product(shapes, reductions) + +def sample_inputs_gaussian_nll_loss(op_info, device, dtype, requires_grad, **kwargs): + _make_tensor = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + # Set low slightly above 0 so gradcheck doesn't accidentally dip below 0 + make_var = partial(make_tensor, low=0.1, device=device, dtype=dtype, requires_grad=requires_grad) + + def gen_shape(shape): + yield shape + # Broadcast + yield (*shape[:-1], 1) + yield shape[:-1] + + def gen_shape_kwargs(): + for s, r in _generate_sample_shape_reduction(): + for t_s, v_s in product(gen_shape(s), gen_shape(s)): + yield _make_tensor(s), _make_tensor(t_s), make_var(v_s), dict(reduction=r) + yield ( + _make_tensor(s), _make_tensor(t_s), make_var(v_s), + dict(full=True, reduction=r) + ) + yield ( + _make_tensor(s), _make_tensor(t_s), make_var(v_s), + dict(eps=random.uniform(1e-6, 1e-3), reduction=r) + ) + yield ( + _make_tensor(s), _make_tensor(t_s), make_var(v_s), + dict(full=True, eps=random.uniform(1e-6, 1e-3), reduction=r) + ) + + for input, target, var, kwargs in gen_shape_kwargs(): + yield SampleInput(input, args=(target, var, ), kwargs=kwargs) + +def error_inputs_gaussian_nll_loss(op_info, device, **kwargs): + _make = partial(make_tensor, device=device, dtype=torch.float32) + + # invalid reduction value + yield ErrorInput(SampleInput(_make(10, 2, 3), _make(10, 2, 3), _make((10, 2, 3), low=0), reduction="abc"), + error_type=ValueError, error_regex="abc is not valid") + + # var is of incorrect shape + yield ErrorInput(SampleInput(_make(10, 2, 3), _make(10, 2, 3), _make((10, 2, 2), low=0)), + error_type=ValueError, error_regex="var is of incorrect size") + + # target is of incorrect shape + yield ErrorInput(SampleInput(_make(10, 2, 3), _make(10, 2, 2), _make((10, 2, 3), low=0)), + error_type=RuntimeError, + error_regex=(r"The size of tensor a \(3\) must match the size of tensor b \(2\) " + r"at non-singleton dimension 2")) + +def _generate_sample_inputs_nn_loss(op_info, device, dtype, requires_grad, **kwargs): + _make_tensor = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + for s, r in _generate_sample_shape_reduction(): + yield _make_tensor(s), _make_tensor(s), dict(reduction=r) + +def sample_inputs_hinge_embedding_loss(op_info, device, dtype, requires_grad, **kwargs): + for input, target, d in _generate_sample_inputs_nn_loss(op_info, device, dtype, requires_grad, **kwargs): + # target should contain either 1 or -1 as per docs + mask = torch.rand_like(target) > 0.5 + target[mask] = 1 + target[~mask] = -1 + d['margin'] = random.uniform(-9, 9) + yield SampleInput(input, args=(target, ), kwargs=d) + + # scalar input and target. + _make_tensor = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + yield SampleInput(_make_tensor(()), args=(_make_tensor(()), )) + +def error_inputs_hinge_embedding_loss(op, device, **kwargs): + make_input = partial(make_tensor, device=device, dtype=torch.float32) + # invalid reduction value + yield ErrorInput(SampleInput(make_input(5, 4), args=(make_input(5, 4),), kwargs={'reduction': 'abc'}), + error_type=ValueError, error_regex='is not a valid value') + +def reference_inputs_hinge_embedding_loss(op, device, dtype, requires_grad, **kwargs): + yield from sample_inputs_hinge_embedding_loss(op, device, dtype, requires_grad, **kwargs) + make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + for reduction in ('sum', 'mean', 'none'): + if dtype.is_floating_point: # only supports ints and floats + # NaN propagation + inp = make_input((10, )) + inp[2] = float('nan') + target = make_input((10, )) + # target should contain either 1 or -1 as per docs + mask = torch.rand_like(target) > 0.5 + target[mask] = -1 + target[~mask] = 1 + yield SampleInput(inp, args=(target,), kwargs={'reduction': reduction}) + + # Inf Handling + inp = make_input((10, )) + inp[4] = float('inf') + target = make_input((10, )) + mask = torch.rand_like(target) > 0.5 + target[mask] = -1 + target[~mask] = 1 + yield SampleInput(inp, args=(target,), kwargs={'reduction': reduction}) + + # Broadcasting + inp = make_input((5, 5)) + target = make_input((1, 5)) + mask = torch.rand_like(target) > 0.5 + target[mask] = -1 + target[~mask] = 1 + yield SampleInput(inp, args=(target,), kwargs={'reduction': reduction}) + +def sample_inputs_huber_loss(op_info, device, dtype, requires_grad, **kwargs): + for input, target, d in _generate_sample_inputs_nn_loss(op_info, device, dtype, requires_grad, **kwargs): + d['delta'] = random.uniform(1e-3, 9) + yield SampleInput(input, args=(target, ), kwargs=d) + +def error_inputs_huber_loss(op, device, **kwargs): + make_input = partial(make_tensor, device=device, dtype=torch.float32) + # invalid reduction value + err = 'is not a valid value for reduction' + yield ErrorInput(SampleInput(make_input(5, 4), args=(make_input(5, 4),), kwargs={'reduction': 'abc'}), + error_type=ValueError, error_regex=err) + # delta <= 0 + for delta in (0, -1): + err = 'huber_loss does not support non-positive values for delta.' + yield ErrorInput(SampleInput(make_input(5, 4), args=(make_input(5, 4),), kwargs={'delta': delta}), + error_type=RuntimeError, error_regex=err) + +def sample_inputs_poisson_nll_loss(op_info, device, dtype, requires_grad, **kwargs): + _make_tensor = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + def gen_shape_kwargs(): + for s, r in _generate_sample_shape_reduction(): + for li in (True, False): + for f in (True, False): + i1 = _make_tensor(s) + i2 = _make_tensor(s) + # For Poisson NLL Loss, + # target is assumed to be from + # Poisson Distribution which + # always has positive samples + t1 = _make_tensor(s, low=0) + t2 = _make_tensor(s, low=0) + + if not li: + i1.abs_() + i2.abs_() + t1.abs_() + t2.abs_() + + yield ( + i1, t1, + dict(log_input=li, full=f, reduction=r) + ) + yield ( + i2, t2, + dict(log_input=li, full=f, + eps=random.uniform(1e-8, 1e-3), + reduction=r) + ) + + for input, target, kwargs in gen_shape_kwargs(): + yield SampleInput(input, args=(target, ), kwargs=kwargs) + + # test INT_TO_FLOAT promotion + if dtype.is_complex: + for d in (torch.bool, torch.int64): + yield SampleInput(_make_tensor(dtype=dtype), args=(_make_tensor(dtype=d),)) + yield SampleInput(_make_tensor(dtype=d), args=(_make_tensor(dtype=dtype),)) + +def error_inputs_poisson_nll_loss(op_info, device, **kwargs): + make = partial(make_tensor, device=device, dtype=torch.float32) + + # invalid reduction value + yield ErrorInput(SampleInput(make(5, 4), args=(make(5, 4),), + kwargs={'reduction': 'abc'}), + error_type=ValueError, + error_regex='abc is not a valid value for reduction') + # invalid input shapes + yield ErrorInput(SampleInput(make(5, 4), args=(make(5,),)), + error_regex=(r'(Attempting to broadcast a dimension of length|' + r'The size of tensor a \(5\) must match the ' + r'size of tensor b \(4\) at non-singleton ' + r'dimension 1)')) + +def error_inputs_soft_margin_loss(op_info, device, **kwargs): + make = partial(make_tensor, device=device, dtype=torch.float32) + + # invalid reduction value + yield ErrorInput(SampleInput(make(5, 4), args=(make(5, 4),), + kwargs={'reduction': 'abc'}), + error_type=ValueError, + error_regex='abc is not a valid value for reduction') + # invalid input shapes + yield ErrorInput(SampleInput(make(5, 4), args=(make(5,),)), + error_regex=(r'(Attempting to broadcast a dimension of length|' + r'The size of tensor a \(4\) must match the ' + r'size of tensor b \(5\) at non-singleton ' + r'dimension 1)')) + +def sample_inputs_triplet_margin_loss(op_info, device, dtype, requires_grad, with_distance=False, **kwargs): + make = partial(make_tensor, (S, M), device=device, dtype=dtype, requires_grad=requires_grad) + + kwargss = ( + *[dict(margin=margin) for margin in (1e-6, 1.0, 10.0)], + dict(swap=True), + *[dict(reduction=reduction) for reduction in ("mean", "sum", "none")], + ) + + for kwargs in kwargss: + input = make() + args = (make(), make()) + if with_distance: + kwargs["distance_function"] = torch.nn.PairwiseDistance() + yield SampleInput(input, args=args, kwargs=kwargs) + +def error_inputs_triplet_margin_loss(op_info, device, **kwargs): + make_input = partial(make_tensor, device=device, dtype=torch.float32) + + samples = ( + # input, args, kwargs, error_type, error_regex + # invalid reduction + (make_input(3, 4), (make_input(3, 4), make_input(3, 4)), + dict(reduction="abc"), + ValueError, "abc is not a valid value for reduction"), + + # invalid margin + (make_input(3, 4), (make_input(3, 4), make_input(3, 4)), + dict(margin=-1.0), + ValueError, "margin must be greater than 0, got -1.0"), + + # shape mismatch + (make_input(3, 5), (make_input(3, 4), make_input(3, 4)), + {}, + RuntimeError, + (r'(Attempting to broadcast a dimension of length|' + r"The size of tensor a \(5\) must match the size of tensor b \(4\) " + r"at non-singleton dimension 1)")), + (make_input(3, 4), (make_input(3, 5), make_input(3, 4)), + {}, + RuntimeError, + (r'(Attempting to broadcast a dimension of length|' + r"The size of tensor a \(4\) must match the size of tensor b \(5\) " + r"at non-singleton dimension 1)")), + (make_input(3, 4), (make_input(3, 4), make_input(3, 5)), + {}, + RuntimeError, + (r'(Attempting to broadcast a dimension of length|' + r"The size of tensor a \(4\) must match the size of tensor b \(5\) " + r"at non-singleton dimension 1)")), + + # different dimensions + (make_input(3,), (make_input(3, 4), make_input(3, 4)), + {}, + RuntimeError, + (r"The anchor, positive, and negative tensors are expected to have " + r"the same number of dimensions, but got: anchor 1D, positive 2D, " + r"and negative 2D inputs")), + (make_input(3, 4), (make_input(3,), make_input(3, 4)), + {}, + RuntimeError, + (r"The anchor, positive, and negative tensors are expected to have " + r"the same number of dimensions, but got: anchor 2D, positive 1D, " + r"and negative 2D inputs")), + (make_input(3, 4), (make_input(3, 4), make_input(3,)), + {}, + RuntimeError, + (r"The anchor, positive, and negative tensors are expected to have " + r"the same number of dimensions, but got: anchor 2D, positive 2D, " + r"and negative 1D inputs")), + ) + + for input, args, kwargs, error_type, error_regex in samples: + yield ErrorInput(SampleInput(input, args=args, kwargs=kwargs), + error_type=error_type, error_regex=error_regex) + +def sample_inputs_scaled_mm(op_info, device, dtype, requires_grad, **kwargs): + make_mat_e4m3 = partial(make_tensor, device=device, dtype=torch.float8_e4m3fn, requires_grad=requires_grad) + make_mat_e5m2 = partial(make_tensor, device=device, dtype=torch.float8_e5m2, requires_grad=requires_grad) + make_scale = partial(make_tensor, device=device, dtype=torch.float, requires_grad=False) + M, N, K = 15, 32, 16 + samples = [] + # two e4m3 + mat1 = make_mat_e4m3((M, K)) + mat2 = make_mat_e4m3((K, N)).t().contiguous().t() + scale1 = make_scale((1,)) + scale2 = make_scale((1,)) + samples.append(SampleInput(mat1, mat2, scale1, scale2)) + # mat1 e4m3 mat2 e5m2 + mat1 = make_mat_e4m3((M, K)) + mat2 = make_mat_e5m2((K, N)).t().contiguous().t() + scale1 = make_scale((1,)) + scale2 = make_scale((1,)) + samples.append(SampleInput(mat1, mat2, scale1, scale2)) + # mat1 e5m2 mat2 e4m3 + mat1 = make_mat_e5m2((M, K)) + mat2 = make_mat_e4m3((K, N)).t().contiguous().t() + scale1 = make_scale((1,)) + scale2 = make_scale((1,)) + samples.append(SampleInput(mat1, mat2, scale1, scale2)) + + yield from samples + +def sample_inputs_scaled_mm_v2(op_info, device, dtype, requires_grad, **kwargs): + from torch.nn.functional import ScalingType, SwizzleType + make_mat_e4m3 = partial(make_tensor, device=device, dtype=torch.float8_e4m3fn, requires_grad=requires_grad) + + make_scale = partial(make_tensor, device=device, dtype=torch.float, requires_grad=False) + + M, N, K = 15, 32, 16 + samples = [] + # two e4m3 tensorwise + mat1 = make_mat_e4m3((M, K)) + mat2 = make_mat_e4m3((K, N)).t().contiguous().t() + scale1 = make_scale((1,)) + scale2 = make_scale((1,)) + samples.append( + SampleInput( + mat1, + mat2, + [scale1, ], + [ScalingType.TensorWise, ], + [SwizzleType.NO_SWIZZLE, ], + [scale2, ], + [ScalingType.TensorWise, ], + [SwizzleType.NO_SWIZZLE, ], + None, # bias + torch.bfloat16, # out_dtype + ) + ) + # two e4m3 rowwise + mat1 = make_mat_e4m3((M, K)) + mat2 = make_mat_e4m3((K, N)).t().contiguous().t() + scale1 = make_scale((M, 1)) + scale2 = make_scale((1, N)) + samples.append( + SampleInput( + mat1, + mat2, + [scale1, ], + [ScalingType.RowWise, ], + [SwizzleType.NO_SWIZZLE, ], + [scale2, ], + [ScalingType.RowWise, ], + [SwizzleType.NO_SWIZZLE, ], + None, # bias + torch.bfloat16, # out_dtype + ) + ) + M, K, N = 256, 512, 768 + mat1 = make_mat_e4m3((M, K)) + mat2 = make_mat_e4m3((K, N)).t().contiguous().t() + + dmajor, dminor = torch.cuda.get_device_capability() + + if dmajor == 9 and not torch.version.hip: + # 1x128 x 1x128 + scale1 = make_scale((K // 128, M)).t() + scale2 = make_scale((K // 128, N)).t() + samples.append( + SampleInput( + mat1, + mat2, + [scale1, ], + [ScalingType.BlockWise1x128, ], + [SwizzleType.NO_SWIZZLE, ], + [scale2, ], + [ScalingType.BlockWise1x128, ], + [SwizzleType.NO_SWIZZLE, ], + None, # bias + torch.bfloat16, # out_dtype + ) + ) + # 128x128 x 1x128 + L4 = round_up(K // 128, 4) + scale1 = make_scale((M // 128, L4)).t() + scale2 = make_scale((K // 128, N)).t() + samples.append( + SampleInput( + mat1, + mat2, + [scale1, ], + [ScalingType.BlockWise128x128, ], + [SwizzleType.NO_SWIZZLE, ], + [scale2, ], + [ScalingType.BlockWise1x128, ], + [SwizzleType.NO_SWIZZLE, ], + None, # bias + torch.bfloat16, # out_dtype + ) + ) + # 1x128 x 128x128 + L4 = round_up(K // 128, 4) + scale1 = make_scale((K // 128, M)).t() + scale2 = make_scale((N // 128, L4)).t() + samples.append( + SampleInput( + mat1, + mat2, + [scale1, ], + [ScalingType.BlockWise1x128, ], + [SwizzleType.NO_SWIZZLE, ], + [scale2, ], + [ScalingType.BlockWise128x128, ], + [SwizzleType.NO_SWIZZLE, ], + None, # bias + torch.bfloat16, # out_dtype + ) + ) + + if dmajor >= 10: + # MXFP8 + scale1 = make_scale((M, K // 32)).to(torch.float8_e8m0fnu) + scale2 = make_scale((K // 32, N)).to(torch.float8_e8m0fnu) + samples.append( + SampleInput( + mat1, + mat2, + [scale1, ], + [ScalingType.BlockWise1x32, ], + [SwizzleType.SWIZZLE_32_4_4, ], + [scale2, ], + [ScalingType.BlockWise1x32, ], + [SwizzleType.SWIZZLE_32_4_4, ], + None, # bias + torch.bfloat16, # out_dtype + ) + ) + # NVFP4 + # [M, K] -> [M, K // 2] + # [K, N] -> [K // 2, N] + mat1_fp4 = _bfloat16_to_float4_e2m1fn_x2(mat1.to(torch.bfloat16)) + mat2_fp4 = _bfloat16_to_float4_e2m1fn_x2(mat2.to(torch.bfloat16).t()).t() + scale1 = make_scale((M, K // 16)).to(torch.float8_e4m3fn) + global_scale1 = make_scale((1, )) + scale2 = make_scale((K // 16, N)).to(torch.float8_e4m3fn) + global_scale2 = make_scale((1, )) + samples.append( + SampleInput( + mat1_fp4, + mat2_fp4, + [scale1, global_scale1], + [ScalingType.BlockWise1x16, ScalingType.TensorWise], + [SwizzleType.SWIZZLE_32_4_4, ], + [scale2, global_scale2], + [ScalingType.BlockWise1x16, ScalingType.TensorWise], + [SwizzleType.SWIZZLE_32_4_4, ], + None, # bias + torch.bfloat16, # out_dtype + ) + ) + + + yield from samples + +def sample_inputs_scaled_dot_product_attention(op_info, device, dtype, requires_grad, **kwargs): + make = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + batch, seq_q, seq_kv, num_heads, head_dim = 4, 3, 6, 4, 8 + + dim_3_q_shape = (batch, seq_q, head_dim) + dim_3_kv_shape = (batch, seq_kv, head_dim) + dim_4_q_shape = (batch, num_heads, seq_q, head_dim) + dim_4_kv_shape = (batch, num_heads, seq_kv, head_dim) + + broadcast_tuple = ((num_heads, seq_q, head_dim), (batch, num_heads, seq_kv, head_dim)) + + qkv_shapes = [(dim_3_q_shape, dim_3_kv_shape), (dim_4_q_shape, dim_4_kv_shape), broadcast_tuple] + samples = [] + gqa_options = [True, False] + causal_options = [True, False] + for qkv_shape, is_causal, dropout_p, _enable_gqa in product( + qkv_shapes, causal_options, [0.0, 0.5], gqa_options): + shape_q, shape_kv = qkv_shape + samples.append(SampleInput( + make(shape_q), + make(shape_kv), + make(shape_kv), + is_causal=is_causal, + dropout_p=dropout_p + )) + + # Add non standard shapes + # FIXME(rec): should diff_v_head_dim be appended to samples? + diff_v_head_dim = SampleInput( # noqa: F841 + make((batch, num_heads, seq_q, head_dim)), + make((batch, num_heads, seq_kv, head_dim)), + make((batch, num_heads, seq_kv, head_dim + 8)), + is_causal=is_causal, + dropout_p=dropout_p + ) + + # Add an attn_mask + samples.append( + SampleInput( + make((batch, num_heads, seq_q, head_dim)), + make((batch, num_heads, seq_kv, head_dim)), + make((batch, num_heads, seq_kv, head_dim)), + attn_mask=make((seq_q, seq_kv)), + is_causal=False, + dropout_p=0.0) + ) + + yield from samples + + +def sample_inputs_efficient_attention_forward(op_info, device, dtype, requires_grad, **kwargs): + make = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + batch, num_heads, head_dim = 4, 4, 8 + seq_q = 11 + seq_kv = 32 + + dim_4_q_shape = (batch, num_heads, seq_q, head_dim) + dim_4_kv_shape = (batch, num_heads, seq_kv, head_dim) + + qkv_shapes = [(dim_4_q_shape, dim_4_kv_shape)] + samples = [] + mask_types = [1, 2] # UpperLeft, LowerRight + scales = [None, 1.0] + + for qkv_shape, _is_causal, dropout_p, mask_type, scale in product( + qkv_shapes, [True, False], [0.0, 0.5], mask_types, scales): + shape_q, shape_kv = qkv_shape + samples.append(SampleInput( + make(shape_q).transpose(1, 2), + make(shape_kv).transpose(1, 2), + make(shape_kv).transpose(1, 2), + bias=None, + cu_seqlens_q=None, + cu_seqlens_k=None, + max_seqlen_q=None, + max_seqlen_k=None, + dropout_p=dropout_p, + custom_mask_type=mask_type, + compute_log_sumexp=requires_grad, + scale=scale, + seqlen_k=None + )) + + # Add non standard shapes + # FIXME(rec): should diff_v_head_dim be appended to samples? + diff_v_head_dim = SampleInput( # noqa: F841 + make((batch, seq_q, num_heads, head_dim)), + make((batch, seq_kv, num_heads, head_dim)), + make((batch, seq_kv, num_heads, head_dim + 8)), + bias=None, + cu_seqlens_q=None, + cu_seqlens_k=None, + max_seqlen_q=None, + max_seqlen_k=None, + dropout_p=dropout_p, + custom_mask_type=0, # No Mask + compute_log_sumexp=requires_grad, + scale=None, + seqlen_k=None + ) + + # Add an attn_mask + samples.append( + SampleInput( + make((batch, seq_q, num_heads, head_dim)), + make((batch, seq_kv, num_heads, head_dim)), + make((batch, seq_kv, num_heads, head_dim)), + bias=make(batch, num_heads, seq_q, seq_kv), + cu_seqlens_q=None, + cu_seqlens_k=None, + max_seqlen_q=None, + max_seqlen_k=None, + dropout_p=dropout_p, + custom_mask_type=0, # No Mask + compute_log_sumexp=requires_grad, + scale=None, + seqlen_k=None + ) + ) + + # jagged (with query/keys offsets) + cu_seqlens_k = torch.arange(-1, 32 * 2 + 1, 2, dtype=torch.int32, device=device) + cu_seqlens_k[-1] = 62 + cu_seqlens_k[0] = 0 + samples.append( + SampleInput( + make((32, 2, 64)).view(-1, 8, 8).unsqueeze(0), + make((64, 64)).view(-1, 8, 8).unsqueeze(0), + make((64, 64)).view(-1, 8, 8).unsqueeze(0), + bias=None, + cu_seqlens_q=torch.arange(0, 32 * 2 + 2, 2, dtype=torch.int32, device=device), + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=2, + max_seqlen_k=2, + dropout_p=0.0, + custom_mask_type=0, # No Mask + compute_log_sumexp=requires_grad, + scale=None, + seqlen_k=None, + ) + ) + + yield from samples + +def sample_inputs_flash_attention_forward(op_info, device, dtype, requires_grad, **kwargs): + make = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + batch, num_heads, head_dim = 4, 4, 8 + seq_q = 11 + seq_kv = 32 + + dim_4_q_shape = (batch, num_heads, seq_q, head_dim) + dim_4_kv_shape = (batch, num_heads, seq_kv, head_dim) + + qkv_shapes = [(dim_4_q_shape, dim_4_kv_shape)] + samples = [] + scales = [None, 1.0] + + for qkv_shape, is_causal, dropout_p, scale in product( + qkv_shapes, [True, False], [0.0, 0.5], scales): + shape_q, shape_kv = qkv_shape + samples.append(SampleInput( + make(shape_q).transpose(1, 2), + make(shape_kv).transpose(1, 2), + make(shape_kv).transpose(1, 2), + cum_seq_q=None, + cum_seq_k=None, + max_q=seq_q, + max_k=seq_kv, + dropout_p=dropout_p, + is_causal=is_causal, + return_debug_mask=False, + scale=scale, + )) + + yield from samples + +def sample_inputs_pairwise_distance(op_info, device, dtype, requires_grad, **kwargs): + make = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + shape = (3,) + batched_shape = (2, *shape) + shapes_and_kwargs = [ + (shape, None), + (batched_shape, None), + (shape, dict(keepdim=True)), + (batched_shape, dict(keepdim=True)), + (shape, dict(p=5.0)), + (shape, dict(p=-1.0)), + (shape, dict(eps=1.0)), + ] + + return ( + SampleInput(make(shape), args=(make(shape),), kwargs=kwargs) for shape, kwargs in shapes_and_kwargs + ) + +def sample_inputs_pixel_shuffle(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + yield from ( + SampleInput(make_arg((1, 9, 2, 2)), upscale_factor=upscale_factor) + for upscale_factor in (1, 3) + ) + yield from ( + SampleInput(make_arg(shape), upscale_factor=1) + for shape in [ + (1, 0, 1, 1), + (1, 1, 0, 1), + (1, 1, 1, 0), + ] + ) + +def sample_inputs_pixel_unshuffle(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + yield from ( + SampleInput(make_arg((1, 1, 6, 6)), downscale_factor=downscale_factor) + for downscale_factor in (1, 3) + ) + yield from ( + SampleInput(make_arg(shape), downscale_factor=1) + for shape in [ + (1, 0, 1, 1), + (1, 1, 0, 1), + (1, 1, 1, 0), + ] + ) + +def sample_inputs_channel_shuffle(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + shapes_groups = [ + ((1, 4, 10, 10), 2), + ((2, 6, 8, 8), 3), + ((2, 8, 5, 5), 4), + ] + + yield from ( + SampleInput(make_arg(shape), args=(groups,)) + for shape, groups in shapes_groups + ) + +def sample_inputs_binary_cross_entropy(op_info, device, dtype, requires_grad, logits=False, **kwargs): + make = partial(make_tensor, device=device, dtype=dtype) + # Lower bounds must be greater than 'eps' defined in gradcheck.py::gradgradcheck() -> eps + # otherwise perturbation calculation causes Tensor value to become negative triggering + # a device-side hardware assertion + make_prob = partial(make, low=1e-6, high=1) + + reductions = ("mean", "sum", "none") + + shapes_and_kwargs = [ + *[(shape, None) for shape in ((), (1,), (S,), (S, S), (S, S, S))], + *[((S, S), dict(reduction=reduction)) for reduction in reductions], + *[((S, S), dict(reduction=reduction, weight=make((S, S)))) for reduction in reductions], + ] + + if logits: + shapes_and_kwargs.extend( + [((S, S), dict(reduction=reduction, pos_weight=make((S,), low=0))) for reduction in reductions] + ) + + for shape, kwargs in shapes_and_kwargs: + yield SampleInput( + (make if logits else make_prob)(shape, requires_grad=requires_grad), + args=(make_prob(shape, requires_grad=requires_grad),), + kwargs=kwargs, + ) + +def sample_inputs_allclose(op_info, device, dtype, requires_grad, **kwargs): + sample_shapes = [(), (S), (S, S, S)] + atols = [1e-2, 1e-16] + rtols = [1e-1, 0.5] + for s, rtol, atol in product(sample_shapes, rtols, atols): + # close sample + t = make_tensor(s, device=device, dtype=dtype, requires_grad=requires_grad) + close = (t + atol).detach().requires_grad_(requires_grad) + yield SampleInput(t, close, rtol=rtol, atol=atol) + + # random sample + a = make_tensor(s, device=device, dtype=dtype, requires_grad=requires_grad) + b = make_tensor(s, device=device, dtype=dtype, requires_grad=requires_grad) + yield SampleInput(a, b, rtol=rtol, atol=atol) + + +def sample_inputs_l1_loss(op_info, device, dtype, requires_grad, **kwargs): + yield from sample_inputs_loss(op_info, device, dtype, requires_grad, **kwargs) + + # test COMPLEX_TO_FLOAT promotion + if dtype.is_complex: + make = partial(make_tensor, (), device=device, requires_grad=requires_grad) + yield SampleInput(make(dtype=dtype), args=(make(dtype=torch.double),)) + yield SampleInput(make(dtype=torch.double), args=(make(dtype=dtype),)) + +def error_inputs_l1_loss(op_info, device, **kwargs): + make = partial(make_tensor, device=device, dtype=torch.float32) + + # invalid reduction value + yield ErrorInput(SampleInput(make(5, 4), args=(make(5, 4),), + kwargs={'reduction': 'abc'}), + error_type=ValueError, + error_regex='abc is not a valid value for reduction') + # invalid input shapes + yield ErrorInput(SampleInput(make(5, 4), args=(make(5,),)), + error_regex=(r'(Attempting to broadcast a dimension of length|' + r'The size of tensor a \(4\) must match the ' + r'size of tensor b \(5\) at non-singleton ' + r'dimension 1)') + ) + +def sample_inputs_smooth_l1_loss(op_info, device, dtype, requires_grad, **kwargs): + yield from sample_inputs_loss(op_info, device, dtype, requires_grad, **kwargs) + + make = partial(make_tensor, (S, S), device=device, dtype=dtype, requires_grad=requires_grad) + + # This test case always triggers the smooth condition, since absolute difference of input and target + # is smaller than beta + yield SampleInput(make(low=0, high=2), args=(make(low=-2, high=0),), kwargs=dict(beta=5)) + yield SampleInput(make(), args=(make(),), kwargs=dict(beta=0)) + +def sample_inputs_kl_div(op_info, device, dtype, requires_grad, **kwargs): + # kl_div works with inputs in [0, 1] (aka the pdf of a probability measure) + # Then log [0, 1] = (-inf, 0], so this is the log space + make_arg = partial(make_tensor, low=0., device=device, dtype=dtype, requires_grad=requires_grad) + + def make_log(shape): + out = torch.nn.functional.log_softmax(make_arg(shape), -1) + out.requires_grad_(requires_grad) + return out + + def make_prob(shape): + out = torch.nn.functional.softmax(make_arg(shape), -1) + out.requires_grad_(requires_grad) + return out + + shapes = ((2,), (2, 3)) + reductions = ("none", "mean", "batchmean", "sum") + for shape, reduction, log_target in product(shapes, reductions, (True, False)): + input = make_log(shape) + target = make_log(shape) if log_target else make_prob(shape) + yield SampleInput(input, args=(target,), kwargs=dict(reduction=reduction, log_target=log_target)) + +def sample_inputs_pdist(op_info, device, dtype, requires_grad, **kwargs): + make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + yield from (SampleInput(make_input((n, m))) for n, m in itertools.product((1, S), repeat=2)) + yield from (SampleInput(make_input((S, S)), kwargs=dict(p=p)) for p in (0.0, 1.0, 2.0, 10.0, float("inf"))) + +def reference_pdist(input, p=2): + pdist = scipy.spatial.distance.pdist + if p == 0: + output = pdist(input, "hamming") * input.shape[1] + elif p == float("inf"): + output = pdist(input, lambda x, y: np.abs(x - y).max()) + else: + output = pdist(input, "minkowski", p=p) + return output.astype(input.dtype) + +def sample_inputs_diagflat(op_info, device, dtype, requires_grad, **kwargs): + make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + yield SampleInput(make_input(())) + yield SampleInput(make_input((2,))) + yield SampleInput(make_input((2, 2))) + yield SampleInput(make_input((2,)), offset=1) + yield SampleInput(make_input((2,)), offset=-1) + + +_UNPOOL_NAME_TO_DIM = { + 'nn.functional.max_unpool1d': 1, + 'nn.functional.max_unpool2d': 2, + 'nn.functional.max_unpool3d': 3 +} + + +def error_inputs_max_unpool(op_info, device, **kwargs): + """Error inputs for max_unpool: shape mismatch between input and indices.""" + make_arg = partial(make_tensor, device=device, dtype=torch.float32) + pool_dim = _UNPOOL_NAME_TO_DIM[op_info.name] + + # Create mismatched shapes for input and indices + kwargs_dict = {'kernel_size': 3, 'stride': 2, 'padding': 0} + if pool_dim == 1: + input_shape = (8, 8) + indices_shape = (8, 7) + elif pool_dim == 2: + input_shape = (1, 1, 4, 4) + indices_shape = (1, 1, 4, 1) + else: # pool_dim == 3 + input_shape = (1, 1, 4, 4, 4) + indices_shape = (1, 1, 4, 4, 1) + + yield ErrorInput( + SampleInput( + make_arg(input_shape), + args=(torch.zeros(indices_shape, device=device, dtype=torch.long),), + kwargs=kwargs_dict + ), + error_type=RuntimeError, + error_regex='Expected shape of indices to be' + ) + + +def sample_inputs_max_unpool(op_info, device, dtype, requires_grad, **kwargs): + unpool_name_to_pool_method_dict = { + 'nn.functional.max_unpool1d': torch.nn.functional.max_pool1d, + 'nn.functional.max_unpool2d': torch.nn.functional.max_pool2d, + 'nn.functional.max_unpool3d': torch.nn.functional.max_pool3d + } + + unpool_to_pool_name_dict = {k: f'nn.functional.{v.__name__}' for k, v in unpool_name_to_pool_method_dict.items()} + + pool_dim = _UNPOOL_NAME_TO_DIM[op_info.name] + pool_method = unpool_name_to_pool_method_dict[op_info.name] + + pool_op_info = copy.copy(op_info) + pool_op_info.name = unpool_to_pool_name_dict[op_info.name] + + for sample in sample_inputs_max_pool(pool_op_info, device, dtype, requires_grad, **kwargs): + # shapes (C, ...) do not work as of now, + # see https://github.com/pytorch/pytorch/issues/68337 + # TODO: remove once the issue is resolved + if sample.input.dim() != pool_dim + 2: + continue + + # No dilation > 1 for max_unpool, + # see https://github.com/pytorch/pytorch/issues/68420 + if sample.kwargs['dilation'] != 1: + continue + + # Can't unpool without indices + if sample.kwargs['return_indices']: + pool, indices = pool_method(sample.input, **sample.kwargs) + # arg has to be a leaf + arg = pool.detach().requires_grad_(requires_grad) + sample_kwargs = { + 'kernel_size': sample.kwargs['kernel_size'], + 'stride': sample.kwargs['stride'], + 'padding': sample.kwargs['padding'], + # output_size could be None but we specify it explicitly + # to compensate for the information lose in pool due + # to the floor/ceil operation used to compute the shapes + 'output_size': sample.input.size() + } + + yield SampleInput(arg, args=(indices,), kwargs=sample_kwargs) + +def sample_inputs_max_unpool_grad(op_info, device, dtype, requires_grad, **kwargs): + for sample in sample_inputs_max_unpool(op_info, device, dtype, requires_grad, **kwargs): + indices = sample.args[0] + # The samples for max_unpool are generated with max_pool. + # It could be that a single element from the max_pool's + # input is mapped to several locations in its output. + # This situation leads to failed gradchecks because + # the finite difference algorithm perturbs the elements + # of the output one by one, and not in classes of + # equivalences determined by whether two elements + # in the output are coming from the same location in the + # input (simply put, they have the same corresponding index). + # So, there are two ways to resolve this issue: + # 1. Extract a perturbation for one element and apply it all + # the elements from the same equivalence class, or + # 2. Make sure that the equivalence classes are all singletons, + # i.e. the index tensor has to be comprised of only unique + # indices. + # Here we go with the solution 2, the easiest of all. + if indices.unique().numel() == indices.numel(): + yield sample + +def sample_inputs_multi_head_attention_forward(opinfo, device, dtype, requires_grad, **kwargs): + make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + if requires_grad: + # backward tests would take too long to complete, causing the job timeout. + bsz = 2 + is_batcheds = (True,) + use_separate_proj_weights = (False,) + emb_sizes = (2,) + src_lens = (XS,) + tgt_lens = (XS,) + heads = (2,) + dropouts = (0.5,) + mask_types = ("2d",) + else: + bsz = 2 + is_batcheds = (False, True) + use_separate_proj_weights = (False, True) + emb_sizes = (2, 4) + src_lens = (XS,) + tgt_lens = (XS, S) + heads = (1, 2) + dropouts = (0.0, 0.5) + mask_types = (None, "2d", "3d") + + for is_batched, use_separate_proj_weight, mask_type, emb_size, src_len, tgt_len, num_heads, dropout_p in itertools.product( + is_batcheds, use_separate_proj_weights, mask_types, emb_sizes, src_lens, tgt_lens, heads, dropouts + ): + attn_mask = None + if mask_type == "2d": + attn_mask = make_input(src_len, tgt_len) + elif mask_type == "3d": + attn_mask = make_input((bsz if is_batched else 1) * num_heads, src_len, tgt_len) + + if is_batched: + q = make_input(src_len, bsz, emb_size) + k = make_input(tgt_len, bsz, emb_size) + v = make_input(tgt_len, bsz, emb_size) + else: + q = make_input(src_len, emb_size) + k = make_input(tgt_len, emb_size) + v = make_input(tgt_len, emb_size) + if use_separate_proj_weight: + in_proj_weight = None + q_proj_weight = make_input(emb_size, emb_size) + k_proj_weight = make_input(emb_size, emb_size) + v_proj_weight = make_input(emb_size, emb_size) + else: + in_proj_weight = make_input(emb_size * 3, emb_size) + q_proj_weight = None + k_proj_weight = None + v_proj_weight = None + + bias_k = make_input(emb_size) + bias_v = make_input(emb_size) + in_proj_bias = make_input(emb_size * 3) + out_proj_weight = make_input(emb_size, emb_size) + out_proj_bias = make_input(emb_size) + sample_args = ( + k, v, emb_size, num_heads, in_proj_weight, + in_proj_bias, bias_k, bias_v, False, + dropout_p, out_proj_weight, out_proj_bias + ) + sample_kwargs = { + "q_proj_weight" : q_proj_weight, + "k_proj_weight" : k_proj_weight, + "v_proj_weight" : v_proj_weight, + "attn_mask" : attn_mask, + "training" : dropout_p > 0.0, + "use_separate_proj_weight" : use_separate_proj_weight + } + + yield SampleInput(q, args=sample_args, kwargs=sample_kwargs) + + +# Includes some values such that N * N won't be a multiple of 4, +# which should ensure we test the vectorized and non-vectorized +# kernel code paths. +NUM_SIZE0_TENSORS = 10000 +foreach_num_tensors = [20, 23] if not TEST_WITH_SLOW else [23, 30, 300] +_foreach_inputs_default_kwargs = {"noncontiguous": False, "same_size": False, "low": None, "high": None} + + +class ForeachRightmostArgType(enum.Enum): + TensorList = enum.auto() + ScalarList = enum.auto() + Scalar = enum.auto() + Tensor = enum.auto() + + +class ForeachSampleInput(SampleInput): + # For TensorList Scalar/Tensor, we compute the reference + # by converting it into TensorList ScalarList/TensorList and + # then converting into multiple Tensor Scalar/Tensor. + # ref_args contains the args converted to TensorList ScalarList/TensorList + ref_args: Any + disable_fastpath: bool + + def __init__(self, *args, disable_fastpath=False, ref_args=None, **kwargs): + super().__init__(*args, **kwargs) + self.ref_args = ref_args or self.args + self.disable_fastpath = disable_fastpath + + +class foreach_inputs_sample_func: + def __init__( + self, + arity: int, + rightmost_supports_scalar: bool, + rightmost_supports_scalarlist: bool, + rightmost_supports_tensor: bool = False, + ) -> None: + self.arity = arity + self._set_rightmost_arg_types( + rightmost_supports_scalar, rightmost_supports_scalarlist, rightmost_supports_tensor, + ) + self._intersperse_empty = (True, False) + + def _set_rightmost_arg_types( + self, + rightmost_supports_scalar: bool, + rightmost_supports_scalarlist: bool, + rightmost_supports_tensor: bool, + ) -> None: + self._rightmost_arg_types = [ForeachRightmostArgType.TensorList] + if self.arity > 1: + if rightmost_supports_scalar: + self._rightmost_arg_types.append(ForeachRightmostArgType.Scalar) + if rightmost_supports_scalarlist: + self._rightmost_arg_types.append(ForeachRightmostArgType.ScalarList) + if rightmost_supports_tensor: + self._rightmost_arg_types.append(ForeachRightmostArgType.Tensor) + + def _sample_rightmost_arg( + self, + opinfo, + rightmost_arg_type, + device, + dtype, + num_tensors, + allow_higher_dtype_scalars, + **_foreach_inputs_kwargs, + ): + if rightmost_arg_type == ForeachRightmostArgType.TensorList: + return [sample_inputs_foreach(None, device, dtype, num_tensors, **_foreach_inputs_kwargs)] + if rightmost_arg_type == ForeachRightmostArgType.Tensor: + return [make_tensor( + (), device=device, dtype=dtype, + noncontiguous=_foreach_inputs_kwargs["noncontiguous"], + requires_grad=_foreach_inputs_kwargs.get("requires_grad", False), + )] + should_use_simpler_scalars = opinfo.name == "_foreach_pow" and dtype in (torch.float16, torch.bfloat16) + + def sample_float(): + s = random.random() + if should_use_simpler_scalars: + return 1.0 if s > 0.5 else 2.0 + else: + return 1.0 - s + + high = 2 if should_use_simpler_scalars else 9 + if rightmost_arg_type == ForeachRightmostArgType.ScalarList: + scalarlist_list = [] + scalarlist_list.append([random.randint(0, high) + 1 for _ in range(num_tensors)]) + + if allow_higher_dtype_scalars or dtype.is_floating_point: + scalarlist_list.append([sample_float() for _ in range(num_tensors)]) + if allow_higher_dtype_scalars or dtype.is_complex: + scalarlist_list.append([complex(sample_float(), sample_float()) for _ in range(num_tensors)]) + scalarlist_list.append([1, 2.0, 3.0 + 4.5j] + [3.0 for _ in range(num_tensors - 3)]) + scalarlist_list.append([True, 1, 2.0, 3.0 + 4.5j] + [3.0 for _ in range(num_tensors - 4)]) + return scalarlist_list + if rightmost_arg_type == ForeachRightmostArgType.Scalar: + scalars = [] + scalars.append(random.randint(1, high + 1)) + if allow_higher_dtype_scalars or dtype.is_floating_point: + scalars.append(sample_float()) + if allow_higher_dtype_scalars or dtype.is_complex: + scalars.append(complex(sample_float(), sample_float())) + scalars.append(True) + return scalars + raise AssertionError(f"Invalid rightmost_arg_type of {rightmost_arg_type}") + + def _should_disable_fastpath(self, opinfo, rightmost_arg, rightmost_arg_type, dtype): + if self.arity == 1: + if "foreach_abs" in opinfo.name and dtype in complex_types(): + return True + # unary + if opinfo.ref in (torch.abs, torch.neg): + return False + if opinfo.ref_inplace == torch.Tensor.zero_: + return False + return dtype in integral_types_and(torch.bool) + if self.arity < 2 or rightmost_arg_type == ForeachRightmostArgType.Tensor: + return None + if "foreach_pow" in opinfo.name and dtype in integral_types_and(torch.bool): + return True + if any( + foreach_name in opinfo.name + for foreach_name in ("foreach_clamp_max", "foreach_clamp_min", "foreach_maximum", "foreach_minimum") + ) and dtype in integral_types_and(torch.bool): + return True + if rightmost_arg_type == ForeachRightmostArgType.TensorList: + disable_fastpath = "foreach_div" in opinfo.name and dtype in integral_types_and(torch.bool) + if "foreach_add" in opinfo.name and dtype == torch.bool: + disable_fastpath = True + return disable_fastpath + elif rightmost_arg_type == ForeachRightmostArgType.Scalar: + disable_fastpath = "foreach_div" in opinfo.name and dtype in integral_types_and(torch.bool) + if isinstance(rightmost_arg, bool): + disable_fastpath |= dtype == torch.bool + if opinfo.ref in (torch.add, torch.mul): + disable_fastpath = False + elif isinstance(rightmost_arg, int): + disable_fastpath |= dtype == torch.bool + elif isinstance(rightmost_arg, float): + disable_fastpath |= dtype in integral_types_and(torch.bool) + elif isinstance(rightmost_arg, complex): + disable_fastpath |= dtype not in complex_types() + else: + raise AssertionError(f"Invalid scalar of type {rightmost_arg_type} - {rightmost_arg}") + return disable_fastpath + elif rightmost_arg_type == ForeachRightmostArgType.ScalarList: + disable_fastpath = opinfo.ref == torch.div and dtype in integral_types_and(torch.bool) + elmt_t = type(rightmost_arg[0]) + has_same_type = all(isinstance(v, elmt_t) for v in rightmost_arg) + if not has_same_type: + return dtype not in complex_types() + if isinstance(rightmost_arg[0], bool): + if ("foreach_add" in opinfo.name or "foreach_mul" in opinfo.name) and dtype == torch.bool: + disable_fastpath = False + elif isinstance(rightmost_arg[0], int): + disable_fastpath |= dtype == torch.bool + elif isinstance(rightmost_arg[0], float): + disable_fastpath |= dtype in integral_types_and(torch.bool) + elif isinstance(rightmost_arg[0], complex): + disable_fastpath |= dtype not in complex_types() + else: + raise AssertionError(f"Invalid scalarlist of {rightmost_arg}") + return disable_fastpath + else: + raise AssertionError(f"Invalid rightmost_arg_type of {rightmost_arg_type}") + + def _sample_kwargs(self, opinfo, rightmost_arg, rightmost_arg_type, dtype): + kwargs = {} + if rightmost_arg_type == ForeachRightmostArgType.TensorList and opinfo.supports_alpha_param: + if dtype in integral_types_and(torch.bool): + kwargs["alpha"] = 3 + elif dtype.is_complex: + kwargs["alpha"] = complex(3, 3) + else: + kwargs["alpha"] = 3.14 + if self.arity > 1: + kwargs["disable_fastpath"] = self._should_disable_fastpath(opinfo, rightmost_arg, rightmost_arg_type, dtype) + return kwargs + + def sample_zero_size_tensor_inputs(self, opinfo, device, dtype, requires_grad, **kwargs): + assert "num_input_tensors" not in kwargs + _foreach_inputs_kwargs = {k: kwargs.pop(k, v) for k, v in _foreach_inputs_default_kwargs.items()} + _foreach_inputs_kwargs["requires_grad"] = requires_grad + allow_higher_dtype_scalars = kwargs.pop("allow_higher_dtype_scalars", False) + for _rightmost_arg_type in self._rightmost_arg_types: + zero_size_foreach_inputs_kwargs = copy.deepcopy(_foreach_inputs_kwargs) + zero_size_foreach_inputs_kwargs["zero_size"] = True + input = sample_inputs_foreach(None, device, dtype, NUM_SIZE0_TENSORS, **zero_size_foreach_inputs_kwargs) + if self.arity > 1: + args = [ + sample_inputs_foreach(None, device, dtype, NUM_SIZE0_TENSORS, **zero_size_foreach_inputs_kwargs) + for _ in range(self.arity - 2) + ] + args.append( + self._sample_rightmost_arg( + opinfo, + ForeachRightmostArgType.TensorList, + device, + dtype, + NUM_SIZE0_TENSORS, + allow_higher_dtype_scalars=allow_higher_dtype_scalars, + **zero_size_foreach_inputs_kwargs, + )[0]) + kwargs = self._sample_kwargs( + opinfo, args[-1], ForeachRightmostArgType.TensorList, dtype) + else: + args = [] + kwargs = {} + if opinfo.ref in (torch.abs, torch.neg): + kwargs["disable_fastpath"] = False + else: + kwargs["disable_fastpath"] = dtype in integral_types_and(torch.bool) + yield ForeachSampleInput(input, *args, **kwargs) + + def __call__(self, opinfo, device, dtype, requires_grad, **kwargs): + num_input_tensors_specified = "num_input_tensors" in kwargs + num_input_tensors = kwargs.pop("num_input_tensors") if num_input_tensors_specified else foreach_num_tensors + assert isinstance(num_input_tensors, list) + _foreach_inputs_kwargs = {k: kwargs.pop(k, v) for k, v in _foreach_inputs_default_kwargs.items()} + _foreach_inputs_kwargs["requires_grad"] = requires_grad + _foreach_inputs_kwargs["zero_size"] = False + allow_higher_dtype_scalars = kwargs.pop("allow_higher_dtype_scalars", False) + + # add empty tensor interspersion to test fully fixing #100701 + for num_tensors, rightmost_arg_type, intersperse_empty_tensors in itertools.product( + num_input_tensors, self._rightmost_arg_types, self._intersperse_empty): + if intersperse_empty_tensors and (num_tensors != max(num_input_tensors) or str(device) == 'cpu'): + # generate interspersed empty tensors for only 1 N on non-cpu device to lessen redundancy + continue + _foreach_inputs_kwargs["intersperse_empty_tensors"] = intersperse_empty_tensors + input = sample_inputs_foreach( + None, device, dtype, num_tensors, **_foreach_inputs_kwargs) + args = [] + if self.arity > 1: + args = [ + sample_inputs_foreach( + None, device, dtype, num_tensors, **_foreach_inputs_kwargs) + for _ in range(self.arity - 2) + ] + rightmost_arg_list = self._sample_rightmost_arg( + opinfo, rightmost_arg_type, device, dtype, num_tensors, allow_higher_dtype_scalars, + **_foreach_inputs_kwargs) + for rightmost_arg in rightmost_arg_list: + args.append(rightmost_arg) + kwargs = self._sample_kwargs(opinfo, rightmost_arg, rightmost_arg_type, dtype) + ref_args = args + if rightmost_arg_type in (ForeachRightmostArgType.Scalar, ForeachRightmostArgType.Tensor): + ref_args = args[:-1] + [[args[-1] for _ in range(num_tensors)]] + sample = ForeachSampleInput(input, *args, ref_args=ref_args, **kwargs) + yield sample + args.pop() + else: + yield ForeachSampleInput( + input, + *args, + disable_fastpath=self._should_disable_fastpath(opinfo, None, None, dtype), + ) + + +class foreach_max_sample_func(foreach_inputs_sample_func): + def __init__( + self, + arity: int, + rightmost_supports_scalar: bool, + rightmost_supports_scalarlist: bool, + rightmost_supports_tensor: bool = False, + ) -> None: + super().__init__(arity, rightmost_supports_scalar, rightmost_supports_scalarlist, rightmost_supports_tensor) + self._intersperse_empty = (False,) + + def sample_zero_size_tensor_inputs(self, opinfo, device, dtype, requires_grad, **kwargs): + return [] + + def _should_disable_fastpath(self, opinfo, rightmost_arg, rightmost_arg_type, dtype): + return False + + +class foreach_norm_sample_func(foreach_inputs_sample_func): + def sample_zero_size_tensor_inputs(self, opinfo, device, dtype, requires_grad, **kwargs): + assert "num_input_tensors" not in kwargs + _foreach_inputs_kwargs = {k: kwargs.pop(k, v) for k, v in _foreach_inputs_default_kwargs.items()} + _foreach_inputs_kwargs["requires_grad"] = requires_grad + for ord in (0, 1, 2, -1, -2, float('inf'), float('-inf')): + input = sample_inputs_foreach(None, device, dtype, NUM_SIZE0_TENSORS, zero_size=True, **_foreach_inputs_kwargs) + disable_fastpath = True + if ord in (1, 2, float('inf')) and dtype in floating_types_and(torch.half, torch.bfloat16): + disable_fastpath = False + yield ForeachSampleInput(input, ord=ord, disable_fastpath=disable_fastpath) + + def __call__(self, opinfo, device, dtype, requires_grad, **kwargs): + num_input_tensors = kwargs.pop("num_input_tensors", foreach_num_tensors) + assert isinstance(num_input_tensors, list) + _foreach_inputs_kwargs = {k: kwargs.pop(k, v) for k, v in _foreach_inputs_default_kwargs.items()} + _foreach_inputs_kwargs["requires_grad"] = requires_grad + _allow_higher_dtype_scalars = kwargs.pop("allow_higher_dtype_scalars", False) + + for num_tensors, ord, out_dtype, intersperse_empty_tensors in product( + num_input_tensors, + (0, 1, 2, -1, -2, float('inf'), float('-inf')), + (None,) + (torch.complex128,) if dtype in complex_types() else (torch.float64,), + (True, False), + ): + # inf norm and negative norms on empty tensors is not supported by our reference func vector norm: + # linalg.vector_norm cannot compute the inf norm on an empty tensor because the operation does not have an identity + if (ord in [float('inf'), float('-inf')] or ord < 0) and intersperse_empty_tensors: + continue + + _foreach_inputs_kwargs["intersperse_empty_tensors"] = intersperse_empty_tensors + input = sample_inputs_foreach(None, device, dtype, num_tensors, zero_size=False, **_foreach_inputs_kwargs) + disable_fastpath = True + if ord in (1, 2, float('inf')) and dtype in floating_types_and(torch.half, torch.bfloat16): + disable_fastpath = False + yield ForeachSampleInput(input, ord=ord, disable_fastpath=disable_fastpath, dtype=out_dtype) + + # Also test nan propagation with a single tensor, but skip autograd testing + if not requires_grad: + nan_inputs = [ + [float('nan')], + [float('nan'), 1.0], + [1.0, float('nan')], + [1.0, 2.0, 3.0, float('nan'), float('nan'), 7.0, float('nan'), float('nan'), -1.5, 6.0], + [7.0, 3.0, float('nan'), float('nan'), -1.5, 6.0], + [3.0, float('nan'), float('nan'), -1.5, 6.0], + ] + for input in nan_inputs: + x = torch.tensor(input, device=device) + disable_fastpath = True + if ord in (1, 2, float('inf')) and dtype in floating_types_and(torch.half, torch.bfloat16): + disable_fastpath = False + yield ForeachSampleInput([x], ord=ord, disable_fastpath=disable_fastpath) + + +class foreach_pointwise_sample_func(foreach_inputs_sample_func): + + def __init__( + self, + arity: int = 3, + rightmost_supports_scalar: bool = False, + rightmost_supports_scalarlist: bool = False, + ): + super().__init__(arity, rightmost_supports_scalar, rightmost_supports_scalarlist) + + def _should_disable_fastpath(self, opinfo, rightmost_arg, rightmost_arg_type, dtype): + return dtype in integral_types_and(torch.bool) and opinfo.ref == torch.addcmul + + def sample_zero_size_tensor_inputs(self, opinfo, device, dtype, requires_grad, **kwargs): + assert "num_input_tensors" not in kwargs + _foreach_inputs_kwargs = {k: kwargs.pop(k, v) for k, v in _foreach_inputs_default_kwargs.items()} + _foreach_inputs_kwargs["requires_grad"] = requires_grad + # zero_size tensor + input = sample_inputs_foreach(None, device, dtype, NUM_SIZE0_TENSORS, zero_size=True, **_foreach_inputs_kwargs) + args = [ + sample_inputs_foreach(None, device, dtype, NUM_SIZE0_TENSORS, zero_size=True, **_foreach_inputs_kwargs) + for _ in range(2) + ] + kwargs.pop("scalars", None) + kwargs.update(self._sample_kwargs(opinfo, args[-1], ForeachRightmostArgType.TensorList, dtype)) + yield ForeachSampleInput(input, *args, **kwargs) + + def __call__(self, opinfo, device, dtype, requires_grad, **kwargs): + num_input_tensors_specified = "num_input_tensors" in kwargs + num_input_tensors = kwargs.pop("num_input_tensors") if num_input_tensors_specified else foreach_num_tensors + assert isinstance(num_input_tensors, list) + _foreach_inputs_kwargs = {k: kwargs.pop(k, v) for k, v in _foreach_inputs_default_kwargs.items()} + _foreach_inputs_kwargs["requires_grad"] = requires_grad + allow_higher_dtype_scalars = kwargs.pop("allow_higher_dtype_scalars", False) + + for num_tensors, rightmost_arg_type, intersperse_empty_tensors in itertools.product( + num_input_tensors, self._rightmost_arg_types, (True, False)): + _foreach_inputs_kwargs["intersperse_empty_tensors"] = intersperse_empty_tensors + input = sample_inputs_foreach(None, device, dtype, num_tensors, zero_size=False, **_foreach_inputs_kwargs) + args = [ + sample_inputs_foreach(None, device, dtype, num_tensors, zero_size=False, **_foreach_inputs_kwargs) + for _ in range(2 - int(rightmost_arg_type == ForeachRightmostArgType.TensorList)) + ] + rightmost_arg_list = self._sample_rightmost_arg( + opinfo, + rightmost_arg_type, + device, + dtype, + num_tensors, + zero_size=False, + allow_higher_dtype_scalars=False if intersperse_empty_tensors else allow_higher_dtype_scalars, + **_foreach_inputs_kwargs, + ) + for rightmost_arg in rightmost_arg_list: + kwargs = {} + if rightmost_arg_type == ForeachRightmostArgType.TensorList: + args.append(rightmost_arg) + elif rightmost_arg_type in [ForeachRightmostArgType.Tensor, ForeachRightmostArgType.ScalarList]: + kwargs["scalars"] = rightmost_arg + else: + kwargs["value"] = rightmost_arg + kwargs.update(self._sample_kwargs(opinfo, rightmost_arg, rightmost_arg_type, dtype)) + assert len(args) == 2, f"{len(args)=}" + sample = ForeachSampleInput(input, *args, **kwargs) + yield sample + if rightmost_arg_type == ForeachRightmostArgType.TensorList: + args.pop() + + +foreach_unary_op_db: list[OpInfo] = [ + ForeachFuncInfo( + 'exp', + sample_inputs_func=foreach_inputs_sample_func(1, False, False), + dtypesIfHpu=custom_types(torch.float32), + backward_requires_result=True, + supports_autograd=True, + supports_inplace_autograd=True, + supports_forward_ad=True, + decorators=( + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_meta_inplace", + dtypes=integral_types_and(torch.bool,), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_symbolic_meta_inplace", + dtypes=integral_types_and(torch.bool,), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_meta_inplace", + dtypes=integral_types_and(torch.bool,), + ), + ), + ), + ForeachFuncInfo( + 'acos', + sample_inputs_func=foreach_inputs_sample_func(1, False, False), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + supports_autograd=True, + supports_inplace_autograd=True, + supports_forward_ad=True, + decorators=( + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_meta_inplace", + dtypes=integral_types_and(torch.bool,), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_symbolic_meta_inplace", + dtypes=integral_types_and(torch.bool,), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_meta_inplace", + dtypes=integral_types_and(torch.bool,), + ), + ), + ), + ForeachFuncInfo( + 'asin', + sample_inputs_func=foreach_inputs_sample_func(1, False, False), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + supports_autograd=True, + supports_inplace_autograd=True, + supports_forward_ad=True, + decorators=( + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_meta_inplace", + dtypes=integral_types_and(torch.bool,), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_symbolic_meta_inplace", + dtypes=integral_types_and(torch.bool,), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_meta_inplace", + dtypes=integral_types_and(torch.bool,), + ), + ), + ), + ForeachFuncInfo( + 'atan', + sample_inputs_func=foreach_inputs_sample_func(1, False, False), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + supports_autograd=True, + supports_inplace_autograd=True, + supports_forward_ad=True, + decorators=( + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_meta_inplace", + dtypes=integral_types_and(torch.bool,), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_symbolic_meta_inplace", + dtypes=integral_types_and(torch.bool,), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_meta_inplace", + dtypes=integral_types_and(torch.bool,), + ), + ), + ), + ForeachFuncInfo( + 'cos', + sample_inputs_func=foreach_inputs_sample_func(1, False, False), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + supports_autograd=True, + supports_inplace_autograd=True, + supports_forward_ad=True, + decorators=( + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_meta_inplace", + dtypes=integral_types_and(torch.bool,), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_symbolic_meta_inplace", + dtypes=integral_types_and(torch.bool,), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_meta_inplace", + dtypes=integral_types_and(torch.bool,), + ), + ), + ), + ForeachFuncInfo( + 'cosh', + sample_inputs_func=foreach_inputs_sample_func(1, False, False), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + supports_autograd=True, + supports_inplace_autograd=True, + supports_forward_ad=True, + decorators=( + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_meta_inplace", + dtypes=integral_types_and(torch.bool,), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_symbolic_meta_inplace", + dtypes=integral_types_and(torch.bool,), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_meta_inplace", + dtypes=integral_types_and(torch.bool,), + ), + ), + ), + ForeachFuncInfo( + 'log', + sample_inputs_func=foreach_inputs_sample_func(1, False, False), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + supports_autograd=True, + supports_inplace_autograd=True, + supports_forward_ad=True, + decorators=( + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_meta_inplace", + dtypes=integral_types_and(torch.bool,), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_symbolic_meta_inplace", + dtypes=integral_types_and(torch.bool,), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_meta_inplace", + dtypes=integral_types_and(torch.bool,), + ), + ), + ), + ForeachFuncInfo( + 'log10', + sample_inputs_func=foreach_inputs_sample_func(1, False, False), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + supports_autograd=True, + supports_inplace_autograd=True, + supports_forward_ad=True, + decorators=( + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_meta_inplace", + dtypes=integral_types_and(torch.bool,), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_symbolic_meta_inplace", + dtypes=integral_types_and(torch.bool,), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_meta_inplace", + dtypes=integral_types_and(torch.bool,), + ), + ), + ), + ForeachFuncInfo( + 'log2', + sample_inputs_func=foreach_inputs_sample_func(1, False, False), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + supports_autograd=True, + supports_inplace_autograd=True, + supports_forward_ad=True, + decorators=( + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_meta_inplace", + dtypes=integral_types_and(torch.bool,), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_symbolic_meta_inplace", + dtypes=integral_types_and(torch.bool,), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_meta_inplace", + dtypes=integral_types_and(torch.bool,), + ), + ), + ), + ForeachFuncInfo( + 'tan', + sample_inputs_func=foreach_inputs_sample_func(1, False, False), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + backward_requires_result=True, + supports_autograd=True, + supports_inplace_autograd=True, + supports_forward_ad=True, + decorators=( + # due to https://github.com/pytorch/pytorch/pull/102427 enabling jiterator for complex + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_meta_inplace", + dtypes=integral_types_and(torch.bool,), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_symbolic_meta_inplace", + dtypes=integral_types_and(torch.bool,), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_meta_inplace", + dtypes=integral_types_and(torch.bool,), + ), + DecorateInfo( + toleranceOverride( + { + torch.complex64: tol(atol=3e-04, rtol=2e-05) + } + ), + 'TestForeach', + 'test_parity', + device_type='cuda' + ), + ), + ), + ForeachFuncInfo( + 'tanh', + sample_inputs_func=foreach_inputs_sample_func(1, False, False), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + backward_requires_result=True, + supports_autograd=True, + supports_inplace_autograd=True, + supports_forward_ad=True, + decorators=( + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_meta_inplace", + dtypes=integral_types_and(torch.bool,), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_symbolic_meta_inplace", + dtypes=integral_types_and(torch.bool,), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_meta_inplace", + dtypes=integral_types_and(torch.bool,), + ), + DecorateInfo( + toleranceOverride( + {torch.complex64: tol(atol=5e-03, rtol=1e-04)} + ), + 'TestForeach', + 'test_parity', + device_type='cuda' + ), + ), + ), + ForeachFuncInfo( + 'sin', + sample_inputs_func=foreach_inputs_sample_func(1, False, False), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + supports_autograd=True, + supports_inplace_autograd=True, + supports_forward_ad=True, + decorators=( + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_meta_inplace", + dtypes=integral_types_and(torch.bool,), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_symbolic_meta_inplace", + dtypes=integral_types_and(torch.bool,), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_meta_inplace", + dtypes=integral_types_and(torch.bool,), + ), + ), + ), + ForeachFuncInfo( + 'sinh', + sample_inputs_func=foreach_inputs_sample_func(1, False, False), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + supports_autograd=True, + supports_inplace_autograd=True, + supports_forward_ad=True, + decorators=( + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_meta_inplace", + dtypes=integral_types_and(torch.bool), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_symbolic_meta_inplace", + dtypes=integral_types_and(torch.bool), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_meta_inplace", + dtypes=integral_types_and(torch.bool), + ), + ), + ), + ForeachFuncInfo( + 'neg', + sample_inputs_func=foreach_inputs_sample_func(1, False, False), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + supports_autograd=True, + supports_inplace_autograd=True, + supports_forward_ad=True, + decorators=( + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_meta_inplace", + dtypes=(torch.bool,), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_meta_outplace", + dtypes=(torch.bool,), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_symbolic_meta_inplace", + dtypes=(torch.bool,), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_symbolic_meta_outplace", + dtypes=(torch.bool,), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_meta_inplace", + dtypes=(torch.bool,), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_meta_outplace", + dtypes=(torch.bool,), + ), + DecorateInfo( + unittest.expectedFailure, + "TestForeach", + "test_unary_op_tensors_on_different_devices", + device_type="cuda", + dtypes=(torch.bool,), + ), + ), + ), + ForeachFuncInfo( + 'sqrt', + sample_inputs_func=foreach_inputs_sample_func(1, False, False), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + supports_autograd=True, + supports_inplace_autograd=True, + supports_forward_ad=True, + backward_requires_result=True, + decorators=( + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_meta_inplace", + dtypes=integral_types_and(torch.bool), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_symbolic_meta_inplace", + dtypes=integral_types_and(torch.bool), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_meta_inplace", + dtypes=integral_types_and(torch.bool), + ), + ), + ), + ForeachFuncInfo( + 'rsqrt', + sample_inputs_func=foreach_inputs_sample_func(1, False, False), + supports_autograd=True, + supports_inplace_autograd=True, + supports_forward_ad=True, + backward_requires_result=True, + decorators=( + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_meta_inplace", + dtypes=integral_types_and(torch.bool), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_symbolic_meta_inplace", + dtypes=integral_types_and(torch.bool), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_meta_inplace", + dtypes=integral_types_and(torch.bool), + ), + ), + ), + ForeachFuncInfo( + 'ceil', + sample_inputs_func=foreach_inputs_sample_func(1, False, False), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32, torch.int8), + supports_autograd=True, + supports_inplace_autograd=True, + supports_forward_ad=True, + decorators=( + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_meta_inplace", + dtypes=complex_types_and(torch.bool), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_meta_outplace", + dtypes=complex_types_and(torch.bool), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_symbolic_meta_inplace", + dtypes=complex_types_and(torch.bool), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_symbolic_meta_outplace", + dtypes=complex_types_and(torch.bool), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_meta_inplace", + dtypes=complex_types_and(torch.bool), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_meta_outplace", + dtypes=complex_types_and(torch.bool), + ), + DecorateInfo( + unittest.expectedFailure, + "TestForeach", + "test_autodiff", + device_type="cuda", + dtypes=(torch.complex128,), + ), + ), + ), + ForeachFuncInfo( + 'erf', + sample_inputs_func=foreach_inputs_sample_func(1, False, False), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + supports_autograd=True, + supports_inplace_autograd=True, + supports_forward_ad=True, + decorators=( + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_meta_inplace", + dtypes=integral_types_and(torch.bool) + complex_types(), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_meta_outplace", + dtypes=complex_types(), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_symbolic_meta_inplace", + dtypes=integral_types_and(torch.bool) + complex_types(), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_symbolic_meta_outplace", + dtypes=complex_types(), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_meta_inplace", + dtypes=integral_types_and(torch.bool) + complex_types(), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_meta_outplace", + dtypes=complex_types(), + ), + DecorateInfo( + unittest.expectedFailure, + "TestForeach", + "test_autodiff", + device_type="cuda", + dtypes=(torch.complex128,), + ), + ), + ), + ForeachFuncInfo( + 'erfc', + sample_inputs_func=foreach_inputs_sample_func(1, False, False), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + supports_autograd=True, + supports_inplace_autograd=True, + supports_forward_ad=True, + decorators=( + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_meta_inplace", + dtypes=integral_types_and(torch.bool) + complex_types(), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_meta_outplace", + dtypes=complex_types(), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_symbolic_meta_inplace", + dtypes=integral_types_and(torch.bool) + complex_types(), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_symbolic_meta_outplace", + dtypes=complex_types(), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_meta_inplace", + dtypes=integral_types_and(torch.bool) + complex_types(), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_meta_outplace", + dtypes=complex_types(), + ), + DecorateInfo( + unittest.expectedFailure, + "TestForeach", + "test_autodiff", + device_type="cuda", + dtypes=(torch.complex128,), + ), + ), + ), + ForeachFuncInfo( + 'expm1', + sample_inputs_func=foreach_inputs_sample_func(1, False, False), + dtypesIfHpu=custom_types(torch.float32), + supports_autograd=True, + supports_inplace_autograd=True, + supports_forward_ad=True, + backward_requires_result=True, + decorators=( + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_meta_inplace", + dtypes=integral_types_and(torch.bool), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_symbolic_meta_inplace", + dtypes=integral_types_and(torch.bool), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_meta_inplace", + dtypes=integral_types_and(torch.bool), + ), + ), + ), + ForeachFuncInfo( + 'floor', + sample_inputs_func=foreach_inputs_sample_func(1, False, False), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + supports_autograd=True, + supports_inplace_autograd=True, + supports_forward_ad=True, + decorators=( + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_meta_inplace", + dtypes=complex_types_and(torch.bool), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_meta_outplace", + dtypes=complex_types_and(torch.bool), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_symbolic_meta_inplace", + dtypes=complex_types_and(torch.bool), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_symbolic_meta_outplace", + dtypes=complex_types_and(torch.bool), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_meta_inplace", + dtypes=complex_types_and(torch.bool), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_meta_outplace", + dtypes=complex_types_and(torch.bool), + ), + DecorateInfo( + unittest.expectedFailure, + "TestForeach", + "test_autodiff", + device_type="cuda", + dtypes=(torch.complex128,), + ), + ), + ), + ForeachFuncInfo( + 'log1p', + sample_inputs_func=foreach_inputs_sample_func(1, False, False), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + supports_autograd=True, + supports_inplace_autograd=True, + supports_forward_ad=True, + decorators=( + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_meta_inplace", + dtypes=integral_types_and(torch.bool), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_symbolic_meta_inplace", + dtypes=integral_types_and(torch.bool), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_meta_inplace", + dtypes=integral_types_and(torch.bool), + ), + ), + ), + ForeachFuncInfo( + 'round', + sample_inputs_func=foreach_inputs_sample_func(1, False, False), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + supports_autograd=True, + supports_inplace_autograd=True, + supports_forward_ad=True, + decorators=( + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_meta_inplace", + dtypes=complex_types_and(torch.bool), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_meta_outplace", + dtypes=complex_types_and(torch.bool), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_symbolic_meta_inplace", + dtypes=complex_types_and(torch.bool), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_symbolic_meta_outplace", + dtypes=complex_types_and(torch.bool), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_meta_inplace", + dtypes=complex_types_and(torch.bool), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_meta_outplace", + dtypes=complex_types_and(torch.bool), + ), + DecorateInfo( + unittest.expectedFailure, + "TestForeach", + "test_autodiff", + device_type="cuda", + dtypes=(torch.complex128,), + ), + ), + ), + ForeachFuncInfo( + 'frac', + sample_inputs_func=foreach_inputs_sample_func(1, False, False), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + supports_autograd=True, + supports_inplace_autograd=True, + supports_forward_ad=True, + decorators=( + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_meta_inplace", + dtypes=integral_types_and(torch.bool) + complex_types(), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_meta_outplace", + dtypes=integral_types_and(torch.bool) + complex_types(), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_symbolic_meta_inplace", + dtypes=integral_types_and(torch.bool) + complex_types(), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_symbolic_meta_outplace", + dtypes=integral_types_and(torch.bool) + complex_types(), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_meta_inplace", + dtypes=integral_types_and(torch.bool) + complex_types(), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_meta_outplace", + dtypes=integral_types_and(torch.bool) + complex_types(), + ), + DecorateInfo( + unittest.expectedFailure, + "TestForeach", + "test_autodiff", + device_type="cuda", + dtypes=(torch.complex128,), + ), + ), + ), + ForeachFuncInfo( + 'reciprocal', + sample_inputs_func=foreach_inputs_sample_func(1, False, False), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + supports_autograd=True, + supports_inplace_autograd=True, + supports_forward_ad=True, + backward_requires_result=True, + decorators=( + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_meta_inplace", + dtypes=integral_types_and(torch.bool), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_symbolic_meta_inplace", + dtypes=integral_types_and(torch.bool), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_meta_inplace", + dtypes=integral_types_and(torch.bool), + ), + ), + ), + ForeachFuncInfo( + 'sigmoid', + sample_inputs_func=foreach_inputs_sample_func(1, False, False), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.float16), + supports_autograd=True, + supports_inplace_autograd=True, + supports_forward_ad=True, + backward_requires_result=True, + decorators=( + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_meta_inplace", + dtypes=integral_types_and(torch.bool), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_symbolic_meta_inplace", + dtypes=integral_types_and(torch.bool), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_meta_inplace", + dtypes=integral_types_and(torch.bool), + ), + ), + ), + ForeachFuncInfo( + 'trunc', + sample_inputs_func=foreach_inputs_sample_func(1, False, False), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32, torch.int8), + supports_autograd=True, + supports_inplace_autograd=True, + supports_forward_ad=True, + decorators=( + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_meta_inplace", + dtypes=complex_types_and(torch.bool), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_meta_outplace", + dtypes=complex_types_and(torch.bool), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_symbolic_meta_inplace", + dtypes=complex_types_and(torch.bool), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_symbolic_meta_outplace", + dtypes=complex_types_and(torch.bool), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_meta_inplace", + dtypes=complex_types_and(torch.bool), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_meta_outplace", + dtypes=complex_types_and(torch.bool), + ), + DecorateInfo( + unittest.expectedFailure, + "TestForeach", + "test_autodiff", + device_type="cuda", + dtypes=(torch.complex128,), + ), + ), + ), + ForeachFuncInfo( + 'abs', + sample_inputs_func=foreach_inputs_sample_func(1, False, False), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + supports_autograd=True, + supports_inplace_autograd=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + decorators=( + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_symbolic_meta_inplace", + dtypes=complex_types(), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_meta_inplace", + dtypes=complex_types(), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_meta_outplace", + device_type="cpu", + dtypes=(torch.bool,), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_symbolic_meta_inplace", + device_type="cpu", + dtypes=(torch.bool,), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_symbolic_meta_outplace", + device_type="cpu", + dtypes=(torch.bool,), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_meta_inplace", + device_type="cpu", + dtypes=(torch.bool,), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_meta_outplace", + device_type="cpu", + dtypes=(torch.bool,), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_meta_inplace", + device_type="cpu", + dtypes=(torch.bool,), + ), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace", dtypes=complex_types()), + ), + ), + ForeachFuncInfo( + 'zero', + sample_inputs_func=foreach_inputs_sample_func(1, False, False), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + supports_autograd=True, + supports_inplace_autograd=True, + supports_forward_ad=True, + supports_out=False, + ), + ForeachFuncInfo( + 'sign', + sample_inputs_func=foreach_inputs_sample_func(1, False, False), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + supports_autograd=True, + supports_inplace_autograd=True, + supports_forward_ad=True, + decorators=( + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_meta_inplace", + dtypes=complex_types(), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_meta_outplace", + dtypes=complex_types(), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_symbolic_meta_inplace", + dtypes=complex_types(), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_symbolic_meta_outplace", + dtypes=complex_types(), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_meta_inplace", + dtypes=complex_types(), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_meta_outplace", + dtypes=complex_types(), + ), + DecorateInfo( + unittest.expectedFailure, + "TestForeach", + "test_autodiff", + device_type="cuda", + dtypes=(torch.complex128,), + ), + ), + ), + ForeachFuncInfo( + 'lgamma', + sample_inputs_func=foreach_inputs_sample_func(1, False, False), + supports_autograd=True, + supports_inplace_autograd=True, + supports_forward_ad=True, + decorators=( + DecorateInfo(unittest.skip("In-place lgamma not supported for integral tensors"), "TestMeta", + "test_dispatch_symbolic_meta_inplace", dtypes=integral_types_and(torch.bool)), + # DecorateInfo(unittest.skip("In-place lgamma not supported for integral tensors"), "TestMeta", + # "test_dispatch_meta_inplace", dtypes=integral_types_and(torch.bool)), + DecorateInfo(unittest.skip("In-place lgamma not supported for integral tensors"), "TestMeta", + "test_meta_inplace", dtypes=integral_types_and(torch.bool)), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_meta_inplace", + dtypes=complex_types() + integral_types_and(torch.bool), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_meta_outplace", + dtypes=complex_types(), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_symbolic_meta_inplace", + dtypes=complex_types() + integral_types_and(torch.bool), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_symbolic_meta_outplace", + dtypes=complex_types(), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_meta_inplace", + dtypes=complex_types() + integral_types_and(torch.bool), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_meta_outplace", + dtypes=complex_types(), + ), + DecorateInfo( + unittest.expectedFailure, + "TestForeach", + "test_autodiff", + device_type="cuda", + dtypes=(torch.complex128,), + ), + ), + ), +] + +foreach_binary_op_db: list[OpInfo] = [ + ForeachFuncInfo( + "add", + sample_inputs_func=foreach_inputs_sample_func(2, True, True, True), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.float16, torch.int32), + supports_alpha_param=True, + supports_autograd=True, + supports_inplace_autograd=True, + supports_forward_ad=True, + decorators=( + # These tests fail with aten._local_scalar_dense not being implemented. + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace"), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace", + dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16)), + # Samples have complex types and inplace only works if the dtype is complex. + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace", + dtypes=(torch.bool,)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace", + dtypes=(torch.bool,)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides", + dtypes=integral_types() + complex_types_and(torch.bool, torch.bfloat16, torch.float16, torch.float64)), + ), + ), + ForeachFuncInfo( + "sub", + sample_inputs_func=foreach_inputs_sample_func(2, True, True), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + supports_alpha_param=True, + supports_autograd=True, + supports_inplace_autograd=True, + supports_forward_ad=True, + decorators=( + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace"), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace"), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace"), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace"), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace"), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace"), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides"), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides"), + DecorateInfo(unittest.skip("consistently fails internally and causes other tests to appear flaky"), + "TestForeach", "test_parity", dtypes=(torch.complex128,), + active_if=lambda kwargs: IS_FBCODE and not kwargs["noncontiguous"]), + ), + ), + ForeachFuncInfo( + "mul", + sample_inputs_func=foreach_inputs_sample_func(2, True, True, True), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + supports_autograd=True, + supports_inplace_autograd=True, + supports_forward_ad=True, + decorators=( + # Samples have complex types and inplace only works if the dtype is complex. + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace", + dtypes=(torch.bool,)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace", + dtypes=(torch.bool,)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace", dtypes=(torch.bool,)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides", + dtypes=(torch.bool,)), + DecorateInfo(unittest.skip("consistently fails internally and causes other tests to appear flaky"), + "TestForeach", "test_parity", dtypes=(torch.complex128,), + active_if=lambda kwargs: IS_FBCODE and not kwargs["noncontiguous"]), + ), + ), + ForeachFuncInfo( + "div", + sample_inputs_func=foreach_inputs_sample_func(2, True, True, True), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.float16, torch.int32, torch.int8), + supports_autograd=True, + supports_inplace_autograd=True, + supports_forward_ad=True, + decorators=( + # Samples have complex types and inplace only works if the dtype is complex. + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace", + dtypes=integral_types_and(torch.bool)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace", + dtypes=integral_types_and(torch.bool)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace", + dtypes=integral_types_and(torch.bool)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides", + dtypes=integral_types_and(torch.bool)), + ), + ), + ForeachFuncInfo( + "clamp_min", + sample_inputs_func=foreach_inputs_sample_func(2, True, True), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.float16, torch.int64, torch.int32, torch.int8, torch.bool), + supports_autograd=True, + supports_inplace_autograd=True, + supports_forward_ad=True, + decorators=( + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace", + dtypes=complex_types_and(torch.bool)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace", + dtypes=complex_types_and(torch.bool)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace", + dtypes=complex_types_and(torch.bool)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace", + dtypes=complex_types_and(torch.bool)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace", + dtypes=complex_types_and(torch.bool)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace", + dtypes=complex_types_and(torch.bool)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides", + dtypes=complex_types_and(torch.bool)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides", + dtypes=complex_types_and(torch.bool)), + DecorateInfo( + unittest.expectedFailure, + "TestForeach", + "test_autodiff", + device_type="cuda", + dtypes=(torch.complex128,), + ), + DecorateInfo( + unittest.expectedFailure, + "TestForeach", + "test_binary_op_scalar_with_overlapping_tensors", + dtypes=complex_types(), + ), + ), + ), + ForeachFuncInfo( + "clamp_max", + sample_inputs_func=foreach_inputs_sample_func(2, True, True), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.float16, torch.int64, torch.int32, torch.int8, torch.bool), + supports_autograd=True, + supports_inplace_autograd=True, + supports_forward_ad=True, + decorators=( + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace", + dtypes=complex_types_and(torch.bool)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace", + dtypes=complex_types_and(torch.bool)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace", + dtypes=complex_types_and(torch.bool)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace", + dtypes=complex_types_and(torch.bool)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace", + dtypes=complex_types_and(torch.bool)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace", + dtypes=complex_types_and(torch.bool)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides", + dtypes=complex_types_and(torch.bool)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides", + dtypes=complex_types_and(torch.bool)), + DecorateInfo( + unittest.expectedFailure, + "TestForeach", + "test_autodiff", + device_type="cuda", + dtypes=(torch.complex128,), + ), + DecorateInfo( + unittest.expectedFailure, + "TestForeach", + "test_binary_op_scalar_with_overlapping_tensors", + dtypes=complex_types(), + ), + ), + ), + # note(crcrpar): forward ad not implemented. + ForeachFuncInfo( + "minimum", + sample_inputs_func=foreach_inputs_sample_func(2, True, True), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32), + supports_autograd=True, + supports_inplace_autograd=False, + supports_forward_ad=False, + decorators=( + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace", + dtypes=complex_types_and(torch.bool)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace", + dtypes=complex_types_and(torch.bool)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace", + dtypes=complex_types_and(torch.bool)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace", + dtypes=complex_types_and(torch.bool)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace", + dtypes=complex_types_and(torch.bool)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace", + dtypes=complex_types_and(torch.bool)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides", + dtypes=complex_types_and(torch.bool)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides", + dtypes=complex_types_and(torch.bool)), + DecorateInfo( + unittest.expectedFailure, + "TestForeach", + "test_autodiff", + device_type="cuda", + dtypes=(torch.complex128,), + ), + DecorateInfo( + unittest.expectedFailure, + "TestForeach", + "test_binary_op_scalar_with_overlapping_tensors", + dtypes=complex_types(), + ), + ), + ), + # note(crcrpar): forward ad not implemented. + ForeachFuncInfo( + "maximum", + sample_inputs_func=foreach_inputs_sample_func(2, True, True), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32), + supports_autograd=True, + supports_forward_ad=False, + supports_inplace_autograd=False, + decorators=( + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace", + dtypes=complex_types_and(torch.bool)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace", + dtypes=complex_types_and(torch.bool)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace", + dtypes=complex_types_and(torch.bool)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace", + dtypes=complex_types_and(torch.bool)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace", + dtypes=complex_types_and(torch.bool)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace", + dtypes=complex_types_and(torch.bool)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides", + dtypes=complex_types_and(torch.bool)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides", + dtypes=complex_types_and(torch.bool)), + DecorateInfo( + unittest.expectedFailure, + "TestForeach", + "test_autodiff", + device_type="cuda", + dtypes=(torch.complex128,), + ), + DecorateInfo( + unittest.expectedFailure, + "TestForeach", + "test_binary_op_scalar_with_overlapping_tensors", + dtypes=complex_types(), + ), + ), + ), + ForeachFuncInfo( + "pow", + supports_alpha_param=False, + supports_scalar_self_arg=True, + sample_inputs_func=foreach_inputs_sample_func(2, True, True), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.float16, torch.int32, torch.int8, torch.bool), + supports_autograd=True, + supports_inplace_autograd=True, + supports_forward_ad=True, + decorators=( + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace", dtypes=(torch.bool,)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace", + dtypes=(torch.bool,)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace", dtypes=(torch.bool,)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace", dtypes=(torch.bool,)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace", + dtypes=(torch.bool,)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace", + dtypes=(torch.bool,)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace", + dtypes=(torch.bool,)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace", + dtypes=(torch.bool,),), + DecorateInfo(unittest.skip("flaky"), "TestForeach", "test_parity", device_type="cpu", dtypes=(torch.complex64,)), + DecorateInfo( + unittest.skip("failed starting on ROCm 6.2"), + "TestForeach", + "test_parity", + device_type="cuda", + dtypes=(torch.complex64,), + active_if=TEST_WITH_ROCM), + DecorateInfo( + unittest.expectedFailure, + "TestForeach", + "test_binary_op_with_scalar_self_support", + device_type="cuda", + dtypes=(torch.bool,), + active_if=lambda kwargs: kwargs["is_fastpath"], + ), + ), + backward_requires_result=True, + ), + ForeachFuncInfo( + "copy", + sample_inputs_func=foreach_inputs_sample_func(2, False, False), + supports_out=False, + supports_forward_ad=False, + supports_autograd=False, + supports_inplace_autograd=False, + ) +] + +foreach_pointwise_op_db: list[ForeachFuncInfo] = [ + ForeachFuncInfo( + "addcmul", + sample_inputs_func=foreach_pointwise_sample_func(4, True, True), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + supports_autograd=True, + supports_inplace_autograd=True, + supports_forward_ad=True, + decorators=( + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace", dtypes=(torch.bool,)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace", + dtypes=(torch.bool,)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace", dtypes=(torch.bool,)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides", + dtypes=(torch.bool,)), + # # Samples have complex types and inplace only works if the dtype is complex. + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace", dtypes=(torch.bool,)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace", + dtypes=(torch.bool,)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace", dtypes=(torch.bool,)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides", + dtypes=integral_types() + complex_types_and(torch.bool)), + ), + ), + ForeachFuncInfo( + "addcdiv", + sample_inputs_func=foreach_pointwise_sample_func(4, True, True), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + supports_autograd=True, + supports_inplace_autograd=True, + supports_forward_ad=True, + decorators=( + # Samples have complex types and inplace only works if the dtype is complex. + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace", + dtypes=integral_types_and(torch.bool)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace", + dtypes=integral_types_and(torch.bool)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace", + dtypes=integral_types_and(torch.bool)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides", + dtypes=integral_types() + complex_types_and(torch.bool)), + # fails with div_cpu is not implemented with ComplexHalf + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace", + dtypes=integral_types_and(torch.bool)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace", + dtypes=integral_types_and(torch.bool)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace", + dtypes=integral_types_and(torch.bool)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides", + dtypes=integral_types() + complex_types_and(torch.bool)), + ), + ), +] + +foreach_reduce_op_db: list[ForeachFuncInfo] = [ + ForeachFuncInfo( + "max", + sample_inputs_func=foreach_max_sample_func(1, False, False), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32), + supports_autograd=True, + supports_inplace_autograd=True, + supports_forward_ad=True, + decorators=( + # no complex support for ordering ops like max + DecorateInfo( + unittest.expectedFailure, + "TestForeach", + "test_autodiff", + dtypes=(torch.complex128, torch.complex64), + ), + DecorateInfo( + unittest.expectedFailure, + "TestForeach", + "test_foreach_reduce_large_input", + dtypes=(torch.complex128, torch.complex64), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_symbolic_meta_outplace", + dtypes=(torch.complex128, torch.complex64), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_meta_outplace", + dtypes=(torch.complex128, torch.complex64), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_meta_outplace", + dtypes=(torch.complex128, torch.complex64), + ), + ), + ), + ForeachFuncInfo( + "norm", + sample_inputs_func=foreach_norm_sample_func(1, False, False), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + supports_autograd=True, + supports_inplace_autograd=True, + supports_forward_ad=True, + decorators=( + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace"), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_meta_outplace", + dtypes=integral_types_and(torch.bool), + ), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace"), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_symbolic_meta_outplace", + dtypes=integral_types_and(torch.bool), + ), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace"), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_meta_outplace", + dtypes=integral_types_and(torch.bool), + ), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides"), + DecorateInfo( + unittest.expectedFailure, + "TestForeach", + "test_foreach_reduce_large_input", + device_type="cuda", + dtypes=integral_types_and(torch.bool), + ), + ), + ), +] + +foreach_other_op_db: list[ForeachFuncInfo] = [ + ForeachFuncInfo( + "lerp", + sample_inputs_func=foreach_inputs_sample_func(3, True, True), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + supports_autograd=True, + supports_inplace_autograd=True, + supports_forward_ad=True, + decorators=( + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_meta_inplace", + dtypes=integral_types_and(torch.bool), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_meta_outplace", + dtypes=integral_types_and(torch.bool), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_symbolic_meta_outplace", + dtypes=integral_types_and(torch.bool), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_symbolic_meta_inplace", + dtypes=integral_types_and(torch.bool), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_meta_inplace", + dtypes=integral_types_and(torch.bool), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_meta_outplace", + dtypes=integral_types_and(torch.bool), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_symbolic_meta_inplace_all_strides", + dtypes=integral_types_and(torch.bool), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_symbolic_meta_outplace_all_strides", + dtypes=integral_types_and(torch.bool), + ), + ), + ), +] + +def reference_sign(x): + if x.dtype == np.bool_: + # `np.sign` doesn't support `bool`. + # >>> np.sign(True) + # ufunc 'sign' did not contain a loop + # with signature matching types dtype('bool') -> dtype('bool') + return np.sign(x, dtype=np.uint8).astype(np.bool_) + return np.sign(x) + + +def reference_sgn(x): + # NumPy doesn't have an equivalent to `torch.sgn` when the dtype is complex. + # For complex inputs, `np.sign` returns sign(x.real) + 0j if x.real != 0 else sign(x.imag) + 0j. + # while `torch.sgn` returns, 0 if abs(input) == 0 else input/abs(input) + if x.dtype not in [np.complex64, np.complex128]: + return reference_sign(x) + + out = (x / np.abs(x)) + if out.ndim == 0: + # Handle x == 0 case + if (x == 0): + # Can't assign to np.complex object + # So make a new one. + return np.array(complex(0, 0), dtype=x.dtype) + return out + + # Handle x == 0 case + mask = (x == 0) + out[mask] = complex(0, 0) + return out + + +def reference_sigmoid(x): + # 'scipy.special.expit' not supported for the input types + if x.dtype in [np.complex64, np.complex128]: + return (1 / (1 + np.exp(-x))) + return scipy.special.expit(x) + + +def reference_logsigmoid(x): + return np.where( + x < 0, + x - np.log1p(np.exp(x)), + -np.log1p(np.exp(-x))) + + +def reference_hardsigmoid(x): + intermediate = x / 6 + 0.5 + y = np.clip(intermediate, 0, None) + return np.where(y > 1, 1, y).astype(x.dtype) + + +def reference_lgamma(x): + # scipy.special.gammaln returns `-inf` when input is `-inf`. + # While Pytorch, C and C++, all return `inf` when input is `-inf`. + # Reference: + # https://en.cppreference.com/w/cpp/numeric/math/lgamma + # https://en.cppreference.com/w/c/numeric/math/lgamma + + # To handle the above discrepancy, + # we replace -inf with inf so values + # that were originally -inf map to inf as expected + if x.dtype.kind == 'f': + x = np.where(x == float('-inf'), np.array(float('inf'), dtype=x.dtype), x) + + out = scipy.special.gammaln(x) + + if x.dtype == np.float16: + # `scipy.special.gammaln` returns output of float32 when input is float16, + # while `torch.lgamma` preserves `float16`. But due to smaller range of float16, + # Pytorch version outputs `inf` while SciPy returns finite values. + out = out.astype(np.float16) + + return out + + +def reference_mvlgamma(x, d): + if x.dtype == np.float16: + return scipy.special.multigammaln(x, d).astype(np.float16) + + return scipy.special.multigammaln(x, d) + +def reference_softplus(input, beta=1, threshold=20): + non_linear = input * beta <= threshold + output = input.copy() + output[non_linear] = np.log(1 + np.exp(beta * input[non_linear])) / beta + return output + +def reference_gelu(X, *, approximate='none'): + def _gelu_ref(X): + return X * stats.norm.cdf(X) + + def _tanh_gelu_ref(X): + M_SQRT_2_PI = math.sqrt(2 / math.pi) + Z = M_SQRT_2_PI * (X + 0.044715 * np.power(X, 3.0)) + return 0.5 * X * (1.0 + np.tanh(Z)) + + if approximate == 'tanh': + return _tanh_gelu_ref(X) + else: + return _gelu_ref(X) + + +def reference_one_hot(a: npt.NDArray, num_classes: int = -1) -> npt.NDArray: + if num_classes == -1: + num_classes = int(np.amax(a) + 1) + + idcs = a.reshape(-1) + np.arange(0, a.size, dtype=np.int64) * num_classes + one_hot = np.zeros((a.size, num_classes), dtype=a.dtype) + np.put(one_hot, idcs, 1) + return one_hot.reshape(*a.shape, -1) + + +def reference_mse_loss(input, target, reduction="mean"): + se = (input - target) ** 2 + if reduction == "mean": + return np.mean(se) + elif reduction == "sum": + return np.sum(se) + else: # reduction == "none" + return se + + +def reference_layer_norm(inp: npt.NDArray, normalized_shape: tuple[int, ...], weight=None, bias=None, eps=1e-5): + return reference_native_layer_norm(inp, normalized_shape, weight, bias, eps)[0] + + +def reference_native_layer_norm(inp: npt.NDArray, normalized_shape: tuple[int, ...], weight, bias, eps): + feature_size = np.prod(normalized_shape) + inp_view = inp.reshape(-1, feature_size) # type: ignore[call-overload] + mean = inp_view.mean(axis=-1, keepdims=True) + var = inp_view.var(axis=-1, ddof=0, keepdims=True) + Y = (inp_view - mean) / np.sqrt(var + eps) + if weight is None and bias is not None: + Y = Y + bias.reshape(-1) + elif weight is not None and bias is None: + Y = Y * weight.reshape(-1) + elif weight is not None and bias is not None: + Y = Y * weight.reshape(-1) + bias.reshape(-1) + axis = inp.ndim - len(normalized_shape) + stat_shape = inp.shape[:axis] + (1,) * len(normalized_shape) + return Y.reshape(*inp.shape), mean.reshape(stat_shape), (1.0 / np.sqrt(var + eps)).reshape(stat_shape) + + +def reference_rms_norm(inp: npt.NDArray, normalized_shape: tuple[int, ...], weight=None, eps=None): + if eps is None: + eps = torch.finfo(numpy_to_torch_dtype(inp.dtype)).eps + feature_size = np.prod(normalized_shape) + inp_view = inp.reshape(-1, feature_size) # type: ignore[call-overload] + rms = np.sqrt((inp_view**2).mean(axis=-1, keepdims=True) + eps) + Y = inp_view / rms + if weight is not None: + Y = Y * weight.reshape(-1) + return Y.reshape(*inp.shape) + + +def reference_group_norm(inp: npt.NDArray, num_groups: int, weight=None, bias=None, eps=1e-5): + inp_view = inp + if np.prod(inp.shape) != 0: + inp_view = inp.reshape((inp.shape[0], num_groups, -1)) + mean = inp_view.mean(axis=-1, keepdims=True) + var = inp_view.var(axis=-1, ddof=0, keepdims=True) + Y = (inp_view - mean) / np.sqrt(var + eps) + Y = Y.reshape(inp.shape) + if weight is not None: + # weight is a vector of length equal to the channel + if len(Y.shape) > 2: + weight = np.expand_dims(weight, [0] + [idx + 2 for idx in range(inp.ndim - 2)]) + Y = Y * weight + if bias is not None: + # bias is a vector of length equal to the channel + if len(Y.shape) > 2: + bias = np.expand_dims(bias, [0] + [idx + 2 for idx in range(inp.ndim - 2)]) + Y = Y + bias + return Y + + +# using a custom reference function since numpy only has a string side arg (instead of right and side) and doesn't +# have an out_int32 arg. Additionally, numpy doesn't support searchsorted with ND arrays, so this splits those into +# stacked 1D cases +def reference_searchsorted(sorted_sequence, boundary, out_int32=False, right=False, side='left', sorter=None): + side = 'right' if (right or side == 'right') else 'left' + if len(sorted_sequence.shape) == 1 : + ret = np.searchsorted(sorted_sequence, boundary, side=side, sorter=sorter) + return ret.astype(np.int32) if out_int32 else ret + elif sorted_sequence.shape[0] == 0: + if sorter is not None: + sorter = sorter.flatten() + ret = np.searchsorted(sorted_sequence.flatten(), boundary.flatten(), side=side, sorter=sorter) + ret = ret.astype(np.int32) if out_int32 else ret + return ret.reshape(boundary.shape) + else: + # numpy searchsorted only supports 1D inputs so we split up ND inputs + orig_shape = boundary.shape + num_splits = np.prod(sorted_sequence.shape[:-1]) + splits = range(num_splits) + sorted_sequence, boundary = sorted_sequence.reshape(num_splits, -1), boundary.reshape(num_splits, -1) + if sorter is not None: + sorter = sorter.reshape(num_splits, -1) + + split_sequence = [sorted_sequence[i] for i in splits] + split_boundary = [boundary[i] for i in splits] + split_sorter = [sorter[i] if (sorter is not None) else None for i in splits] + + split_ret = [np.searchsorted(s_seq, b, side=side, sorter=s_sort) + for (s_seq, b, s_sort) in zip(split_sequence, split_boundary, split_sorter, strict=True)] + split_ret = [i.astype(np.int32) for i in split_ret] if out_int32 else split_ret + return np.stack(split_ret).reshape(orig_shape) + +def reference_hash_tensor(tensor, dim=(), keepdim=False, mode=0): + assert mode == 0, "Only mode=0 (xor_sum) is supported right now" + + dtype = tensor.dtype + if dtype.kind == 'f': + tensor = tensor.astype(np.float64).view(np.uint64) + else: + tensor = tensor.astype(np.uint64) + + + if dim == (): + result = np.bitwise_xor.reduce(tensor.flatten(), keepdims=keepdim) + else: + if isinstance(dim, list): + dim = tuple(dim) + result = np.bitwise_xor.reduce(tensor, axis=dim, keepdims=keepdim) + + return result + + +def loss_reference_reduction_wrapper(fn): + def wrapper(input, target, *, size_average=None, reduce=None, reduction="mean", **other_kwargs): + if size_average is not None or reduce is not None: + raise RuntimeError( + "The keyword arguments 'size_average' and 'reduce' are deprecated and not supported by this wrapper" + ) + output = fn(input, target, **other_kwargs) + if reduction == "mean": + return np.mean(output) + elif reduction == "sum": + return np.sum(output) + else: # reduction == "none" + return output + + return wrapper + +@loss_reference_reduction_wrapper +def reference_smooth_l1_loss(input, target, beta=1.0): + diff = input - target + abs_diff = np.abs(diff) + above_threshold = abs_diff >= beta + + loss = np.empty_like(input) + loss[above_threshold] = abs_diff[above_threshold] - 0.5 * beta + loss[~above_threshold] = diff[~above_threshold] ** 2 / (2 * beta) + + return loss + +def reference_std_var(f): + """Forwards unbiased/correction kwargs as NumPy's equivalent ddof""" + g = reference_reduction_numpy(f) + + @wraps(g) + def wrapper(x: npt.NDArray, *args, **kwargs): + assert not ('unbiased' in kwargs and 'correction' in kwargs) + + if 'unbiased' in kwargs: + kwargs['ddof'] = int(kwargs.pop('unbiased')) + elif 'correction' in kwargs: + kwargs['ddof'] = kwargs.pop('correction') + + return g(x, *args, **kwargs) + + return wrapper + +def generate_std_var_kwargs(t: torch.Tensor, **kwargs): + """Generates unbiased/correction kwargs for std/var operators""" + yield ((), {'unbiased': True}) + yield ((), {'unbiased': False}) + + # Currently, calling std with correction is only enabled when + # both dim and keepdim are provided. + if 'dim' in kwargs and 'keepdim' in kwargs: + yield ((), {'correction': 0}) + yield ((), {'correction': 1}) + + numel = torch.tensor(t.shape)[kwargs.get('dim')].prod() + yield ((), {'correction': numel // 2}) + +def error_inputs_mean(op_info, device, is_ref=False, **kwargs): + if is_ref: + err_msg1 = (r"mean\(\): could not infer output dtype. " + r"Input dtype must be either a floating point or complex dtype. " + r"Got: torch.int64") + else: + err_msg1 = (r"mean\(\): could not infer output dtype. " + r"Input dtype must be either a floating point or complex dtype. " + r"Got: Long") + yield ErrorInput( + SampleInput(make_tensor((3, 4, 5), dtype=torch.int64, device=device), []), + error_regex=err_msg1, + ) + + if is_ref: + err_msg2 = (r"mean\(\): could not infer output dtype. " + r"Optional dtype must be either a floating point or complex dtype. " + r"Got: torch.int64") + else: + err_msg2 = (r"mean\(\): could not infer output dtype. " + r"Optional dtype must be either a floating point or complex dtype. " + r"Got: Long") + yield ErrorInput( + SampleInput( + make_tensor((3, 4, 5), dtype=torch.float32, device=device), + [], + dtype=torch.int64), + error_regex=err_msg2 + ) + +# numpy implementation of torch.flatten +# unfortunately there's no np.flatten. we figure out the desired shape and call np.reshape +def reference_flatten(input, start_dim=0, end_dim=-1): + in_shape = input.shape + in_rank = len(in_shape) + for d in start_dim, end_dim: + if not ((in_rank == 0 and d in (-1, 0)) or -in_rank <= d < in_rank): + raise IndexError(f"Dimension out of range (expected to be in range of [{-in_rank}, {in_rank - 1}], but got {d}") + end_dim = end_dim if end_dim >= 0 else in_rank + end_dim + start_dim = start_dim if start_dim >= 0 else in_rank + start_dim + if in_rank == 0: + end_dim = start_dim + if end_dim < start_dim: + raise RuntimeError("flatten() has invalid args: start_dim cannot come after end_dim") + flatten_bit_dim = functools.reduce(operator.mul, in_shape[start_dim:end_dim + 1], 1) + out_shape = in_shape[:start_dim] + (flatten_bit_dim,) + in_shape[end_dim + 1:] + return np.reshape(input, out_shape) + + +def sample_inputs_alias_copy(op_info, device, dtype, requires_grad, **kwargs): + yield SampleInput(make_tensor((S,), dtype=dtype, device=device, requires_grad=requires_grad)) + yield SampleInput(make_tensor((), dtype=dtype, device=device, requires_grad=requires_grad)) + + +# Operator database (sorted alphabetically) +op_db: list[OpInfo] = [ + UnaryUfuncInfo('abs', + aliases=('absolute', ), + ref=np.abs, + dtypes=all_types_and_complex_and(torch.half, torch.bfloat16, torch.chalf), + dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + skips=( + DecorateInfo(unittest.skip("In-place abs not supported for complex tensors"), 'TestBwdGradients', + 'test_inplace_grad', dtypes=(torch.cdouble,)), + DecorateInfo(unittest.skip("In-place abs not supported for complex tensors"), 'TestBwdGradients', + 'test_inplace_gradgrad', dtypes=(torch.cdouble,)), + DecorateInfo(unittest.skip("In-place abs not supported for complex tensors"), 'TestFwdGradients', + 'test_inplace_forward_mode_AD', dtypes=(torch.cdouble,)), + DecorateInfo(unittest.skip("In-place abs not supported for complex tensors"), "TestSparseUnaryUfuncs", + "test_inplace", dtypes=(torch.cdouble, torch.cfloat, torch.chalf)), + # Reference: https://github.com/pytorch/pytorch/issues/49224 + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_small', + dtypes=[torch.int8], active_if=TEST_WITH_ASAN), + # TODO: Fix test_out_arg_all_dtypes as torch.empty_like(expected_output) where expected_output=op(input) + # We can break the logic of the loop over all possible types but it is OK. + # https://github.com/pytorch/pytorch/blob/master/test/test_unary_ufuncs.py#L440-L449 + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_out_arg_all_dtypes', + dtypes=[torch.cfloat, torch.cdouble]), + DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_meta_inplace', + dtypes=(torch.cdouble, torch.cfloat, torch.chalf)), + DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_dispatch_meta_inplace', + dtypes=(torch.cdouble, torch.cfloat, torch.chalf)), + DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_dispatch_symbolic_meta_inplace', + dtypes=(torch.cdouble, torch.cfloat, torch.chalf)), + DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_dispatch_symbolic_meta_inplace_all_strides', + dtypes=(torch.cdouble, torch.cfloat, torch.chalf)), + ), + supports_fwgrad_bwgrad=True, + assert_autodiffed=True, + supports_sparse=True, + supports_sparse_csr=True, + supports_sparse_csc=True, + supports_sparse_bsr=True, + supports_sparse_bsc=True, + supports_forward_ad=True), + # NOTE: CPU complex acos produces incorrect outputs (https://github.com/pytorch/pytorch/issues/42952) + UnaryUfuncInfo('acos', + aliases=('arccos', ), + ref=np.arccos, + domain=(-1, 1), + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + dtypesIfCUDA=all_types_and_complex_and(torch.chalf, torch.bool, torch.half, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + assert_autodiffed=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + promotes_int_to_float=True, + decorators=(precisionOverride({torch.float16: 1e-2, + torch.bfloat16: 1e-1, + torch.complex64: 1e-2}),), + skips=( + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_normal', + device_type='cuda', dtypes=[torch.cdouble], active_if=IS_WINDOWS), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', + device_type='cuda', dtypes=[torch.cdouble], active_if=IS_WINDOWS), + # Failing with wrong imaginary sign on at least some Windows jobs + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_small', + device_type='cuda', dtypes=[torch.cdouble], + active_if=IS_WINDOWS), + # Failing with wrong imaginary sign on at least some Windows jobs + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', + device_type='cuda', dtypes=[torch.cdouble], + active_if=IS_WINDOWS), + DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients', 'test_fn_grad', + dtypes=[torch.cdouble], active_if=IS_WINDOWS), + DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients', 'test_method_grad', + dtypes=[torch.cdouble], active_if=IS_WINDOWS), + DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients', 'test_inplace_grad', + dtypes=[torch.cdouble], active_if=IS_WINDOWS), + DecorateInfo(unittest.skip("Skipped!"), 'TestFwdGradients', 'test_forward_mode_AD', + dtypes=[torch.cdouble], active_if=IS_WINDOWS), + DecorateInfo(unittest.skip("Skipped!"), 'TestFwdGradients', 'test_inplace_forward_mode_AD', + dtypes=[torch.cdouble], active_if=IS_WINDOWS),)), + # NOTE: the derivative for inplace acosh is not implemented + UnaryUfuncInfo('acosh', + aliases=('arccosh', ), + ref=np.arccosh, + domain=(1, None), + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + dtypesIfCUDA=all_types_and_complex_and(torch.chalf, torch.bool, torch.half, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + decorators=(precisionOverride({torch.bfloat16: 5e-2}),), + supports_inplace_autograd=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + promotes_int_to_float=True, + skips=( + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_normal', + device_type='cuda', dtypes=[torch.cdouble], active_if=IS_WINDOWS), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', + device_type='cuda', dtypes=[torch.cdouble], active_if=IS_WINDOWS), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', + device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', + device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', + device_type='cuda', dtypes=[torch.cdouble], + active_if=IS_WINDOWS), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', + device_type='cuda', dtypes=[torch.cdouble], + active_if=IS_WINDOWS), + # Failing with wrong imaginary sign on at least some Windows jobs + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_small', + device_type='cuda', dtypes=[torch.cdouble], + active_if=IS_WINDOWS), + ), + # acosh is not defined at x < 1 (real) + reference_numerics_filter=NumericsFilter( + condition=lambda x: (x < 1 if not x.is_complex() else torch.zeros_like(x, dtype=torch.bool)), + safe_val=2)), + BinaryUfuncInfo('add', + # NumPy has no builtin reference for the alpha kwarg, but it is easy enough to emulate + ref=lambda input, other, *, alpha=1: ( + np.add(input, other) + if alpha == 1 + else np.add(input, np.multiply(alpha, other)) + ), + dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, + torch.float16, torch.chalf), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32), + assert_autodiffed=True, + sample_inputs_func=sample_inputs_add_sub, + supports_fwgrad_bwgrad=True, + supports_forward_ad=True, + supports_two_python_scalars=True, + decorators=( + DecorateInfo( + toleranceOverride({torch.chalf: tol(atol=1e-2, rtol=0)}), + 'TestBinaryUfuncs', 'test_reference_numerics'), + ), + skips=( + # boolean alpha not handled properly + DecorateInfo(unittest.expectedFailure, + 'TestNNCOpInfo', + 'test_nnc_correctness', + dtypes=(torch.bool,)), + DecorateInfo(unittest.skip("Skipped!"), + 'TestCommon', + 'test_numpy_refs', + dtypes=(torch.complex128,)), + DecorateInfo(unittest.skip("Skipped!"), + 'TestBinaryUfuncs', + 'test_reference_numerics_extremal_values', + dtypes=(torch.complex64, torch.complex128)), + )), + OpInfo('item', + op=lambda inp, *args, **kwargs: wrapper_set_seed(torch.Tensor.item, inp, *args, **kwargs), + ref=np.ndarray.item, + method_variant=None, + dtypes=all_types_and_complex_and(torch.bfloat16, torch.float16, torch.chalf, torch.bool), + dtypesIfHpu=custom_types(torch.float32), + supports_out=False, + supports_autograd=False, + error_inputs_func=error_inputs_item, + sample_inputs_func=sample_inputs_item, + skips=( + # Error testing item function variant + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', + dtypes=(torch.float32, torch.complex64)), + # FX failed to normalize op - add the op to the op_skip list. + DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), + # RuntimeError: Composite compliance check failed with the above error. + DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_operator'), + # Booleans mismatch: AssertionError: False is not true + DecorateInfo(unittest.expectedFailure, 'TestFakeTensor', 'test_fake_autocast'), + # Booleans mismatch: AssertionError: False is not true + DecorateInfo(unittest.expectedFailure, 'TestFakeTensor', 'test_fake'), + )), + OpInfo('arange', + dtypes=all_types_and(torch.bfloat16, torch.float16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32, torch.int8), + supports_out=True, + supports_autograd=False, + is_factory_function=True, + error_inputs_func=error_inputs_arange, + sample_inputs_func=sample_inputs_arange, + skips=( + # https://github.com/pytorch/pytorch/issues/81774 + DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), + + # Tests that assume input is a tensor or sequence of tensors + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_variant_consistency_eager'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'), + + # Lazy tensor failures + DecorateInfo(unittest.expectedFailure, 'TestLazyOpInfo', 'test_dispatched_to_lazy'), + DecorateInfo(unittest.skip("Skipped!"), 'TestLazyOpInfo', 'test_correctness'), + DecorateInfo(unittest.skip("Skipped!"), 'TestLazyOpInfo', 'test_correctness_with_reusing_ir'), + + # Exception raised from analyzeImpl at ../torch/csrc/jit/ir/alias_analysis.cpp:608 + # We don't have an op for aten::arange but it isn't a special case. + # Argument types: bool, bool, bool, int, int, Device, boo + DecorateInfo(unittest.expectedFailure, 'TestNNCOpInfo', 'test_nnc_correctness'), + + # Captured graph does not contain aten::arange (succeeds on complex!) + # g: graph(): + # %25 : Long(1, strides=[1], requires_grad=0, device=cpu) = prim::Constant[value={1}]() + # return (%25) + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', dtypes=(torch.float32,)), + + # UserWarning not triggered : Resized a non-empty tensor but did not warn about it. + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'), + )), + OpInfo('cauchy', + op=lambda inp, *args, **kwargs: wrapper_set_seed(torch.Tensor.cauchy_, inp, *args, **kwargs), + inplace_variant=torch.Tensor.cauchy_, + dtypes=floating_types_and(torch.float16, torch.bfloat16), + supports_out=False, + supports_autograd=False, + allow_cow_input_materialize_forward=[0], + sample_inputs_func=sample_inputs_cauchy, + error_inputs_func=error_inputs_cauchy, + skips=( + # Tests that assume input tensor has a meaningful effect on output tensor + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_variant_consistency_eager'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'), + + # AssertionError: JIT Test does not execute any logic + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + + # AssertionError: Tensor-likes are not close! + DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'), + + # FX failed to normalize op - add the op to the op_skip list. + DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), + + # vmap: calling random operator not supported + DecorateInfo(unittest.skip("Test expects tensor input"), "TestVmapOperatorsOpInfo", "test_vmap_exhaustive"), + DecorateInfo(unittest.skip("Test expects tensor input"), "TestVmapOperatorsOpInfo", "test_op_has_batch_rule"), + + DecorateInfo(unittest.skip("make_traced() doesn't set seed properly!"), 'TestCommon', 'test_python_ref_executor'), + + DecorateInfo(unittest.expectedFailure, 'TestDecomp', 'test_quick'), + )), + OpInfo('exponential', + op=lambda inp, *args, **kwargs: wrapper_set_seed(torch.Tensor.exponential_, inp, *args, **kwargs), + inplace_variant=torch.Tensor.exponential_, + dtypes=floating_types_and(torch.float16, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + supports_out=False, + supports_autograd=False, + allow_cow_input_materialize_forward=[0], + sample_inputs_func=sample_inputs_exponential, + error_inputs_func=error_inputs_exponential, + skips=( + # Tests that assume input tensor has a meaningful effect on output tensor + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_variant_consistency_eager'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'), + + # AssertionError: JIT Test does not execute any logic + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + + # AssertionError: Tensor-likes are not close! + DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'), + + # FX failed to normalize op - add the op to the op_skip list. + DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), + + # vmap: calling random operator not supported + DecorateInfo(unittest.expectedFailure, "TestVmapOperatorsOpInfo", "test_vmap_exhaustive"), + DecorateInfo(unittest.expectedFailure, "TestVmapOperatorsOpInfo", "test_op_has_batch_rule"), + + DecorateInfo(unittest.expectedFailure, 'TestDecomp', 'test_quick'), + DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'), + )), + OpInfo('geometric', + op=lambda inp, *args, **kwargs: wrapper_set_seed(torch.Tensor.geometric_, inp, *args, **kwargs), + inplace_variant=torch.Tensor.geometric_, + dtypes=floating_types_and(torch.float16, torch.bfloat16, torch.int8, torch.int16, torch.int32, torch.int64, torch.uint8), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + supports_out=False, + supports_autograd=False, + allow_cow_input_materialize_forward=[0], + sample_inputs_func=sample_inputs_geometric, + error_inputs_func=error_inputs_geometric, + skips=( + # Tests that assume input tensor has a meaningful effect on output tensor + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_variant_consistency_eager'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'), + + # AssertionError: JIT Test does not execute any logic + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + + # AssertionError: Tensor-likes are not close! + DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'), + + # FX failed to normalize op - add the op to the op_skip list. + DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), + + # vmap: calling random operator not supported + DecorateInfo(unittest.skip("Test expects tensor input"), "TestVmapOperatorsOpInfo", "test_vmap_exhaustive"), + DecorateInfo(unittest.skip("Test expects tensor input"), "TestVmapOperatorsOpInfo", "test_op_has_batch_rule"), + + DecorateInfo(unittest.expectedFailure, 'TestDecomp', 'test_quick'), + )), + OpInfo('log_normal', + op=lambda inp, *args, **kwargs: wrapper_set_seed(torch.Tensor.log_normal_, inp, *args, **kwargs), + inplace_variant=torch.Tensor.log_normal_, + dtypes=floating_types_and(torch.float16, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + supports_out=False, + supports_autograd=False, + allow_cow_input_materialize_forward=[0], + sample_inputs_func=sample_inputs_log_normal, + error_inputs_func=error_inputs_log_normal, + skips=( + # Tests that assume input tensor has a meaningful effect on output tensor + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_variant_consistency_eager'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'), + + # AssertionError: JIT Test does not execute any logic + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + + # AssertionError: Tensor-likes are not close! + DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'), + # FX failed to normalize op - add the op to the op_skip list. + DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), + + # vmap: calling random operator not supported + DecorateInfo(unittest.skip("Test expects tensor input"), "TestVmapOperatorsOpInfo", "test_vmap_exhaustive"), + DecorateInfo(unittest.skip("Test expects tensor input"), "TestVmapOperatorsOpInfo", "test_op_has_batch_rule"), + DecorateInfo(unittest.expectedFailure, 'TestDecomp', 'test_quick'), + )), + OpInfo('normal', + variant_test_name='in_place', + op=lambda inp, *args, **kwargs: wrapper_set_seed(torch.Tensor.normal_, inp, *args, **kwargs), + inplace_variant=torch.Tensor.normal_, + dtypes=floating_and_complex_types_and(torch.float16, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + supports_out=False, + supports_autograd=False, + allow_cow_input_materialize_forward=[0], + sample_inputs_func=sample_inputs_normal, + error_inputs_func=error_inputs_normal, + skips=( + # Tests that assume input is a tensor or sequence of tensors + DecorateInfo(unittest.skip("Test expects tensor input"), "TestCommon", "test_noncontiguous_samples"), + + # Tests that assume input tensor has a meaningful effect on output tensor + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_variant_consistency_eager'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'), + DecorateInfo(unittest.expectedFailure, 'TestDecomp', 'test_quick'), + # AssertionError: JIT Test does not execute any logic + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + # AssertionError: Tensor-likes are not close! + DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'), + # FX failed to normalize op - add the op to the op_skip list. + DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), + # vmap: calling random operator not supported + DecorateInfo(unittest.skip("Test expects tensor input"), "TestVmapOperatorsOpInfo", "test_vmap_exhaustive"), + DecorateInfo(unittest.skip("Test expects tensor input"), "TestVmapOperatorsOpInfo", "test_op_has_batch_rule"), + )), + OpInfo('uniform', + op=lambda inp, *args, **kwargs: wrapper_set_seed(torch.Tensor.uniform_, inp, *args, **kwargs), + method_variant=None, + inplace_variant=torch.Tensor.uniform_, + dtypes=floating_and_complex_types_and(torch.bfloat16, torch.float16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + supports_out=False, + supports_autograd=False, + is_factory_function=False, + allow_cow_input_materialize_forward=[0], + sample_inputs_func=sample_inputs_uniform, + error_inputs_func=error_inputs_uniform, + skips=( + # FX failed to normalize op - add the op to the op_skip list. + DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), + # Tests that assume input tensor has a meaningful effect on output tensor + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_variant_consistency_eager'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'), + # AssertionError: JIT Test does not execute any logic + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + # aten.uniform was not decomposed + DecorateInfo(unittest.expectedFailure, 'TestDecomp', 'test_quick'), + DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'), + )), + BinaryUfuncInfo('clamp_max', + ref=_clamp_max_numpy, + dtypes=all_types_and(torch.bool, torch.float16, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32, torch.int8), + supports_forward_ad=True, + supports_rhs_python_scalar=False, + supports_fwgrad_bwgrad=True, + rhs_make_tensor_kwargs=dict(exclude_zero=False), + skips=( + # RuntimeError: "max_elementwise_cuda" not implemented for 'ComplexFloat' + DecorateInfo(unittest.expectedFailure, + 'TestBinaryUfuncs', + 'test_type_promotion', + device_type='cuda'), + # dispatch to lazy test failed + DecorateInfo(unittest.expectedFailure, 'TestLazyOpInfo', 'test_dispatched_to_lazy'), + # test error disabled since rhs non-tensor python scalar is supported + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_errors'), + )), + BinaryUfuncInfo('clamp_min', + ref=_clamp_min_numpy, + dtypes=all_types_and(torch.bool, torch.float16, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32, torch.int8), + supports_forward_ad=True, + supports_rhs_python_scalar=False, + supports_fwgrad_bwgrad=True, + rhs_make_tensor_kwargs=dict(exclude_zero=False), + skips=( + # RuntimeError: "min_elementwise_cuda" not implemented for 'ComplexFloat' + DecorateInfo(unittest.expectedFailure, + 'TestBinaryUfuncs', + 'test_type_promotion', + device_type='cuda'), + # dispatch to lazy test failed + DecorateInfo(unittest.expectedFailure, 'TestLazyOpInfo', 'test_dispatched_to_lazy'), + # test error disabled since rhs non-tensor python scalar is supported + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_errors'), + )), + BinaryUfuncInfo('mul', + aliases=('multiply',), + dtypes=all_types_and_complex_and(torch.chalf, torch.float16, torch.bfloat16, torch.bool), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + assert_autodiffed=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_two_python_scalars=True, + error_inputs_sparse_func=error_inputs_sparse_mul, + sample_inputs_sparse_coo_func=partial(sample_inputs_sparse_mul, layout=torch.sparse_coo), + sample_inputs_sparse_csr_func=partial(sample_inputs_sparse_mul, layout=torch.sparse_csr), + sample_inputs_sparse_csc_func=partial(sample_inputs_sparse_mul, layout=torch.sparse_csc), + sample_inputs_sparse_bsr_func=partial(sample_inputs_sparse_mul, layout=torch.sparse_bsr), + sample_inputs_sparse_bsc_func=partial(sample_inputs_sparse_mul, layout=torch.sparse_bsc)), + BinaryUfuncInfo('sub', + # NumPy has no builtin reference for the alpha kwarg, but it is easy enough to emulate + ref=lambda input, other, *, alpha=1: np.subtract(input, np.multiply(alpha, other)), + aliases=('subtract',), + dtypes=all_types_and_complex_and(torch.bfloat16, torch.float16, torch.chalf), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + assert_autodiffed=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + sample_inputs_func=sample_inputs_add_sub, + supports_two_python_scalars=True, + decorators=( + DecorateInfo( + toleranceOverride({torch.float16: tol(atol=1e-2, rtol=0), + torch.bfloat16: tol(atol=1e-5, rtol=5e-3), + torch.complex32: tol(atol=1e-5, rtol=1e-3)}), + 'TestBinaryUfuncs', 'test_reference_numerics'), + DecorateInfo( + toleranceOverride({torch.chalf: tol(atol=1e-2, rtol=0)}), + 'TestCommon', 'test_complex_half_reference_testing', device_type='cpu'), + DecorateInfo( + toleranceOverride({torch.chalf: tol(atol=5e-3, rtol=0)}), + 'TestDecomp', 'test_comprehensive', device_type='cpu'), + DecorateInfo( + toleranceOverride({torch.chalf: tol(atol=5e-3, rtol=0)}), + 'TestDecomp', 'test_quick', device_type='cpu'), + ), + skips=( + DecorateInfo(unittest.skip("Skipped!"), + 'TestBinaryUfuncs', + 'test_reference_numerics', + dtypes=(torch.uint8,)), + DecorateInfo(unittest.skip("Skipped!"), + 'TestBinaryUfuncs', + 'test_reference_numerics_small_values', + dtypes=(torch.uint8,)), + )), + OpInfo('addmm', + # This addmm OpInfo is for when alpha and beta are not both equal to 1. + # alpha=beta=1 is tested in the following opinfo, because that special case will + # trigger addmm being decomposed by a jit pass. + dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16), + dtypesIfROCM=floating_and_complex_types_and(torch.float16, torch.bfloat16), + dtypesIfCUDA=floating_and_complex_types_and(torch.float16, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + assert_autodiffed=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, + sample_inputs_func=sample_inputs_addmm, + skips=( + # Issue with conj and torch dispatch, see https://github.com/pytorch/pytorch/issues/82479 + DecorateInfo( + unittest.skip("Skipped!"), + 'TestSchemaCheckModeOpInfo', + 'test_schema_correctness', + dtypes=(torch.complex64, torch.complex128)), + DecorateInfo(toleranceOverride({torch.float16: tol(atol=1e-3, rtol=2e-3)}), + "TestConsistency", "test_output_grad_match", device_type="mps"), + )), + OpInfo('addmm', + # When alpha=beta=1 as compile-time constants, JIT will decompose addmm into mm and add. + variant_test_name='decomposed', + dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16), + dtypesIfCUDA=floating_and_complex_types_and(torch.float16, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + assert_autodiffed=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, + autodiff_nonfusible_nodes=['aten::add', 'aten::mm'], + sample_inputs_func=partial(sample_inputs_addmm, alpha=1, beta=1), + skips=( + # Issue with conj and torch dispatch, see https://github.com/pytorch/pytorch/issues/82479 + DecorateInfo( + unittest.skip("Skipped!"), + 'TestSchemaCheckModeOpInfo', + 'test_schema_correctness', + dtypes=(torch.complex64, torch.complex128)), + # https://github.com/pytorch/pytorch/issues/71784 + DecorateInfo(unittest.skip('Skipped!'), 'TestNNCOpInfo', 'test_nnc_correctness', + device_type='cpu', dtypes=(torch.float16,)), + )), + OpInfo('addmv', + dtypes=all_types_and_complex_and(torch.bfloat16, torch.float16), + dtypesIfCUDA=floating_types_and(torch.float16, torch.complex64, torch.complex128, + torch.bfloat16), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + decorators=[ + DecorateInfo( + toleranceOverride({torch.half: tol(atol=1e-5, rtol=3e-3)}), + 'TestInductorOpInfo', 'test_comprehensive', device_type='cpu'), + DecorateInfo(toleranceOverride({torch.float32: tol(atol=2e-5, rtol=3e-6)}), + "TestConsistency", "test_output_match", device_type="mps"), + DecorateInfo(toleranceOverride({torch.float32: tol(atol=2e-5, rtol=3e-6)}), + "TestConsistency", "test_output_grad_match", device_type="mps"), + ], + sample_inputs_func=sample_inputs_addmv), + OpInfo('addbmm', + ref=lambda M, batch1, batch2, beta=1, alpha=1: np.add(np.multiply(np.asarray(beta, dtype=M.dtype), M), + np.multiply(np.asarray(alpha, dtype=batch1.dtype), + np.sum(np.matmul(batch1, batch2), axis=0))), + dtypes=all_types_and_complex_and(torch.bfloat16, torch.float16), + dtypesIfCUDA=floating_and_complex_types_and(torch.float16, + *[torch.bfloat16] + if SM53OrLater or TEST_WITH_ROCM else []), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + # Runs very slowly on slow gradcheck - alternatively reduce input sizes + gradcheck_fast_mode=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + decorators=[ + DecorateInfo( + toleranceOverride({torch.float32: tol(atol=1.3e-05, rtol=1.3e-05), + torch.complex64: tol(atol=1e-05, rtol=1.2e-03)}), + 'TestCommon', 'test_numpy_refs'), + # MPS has slightly worse precision. Is this acceptable? + DecorateInfo( + toleranceOverride({torch.float32: tol(atol=1.3e-04, rtol=1.3e-04), + torch.complex64: tol(atol=1e-05, rtol=1.2e-03)}), + 'TestCommon', 'test_numpy_ref_mps'), + DecorateInfo( + toleranceOverride({torch.float32: tol(atol=1e-5, rtol=1e-5), + torch.bfloat16: tol(atol=2e-1, rtol=6e-1)}), + 'TestConsistency', + 'test_output_match', + ), + DecorateInfo( + toleranceOverride({torch.float32: tol(atol=1.5e-05, rtol=1e-05)}), + 'TestCommon', 'test_out'), + DecorateInfo( + toleranceOverride({torch.half: tol(atol=6e-3, rtol=1e-2)}), + 'TestInductorOpInfo', 'test_comprehensive', device_type='cpu'), + ], + skips=( + # NVIDIA only assures that bfloat16 is supported by bmm if SM >= 5.3 + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_dtypes', device_type='cuda', active_if=not SM53OrLater), + # addbmm does not correctly warn when resizing out= inputs + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'), + # https://github.com/pytorch/pytorch/issues/55907 + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_variant_consistency_eager'), + ), + sample_inputs_func=sample_inputs_addbmm), + OpInfo('baddbmm', + dtypes=all_types_and_complex_and(torch.bfloat16, torch.float16), + dtypesIfCUDA=floating_types_and(torch.float16, torch.complex64, torch.complex128, + torch.bfloat16), + backward_dtypesIfCUDA=floating_types_and(torch.float16, + *[torch.bfloat16] if SM53OrLater or TEST_WITH_ROCM else [], + torch.complex64, torch.complex128), + # Runs very slowly on slow gradcheck - alternatively reduce input sizes + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + gradcheck_fast_mode=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + decorators=[ + DecorateInfo( + toleranceOverride({torch.complex64: tol(atol=1e-05, rtol=1.2e-03)}), + 'TestCommon', 'test_variant_consistency_eager', device_type='cuda'), + # Higher differences starting with Zen3 or Alder Lake + DecorateInfo( + toleranceOverride({torch.complex64: tol(atol=4e-05, rtol=4e-06)}), + 'TestDecomp', 'test_quick', device_type='cpu'), + DecorateInfo( + toleranceOverride({torch.complex64: tol(atol=1e-05, rtol=1.2e-03)}), + 'TestMathBits', 'test_conj_view', device_type='cuda'), + ], + sample_inputs_func=sample_inputs_baddbmm, + skips=( + # Issue with conj and torch dispatch, see https://github.com/pytorch/pytorch/issues/82479 + DecorateInfo( + unittest.skip("Skipped!"), + 'TestSchemaCheckModeOpInfo', + 'test_schema_correctness', + dtypes=(torch.complex64, torch.complex128)), + )), + OpInfo('dot', + dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16), + dtypesIfCUDA=floating_and_complex_types_and(torch.float16, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + assert_autodiffed=True, + sample_inputs_func=sample_inputs_dot_vdot, + error_inputs_func=error_inputs_dot_vdot, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + skips=( + # Issue with conj and torch dispatch, see https://github.com/pytorch/pytorch/issues/82479 + DecorateInfo( + unittest.skip("Skipped!"), + 'TestSchemaCheckModeOpInfo', + 'test_schema_correctness', + dtypes=(torch.complex64, torch.complex128)), + )), + OpInfo('vdot', + dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16), + dtypesIfCUDA=floating_and_complex_types_and(torch.float16, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + sample_inputs_func=sample_inputs_dot_vdot, + error_inputs_func=error_inputs_dot_vdot, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + skips=( + # Issue with conj and torch dispatch, see https://github.com/pytorch/pytorch/issues/82479 + DecorateInfo( + unittest.skip("Skipped!"), + 'TestSchemaCheckModeOpInfo', + 'test_schema_correctness', + dtypes=(torch.complex64, torch.complex128)), + )), + OpInfo('bmm', + dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16), + dtypesIfCUDA=floating_and_complex_types_and(torch.float16, + *[torch.bfloat16] + if SM53OrLater or TEST_WITH_ROCM else []), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + assert_autodiffed=True, + assert_jit_shape_analysis=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + skips=( + # NVIDIA only assures that bfloat16 is supported by bmm if SM >= 5.3 + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_dtypes', device_type='cuda', active_if=not SM53OrLater), + DecorateInfo(toleranceOverride({torch.float32: tol(atol=1e-5, rtol=1e-5)}), + "TestCommon", "test_out"), + # Fast math on MacOS-13? + DecorateInfo( + toleranceOverride({torch.float32: tol(atol=2e-5, rtol=5e-6)}), + 'TestConsistency', + 'test_output_match', + active_if=lambda _: MACOS_VERSION < 14.0, + device_type='mps', + dtypes=(torch.float32,)), + ), + sample_inputs_func=sample_inputs_bmm), + OpInfo('mv', + dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16), + dtypesIfCUDA=floating_and_complex_types_and(torch.float16, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + assert_autodiffed=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + sample_inputs_func=sample_inputs_mv), + OpInfo('addr', + dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16), + # Reference: https://github.com/pytorch/pytorch/issues/50747 + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + skips=( + # Reference: https://github.com/pytorch/pytorch/issues/50747 + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_variant_consistency_eager', + dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16)), + ), + sample_inputs_func=sample_inputs_addr, + gradcheck_nondet_tol=GRADCHECK_NONDET_TOL), + OpInfo('addcmul', + dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + assert_autodiffed=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + skips=( + # TODO: update sample inputs with for_inplace_variant kwarg to support this test + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_variant_consistency_eager'), + ), + sample_inputs_func=sample_inputs_addcmul_addcdiv, + reference_inputs_func=partial( + reference_inputs_elementwise_ternary, sample_inputs_func=reference_inputs_addcmul_addcdiv)), + OpInfo('addcdiv', + dtypes=floating_and_complex_types_and(torch.float16, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + skips=( + # TODO: update sample inputs with for_inplace_variant kwarg to support this test + DecorateInfo(unittest.expectedFailure, + 'TestCommon', + 'test_variant_consistency_eager'), + ), + sample_inputs_func=sample_inputs_addcmul_addcdiv, + reference_inputs_func=partial( + reference_inputs_elementwise_ternary, sample_inputs_func=reference_inputs_addcmul_addcdiv)), + UnaryUfuncInfo('asin', + aliases=('arcsin', ), + ref=np.arcsin, + domain=(-1, 1), + supports_sparse=True, + supports_sparse_csr=True, + supports_sparse_csc=True, + supports_sparse_bsr=True, + supports_sparse_bsc=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + promotes_int_to_float=True, + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + dtypesIfCUDA=all_types_and_complex_and(torch.chalf, torch.bool, torch.half, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + assert_autodiffed=True, + decorators=[ + DecorateInfo( + toleranceOverride({torch.float16: tol(atol=1e-05, rtol=1e-03)}), + 'TestUnaryUfuncs', device_type='cuda' + ), + DecorateInfo( + toleranceOverride({torch.float32: tol(atol=8e-5, rtol=4e-5)}), + 'TestInductorOpInfo', 'test_comprehensive', device_type='cuda' + ), + DecorateInfo( + toleranceOverride({torch.complex64: tol(atol=5e-05, rtol=2e-05)}), + 'TestUnaryUfuncs', 'test_reference_numerics_extremal', device_type='cpu' + ), + precisionOverride({torch.bfloat16: 1e-2}), + ], + skips=( + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', + device_type='cuda', dtypes=[torch.cdouble], + active_if=IS_WINDOWS), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', + device_type='cuda', dtypes=[torch.cdouble], + active_if=IS_WINDOWS), + DecorateInfo(unittest.skip("Skipped! sparse backward not supported"), + 'TestSparseUnaryUfuncs', 'test_sparse_fn_grad'), + )), + # NOTE: derivative for inplace asinh is not implemented + UnaryUfuncInfo('asinh', + aliases=('arcsinh', ), + ref=np.arcsinh, + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + dtypesIfCUDA=all_types_and_complex_and(torch.chalf, torch.bool, torch.half, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + decorators=(precisionOverride({torch.bfloat16: 5e-2}),), + supports_inplace_autograd=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_sparse=True, + supports_sparse_csr=True, + supports_sparse_csc=True, + supports_sparse_bsr=True, + supports_sparse_bsc=True, + promotes_int_to_float=True, + skips=( + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', + device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', + device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_small', + device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_normal', + device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', + device_type='cuda', dtypes=[torch.cdouble], + active_if=IS_WINDOWS), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', + device_type='cuda', dtypes=[torch.cdouble], + active_if=IS_WINDOWS), + DecorateInfo(unittest.skip("Skipped! sparse backward not supported"), + 'TestSparseUnaryUfuncs', 'test_sparse_fn_grad'), + )), + UnaryUfuncInfo('atan', + aliases=('arctan', ), + ref=np.arctan, + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + dtypesIfCUDA=all_types_and_complex_and(torch.chalf, torch.bool, torch.half, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + assert_autodiffed=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_sparse=True, + supports_sparse_csr=True, + supports_sparse_csc=True, + supports_sparse_bsr=True, + supports_sparse_bsc=True, + promotes_int_to_float=True, + decorators=(precisionOverride({torch.bfloat16: 1e-2}),), + skips=( + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', + dtypes=[torch.cfloat, torch.cdouble], active_if=(IS_MACOS or IS_WINDOWS)), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', + dtypes=[torch.cfloat, torch.cdouble], active_if=(IS_MACOS or IS_WINDOWS)), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_small', + dtypes=[torch.cfloat, torch.cdouble], active_if=(IS_MACOS or IS_WINDOWS)), + DecorateInfo(unittest.skip("Skipped! sparse backward not supported"), + 'TestSparseUnaryUfuncs', 'test_sparse_fn_grad'), + )), + BinaryUfuncInfo('atan2', + aliases=('arctan2',), + dtypes=all_types_and(torch.bool, torch.bfloat16, torch.half), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + promotes_int_to_float=True, + supports_rhs_python_scalar=False, + skips=( + # Incorrectly attempts to use a scalar for the second argument + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_jit_alias_remapping'), + )), + UnaryUfuncInfo('atanh', + aliases=('arctanh', ), + ref=np.arctanh, + domain=(-1, 1), + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + dtypesIfCUDA=all_types_and_complex_and(torch.chalf, torch.bool, torch.half, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + decorators=[ + precisionOverride({torch.bfloat16: 1e-2}), + DecorateInfo( + toleranceOverride({torch.float32: tol(atol=9e-3, rtol=8e-5)}), + "TestInductorOpInfo", + "test_comprehensive", + device_type="cuda" + ), + DecorateInfo(toleranceOverride({torch.float16: tol(atol=1e-3, rtol=2e-3)}), + "TestConsistency", "test_output_grad_match", device_type="mps"), + ], + supports_inplace_autograd=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_sparse=True, + supports_sparse_csr=True, + supports_sparse_csc=True, + supports_sparse_bsr=True, + supports_sparse_bsc=True, + promotes_int_to_float=True, + skips=( + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_small', + device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', + device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', + device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', + device_type='cuda', dtypes=[torch.cfloat, torch.cdouble], + active_if=IS_WINDOWS), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', + device_type='cuda', dtypes=[torch.cfloat], + active_if=IS_WINDOWS), + DecorateInfo(unittest.skip("Skipped! sparse backward not supported"), + 'TestSparseUnaryUfuncs', 'test_sparse_fn_grad'), + )), + OpInfo('allclose', + dtypes=floating_and_complex_types_and(torch.float16, torch.bfloat16), + ref=np.allclose, + supports_autograd=False, + supports_forward_ad=False, + sample_inputs_func=sample_inputs_allclose, + skips=( + DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'), + DecorateInfo(unittest.skip("Skipped!"), 'TestNNCOpInfo', 'test_nnc_correctness'), + DecorateInfo(unittest.skip("Skipped!"), 'TestCudaFuserOpInfo'), + ), + supports_out=False), + OpInfo('broadcast_to', + ref=np.broadcast_to, + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # See https://github.com/pytorch/pytorch/pull/78358 + check_batched_forward_grad=False, + sample_inputs_func=sample_inputs_broadcast_to), + OpInfo('broadcast_shapes', + op=torch.broadcast_shapes, + ref=np.broadcast_shapes if np.lib.NumpyVersion(np.__version__) >= '1.20.0' else None, + dtypes=_dispatch_dtypes((torch.float32,)), + supports_out=False, + supports_gradgrad=False, + assert_autodiffed=False, + supports_autograd=False, + supports_scripting=False, + sample_inputs_func=sample_inputs_broadcast_shapes, + skips=( + # https://github.com/pytorch/pytorch/issues/64997 + DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), + # skip dtype tests since broadcast_shape is not device dependent. + # having dtypes limited to torch.float32 would cause test_dtypes to report unexpected success + DecorateInfo(unittest.skip('Skipped!'), 'TestCommon', 'test_dtypes'), + # skip these tests since we have non tensor input + DecorateInfo(unittest.skip('Skipped!'), "TestCommon", "test_noncontiguous_samples"), + DecorateInfo(unittest.skip('Skipped!'), 'TestCommon', 'test_variant_consistency_eager'), + DecorateInfo(unittest.skip('Skipped!'), 'TestJit', 'test_variant_consistency_jit'), + )), + OpInfo('broadcast_tensors', + ref=np.broadcast_arrays, + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + sample_inputs_func=sample_inputs_broadcast_tensors, + reference_inputs_func=reference_inputs_broadcast_tensors, + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # See https://github.com/pytorch/pytorch/pull/78358 + check_batched_forward_grad=False, + skips=( + # https://github.com/pytorch/pytorch/issues/64997 + DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), + # JIT does not support variadic tensors. + # RuntimeError: input->type()->kind() == TypeKind::OptionalType + # INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":252, + # please report a bug to PyTorch. + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', dtypes=[torch.float32]), + )), + OpInfo('block_diag', + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf), + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # Default batching rule in core doesn't work for ops with TensorList args + check_batched_forward_grad=False, + skips=( + # https://github.com/pytorch/pytorch/issues/64997 + DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), + # JIT does not support variadic tensors. + # RuntimeError: input->type()->kind() == TypeKind::OptionalType + # INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":252, + # please report a bug to PyTorch. + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', dtypes=[torch.float32]), + ), + sample_inputs_func=sample_inputs_block_diag), + UnaryUfuncInfo('bitwise_not', + ref=np.bitwise_not, + dtypes=integral_types_and(torch.bool), + dtypesIfHpu=custom_types(torch.bool), + operator_variant=operator.invert, + supports_autograd=False), + BinaryUfuncInfo('bitwise_left_shift', + op=torch.bitwise_left_shift, + dtypes=integral_types(), + dtypesIfCUDA=integral_types(), + dtypesIfHpu=custom_types(torch.int32, torch.int8, torch.bool), + operator_variant=operator.lshift, + inplace_operator_variant=operator.ilshift, + supports_autograd=False, + supports_one_python_scalar=True, + rhs_make_tensor_kwargs=dict(low=0), + skips=( + DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', 'test_type_promotion'), + # https://github.com/pytorch/pytorch/issues/70904 + DecorateInfo(unittest.skip("Some inputs produce undefined outputs"), 'TestCommon', 'test_compare_cpu'), + )), + BinaryUfuncInfo('bitwise_right_shift', + op=torch.bitwise_right_shift, + dtypes=integral_types(), + dtypesIfCUDA=integral_types(), + dtypesIfHpu=custom_types(torch.int32, torch.int8, torch.bool), + operator_variant=operator.rshift, + inplace_operator_variant=operator.irshift, + supports_autograd=False, + supports_one_python_scalar=True, + rhs_make_tensor_kwargs=dict(low=0), + skips=( + DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', 'test_type_promotion'), + # https://github.com/pytorch/pytorch/issues/70904 + DecorateInfo(unittest.skip("Some inputs produce undefined outputs"), 'TestCommon', 'test_compare_cpu'), + )), + OpInfo('combinations', + op=torch.combinations, + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # See https://github.com/pytorch/pytorch/pull/78358 + check_batched_forward_grad=False, + supports_out=False, + sample_inputs_func=sample_inputs_combinations), + OpInfo('cartesian_prod', + op=torch.cartesian_prod, + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # See https://github.com/pytorch/pytorch/pull/78358 + check_batched_forward_grad=False, + sample_inputs_func=sample_inputs_cartesian_prod, + skips=( + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'), + DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), + # RuntimeError: input->type()->kind() == TypeKind::OptionalType + # INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":270 + DecorateInfo(unittest.expectedFailure, + 'TestJit', 'test_variant_consistency_jit', dtypes=(torch.float32,)), + )), + OpInfo('cdist', + dtypes=floating_types(), + supports_out=False, + supports_gradgrad=False, + assert_autodiffed=False, + sample_inputs_func=sample_inputs_cdist), + UnaryUfuncInfo('ceil', + ref=np.ceil, + dtypes=all_types_and(torch.half, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32, torch.int8), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + skips=( + DecorateInfo(unittest.expectedFailure, + 'TestNNCOpInfo', + 'test_nnc_correctness', + dtypes=tuple(t for t in integral_types() if t != torch.uint8)), + ), + supports_sparse=True, + supports_sparse_csr=True, + supports_sparse_csc=True, + supports_sparse_bsr=True, + supports_sparse_bsc=True, + assert_autodiffed=True), + OpInfo('cholesky', + dtypes=floating_and_complex_types(), + sample_inputs_func=sample_inputs_linalg_cholesky, + gradcheck_wrapper=gradcheck_wrapper_hermitian_input, + decorators=[skipCUDAIfNoMagma, skipCPUIfNoLapack],), + OpInfo('cholesky_inverse', + dtypes=floating_and_complex_types(), + backward_dtypes=floating_and_complex_types(), + # https://github.com/pytorch/pytorch/issues/80411 + gradcheck_fast_mode=True, + supports_fwgrad_bwgrad=True, + supports_forward_ad=True, + check_batched_gradgrad=True, + sample_inputs_func=sample_inputs_linalg_cholesky_inverse, + gradcheck_wrapper=gradcheck_wrapper_triangular_input_real_positive_diagonal, + decorators=[ + skipCUDAIfNoMagma, + skipCPUIfNoLapack, + DecorateInfo( + toleranceOverride({ + torch.float32: tol(atol=5e-03, rtol=1e-04) + }), + 'TestCommon', device_type='cpu', + ), + DecorateInfo( + toleranceOverride({ + torch.float32: tol(atol=5e-03, rtol=1e-04) + }), + 'TestEagerFusionOpInfo', device_type='cpu', + ), + ], + skips=( + # Strides are not the same! Original strides were ((4, 2, 1),) and strides are now ((4, 1, 2),) + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'),), + ), + OpInfo('cholesky_solve', + op=torch.cholesky_solve, + dtypes=floating_and_complex_types(), + sample_inputs_func=sample_inputs_cholesky_solve, + check_batched_gradgrad=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + gradcheck_wrapper=lambda *args, **kwargs: gradcheck_wrapper_triangular_input(*args, idx=1, **kwargs), + decorators=[skipCUDAIfNoMagma, skipCPUIfNoLapack]), + OpInfo('chunk', + dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16, torch.chalf), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32), + sample_inputs_func=sample_inputs_chunk, + reference_inputs_func=reference_inputs_chunk, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_out=False), + OpInfo('unsafe_chunk', + dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16, torch.chalf), + sample_inputs_func=sample_inputs_chunk, + check_batched_forward_grad=False, + reference_inputs_func=reference_inputs_chunk, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_out=False), + OpInfo('clone', + ref=np.copy, + dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16, torch.chalf), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32, torch.int8, torch.bool), + sample_inputs_func=sample_inputs_clone_contiguous, + reference_inputs_func=reference_inputs_clone_contiguous, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_out=False, + skips=( + # TypeError: _copy_dispatcher() got an unexpected keyword argument 'memory_format' + # (NumPy reference needs to be extended with memory_format) + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_numpy_ref'), + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_numpy_ref_mps'), + ),), + OpInfo('contiguous', + op=lambda x, *args, **kwargs: x.contiguous(*args, **kwargs), + dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16, torch.chalf), + sample_inputs_func=sample_inputs_clone_contiguous, + reference_inputs_func=reference_inputs_clone_contiguous, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + autodiff_fusible_nodes=['aten::contiguous'], + assert_jit_shape_analysis=True, + supports_out=False, + skips=( + DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), + )), + OpInfo('sum_to_size', + op=lambda x, *args, **kwargs: x.sum_to_size(*args, **kwargs), + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), + sample_inputs_func=sample_inputs_sum_to_size, + error_inputs_func=error_inputs_sum_to_size, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_out=False, + skips=( + # lambda impl + DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', dtypes=(torch.float,)), + )), + OpInfo('clamp', + aliases=('clip',), + ref=_clamp_numpy, + dtypes=all_types_and(torch.bfloat16, torch.half), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32, torch.int8, torch.bool), + sample_inputs_func=sample_inputs_clamp, + reference_inputs_func=partial(reference_inputs_elementwise_ternary, sample_inputs_func=sample_inputs_clamp), + assert_autodiffed=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + skips=( + # NNC appear to not handle boolean clamp + DecorateInfo(unittest.expectedFailure, + 'TestNNCOpInfo', + 'test_nnc_correctness', + dtypes=(torch.bool,)), + # MPS does not support float64, while numpy does internal computations in float64. + # See https://github.com/pytorch/pytorch/blob/3c1cf03fde145bdbe1f5ffb81765d076c10b4c04/test/test_ops.py#L260-L264 + DecorateInfo(unittest.expectedFailure, + 'TestCommon', + 'test_numpy_ref_mps'), + )), + UnaryUfuncInfo('positive', + ref=np.positive, + dtypes=all_types_and_complex_and(torch.half, torch.bfloat16, torch.chalf), + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_sparse=True, + supports_sparse_csr=True, + supports_sparse_csc=True, + supports_sparse_bsr=True, + supports_sparse_bsc=True, + ), + UnaryUfuncInfo('conj', + ref=np.conj, + dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, + torch.half, torch.chalf), + dtypesIfHpu=custom_types(torch.float32, torch.int32), + supports_sparse=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # See https://github.com/pytorch/pytorch/pull/78358 + check_batched_forward_grad=False, + supports_out=False), + UnaryUfuncInfo('conj_physical', + decomp_aten_name='_conj_physical', + ref=np.conj, + dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, + torch.half, torch.chalf), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_sparse=True, + supports_sparse_csr=True, + supports_sparse_csc=True, + supports_sparse_bsr=True, + supports_sparse_bsc=True, + skips=( + # RuntimeError: inputSet && outputSet + # INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":118, + # please report a bug to PyTorch. + DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit', dtypes=(torch.float32, )), + DecorateInfo(unittest.skip("Skipped! conj_physical_ not implemented for sparse"), + 'TestSparseUnaryUfuncs', 'test_inplace'), + )), + OpInfo('resolve_conj', + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + sample_inputs_func=sample_inputs_view_as_real, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_out=False, + ), + OpInfo('resolve_neg', + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + sample_inputs_func=sample_inputs_view_as_real, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_out=False, + ), + OpInfo('view_as_real', + dtypes=complex_types(), + supports_forward_ad=True, + supports_out=False, + supports_fwgrad_bwgrad=True, + sample_inputs_func=sample_inputs_view_as_real, + test_conjugated_samples=False, + ), + OpInfo('view_as_complex', + dtypes=floating_types_and(torch.half), + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + test_neg_view=False, + sample_inputs_func=sample_inputs_view_as_complex, + skips=( + # RuntimeError: Tensor must have a last dimension with stride 1 + DecorateInfo(unittest.expectedFailure, "TestCommon", "test_noncontiguous_samples"), + # RuntimeError: "eq_cpu" not implemented for 'ComplexHalf' + DecorateInfo(unittest.skip("Skipped!"), 'TestNNCOpInfo', 'test_nnc_correctness', dtypes=(torch.half,)), + # RuntimeError: view size is not compatible with input tensor's size and stride + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides"), + )), + BinaryUfuncInfo('complex', + dtypes=floating_types_and(torch.half), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_rhs_python_scalar=False, + error_inputs_func=error_inputs_complex, + skips=( + # Tests don't account for complex's type promotion semantics + DecorateInfo(unittest.expectedFailure, 'TestBinaryUfuncs', 'test_type_promotion'), + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_out', device_type='mps'), + DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_binary_ufuncs_mixed_dtype'),)), + BinaryUfuncInfo('copysign', + sample_inputs_func=sample_inputs_copysign, + dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + promotes_int_to_float=True, + # https://github.com/pytorch/pytorch/issues/80411 + gradcheck_fast_mode=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True), + OpInfo('corrcoef', + dtypes=all_types_and_complex_and(torch.half, torch.bfloat16), + sample_inputs_func=sample_inputs_corrcoef, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # See https://github.com/pytorch/pytorch/pull/78358 + check_batched_forward_grad=False, + skips=( + # Issue with conj and torch dispatch, see https://github.com/pytorch/pytorch/issues/82479 + DecorateInfo( + unittest.skip("Skipped!"), + 'TestSchemaCheckModeOpInfo', + 'test_schema_correctness', + dtypes=(torch.complex64, torch.complex128)), + ), + supports_out=False), + UnaryUfuncInfo('cos', + ref=np.cos, + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + dtypesIfCUDA=all_types_and_complex_and(torch.chalf, torch.bool, torch.half, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + assert_autodiffed=True, + handles_large_floats=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + promotes_int_to_float=True, + decorators=(precisionOverride({torch.bfloat16: 1e-2}),), + skips=( + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', + dtypes=(torch.cfloat, torch.cdouble,), device_type='cpu', active_if=IS_WINDOWS), + # This fails on CUDA but passes on ROCm + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', + dtypes=(torch.cdouble,), device_type='cuda'), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', + dtypes=[torch.cfloat, torch.cdouble], active_if=IS_WINDOWS), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', + device_type='cpu', + dtypes=[torch.cfloat, torch.cdouble], active_if=IS_MACOS), + # AssertionError: Tensor-likes are not close! + # Greatest absolute difference: nan at index (700,) (up to 1e-05 allowed) + # Greatest relative difference: nan at index (700,) (up to 0.001 allowed) + DecorateInfo(unittest.expectedFailure, 'TestUnaryUfuncs', 'test_reference_numerics_large', + device_type='cuda', + dtypes=(torch.chalf,), active_if=IS_WINDOWS), + )), + UnaryUfuncInfo('cosh', + ref=np_unary_ufunc_integer_promotion_wrapper(np.cosh), + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + dtypesIfCUDA=all_types_and_complex_and(torch.chalf, torch.bool, torch.half, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + assert_autodiffed=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + promotes_int_to_float=True, + skips=( + # Reference: https://github.com/pytorch/pytorch/issues/48641 + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', + device_type='cpu', dtypes=[torch.int8]), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', + dtypes=[torch.cdouble]), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', + dtypes=[torch.cfloat, torch.cdouble], active_if=IS_WINDOWS), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', + dtypes=[torch.cfloat, torch.cdouble], active_if=IS_WINDOWS), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', + device_type='cpu', + dtypes=[torch.cfloat, torch.cdouble], active_if=IS_MACOS), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', + device_type='cpu', + dtypes=[torch.cfloat, torch.cdouble], active_if=IS_MACOS), + # AssertionError: Tensor-likes are not close! + # Greatest absolute difference: nan at index (6000,) (up to 1e-05 allowed) + # Greatest relative difference: nan at index (6000,) (up to 0.001 allowed) + DecorateInfo(unittest.expectedFailure, 'TestUnaryUfuncs', 'test_reference_numerics_large', + device_type='cuda', + dtypes=(torch.chalf,), active_if=IS_WINDOWS), + )), + OpInfo('cov', + dtypes=all_types_and_complex_and(torch.half, torch.bfloat16), + sample_inputs_func=sample_inputs_cov, + error_inputs_func=error_inputs_cov, + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + skips=( + # Issue with conj and torch dispatch, see https://github.com/pytorch/pytorch/issues/82479 + DecorateInfo( + unittest.skip("Skipped!"), + 'TestSchemaCheckModeOpInfo', + 'test_schema_correctness', + dtypes=(torch.complex64, torch.complex128)), + # Float did not match double + DecorateInfo(unittest.expectedFailure, 'TestBwdGradients', 'test_fn_grad'), + # Jacobian mismatch + DecorateInfo(unittest.expectedFailure, 'TestBwdGradients', 'test_fn_gradgrad'), + DecorateInfo(unittest.expectedFailure, 'TestFwdGradients', 'test_forward_mode_AD'), + DecorateInfo(unittest.skip("Barely fails"), 'TestFwdGradients', 'test_fn_fwgrad_bwgrad'), + # JIT test not working for tensor kwargs (https://github.com/pytorch/pytorch/issues/58507) + # RuntimeError: + # undefined value tensor: + # File "", line 3 + # def the_method(i0): + # return torch.cov(i0, correction=0, fweights=None, aweights=tensor([0.0518, 0.4681], dtype=torch.float32, requires_grad=True)) # noqa: B950 + # ~~~~~~ <--- HERE + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + DecorateInfo(toleranceOverride({torch.float16: tol(atol=8e-3, rtol=1.4e-3)}), + "TestInductorOpInfo", "test_comprehensive", device_type="cpu"), + DecorateInfo(toleranceOverride({torch.float32: tol(atol=3e-4, rtol=1e-4)}), + "TestConsistency", "test_output_grad_match", device_type="mps"), + )), + OpInfo('cross', + dtypes=all_types_and_complex_and(torch.half, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32), + sample_inputs_func=sample_inputs_cross, + supports_fwgrad_bwgrad=True, + supports_out=True, + supports_forward_ad=True), + OpInfo('cumsum', + dtypes=all_types_and_complex_and(torch.half, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + skips=( + # cumsum does not handle correctly out= dtypes + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'), + ), + sample_inputs_func=sample_inputs_cumulative_ops), + OpInfo('cumprod', + dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + skips=( + # cumprod does not handle correctly out= dtypes + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'), + ), + # gradgradcheck fails in fast_mode=True: #56275 + sample_inputs_func=sample_inputs_cumprod, + gradcheck_fast_mode=False), + OpInfo('cummax', + dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16), + sample_inputs_func=partial(sample_inputs_cumulative_ops, supports_dtype_kwargs=False), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + skips=( + ), + gradcheck_nondet_tol=GRADCHECK_NONDET_TOL), + OpInfo('cummin', + dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16), + sample_inputs_func=partial(sample_inputs_cumulative_ops, supports_dtype_kwargs=False), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + skips=( + ), + gradcheck_nondet_tol=GRADCHECK_NONDET_TOL), + UnaryUfuncInfo('deg2rad', + ref=np.radians, + decorators=(precisionOverride({torch.bfloat16: 7e-1, + torch.float16: 7e-1}),), + dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_sparse=True, + supports_sparse_csr=True, + supports_sparse_csc=True, + supports_sparse_bsr=True, + supports_sparse_bsc=True, + promotes_int_to_float=True), + OpInfo('diff', + op=torch.diff, + # np.diff has np._NoValue as default values for prepend and append, compare_with_reference breaks if prepend/append + # are set as None when converting to numpy + ref=lambda input, n=1, dim=-1, prepend=np._NoValue, append=np._NoValue: ( + np.diff(input, n, dim, np._NoValue if prepend is None else prepend, np._NoValue if append is None else append) + ), + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), + # Runs very slowly on slow gradcheck - alternatively reduce input sizes + gradcheck_fast_mode=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + sample_inputs_func=sample_inputs_diff, + error_inputs_func=error_inputs_diff, + # See https://github.com/pytorch/pytorch/pull/78358 + check_batched_forward_grad=False, + skips=( + )), + BinaryUfuncInfo('div', + aliases=('divide',), + variant_test_name='no_rounding_mode', + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32, torch.int8), + # Runs very slowly on slow gradcheck - alternatively reduce input sizes + gradcheck_fast_mode=True, + supports_forward_ad=True, + promotes_int_to_float=True, + supports_fwgrad_bwgrad=True, + supports_two_python_scalars=True, + assert_autodiffed=True, + rhs_make_tensor_kwargs=dict(exclude_zero=True),), + BinaryUfuncInfo('div', + aliases=('divide',), + variant_test_name='trunc_rounding', + dtypes=all_types_and(torch.half, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32, torch.int8), + sample_kwargs=lambda device, dtype, input: + ({"rounding_mode": "trunc"}, {"rounding_mode": "trunc"}), + # https://github.com/pytorch/pytorch/issues/80411 + gradcheck_fast_mode=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_two_python_scalars=True, + assert_autodiffed=True, + rhs_make_tensor_kwargs=dict(exclude_zero=True), + decorators=( + # See https://github.com/pytorch/pytorch/issues/111126 + DecorateInfo(unittest.expectedFailure, 'TestBinaryUfuncs', 'test_type_promotion'), + ), + skips=( + # RuntimeError: MALFORMED INPUT: Unhandled node kind (in computeValue): aten::div + DecorateInfo(unittest.expectedFailure, 'TestNNCOpInfo', 'test_working'), + # FIXME: + # torch.autograd.gradcheck.GradcheckError: Jacobian mismatch for + # output 0 with respect to input 1, + # numerical:tensor(-17746.9307, dtype=torch.float64) + # analytical:tensor(0., dtype=torch.float64) + DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients', + 'test_fn_grad', device_type='cpu', + dtypes=(torch.float64,)), + )), + BinaryUfuncInfo('div', + aliases=('divide',), + variant_test_name='floor_rounding', + dtypes=all_types_and(torch.half, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32, torch.int8), + sample_kwargs=lambda device, dtype, input: + ({"rounding_mode": "floor"}, {"rounding_mode": "floor"}), + # https://github.com/pytorch/pytorch/issues/80411 + gradcheck_fast_mode=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_two_python_scalars=True, + assert_autodiffed=True, + rhs_make_tensor_kwargs=dict(exclude_zero=True), + decorators=( + # See https://github.com/pytorch/pytorch/issues/111126 + DecorateInfo(unittest.expectedFailure, 'TestBinaryUfuncs', 'test_type_promotion'), + ), + skips=( + # RuntimeError: MALFORMED INPUT: Unhandled node kind (in computeValue): aten::div + DecorateInfo(unittest.expectedFailure, 'TestNNCOpInfo', 'test_working'), + # FIXME: + # torch.autograd.gradcheck.GradcheckError: Jacobian mismatch for + # output 0 with respect to input 1, + # numerical:tensor(-17746.9307, dtype=torch.float64) + # analytical:tensor(0., dtype=torch.float64) + DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients', + 'test_fn_grad', + dtypes=(torch.float64,), + device_type='cpu'), + DecorateInfo(unittest.skip("Broken on MacOS13"), + 'TestConsistency', + 'test_output_match', + device_type='mps', + dtypes=(torch.float16,), + active_if=lambda _: MACOS_VERSION < 14.0), + )), + BinaryUfuncInfo('true_divide', + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf), + supports_forward_ad=True, + promotes_int_to_float=True, + supports_fwgrad_bwgrad=True, + supports_two_python_scalars=True, + rhs_make_tensor_kwargs=dict(exclude_zero=True)), + OpInfo('equal', + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32, torch.int8, torch.bool), + ref=lambda input, other: (input == other).all(), + sample_inputs_func=sample_inputs_equal, + supports_autograd=False, + supports_tracing=False, + skips=( + )), + UnaryUfuncInfo('exp', + ref=np_unary_ufunc_integer_promotion_wrapper(np.exp), + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + skips=( + # Reference: https://github.com/pytorch/pytorch/issues/48010 + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', + device_type='cpu', dtypes=[torch.cfloat, torch.cdouble], active_if=IS_WINDOWS), + ), + assert_autodiffed=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + promotes_int_to_float=True), + OpInfo('expand', + op=lambda self, shape: self.expand(shape), + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32), + sample_inputs_func=sample_inputs_expand, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + assert_jit_shape_analysis=True, + supports_out=False, + skips=( + DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), + )), + OpInfo('expand_as', + op=lambda self, other: self.expand_as(other), + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32, torch.int8, torch.bool), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + sample_inputs_func=sample_inputs_expand_as, + supports_out=False, + skips=( + DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),), + ), + OpInfo('expand_copy', + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + sample_inputs_func=sample_inputs_expand, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + assert_jit_shape_analysis=True, + supports_out=True, + skips=( + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', dtypes=(torch.float32,)), + )), + OpInfo('diag', + ref=np.diag, + dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16), + dtypesIfCUDA=all_types_and_complex_and(torch.chalf, torch.bool, torch.half, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + check_batched_forward_grad=False, + sample_inputs_func=sample_inputs_diag, + error_inputs_func=error_inputs_diag), + OpInfo('diag_embed', + dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16, torch.chalf), + supports_out=False, + # Runs very slowly on slow gradcheck - alternatively reduce input sizes + gradcheck_fast_mode=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + sample_inputs_func=sample_inputs_diagonal_diag_embed, + reference_inputs_func=reference_inputs_diagonal_diag_embed, + error_inputs_func=error_inputs_diagonal_diag_embed), + OpInfo('diagonal', + aten_backward_name='diagonal_backward', + dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16, torch.chalf), + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + sample_inputs_func=sample_inputs_diagonal_diag_embed, + reference_inputs_func=reference_inputs_diagonal_diag_embed, + error_inputs_func=error_inputs_diagonal_diag_embed), + OpInfo('diagonal_copy', + dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16, torch.chalf), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + sample_inputs_func=sample_inputs_diagonal_diag_embed, + reference_inputs_func=reference_inputs_diagonal_diag_embed, + error_inputs_func=error_inputs_diagonal_diag_embed), + OpInfo('diagonal_scatter', + dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16), + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + sample_inputs_func=sample_inputs_diagonal_scatter), + OpInfo('alias_copy', + dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16, torch.chalf), + sample_inputs_func=sample_inputs_alias_copy, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_out=True), + BinaryUfuncInfo('eq', + ref=np.equal, + dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16, torch.chalf), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32, torch.int8, torch.bool), + always_returns_bool=True, + supports_autograd=False, + sample_inputs_func=sample_inputs_comparison_ops, + skips=( + )), + BinaryUfuncInfo('fmax', + op=torch.fmax, + dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32, torch.int8, torch.bool), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_rhs_python_scalar=False, + skips=( + # RuntimeError: "max_elementwise_cuda" not implemented for 'ComplexFloat' + DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', 'test_type_promotion'), + )), + BinaryUfuncInfo('fmin', + op=torch.fmin, + dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32, torch.int8, torch.bool), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_rhs_python_scalar=False, + skips=( + # RuntimeError: "min_elementwise_cuda" not implemented for 'ComplexFloat' + DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', 'test_type_promotion'), + )), + BinaryUfuncInfo('fmod', + ref=np.fmod, + dtypes=all_types_and(torch.float16, torch.bfloat16), + dtypesIfCUDA=all_types_and(torch.float16, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32), + # https://github.com/pytorch/pytorch/issues/80411 + gradcheck_fast_mode=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + assert_autodiffed=None, + rhs_make_tensor_kwargs={'exclude_zero': True}, + decorators=( + DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', + 'test_contig_vs_every_other', + dtypes=(torch.bfloat16,)), + DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', + 'test_non_contig', + dtypes=(torch.bfloat16,)), + DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', + 'test_reference_numerics', + dtypes=(torch.bfloat16,)), + DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', + 'test_reference_numerics_small_values', + dtypes=(torch.uint8,)), + # FIXME: + # torch.autograd.gradcheck.GradcheckError: Jacobian mismatch for + # output 0 with respect to input 1, + # numerical:tensor(101.6283, dtype=torch.float64) + # analytical:tensor(-18.3575, dtype=torch.float64) + DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients', + 'test_fn_grad', + dtypes=(torch.float64,), + device_type='cpu'), + )), + BinaryUfuncInfo('remainder', + ref=np.remainder, + dtypes=all_types_and(torch.float16, torch.bfloat16), + dtypesIfCUDA=all_types_and(torch.float16, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32, torch.int8, torch.bool), + # https://github.com/pytorch/pytorch/issues/80411 + gradcheck_fast_mode=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + assert_autodiffed=None, + operator_variant=operator.mod, + inplace_operator_variant=operator.imod, + supports_one_python_scalar=True, + rhs_make_tensor_kwargs={'exclude_zero': True}, + decorators=( + DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', + 'test_contig_vs_every_other', + dtypes=(torch.bfloat16,)), + DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', + 'test_non_contig', + dtypes=(torch.bfloat16,)), + DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', + 'test_reference_numerics', + dtypes=(torch.bfloat16,)), + DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', + 'test_reference_numerics_small_values', + dtypes=(torch.uint8,)), + DecorateInfo(unittest.skip("Skipped!"), 'TestNNCOpInfo', + 'test_nnc_correctness', + dtypes=(torch.bfloat16,)), + # Fails on XLA + # False is not true : Tensors failed to compare as equal! + # Attempted to compare equality of tensors with different dtypes + DecorateInfo(unittest.skip("Skipped!"), 'TestOpInfo', device_type='xla', dtypes=(torch.long,)), + # FIXME: + # torch.autograd.gradcheck.GradcheckError: Jacobian mismatch for + # output 0 with respect to input 1, + # numerical:tensor(102.4676, dtype=torch.float64) + # analytical:tensor(-17.5182, dtype=torch.float64) + DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients', + 'test_fn_grad', device_type='cpu', + dtypes=(torch.float64,)), + DecorateInfo( + toleranceOverride({ + torch.float16: tol(atol=5e-4, rtol=3e-3), + }), + "TestInductorOpInfo", + "test_comprehensive", + device_type="cuda" + ), + DecorateInfo(unittest.skip("Broken on MacOS13"), + 'TestConsistency', + 'test_output_match', + device_type='mps', + dtypes=(torch.float16,), + active_if=lambda _: MACOS_VERSION < 14.0), + )), + UnaryUfuncInfo('frac', + ref=lambda x: np.modf(x)[0], + dtypes=floating_types_and(torch.bfloat16, torch.float16), + dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + assert_autodiffed=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_sparse=True, + supports_sparse_csr=True, + supports_sparse_csc=True, + supports_sparse_bsr=True, + supports_sparse_bsc=True, + skips=( + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', + dtypes=(torch.bfloat16, torch.float16, torch.float32, torch.float64)), + # 76047 + DecorateInfo(unittest.expectedFailure, 'TestNNCOpInfo', 'test_nnc_correctness', + dtypes=(torch.bfloat16, torch.float32, torch.float64)), + )), + OpInfo('stft', + decorators=[ + skipCPUIfNoFFT, + DecorateInfo(unittest.skip("Skipped! stft does not match the native function"), + 'TestJit', 'test_variant_consistency_jit'), + ], + dtypes=floating_and_complex_types(), + sample_inputs_func=sample_inputs_stft, + # Runs very slowly on slow gradcheck - alternatively reduce input sizes + gradcheck_fast_mode=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + check_batched_forward_grad=False, + check_batched_grad=False, + check_batched_gradgrad=False, + supports_out=False, + gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, + ), + OpInfo('istft', + dtypes=complex_types(), + sample_inputs_func=sample_inputs_istft, + # Runs very slowly on slow gradcheck - alternatively reduce input sizes + gradcheck_fast_mode=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + check_batched_forward_grad=False, + check_batched_grad=False, + check_batched_gradgrad=False, + supports_out=False, + decorators=( + DecorateInfo(unittest.skip("Skipped! istft does not match the native function"), + 'TestJit', 'test_variant_consistency_jit'), + ), + skips=( + skipCPUIfNoFFT, + # gradcheck fails on ROCm (gh-68429) + # grad is computed improperly (probably for weights tensor) + DecorateInfo(unittest.expectedFailure, 'TestBwdGradients', 'test_fn_grad'), + # Pre-existing condition (calls .item); needs to be fixed + DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_backward'), + )), + UnaryUfuncInfo('floor', + ref=np.floor, + dtypes=all_types_and(torch.half, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32, torch.int8), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + skips=( + DecorateInfo(unittest.expectedFailure, + 'TestNNCOpInfo', + 'test_nnc_correctness', + dtypes=tuple(t for t in integral_types() if t != torch.uint8)), + ), + supports_sparse=True, + supports_sparse_csr=True, + supports_sparse_csc=True, + supports_sparse_bsr=True, + supports_sparse_bsc=True, + assert_autodiffed=True), + OpInfo('flip', + op=torch.flip, + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32, torch.int8, torch.bool), + sample_inputs_func=sample_inputs_flip, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_out=False), + OpInfo('fliplr', + op=torch.fliplr, + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + sample_inputs_func=sample_inputs_fliplr_flipud, + error_inputs_func=error_inputs_fliplr, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_out=False), + OpInfo('flipud', + op=torch.flipud, + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + sample_inputs_func=sample_inputs_fliplr_flipud, + error_inputs_func=error_inputs_flipud, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_out=False), + OpInfo('sparse.sampled_addmm', + dtypes=floating_and_complex_types(), + supports_autograd=True, + sample_inputs_func=sample_inputs_sparse_sampled_addmm, + decorators=[ + skipCPUIfNoMklSparse, + skipXPU], + skips=( + # NotImplementedError: Tensors of type SparseCsrTensorImpl do not have is_contiguous + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_noncontiguous_samples'), + # RuntimeError: Sparse CSR tensors do not have strides. + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_out'), + DecorateInfo(unittest.skip("Skipped!"), 'TestTags', 'test_tags'), + # RuntimeError: sampled_addmm: Expected result to have sparse csr layout, but got Strided + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_out_warning'), + # RuntimeError: Sparse CSR tensors do not have strides + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_variant_consistency_eager'), + # RuntimeError: Sparse CSR tensors do not have strides + DecorateInfo(unittest.skip("Skipped!"), 'TestCompositeCompliance', 'test_operator'), + # RuntimeError: Sparse CSR tensors do not have strides + DecorateInfo(unittest.skip("Skipped!"), 'TestCompositeCompliance', 'test_backward'), + # RuntimeError: Sparse CSR tensors do not have strides + DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_conj_view'), + # RuntimeError: Sparse CSR tensors do not have strides + DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_conj_view'), + # RuntimeError: Sparse CSR tensors do not have strides + DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_view'), + # RuntimeError: Sparse CSR tensors do not have strides + DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'), + # RuntimeError: unsupported memory format option Preserve + DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'), + # RuntimeError: sparse_mask does not support automatic differentiation for outputs with complex dtype + # RuntimeError: Sparse CSR tensors do not have strides + DecorateInfo(unittest.skip("Skipped!"), 'TestFwdGradients', 'test_fn_fwgrad_bwgrad'), + # ValueError: Sparse output is not supported at gradcheck yet. Please call to_dense(masked_grad=...) ... + DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients', 'test_fn_grad'), + # RuntimeError: sparse_mask does not support automatic differentiation for outputs with complex dtype. + # RuntimeError: Sparse CSR tensors do not have is_contiguous + DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients', 'test_fn_gradgrad'), + # ValueError: Sparse output is not supported at gradcheck yet. Please call to_dense(masked_grad=...) ... + DecorateInfo(unittest.skip("Skipped!"), 'TestFwdGradients', 'test_forward_mode_AD'), + # NotImplementedError: Could not run 'aten::sparse_sampled_addmm' with arguments from the 'SparseCsrMeta' backend. + DecorateInfo(unittest.skip("Skipped!"), 'TestMeta', 'test_dispatch_meta_outplace'), + DecorateInfo(unittest.skip("Skipped!"), 'TestMeta', 'test_dispatch_symbolic_meta_outplace'), + DecorateInfo(unittest.skip("Skipped!"), 'TestMeta', 'test_meta_outplace'), + DecorateInfo(unittest.skip("Skipped!"), 'TestMeta', 'test_dispatch_symbolic_meta_outplace_all_strides'), + DecorateInfo(unittest.skip("Skipped!"), 'TestFakeTensor', 'test_fake_crossref_backward_no_amp'), + )), + OpInfo('sparse.mm', + dtypes=floating_types_and(torch.bfloat16, torch.float16), + variant_test_name='reduce', + supports_autograd=True, + supports_out=False, + supports_gradgrad=False, + supports_forward_ad=False, + sample_inputs_func=sample_inputs_sparse_mm_reduce, + decorators=[onlyCPU], + skips=( + # NotImplementedError: Tensors of type SparseCsrTensorImpl do not have is_contiguous + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_noncontiguous_samples'), + # RuntimeError: Sparse CSR tensors do not have strides. + DecorateInfo(unittest.skip("Skipped!"), 'TestTags', 'test_tags'), + # RuntimeError: Sparse CSR tensors do not have strides + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_variant_consistency_eager'), + # RuntimeError: Sparse CSR tensors do not have strides + DecorateInfo(unittest.skip("Skipped!"), 'TestCompositeCompliance', 'test_operator'), + # RuntimeError: Sparse CSR tensors do not have strides + DecorateInfo(unittest.skip("Skipped!"), 'TestCompositeCompliance', 'test_backward'), + # RuntimeError: Sparse CSR tensors do not have strides + DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_conj_view'), + # RuntimeError: Sparse CSR tensors do not have strides + DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_conj_view'), + # RuntimeError: Sparse CSR tensors do not have strides + DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_view'), + # RuntimeError: Sparse CSR tensors do not have strides + DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'), + # RuntimeError: unsupported memory format option Preserve + DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'), + # ValueError: Sparse output is not supported at gradcheck yet. Please call to_dense(masked_grad=...) ... + DecorateInfo(unittest.skip("Skipped!"), 'TestFwdGradients', 'test_fn_fwgrad_bwgrad'), + # RuntimeError: Sparse CSR tensors do not have is_contiguou + DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients', 'test_fn_grad'), + # ValueError: Sparse output is not supported at gradcheck yet. Please call to_dense(masked_grad=...) ... + DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients', 'test_fn_gradgrad'), + # RuntimeError: Sparse CSR tensors do not have strides + DecorateInfo(unittest.skip("Skipped!"), 'TestFwdGradients', 'test_forward_mode_AD'), + # ValueError: Sparse output is not supported at gradcheck yet. Please call to_dense(masked_grad=...) ... + DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients', 'test_fn_fail_gradgrad'), + # NotImplementedError: Could not run 'aten::_sparse_mm_reduce_impl' with arguments from the 'SparseCsrMeta' backend + DecorateInfo(unittest.skip("Skipped!"), 'TestMeta', 'test_dispatch_meta_outplace'), + DecorateInfo(unittest.skip("Skipped!"), 'TestMeta', 'test_dispatch_symbolic_meta_outplace'), + DecorateInfo(unittest.skip("Skipped!"), 'TestMeta', 'test_meta_outplace'), + )), + UnaryUfuncInfo('i0', + ref=np_unary_ufunc_integer_promotion_wrapper( + scipy.special.i0) if TEST_SCIPY else None, + aliases=('special.i0',), + decorators=(precisionOverride({torch.bfloat16: 3e-1, + torch.float16: 5e-1}),), + dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + promotes_int_to_float=True, + sample_inputs_func=sample_inputs_i0_i1, + skips=( + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', + dtypes=(torch.int8,)), + )), + BinaryUfuncInfo('floor_divide', + ref=_floor_divide_np, + dtypes=all_types_and(torch.half, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32, torch.int8, torch.bool), + + supports_autograd=False, + rhs_make_tensor_kwargs=dict(exclude_zero=True), + supports_two_python_scalars=True, + skips=( + # AssertionError: Results of original model and exported/imported version of model differed + DecorateInfo(unittest.skip('Skipped!'), 'TestJit', 'test_variant_consistency_jit'), + # bfloat16 floor_divide compared with a float32 reference works inconsistently + DecorateInfo(unittest.skip('Skipped!'), 'TestBinaryUfuncs', + dtypes=(torch.bfloat16,)), + # int8 floor divide has different results for -128 // -1 vs. NumPy + DecorateInfo(unittest.skip('Skipped!'), 'TestBinaryUfuncs', 'test_reference_numerics_small_values', + dtypes=(torch.int8,)), + # The following tests fails on some jobs + DecorateInfo(unittest.skip('Skipped!'), 'TestBinaryUfuncs', 'test_reference_numerics_extremal_values', + dtypes=(torch.float16,)), + DecorateInfo(toleranceOverride({torch.float16: tol(atol=1e-3, rtol=5e-3)}), + 'TestBinaryUfuncs', 'test_reference_numerics'), + )), + UnaryUfuncInfo('frexp', + op=torch.frexp, + ref=np.frexp, + dtypes=floating_types_and(torch.half, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + # skip testing torch.frexp as it is not supported by ROCm platform yet + decorators=[], + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + skips=( + # skips below tests as torch.frexp returns tuple-like (mantissa, exponent) as outputs, + # while these tests currently requires output to a single tensor. + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_batch_vs_slicing'), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_contig_vs_every_other'), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_contig_vs_transposed'), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_non_contig_expand'), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_variant_consistency'), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_out_arg_all_dtypes'), + + # skips test_reference_numerics due to error in Windows CI. + # The np.frexp returns exponent as np.intc dtype on Windows platform, + # and np.intc does not have the correspond torch dtype + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_small', + active_if=IS_WINDOWS), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', + active_if=IS_WINDOWS), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', + active_if=IS_WINDOWS), + )), + UnaryUfuncInfo('log1p', + ref=np.log1p, + aliases=('special.log1p',), + domain=(-1, None), + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + decorators=(precisionOverride({torch.bfloat16: 1e-1}),), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_sparse=True, + supports_sparse_csr=True, + supports_sparse_csc=True, + supports_sparse_bsr=True, + supports_sparse_bsc=True, + assert_autodiffed=True, + promotes_int_to_float=True), + BinaryUfuncInfo('ge', + ref=np.greater_equal, + aliases=('greater_equal',), + dtypes=all_types_and(torch.bool, torch.bfloat16, torch.float16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32), + always_returns_bool=True, + supports_autograd=False, + skips=( + )), + OpInfo('geqrf', + dtypes=floating_and_complex_types(), + sample_inputs_func=sample_inputs_linalg_qr_geqrf, + decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack], + supports_autograd=False, + skips=( + # FIXME: geqrf can't forward with complex inputs that require grad + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_dtypes'), + # Strides are not the same! + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'), + )), + BinaryUfuncInfo('gt', + ref=np.greater, + aliases=('greater',), + dtypes=all_types_and(torch.bool, torch.bfloat16, torch.float16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32, torch.int8), + always_returns_bool=True, + supports_autograd=False, + skips=( + )), + UnaryUfuncInfo('imag', + ref=np.imag, + dtypes=complex_types_and(torch.chalf), + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # See https://github.com/pytorch/pytorch/issues/66357 + # RuntimeError: view_as_real doesn't work on unresolved conjugated tensors. + check_batched_forward_grad=False, + skips=( + # Skip since real and imag don't have out variants. + DecorateInfo(unittest.expectedFailure, 'TestUnaryUfuncs', 'test_out_arg_all_dtypes'), + )), + OpInfo('gradient', + dtypes=floating_and_complex_types_and(torch.int8, torch.int16, + torch.int32, torch.int64, + torch.bfloat16, torch.half), + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # See https://github.com/pytorch/pytorch/pull/78358 + check_batched_forward_grad=False, + skips=( + DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), + # following tests give a runtime error with undefined value tensor + # see discussion : https://github.com/pytorch/pytorch/issues/56660 + # RuntimeError: + # Arguments for call are not valid. + DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit', dtypes=(torch.float32, torch.complex64)), # noqa: B950 + DecorateInfo(unittest.skip("Skipped!"), 'TestNNCOpInfo', 'test_nnc_correctness'), + DecorateInfo(unittest.skip("Skipped!"), 'TestCudaFuserOpInfo'), + ), + supports_inplace_autograd=False, + sample_inputs_func=sample_inputs_gradient, + error_inputs_func=error_inputs_gradient), + OpInfo('isin', + dtypes=all_types_and(torch.bfloat16, torch.half), + supports_autograd=False, + sample_inputs_func=sample_inputs_isin), + OpInfo('kthvalue', + dtypes=all_types_and(torch.bfloat16, torch.float16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + sample_inputs_func=sample_inputs_kthvalue, + error_inputs_func=error_inputs_kthvalue), + BinaryUfuncInfo('le', + ref=np.less_equal, + aliases=('less_equal',), + dtypes=all_types_and(torch.bool, torch.bfloat16, torch.float16), + always_returns_bool=True, + supports_autograd=False, + skips=( + )), + OpInfo('linspace', + dtypes=all_types_and_complex_and(torch.bfloat16, torch.float16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32), + is_factory_function=True, + supports_out=True, + supports_autograd=False, + error_inputs_func=error_inputs_linspace, + sample_inputs_func=sample_inputs_linspace, + skips=( + # FX failed to normalize op - add the op to the op_skip list. + DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), + # Tests that assume input is a tensor or sequence of tensors + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_variant_consistency_eager'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'), + + # Same failure as arange: cannot find linspace in captured graph + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', dtypes=(torch.float32,)), + + # UserWarning not triggered : Resized a non-empty tensor but did not warn about it. + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'), + # UserWarning: CUDA caching allocator reports a memory leak not verified by the driver API + # in __main__.TestJitCUDA.test_variant_consistency_jit_logspace_cuda_complex64! + # Caching allocator allocated memory was 0 and is now reported as 307200 on device 0. + # CUDA driver allocated memory was 1254555648 and is now 1242955776. + DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit', + dtypes=(torch.cfloat,), device_type="cuda"), + )), + OpInfo('linspace', + dtypes=all_types_and_complex_and(torch.bfloat16, torch.float16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32), + is_factory_function=True, + supports_out=True, + supports_autograd=False, + error_inputs_func=error_inputs_linspace, + sample_inputs_func=sample_inputs_linspace_tensor_overload, + variant_test_name="tensor_overload", + skips=( + # FX failed to normalize op - add the op to the op_skip list. + DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), + # TypeError: 'int' object is not subscriptable + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_variant_consistency_eager'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'), + + # Same failure as arange: cannot find linspace in captured graph + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', dtypes=(torch.float32,)), + + # UserWarning not triggered : Resized a non-empty tensor but did not warn about it. + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'), + # UserWarning: CUDA caching allocator reports a memory leak not verified by the driver API + # in __main__.TestJitCUDA.test_variant_consistency_jit_logspace_cuda_complex64! + # Caching allocator allocated memory was 0 and is now reported as 307200 on device 0. + # CUDA driver allocated memory was 1254555648 and is now 1242955776. + DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit', + dtypes=(torch.cfloat,), device_type="cuda"), + )), + OpInfo('logspace', + dtypes=all_types_and_complex_and(torch.half, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + is_factory_function=True, + supports_out=True, + supports_autograd=False, + error_inputs_func=error_inputs_linspace, + sample_inputs_func=sample_inputs_logspace, + skips=( + # FX failed to normalize op - add the op to the op_skip list. + DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), + # Tests that assume input is a tensor or sequence of tensors + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_variant_consistency_eager'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'), + # Same failure as arange: cannot find linspace in captured graph + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', dtypes=(torch.float32,)), + + # UserWarning not triggered : Resized a non-empty tensor but did not warn about it. + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'), + + # Off-by-one issue when casting floats to ints + DecorateInfo(unittest.expectedFailure, 'TestDecomp', 'test_quick', + dtypes=(torch.int16, torch.int32, torch.int64), device_type="cuda"), + DecorateInfo(unittest.expectedFailure, 'TestDecomp', 'test_comprehensive', + dtypes=(torch.int16, torch.int32, torch.int64), device_type="cuda"), + # UserWarning: CUDA caching allocator reports a memory leak not verified by the driver API + # in __main__.TestJitCUDA.test_variant_consistency_jit_logspace_cuda_complex64! + # Caching allocator allocated memory was 0 and is now reported as 307200 on device 0. + # CUDA driver allocated memory was 1254555648 and is now 1242955776. + DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit', + dtypes=(torch.cfloat,), device_type="cuda"), + )), + OpInfo('logspace', + dtypes=all_types_and_complex_and(torch.half, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + is_factory_function=True, + supports_out=True, + supports_autograd=False, + error_inputs_func=error_inputs_linspace, + sample_inputs_func=sample_inputs_logspace_tensor_overload, + variant_test_name="tensor_overload", + skips=( + # FX failed to normalize op - add the op to the op_skip list. + DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), + # TypeError: 'int' object is not subscriptable + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_variant_consistency_eager'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'), + # Same failure as arange: cannot find linspace in captured graph + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', dtypes=(torch.float32,)), + + # UserWarning not triggered : Resized a non-empty tensor but did not warn about it. + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'), + + # Off-by-one issue when casting floats to ints + DecorateInfo(unittest.expectedFailure, 'TestDecomp', 'test_quick', + dtypes=(torch.int16, torch.int32, torch.int64), device_type="cuda"), + DecorateInfo(unittest.expectedFailure, 'TestDecomp', 'test_comprehensive', + dtypes=(torch.int16, torch.int32, torch.int64), device_type="cuda"), + # UserWarning: CUDA caching allocator reports a memory leak not verified by the driver API + # in __main__.TestJitCUDA.test_variant_consistency_jit_logspace_cuda_complex64! + # Caching allocator allocated memory was 0 and is now reported as 307200 on device 0. + # CUDA driver allocated memory was 1254555648 and is now 1242955776. + DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit', + dtypes=(torch.cfloat,), device_type="cuda"), + )), + UnaryUfuncInfo('log', + ref=np.log, + domain=(0, None), + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf), + backward_dtypesIfCUDA=floating_and_complex_types_and(torch.half, torch.bfloat16, torch.chalf), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + assert_autodiffed=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + promotes_int_to_float=True, + decorators=(precisionOverride({torch.bfloat16: 5e-2}),), + skips=( + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', + device_type='cpu', dtypes=[torch.cfloat, torch.cdouble], + active_if=IS_WINDOWS), + ), + # log(z)->-inf for |z|->0 + reference_numerics_filter=NumericsFilter(condition=lambda x: torch.abs(x) < 0.1, safe_val=1)), + UnaryUfuncInfo('log10', + ref=np.log10, + domain=(0, None), + decorators=(precisionOverride({torch.bfloat16: 5e-2}),), + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + assert_autodiffed=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + promotes_int_to_float=True, + skips=( + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', + device_type='cpu', dtypes=[torch.cfloat, torch.cdouble], + active_if=IS_WINDOWS), + ), + # log10(z)->-inf for |z|->0 + reference_numerics_filter=NumericsFilter(condition=lambda x: torch.abs(x) < 0.1, safe_val=1)), + UnaryUfuncInfo('log2', + ref=np.log2, + domain=(0, None), + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + assert_autodiffed=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + promotes_int_to_float=True, + decorators=(precisionOverride({torch.bfloat16: 1e-1}),), + skips=( + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', + dtypes=[torch.cfloat, torch.cdouble]), + ), + # log2(z)->-inf for |z|->0 + reference_numerics_filter=NumericsFilter(condition=lambda x: torch.abs(x) < 0.1, safe_val=1)), + BinaryUfuncInfo('ldexp', + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + # Runs very slowly on slow gradcheck - alternatively reduce input sizes + gradcheck_fast_mode=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_inplace_autograd=False, + promotes_int_to_float=True, + supports_out=True, + supports_rhs_python_scalar=False, + skips=( + # RuntimeError: mul(): functions with out=... arguments don't support + # automatic differentiation, but one of the arguments requires grad + # https://github.com/pytorch/pytorch/issues/68966 + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_variant_consistency_eager'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'), + ), + decorators=[ + DecorateInfo( + toleranceOverride({ + torch.complex64: tol(atol=1e-05, rtol=1e-05) + }), + 'TestCommon', device_type='cpu', + ), + ], ), + BinaryUfuncInfo('logaddexp', + dtypes=floating_and_complex_types_and(torch.bfloat16, torch.float16), + dtypesIfCUDA=floating_and_complex_types_and(torch.bfloat16, torch.float16, torch.complex32), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_rhs_python_scalar=False), + OpInfo('logaddexp2', + dtypes=floating_types_and(torch.bfloat16, torch.half), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + sample_inputs_func=sample_inputs_logaddexp), + UnaryUfuncInfo('logical_not', + ref=np.logical_not, + decorators=(precisionOverride({torch.bfloat16: 7e-1, + torch.float16: 5e-1}),), + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int8, torch.bool), + supports_autograd=False, + skips=( + # The function variant always returns BoolTensor + # while the inplace variant preserves the input dtype. + # >>> t = torch.randn(3) + # >>> torch.logical_not(t) + # tensor([False, False, False]) + # >>> torch.logical_not(t).dtype + # torch.bool + # >>> t.logical_not_().dtype + # torch.float32 + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_variant_consistency', + dtypes=all_types_and_complex_and(torch.half, torch.bfloat16)), + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_variant_consistency_eager', + dtypes=all_types_and_complex_and(torch.half, torch.bfloat16)), + )), + BinaryUfuncInfo('lt', + ref=np.less, + aliases=('less',), + dtypes=all_types_and(torch.bool, torch.bfloat16, torch.float16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int8, torch.int32), + always_returns_bool=True, + supports_autograd=False, + skips=( + )), + OpInfo('lu_unpack', + op=torch.lu_unpack, + dtypes=floating_and_complex_types(), + # Runs very slowly on slow gradcheck - alternatively reduce input sizes + gradcheck_fast_mode=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + skips=(skipCPUIfNoLapack,), + sample_inputs_func=sample_inputs_lu_unpack), + OpInfo('lu', + op=torch.lu, + dtypes=floating_and_complex_types(), + # Runs very slowly on slow gradcheck - alternatively reduce input sizes + gradcheck_fast_mode=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # https://github.com/pytorch/pytorch/issues/66357 + check_batched_forward_grad=False, + sample_inputs_func=sample_inputs_lu, + decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack], + skips=( + # we skip jit tests because `lu` is a torch function + # RuntimeError: + # 'Tensor (inferred)' object has no attribute or method 'lu'.: + # File "", line 3 + # def the_method(i0): + # return i0.lu(True, True) + # ~~~~~ <--- HERE + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + # RuntimeError not raised: Expected RuntimeError when calling with input.device=cpu and out.device=cuda + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'), + # UserWarning not triggered : Resized a non-empty tensor but did not warn about it. + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'), + )), + OpInfo('lu_solve', + op=torch.lu_solve, + dtypes=floating_and_complex_types(), + supports_forward_ad=True, + # See https://github.com/pytorch/pytorch/issues/66357 + check_batched_forward_grad=False, + supports_fwgrad_bwgrad=True, + sample_inputs_func=sample_inputs_lu_solve, + skips=( + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_out', + device_type='mps', dtypes=[torch.float32]), + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_variant_consistency_eager', + device_type='mps', dtypes=[torch.float32]), + DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit', + device_type='mps', dtypes=[torch.float32]), + DecorateInfo(unittest.skip("Tests different backward paths"), + "TestCommon", "test_floating_inputs_are_differentiable"),), + decorators=[skipCPUIfNoLapack, skipCUDAIfNoMagmaAndNoCusolver]), + OpInfo('masked_fill', + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int8, torch.bool, torch.int32), + sample_inputs_func=sample_inputs_masked_fill, + error_inputs_func=error_inputs_masked_fill, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + check_batched_forward_grad=False, + supports_out=False), + OpInfo('masked_scatter', + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int8, torch.bool, torch.int32), + sample_inputs_func=sample_inputs_masked_scatter, + error_inputs_func=error_inputs_masked_scatter, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # https://github.com/pytorch/pytorch/issues/66357 + check_batched_forward_grad=False, + supports_out=False, + skips=( + # Compiler issue on ROCm. Regression started in ROCm 6.4. + DecorateInfo(unittest.skip('Skipped!'), 'TestCommon', 'test_non_standard_bool_values', + dtypes=[torch.bool], active_if=TEST_WITH_ROCM), + )), + OpInfo('masked_select', + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + sample_inputs_func=sample_inputs_masked_select, + error_inputs_func=error_inputs_masked_select, + skips=( + # Compiler issue on ROCm. Might need to skip until ROCm5.5 + DecorateInfo(unittest.skip('Skipped!'), 'TestCommon', 'test_non_standard_bool_values', + dtypes=[torch.bool], active_if=TEST_WITH_ROCM), + )), + OpInfo('matrix_exp', + dtypes=floating_and_complex_types_and(torch.float16, torch.bfloat16), + aliases=('linalg.matrix_exp',), + sample_inputs_func=sample_inputs_matrix_exp, + # Needs to construct a 2nx2n matrix by copy_ ing into it + check_batched_grad=False, + check_batched_gradgrad=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # https://github.com/pytorch/pytorch/issues/66357 + check_batched_forward_grad=False, + skips=( + # mexp does not support bf16 and fp16 + DecorateInfo(unittest.skip('Skipped!'), 'TestInductorOpInfo', 'test_comprehensive', + dtypes=[torch.half], device_type="cpu"), + ), + supports_out=False, + ), + OpInfo('matmul', + aliases=('linalg.matmul',), + dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16), + dtypesIfCUDA=floating_and_complex_types_and(torch.float16, + *[torch.bfloat16] + if SM53OrLater or TEST_WITH_ROCM else []), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + assert_autodiffed=True, + assert_jit_shape_analysis=True, + # Runs very slowly on slow gradcheck - alternatively reduce input sizes + gradcheck_fast_mode=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + check_batched_forward_grad=False, + sample_inputs_func=partial(sample_inputs_matmul, is_rmatmul=False), + decorators=[ + # NVIDIA only assures that bfloat16 is supported by bmm if SM >= 5.3 + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_dtypes', device_type='cuda', active_if=not SM53OrLater), + # ROCm intermittently fails the test with standard atol/rtol + DecorateInfo(toleranceOverride({torch.float32: tol(atol=1e-4, rtol=0)}), + 'TestCommon', 'test_noncontiguous_samples', device_type='cuda', + active_if=TEST_WITH_ROCM), + DecorateInfo(toleranceOverride({torch.float32: tol(atol=1e-4, rtol=0)}), + 'TestCommon', 'test_out', device_type='cuda', + active_if=TEST_WITH_ROCM), + # mv for the sample with shapes (S, S, M, M), (M,) has some variance in the + # backward on CPU + DecorateInfo(toleranceOverride({torch.float32: tol(atol=0, rtol=1e-5)}), + 'TestCommon', 'test_noncontiguous_samples', + device_type='cpu'), + DecorateInfo( + toleranceOverride({ + torch.float32: tol(atol=1e-5, rtol=1e-5), + torch.complex64: tol(atol=1e-5, rtol=1e-5), + }), + "TestDecomp", "test_comprehensive", device_type="cuda", + ), + ], + skips=( + # Strides are not the same! + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'), + # https://github.com/pytorch/pytorch/issues/67470 + DecorateInfo(unittest.skip("67470!"), + 'TestCommon', 'test_noncontiguous_samples', + device_type='cpu', dtypes=(torch.long,)), + # AssertionError: False is not true : Tensors failed to compare as equal! + DecorateInfo(unittest.skip("Skipped!"), 'TestOpInfo', + device_type='xla', dtypes=(torch.long,)), + # https://github.com/pytorch/pytorch/issues/71774 + DecorateInfo(unittest.skip('Skipped!'), 'TestNNCOpInfo', 'test_nnc_correctness', + device_type='cpu', dtypes=(torch.long,)), + )), + OpInfo('max', + variant_test_name='reduction_with_dim', + dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32), + sample_inputs_func=sample_inputs_max_min_reduction_with_dim, + supports_fwgrad_bwgrad=True, + skips=( + ), + supports_forward_ad=True), + OpInfo('max', + variant_test_name='reduction_no_dim', + dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32), + supports_out=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + sample_inputs_func=sample_inputs_max_min_reduction_no_dim, + skips=( + )), + OpInfo('median', + dtypes=all_types_and(torch.bfloat16, torch.float16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32), + # TODO: some signatures of median do support out + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + error_inputs_func=error_inputs_median, + sample_inputs_func=partial(sample_inputs_reduction, supports_multiple_dims=False)), + OpInfo('nanmedian', + dtypes=all_types_and(torch.bfloat16, torch.float16), + # TODO: some signatures of nanmedian do support out + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + sample_inputs_func=partial(sample_inputs_reduction, supports_multiple_dims=False)), + OpInfo('var_mean', + dtypes=floating_and_complex_types_and(torch.half, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + sample_inputs_func=sample_inputs_std_var, + # TODO: some signatures of var_mean do support out + supports_out=False, + supports_forward_ad=True, + check_batched_forward_grad=False, + supports_fwgrad_bwgrad=True, + decorators=( + DecorateInfo(toleranceOverride({torch.float64: tol(atol=2e-7, rtol=2e-7)}), + "TestDecomp", "test_comprehensive", device_type="cuda"), + DecorateInfo(toleranceOverride({torch.float16: tol(atol=1e-3, rtol=2e-3)}), + "TestInductorOpInfo", "test_comprehensive", device_type="cuda"), + )), + OpInfo('var_mean', + variant_test_name='unbiased', + dtypes=floating_and_complex_types_and(torch.half, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + sample_inputs_func=sample_inputs_std_var_unbiased, + # TODO: some signatures of var_mean do support out + supports_out=False, + supports_forward_ad=True, + check_batched_forward_grad=False, + supports_fwgrad_bwgrad=True, + decorators=( + DecorateInfo(toleranceOverride({torch.float64: tol(atol=2e-7, rtol=2e-7)}), + "TestDecomp", "test_comprehensive", device_type="cuda"), + DecorateInfo(toleranceOverride({torch.float16: tol(atol=1e-3, rtol=2e-3)}), + "TestInductorOpInfo", "test_comprehensive", device_type="cuda"), + )), + OpInfo('std_mean', + dtypes=floating_and_complex_types_and(torch.half, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + sample_inputs_func=sample_inputs_std_var, + # TODO: some signatures of std_mean do support out + supports_out=False, + supports_forward_ad=True, + check_batched_forward_grad=False, + supports_fwgrad_bwgrad=True, + decorators=( + DecorateInfo(toleranceOverride({torch.float64: tol(atol=2e-7, rtol=2e-7)}), + "TestDecomp", "test_comprehensive", device_type="cuda"), + )), + OpInfo('std_mean', + variant_test_name='unbiased', + dtypes=floating_and_complex_types_and(torch.half, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + sample_inputs_func=sample_inputs_std_var_unbiased, + # TODO: some signatures of var_mean do support out + supports_out=False, + supports_forward_ad=True, + check_batched_forward_grad=False, + supports_fwgrad_bwgrad=True, + decorators=( + DecorateInfo( + toleranceOverride({ + torch.float16: tol(atol=4e-5, rtol=9e-3), + torch.float64: tol(atol=2e-7, rtol=2e-7), + }), + "TestDecomp", + "test_comprehensive", + device_type="cuda" + ), + DecorateInfo( + toleranceOverride({ + torch.float16: tol(atol=4e-5, rtol=9e-3), + torch.float64: tol(atol=2e-7, rtol=2e-7), + }), + "TestInductorOpInfo", + "test_comprehensive", + device_type="cuda" + ), + )), + OpInfo('meshgrid', + variant_test_name='variadic_tensors', + ref=np.meshgrid, + dtypes=all_types_and_complex_and(torch.bfloat16, torch.bool, torch.float16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + sample_inputs_func=partial(sample_inputs_meshgrid, variant='variadic'), + skips=[ + # JIT does not support variadic tensors. + # RuntimeError: input->type()->kind() == TypeKind::OptionalType + # INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":252, + # please report a bug to PyTorch. + DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'), + # meshgrid is defined in torch.functional to take a + # variadic list of tensors. Variadic parameters are not + # compatible with the normalize operator tests. + DecorateInfo(unittest.skip("Skipped!"), 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), + # Skip operator schema test because this is a functional and not an operator + DecorateInfo(unittest.skip("Skipped!"), 'TestOperatorSignatures', 'test_get_torch_func_signature_exhaustive'), + ], + supports_out=False, + supports_fwgrad_bwgrad=True, + supports_forward_ad=True, + # See https://github.com/pytorch/pytorch/pull/78358 + check_batched_forward_grad=False,), + OpInfo('meshgrid', + variant_test_name='list_of_tensors', + # Unlike the variant above, we do not use np.meshgrid as a + # ref since it does not officially support list of numpy + # arrays. + dtypes=all_types_and_complex_and(torch.bfloat16, torch.bool, torch.float16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + sample_inputs_func=partial(sample_inputs_meshgrid, variant='list'), + skips=[ + # meshgrid is defined in torch.functional to take a + # variadic list of tensors. Variadic parameters are not + # compatible with the normalize operator tests. + DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), + ], + assert_autodiffed=True, + supports_out=False, + autodiff_nonfusible_nodes=[], + supports_fwgrad_bwgrad=True, + supports_forward_ad=True, + # See https://github.com/pytorch/pytorch/pull/78358 + check_batched_forward_grad=False,), + OpInfo('min', + variant_test_name='reduction_with_dim', + dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32), + sample_inputs_func=sample_inputs_max_min_reduction_with_dim, + supports_fwgrad_bwgrad=True, + supports_forward_ad=True, + skips=( + )), + OpInfo('min', + variant_test_name='reduction_no_dim', + dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool), + supports_out=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + sample_inputs_func=sample_inputs_max_min_reduction_no_dim, + skips=( + )), + OpInfo('quantile', + dtypes=floating_types(), + sample_inputs_func=sample_inputs_reduction_quantile, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # See https://github.com/pytorch/pytorch/issues/66357 + # Relies on copy_ to broadcast, but the forward AD path calls broadcast_to which + # does not have a batching rule in core + check_batched_forward_grad=False), + OpInfo('nanquantile', + dtypes=floating_types(), + sample_inputs_func=sample_inputs_reduction_quantile, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # See https://github.com/pytorch/pytorch/issues/66357 + # Relies on copy_ to broadcast, but the forward AD path calls broadcast_to which + # does not have a batching rule in core + check_batched_forward_grad=False), + BinaryUfuncInfo( + 'max', + aliases=('maximum',), + variant_test_name='binary', + dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + assert_autodiffed=True, + ref=np.maximum, + supports_rhs_python_scalar=False, + skips=( + # Incorrectly attempts to use a scalar for the second argument + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_jit_alias_remapping'), + # TODO: FIXME: RuntimeError: "max_elementwise_cuda" not implemented for 'ComplexFloat' + DecorateInfo(unittest.expectedFailure, 'TestBinaryUfuncs', 'test_type_promotion', device_type='cuda'), + )), + BinaryUfuncInfo( + 'maximum', + dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + ref=np.maximum, + supports_rhs_python_scalar=False, + skips=( + # TODO: FIXME: RuntimeError: "max_elementwise_cuda" not implemented for 'ComplexFloat' + DecorateInfo(unittest.expectedFailure, 'TestBinaryUfuncs', 'test_type_promotion', device_type='cuda'), + )), + BinaryUfuncInfo( + 'min', + aliases=('minimum',), + variant_test_name='binary', + dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + assert_autodiffed=True, + ref=np.minimum, + supports_rhs_python_scalar=False, + skips=( + # Incorrectly attempts to use a scalar for the second argument + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_jit_alias_remapping'), + # TODO: FIXME: RuntimeError: "min_elementwise_cuda" not implemented for 'ComplexFloat' + DecorateInfo(unittest.expectedFailure, + 'TestBinaryUfuncs', + 'test_type_promotion', + device_type='cuda'), + )), + BinaryUfuncInfo( + 'minimum', + dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + ref=np.minimum, + supports_rhs_python_scalar=False, + skips=( + # TODO: FIXME: RuntimeError: "min_elementwise_cuda" not implemented for 'ComplexFloat' + DecorateInfo(unittest.expectedFailure, + 'TestBinaryUfuncs', + 'test_type_promotion', + device_type='cuda'), + ), + ), + BinaryUfuncInfo('logical_and', + ref=np.logical_and, + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32, torch.int8, torch.bool), + supports_autograd=False, + always_returns_bool=True, + supports_rhs_python_scalar=False), + BinaryUfuncInfo('logical_or', + ref=np.logical_or, + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int8, torch.bool), + supports_autograd=False, + always_returns_bool=True, + supports_rhs_python_scalar=False), + BinaryUfuncInfo('logical_xor', + ref=np.logical_xor, + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int8, torch.bool), + supports_autograd=False, + always_returns_bool=True, + supports_rhs_python_scalar=False, + skips=( + )), + BinaryUfuncInfo('bitwise_and', + ref=np.bitwise_and, + dtypes=integral_types_and(torch.bool), + dtypesIfHpu=custom_types(torch.bool), + operator_variant=operator.and_, + inplace_operator_variant=operator.iand, + supports_autograd=False, + supports_one_python_scalar=True, + skips=( + # RuntimeError: "bitwise_and_cuda" not implemented for 'Half' + DecorateInfo(unittest.expectedFailure, 'TestBinaryUfuncs', + 'test_type_promotion', device_type='cuda'), + )), + BinaryUfuncInfo('bitwise_or', + ref=np.bitwise_or, + dtypes=integral_types_and(torch.bool), + dtypesIfHpu=custom_types(torch.bool), + operator_variant=operator.or_, + inplace_operator_variant=operator.ior, + supports_autograd=False, + supports_one_python_scalar=True, + skips=( + # TODO: FIXME: RuntimeError: "bitwise_or_cuda" not implemented for 'Half' + DecorateInfo(unittest.expectedFailure, + 'TestBinaryUfuncs', + 'test_type_promotion', + device_type='cuda'), + )), + BinaryUfuncInfo('bitwise_xor', + ref=np.bitwise_xor, + dtypes=integral_types_and(torch.bool), + dtypesIfHpu=custom_types(torch.bool), + operator_variant=operator.xor, + inplace_operator_variant=operator.ixor, + supports_autograd=False, + supports_one_python_scalar=True, + skips=( + # TODO: FIXME: RuntimeError: "bitwise_xor_cuda" not implemented for 'Half' + DecorateInfo(unittest.expectedFailure, + 'TestBinaryUfuncs', + 'test_type_promotion', + device_type='cuda'), + )), + BinaryUfuncInfo('heaviside', + ref=lambda a, b: ( + # necessary because np.heaviside incorrectly returns float64 when passed args of dtype int64 + np.int64(np.heaviside(a, b)) if a.dtype == np.int64 and b.dtype == np.int64 else np.heaviside(a, b) + ), + dtypes=all_types_and(torch.bool, torch.float16, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32), + supports_autograd=False, + supports_rhs_python_scalar=False, + skips=( + # RuntimeError: heaviside is not yet implemented for tensors with different dtypes. + DecorateInfo(unittest.expectedFailure, + 'TestBinaryUfuncs', + 'test_type_promotion'), + DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_binary_ufuncs_mixed_dtype'), + # PyTorch's heaviside does not appear to propagate NaNs + DecorateInfo(unittest.skip("Skipped!"), + 'TestBinaryUfuncs', + 'test_reference_numerics_extremal_values'), + )), + BinaryUfuncInfo('lcm', + ref=np.lcm, + dtypes=integral_types_and(), + supports_autograd=False, + supports_rhs_python_scalar=False), + BinaryUfuncInfo('gcd', + ref=np.gcd, + dtypes=integral_types_and(), + supports_autograd=False, + supports_rhs_python_scalar=False, + skips=( + DecorateInfo(unittest.expectedFailure, + 'TestBinaryUfuncs', + 'test_reference_numerics_small_values', + dtypes=(torch.int8,)),)), + BinaryUfuncInfo('isclose', + ref=np.isclose, + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), + sample_inputs_func=sample_inputs_isclose, + error_inputs_func=error_inputs_isclose, + supports_autograd=False, + supports_out=False, + supports_rhs_python_scalar=False, + skips=( + DecorateInfo(unittest.expectedFailure, + 'TestCommon', + 'test_numpy_refs', dtypes=(torch.complex128,)), + # RuntimeError: Short did not match Int + DecorateInfo(unittest.expectedFailure, + 'TestBinaryUfuncs', + 'test_type_promotion'), + DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_binary_ufuncs_mixed_dtype'), + DecorateInfo(unittest.skip("Skipped!"), + 'TestBinaryUfuncs', + 'test_reference_numerics_extremal_values'), + )), + # `softmax` supports different dtypes based on whether `dtype` argument, + # is passed or not. Hence two OpInfo entries, one with dtype and other without. + # https://github.com/pytorch/pytorch/issues/68752 + OpInfo('softmax', + aliases=('special.softmax', 'nn.functional.softmax',), + aten_name='softmax', + aten_backward_name='_softmax_backward_data', + dtypes=floating_types_and(torch.half, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + sample_inputs_func=sample_inputs_softmax_variant, + assert_jit_shape_analysis=True, + assert_autodiffed=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_out=True), + OpInfo('softmax', + aliases=('special.softmax', 'nn.functional.softmax',), + variant_test_name="with_dtype", + aten_name='softmax', + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + sample_inputs_func=partial(sample_inputs_softmax_variant, with_dtype=True), + assert_autodiffed=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_out=True), + OpInfo( + '_softmax_backward_data', + op=torch.ops.aten._softmax_backward_data, + aten_name='_softmax_backward_data', + dtypes=floating_types_and(torch.bfloat16, torch.float16), + sample_inputs_func=sample_inputs_softmax_backward_data, + assert_autodiffed=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_out=False, + skips=( + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', dtypes=(torch.float32,)), + ), + ), + # `softmin` supports different dtypes based on whether `dtype` argument, + # is passed or not. Hence two OpInfo entries, one with dtype and other without. + # https://github.com/pytorch/pytorch/issues/68752 + OpInfo('nn.functional.softmin', + aten_name='softmin', + dtypes=floating_types_and(torch.half, torch.bfloat16), + sample_inputs_func=sample_inputs_softmax_variant, + assert_jit_shape_analysis=False, + assert_autodiffed=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_out=False), + OpInfo('nn.functional.softmin', + variant_test_name="with_dtype", + aten_name='softmin', + dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16), + sample_inputs_func=partial(sample_inputs_softmax_variant, with_dtype=True), + assert_autodiffed=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_out=False), + OpInfo( + "nn.functional.cross_entropy", + dtypes=floating_types_and(torch.float16, torch.bfloat16), + sample_inputs_func=sample_inputs_cross_entropy, + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + decorators=( + DecorateInfo( + toleranceOverride({torch.float32: tol(atol=3e-3, rtol=1e-3)}), + "TestJit", + "test_variant_consistency_jit", + device_type="cpu", + ), + ), + skips=( + # AssertionError: False is not true : Scalars failed to compare as equal! 0 != 1536 + # test_ops.TestJitCUDA.test_variant_consistency_jit_nn_functional_cross_entropy_cuda_float32 leaked + # 1536 bytes CUDA memory on device 0 + DecorateInfo( + unittest.expectedFailure, + "TestJit", + "test_variant_consistency_jit", + device_type="cuda", + ), + DecorateInfo(unittest.skip("FP16 corss_entropy cases have not been enabled on MPS yet"), + dtypes=(torch.half,), device_type="mps"), + + ) + ), + OpInfo('nn.functional.normalize', + dtypes=floating_and_complex_types_and(torch.half, torch.bfloat16), + sample_inputs_func=sample_inputs_normalize, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True), + OpInfo('aminmax', + ref=lambda x, dim=None, keepdim=False: (np.amin(x, axis=dim, keepdims=keepdim), np.amax(x, axis=dim, keepdims=keepdim)), + dtypes=all_types_and(torch.bool, torch.float16, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32, torch.int8), + decorators=(onlyNativeDeviceTypes,), + supports_autograd=False, + sample_inputs_func=sample_inputs_aminmax, + error_inputs_func=error_inputs_aminmax_amax_amin), + OpInfo('as_strided', + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32, torch.int8, torch.bool), + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # vmap does not support inplace views + check_inplace_batched_forward_grad=False, + sample_inputs_func=sample_inputs_as_strided, + skips=( + # Note: This xfail is fine -- it's inherent to how as_strided works + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_noncontiguous_samples'), + # AssertionError: False is not true : Scalars failed to compare as equal! + DecorateInfo(unittest.skip("Errors when storage_offset is included"), + 'TestCommon', 'test_variant_consistency_eager'), + # Not close + DecorateInfo(unittest.skip("Errors when storage_offset is included"), + 'TestCommon', 'test_complex_half_reference_testing'), + # Not close + DecorateInfo(unittest.skip("Errors when storage_offset is included"), 'TestMathBits', 'test_conj_view'), + DecorateInfo(unittest.skip("Errors when storage_offset is included"), 'TestMathBits', 'test_neg_view'), + DecorateInfo(unittest.skip("Numerous errors"), 'TestFwdGradients'), + DecorateInfo(unittest.skip("Numerous errors"), 'TestBwdGradients'), + )), + OpInfo('as_strided', + variant_test_name='partial_views', + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32, torch.int8, torch.bool), + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # vmap does not support inplace views + check_inplace_batched_forward_grad=False, + sample_inputs_func=sample_inputs_as_strided_partial_views, + skips=( + # Note: This xfail is fine -- it's inherent to how as_strided works + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_noncontiguous_samples'), + # These fail because the test changes the input's in-memory layout + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_complex_half_reference_testing'), + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_variant_consistency_eager'), + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_compare_cpu'), + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + DecorateInfo(unittest.expectedFailure, 'TestFwdGradients', 'test_fn_fwgrad_bwgrad', + dtypes=(torch.complex64, torch.complex128)), + DecorateInfo(unittest.expectedFailure, 'TestFwdGradients', 'test_forward_mode_AD'), + DecorateInfo(unittest.expectedFailure, 'TestFwdGradients', 'test_inplace_forward_mode_AD'), + DecorateInfo(unittest.expectedFailure, 'TestBwdGradients', 'test_inplace_grad'), + DecorateInfo(unittest.expectedFailure, 'TestBwdGradients', 'test_inplace_gradgrad'), + DecorateInfo(unittest.expectedFailure, 'TestProxyTensorOpInfo', + 'test_make_fx_symbolic_exhaustive_inplace'), + DecorateInfo(unittest.expectedFailure, 'TestNNCOpInfo', 'test_nnc_correctness'), + # Fail but are also flaky + DecorateInfo(unittest.skip("Test changes in memory layout"), 'TestMathBits'), + DecorateInfo(unittest.skip("Modifies input strides and storage_offset"), 'TestCommon', + 'test_non_standard_bool_values'), + # RuntimeError: setStorage: sizes [2, 2], strides [1, 2], storage offset 10, and itemsize 2 requiring a + # storage size of 28 are out of bounds for storage of size 20 + DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_meta_inplace'), + DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_dispatch_meta_inplace'), + DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_dispatch_symbolic_meta_inplace'), + DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_dispatch_symbolic_meta_inplace_all_strides'), + )), + OpInfo('as_strided_copy', + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf), + supports_out=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # vmap does not support inplace views + check_inplace_batched_forward_grad=False, + sample_inputs_func=sample_inputs_as_strided, + skips=( + # Note: This xfail is fine -- it's inherent to how as_strided works + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_noncontiguous_samples'), + # AssertionError: False is not true : Scalars failed to compare as equal! + DecorateInfo(unittest.skip("Errors when storage_offset is included"), + 'TestCommon', 'test_variant_consistency_eager'), + # Not close + DecorateInfo(unittest.skip("Errors when storage_offset is included"), + 'TestCommon', 'test_complex_half_reference_testing'), + # Not close + DecorateInfo(unittest.skip("Errors when storage_offset is included"), 'TestMathBits', 'test_conj_view'), + DecorateInfo(unittest.skip("Errors when storage_offset is included"), 'TestMathBits', 'test_neg_view'), + DecorateInfo(unittest.skip("Numerous errors"), 'TestFwdGradients'), + DecorateInfo(unittest.skip("Numerous errors"), 'TestBwdGradients'), + DecorateInfo(unittest.expectedFailure, 'TestDTensorOps', 'test_dtensor_op_db'), + )), + OpInfo('as_strided_scatter', + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf), + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # vmap does not support inplace views + check_inplace_batched_forward_grad=False, + sample_inputs_func=sample_inputs_as_strided_scatter, + error_inputs_func=error_inputs_as_strided_scatter, + skips=( + DecorateInfo(unittest.skip('Works for int64, fails for everything else'), 'TestCommon', 'test_noncontiguous_samples'), # noqa: B950 + DecorateInfo(unittest.skip('Fails in most cases, passes on LAZY for some reason'), 'TestCommon', 'test_variant_consistency_eager'), # noqa: B950 + DecorateInfo(unittest.skip('Fails on cuda'), 'TestCommon', 'test_complex_half_reference_testing', + active_if=not TEST_WITH_ROCM), + DecorateInfo(unittest.expectedFailure, 'TestBwdGradients', 'test_fn_grad'), + DecorateInfo(unittest.expectedFailure, 'TestFwdGradients', 'test_forward_mode_AD'), + DecorateInfo(unittest.skip('Passes on complex128 and float64 only'), 'TestFwdGradients', 'test_fn_fwgrad_bwgrad'), + # AssertionError: Tensor-likes are not close! (new_empty_strided.default) + DecorateInfo(unittest.skip("Expected: new_empty_strided is not comparable"), 'TestDecomp', 'test_comprehensive'),)), + OpInfo('native_layer_norm', + aten_name='native_layer_norm', + ref=reference_native_layer_norm, + dtypes=floating_types_and(torch.half, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + supports_out=False, + assert_jit_shape_analysis=True, + supports_fwgrad_bwgrad=True, + sample_inputs_func=sample_inputs_native_layer_norm, + error_inputs_func=error_inputs_native_layer_norm, + skips=( + # IndexError: tuple index out of range + DecorateInfo(unittest.skip('Skipped!'), 'TestFwdGradients', 'test_forward_mode_AD'), + # Tests fail when weight=None and bias is defined + # https://github.com/pytorch/pytorch/issues/79705 + DecorateInfo(unittest.expectedFailure, 'TestBwdGradients', 'test_fn_gradgrad'), + # JIT test also tries to compute double backward, which fails + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + DecorateInfo(unittest.skip("Unsupported on MPS for now"), 'TestCommon', 'test_numpy_ref_mps'), + DecorateInfo(toleranceOverride({torch.float32: tol(atol=2e-03, rtol=5e-03)}), + "TestDecomp", "test_comprehensive", device_type="cpu"), + )), + OpInfo('native_batch_norm', + aten_name='native_batch_norm', + dtypes=floating_types_and(torch.float16, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + assert_jit_shape_analysis=True, + allow_cow_input_materialize_forward=[3, 4], + allow_cow_input_materialize_backward=[3, 4], + sample_inputs_func=sample_inputs_native_batch_norm, + skips=( + # NotImplementedError: Could not run + # 'aten::native_batch_norm.out' with arguments from the 'CPU' backend. + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning', device_type="cpu"), + # RuntimeError: out_invstd.dim() == 1 && out_invstd.is_contiguous() && out_invstd.sizes()[0] + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out', device_type="cuda"), + # Problem with _get_numerical_jacobian + # IndexError: tuple index out of range + DecorateInfo(unittest.skip("Skipped!"), 'TestFwdGradients', 'test_forward_mode_AD'), + # RuntimeError: deepEquals(input.iValue, deepCopiedInput) INTERNAL ASSERT FAILED + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + # https://github.com/pytorch/pytorch/issues/85960 + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_compare_cpu'), + # AssertionError: Booleans mismatch: True is not False + DecorateInfo(unittest.skip("Skipped!"), 'TestFakeTensor', 'test_fake_autocast'), + DecorateInfo(unittest.skip("Skipped!"), 'TestFakeTensor', 'test_fake'), + DecorateInfo(toleranceOverride({torch.float32: tol(atol=5e-5, rtol=5e-5)}), + "TestCompositeCompliance", "test_forward_ad"), + ) + ), + OpInfo('_native_batch_norm_legit', + aten_name='_native_batch_norm_legit', + dtypes=floating_types_and(torch.float16, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + assert_jit_shape_analysis=True, + allow_cow_input_materialize_forward=[3, 4], + allow_cow_input_materialize_backward=[3, 4], + sample_inputs_func=sample_inputs__native_batch_norm_legit, + skips=( + # NotImplementedError: Could not run + # 'aten::native_batch_norm.out' with arguments from the 'CPU' backend. + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning', device_type="cpu"), + # RuntimeError: out_invstd.dim() == 1 && out_invstd.is_contiguous() && out_invstd.sizes()[0] + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out', device_type="cuda"), + # Problem with _get_numerical_jacobian + # IndexError: tuple index out of range + DecorateInfo(unittest.skip("Skipped!"), 'TestFwdGradients', 'test_forward_mode_AD'), + # RuntimeError: deepEquals(input.iValue, deepCopiedInput) INTERNAL ASSERT FAILED + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + # https://github.com/pytorch/pytorch/issues/85960 + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_compare_cpu'), + DecorateInfo(toleranceOverride({torch.float32: tol(atol=5e-5, rtol=5e-5)}), + "TestCompositeCompliance", "test_forward_ad"), + ) + ), + OpInfo('_batch_norm_with_update', + op=torch.ops.aten._batch_norm_with_update, + aten_name='_batch_norm_with_update', + dtypes=floating_types_and(torch.float16, torch.bfloat16), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + assert_jit_shape_analysis=True, + allow_cow_input_materialize_forward=[3, 4], + allow_cow_input_materialize_backward=[3, 4], + sample_inputs_func=sample_inputs__batch_norm_with_update, + skips=( + # NotImplementedError: Could not run + # 'aten::native_batch_norm.out' with arguments from the 'CPU' backend. + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning', device_type="cpu"), + # RuntimeError: out_invstd.dim() == 1 && out_invstd.is_contiguous() && out_invstd.sizes()[0] + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out', device_type="cuda"), + # Problem with _get_numerical_jacobian + # IndexError: tuple index out of range + DecorateInfo(unittest.skip("Skipped!"), 'TestFwdGradients', 'test_forward_mode_AD'), + # RuntimeError: deepEquals(input.iValue, deepCopiedInput) INTERNAL ASSERT FAILED + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + DecorateInfo(toleranceOverride({torch.float32: tol(atol=5e-5, rtol=5e-5)}), + "TestCompositeCompliance", "test_forward_ad"), + # _batch_norm_with_update expects contiguous inputs for cudnn and miopen + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_noncontiguous_samples', device_type="cuda"), + DecorateInfo(unittest.expectedFailure, + 'TestMeta', 'test_dispatch_symbolic_meta_outplace_all_strides', device_type="cuda"), + # _batch_norm_with_update does not have python bindings + DecorateInfo(unittest.skip("Skipped!"), 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), + # aten out variants do not accept out= kwarg, only python out variants + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'), + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'), + ) + ), + OpInfo('nn.functional.cosine_similarity', + aten_name="cosine_similarity", + dtypes=floating_types_and(torch.half, torch.bfloat16), + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + decorators=[ + DecorateInfo( + toleranceOverride({torch.float16: tol(atol=1.3e-5, rtol=2e-2)}), + "TestInductorOpInfo", + "test_comprehensive", + device_type="cuda" + ), + ], + sample_inputs_func=sample_inputs_cosine_similarity), + OpInfo('nn.functional.adaptive_avg_pool1d', + dtypes=floating_types_and(torch.half, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.float16), + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, + error_inputs_func=error_inputs_adaptive_avg_pool1d, + sample_inputs_func=sample_inputs_adaptive_avg_pool1d), + OpInfo('nn.functional.adaptive_avg_pool2d', + dtypes=floating_types_and(torch.half, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.float16), + decorators=( + # RuntimeError: + # adaptive_avg_pool2d(Tensor input, int[2] output_size) -> (Tensor): + # Expected a value of type 'List[int]' for argument 'output_size' but + # instead found type 'Tuple[NoneType, int]'. : + # File "", line 3 + # def the_method(i0): + # return torch.nn.functional.adaptive_avg_pool2d(i0, (None, 7)) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + ), + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, + error_inputs_func=error_inputs_adaptive_avg_pool2d, + sample_inputs_func=sample_inputs_adaptive_avg_pool2d), + OpInfo('nn.functional.adaptive_avg_pool3d', + dtypes=floating_types_and(torch.half, torch.bfloat16), + dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.float16), + decorators=( + # RuntimeError: + # adaptive_avg_pool3d(Tensor input, int[3] output_size) -> (Tensor): + # Expected a value of type 'List[int]' for argument 'output_size' but + # instead found type 'Tuple[NoneType, NoneType, NoneType]'. : + # File "", line 3 + # + # def the_method(i0): + # return torch.nn.functional.adaptive_avg_pool3d(i0, (None, None, None)) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE + # + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + ), + # Runs very slowly on slow gradcheck - alternatively reduce input sizes + gradcheck_fast_mode=True, + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, + error_inputs_func=error_inputs_adaptive_avg_pool3d, + sample_inputs_func=sample_inputs_adaptive_avg_pool3d), + OpInfo('nn.functional.adaptive_max_pool1d', + dtypes=floating_types_and(torch.half, torch.bfloat16), + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # got: Batching rule not implemented for aten::flatten.using_ints + check_batched_forward_grad=False, + gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, + error_inputs_func=error_inputs_adaptive_max_pool1d, + sample_inputs_func=sample_inputs_adaptive_max_pool1d), + OpInfo('nn.functional.adaptive_max_pool2d', + dtypes=floating_types_and(torch.half, torch.bfloat16), + decorators=( + # RuntimeError: + # adaptive_max_pool2d(Tensor input, int[2] output_size) -> (Tensor): + # Expected a value of type 'List[int]' for argument 'output_size' but + # instead found type 'Tuple[NoneType, int]'. : + # File "", line 3 + # def the_method(i0): + # return torch.nn.functional.adaptive_max_pool2d(i0, (None, 7)) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + ), + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # got: Batching rule not implemented for aten::flatten.using_ints + check_batched_forward_grad=False, + gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, + error_inputs_func=error_inputs_adaptive_max_pool2d, + sample_inputs_func=sample_inputs_adaptive_max_pool2d), + OpInfo('nn.functional.adaptive_max_pool3d', + dtypes=floating_types_and(torch.bfloat16, torch.half), + decorators=( + # RuntimeError: + # adaptive_max_pool3d(Tensor input, int[3] output_size) -> (Tensor): + # Expected a value of type 'List[int]' for argument 'output_size' but + # instead found type 'Tuple[NoneType, NoneType, NoneType]'. : + # File "", line 3 + # + # def the_method(i0): + # return torch.nn.functional.adaptive_max_pool3d(i0, (None, None, None)) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE + # + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + ), + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # got: Batching rule not implemented for aten::flatten.using_ints + check_batched_forward_grad=False, + gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, + error_inputs_func=error_inputs_adaptive_max_pool3d, + sample_inputs_func=sample_inputs_adaptive_max_pool3d), + OpInfo('nn.functional.avg_pool1d', + aten_name='avg_pool1d', + supports_autograd=True, + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + dtypes=floating_types_and(torch.int64, torch.float16, torch.bfloat16), + dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.float16), + gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, + error_inputs_func=error_inputs_avg_pool1d, + sample_inputs_func=sample_inputs_avgpool1d), + OpInfo('nn.functional.avg_pool3d', + aten_name='avg_pool3d', + supports_autograd=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + dtypes=floating_types_and(torch.int64), + dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.float16), + gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, + error_inputs_func=error_inputs_avg_pool3d, + sample_inputs_func=sample_inputs_avgpool3d, + skips=( + # AssertionError: Tensor-likes are not close! + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out', device_type='cpu'), + )), + OpInfo( + "nn.functional.binary_cross_entropy_with_logits", + aten_name="binary_cross_entropy_with_logits", + supports_autograd=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_out=False, + dtypes=floating_types_and(torch.half, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.float16), + gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, + sample_inputs_func=sample_inputs_binary_cross_entropy_with_logits, + skips=( + DecorateInfo( + unittest.skip("Skipped!"), + 'TestJit', + 'test_variant_consistency_jit', + dtypes=(torch.float32,) + ), + DecorateInfo(toleranceOverride({torch.float32: tol(atol=2e-5, rtol=3e-6)}), + "TestConsistency", "test_output_match", device_type="mps"), + ), + ), + UnaryUfuncInfo( + 'nn.functional.relu', + aten_name="relu", + ref=lambda a: np.where(a <= 0, 0, a), + supports_autograd=True, + supports_sparse=True, + supports_sparse_csr=True, + supports_sparse_csc=True, + supports_sparse_bsr=True, + supports_sparse_bsc=True, + dtypes=all_types_and(torch.half, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.float16), + sample_inputs_func=sample_inputs_nn_activation_relu, + supports_out=False, + supports_fwgrad_bwgrad=True, + supports_forward_ad=True), + OpInfo('nn.functional.conv_transpose1d', + # `ref` for this function is backward of + # corresponding `conv*d` + ref=partial(conv_transpose_ref, fn=torch.nn.functional.conv_transpose1d), + aten_name='conv_transpose1d', + aliases=('conv_transpose1d',), + dtypes=floating_and_complex_types_and(torch.int64, torch.float16, torch.bfloat16), + dtypesIfCUDA=floating_and_complex_types_and(torch.float16, torch.chalf, + torch.bfloat16), + sample_inputs_func=sample_inputs_conv_transpose1d, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + assert_jit_shape_analysis=True, + gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, + decorators=( + DecorateInfo( + toleranceOverride({torch.float32: tol(atol=1e-04, rtol=1.3e-06), }), + 'TestCommon', 'test_variant_consistency_eager', device_type='cuda'), + DecorateInfo( + toleranceOverride({torch.chalf: tol(atol=5e-2, rtol=5e-2), }), + 'TestCommon', 'test_complex_half_reference_testing'), + DecorateInfo( + toleranceOverride({torch.float: tol(atol=1.5e-5, rtol=1.5e-5), }), + 'TestCommon', 'test_numpy_ref_mps'), + DecorateInfo( + toleranceOverride({torch.half: tol(atol=1e-3, rtol=5e-3), }), + 'TestInductorOpInfo', 'test_comprehensive', device_type='cpu'), + ), + skips=( + # Reason for Skip: https://github.com/pytorch/pytorch/pull/79694#issuecomment-1186949486 + DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit', + dtypes=(torch.complex64,)), + # RuntimeError: UNSUPPORTED DTYPE: complex + DecorateInfo(unittest.expectedFailure, 'TestNNCOpInfo', 'test_nnc_correctness', + dtypes=(torch.complex64, torch.complex128)), + # RuntimeError: !lhs.isAliasOf(rhs)INTERNAL ASSERT FAILED at + # "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":104, please report a bug to PyTorch. + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', + dtypes=(torch.float,)), + # RuntimeError: "slow_conv2d_cpu_grad_input" not implemented for 'Long' + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_numpy_ref', + dtypes=(torch.int64,)), + ), + supports_out=False,), + OpInfo('nn.functional.conv_transpose2d', + aten_name='conv_transpose2d', + aliases=('conv_transpose2d',), + # `ref` for this function is backward of + # corresponding `conv*d` + ref=partial(conv_transpose_ref, fn=torch.nn.functional.conv_transpose2d), + dtypes=floating_and_complex_types_and(torch.int64, torch.float16, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + dtypesIfCUDA=floating_and_complex_types_and(torch.float16, torch.chalf, + torch.bfloat16), + sample_inputs_func=sample_inputs_conv_transpose2d, + # Runs very slowly on slow-gradcheck for complex. + gradcheck_fast_mode=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + assert_jit_shape_analysis=True, + gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, + decorators=[ + DecorateInfo( + toleranceOverride({torch.float32: tol(atol=1e-04, rtol=1.3e-06), }), + 'TestCommon', 'test_variant_consistency_eager', device_type='cuda'), + DecorateInfo( + toleranceOverride({torch.float32: tol(atol=2e-05, rtol=5e-05), }), + 'TestCommon', 'test_noncontiguous_samples', device_type='cuda'), + DecorateInfo( + toleranceOverride({torch.chalf: tol(atol=8e-2, rtol=8e-2), }), + 'TestCommon', 'test_complex_half_reference_testing'), + DecorateInfo( + toleranceOverride({torch.half: tol(atol=1e-3, rtol=4e-3), }), + 'TestInductorOpInfo', 'test_comprehensive', device_type='cpu')], + skips=( + # RuntimeError: !lhs.isAliasOf(rhs)INTERNAL ASSERT FAILED at + # "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":104, please report a bug to PyTorch. + DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'), + # RuntimeError: UNSUPPORTED DTYPE: complex + DecorateInfo(unittest.expectedFailure, 'TestNNCOpInfo', 'test_nnc_correctness', + dtypes=(torch.complex64, torch.complex128)), + # RuntimeError: "slow_conv2d_cpu_grad_input" not implemented for 'Long' + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_numpy_ref', + dtypes=(torch.int64,)), + # Reference: https://github.com/pytorch/pytorch/issues/86356 + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_numpy_ref', + dtypes=(torch.double, torch.cdouble)), + DecorateInfo(unittest.skip("Unsupported on MPS for now"), 'TestCommon', 'test_numpy_ref_mps'), + # AssertionError: None mismatch: torch.complex64 is not None + DecorateInfo(unittest.expectedFailure, 'TestDtypeCustomRules', 'test_custom_rules', + dtypes=(torch.complex64, torch.complex128)), + ), + supports_out=False,), + OpInfo('nn.functional.conv_transpose3d', + aten_name='conv_transpose3d', + aliases=('conv_transpose3d',), + # `ref` for this function is backward of + # corresponding `conv*d` + ref=partial(conv_transpose_ref, fn=torch.nn.functional.conv_transpose3d), + dtypes=floating_and_complex_types_and(torch.int64, torch.float16, torch.bfloat16), + dtypesIfCUDA=floating_and_complex_types_and( + torch.float16, torch.chalf, torch.bfloat16), + sample_inputs_func=sample_inputs_conv_transpose3d, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + assert_jit_shape_analysis=True, + # Runs very slowly on slow-gradcheck - alternatively reduce input sizes + gradcheck_fast_mode=True, + gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, + decorators=[ + DecorateInfo( + toleranceOverride({torch.float16: tol(atol=5e-2, rtol=5e-2), }), + 'TestInductorOpInfo', 'test_comprehensive', device_type='cuda'), + DecorateInfo( + toleranceOverride({torch.float32: tol(atol=1e-04, rtol=1.3e-06), + torch.complex64: tol(atol=1.3e-04, rtol=1.3e-05)}), + 'TestCommon', 'test_variant_consistency_eager', device_type='cuda'), + DecorateInfo( + toleranceOverride({torch.float32: tol(atol=2e-04, rtol=2e-04), }), + 'TestCompositeCompliance', 'test_operator', device_type='cuda'), + DecorateInfo( + toleranceOverride({torch.float32: tol(atol=1.3e-04, rtol=1.3e-06), + torch.complex64: tol(atol=1.3e-04, rtol=1.3e-05)}), + 'TestCommon', 'test_noncontiguous_samples', device_type='cuda'), + DecorateInfo( + toleranceOverride({torch.float32: tol(atol=1e-04, rtol=2e-05), }), + 'TestCompositeCompliance', 'test_forward_ad', device_type='cuda', + active_if=TEST_CUDNN), + DecorateInfo( + toleranceOverride({torch.complex64: tol(atol=1e-4, rtol=1e-4)}), + "TestMathBits", "test_conj_view", device_type='cuda'), + DecorateInfo( + toleranceOverride({torch.chalf: tol(atol=9e-2, rtol=9e-2), }), + 'TestCommon', 'test_complex_half_reference_testing'), + DecorateInfo( + toleranceOverride({torch.half: tol(atol=9e-3, rtol=2e-1), }), + 'TestInductorOpInfo', 'test_comprehensive', device_type='cpu')], + skips=( + # RuntimeError: !lhs.isAliasOf(rhs)INTERNAL ASSERT FAILED at + # "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":104, please report a bug to PyTorch. + DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'), + # RuntimeError: "slow_conv3d_cpu_grad_input" not implemented for 'Long' + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_numpy_ref', + dtypes=(torch.int64,)), + # Reference: https://github.com/pytorch/pytorch/issues/86356 + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_numpy_ref', + dtypes=(torch.double, torch.cdouble)), + DecorateInfo(unittest.skip("Unsupported on MPS for now"), 'TestCommon', 'test_numpy_ref_mps'), + # RuntimeError: UNSUPPORTED DTYPE: complex + DecorateInfo(unittest.expectedFailure, 'TestNNCOpInfo', 'test_nnc_correctness', + dtypes=(torch.complex64, torch.complex128)), + DecorateInfo(unittest.skip('Skipped for ROCm!'), 'TestCommon', 'test_complex_half_reference_testing', + dtypes=[torch.complex32], active_if=TEST_WITH_ROCM), + ), + supports_out=False,), + OpInfo('nn.functional.conv1d', + aliases=('conv1d',), + aten_name='conv1d', + dtypes=floating_and_complex_types_and(torch.int64, torch.float16, torch.bfloat16), + dtypesIfCUDA=floating_and_complex_types_and(torch.float16, torch.chalf, + torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + sample_inputs_func=sample_inputs_conv1d, + error_inputs_func=error_inputs_conv1d, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + assert_jit_shape_analysis=True, + gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, + decorators=( + DecorateInfo( + toleranceOverride({torch.chalf: tol(atol=1e-2, rtol=5e-2)}), + 'TestCommon', 'test_complex_half_reference_testing' + ), + DecorateInfo( + toleranceOverride({torch.float16: tol(atol=2e-3, rtol=1e-3)}), + 'TestInductorOpInfo', 'test_comprehensive', device_type='cuda', + ), + ), + skips=( + # RuntimeError: !lhs.isAliasOf(rhs)INTERNAL ASSERT FAILED at + # "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":103, please report a bug to PyTorch. + DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'), + # Ref: https://github.com/pytorch/pytorch/issues/75309 + # AssertionError: None mismatch: torch.complex128 is not None + DecorateInfo(unittest.expectedFailure, 'TestDtypeCustomRules', + 'test_custom_rules', dtypes=(torch.complex64, torch.complex128)), + # Ref: https://github.com/pytorch/pytorch/issues/75309 + # RuntimeError: UNSUPPORTED DTYPE: complex + DecorateInfo(unittest.expectedFailure, 'TestNNCOpInfo', + 'test_nnc_correctness', dtypes=(torch.complex64, torch.complex128)), + ), + supports_expanded_weight=True, + supports_out=False,), + OpInfo('nn.functional.conv2d', + aliases=('conv2d',), + aten_name='conv2d', + dtypes=floating_and_complex_types_and(torch.int64, torch.float16, torch.bfloat16), + dtypesIfCUDA=floating_and_complex_types_and(torch.float16, torch.chalf, + torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + sample_inputs_func=partial(sample_inputs_conv2d), + error_inputs_func=error_inputs_conv2d, + gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, + # Runs very slowly on slow gradcheck - alternatively reduce input sizes + gradcheck_fast_mode=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + assert_jit_shape_analysis=True, + decorators=( + DecorateInfo( + toleranceOverride({torch.chalf: tol(atol=6e-2, rtol=5e-2)}), + 'TestCommon', 'test_complex_half_reference_testing', + ), + DecorateInfo( + toleranceOverride({torch.float16: tol(atol=5e-3, rtol=1e-3)}), + 'TestInductorOpInfo', 'test_comprehensive', + ), + ), + skips=( + # RuntimeError: !lhs.isAliasOf(rhs)INTERNAL ASSERT FAILED at + # "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":103, please report a bug to PyTorch. + DecorateInfo(unittest.skip("Works on some configs!"), 'TestJit', 'test_variant_consistency_jit'), + # Ref: https://github.com/pytorch/pytorch/issues/75309 + # AssertionError: None mismatch: torch.complex128 is not None + DecorateInfo(unittest.expectedFailure, 'TestDtypeCustomRules', + 'test_custom_rules', dtypes=(torch.complex64, torch.complex128)), + # RuntimeError: UNSUPPORTED DTYPE: complex + DecorateInfo(unittest.expectedFailure, 'TestNNCOpInfo', + 'test_nnc_correctness', dtypes=(torch.complex64, torch.complex128)), + ), + supports_expanded_weight=True, + supports_out=False,), + OpInfo('nn.functional.conv3d', + aliases=('conv3d',), + aten_name='conv3d', + dtypes=floating_and_complex_types_and(torch.int64, torch.bfloat16, torch.float16), + dtypesIfCUDA=floating_and_complex_types_and(torch.float16, torch.chalf, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + sample_inputs_func=sample_inputs_conv3d, + error_inputs_func=error_inputs_conv3d, + gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, + gradcheck_fast_mode=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + decorators=( + DecorateInfo( + toleranceOverride({torch.chalf: tol(atol=6e-2, rtol=5e-2)}), + 'TestCommon', 'test_complex_half_reference_testing', + ), + # TF32 + DecorateInfo( + toleranceOverride({torch.float32: tol(atol=5e-3, rtol=1e-3), + torch.complex64: tol(atol=5e-3, rtol=1e-3)}), + 'TestCommon', 'test_noncontiguous_samples', + ), + DecorateInfo( + toleranceOverride({torch.complex64: tol(atol=2e-5, rtol=3e-6)}), + 'TestCommon', 'test_variant_consistency_eager', + ), + DecorateInfo( + toleranceOverride({torch.complex64: tol(atol=5e-5, rtol=5e-6)}), + 'TestMathBits', 'test_conj_view', + ), + DecorateInfo( + toleranceOverride({torch.float32: tol(atol=5e-5, rtol=5e-6)}), + 'TestOperators', 'test_vjpvmap', + ), + DecorateInfo( + toleranceOverride({torch.float16: tol(atol=5e-3, rtol=1e-3)}), + 'TestInductorOpInfo', 'test_comprehensive', + ), + ), + skips=( + # RuntimeError: !lhs.isAliasOf(rhs) INTERNAL ASSERT FAILED at + # "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":103, please report a bug to PyTorch. + DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'), + # RuntimeError: UNSUPPORTED DTYPE: complex + DecorateInfo(unittest.expectedFailure, 'TestNNCOpInfo', + 'test_nnc_correctness', dtypes=(torch.complex64, torch.complex128)), + # AssertionError: Tensor-likes are not close! + # break slow tests + DecorateInfo(unittest.skip('Skipped!'), 'TestCommon', 'test_compare_cpu'), + ), + supports_expanded_weight=True, + supports_out=False,), + OpInfo('nn.functional.group_norm', + aten_name='group_norm', + aliases=('group_norm',), + ref=reference_group_norm, + dtypes=floating_types_and(torch.float16, torch.bfloat16), + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + error_inputs_func=error_inputs_group_norm, + decorators=[ + # RuntimeError: Cannot insert a Tensor that requires grad as a constant. + # Consider making it a parameter or input, or detaching the gradient + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', dtypes=(torch.float32,)), + DecorateInfo( + toleranceOverride({torch.float32: tol(atol=5e-05, rtol=3e-03)}), + "TestDecomp", + "test_comprehensive", + device_type="cpu" + ), + ], + sample_inputs_func=sample_inputs_group_norm, + reference_inputs_func=reference_inputs_group_norm, + supports_expanded_weight=True,), + OpInfo('nn.functional.instance_norm', + # no ref because instance_norm will often have numerical instability (large numbers or nan) + dtypes=floating_types_and(torch.float16, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + allow_cow_input_materialize_forward=['running_mean', 'running_var'], + decorators=[ + # RuntimeError: Cannot insert a Tensor that requires grad as a constant. + # Consider making it a parameter or input, or detaching the gradient + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', dtypes=(torch.float32,)), + ], + sample_inputs_func=sample_inputs_instance_norm, + supports_expanded_weight=True,), + OpInfo('nn.functional.layer_norm', + aten_name='layer_norm', + aten_backward_name='layer_norm_backward', + aliases=('layer_norm',), + ref=reference_layer_norm, + dtypes=floating_types_and(torch.half, torch.bfloat16), + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + assert_jit_shape_analysis=True, + decorators=[ + DecorateInfo( + toleranceOverride({torch.float32: tol(atol=1e-05, rtol=1e-03)}), + 'TestCommon', 'test_numpy_refs' + ), + DecorateInfo(unittest.skip("Bug in MPS backend!"), 'TestCommon', 'test_numpy_ref_mps'), + ], + sample_inputs_func=sample_inputs_layer_norm, + supports_expanded_weight=True,), + OpInfo('nn.functional.rms_norm', + aten_name='rms_norm', + aliases=('rms_norm',), + ref=reference_rms_norm, + dtypes=floating_and_complex_types_and(torch.half, torch.bfloat16), + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + sample_inputs_func=sample_inputs_rms_norm, + error_inputs_func=error_inputs_rms_norm,), + OpInfo('nn.functional.local_response_norm', + dtypes=floating_types_and(torch.int64, torch.float16, torch.bfloat16), + dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16), + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + decorators=[ + # RuntimeError: falseINTERNAL ASSERT FAILED at + # "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":185, please report a bug to PyTorch. + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', dtypes=(torch.float32,)), + ], + sample_inputs_func=sample_inputs_local_response_norm,), + OpInfo('constant_pad_nd', + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.half), + sample_inputs_func=sample_inputs_constant_pad_nd, + supports_out=False, + skips=( + # bool can't be passed to Scalar arguments in JIT tracer because + # BoolType is not a subtype of ScalarType. + DecorateInfo( + unittest.expectedFailure, 'TestNNCOpInfo', + 'test_nnc_correctness', dtypes=(torch.bool,)), + )), + OpInfo('nn.functional.pad', + variant_test_name='constant', + aten_name='constant_pad_nd', + # Runs very slowly on slow gradcheck - alternatively reduce input sizes + gradcheck_fast_mode=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.half), + sample_inputs_func=partial(sample_inputs_nn_pad, mode='constant'), + supports_out=False), + OpInfo('nn.functional.pad', + variant_test_name='reflect', + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + dtypes=all_types_and_complex_and(torch.bfloat16, torch.half), + sample_inputs_func=partial(sample_inputs_nn_pad, mode='reflect'), + skips=( + # Doesn't have a corresponding aten operator. + # RuntimeError: falseINTERNAL ASSERT FAILED at + # "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":185, please report a bug to PyTorch. + DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit', dtypes=(torch.float32,)), + ), + gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, + supports_out=False), + OpInfo('nn.functional.pad', + variant_test_name='replicate', + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + dtypes=all_types_and_complex_and(torch.half, torch.bfloat16), + sample_inputs_func=partial(sample_inputs_nn_pad, mode='replicate'), + skips=( + # Doesn't have a corresponding aten operator. + # RuntimeError: falseINTERNAL ASSERT FAILED at + # "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":185, please report a bug to PyTorch. + DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit', dtypes=(torch.float32,)), + ), + gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, + supports_out=False), + OpInfo('nn.functional.pad', + variant_test_name='replicate_negative', + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + dtypes=all_types_and_complex_and(torch.half, torch.bfloat16), + sample_inputs_func=sample_inputs_nn_pad_replicate_negative, + skips=( + # Doesn't have a corresponding aten operator. + # RuntimeError: falseINTERNAL ASSERT FAILED at + # "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":185, please report a bug to PyTorch. + DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit', dtypes=(torch.float32,)), + # Some negative padding cases cause a segfault on MPS + DecorateInfo(unittest.skip("Not fully supported on MPS"), 'TestConsistency'), + ), + gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, + supports_out=False), + OpInfo('nn.functional.pad', + variant_test_name='circular', + dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.half), + sample_inputs_func=partial(sample_inputs_nn_pad, mode='circular'), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + check_batched_grad=False, + # https://github.com/pytorch/pytorch/issues/66357 + check_batched_forward_grad=False, + skips=( + # Doesn't have a corresponding aten operator. + # RuntimeError: falseINTERNAL ASSERT FAILED at + # "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":185, please report a bug to PyTorch. + DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit', dtypes=(torch.float32,)), + # Difference from is larger with decomposition new_empty_strided.default than original on output 0 + DecorateInfo(unittest.skip("Expected: new_empty_strided is not comparable"), 'TestDecomp', 'test_comprehensive'), + ), + supports_out=False), + OpInfo('nn.functional.hardswish', + aten_name="hardswish", + aten_backward_name='hardswish_backward', + supports_autograd=True, + assert_autodiffed=True, + sample_inputs_func=sample_inputs_hardswish, + dtypes=floating_types_and(torch.bfloat16, torch.half), + supports_gradgrad=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_out=False, + autodiff_nonfusible_nodes=["aten::hardswish"]), + OpInfo('nn.functional.unfold', + aten_name='im2col', + dtypes=floating_and_complex_types_and(torch.half, torch.bfloat16, torch.bool), + dtypesIfCUDA=floating_and_complex_types_and(torch.half, torch.bfloat16, torch.bool), + sample_inputs_func=sample_inputs_nn_unfold, + # Runs very slowly on slow gradcheck - alternatively reduce input sizes + gradcheck_fast_mode=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_out=False, + skips=( + # NOTE: this failure may not reproduce consistently on different systems + # false INTERNAL ASSERT FAILED at "...torch/csrc/jit/passes/utils/check_alias_annotation.cpp":185 + DecorateInfo(unittest.skip("Internal assert failed!"), 'TestJit', 'test_variant_consistency_jit'), + # Compiler issue on ROCm. Regression started in ROCm 6.4. + DecorateInfo(unittest.skip('Skipped!'), 'TestCommon', 'test_non_standard_bool_values', + dtypes=[torch.bool], active_if=TEST_WITH_ROCM), + )), + OpInfo('nn.functional.interpolate', + aten_name="interpolate", + variant_test_name='nearest', + supports_autograd=True, + supports_fwgrad_bwgrad=True, + supports_forward_ad=True, + dtypes=floating_types_and(torch.uint8, torch.half, torch.bfloat16), + sample_inputs_func=partial(sample_inputs_interpolate, 'nearest'), + skips=( + # RuntimeError: false + # INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":185, + # please report a bug to PyTorch. + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + ), + supports_out=False), + OpInfo('nn.functional.interpolate', + aten_name="interpolate", + variant_test_name='nearest-exact', + supports_autograd=True, + supports_fwgrad_bwgrad=True, + supports_forward_ad=True, + dtypes=floating_types_and(torch.half, torch.bfloat16, torch.uint8), + sample_inputs_func=partial(sample_inputs_interpolate, 'nearest-exact'), + skips=( + # RuntimeError: false + # INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":185, + # please report a bug to PyTorch. + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + # RuntimeError: aten::_upsample_nearest_exact*d hit the vmap fallback which is currently disabled + DecorateInfo(unittest.expectedFailure, 'TestOperators', 'test_vmapjvpall_has_batch_rule'), + DecorateInfo(unittest.expectedFailure, 'TestOperators', 'test_vmapvjp_has_batch_rule'), + DecorateInfo(unittest.expectedFailure, 'TestVmapOperatorsOpInfo', 'test_op_has_batch_rule'), + ), + supports_out=False), + OpInfo('nn.functional.interpolate', + aten_name="interpolate", + variant_test_name='linear', + supports_autograd=True, + supports_fwgrad_bwgrad=True, + supports_forward_ad=True, + dtypes=floating_types_and(torch.half, torch.bfloat16), + sample_inputs_func=partial(sample_inputs_interpolate, 'linear'), + skips=( + # RuntimeError: false + # INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":185, + # please report a bug to PyTorch. + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + ), + supports_out=False), + OpInfo('nn.functional.interpolate', + aten_name="interpolate", + variant_test_name='bilinear', + supports_fwgrad_bwgrad=True, + supports_autograd=True, + supports_forward_ad=True, + dtypes=floating_types_and(torch.uint8, torch.half, torch.bfloat16), + dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16), + gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, + sample_inputs_func=partial(sample_inputs_interpolate, 'bilinear'), + reference_inputs_func=partial(reference_inputs_interpolate, 'bilinear'), + skips=( + # RuntimeError: false + # INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":185, + # please report a bug to PyTorch. + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + ), + supports_out=False), + OpInfo('nn.functional.interpolate', + aten_name="interpolate", + variant_test_name='bicubic', + supports_autograd=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + dtypes=floating_types_and(torch.uint8, torch.half, torch.bfloat16), + dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16), + sample_inputs_func=partial(sample_inputs_interpolate, 'bicubic'), + reference_inputs_func=partial(reference_inputs_interpolate, 'bicubic'), + gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, + skips=( + # RuntimeError: false + # INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":185, + # please report a bug to PyTorch. + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + ), + supports_out=False), + OpInfo('nn.functional.interpolate', + aten_name="interpolate", + variant_test_name='trilinear', + supports_autograd=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + dtypes=floating_types_and(torch.half, torch.bfloat16), + gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, + sample_inputs_func=partial(sample_inputs_interpolate, 'trilinear'), + skips=( + # RuntimeError: false + # INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":185, + # please report a bug to PyTorch. + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + ), + supports_out=False), + OpInfo('nn.functional.interpolate', + aten_name="interpolate", + variant_test_name='area', + supports_autograd=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + dtypes=floating_types_and(torch.half, torch.bfloat16), + dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16), + sample_inputs_func=partial(sample_inputs_interpolate, 'area'), + gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, + skips=( + # RuntimeError: false + # INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":185, + # please report a bug to PyTorch. + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + ), + supports_out=False), + OpInfo('nn.functional.upsample_bilinear', + supports_autograd=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + dtypes=floating_types_and(torch.uint8, torch.half, torch.bfloat16), + dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16), + gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, + sample_inputs_func=partial(sample_inputs_upsample, 'bilinear'), + reference_inputs_func=partial(reference_inputs_upsample, 'bilinear'), + skips=( + # RuntimeError: false + # INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":185, + # please report a bug to PyTorch. + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + ), + supports_out=False), + OpInfo('_upsample_bilinear2d_aa', + op=torch.ops.aten._upsample_bilinear2d_aa, + aten_name='_upsample_bilinear2d_aa', + supports_autograd=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + dtypes=floating_types_and(torch.uint8), + dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16), + gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, + sample_inputs_func=partial(sample_inputs_upsample_aa, 'bilinear'), + supports_out=False, + skips=( + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + DecorateInfo(unittest.expectedFailure, 'TestDTensorOps', 'test_dtensor_op_db'), + DecorateInfo(unittest.expectedFailure, 'TestInductorOpInfo', 'test_comprehensive'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'), + )), + OpInfo( + "nn.functional.soft_margin_loss", + dtypes=floating_types_and(torch.half, torch.bfloat16), + supports_out=False, + supports_forward_ad=True, + # doesn't support grad on target + sample_inputs_func=partial(sample_inputs_loss, rhs_requires_grad=False), + error_inputs_func=error_inputs_soft_margin_loss, + ), + OpInfo('nn.functional.upsample_nearest', + supports_autograd=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + dtypes=floating_types_and(torch.uint8, torch.half, torch.bfloat16), + gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, + sample_inputs_func=partial(sample_inputs_upsample, 'nearest'), + skips=( + # RuntimeError: false + # INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":185, + # please report a bug to PyTorch. + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + ), + supports_out=False), + OpInfo( + "nn.functional.margin_ranking_loss", + dtypes=all_types_and(torch.half, torch.bfloat16), + supports_out=False, + sample_inputs_func=sample_inputs_margin_ranking_loss, + error_inputs_func=error_inputs_margin_ranking_loss, + reference_inputs_func=reference_inputs_margin_ranking_loss, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True), + OpInfo( + "nn.functional.multi_margin_loss", + dtypes=floating_types(), + dtypesIfCUDA=floating_types_and(torch.bfloat16, torch.float16), + supports_out=False, + supports_gradgrad=False, + sample_inputs_func=sample_inputs_multi_margin_loss, + reference_inputs_func=reference_inputs_multi_margin_loss, + error_inputs_func=error_inputs_multi_margin_loss, + decorators=( + DecorateInfo( + toleranceOverride({torch.float32: tol(atol=1e-4, rtol=1e-4)}), + "TestJit", + "test_variant_consistency_jit", + ), + ), + ), + OpInfo( + "nn.functional.multilabel_margin_loss", + dtypes=floating_types(), + dtypesIfCUDA=floating_types_and(torch.bfloat16, torch.float16), + supports_out=False, + supports_gradgrad=False, + sample_inputs_func=sample_inputs_multilabel_margin_loss, + reference_inputs_func=reference_inputs_multilabel_margin_loss, + error_inputs_func=error_inputs_multilabel_margin_loss, + ), + OpInfo('nn.functional.leaky_relu', + aliases=None, + aten_name="leaky_relu", + aten_backward_name='leaky_relu_backward', + sample_inputs_func=sample_inputs_leaky_relu, + dtypes=floating_types_and(torch.bfloat16, torch.float16), + inplace_variant=lambda x, negative_slope=0.01: + torch.nn.functional.leaky_relu(x, negative_slope, inplace=True), + supports_autograd=True, + assert_autodiffed=True, + supports_gradgrad=True, + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + autodiff_nonfusible_nodes=["aten::leaky_relu"]), + OpInfo( + "nn.functional.multilabel_soft_margin_loss", + supports_out=False, + dtypes=floating_types_and(torch.half, torch.bfloat16), + sample_inputs_func=sample_inputs_multilabel_soft_margin_loss, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + decorators=( + DecorateInfo( + toleranceOverride({torch.float32: tol(atol=1e-4, rtol=1e-4)}), + "TestJit", + "test_variant_consistency_jit", + ), + DecorateInfo( + toleranceOverride({torch.float16: tol(atol=4e-3, rtol=1.3e-3)}), + "TestInductorOpInfo", + "test_comprehensive", + device_type="cuda" + ), + ), + skips=( + # AssertionError: False is not true : Scalars failed to compare as equal! 0 != 4096 + # __main__.TestJitCUDA.test_variant_consistency_jit_nn_functional_multilabel_soft_margin_loss_cuda_float32 + # leaked 4096 bytes CUDA memory on device 0 + DecorateInfo( + # Skip instead of expectedFailure because this fails + # locally for me but passes in CI. + unittest.skip("Skipped!"), + "TestJit", + "test_variant_consistency_jit", + device_type="cuda", + ), + ), + ), + OpInfo('nn.functional.avg_pool2d', + aten_name='avg_pool2d', + supports_autograd=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + dtypes=floating_types_and(torch.int64, torch.float16, torch.bfloat16), + dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16), + error_inputs_func=error_inputs_avg_pool2d, + sample_inputs_func=sample_inputs_avgpool2d, + skips=( + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out', device_type='cuda'), + )), + OpInfo('nn.functional.fractional_max_pool2d', + supports_autograd=True, + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + op=lambda input, *args, **kwargs: + wrapper_set_seed(torch.nn.functional.fractional_max_pool2d, input, *args, **kwargs), + # vmap does not support random operations + check_batched_forward_grad=False, + dtypes=floating_types_and(torch.bfloat16, torch.float16), + test_neg_view=False, + sample_inputs_func=sample_inputs_fractional_max_pool2d, + decorators=( + # FIXME: AssertionError: False is not true : Tensors failed to compare as equal! + DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), + # RuntimeError: input->type()->kind() == TypeKind::OptionalType + # INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":270 + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit')), + skips=( + DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'),)), + OpInfo('nn.functional.fractional_max_pool3d', + supports_autograd=True, + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + op=lambda input, *args, **kwargs: + wrapper_set_seed(torch.nn.functional.fractional_max_pool3d, input, *args, **kwargs), + # vmap does not support random operations + check_batched_forward_grad=False, + dtypes=floating_types_and(torch.bfloat16, torch.float16), + test_neg_view=False, + gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, + sample_inputs_func=sample_inputs_fractional_max_pool3d, + decorators=( + # FIXME: both derivatives are implemented incorrectly + # https://github.com/pytorch/pytorch/issues/69322 + # FIXME: AssertionError: False is not true : Tensors failed to compare as equal! + DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), + # RuntimeError: input->type()->kind() == TypeKind::OptionalType + # INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":270 + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit')), + skips=( + DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'),)), + OpInfo('nn.functional.max_pool1d', + aten_name='max_pool1d', + supports_autograd=True, + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # got: Batching rule not implemented for aten::flatten.using_ints + check_batched_forward_grad=False, + # TODO: add shape checks + assert_jit_shape_analysis=False, + dtypes=floating_types_and(torch.bfloat16, torch.float16), + dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16), + skips=( + # Pre-existing condition; Needs to be fixed + DecorateInfo(unittest.skip("Works on some configs"), 'TestNNCOpInfo', + 'test_nnc_correctness', dtypes=(torch.bfloat16,)), + # RuntimeError: The tensor has a non-zero number of elements, but its data is not allocated yet. + # Caffe2 uses a lazy allocation, so you will need to call mutable_data() or raw_mutable_data() + # to actually allocate memory + DecorateInfo(unittest.skip("Skipped!"), 'TestTags', 'test_tags'), + ), + error_inputs_func=error_inputs_max_pool1d, + sample_inputs_func=sample_inputs_max_pool), + OpInfo('nn.functional.max_pool2d', + aten_name='max_pool2d', + # Runs very slowly on slow gradcheck - alternatively reduce input sizes + gradcheck_fast_mode=True, + # Vmap is not happy with non-contiguous (channels_last) inputs + check_batched_gradgrad=False, + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # got: Batching rule not implemented for aten::flatten.using_ints + check_batched_forward_grad=False, + assert_jit_shape_analysis=True, + dtypes=all_types_and(torch.float16, torch.bfloat16), + dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16), + error_inputs_func=error_inputs_max_pool2d, + sample_inputs_func=sample_inputs_max_pool), + OpInfo('max_pool2d_with_indices_backward', + op=max_pool2d_backward, + # We've defined a custom op, so there's no corresponding aten op + aten_name=None, + method_variant=None, + inplace_variant=None, + operator_variant=None, + inplace_operator_variant=None, + check_batched_gradgrad=False, + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + check_batched_forward_grad=False, + assert_jit_shape_analysis=False, + dtypes=floating_types_and(torch.bfloat16, torch.float16), + sample_inputs_func=sample_inputs_max_pool, + skips=( + # We've defined a custom op here, and we don't handle the case where we receive an out kwarg + DecorateInfo(unittest.skip("Skipped!"), "TestCommon", "test_out"), + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'), + # FX failed to normalize op - add the op to the op_skip list. + DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), + # object has no attribute max_pool2d_with_indices_backward (It's not available on torch -- so expected) + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit') + )), + OpInfo('nn.functional.max_pool3d', + aten_name='max_pool3d', + # Runs very slowly on slow gradcheck - alternatively reduce input sizes + gradcheck_fast_mode=True, + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # got: Batching rule not implemented for aten::flatten.using_ints + check_batched_forward_grad=False, + # TODO: add shape checks + assert_jit_shape_analysis=False, + dtypes=all_types_and(torch.bfloat16, torch.float16), + dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16), + # TODO: investigate nondeterminism + gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, + error_inputs_func=error_inputs_max_pool3d, + sample_inputs_func=sample_inputs_max_pool), + OpInfo('nn.functional.max_unpool1d', + aten_name='max_unpool1d', + supports_autograd=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_out=False, + assert_jit_shape_analysis=False, + dtypes=floating_types_and(torch.float16, torch.bfloat16), + sample_inputs_func=sample_inputs_max_unpool, + error_inputs_func=error_inputs_max_unpool, + skips=( + # Gradients are tested in `variant_test_name=grad` below. + # We skip tests here because there is non-determinism in backward + # with gather, when there are writes into the same memory location, + # and if there are several indices pointing to the same memory, + # gradcheck is oblivious about that and cannot perturb them all at once + # (see sample_inputs_max_unpool_grad to find out more). + DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients', 'test_fn_grad'), + DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients', 'test_fn_gradgrad'), + DecorateInfo(unittest.skip("Skipped!"), 'TestFwdGradients', 'test_forward_mode_AD', + active_if=(not IS_MACOS)), + DecorateInfo(unittest.skip("Skipped!"), 'TestCompositeCompliance', 'test_forward_ad', + device_type='cpu'), + DecorateInfo(unittest.skip("Skipped!"), 'TestDecomp', 'test_quick_core_backward'), + )), + OpInfo('nn.functional.max_unpool1d', + variant_test_name='grad', + aten_name='max_unpool1d', + supports_autograd=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_out=False, + assert_jit_shape_analysis=False, + dtypes=floating_types_and(torch.float16, torch.bfloat16), + sample_inputs_func=sample_inputs_max_unpool_grad), + OpInfo('nn.functional.max_unpool2d', + aten_name='max_unpool2d', + supports_autograd=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_out=False, + assert_jit_shape_analysis=False, + dtypes=floating_types_and(torch.float16, torch.bfloat16), + sample_inputs_func=sample_inputs_max_unpool, + error_inputs_func=error_inputs_max_unpool, + skips=( + # Gradients are tested in `variant_test_name=grad` below. + # We skip tests here because there is non-determinism in backward + # with gather, when there are writes into the same memory location, + # and if there are several indices pointing to the same memory, + # gradcheck is oblivious about that and cannot perturb them all at once + # (see sample_inputs_max_unpool_grad to find out more). + DecorateInfo(unittest.skip("Skipped!"), 'TestFwdGradients', 'test_forward_mode_AD', + active_if=(not IS_MACOS)), + DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients', 'test_fn_gradgrad'), + DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients', 'test_fn_grad'), + DecorateInfo(unittest.skip("Skipped!"), 'TestCompositeCompliance', 'test_forward_ad'), + DecorateInfo(unittest.skip("Skipped!"), 'TestDecomp', 'test_quick_core_backward'), + )), + OpInfo('nn.functional.max_unpool2d', + variant_test_name='grad', + aten_name='max_unpool2d', + # Runs very slowly on slow gradcheck - alternatively reduce input sizes + gradcheck_fast_mode=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # Vmap is not happy with non-contiguous (channels_last) inputs + check_batched_grad=False, + supports_out=False, + assert_jit_shape_analysis=False, + dtypes=floating_types_and(torch.float16, torch.bfloat16), + sample_inputs_func=sample_inputs_max_unpool_grad), + OpInfo('nn.functional.max_unpool3d', + aten_name='max_unpool3d', + # Runs very slowly on slow gradcheck - alternatively reduce input sizes + gradcheck_fast_mode=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_out=False, + assert_jit_shape_analysis=False, + dtypes=floating_types_and(torch.float16, torch.bfloat16), + sample_inputs_func=sample_inputs_max_unpool, + error_inputs_func=error_inputs_max_unpool, + skips=( + # Gradients are tested in `variant_test_name=grad` below. + # We skip tests here because there is non-determinism in backward + # with gather, when there are writes into the same memory location, + # and if there are several indices pointing to the same memory, + # gradcheck is oblivious about that and cannot perturb them all at once + # (see sample_inputs_max_unpool_grad to find out more). + DecorateInfo(unittest.skip("Skipped!"), 'TestFwdGradients', 'test_forward_mode_AD', + active_if=(not IS_MACOS)), + DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients', 'test_fn_gradgrad'), + DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients', 'test_fn_grad'), + DecorateInfo(unittest.skip("Skipped!"), 'TestCompositeCompliance', 'test_forward_ad'), + DecorateInfo(unittest.skip("Skipped!"), 'TestDecomp', 'test_quick_core_backward'), + )), + OpInfo('nn.functional.max_unpool3d', + variant_test_name='grad', + aten_name='max_unpool3d', + supports_autograd=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_out=False, + assert_jit_shape_analysis=False, + dtypes=floating_types_and(torch.float16, torch.bfloat16), + sample_inputs_func=sample_inputs_max_unpool_grad), + OpInfo('nn.functional.linear', + aten_name='linear', + supports_autograd=True, + supports_gradgrad=True, + sample_inputs_func=sample_inputs_linear, + dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16), + dtypesIfROCM=floating_and_complex_types_and(torch.float16, torch.bfloat16), + dtypesIfCUDA=floating_and_complex_types_and(torch.float16, torch.bfloat16), + backward_dtypesIfCUDA=floating_and_complex_types_and(torch.float16, torch.bfloat16), + # linear calls mm under the hood which is nondeterministic on CUDA + # https://pytorch.org/docs/stable/generated/torch.use_deterministic_algorithms.html#torch.use_deterministic_algorithms + gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # See https://github.com/pytorch/pytorch/issues/66357 + check_batched_forward_grad=False, + supports_expanded_weight=True, + decorators=( + # Strides are not the same! + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'), + )), + OpInfo('nn.functional.bilinear', + aten_name='bilinear', + supports_autograd=True, + sample_inputs_func=sample_inputs_bilinear, + dtypes=all_types_and(torch.float16, torch.bfloat16), + dtypesIfCUDA=floating_types_and(torch.float16, + *[torch.bfloat16] if SM53OrLater or TEST_WITH_ROCM else []), + decorators=( + DecorateInfo(toleranceOverride({torch.float16: tol(atol=2e-03, rtol=1.3e-03)}), + 'TestInductorOpInfo', 'test_comprehensive', device_type='cpu'), + ), + skips=( + # NVIDIA only assures that bfloat16 is supported by bmm if SM >= 5.3 + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_dtypes', device_type='cuda', active_if=not SM53OrLater), + DecorateInfo(unittest.skip("Skipped!"), 'TestNNCOpInfo', 'test_nnc_correctness', dtypes=(torch.bfloat16,)), + ), + # Runs very slowly on slow gradcheck - alternatively reduce input sizes + gradcheck_fast_mode=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_out=False), + OpInfo('nn.functional.glu', + aten_name='glu', + # Runs very slowly on slow gradcheck - alternatively reduce input sizes + gradcheck_fast_mode=True, + sample_inputs_func=sample_inputs_glu, + dtypes=floating_types_and(torch.bfloat16, torch.float16), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_out=False), + UnaryUfuncInfo( + 'nn.functional.elu', + aten_backward_name='elu_backward', + ref=lambda x, alpha=1.0, inplace=False: + np.maximum(0., x) + np.minimum(0., alpha * (np.exp(x) - 1)), + dtypes=floating_types_and(torch.bfloat16, torch.float16), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_autograd=True, + assert_autodiffed=False, + supports_gradgrad=True, + supports_out=False, + sample_kwargs=lambda device, dtype, input: + ({'alpha': 0.8}, {'alpha': 0.8}), + inplace_variant=lambda x, alpha=1.0: + torch.nn.functional.elu(x, alpha, inplace=True), + decorators=[ + DecorateInfo( + toleranceOverride({ + torch.float16: tol(atol=1e-03, rtol=1.2e-03), + torch.bfloat16: tol(atol=1e-03, rtol=1.2e-03) + }), + 'TestUnaryUfuncs', device_type='cuda', + ), ], + ), + # Marked as a Unary function because it has some rather odd broadcasting semantics in its + # second argument + UnaryUfuncInfo( + 'nn.functional.prelu', + aten_backward_name='_prelu_kernel_backward', + ref=lambda x, weight: + np.maximum(0., x) + np.minimum(0., x) * + (weight if x.ndim == 1 else weight.reshape([weight.size if i == 1 else 1 for i in range(x.ndim)])), + dtypes=floating_types_and(torch.bfloat16, torch.float16), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_autograd=True, + assert_autodiffed=False, + supports_gradgrad=True, + supports_out=False, + # test_reference_numerics only tests the case when the weight tensor is a scalar + sample_kwargs=sample_kwargs_prelu_scalar_weight, + error_inputs_func=error_inputs_prelu, + sample_inputs_func=sample_inputs_prelu, + reference_inputs_func=reference_inputs_prelu, + decorators=[ + # RuntimeError: Cannot insert a Tensor that requires grad as a constant. + # Consider making it a parameter or input, or detaching the gradient + # https://github.com/pytorch/pytorch/issues/68752 + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), ], + ), + UnaryUfuncInfo( + 'nn.functional.celu', + ref=lambda x, alpha=1.0, inplace=False: + np.maximum(0., x) + np.minimum(0., alpha * (np.exp(x / alpha) - 1)), + dtypes=floating_types_and(torch.bfloat16, torch.float16), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_autograd=True, + assert_autodiffed=False, + supports_gradgrad=True, + supports_out=False, + sample_kwargs=lambda device, dtype, input: + ({'alpha': 0.8}, {'alpha': 0.8}), + inplace_variant=lambda x, alpha=1.0: + torch.nn.functional.celu(x, alpha, inplace=True), + decorators=[ + DecorateInfo( + toleranceOverride({ + torch.float16: tol(atol=1e-03, rtol=1.2e-03), + torch.bfloat16: tol(atol=1e-03, rtol=1.2e-03) + }), + 'TestUnaryUfuncs', device_type='cuda', + ), ], + ), + UnaryUfuncInfo( + 'nn.functional.rrelu', + aten_backward_name='rrelu_with_noise_backward', + op=lambda input, *args, **kwargs: + wrapper_set_seed(torch.nn.functional.rrelu, input, *args, **kwargs), + inplace_variant=lambda input, *args, **kwargs: + wrapper_set_seed(torch.nn.functional.rrelu, input, *args, inplace=True, **kwargs), + dtypes=floating_types_and(torch.bfloat16), + dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16), + gradcheck_wrapper=wrapper_set_seed, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_out=False, + sample_kwargs=lambda device, dtype, input: + (dict(lower=0., upper=1., training=True), dict(lower=0., upper=1., training=True)), + sample_inputs_func=sample_inputs_rrelu, + error_inputs_func=error_inputs_rrelu, + decorators=( + DecorateInfo( + toleranceOverride({ + torch.float16: tol(atol=1e-03, rtol=1.2e-03), + torch.bfloat16: tol(atol=1e-03, rtol=1.2e-03) + }), + 'TestUnaryUfuncs', device_type='cuda', + ),), + skips=( + # lambda impl + DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), + # lambda impl + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + # In-place operations do not play well with forward AD + # https://github.com/pytorch/pytorch/issues/77447 + DecorateInfo(unittest.expectedFailure, 'TestFwdGradients', + 'test_inplace_forward_mode_AD'), + # The noise vector that's generated in these tests is not the same elementwise + DecorateInfo(unittest.skip("Different noise"), 'TestUnaryUfuncs', 'test_batch_vs_slicing'), + DecorateInfo(unittest.skip("Different noise"), 'TestUnaryUfuncs', 'test_contig_vs_every_other'), + DecorateInfo(unittest.skip("Different noise"), 'TestUnaryUfuncs', 'test_non_contig_expand'), + DecorateInfo(unittest.skip("Different noise"), 'TestUnaryUfuncs', 'test_contig_vs_transposed'), + DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu')), + skip_correctness_check_compile_vs_eager=True, + ), + UnaryUfuncInfo( + 'nn.functional.selu', + ref=lambda x, inplace=False: + 1.0507009873554804934193349852946 * ( + np.maximum(0., x) + np.minimum(0., 1.6732632423543772848170429916717 * (np.exp(x) - 1)) + ), + dtypes=floating_types_and(torch.bfloat16, torch.float16), + supports_forward_ad=True, # depends on 'elu' + supports_fwgrad_bwgrad=True, + supports_autograd=True, + assert_autodiffed=False, + supports_gradgrad=True, + supports_out=False, + inplace_variant=lambda x: torch.nn.functional.selu(x, inplace=True), + decorators=[ + DecorateInfo( + toleranceOverride({ + torch.float16: tol(atol=1e-2, rtol=1.8e-2), + torch.bfloat16: tol(atol=1e-2, rtol=1.8e-2) + }), + 'TestUnaryUfuncs', device_type='cuda', + ), ], + ), + OpInfo( + 'torch._scaled_mm_v2', + sample_inputs_func=sample_inputs_scaled_mm_v2, + dtypes=float8_types(), + dtypesIfCUDA=empty_types() + (torch.float8_e4m3fn,), + supports_out=True, + supports_forward_ad=False, + supports_autograd=False, + decorators=[onlyCUDA, skipCUDAIf(not SM89OrLater or TEST_WITH_ROCM, 'Requires CUDA SM >= 8.9')], + skips=( + # Sample inputs isn't really parametrized on dtype + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_dtypes'), + # "add_stub" not implemented for 'Float8_e4m3fn' + # "ufunc_add_CUDA" not implemented for 'Float8_e4m3fn' + # https://github.com/pytorch/pytorch/issues/107256 + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_out'), + # "mul_cuda" not implemented for float8_e4m3fn + # "mul_cpu_reduced_float" not implemented for 'Float8_e4m3fn' + # https://github.com/pytorch/pytorch/issues/107256 + DecorateInfo(unittest.skip("Skipped!"), 'TestSchemaCheckModeOpInfo', 'test_schema_correctness'), + # aten::_scaled_mm hit the vmap fallback which is currently disabled + DecorateInfo(unittest.skip("Skipped!"), "TestVmapOperatorsOpInfo", "test_op_has_batch_rule"), + DecorateInfo(unittest.skip("Skipped!"), "TestVmapOperatorsOpInfo", "test_vmap_exhaustive"), + DecorateInfo(unittest.expectedFailure, 'TestNNCOpInfo', 'test_nnc_correctness', + dtypes=(torch.float8_e4m3fn, torch.float8_e4m3fnuz, torch.float8_e5m2, torch.float8_e5m2fnuz)), + ) + ), + OpInfo( + 'torch._scaled_mm', + sample_inputs_func=sample_inputs_scaled_mm, + dtypes=float8_types(), + dtypesIfCUDA=empty_types() + (torch.float8_e4m3fn,), + supports_out=True, + supports_forward_ad=False, + supports_autograd=False, + decorators=[skipXPU, skipCUDAIf(not SM89OrLater or TEST_WITH_ROCM, 'Requires CUDA SM >= 8.9')], + skips=( + # Sample inputs isn't really parametrized on dtype + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_dtypes'), + # "add_stub" not implemented for 'Float8_e4m3fn' + # "ufunc_add_CUDA" not implemented for 'Float8_e4m3fn' + # https://github.com/pytorch/pytorch/issues/107256 + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_out'), + # "mul_cuda" not implemented for float8_e4m3fn + # "mul_cpu_reduced_float" not implemented for 'Float8_e4m3fn' + # https://github.com/pytorch/pytorch/issues/107256 + DecorateInfo(unittest.skip("Skipped!"), 'TestSchemaCheckModeOpInfo', 'test_schema_correctness'), + # aten::_scaled_mm hit the vmap fallback which is currently disabled + DecorateInfo(unittest.skip("Skipped!"), "TestVmapOperatorsOpInfo", "test_op_has_batch_rule"), + DecorateInfo(unittest.skip("Skipped!"), "TestVmapOperatorsOpInfo", "test_vmap_exhaustive"), + DecorateInfo(unittest.expectedFailure, 'TestNNCOpInfo', 'test_nnc_correctness', + dtypes=(torch.float8_e4m3fn, torch.float8_e4m3fnuz, torch.float8_e5m2, torch.float8_e5m2fnuz)), + ) + ), + OpInfo( + 'torch.ops.aten._safe_softmax.default', + dtypes=all_types_and(torch.half, torch.bfloat16, torch.bool), + sample_inputs_func=sample_inputs_safe_softmax, + assert_jit_shape_analysis=True, + assert_autodiffed=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_out=False, + supports_cow_input_no_materialize_backward=False, + decorators=[], + skips=( + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + ), + ), + OpInfo( + 'nn.functional.scaled_dot_product_attention', + op=lambda *args, **kwargs: + wrapper_set_seed(torch.nn.functional.scaled_dot_product_attention, *args, **kwargs), + sample_inputs_func=sample_inputs_scaled_dot_product_attention, + dtypes=floating_types_and(torch.float16, torch.bfloat16), + supports_out=False, + supports_forward_ad=False, + supports_fwgrad_bwgrad=True, + check_batched_forward_grad=False, + decorators=[DecorateInfo(toleranceOverride( + {torch.float32: tol(atol=5e-05, rtol=5e-6)}), 'TestCommon',), ], + skips=( + # When attn mask is a composite tensor this fails backward by returning a none + DecorateInfo(unittest.skip("Skipped!"), 'TestCompositeCompliance', 'test_backward', device_type='cuda'), + # This is only failing on Linux Bionic 3.10 Cuda 11.6 + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_dtypes', + device_type='cuda'), + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_noncontiguous_samples', + dtypes=(torch.float32,)), + # AssertionError: JIT Test does not execute any logic + DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'), + # Forward works for dtype=float64 which is the math path + DecorateInfo(unittest.skip("Skipped!"), 'TestFwdGradients', 'test_forward_mode_AD'), + # Not implemented for Forward AD + DecorateInfo(unittest.skip("Skipped!"), 'TestFwdGradients', 'test_fn_fwgrad_bwgrad', + device_type='cpu'), + # Not implemented for backward derivative + DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients', 'test_fn_gradgrad', + device_type='cpu'), + # CPU and CUDA have inconsistencies for intermediate outputs + DecorateInfo(unittest.skip("Skipped!"), 'TestMeta', 'test_dispatch_meta_outplace', + device_type='cpu'), + DecorateInfo(unittest.skip("Skipped!"), 'TestMeta', 'test_dispatch_symbolic_meta_outplace', + device_type='cpu'), + # When changing input from Tensor to CompositeCompliantTensor, input.requires_grad() changes from true to false + DecorateInfo(unittest.skip("Skipped!"), 'TestCompositeCompliance', 'test_backward', + device_type='cpu'), + # OpInfo was implemented with a lambda + DecorateInfo(unittest.skip("Skipped!"), 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), + # TODO Need to understand what this is testing and why it doesn't work + DecorateInfo(unittest.skip("Skipped"), 'TestDecomp', 'test_comprehensive'), + DecorateInfo(unittest.skip('output is non-deterministic (when dropout_p > 0)'), 'TestCommon', 'test_compare_cpu'), + # TODO skip this for now since we can't skip on runtime arch support + DecorateInfo(unittest.skip('This is '), 'TestInductorOpInfo', 'test_comprehensive'), + # skip for sm < 80 + DecorateInfo(unittest.skip("Skipped!"), 'TestSchemaCheckModeOpInfo', 'test_schema_correctness', + device_type='cuda', dtypes=(torch.bfloat16,), active_if=not SM80OrLater), + # FIXME + DecorateInfo(unittest.skip('test_cow_input does not work with efficient attention on ROCM'), + 'TestCompositeCompliance', 'test_cow_input', + device_type='cuda', dtypes=(torch.bfloat16, torch.float16, torch.float32), + active_if=TEST_WITH_ROCM and PLATFORM_SUPPORTS_MEM_EFF_ATTENTION),), + ), + OpInfo( + 'torch.ops.aten._flash_attention_forward', + sample_inputs_func=sample_inputs_flash_attention_forward, + dtypes=empty_types(), + dtypesIfCUDA=custom_types(torch.float16) + if not SM80OrLater + else custom_types(torch.float16, torch.bfloat16), + supports_out=False, + supports_autograd=True, + supports_fwgrad_bwgrad=False, + supports_forward_ad=False, + check_batched_forward_grad=False, + decorators=[skipCUDAIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "This platform doesn't support Flash Attention")], + skips=( + # Checking the scalar value of the philox seed and offset + DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_operator', device_type='cuda'), + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_noncontiguous_samples', device_type='cuda'), + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', device_type='cuda'), + # None Mismatch Tensor + DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_backward', device_type='cuda'), + ) + ), + OpInfo( + 'torch.ops.aten._efficient_attention_forward', + sample_inputs_func=sample_inputs_efficient_attention_forward, + dtypes=empty_types(), + dtypesIfCUDA=custom_types(torch.float16, torch.float32) + if not SM80OrLater + else custom_types(torch.float16, torch.float32, torch.bfloat16), + supports_out=False, + supports_autograd=True, + supports_fwgrad_bwgrad=False, + supports_forward_ad=False, + check_batched_forward_grad=False, + # TODO: Skip because it produces a CUDA illegal memory access for some reason + skip_cow_input_backward=True, + # FIXME: mask_type == 2 (LowerRight) + decorators=[ + skipCUDAIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "This platform doesn't support efficient attention"), + skipCUDAIf(TEST_WITH_ROCM, "Efficient attention on ROCM doesn't support custom_mask_type==2"), + skipXPU], + skips=( + # Checking the scaler value of the philox seed and offset + DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_operator', device_type='cuda'), + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_noncontiguous_samples', device_type='cuda'), + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', device_type='cuda'), + # None Mismatch Tensor + DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_backward', device_type='cuda'), + ) + ), + UnaryUfuncInfo( + 'nn.functional.silu', + aten_backward_name='silu_backward', + ref=lambda x, inplace=False: x / (1 + np.exp(-x)), + dtypes=floating_types_and(torch.bfloat16, torch.float16), + supports_forward_ad=True, + supports_autograd=True, + supports_fwgrad_bwgrad=True, + assert_autodiffed=True, + supports_out=False, + inplace_variant=lambda x: torch.nn.functional.silu(x, inplace=True), + decorators=[ + DecorateInfo( + toleranceOverride({ + torch.float16: tol(atol=1e-3, rtol=1e-3), + torch.bfloat16: tol(atol=1e-4, rtol=1e-4) + }), + 'TestUnaryUfuncs', device_type='cuda', + ), ], + skips=( + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_normal', + dtypes=(torch.cfloat,), device_type='cpu'), + ), + autodiff_nonfusible_nodes=["aten::silu"], + ), + # TODO: combine this with the nn.functional.silu OpInfo when + # complex autodiff for silu is supported or when + # the forward bug is fixed + # Note: silu errors when given inputs that require grad + # but it doesn't support grad in their dtype + # This is why the dtypes list above passes test_dtypes, + # because it's getting lucky and failing in forward + # because test_dtypes sets requires_grad to True + # THIS IS A BUG + UnaryUfuncInfo( + 'nn.functional.silu', + variant_test_name='complex', + ref=lambda x, inplace=False: + x / (1 + np.exp(-x)), + dtypes=complex_types(), + dtypesIfCUDA=complex_types(), + supports_forward_ad=False, + supports_autograd=False, + assert_autodiffed=False, + supports_out=False, + inplace_variant=lambda x: torch.nn.functional.silu(x, inplace=True), + decorators=[ + DecorateInfo( + toleranceOverride({ + torch.float16: tol(atol=1e-3, rtol=1e-3), + torch.bfloat16: tol(atol=1e-4, rtol=1e-4) + }), + 'TestUnaryUfuncs', device_type='cuda', + ), ], + skips=( + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_normal', + dtypes=(torch.cfloat,)), + # FIXME: intentionally misreports dtypes + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_dtypes'), + # FIXME: numpy reference diverges: Comparing (nan+nanj) and (-0+0j) + DecorateInfo(unittest.skip("Skipped!"), + 'TestUnaryUfuncs', 'test_reference_numerics_large', + dtypes=(torch.complex64, torch.cdouble)), + DecorateInfo(unittest.skip("Skipped!"), + 'TestUnaryUfuncs', 'test_reference_numerics_small', + dtypes=(torch.complex64,)), + DecorateInfo(unittest.skip("Skipped!"), + 'TestUnaryUfuncs', 'test_reference_numerics_extremal', + dtypes=(torch.complex64,)))), + UnaryUfuncInfo( + 'nn.functional.hardsigmoid', + aten_backward_name='hardsigmoid_backward', + ref=reference_hardsigmoid, + dtypes=floating_types_and(torch.bfloat16, torch.float16), + supports_autograd=True, + assert_autodiffed=False, + supports_gradgrad=False, + supports_forward_ad=True, + supports_out=False, + inplace_variant=partial(torch.nn.functional.hardsigmoid, inplace=True), + decorators=[ + DecorateInfo( + toleranceOverride({torch.float16: tol(atol=1e-04, rtol=0.001)}), 'TestUnaryUfuncs', device_type='cuda',), ], + skips=[ + # still want to test that first derivative works though second derivative isn't supported + DecorateInfo(unittest.expectedFailure, 'TestBwdGradients', "test_inplace_gradgrad")] + ), + UnaryUfuncInfo( + 'nn.functional.logsigmoid', + aten_name="log_sigmoid", + aten_backward_name='log_sigmoid_backward', + ref=reference_logsigmoid, + dtypes=floating_types_and(torch.half, torch.bfloat16), + supports_autograd=True, + assert_autodiffed=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_gradgrad=True, + # autodiff_nonfusible_nodes=["aten::log_sigmoid"], + decorators=[ + DecorateInfo( + precisionOverride({torch.float16: 1e-2, torch.bfloat16: 5e-3}), + 'TestUnaryUfuncs', 'test_reference_numerics_small'), + DecorateInfo( + precisionOverride({torch.float16: 1e-2, torch.bfloat16: 5e-3}), + 'TestUnaryUfuncs', 'test_reference_numerics_large'), + DecorateInfo( + precisionOverride({torch.float16: 1e-2, torch.bfloat16: 5e-3}), + 'TestUnaryUfuncs', 'test_reference_numerics_extremal'), + ], + skips=( + # Resized a non-empty tensor but did not warn about it. + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning', device_type='cpu'), + ), + ), + UnaryUfuncInfo( + 'nn.functional.mish', + aten_backward_name='mish_backward', + ref=lambda x: x * np.tanh(reference_softplus(x)), + dtypes=floating_types_and(torch.bfloat16, torch.float16), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_autograd=True, + assert_autodiffed=False, + supports_gradgrad=True, + supports_out=False, + inplace_variant=partial(torch.nn.functional.mish, inplace=True), + decorators=[ + DecorateInfo( + toleranceOverride({torch.float16: tol(atol=1e-02, rtol=1e-03)}), 'TestUnaryUfuncs',), ], + ), + UnaryUfuncInfo( + 'nn.functional.softsign', + ref=lambda x: x / (np.abs(x) + 1), + dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16), + dtypesIfCUDA=all_types_and_complex_and(torch.float16, torch.bfloat16, torch.bool), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_autograd=True, + assert_autodiffed=False, + supports_gradgrad=True, + supports_out=False, + decorators=[ + DecorateInfo( + toleranceOverride({torch.float16: tol(atol=1e-03, rtol=1.3e-04)}), 'TestUnaryUfuncs',), ], + skips=( + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_small', + dtypes=(torch.int, torch.int8)),), + ), + UnaryUfuncInfo( + 'nn.functional.tanhshrink', + ref=lambda x: x - np.tanh(x), + dtypes=all_types_and_complex_and(torch.half, torch.bfloat16), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_autograd=True, + assert_autodiffed=False, + supports_gradgrad=True, + supports_out=False, + decorators=[ + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_normal', + device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]), + DecorateInfo( + toleranceOverride({torch.bfloat16: tol(atol=1e-02, rtol=1.6e-02)}), 'TestUnaryUfuncs',), + DecorateInfo(toleranceOverride({torch.complex64: tol(atol=6e-04, rtol=1e-05), + torch.bfloat16: tol(atol=1e-02, rtol=1.6e-02)}), + 'TestUnaryUfuncs', 'test_reference_numerics_extremal', device_type='cuda'), + ], + skips=( + # in each case, pytorch will produce a nan while numpy will not + DecorateInfo(unittest.skip("Fails on some jobs works on others!"), + 'TestUnaryUfuncs', "test_reference_numerics_large", + dtypes=(torch.complex64, torch.complex128), active_if=(IS_MACOS)), + DecorateInfo(unittest.skip("Fails on some jobs works on others!"), + 'TestUnaryUfuncs', "test_reference_numerics_extremal", + dtypes=(torch.complex64, torch.complex128), device_type='cpu', + active_if=(IS_MACOS or IS_WINDOWS)), + ), + # tan(j * pi/2 * odd_number) is nan which also make tanhshrink nan. + reference_numerics_filter=NumericsFilter( + condition=lambda x: (close_to_int(x / (math.pi * 0.5j)) + if x.is_complex() else x.new_tensor(False, dtype=torch.bool)), + safe_val=0) + ), + UnaryUfuncInfo( + 'nn.functional.threshold', + ref=lambda x, threshold, value: np.where(x <= threshold, value, x).astype(x.dtype), + dtypes=all_types_and(torch.half, torch.bfloat16), + inplace_variant=lambda x, threshold, value: + torch.nn.functional.threshold(x, threshold, value, inplace=True), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + assert_autodiffed=False, + supports_gradgrad=True, + supports_out=False, + sample_kwargs=lambda device, dtype, input: ({'threshold': float.fromhex('0x1.3ap-3'), + 'value': -9}, + {'threshold': float.fromhex('0x1.3ap-3'), + 'value': -9}), + # TODO(whc) should not need sample_inputs_func, but without it + # kwargs aren't being hooked up properly + sample_inputs_func=sample_inputs_threshold, + ), + OpInfo( + "nn.functional.triplet_margin_loss", + sample_inputs_func=sample_inputs_triplet_margin_loss, + error_inputs_func=error_inputs_triplet_margin_loss, + dtypes=all_types_and_complex_and(torch.half, torch.bfloat16), + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + ), + OpInfo( + "nn.functional.triplet_margin_with_distance_loss", + sample_inputs_func=partial(sample_inputs_triplet_margin_loss, with_distance=True), + error_inputs_func=error_inputs_triplet_margin_loss, + dtypes=all_types_and_complex_and(torch.half, torch.bfloat16), + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + skips=( + # This test cannot handle a callable passed to `distance_function`. If we would use + # `distance_function=None`, the test would pass fine. + DecorateInfo( + unittest.expectedFailure, + "TestJit", + "test_variant_consistency_jit", + ), + DecorateInfo( + unittest.expectedFailure, + "TestNormalizeOperators", + "test_normalize_operator_exhaustive", + ), + ), + ), + BinaryUfuncInfo('nextafter', + dtypes=floating_types_and(torch.bfloat16, torch.half), + supports_autograd=False, + supports_rhs_python_scalar=False), + OpInfo( + "to", + op=lambda x, *args, **kwargs: x.to(*args, **kwargs), + dtypes=all_types_and_complex_and(torch.bfloat16, torch.float16, torch.bool), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_out=False, + sample_inputs_func=sample_inputs_to, + skips=( + # RuntimeError: undefined value cpu + DecorateInfo( + unittest.skip("Skipped!"), + "TestJit", + "test_variant_consistency_jit", + device_type="cpu", + ), + # NotImplementedError: Cannot copy out of meta tensor; no data! + DecorateInfo( + unittest.skip("Skipped!"), + "TestMeta", + "test_meta_outplace", + ), + # https://github.com/pytorch/pytorch/issues/84335 + DecorateInfo( + unittest.skip("Skipped!"), + "TestProxyTensorOpInfo", + "test_make_fx_symbolic_exhaustive", + ), + DecorateInfo( + unittest.skip("Skipped!"), + "TestNormalizeOperators", + "test_normalize_operator_exhaustive", + ), + ), + ), + OpInfo('topk', + dtypes=all_types_and(torch.bfloat16, torch.float16), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + assert_jit_shape_analysis=True, + sample_inputs_func=sample_inputs_topk), + # Multiple variants for batch_norm to test with and without cuDNN disabled + # See https://github.com/pytorch/pytorch/pull/63218#discussion_r688549391 for more details + OpInfo('nn.functional.batch_norm', + aten_name='batch_norm', + dtypes=floating_types_and(torch.float16, torch.bfloat16), + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + assert_jit_shape_analysis=True, + allow_cow_input_materialize_forward=[1, 2], + allow_cow_input_materialize_backward=[1, 2], + sample_inputs_func=sample_inputs_batch_norm, + skips=( + # see https://github.com/pytorch/pytorch/issues/71286 + DecorateInfo(unittest.expectedFailure, 'TestNNCOpInfo', 'test_nnc_correctness'), + DecorateInfo(unittest.skip('Skipped!'), 'TestNNCOpInfo', 'test_nnc_correctness', + device_type='cpu', dtypes=(torch.bfloat16, torch.float16)), + DecorateInfo(toleranceOverride({torch.float32: tol(atol=5e-05, rtol=1e-05)}), + 'TestCompositeCompliance', 'test_forward_ad', device_type="cpu"), + )), + # This variant tests batch_norm with cuDNN disabled only on CUDA devices + OpInfo('nn.functional.batch_norm', + variant_test_name='without_cudnn', + aten_name='batch_norm', + dtypes=empty_types(), + dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16), + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + allow_cow_input_materialize_forward=[1, 2], + allow_cow_input_materialize_backward=[1, 2], + decorators=[onlyCUDA, disablecuDNN], + skips=( + DecorateInfo(toleranceOverride({torch.float32: tol(atol=1e-03, rtol=1e-04)}), + 'TestJit', 'test_variant_consistency_jit'), + ), + sample_inputs_func=sample_inputs_batch_norm), + OpInfo( + "nn.functional.binary_cross_entropy", + aten_backward_name='binary_cross_entropy_backward', + sample_inputs_func=sample_inputs_binary_cross_entropy, + dtypes=floating_types_and(torch.float16, torch.bfloat16), + dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16), + supports_out=False, + gradcheck_fast_mode=False, + supports_autograd=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + decorators=( + # RuntimeError: expected int at position 0, but got: Tensor + DecorateInfo( + unittest.skip("Skipped!"), + "TestCudaFuserOpInfo", + ), + # RuntimeError: expected int at position 0, but got: Tensor + DecorateInfo( + unittest.skip("Skipped!"), + "TestNNCOpInfo", + "test_nnc_correctness", + ), + # Fails for unknown reason: https://github.com/pytorch/pytorch/issues/120783 + DecorateInfo( + unittest.skip("Skipped!"), + "TestCompositeCompliance", + "test_cow_input", + device_type='cuda', + ), + DecorateInfo( + toleranceOverride({torch.float32: tol(atol=1e-3, rtol=1e-3)}), + "TestJit", + "test_variant_consistency_jit", + ), + # RuntimeError: output with shape [] doesn't match the broadcast shape [5, 5] + DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_dispatch_meta_outplace'), + DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_dispatch_symbolic_meta_outplace'), + DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_dispatch_symbolic_meta_outplace_all_strides'), + ), + skips=( + # RuntimeError: expected int at position 0, but got: Tensor + DecorateInfo( + unittest.expectedFailure, + "TestJit", + "test_variant_consistency_jit", + ), + ), + ), + # We have to add 2 OpInfo entry for `igamma` and `igammac`.First is the + # standard entry, second is to run gradcheck tests on the second argument. + BinaryUfuncInfo('igamma', + dtypes=floating_types_and(torch.bfloat16, torch.float16), + aliases=('torch.special.gammainc',), + dtypesIfCUDA=floating_types(), + # TODO: FIXME + supports_rhs_python_scalar=False, + supports_autograd=False, + skips=( + # FIXME: incorrectly tries to pass a rhs scalar + DecorateInfo(unittest.expectedFailure, 'TestJit', + 'test_jit_alias_remapping'), + )), + # TODO: FIXME, ideally by implemented grad for both inputs + # BinaryUfuncInfo('igamma', + # variant_test_name='grad_other', + # # Since autograd formula is implemented only for other and + # # gradcheck test verifies the formula for input in SampleInput, + # # we permute the arguments. + # op=lambda self, other, **kwargs: torch.igamma(other, self, **kwargs), + # inplace_variant=None, + # method_variant=None, + # supports_rhs_python_scalar=False, + # rhs_make_tensor_kwargs=dict(requires_grad=False), + # dtypes=floating_types_and(torch.bfloat16, torch.float16), + # backward_dtypesIfCPU=floating_types_and(torch.bfloat16), + # dtypesIfCUDA=floating_types(), + # backward_dtypesIfCUDA=floating_types(), + # supports_inplace_autograd=False, + # skips=( + # # Derivative wrt first tensor not implemented + # DecorateInfo(unittest.expectedFailure, "TestCommon", + # "test_floating_inputs_are_differentiable"),"), + # # test does not work with passing lambda for op + # # AssertionError: False is not true : Tensors failed to compare as equal! + # DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'), + # # test fails are we permute the arguments function variant + # # but not for inplace or method. + # DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_variant_consistency_eager'), + # # TypeError: igamma(): argument 'input' (position 1) must be Tensor, not float + # DecorateInfo(unittest.skip('Skipped!'), 'TestBinaryUfuncs'), + # )), + BinaryUfuncInfo('igammac', + dtypes=floating_types_and(torch.bfloat16, torch.float16), + aliases=('torch.special.gammaincc',), + dtypesIfCUDA=floating_types(), + supports_autograd=False, + supports_rhs_python_scalar=False, + skips=( + # FIXME: incorrectly tries to pass a rhs scalar + DecorateInfo(unittest.expectedFailure, 'TestJit', + 'test_jit_alias_remapping'), + )), + # TODO: FIXME, ideally by implementing grad for both inputs + # BinaryUfuncInfo('igammac', + # variant_test_name='grad_other', + # # Since autograd formula is implemented only for other and + # # gradcheck test verifies the formula for input in SampleInput, + # # we permute the arguments + # op=lambda self, other, **kwargs: torch.igammac(other, self, **kwargs), + # inplace_variant=None, + # method_variant=None, + # supports_rhs_python_scalar=False, + # rhs_make_tensor_kwargs=dict(requires_grad=False), + # dtypes=floating_types_and(torch.bfloat16, torch.float16), + # backward_dtypesIfCPU=floating_types_and(torch.bfloat16), + # dtypesIfCUDA=floating_types(), + # backward_dtypesIfCUDA=floating_types(), + # supports_inplace_autograd=False, + # decorators=[ + # # Derivative wrt first tensor not implemented + # DecorateInfo(unittest.expectedFailure, "TestCommon", + # "test_floating_inputs_are_differentiable"), + # ], + # skips=( + # # test does not work with passing lambda for op + # # AssertionError: False is not true : Tensors failed to compare as equal! + # DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'), + # # test fails are we permute the arguments function variant + # # but not for inplace or method. + # DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_variant_consistency_eager'), + # # TypeError: igammac(): argument 'input' (position 1) must be Tensor, not float + # DecorateInfo(unittest.skip('Skipped!'), 'TestBinaryUfuncs'), + # )), + UnaryUfuncInfo('nn.functional.softshrink', + aten_name="softshrink", + aten_backward_name='softshrink_backward', + dtypes=floating_types_and(torch.bfloat16, torch.float16), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + assert_autodiffed=False, + sample_inputs_func=sample_inputs_softshrink, + error_inputs_func=error_inputs_softshrink), + UnaryUfuncInfo('nn.functional.hardshrink', + aten_name="hardshrink", + aten_backward_name='hardshrink_backward', + dtypes=floating_types_and(torch.bfloat16, torch.float16), + assert_autodiffed=True, + sample_inputs_func=sample_inputs_hardshrink, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + autodiff_nonfusible_nodes=["aten::hardshrink"]), + UnaryUfuncInfo('nn.functional.hardtanh', + aten_name="hardtanh", + aten_backward_name='hardtanh_backward', + dtypes=floating_types_and(torch.int8, torch.int16, torch.int32, torch.int64, torch.half, torch.bfloat16), + backward_dtypes=all_types_and(torch.half, torch.bfloat16), + backward_dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16), + assert_autodiffed=True, + sample_inputs_func=sample_inputs_hardtanh, + error_inputs_func=error_inputs_hardtanh, + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + autodiff_nonfusible_nodes=["aten::hardtanh"]), + OpInfo('nn.functional.gelu', + aten_name="gelu", + aten_backward_name='gelu_backward', + ref=reference_gelu if TEST_SCIPY else None, + error_inputs_func=error_inputs_gelu, + supports_autograd=True, + assert_autodiffed=True, + sample_inputs_func=sample_inputs_gelu, + dtypes=floating_types_and(torch.bfloat16, torch.half), + supports_gradgrad=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + autodiff_nonfusible_nodes=["aten::gelu"], + skips=( + # AssertionError: Tensor-likes are not close! + # May not replicate in CI + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_out'), + DecorateInfo(unittest.skip("Unsupported on MPS for now"), 'TestCommon', 'test_numpy_ref_mps'), + )), + UnaryUfuncInfo('nn.functional.relu6', + aten_name="relu6", + dtypes=all_types_and(torch.half, torch.bfloat16), + backward_dtypes=floating_types_and(torch.half, torch.bfloat16), + assert_autodiffed=True, + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + autodiff_nonfusible_nodes=["aten::relu6"]), + OpInfo('mm', + dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16), + dtypesIfCUDA=floating_and_complex_types_and(torch.float16, torch.bfloat16), + assert_autodiffed=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + sample_inputs_func=sample_inputs_mm, + skips=( + # Issue with conj and torch dispatch, see https://github.com/pytorch/pytorch/issues/82479 + DecorateInfo( + unittest.skip("Skipped!"), + 'TestSchemaCheckModeOpInfo', + 'test_schema_correctness', + dtypes=(torch.complex64, torch.complex128)), + # Fast math on MacOS-13? + DecorateInfo( + toleranceOverride({torch.float32: tol(atol=2e-5, rtol=5e-6)}), + 'TestConsistency', + 'test_output_match', + active_if=lambda _: MACOS_VERSION < 14.0, + device_type='mps', + dtypes=(torch.float32,)), + )), + OpInfo('mode', + op=torch.mode, + dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + skips=( + # Resized a non-empty tensor but did not warn about it + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'), + # FIXME: + # Expected 2114 but got 1123. + # Absolute difference: 991 (up to 0.001 allowed) + # Relative difference: 0.46877956480605487 (up to 0.001 allowed) + DecorateInfo( + unittest.skip("Skipped!"), + "TestCommon", + "test_compare_cpu", + dtypes=(torch.float32,), + device_type="cuda", + ), + ), + sample_inputs_func=sample_inputs_mode,), + make_mvlgamma_opinfo(variant_test_name='mvlgamma_p_1', + domain=(1, None), + skips=skips_mvlgamma(), + sample_kwargs=lambda device, dtype, input: ({'p': 1}, {'d': 1})), + make_mvlgamma_opinfo(variant_test_name='mvlgamma_p_3', + domain=(2, None), + skips=skips_mvlgamma(), + sample_kwargs=lambda device, dtype, input: ({'p': 3}, {'d': 3})), + make_mvlgamma_opinfo(variant_test_name='mvlgamma_p_5', + domain=(3, None), + skips=skips_mvlgamma(), + sample_kwargs=lambda device, dtype, input: ({'p': 5}, {'d': 5})), + BinaryUfuncInfo('ne', + ref=np.not_equal, + aliases=('not_equal',), + dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16), + always_returns_bool=True, + supports_autograd=False, + skips=( + )), + OpInfo('narrow', + dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16, torch.chalf), + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + sample_inputs_func=partial(sample_inputs_narrow_narrow_copy, is_narrow=True), + reference_inputs_func=partial(reference_inputs_narrow_narrow_copy, is_narrow=True), + error_inputs_func=partial(error_inputs_narrow_narrow_copy, is_narrow=True, is_ref=False), + skips=( + # Use of .item() + DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_operator'), + DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_backward'), + DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_forward_ad'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'), + )), + OpInfo('narrow_copy', + dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16, torch.chalf), + supports_out=True, + supports_forward_ad=False, + supports_fwgrad_bwgrad=False, + supports_autograd=False, + # https://github.com/pytorch/pytorch/issues/86931 + sample_inputs_func=partial(sample_inputs_narrow_narrow_copy, is_narrow=False), + reference_inputs_func=partial(reference_inputs_narrow_narrow_copy, is_narrow=False), + error_inputs_func=partial(error_inputs_narrow_narrow_copy, is_narrow=False, is_ref=False), + skips=( + # https://github.com/pytorch/pytorch/issues/84577 + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'), + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'), + # Could not run 'aten::narrow_copy.out' with arguments from the 'CUDA' backend + DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_meta_outplace', + device_type='cuda'), + DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_dispatch_meta_outplace', + device_type='cuda'), + DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_dispatch_symbolic_meta_outplace', + device_type='cuda'), + DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_dispatch_symbolic_meta_outplace_all_strides'), + )), + OpInfo('view_copy', + dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16), + ref=lambda x, newshape: np.reshape(x, newshape).copy(), + supports_out=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_autograd=True, + sample_inputs_func=sample_inputs_view_reshape, + error_inputs_func=error_inputs_view_reshape, + skips=( + # RuntimeError: view size is not compatible with input tensor's size and stride + # (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead. + DecorateInfo( + unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides" + ), + )), + UnaryUfuncInfo('neg', + aliases=('negative', ), + ref=np.negative, + dtypes=all_types_and_complex_and(torch.half, torch.bfloat16, torch.chalf), + error_inputs_func=error_inputs_neg, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_sparse=True, + supports_sparse_csr=True, + supports_sparse_csc=True, + supports_sparse_bsr=True, + supports_sparse_bsc=True, + assert_autodiffed=True), + OpInfo('dist', + op=torch.dist, + dtypes=floating_and_complex_types_and(torch.half, torch.bfloat16), + # Runs very slowly on slow gradcheck - alternatively reduce input sizes + gradcheck_fast_mode=True, + supports_out=False, + supports_forward_ad=True, + # torch.autograd.gradcheck.GradcheckError: While computing batched gradients, got: + # Could not allocate memory to change Tensor SizesAndStrides! + check_batched_forward_grad=False, + supports_fwgrad_bwgrad=True, + sample_inputs_func=sample_inputs_dist), + OpInfo('outer', + op=torch.outer, + aliases=('ger', ), + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # See https://github.com/pytorch/pytorch/pull/78358 + check_batched_forward_grad=False, + sample_inputs_func=sample_inputs_outer,), + OpInfo('ormqr', + op=torch.ormqr, + dtypes=floating_and_complex_types(), + # https://github.com/pytorch/pytorch/issues/80411 + gradcheck_fast_mode=True, + supports_forward_ad=False, + supports_fwgrad_bwgrad=False, + sample_inputs_func=sample_inputs_ormqr, + error_inputs_func=error_inputs_ormqr, + decorators=[skipCUDAIfNoCusolver, skipCPUIfNoLapack], + skips=( + # Strides are not the same! + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'), + )), + OpInfo('permute', + ref=np.transpose, + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf), + supports_out=False, + assert_autodiffed=True, + autodiff_fusible_nodes=[], # aliases inputs, shouldn't be fused + autodiff_nonfusible_nodes=[], # aliases inputs, shouldn't be fused + assert_jit_shape_analysis=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_varargs=True, + sample_inputs_func=sample_inputs_permute, + reference_inputs_func=reference_inputs_permute), + OpInfo('permute_copy', + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf), + supports_out=True, + assert_autodiffed=True, + assert_jit_shape_analysis=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_varargs=False, # torch.permute is also not varargs + sample_inputs_func=sample_inputs_permute, + reference_inputs_func=reference_inputs_permute, + skips=( + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', dtypes=(torch.float32,)), + )), + BinaryUfuncInfo('pow', + dtypes=all_types_and_complex_and(torch.half, torch.bfloat16), + dtypesIfCUDA=all_types_and_complex_and(torch.half, torch.bfloat16, torch.chalf), + ref=np.power, + # Due to AVX2 currently not being fully supported for Float16, log_vml_cpu can't be enabled + # for Float16, causing this test to fail. pow's autograd for Float16 is thus currently + # unsupported on CPU. + backward_dtypes=floating_and_complex_types_and(torch.half, torch.bfloat16), + backward_dtypesIfCUDA=floating_and_complex_types_and(torch.bfloat16, torch.half, torch.chalf), + # https://github.com/pytorch/pytorch/issues/80411 + gradcheck_fast_mode=True, + supports_inplace_autograd=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + assert_autodiffed=True, + supports_one_python_scalar=True, + # Integer types do not support negative exponentes + rhs_make_tensor_kwargs=dict(low=0), + # Raising negative real numbers to fractional powers is not supported + lhs_make_tensor_kwargs=dict(low=0), + decorators=( + DecorateInfo(toleranceOverride({torch.complex64: tol(atol=1e-4, rtol=1.3e-05)}), + 'TestBinaryUfuncs', 'test_reference_numerics'), + DecorateInfo(toleranceOverride({torch.complex64: tol(atol=1e-4, rtol=1.3e-05), + torch.complex128: tol(atol=1e-4, rtol=1.3e-05)}), + 'TestBinaryUfuncs', 'test_scalar_support'), + ), + skips=( + # Skipping integers because they are being raised to negative powers causing an error + DecorateInfo(unittest.expectedFailure, 'TestBinaryUfuncs', 'test_reference_numerics_small_values', + dtypes=[torch.int8, torch.int16, torch.int32, torch.int64]), + DecorateInfo(unittest.expectedFailure, 'TestBinaryUfuncs', 'test_reference_numerics_large_values', + dtypes=[torch.int16, torch.int32, torch.int64]), + # FIXME Complex values error with: Greatest absolute difference: nan at index + # Ref: https://github.com/pytorch/pytorch/issues/76853 + # For `chalf`, reference computation in `numpy` is computed in `cfloat`. + # Output of `chalf` saturates to `inf` quicker than reference due to its small range + # which leads to failure of this test. + DecorateInfo(unittest.skip("Skipped!"), 'TestDecomp', 'test_quick', + dtypes=(torch.complex32,), active_if=TEST_WITH_ROCM), + # FIXME: + # Mismatched elements: 1 / 500 (0.2%) + # Greatest absolute difference: nan at index (7, 9, 0) (up to 1e-05 allowed) + # Greatest relative difference: nan at index (7, 9, 0) (up to 0.001 allowed) + DecorateInfo(unittest.skip("Skipped!"), 'TestDecomp', 'test_comprehensive', + dtypes=(torch.complex32,)), + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_complex_half_reference_testing', + dtypes=(torch.complex32,), active_if=TEST_WITH_ROCM), + DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', 'test_batch_vs_slicing', + dtypes=(torch.complex32,)), + DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', 'test_non_contig', + dtypes=(torch.complex32,)), + DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', 'test_reference_numerics', + dtypes=(torch.complex32,)), + DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', 'test_reference_numerics_small_values', + dtypes=(torch.complex32, torch.complex64, torch.complex128)), + DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', 'test_reference_numerics_large_values', + dtypes=(torch.complex32, torch.complex64, torch.complex128)), + DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', 'test_reference_numerics_extremal_values', + dtypes=(torch.complex32, torch.complex64, torch.complex128)), + )), + BinaryUfuncInfo('float_power', + ref=np.float_power, + dtypes=all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool), + promotes_int_to_float=True, + # https://github.com/pytorch/pytorch/issues/80411 + gradcheck_fast_mode=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_one_python_scalar=True, + # Integer types do not support negative exponentes + rhs_make_tensor_kwargs=dict(low=0), + # Raising negative real numbers to fractional powers is not supported + lhs_make_tensor_kwargs=dict(low=0), + decorators=( + DecorateInfo(toleranceOverride({torch.complex64: tol(atol=1e-4, rtol=1.3e-05), + torch.complex128: tol(atol=1e-4, rtol=1.3e-05)}), + 'TestBinaryUfuncs', 'test_scalar_support'), + ), + skips=( + # FIXME + # AssertionError: Object comparison failed: torch.float64 != torch.float32 + DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', 'test_type_promotion'), + # -3.43399e+38 is outside the range of representable values of type 'float' + DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'), + # Complex values error with: Greatest absolute difference: nan at index + DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', 'test_reference_numerics_small_values', + dtypes=[torch.complex64, torch.complex128]), + DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', 'test_reference_numerics_large_values', + dtypes=[torch.complex64, torch.complex128]), + DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', 'test_reference_numerics_extremal_values', + dtypes=[torch.complex64, torch.complex128]), + # Inplace always promotes to double and thus other floating dtypes are not supported + DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_meta_inplace', + dtypes=[torch.bfloat16, torch.float16, torch.float32]), + )), + OpInfo('qr', + op=torch.qr, + dtypes=floating_and_complex_types(), + sample_inputs_func=sample_inputs_linalg_qr_geqrf, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # In-place ops + check_batched_gradgrad=False, + decorators=[skipCUDAIfNoCusolver, skipCPUIfNoLapack]), + UnaryUfuncInfo('rad2deg', + ref=np.degrees, + decorators=(precisionOverride({torch.bfloat16: 7e-1, + torch.float16: 7e-1}),), + dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_sparse=True, + supports_sparse_csr=True, + supports_sparse_csc=True, + supports_sparse_bsr=True, + supports_sparse_bsc=True, + promotes_int_to_float=True), + UnaryUfuncInfo('real', + ref=np.real, + dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.half, torch.chalf), + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # See https://github.com/pytorch/pytorch/issues/66357 + check_batched_forward_grad=False, + skips=( + # Skip since real and imag don't have out variants. + DecorateInfo(unittest.expectedFailure, 'TestUnaryUfuncs', 'test_out_arg_all_dtypes'), + )), + OpInfo( + "roll", + ref=np.roll, + dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.half, torch.chalf), + error_inputs_func=error_inputs_roll, + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + sample_inputs_func=sample_inputs_roll, + decorators=(onlyNativeDeviceTypes,), + ), + OpInfo( + "rot90", + dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.half), + error_inputs_func=error_inputs_rot90, + # Runs very slowly on slow gradcheck - alternatively reduce input sizes + gradcheck_fast_mode=True, + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + sample_inputs_func=sample_inputs_rot90, + ), + # To test reference numerics against multiple values of argument `decimals`, + # we make multiple OpInfo entries with each entry corresponding to different value of decimals. + UnaryUfuncInfo('round', + ref=np.round, + aliases=('special.round',), + dtypes=all_types_and(torch.half, torch.bfloat16), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + skips=( + DecorateInfo(unittest.expectedFailure, + 'TestNNCOpInfo', + 'test_nnc_correctness', + dtypes=tuple(t for t in integral_types() if t != torch.uint8)), + DecorateInfo(unittest.skip("Skipped!"), + 'TestNNCOpInfo', + 'test_nnc_correctness', + dtypes=(torch.bfloat16,)), + ), + supports_sparse=True, + supports_sparse_csr=True, + supports_sparse_csc=True, + supports_sparse_bsr=True, + supports_sparse_bsc=True, + assert_autodiffed=True, + ), + UnaryUfuncInfo('round', + ref=np.round, + variant_test_name='decimals_0', + aliases=('special.round',), + dtypes=floating_types_and(torch.half, torch.bfloat16), + sample_kwargs=lambda device, dtype, input: ({'decimals': 0}, {'decimals': 0}), + sample_inputs_func=partial(sample_inputs_elementwise_unary, op_kwargs={'decimals': 0}), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + assert_autodiffed=False, + supports_sparse_csr=False), + UnaryUfuncInfo('round', + ref=np.round, + variant_test_name='decimals_3', + aliases=('special.round',), + dtypes=floating_types_and(torch.bfloat16), + dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16), + sample_kwargs=lambda device, dtype, input: ({'decimals': 3}, {'decimals': 3}), + sample_inputs_func=partial(sample_inputs_elementwise_unary, op_kwargs={'decimals': 3}), + skips=( + # test_ops already tested for this overload with `decimals_0` opinfo entry + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon'), + DecorateInfo(unittest.skip("Skipped!"), 'TestFwdGradients'), + DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients'), + DecorateInfo(unittest.skip("Skipped!"), 'TestJit'), + DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits'), + DecorateInfo(toleranceOverride({torch.bfloat16: tol(atol=1e-3, rtol=0.016)}), + "TestUnaryUfuncs", "test_reference_numerics_extremal", + device_type="cuda"), + DecorateInfo(toleranceOverride({torch.bfloat16: tol(atol=1e-3, rtol=0.016)}), + "TestUnaryUfuncs", "test_reference_numerics_normal", + device_type="cuda"), + ), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + assert_autodiffed=False, + supports_sparse_csr=False), + UnaryUfuncInfo('round', + ref=np.round, + variant_test_name='decimals_neg_3', + aliases=('special.round',), + dtypes=floating_types_and(torch.bfloat16), + dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16), + sample_kwargs=lambda device, dtype, input: ({'decimals': -3}, {'decimals': -3}), + sample_inputs_func=partial(sample_inputs_elementwise_unary, op_kwargs={'decimals': -3}), + skips=( + # test_ops already tested for this overload with `decimals_0` opinfo entry + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon'), + DecorateInfo(unittest.skip("Skipped!"), 'TestFwdGradients'), + DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients'), + DecorateInfo(unittest.skip("Skipped!"), 'TestJit'), + DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits'), + ), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + assert_autodiffed=False, + supports_sparse_csr=False), + UnaryUfuncInfo('sin', + ref=np.sin, + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + dtypesIfCUDA=all_types_and_complex_and(torch.chalf, torch.bool, torch.half, torch.bfloat16), + assert_autodiffed=True, + handles_large_floats=False, + supports_sparse=True, + supports_sparse_csr=True, + supports_sparse_csc=True, + supports_sparse_bsr=True, + supports_sparse_bsc=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + promotes_int_to_float=True, + skips=( + # Fails on CUDA but passes on ROCm + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', + dtypes=(torch.cdouble,), device_type='cuda'), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', + dtypes=(torch.cfloat, torch.cdouble,), device_type='cpu', active_if=IS_WINDOWS), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', + dtypes=(torch.cfloat, torch.cdouble,), device_type='cpu', active_if=IS_WINDOWS), + DecorateInfo(unittest.skip("Skipped! sparse backward not supported"), + 'TestSparseUnaryUfuncs', 'test_sparse_fn_grad'), + DecorateInfo(toleranceOverride({torch.float16: tol(atol=1e-3, rtol=2e-3)}), + "TestConsistency", "test_output_grad_match", device_type="mps"), + ), + decorators=(precisionOverride({torch.bfloat16: 1e-2}),)), + UnaryUfuncInfo('sinc', + ref=np_sinc_with_fp16_as_fp32, + aliases=('special.sinc',), + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + handles_large_floats=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + promotes_int_to_float=True), + UnaryUfuncInfo('sinh', + ref=np_unary_ufunc_integer_promotion_wrapper(np.sinh), + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + dtypesIfCUDA=all_types_and_complex_and(torch.chalf, torch.bool, torch.half, torch.bfloat16), + assert_autodiffed=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_sparse=True, + supports_sparse_csr=True, + supports_sparse_csc=True, + supports_sparse_bsr=True, + supports_sparse_bsc=True, + promotes_int_to_float=True, + decorators=(precisionOverride({torch.float16: 1e-2}),), + skips=( + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', + device_type='cpu', dtypes=[torch.cfloat, torch.cdouble], + active_if=(IS_MACOS or IS_WINDOWS)), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', + device_type='cpu', dtypes=[torch.cfloat, torch.cdouble], + active_if=(IS_MACOS or IS_WINDOWS)), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', + dtypes=(torch.cdouble,)), + # Reference: https://github.com/pytorch/pytorch/issues/48641 + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', + device_type='cpu', dtypes=[torch.int8]), + DecorateInfo(unittest.skip("Skipped! sparse backward not supported"), + 'TestSparseUnaryUfuncs', 'test_sparse_fn_grad'), + )), + UnaryUfuncInfo('sign', + ref=reference_sign, + dtypes=all_types_and(torch.bool, torch.bfloat16, torch.half), + dtypesIfCUDA=all_types_and(torch.bool, torch.bfloat16, torch.half), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_sparse=True, + supports_sparse_csr=True, + supports_sparse_csc=True, + supports_sparse_bsr=True, + supports_sparse_bsc=True, + skips=( + # Reference: https://github.com/pytorch/pytorch/issues/41245 + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', + dtypes=[torch.bfloat16, torch.float16, torch.float32, torch.float64]), + )), + UnaryUfuncInfo('sgn', + ref=reference_sgn, + dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.half, torch.chalf), + backward_dtypes=floating_and_complex_types_and(torch.bfloat16, torch.half), + backward_dtypesIfCUDA=floating_and_complex_types_and(torch.bfloat16, torch.half, torch.chalf), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_sparse=True, + supports_sparse_csr=True, + supports_sparse_csc=True, + supports_sparse_bsr=True, + supports_sparse_bsc=True, + skips=( + # Reference: https://github.com/pytorch/pytorch/issues/41245 + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', + dtypes=[torch.bfloat16, torch.float16, torch.float32, torch.float64]), + DecorateInfo(unittest.skip("Skipped! sparse backward not supported"), + 'TestSparseUnaryUfuncs', 'test_sparse_fn_grad'), + )), + OpInfo('split', + dtypes=all_types_and_complex_and(torch.bfloat16, torch.half, torch.bool, torch.chalf), + sample_inputs_func=partial(sample_inputs_split, list_args=False), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_out=False, + autodiff_fusible_nodes=[], # aliases inputs, shouldn't be fused + autodiff_nonfusible_nodes=[], # aliases inputs, shouldn't be fused + assert_autodiffed=True), + OpInfo('split', + # Cannot declare this aten_name because of + # test_variant_consistency_jit_split_list_args_cpu_float32 + decomp_aten_name='split_with_sizes', + variant_test_name='list_args', + dtypes=all_types_and_complex_and(torch.bfloat16, torch.half, torch.bool), + sample_inputs_func=partial(sample_inputs_split, list_args=True), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_out=False), + # `unsafe_split` supports only `int` for split_size argument + OpInfo('unsafe_split', + dtypes=all_types_and_complex_and(torch.bfloat16, torch.half, torch.bool, torch.chalf), + sample_inputs_func=partial(sample_inputs_split, list_args=False), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_out=False, + autodiff_fusible_nodes=[], # aliases inputs, shouldn't be fused + autodiff_nonfusible_nodes=[], # aliases inputs, shouldn't be fused + assert_autodiffed=True, + check_batched_forward_grad=False), + OpInfo('split_with_sizes', + dtypes=all_types_and_complex_and(torch.bfloat16, torch.half, torch.bool, torch.chalf), + sample_inputs_func=sample_inputs_split_with_sizes, + autodiff_fusible_nodes=[], # aliases inputs, shouldn't be fused + autodiff_nonfusible_nodes=[], # aliases inputs, shouldn't be fused + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + assert_autodiffed=True), + OpInfo('split_with_sizes_copy', + dtypes=all_types_and_complex_and(torch.bfloat16, torch.half, torch.bool, torch.chalf), + sample_inputs_func=sample_inputs_split_with_sizes, + supports_out=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + skips=( + # No error raised + DecorateInfo(unittest.expectedFailure, "TestCommon", "test_out_requires_grad_error"), + )), + BinaryUfuncInfo('__radd__', + op=torch.Tensor.__radd__, + dtypes=all_types_and_complex_and(torch.bfloat16, torch.half, torch.bool), + supports_out=False, + skips=( + DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit',), + + ), + assert_autodiffed=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + autodiff_nonfusible_nodes=['aten::add'],), + BinaryUfuncInfo('__rdiv__', + op=torch.Tensor.__rdiv__, + dtypes=all_types_and_complex_and(torch.bfloat16, torch.half, torch.bool), + promotes_int_to_float=True, + lhs_make_tensor_kwargs={'exclude_zero': True}, + # Runs very slowly on slow gradcheck - alternatively reduce input sizes + gradcheck_fast_mode=True, + supports_out=False, + skips=( + # https://github.com/pytorch/pytorch/issues/76806 + DecorateInfo(unittest.expectedFailure, 'TestBinaryUfuncs', 'test_type_promotion'), + DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit',), + ), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + assert_autodiffed=True, + autodiff_nonfusible_nodes=['aten::mul', 'aten::reciprocal'],), + BinaryUfuncInfo('__rmul__', + op=torch.Tensor.__rmul__, + dtypes=all_types_and_complex_and(torch.bfloat16, torch.half, torch.bool), + supports_out=False, + skips=( + DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit',), + ), + assert_autodiffed=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + autodiff_nonfusible_nodes=['aten::mul'],), + BinaryUfuncInfo('__rand__', + op=torch.Tensor.__rand__, + dtypes=integral_types_and(torch.bool), + supports_out=False, + supports_autograd=False, + supports_forward_ad=True, + skips=( + DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), + )), + BinaryUfuncInfo('__ror__', + op=torch.Tensor.__ror__, + dtypes=integral_types_and(torch.bool), + supports_out=False, + supports_autograd=False, + supports_forward_ad=True, + skips=( + DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), + )), + BinaryUfuncInfo('__rxor__', + op=torch.Tensor.__rxor__, + dtypes=integral_types_and(torch.bool), + supports_out=False, + supports_autograd=False, + supports_forward_ad=True, + skips=( + DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), + )), + OpInfo('__rmatmul__', + op=torch.Tensor.__rmatmul__, + dtypes=all_types_and_complex_and(torch.bfloat16, torch.float16), + dtypesIfCUDA=floating_and_complex_types_and(torch.float16, + *[torch.bfloat16] + if SM53OrLater or TEST_WITH_ROCM else []), + assert_autodiffed=True, + sample_inputs_func=partial(sample_inputs_matmul, is_rmatmul=True), + # Runs very slowly on slow gradcheck - alternatively reduce input sizes + gradcheck_fast_mode=True, + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + check_batched_forward_grad=False, + decorators=( + # NVIDIA only assures that bfloat16 is supported by bmm if SM >= 5.3 + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_dtypes', device_type='cuda', active_if=not SM53OrLater), + DecorateInfo(toleranceOverride({torch.complex64: tol(atol=1e-05, rtol=1.2e-03)}), + 'TestMathBits', 'test_conj_view'), + DecorateInfo(toleranceOverride({torch.float32: tol(atol=1e-05, rtol=1.2e-03)}), + 'TestCommon', 'test_noncontiguous_samples'), + DecorateInfo(toleranceOverride({torch.complex64: tol(atol=1e-05, rtol=1e-05)}), + "TestDecomp", "test_comprehensive", device_type="cuda", + active_if=TEST_WITH_ROCM), + ), + skips=( + DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit',), + # https://github.com/pytorch/pytorch/issues/67470 + DecorateInfo(unittest.skip("67470!"), + 'TestCommon', 'test_noncontiguous_samples', + device_type='cpu', dtypes=(torch.long,)), + # Fails on XLA. + # AssertionError: False is not true : Tensors failed to compare as equal + DecorateInfo(unittest.skip("Skipped!"), 'TestOpInfo', device_type='xla', dtypes=(torch.long,)), + # https://github.com/pytorch/pytorch/issues/71774 + DecorateInfo(unittest.skip('Skipped!'), 'TestNNCOpInfo', 'test_nnc_correctness', + device_type='cpu', dtypes=(torch.long,)), + )), + BinaryUfuncInfo('__rmod__', + op=torch.Tensor.__rmod__, + dtypes=floating_types_and(torch.bfloat16, torch.half,), + dtypesIfCUDA=all_types_and(torch.bfloat16, torch.half), + # https://github.com/pytorch/pytorch/issues/80411 + gradcheck_fast_mode=True, + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_one_python_scalar=True, + skips=( + DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit',), + ), + # Support autograd after torch.remainder(Tensor, Tensor) supports + # autograd of the second argument. + # https://github.com/pytorch/pytorch/pull/58476/files#r637167630 + # supports_autograd=False, + assert_autodiffed=True, + autodiff_nonfusible_nodes=['aten::remainder'],), + BinaryUfuncInfo('__rpow__', + op=torch.Tensor.__rpow__, + dtypes=all_types_and_complex_and(torch.bfloat16, torch.half), + # Reference: https://github.com/pytorch/pytorch/issues/54774 + # "log2" "_vml_cpu" not implemented for Half + backward_dtypes=all_types_and_complex_and(torch.bfloat16, torch.half), + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_one_python_scalar=True, + skips=( + DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit',), + # TODO: FIXME tolerance is too high + DecorateInfo(unittest.skip('Skipped!'), 'TestFwdGradients'), + DecorateInfo(unittest.skip('Skipped!'), 'TestBwdGradients'), + ), + assert_autodiffed=True, + autodiff_nonfusible_nodes=['aten::pow'],), + BinaryUfuncInfo('__rsub__', + op=torch.Tensor.__rsub__, + dtypes=all_types_and_complex_and(torch.bfloat16, torch.half), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_out=False, + supports_one_python_scalar=True, + skips=( + DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit',), + ), + assert_autodiffed=True, + autodiff_nonfusible_nodes=['aten::rsub'],), + BinaryUfuncInfo('rsub', + dtypes=all_types_and_complex_and(torch.bfloat16, torch.half), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_out=False, + supports_inplace_autograd=False, + assert_autodiffed=None, + sample_inputs_func=sample_inputs_add_sub), + OpInfo('select', + aten_backward_name='select_backward', + dtypes=all_types_and_complex_and(torch.bfloat16, torch.half, torch.bool, torch.chalf), + sample_inputs_func=sample_inputs_select, + assert_jit_shape_analysis=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_out=False), + OpInfo('select_scatter', + dtypes=all_types_and(torch.bfloat16, torch.half, torch.bool), + sample_inputs_func=sample_inputs_select_scatter, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_out=False), + OpInfo('slice', + op=torch.ops.aten.slice.Tensor, + dtypes=all_types_and_complex_and(torch.bfloat16, torch.half, torch.bool, torch.chalf), + sample_inputs_func=sample_inputs_slice, + gradcheck_fast_mode=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_scripting=False, + supports_inplace_autograd=False, + supports_out=False), + OpInfo('slice_scatter', + dtypes=all_types_and(torch.bfloat16, torch.half, torch.bool), + sample_inputs_func=sample_inputs_slice_scatter, + # https://github.com/pytorch/pytorch/issues/80411 + gradcheck_fast_mode=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_out=True), + UnaryUfuncInfo('signbit', + ref=np.signbit, + dtypes=all_types_and(torch.bool, torch.bfloat16, torch.half), + supports_sparse=True, + supports_sparse_csr=True, + supports_sparse_csc=True, + supports_sparse_bsr=True, + supports_sparse_bsc=True, + supports_autograd=False,), + UnaryUfuncInfo('tan', + ref=np.tan, + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + dtypesIfCUDA=all_types_and_complex_and(torch.chalf, torch.bool, torch.half, torch.bfloat16), + decorators=(DecorateInfo( + toleranceOverride({torch.complex64: tol(atol=1e-04, rtol=1e-05)}), + 'TestUnaryUfuncs', 'test_reference_numerics_extremal', + device_type='cuda'),), + assert_autodiffed=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_sparse=True, + supports_sparse_csr=True, + supports_sparse_csc=True, + supports_sparse_bsr=True, + supports_sparse_bsc=True, + promotes_int_to_float=True, + skips=( + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', + device_type='cpu', dtypes=[torch.cfloat, torch.cdouble], + active_if=(IS_MACOS or IS_WINDOWS)), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', + device_type='cpu', dtypes=[torch.cfloat, torch.cdouble], + active_if=(IS_MACOS or IS_WINDOWS)), + DecorateInfo(unittest.skip("Skipped! sparse backward not supported"), + 'TestSparseUnaryUfuncs', 'test_sparse_fn_grad'), + # FIXME: + # Mismatched elements: 2 / 400 (0.5%) + # Greatest absolute difference: inf at index (7, 16) (up to 1e-05 allowed) + # Greatest relative difference: nan at index (7, 16) (up to 0.001 allowed) + DecorateInfo( + unittest.skip("Skipped!"), + "TestInductorOpInfo", + "test_comprehensive", + dtypes=(torch.float16,), + device_type="cuda", + ), + DecorateInfo(toleranceOverride({torch.complex64: tol(atol=3e-5, rtol=7e-6)}), + "TestConsistency", "test_output_match", device_type="mps"), + DecorateInfo(toleranceOverride({torch.float16: tol(atol=1e-3, rtol=2e-3)}), + "TestConsistency", "test_output_grad_match", device_type="mps"), + ), + # tan(pi/2 * odd_number) is nan + reference_numerics_filter=NumericsFilter( + condition=lambda x: close_to_int(x / (math.pi * 0.5)), safe_val=math.pi)), + UnaryUfuncInfo('tanh', + ref=np.tanh, + aten_backward_name='tanh_backward', + aliases=('nn.functional.tanh',), + decorators=(precisionOverride({torch.bfloat16: 1e-2}), + DecorateInfo( + toleranceOverride({torch.complex64: tol(atol=1e-04, rtol=2e-05)}), + 'TestUnaryUfuncs', 'test_reference_numerics_extremal', + device_type='cuda'),), + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + dtypesIfCUDA=all_types_and_complex_and(torch.chalf, torch.bool, torch.half, torch.bfloat16), + assert_autodiffed=True, + assert_jit_shape_analysis=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_sparse=True, + supports_sparse_csr=True, + supports_sparse_csc=True, + supports_sparse_bsr=True, + supports_sparse_bsc=True, + promotes_int_to_float=True, + skips=( + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', + device_type='cpu', dtypes=[torch.cfloat, torch.cdouble], + active_if=(IS_MACOS or IS_WINDOWS)), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', + device_type='cpu', dtypes=[torch.cfloat, torch.cdouble], + active_if=(IS_MACOS or IS_WINDOWS)), + DecorateInfo(unittest.skip("Skipped! sparse backward not supported"), + 'TestSparseUnaryUfuncs', 'test_sparse_fn_grad'), + DecorateInfo(toleranceOverride({torch.complex64: tol(atol=3e-5, rtol=7e-6)}), + "TestConsistency", "test_output_match", device_type="mps"), + ), + # tan(j * pi/2 * odd_number) is nan + reference_numerics_filter=NumericsFilter( + condition=lambda x: (close_to_int(x / (math.pi * 0.5j)) + if x.is_complex() else x.new_tensor(False, dtype=torch.bool)), + safe_val=0)), + OpInfo('tensor_split', + ref=np.array_split, + dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16), + dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16), + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + skips=( + # Pre-existing condition; Needs to be fixed + DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_operator'), + DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_backward'), + DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_forward_ad'), + ), + sample_inputs_func=sample_inputs_tensor_split,), + OpInfo('hsplit', + dtypes=all_types_and_complex_and(torch.complex32, torch.bool, torch.bfloat16, torch.float16), + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # See https://github.com/pytorch/pytorch/pull/78358 + check_batched_forward_grad=False, + sample_inputs_func=sample_inputs_hsplit, + error_inputs_func=error_inputs_hsplit,), + OpInfo('vsplit', + dtypes=all_types_and_complex_and(torch.complex32, torch.bool, torch.bfloat16, torch.float16), + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # See https://github.com/pytorch/pytorch/pull/78358 + check_batched_forward_grad=False, + sample_inputs_func=sample_inputs_vsplit, + error_inputs_func=error_inputs_vsplit,), + OpInfo('dsplit', + dtypes=all_types_and_complex_and(torch.complex32, torch.bool, torch.bfloat16, torch.float16), + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # See https://github.com/pytorch/pytorch/pull/78358 + check_batched_forward_grad=False, + sample_inputs_func=sample_inputs_dsplit, + error_inputs_func=error_inputs_dsplit,), + OpInfo('triangular_solve', + op=torch.triangular_solve, + dtypes=floating_and_complex_types(), + sample_inputs_func=sample_inputs_legacy_solve, + check_batched_gradgrad=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + gradcheck_wrapper=lambda *args, **kwargs: gradcheck_wrapper_triangular_input(*args, idx=1, **kwargs), + decorators=[ + skipCUDAIfNoMagma, + skipCPUIfNoLapack, + DecorateInfo( + toleranceOverride({torch.float32: tol(atol=3e-5, rtol=3e-6)}), + 'TestConsistency', 'test_output_match', device_type='cpu', + ), + ], + skips=( + # AssertionError: Scalars are not equal! + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'), + # Gradcheck fails + DecorateInfo(unittest.expectedFailure, 'TestFwdGradients', 'test_fn_fwgrad_bwgrad', + dtypes=floating_and_complex_types()), + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_out', + device_type='mps', dtypes=[torch.float32]), + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_variant_consistency_eager', + device_type='mps', dtypes=[torch.float32]), + DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit', + device_type='mps', dtypes=[torch.float32]), + )), + UnaryUfuncInfo('trunc', + aliases=('fix', ), + ref=np.trunc, + dtypes=all_types_and(torch.half, torch.bfloat16), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_sparse=True, + skips=( + DecorateInfo(unittest.expectedFailure, + 'TestNNCOpInfo', + 'test_nnc_correctness', + dtypes=tuple(t for t in integral_types() if t != torch.uint8)), + ), + supports_sparse_csr=True, + supports_sparse_csc=True, + supports_sparse_bsr=True, + supports_sparse_bsc=True, + assert_autodiffed=True), + UnaryUfuncInfo('exp2', + aliases=('special.exp2', ), + ref=np_unary_ufunc_integer_promotion_wrapper(np.exp2), + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + promotes_int_to_float=True, + skips=( + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', + dtypes=[torch.cdouble]), + # Reference: https://github.com/pytorch/pytorch/issues/48010 + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', + device_type='cpu', dtypes=[torch.cfloat, torch.cdouble], active_if=IS_WINDOWS), + )), + UnaryUfuncInfo('expm1', + aliases=('special.expm1', ), + ref=np_unary_ufunc_integer_promotion_wrapper(np.expm1), + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_sparse=True, + supports_sparse_csr=True, + supports_sparse_csc=True, + supports_sparse_bsr=True, + supports_sparse_bsc=True, + promotes_int_to_float=True, + assert_autodiffed=True, + skips=( + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', + device_type='cuda', dtypes=[torch.complex128]), + DecorateInfo(unittest.skip("Skipped! sparse backward not supported"), + 'TestSparseUnaryUfuncs', 'test_sparse_fn_grad'), + )), + UnaryUfuncInfo('nan_to_num', + ref=np.nan_to_num, + dtypes=all_types_and(torch.half, torch.bool, torch.bfloat16), + dtypesIfCUDA=all_types_and(torch.half, torch.bool, torch.bfloat16), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_sparse=True, + skips=( + DecorateInfo(unittest.skip("Skipped! sparse backward not supported"), + 'TestSparseUnaryUfuncs', 'test_sparse_fn_grad'), + ), + # Passing numpy_kwargs via sample_kwargs, as numpy does comparison + # with BFloat16 in float, since it currently doesn't support BFloat16. + # Ref: https://github.com/pytorch/pytorch/issues/57982#issuecomment-839150556 + sample_kwargs=lambda device, dtype, input: ({}, + {'posinf': torch.finfo(torch.bfloat16).max, + 'neginf': torch.finfo(torch.bfloat16).min}) + if dtype is torch.bfloat16 else ({}, {})), + UnaryUfuncInfo('reciprocal', + ref=np_unary_ufunc_integer_promotion_wrapper(np.reciprocal), + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + assert_autodiffed=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + promotes_int_to_float=True, + skips=( + # Reference: https://github.com/pytorch/pytorch/issues/45690 + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', + dtypes=[torch.cfloat, torch.cdouble]), + )), + UnaryUfuncInfo('rsqrt', + ref=lambda x: np.reciprocal(np.sqrt(x)), + domain=(0, None), + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + dtypesIfCUDA=all_types_and_complex_and(torch.chalf, torch.bool, torch.half, torch.bfloat16), + decorators=(precisionOverride({torch.half: 5e-2}),), + assert_autodiffed=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + promotes_int_to_float=True, + skips=( + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', + dtypes=(torch.cfloat, torch.cdouble)), + # AssertionError: Tensor-likes are not close! + # Greatest absolute difference: nan at index (700,) (up to 0.01 allowed) + # Greatest relative difference: nan at index (700,) (up to 0.001 allowed) + DecorateInfo(unittest.expectedFailure, 'TestUnaryUfuncs', 'test_reference_numerics_large', + dtypes=(torch.chalf,)), + )), + UnaryUfuncInfo('sqrt', + ref=np.sqrt, + supports_sparse=True, + domain=(0, None), + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + dtypesIfCUDA=all_types_and_complex_and(torch.chalf, torch.bool, torch.half, torch.bfloat16), + assert_autodiffed=True, + supports_forward_ad=True, + supports_sparse_csr=True, + supports_sparse_csc=True, + supports_sparse_bsr=True, + supports_sparse_bsc=True, + supports_fwgrad_bwgrad=True, + promotes_int_to_float=True, + decorators=( + precisionOverride({torch.bfloat16: 7e-2}), + DecorateInfo( + toleranceOverride({torch.chalf: tol(atol=1e-2, rtol=0)}), + 'TestUnaryUfuncs', 'test_reference_numerics_large'), + ), + skips=( + # Reference: https://github.com/pytorch/pytorch/issues/47358 + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', + device_type='cpu', dtypes=(torch.cfloat, torch.cdouble), + active_if=IS_MACOS), + DecorateInfo(unittest.skip("Skipped! sparse backward not supported"), + 'TestSparseUnaryUfuncs', 'test_sparse_fn_grad'), + DecorateInfo(toleranceOverride({torch.complex64: tol(atol=2e-5, rtol=3e-6)}), + "TestConsistency", "test_output_match", device_type="mps"), + )), + UnaryUfuncInfo('square', + ref=np.square, + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), + decorators=(precisionOverride({torch.complex64: 3e-4, torch.bfloat16: 3e-1}),), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + skips=( + # Reference: https://github.com/pytorch/pytorch/issues/52549 + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', + dtypes=[torch.cfloat, torch.cdouble]), + # >>> t = torch.tensor(complex(-0.01, float("inf"))) + # >>> np.square(t.numpy()) + # (-inf-infj) + # >>> t.square() + # tensor(-inf-infj) + # >>> t.cuda().square() + # tensor(inf+nanj, device='cuda:0') + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', + device_type='cuda', dtypes=[torch.cfloat, torch.cdouble]), + DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_meta_inplace', + dtypes=[torch.bool]), + DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_dispatch_meta_inplace', + dtypes=[torch.bool]), + DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_dispatch_symbolic_meta_inplace', + dtypes=[torch.bool]), + ),), + OpInfo('lerp', + dtypes=floating_and_complex_types_and(torch.bfloat16, torch.half), + dtypesIfCUDA=floating_and_complex_types_and(torch.chalf, torch.half, torch.bfloat16), + sample_inputs_func=sample_inputs_lerp, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + assert_autodiffed=True), + UnaryUfuncInfo('angle', + ref=np.angle, + dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16), + dtypesIfCUDA=all_types_and_complex_and(torch.chalf, torch.bool), + decorators=(precisionOverride({torch.float16: 1e-2, + torch.bfloat16: 1e-2}),), + backward_dtypes=floating_and_complex_types_and(torch.bfloat16, torch.float16), + backward_dtypesIfCUDA=floating_and_complex_types_and(torch.chalf), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_sparse_csr=True, + supports_sparse_csc=True, + supports_sparse_bsr=True, + supports_sparse_bsc=True, + supports_complex_to_float=True, + skips=( + # Ref: https://github.com/pytorch/pytorch/issues/78413 + DecorateInfo(unittest.expectedFailure, 'TestUnaryUfuncs', 'test_reference_numerics_small', + dtypes=(torch.bfloat16, torch.float16, torch.float32, torch.float64),), + )), + UnaryUfuncInfo('isfinite', + ref=np.isfinite, + dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16, torch.chalf), + supports_out=False, + supports_autograd=False), + UnaryUfuncInfo('isinf', + ref=np.isinf, + dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16, torch.chalf), + supports_out=False, + supports_sparse=True, + supports_sparse_csr=True, + supports_sparse_csc=True, + supports_sparse_bsr=True, + supports_sparse_bsc=True, + supports_autograd=False), + UnaryUfuncInfo('isposinf', + ref=np.isposinf, + dtypes=all_types_and(torch.bool, torch.bfloat16, torch.float16), + supports_sparse=True, + supports_sparse_csr=True, + supports_sparse_csc=True, + supports_sparse_bsr=True, + supports_sparse_bsc=True, + supports_autograd=False), + UnaryUfuncInfo('isneginf', + ref=np.isneginf, + dtypes=all_types_and(torch.bool, torch.bfloat16, torch.float16), + supports_sparse=True, + supports_sparse_csr=True, + supports_sparse_csc=True, + supports_sparse_bsr=True, + supports_sparse_bsc=True, + supports_autograd=False), + UnaryUfuncInfo('isreal', + ref=np.isreal, + dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16, torch.chalf), + supports_out=False, + supports_autograd=False), + UnaryUfuncInfo('isnan', + ref=np.isnan, + dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16), + supports_out=False, + supports_sparse=True, + supports_sparse_csr=True, + supports_sparse_csc=True, + supports_sparse_bsr=True, + supports_sparse_bsc=True, + supports_autograd=False), + OpInfo('einsum', + # we need this lambda because SampleInput expects tensor input as the first argument + # TODO(@heitorschueroff) update SampleInput to handle such cases + op=lambda tensors, equation: torch.einsum(equation, tensors), + dtypes=all_types_and_complex_and(torch.half, torch.bfloat16), + dtypesIfCUDA=floating_and_complex_types_and(torch.half, torch.bfloat16), + backward_dtypesIfCUDA=floating_and_complex_types_and(torch.half, torch.bfloat16), + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + check_batched_forward_grad=False, + # See https://github.com/pytorch/pytorch/issues/66357 + sample_inputs_func=sample_inputs_einsum, + skips=( + DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), + # test does not work with passing lambda for op + # there's a test `test_einsum` in `test_jit.py` to handle this case + # AssertionError: JIT Test does not execute any logic + DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'), + )), + OpInfo('svd', + op=torch.svd, + dtypes=floating_and_complex_types(), + sample_inputs_func=sample_inputs_svd, + # Runs very slowly on slow-gradcheck - alternatively reduce input sizes + gradcheck_fast_mode=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + check_batched_forward_grad=False, + # We're using at::allclose, which does not have a batching rule + check_batched_grad=False, + check_batched_gradgrad=False, + decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack, with_tf32_off], + skips=( + # Issue with conj and torch dispatch, see https://github.com/pytorch/pytorch/issues/82479 + DecorateInfo( + unittest.skip("Skipped!"), + 'TestSchemaCheckModeOpInfo', + 'test_schema_correctness', + dtypes=(torch.complex64, torch.complex128)), + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_out', + device_type='mps', dtypes=[torch.float32]), + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_variant_consistency_eager', + device_type='mps', dtypes=[torch.float32]), + DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit', + device_type='mps', dtypes=[torch.float32]), + )), + OpInfo('svd_lowrank', + op=lambda *args, **kwargs: wrapper_set_seed( + lambda a, b, **kwargs: torch.svd_lowrank(a @ b.mT, **kwargs), + *args, **kwargs + ), + dtypes=floating_and_complex_types(), + # Runs very slowly on slow gradcheck - alternatively reduce input sizes + gradcheck_fast_mode=True, + supports_out=False, + # Due to the use of randomness + check_batched_grad=False, + check_batched_gradgrad=False, + check_batched_forward_grad=False, + supports_fwgrad_bwgrad=True, + supports_forward_ad=True, + sample_inputs_func=sample_inputs_svd_lowrank, + decorators=[skipCUDAIfNoCusolver, skipCPUIfNoLapack, with_tf32_off, + DecorateInfo(toleranceOverride({torch.float32: tol(atol=1e-03, rtol=1e-03), + torch.complex64: tol(atol=1e-02, rtol=1e-02)}), + 'TestCommon', 'test_noncontiguous_samples'), + # FIXME This should be the following, but the toleranceOverride does not seem to do anything! + # DecorateInfo(toleranceOverride({torch.complex128: tol(atol=1e-04, rtol=1e-04)}), + # 'TestFwdGradients', 'test_fn_fwgrad_bwgrad'), + DecorateInfo(unittest.skip("See comment above"), + 'TestFwdGradients', + 'test_fn_fwgrad_bwgrad', + dtypes=[torch.complex128]), + ], + skips=( + # test does not work with passing lambda for op + DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'), + # Issue with conj and torch dispatch, see https://github.com/pytorch/pytorch/issues/82479 + DecorateInfo(unittest.expectedFailure, 'TestSchemaCheckModeOpInfo', 'test_schema_correctness', + dtypes=(torch.complex64, torch.complex128)), + DecorateInfo(slowTest, 'TestCompositeCompliance', 'test_forward_ad'), + )), + OpInfo('pca_lowrank', + op=lambda *args, **kwargs: wrapper_set_seed( + lambda a, b, **kwargs: torch.pca_lowrank(a @ b.mT, **kwargs), + *args, **kwargs + ), + dtypes=floating_and_complex_types(), + # Runs very slowly on slow gradcheck - alternatively reduce input sizes + gradcheck_fast_mode=True, + supports_out=False, + check_batched_forward_grad=False, + check_batched_grad=False, + check_batched_gradgrad=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + sample_inputs_func=sample_inputs_pca_lowrank, + decorators=[skipCUDAIfNoCusolver, skipCPUIfNoLapack, with_tf32_off, + DecorateInfo(toleranceOverride({torch.float32: tol(atol=1e-03, rtol=1e-03), + torch.complex64: tol(atol=4e-02, rtol=4e-02)}), + 'TestCommon', 'test_noncontiguous_samples'), + DecorateInfo(toleranceOverride({torch.float32: tol(atol=1e-05, rtol=5e-05)}), + 'TestOperators', 'test_grad'), + # FIXME This should be the following, but the toleranceOverride does not seem to do anything! + # DecorateInfo(toleranceOverride({torch.complex128: tol(atol=1e-04, rtol=1e-04)}), + # 'TestFwdGradients', 'test_fn_fwgrad_bwgrad'), + DecorateInfo(unittest.skip("See comment above"), + 'TestFwdGradients', + 'test_fn_fwgrad_bwgrad', + dtypes=[torch.complex128]), + DecorateInfo( + toleranceOverride({torch.float32: tol(atol=3e-5, rtol=1e-3)}), + 'TestInductorOpInfo', 'test_comprehensive', device_type='cuda'), + ], + skips=( + # test does not work with passing lambda for op + DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + # Issue with conj and torch dispatch, see https://github.com/pytorch/pytorch/issues/82479 + DecorateInfo(unittest.expectedFailure, 'TestSchemaCheckModeOpInfo', 'test_schema_correctness', + dtypes=(torch.complex64, torch.complex128)), + DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'), + )), + BinaryUfuncInfo('polar', + dtypes=floating_types(), + # this function is undefined if 'abs' values are <0 + supports_forward_ad=True, + lhs_make_tensor_kwargs=dict(low=0), + supports_rhs_python_scalar=False, + skips=( + # RuntimeError: Expected object of scalar type Float but got scalar type Double for second argument + DecorateInfo(unittest.skip('Skipped!'), 'TestBinaryUfuncs', 'test_type_promotion'), + DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_binary_ufuncs_mixed_dtype'), + # GradcheckError: Jacobian computed with forward mode mismatch for output 0 with respect to input 0 + # Numerical: + # tensor([[0.]], dtype=torch.float64) + # Analytical: + # tensor([[-0.0047]], dtype=torch.float64, grad_fn=) + DecorateInfo(unittest.expectedFailure, 'TestFwdGradients', 'test_fn_fwgrad_bwgrad'), + )), + # TODO(@kshitij12345): Refactor similar to `mvlgamma` entries. + # To test reference numerics against multiple values of argument `n`, + # we make multiple OpInfo entries with each entry corresponding to different value of n (currently 0 to 4). + # We run the op tests from test_ops.py only for `n=0` to avoid redundancy in testing. + UnaryUfuncInfo('polygamma', + op=lambda x, n, **kwargs: torch.polygamma(n, x, **kwargs), + variant_test_name='polygamma_n_0', + ref=reference_polygamma if TEST_SCIPY else None, + dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16), + dtypesIfCUDA=all_types_and(torch.bool, torch.half, torch.bfloat16), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + promotes_int_to_float=True, + sample_inputs_func=sample_inputs_polygamma, + skips=( + DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), + ), + sample_kwargs=lambda device, dtype, input: ({'n': 0}, {'n': 0}), + # polygamma functions have multiple singularities at x having non-positive integer value + reference_numerics_filter=NumericsFilter(condition=lambda x: (x < 0.1) & ((x - x.round()).abs() < 1e-4), + safe_val=1)), + *(UnaryUfuncInfo('polygamma', + op=lambda x, n, **kwargs: torch.polygamma(n, x, **kwargs), + variant_test_name=f'polygamma_n_{n_}', + ref=reference_polygamma if TEST_SCIPY else None, + dtypes=all_types_and(torch.bool, torch.bfloat16), + dtypesIfCUDA=all_types_and(torch.bool, torch.half, torch.bfloat16), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + promotes_int_to_float=True, + sample_inputs_func=sample_inputs_polygamma, + decorators=( + DecorateInfo(toleranceOverride({torch.float32: tol(atol=1e-4, rtol=1e-3)}), 'TestUnaryUfuncs'), + DecorateInfo(toleranceOverride({torch.bfloat16: tol(atol=1e1, rtol=1e-1), + torch.float32: tol(atol=1e-4, rtol=1e-2)}), + 'TestUnaryUfuncs', 'test_reference_numerics_normal', + active_if=IS_WINDOWS), + ), + skips=( + # Redundant tests + DecorateInfo(unittest.skip("Skipped!"), 'TestFwdGradients'), + DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients'), + DecorateInfo(unittest.skip("Skipped!"), 'TestJit'), + DecorateInfo(unittest.skip("Skipped!"), 'TestNormalizeOperators'), + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon'), + # Mismatch: https://github.com/pytorch/pytorch/issues/55357 + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal'), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large'), + ), + sample_kwargs=lambda device, dtype, input: ({'n': n_}, {'n': n_}), + # polygamma functions have multiple singularities at x having non-positive integer value + reference_numerics_filter=NumericsFilter(condition=lambda x: (x < 0.1) & ((x - x.round()).abs() < 1e-4), + safe_val=1)) + for n_ in (1, 2, 3, 4)), + OpInfo('ravel', + ref=np.ravel, + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf), + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # See https://github.com/pytorch/pytorch/pull/78358 + check_batched_forward_grad=False, + sample_inputs_func=sample_inputs_ravel, + ), + OpInfo('unravel_index', + ref=np.unravel_index, + dtypes=integral_types_and(), + supports_out=False, + supports_autograd=False, + sample_inputs_func=sample_inputs_unravel_index, + ), + OpInfo('reshape', + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf), + sample_inputs_func=sample_inputs_view_reshape, + reference_inputs_func=reference_inputs_view_reshape, + error_inputs_func=error_inputs_view_reshape, + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + ), + OpInfo('reshape_as', + op=lambda x, other: x.reshape_as(other), + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf), + sample_inputs_func=partial(sample_inputs_view_reshape, tensor_arg=True), + reference_inputs_func=partial(reference_inputs_view_reshape, tensor_arg=True), + error_inputs_func=partial(error_inputs_view_reshape, tensor_arg=True), + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + skips=( + DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), + )), + OpInfo('view', + op=lambda x, shape: x.view(shape), + dtypes=all_types_and_complex_and(torch.complex32, torch.bool, torch.float16, torch.bfloat16), + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + assert_jit_shape_analysis=True, + sample_inputs_func=sample_inputs_view_reshape, + reference_inputs_func=reference_inputs_view_reshape, + error_inputs_func=error_inputs_view_reshape, + skips=( + DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), + # RuntimeError: view size is not compatible with input tensor's size and stride + # (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead. + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides"), + )), + OpInfo('view_as', + op=lambda x, other: x.view_as(other), + dtypes=all_types_and_complex_and(torch.complex32, torch.bool, torch.float16, torch.bfloat16), + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + sample_inputs_func=partial(sample_inputs_view_reshape, tensor_arg=True), + reference_inputs_func=partial(reference_inputs_view_reshape, tensor_arg=True), + error_inputs_func=partial(error_inputs_view_reshape, tensor_arg=True), + skips=( + DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), + # RuntimeError: view size is not compatible with input tensor's size and stride + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides") + )), + OpInfo('atleast_1d', + dtypes=all_types_and_complex_and(torch.complex32, torch.bool, torch.float16, torch.bfloat16), + # Runs very slowly on slow gradcheck - alternatively reduce input sizes + gradcheck_fast_mode=True, + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # See https://github.com/pytorch/pytorch/pull/78358 + check_batched_forward_grad=False, + sample_inputs_func=sample_inputs_atleast1d2d3d, + skips=( + # JIT does not support variadic tensors. + # RuntimeError: input->type()->kind() == TypeKind::OptionalType + # INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":252, + # please report a bug to PyTorch. + DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', dtypes=[torch.float32]), + ), + ), + OpInfo('atleast_2d', + dtypes=all_types_and_complex_and(torch.complex32, torch.bool, torch.float16, torch.bfloat16), + # Runs very slowly on slow gradcheck - alternatively reduce input sizes + gradcheck_fast_mode=True, + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # See https://github.com/pytorch/pytorch/pull/78358 + check_batched_forward_grad=False, + skips=( + DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', dtypes=[torch.float32]), + ), + sample_inputs_func=sample_inputs_atleast1d2d3d, + ), + OpInfo('atleast_3d', + dtypes=all_types_and_complex_and(torch.complex32, torch.bool, torch.float16, torch.bfloat16), + # Runs very slowly on slow gradcheck - alternatively reduce input sizes + gradcheck_fast_mode=True, + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # See https://github.com/pytorch/pytorch/pull/78358 + check_batched_forward_grad=False, + skips=( + DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', dtypes=[torch.float32]), + ), + sample_inputs_func=sample_inputs_atleast1d2d3d, + ), + OpInfo('flatten', + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf), + ref=reference_flatten, + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # See https://github.com/pytorch/pytorch/pull/78358 + check_batched_forward_grad=False, + sample_inputs_func=sample_inputs_flatten, + reference_inputs_func=reference_inputs_flatten, + ), + OpInfo('unflatten', + op=torch.unflatten, + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf), + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + sample_inputs_func=sample_inputs_unflatten, + ), + OpInfo('column_stack', + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # See https://github.com/pytorch/pytorch/pull/78358 + check_batched_forward_grad=False, + sample_inputs_func=sample_inputs_column_stack,), + OpInfo('pinverse', + op=torch.pinverse, + dtypes=floating_and_complex_types(), + check_batched_grad=False, + check_batched_gradgrad=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, + supports_out=False, + sample_inputs_func=sample_inputs_linalg_invertible, + decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack], + skips=( + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_variant_consistency_eager', + device_type='mps', dtypes=[torch.float32]), + DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit', + device_type='mps', dtypes=[torch.float32]), + )), + OpInfo('gather', + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), + dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), + sample_inputs_func=sample_inputs_gather, + gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + error_inputs_func=error_inputs_gather, + ), + OpInfo('index_fill', + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.complex32), + inplace_variant=torch.Tensor.index_fill_, + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # https://github.com/pytorch/pytorch/issues/66357 + check_batched_forward_grad=False, + skips=( + # RuntimeError: Mismatch on aten._unique.default: Shapes torch.Size([2]) and torch.Size([1]) are not equal! + DecorateInfo(unittest.expectedFailure, 'TestFakeTensor', 'test_fake_crossref_backward_no_amp'), + # RuntimeError: Mismatch on aten._unique.default: Shapes torch.Size([2]) and torch.Size([1]) are not equal! + DecorateInfo(unittest.expectedFailure, 'TestFakeTensor', 'test_fake_crossref_backward_amp'), + ), + sample_inputs_func=sample_inputs_index, + reference_inputs_func=partial(sample_inputs_index, reference=True)), + OpInfo('index_copy', + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.complex32), + supports_out=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # https://github.com/pytorch/pytorch/issues/66357 + check_batched_forward_grad=False, + sample_inputs_func=sample_inputs_index, + reference_inputs_func=partial(sample_inputs_index, reference=True), + gradcheck_nondet_tol=GRADCHECK_NONDET_TOL), + OpInfo('index_select', + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf), + backward_dtypesIfCUDA=floating_and_complex_types_and(torch.float16, torch.bfloat16, torch.chalf), + sample_inputs_func=sample_inputs_index, + reference_inputs_func=partial(sample_inputs_index, reference=True), + error_inputs_func=error_inputs_index_select, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + assert_jit_shape_analysis=True, + gradcheck_nondet_tol=GRADCHECK_NONDET_TOL), + OpInfo('index_add', + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf), + inplace_variant=torch.Tensor.index_add_, + supports_out=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # https://github.com/pytorch/pytorch/issues/66357 + check_batched_forward_grad=False, + sample_inputs_func=sample_inputs_index, + reference_inputs_func=partial(sample_inputs_index, reference=True), + error_inputs_func=error_inputs_index_add, + skips=( + # boolean alpha not handled properly + DecorateInfo(unittest.expectedFailure, + 'TestNNCOpInfo', + 'test_nnc_correctness', + dtypes=(torch.bool,)), + ), + gradcheck_nondet_tol=GRADCHECK_NONDET_TOL), + *(OpInfo('index_reduce', + variant_test_name=reduction_type, + dtypes=all_types_and(torch.float16, torch.bfloat16), + skips=( + DecorateInfo(toleranceOverride({torch.float16: tol(atol=2e-3, rtol=3e-3)}), + 'TestInductorOpInfo', 'test_comprehensive'), + ), + supports_out=True, + sample_inputs_func=sample_inputs_index_reduce, + ) for reduction_type in ('mean', 'prod', 'amin', 'amax')), + OpInfo('_unsafe_masked_index', + dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16, torch.bool), + supports_out=False, + supports_inplace_autograd=False, + supports_scripting=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + sample_inputs_func=sample_inputs__unsafe_masked_index, + skips=( + DecorateInfo(unittest.skip("Skipped!"), 'TestNNCOpInfo', 'test_nnc_correctness'), + DecorateInfo(slowTest, 'TestDecomp', 'test_quick_core_backward', + dtypes=(torch.float64,), active_if=IS_WINDOWS), + ),), + OpInfo('_unsafe_masked_index_put_accumulate', + dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16, torch.bool), + supports_out=False, + supports_inplace_autograd=False, + supports_scripting=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + decorators=( + DecorateInfo( + toleranceOverride({torch.float16: tol(atol=2e-3, rtol=3e-2)}), + 'TestInductorOpInfo', 'test_comprehensive', device_type='cpu' + ), + ), + sample_inputs_func=sample_inputs__unsafe_masked_index_put_accumulate, + skips=( + DecorateInfo(slowTest, 'TestDecomp', 'test_quick_core_backward', + dtypes=(torch.float64,), active_if=IS_WINDOWS), + ),), + OpInfo('__getitem__', + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf), + # Runs very slowly on slow gradcheck - alternatively reduce input sizes + gradcheck_fast_mode=True, + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_inplace_autograd=False, + supports_scripting=False, + op=torch.Tensor.__getitem__, + skips=( + DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), + # AssertionError: False is not true : Scalars failed to compare as equal! 0 != 104448 + DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit', device_type='cuda'),), + sample_inputs_func=sample_inputs_getitem), + OpInfo('index_put', + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf), + supports_out=False, + supports_inplace_autograd=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # https://github.com/pytorch/pytorch/issues/66357 + check_batched_forward_grad=False, + test_neg_view=False, + sample_inputs_func=sample_inputs_index_put, + skips=( + DecorateInfo(unittest.skip("Skipped"), 'TestBwdGradients', 'test_fn_grad', dtypes=[torch.float64], + device_type='cuda', active_if=(TEST_WITH_ROCM and TEST_WITH_TORCHINDUCTOR)), + )), + OpInfo('sort', + dtypes=all_types_and(torch.bool, torch.float16, torch.bfloat16), + dtypesIfCUDA=all_types_and(torch.bool, torch.float16, torch.bfloat16), + sample_inputs_func=sample_inputs_sort, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + skips=( + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_non_standard_bool_values', + dtypes=[torch.bool], device_type='cuda', active_if=not TEST_WITH_ROCM), + )), + OpInfo('unique', + dtypes=all_types_and(torch.bool, torch.float16, torch.bfloat16, torch.uint16, torch.uint32, torch.uint64), + sample_inputs_func=sample_inputs_unique, + supports_out=False, + supports_autograd=False, + skips=( + # lambda impl + DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + DecorateInfo(unittest.skip('Output order is undefined when sorted=False'), 'TestCommon', 'test_compare_cpu'), + )), + OpInfo('unique_consecutive', + dtypes=all_types_and(torch.bool, torch.float16, torch.bfloat16), + sample_inputs_func=sample_inputs_unique_consecutive, + supports_out=False, + supports_autograd=False, + skips=( + # lambda impl + DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + )), + OpInfo('put', + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + check_batched_forward_grad=False, + check_batched_gradgrad=False, # vmap complains of the sizes + sample_inputs_func=sample_inputs_put), + OpInfo('take', + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), + check_batched_grad=False, # vmap complains of the sizes + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + sample_inputs_func=sample_inputs_take, + error_inputs_func=error_inputs_take), + OpInfo('scatter', + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + sample_inputs_func=sample_inputs_scatter, + error_inputs_func=error_inputs_scatter_and_scatter_add, + skips=( + # Compiler issue on ROCm. Regression started in ROCm 6.4. + DecorateInfo(unittest.skip('Skipped!'), 'TestCommon', 'test_non_standard_bool_values', + dtypes=[torch.bool], active_if=TEST_WITH_ROCM), + )), + UnaryUfuncInfo( + 'bfloat16', + op=lambda x, *args, **kwargs: x.bfloat16(*args, **kwargs), + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf), + supports_out=False, + sample_inputs_func=sample_inputs_conversion, + skips=( + # autograd tests don't handle operators that change dtype + DecorateInfo(unittest.expectedFailure, 'TestFwdGradients'), + DecorateInfo(unittest.expectedFailure, 'TestBwdGradients'), + DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), + # RuntimeError: attribute lookup is not defined on builtin + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + DecorateInfo(unittest.skip("Skipped!"), 'TestNNCOpInfo', 'test_nnc_correctness'), + )), + UnaryUfuncInfo( + 'bool', + op=lambda x, *args, **kwargs: x.bool(*args, **kwargs), + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf), + supports_out=False, + sample_inputs_func=sample_inputs_conversion, + supports_autograd=False, + skips=( + DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), + # RuntimeError: attributis not defined on builtin + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + )), + UnaryUfuncInfo( + 'byte', + op=lambda x, *args, **kwargs: x.byte(*args, **kwargs), + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + supports_out=False, + sample_inputs_func=sample_inputs_byte, + # The autograd test runner cannot handle functions that change dtype + supports_autograd=False, + skips=( + DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), + # RuntimeError: attribute lookup is not defined on builtin + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + DecorateInfo(unittest.skip('Overflow when downcasting signed type is undefined'), 'TestCommon', 'test_compare_cpu'), + )), + UnaryUfuncInfo( + 'char', + op=lambda x, *args, **kwargs: x.char(*args, **kwargs), + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf), + supports_out=False, + sample_inputs_func=sample_inputs_conversion, + # The autograd test runner cannot handle functions that change dtype + supports_autograd=False, + skips=( + DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), + # RuntimeError: attribute lookup is not defined on builtin + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + DecorateInfo(unittest.skip('Overflow when downcasting signed type is undefined'), 'TestCommon', 'test_compare_cpu'), + )), + UnaryUfuncInfo( + 'double', + op=lambda x, *args, **kwargs: x.double(*args, **kwargs), + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf), + supports_out=False, + sample_inputs_func=sample_inputs_conversion, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + skips=( + DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), + # RuntimeError: attribute lookup is not defined on builtin + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + )), + UnaryUfuncInfo( + 'float', + op=lambda x, *args, **kwargs: x.float(*args, **kwargs), + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf), + supports_out=False, + sample_inputs_func=sample_inputs_conversion, + skips=( + # autograd tests don't handle operators that change dtype + DecorateInfo(unittest.expectedFailure, 'TestFwdGradients'), + DecorateInfo(unittest.expectedFailure, 'TestBwdGradients'), + DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), + # RuntimeError: attribute lookup is not defined on builtin + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + )), + UnaryUfuncInfo( + 'half', + op=lambda x, *args, **kwargs: x.half(*args, **kwargs), + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + supports_out=False, + sample_inputs_func=sample_inputs_conversion, + supports_autograd=True, + skips=( + # autograd tests don't handle operators that change dtype + DecorateInfo(unittest.expectedFailure, 'TestFwdGradients'), + DecorateInfo(unittest.expectedFailure, 'TestBwdGradients'), + DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), + # RuntimeError: attribute lookup is not defined on builtin + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + )), + UnaryUfuncInfo( + 'int', + op=lambda x, *args, **kwargs: x.int(*args, **kwargs), + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + supports_out=False, + sample_inputs_func=sample_inputs_conversion, + supports_autograd=False, + skips=( + DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), + # RuntimeError: attribute lookup is not defined on builtin + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + DecorateInfo(unittest.skip('Overflow when downcasting signed type is undefined'), 'TestCommon', 'test_compare_cpu'), + )), + UnaryUfuncInfo( + 'long', + op=lambda x, *args, **kwargs: x.long(*args, **kwargs), + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf), + supports_out=False, + sample_inputs_func=sample_inputs_conversion, + supports_autograd=False, + skips=( + DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), + # RuntimeError: attribute lookup is not defined on builtin + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + DecorateInfo(unittest.skip('Overflow when downcasting signed type is undefined'), 'TestCommon', 'test_compare_cpu'), + )), + UnaryUfuncInfo( + 'short', + op=lambda x, *args, **kwargs: x.short(*args, **kwargs), + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + supports_out=False, + sample_inputs_func=sample_inputs_conversion, + supports_autograd=False, + skips=( + DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), + # RuntimeError: attribute lookup is not defined on builtin + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + DecorateInfo(unittest.skip('Overflow when downcasting signed type is undefined'), 'TestCommon', 'test_compare_cpu'), + )), + UnaryUfuncInfo( + 'cdouble', + op=torch.Tensor.cdouble, + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf), + supports_out=False, + sample_inputs_func=sample_inputs_conversion, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + skips=( + DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), + # RuntimeError: attribute lookup is not defined on builtin + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + DecorateInfo(unittest.skip("Skipped!"), 'TestNNCOpInfo', 'test_nnc_correctness'), + )), + UnaryUfuncInfo( + 'cfloat', + op=torch.Tensor.cfloat, + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf), + supports_out=False, + sample_inputs_func=sample_inputs_conversion, + skips=( + # autograd tests don't handle operators that change dtype + DecorateInfo(unittest.expectedFailure, 'TestFwdGradients'), + DecorateInfo(unittest.expectedFailure, 'TestBwdGradients'), + DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), + # RuntimeError: attribute lookup is not defined on builtin + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + DecorateInfo(unittest.skip("Skipped!"), 'TestNNCOpInfo', 'test_nnc_correctness'), + )), + UnaryUfuncInfo( + 'chalf', + op=lambda x, *args, **kwargs: x.chalf(*args, **kwargs), + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf), + supports_out=False, + sample_inputs_func=sample_inputs_conversion, + skips=( + # autograd tests don't handle operators that change dtype + DecorateInfo(unittest.expectedFailure, 'TestFwdGradients'), + DecorateInfo(unittest.expectedFailure, 'TestBwdGradients'), + # use of lambda doesn't work with test_normalize_operator_exhaustive + DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), + # RuntimeError: "sum_cpu" not implemented for 'ComplexHalf' + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_variant_consistency_eager', + device_type='cpu'), + # TypeError: 'int' object is not iterable + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + # RuntimeError: "sum_cpu" not implemented for 'ComplexHalf' + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view', + device_type='cpu'), + # RuntimeError: "sum_cpu" not implemented for 'ComplexHalf' + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view', + device_type='cpu'), + # RuntimeError: "sum_cpu" not implemented for 'ComplexHalf' + # RuntimeError: "neg_conj_cuda" not implemented for 'ComplexHalf' + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'), + ) + ), + OpInfo('empty_like', + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf), + supports_out=False, + sample_inputs_func=sample_inputs_like_fns, + reference_inputs_func=reference_inputs_like_fns, + supports_autograd=False, + skips=( + # Empty tensor data is garbage so it's hard to make comparisons with it. + DecorateInfo(unittest.skip("Skipped!"), + "TestNormalizeOperators", "test_normalize_operator_exhaustive"), + # Empty tensor data is garbage so it's hard to make comparisons with it. + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_noncontiguous_samples'), + # Empty tensor data is garbage so it's hard to make comparisons with it. + DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'), + # Empty tensor data is garbage so it's hard to make comparisons with it. + DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_conj_view'), + # Empty tensor data is garbage so it's hard to make comparisons with it. + DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_view'), + # Empty tensor data is garbage so it's hard to make comparisons with it. + DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_conj_view'), + # Empty tensor data is garbage so it's hard to make comparisons with it. + DecorateInfo(unittest.skip("Skipped!"), 'TestNNCOpInfo', 'test_nnc_correctness'), + # Empty tensor data is garbage so it's hard to make comparisons with it. + DecorateInfo(unittest.skip("Skipped!"), 'TestCudaFuserOpInfo'), + # Empty tensor data is garbage so it's hard to make comparisons with it. + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_complex_half_reference_testing'), + # Empty tensor data is garbage so it's hard to make comparisons with it. + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_non_standard_bool_values'), + DecorateInfo(unittest.skip("Expected: empty_like is not comparable"), 'TestCompositeCompliance', + 'test_operator'), + DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'), + )), + OpInfo('zeros_like', + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf), + supports_out=False, + sample_inputs_func=sample_inputs_like_fns, + supports_autograd=False, + error_inputs_sparse_func=error_inputs_sparse_like_fns, + sample_inputs_sparse_coo_func=partial(sample_inputs_sparse_like_fns, layout=torch.sparse_coo), + sample_inputs_sparse_csr_func=partial(sample_inputs_sparse_like_fns, layout=torch.sparse_csr), + sample_inputs_sparse_csc_func=partial(sample_inputs_sparse_like_fns, layout=torch.sparse_csc), + sample_inputs_sparse_bsr_func=partial(sample_inputs_sparse_like_fns, layout=torch.sparse_bsr), + sample_inputs_sparse_bsc_func=partial(sample_inputs_sparse_like_fns, layout=torch.sparse_bsc), + skips=( + )), + OpInfo('ones_like', + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf), + supports_out=False, + sample_inputs_func=sample_inputs_like_fns, + supports_autograd=False, + skips=( + )), + OpInfo('randn', + dtypes=floating_and_complex_types_and(torch.half, torch.bfloat16, torch.complex32), + op=lambda *args, **kwargs: wrapper_set_seed(torch.randn, *args, **kwargs), + supports_out=True, + sample_inputs_func=sample_inputs_randn, + supports_autograd=False, + skips=( + # Tests that assume input is a tensor or sequence of tensors + DecorateInfo(unittest.skip("Test expects tensor input"), "TestCommon", "test_noncontiguous_samples"), + DecorateInfo(unittest.skip("Test expects tensor input"), "TestVmapOperatorsOpInfo", "test_vmap_exhaustive"), + DecorateInfo(unittest.skip("Test expects tensor input"), "TestVmapOperatorsOpInfo", "test_op_has_batch_rule"), + # CPU randn generates different values based on the strides of out tensor + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out', device_type='cpu'), + # randn fails to warn when resizing its out tensor + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'), + # FX failed to normalize op - add the op to the op_skip list. + DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), + # Tests that assume input tensor has a meaningful effect on output tensor + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_variant_consistency_eager'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'), + # AssertionError: JIT Test does not execute any logic + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + DecorateInfo(unittest.expectedFailure, 'TestDecomp', 'test_quick'), + )), + OpInfo('randn_like', + dtypes=floating_and_complex_types_and(torch.half, torch.bfloat16, torch.complex32), + op=lambda inp, *args, **kwargs: + wrapper_set_seed(torch.randn_like, inp, *args, **kwargs), + supports_out=False, + sample_inputs_func=sample_inputs_like_fns, + supports_autograd=False, + error_inputs_sparse_func=error_inputs_sparse_like_fns, + sample_inputs_sparse_coo_func=partial(sample_inputs_sparse_like_fns, layout=torch.sparse_coo), + sample_inputs_sparse_csr_func=partial(sample_inputs_sparse_like_fns, layout=torch.sparse_csr), + sample_inputs_sparse_csc_func=partial(sample_inputs_sparse_like_fns, layout=torch.sparse_csc), + sample_inputs_sparse_bsr_func=partial(sample_inputs_sparse_like_fns, layout=torch.sparse_bsr), + sample_inputs_sparse_bsc_func=partial(sample_inputs_sparse_like_fns, layout=torch.sparse_bsc), + skips=( + DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), + # AssertionError: JIT Test does not execute any logic + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + DecorateInfo(unittest.skip("Expected: randn_like is not comparable between dtypes"), + 'TestCommon', 'test_complex_half_reference_testing'), + DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'), + )), + OpInfo('rand_like', + dtypes=floating_types_and(torch.half, torch.bfloat16, torch.complex32, torch.complex64, torch.complex128), + op=lambda inp, *args, **kwargs: + wrapper_set_seed(torch.randn_like, inp, *args, **kwargs), + supports_out=False, + sample_inputs_func=sample_inputs_like_fns, + supports_autograd=False, + skips=( + DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), + # AssertionError: JIT Test does not execute any logic + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + DecorateInfo(unittest.skip("Expected: randn_like is not comparable between dtypes"), + 'TestCommon', 'test_complex_half_reference_testing'), + DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'), + )), + OpInfo('randint', + dtypes=all_types_and(torch.half, torch.bfloat16), + op=lambda *args, **kwargs: + wrapper_set_seed(torch.randint, *args, **kwargs), + supports_out=False, + sample_inputs_func=sample_inputs_randint, + supports_autograd=False, + skips=( + # Tests that assume input is a tensor or sequence of tensors + DecorateInfo(unittest.skip("Test expects tensor input"), "TestCommon", "test_noncontiguous_samples"), + DecorateInfo(unittest.skip("Test expects tensor input"), "TestVmapOperatorsOpInfo", "test_vmap_exhaustive"), + DecorateInfo(unittest.skip("Test expects tensor input"), "TestVmapOperatorsOpInfo", "test_op_has_batch_rule"), + # CPU randint generates different values based on the strides of out tensor + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'), + # randint fails to warn when resizing its out tensor + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'), + # FX failed to normalize op - add the op to the op_skip list. + DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), + # Tests that assume input tensor has a meaningful effect on output tensor + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_variant_consistency_eager'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'), + # AssertionError: JIT Test does not execute any logic + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + # Might need to skip until ROCm5.5 + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_multiple_devices', + dtypes=[torch.float32, torch.int64], active_if=TEST_WITH_ROCM), + )), + OpInfo('randint_like', + dtypes=all_types_and(torch.half, torch.bfloat16), + op=lambda inp, *args, **kwargs: + wrapper_set_seed(torch.randint_like, inp, *args, **kwargs), + supports_out=False, + sample_inputs_func=sample_inputs_randint_like, + supports_autograd=False, + skips=( + DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), + # AssertionError: JIT Test does not execute any logic + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'), + )), + OpInfo('full_like', + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, + torch.uint16, torch.uint32), + supports_out=False, + sample_inputs_func=sample_inputs_full_like, + supports_autograd=False, + ), + OpInfo('new_zeros', + op=lambda x, *args, **kwargs: x.new_zeros(*args, **kwargs), + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf), + supports_out=False, + sample_inputs_func=sample_inputs_new_fns, + skips=( + DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), + ), + supports_autograd=False), + OpInfo('new_ones', + op=lambda x, *args, **kwargs: x.new_ones(*args, **kwargs), + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf), + supports_out=False, + sample_inputs_func=sample_inputs_new_fns, + skips=( + DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), + ), + supports_autograd=False), + OpInfo('ones', + op=torch.ones, + supports_autograd=False, + supports_varargs=True, + is_factory_function=True, + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf), + supports_out=True, + sample_inputs_func=sample_inputs_ones_zeros, + skips=( + # Tests that assume input is a tensor or sequence of tensors + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_variant_consistency_eager'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'), + + # Same failure as arange: cannot find linspace in captured graph + DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'), + + # UserWarning not triggered : Resized a non-empty tensor but did not warn about it. + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'), + )), + OpInfo('zeros', + op=torch.zeros, + supports_autograd=False, + is_factory_function=True, + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf), + supports_out=True, + sample_inputs_func=sample_inputs_ones_zeros, + skips=( + # Tests that assume input is a tensor or sequence of tensors + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_variant_consistency_eager'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'), + + # Same failure as arange: cannot find linspace in captured graph + DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'), + + # UserWarning not triggered : Resized a non-empty tensor but did not warn about it. + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'), + )), + OpInfo('full', + op=torch.full, + supports_autograd=False, + is_factory_function=True, + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf), + supports_out=True, + sample_inputs_func=sample_inputs_full, + skips=( + # Tests that assume input is a tensor or sequence of tensors + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_variant_consistency_eager'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'), + # Same failure as arange: cannot find linspace in captured graph + DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'), + # UserWarning not triggered : Resized a non-empty tensor but did not warn about it. + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'), + # RuntimeError: UNSUPPORTED DTYPE: bool + DecorateInfo(unittest.expectedFailure, 'TestNNCOpInfo', 'test_nnc_correctness', dtypes=(torch.bool,)), + )), + OpInfo('new_empty', + op=lambda x, *args, **kwargs: x.new_empty(*args, **kwargs), + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf), + supports_out=False, + sample_inputs_func=sample_inputs_new_fns, + skips=( + DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), + # Empty tensor data is garbage so it's hard to make comparisons with it. + DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'), + # Empty tensor data is garbage so it's hard to make comparisons with it. + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_variant_consistency_eager'), + # Empty tensor data is garbage so it's hard to make comparisons with it. + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_noncontiguous_samples'), + # Empty tensor data is garbage so it's hard to make comparisons with it. + DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_conj_view'), + # Empty tensor data is garbage so it's hard to make comparisons with it. + DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_view'), + # Empty tensor data is garbage so it's hard to make comparisons with it. + DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_conj_view'), + # Empty tensor data is garbage so it's hard to make comparisons with it. + DecorateInfo(unittest.skip("Skipped!"), 'TestNNCOpInfo', 'test_nnc_correctness'), + # Empty tensor data is garbage so it's hard to make comparisons with it. + DecorateInfo(unittest.skip("Skipped!"), 'TestCudaFuserOpInfo'), + # Empty tensor data is garbage so it's hard to make comparisons with it. + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_non_standard_bool_values'), + DecorateInfo(unittest.skip("Expected: new_empty is not comparable"), 'TestCompositeCompliance', + 'test_operator'), + DecorateInfo(unittest.skip("Expected: new_empty is not comparable"), + 'TestCommon', 'test_complex_half_reference_testing'), + DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'), + ), + supports_autograd=False), + OpInfo('new_empty_strided', + op=lambda x, *args, **kwargs: x.new_empty_strided(*args, **kwargs), + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf), + supports_out=False, + sample_inputs_func=partial(sample_inputs_new_fns, is_strided=True), + supports_autograd=False, + skips=( + # FX failed to normalize op + DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), + # Lazy tensor failures + DecorateInfo(unittest.skip("Skipped!"), 'TestLazyOpInfo', 'test_correctness'), + DecorateInfo(unittest.skip("Skipped!"), 'TestLazyOpInfo', 'test_correctness_with_reusing_ir'), + # Empty tensor data is garbage so it's hard to make comparisons with it. + DecorateInfo(unittest.skip("Expected: new_empty_strided is not comparable"), + 'TestCommon', 'test_variant_consistency_eager'), + DecorateInfo(unittest.skip("Expected: new_empty_strided is not comparable"), + 'TestCommon', 'test_noncontiguous_samples'), + DecorateInfo(unittest.skip("Expected: new_empty_strided is not comparable"), + 'TestMathBits', 'test_conj_view'), + DecorateInfo(unittest.skip("Expected: new_empty_strided is not comparable"), + 'TestMathBits', 'test_neg_view'), + DecorateInfo(unittest.skip("Expected: new_empty_strided is not comparable"), + 'TestMathBits', 'test_neg_conj_view'), + DecorateInfo(unittest.skip("Expected: new_empty_strided is not comparable"), + 'TestCommon', 'test_non_standard_bool_values'), + DecorateInfo(unittest.skip("Expected: new_empty_strided is not comparable"), + 'TestCommon', 'test_complex_half_reference_testing'), + DecorateInfo(unittest.skip("Expected: new_empty_strided is not comparable"), + 'TestCompositeCompliance', 'test_operator'), + DecorateInfo(unittest.skip("Expected: new_empty_strided is not comparable"), + 'TestDecomp', 'test_comprehensive'), + DecorateInfo(unittest.skip("Expected: new_empty_strided is not comparable"), + 'TestDecomp', 'test_quick'), + DecorateInfo(unittest.skip("Expected: new_empty_strided is not comparable"), + 'TestJit', 'test_variant_consistency_jit'), + DecorateInfo(unittest.skip("Expected: new_empty_strided is not comparable"), + 'TestProxyTensorOpInfo', 'test_make_fx_exhaustive'), + DecorateInfo(unittest.skip("Expected: new_empty_strided is not comparable"), + 'TestProxyTensorOpInfo', 'test_make_fx_fake_exhaustive'), + DecorateInfo(unittest.skip("Expected: new_empty_strided is not comparable"), + 'TestProxyTensorOpInfo', 'test_make_fx_symbolic_exhaustive'), + DecorateInfo(unittest.skip("Expected: new_empty_strided is not comparable"), + 'TestNNCOpInfo', 'test_nnc_correctness'), + DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'), + )), + OpInfo('empty_strided', + op=lambda inp, *args, **kwargs: wrapper_set_seed(torch.empty_strided, inp, *args, **kwargs), + dtypes=all_types_and_complex_and(torch.bfloat16, torch.bool, torch.half), + supports_out=False, + supports_autograd=False, + sample_inputs_func=sample_inputs_empty_strided, + skips=( + # FX failed to normalize op - add the op to the op_skip list. + DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), + # AssertionError: JIT Test does not execute any logic + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + # Empty tensor data is garbage so it's hard to make comparisons with it. + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_noncontiguous_samples'), + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_variant_consistency_eager'), + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_non_standard_bool_values'), + DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_conj_view'), + DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_view'), + DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_conj_view'), + DecorateInfo(unittest.skip('Skipped!'), 'TestCommon', 'test_compare_cpu'), + DecorateInfo(unittest.skip("Expected: empty is not comparable"), 'TestCompositeCompliance', 'test_operator'), + # Lazy tensor failures + DecorateInfo(unittest.skip("Expected: empty is not comparable"), 'TestLazyOpInfo'), + # RuntimeError: unsupported operation: more than one element of the written-to tensor refers to a single + # memory location. Please clone() the tensor before performing the operation. + DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_dispatch_meta_outplace'), + DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_dispatch_symbolic_meta_outplace'), + DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_dispatch_symbolic_meta_outplace_all_strides'), + )), + OpInfo('empty', + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf), + sample_inputs_func=sample_inputs_empty, + supports_autograd=False, + skips=( + DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), + # Empty tensor data is garbage so it's hard to make comparisons with it. + DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'), + # Empty tensor data is garbage so it's hard to make comparisons with it. + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_variant_consistency_eager'), + # Empty tensor data is garbage so it's hard to make comparisons with it. + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_noncontiguous_samples'), + # Empty tensor data is garbage so it's hard to make comparisons with it. + DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_conj_view'), + # Empty tensor data is garbage so it's hard to make comparisons with it. + DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_view'), + # Empty tensor data is garbage so it's hard to make comparisons with it. + DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_conj_view'), + # Empty tensor data is garbage so it's hard to make comparisons with it. + DecorateInfo(unittest.skip("Skipped!"), 'TestNNCOpInfo', 'test_nnc_correctness'), + # Empty tensor data is garbage so it's hard to make comparisons with it. + DecorateInfo(unittest.skip("Skipped!"), 'TestCudaFuserOpInfo'), + # Empty tensor data is garbage so it's hard to make comparisons with it. + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_non_standard_bool_values'), + DecorateInfo(unittest.skip("Expected: empty is not comparable"), 'TestCompositeCompliance', + 'test_operator'), + # requires_grad doesn't exist in the jit schema + DecorateInfo(unittest.expectedFailure, 'TestOperatorSignatures', 'test_get_torch_func_signature_exhaustive'), + DecorateInfo(unittest.skip("Expected: empty is not comparable"), + 'TestCommon', + 'test_out'), + DecorateInfo(unittest.skip("Expected: empty is not comparable"), + 'TestCommon', + 'test_out_warning'), + DecorateInfo(unittest.skip("Expected: empty is not comparable"), + 'TestLazyOpInfo'), + DecorateInfo(unittest.skip("Expected: empty is not comparable"), + 'TestCommon', 'test_complex_half_reference_testing'), + DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'), + )), + OpInfo('eye', + dtypes=all_types_complex_float8_and(torch.bool, torch.half, torch.bfloat16), + sample_inputs_func=sample_inputs_eye, + error_inputs_func=error_inputs_eye, + supports_out=True, + supports_autograd=False, + skips=( + DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), + # TODO: same as this? + # https://github.com/pytorch/pytorch/issues/81774 + # also see: arange, new_full + # fails to match any schemas despite working in the interpreter + DecorateInfo(unittest.expectedFailure, 'TestOperatorSignatures', 'test_get_torch_func_signature_exhaustive'), + # fails to match any schemas despite working in the interpreter + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + # skip these tests since we have non tensor input + DecorateInfo(unittest.skip('Skipped!'), "TestCommon", "test_noncontiguous_samples"), + DecorateInfo(unittest.skip('Skipped!'), 'TestCommon', 'test_variant_consistency_eager'), + DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_conj_view'), + DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_conj_view'), + DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_view'), + # UserWarning not triggered : Resized a non-empty tensor but did not warn about it. + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'), + # "mul_cpu_reduced_float" not implemented for 'Float8_e4m3fn' + DecorateInfo(unittest.expectedFailure, 'TestNNCOpInfo', 'test_nnc_correctness', + dtypes=(torch.float8_e4m3fn, torch.float8_e4m3fnuz, torch.float8_e5m2, torch.float8_e5m2fnuz)), + )), + OpInfo('empty_permuted', + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf), + sample_inputs_func=sample_inputs_empty_permuted, + error_inputs_func=error_inputs_empty_permuted, + supports_out=False, + supports_autograd=False, + skips=( + DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), + # Empty tensor data is garbage so it's hard to make comparisons with it. + DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'), + # Empty tensor data is garbage so it's hard to make comparisons with it. + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_variant_consistency_eager'), + # Empty tensor data is garbage so it's hard to make comparisons with it. + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_noncontiguous_samples'), + # Empty tensor data is garbage so it's hard to make comparisons with it. + DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_conj_view'), + # Empty tensor data is garbage so it's hard to make comparisons with it. + DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_view'), + # Empty tensor data is garbage so it's hard to make comparisons with it. + DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_conj_view'), + # Empty tensor data is garbage so it's hard to make comparisons with it. + DecorateInfo(unittest.skip("Skipped!"), 'TestNNCOpInfo', 'test_nnc_correctness'), + # Empty tensor data is garbage so it's hard to make comparisons with it. + DecorateInfo(unittest.skip("Skipped!"), 'TestCudaFuserOpInfo'), + # Empty tensor data is garbage so it's hard to make comparisons with it. + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_non_standard_bool_values'), + DecorateInfo(unittest.skip("Expected: empty_permuted is not comparable"), 'TestCompositeCompliance', + 'test_operator'), + # requires_grad doesn't exist in the jit schema + DecorateInfo(unittest.expectedFailure, 'TestOperatorSignatures', 'test_get_torch_func_signature_exhaustive'), + DecorateInfo(unittest.skip("Expected: empty_permuted is not comparable"), + 'TestCommon', + 'test_out'), + DecorateInfo(unittest.skip("Expected: empty_permuted is not comparable"), + 'TestCommon', + 'test_out_warning'), + DecorateInfo(unittest.skip("Expected: empty_permuted is not comparable"), + 'TestLazyOpInfo'), + DecorateInfo(unittest.skip("Expected: empty_permuted is not comparable"), + 'TestCommon', 'test_complex_half_reference_testing'), + DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'), + )), + OpInfo('scalar_tensor', + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf), + sample_inputs_func=sample_inputs_scalar_tensor, + supports_autograd=False, + supports_out=False, + skips=( + DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), + # fails to match any schemas despite working in the interpreter + DecorateInfo(unittest.expectedFailure, 'TestOperatorSignatures', 'test_get_torch_func_signature_exhaustive'), + # fails to match any schemas despite working in the interpreter + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + # skip these tests since we have non tensor input + DecorateInfo(unittest.skip('Skipped!'), "TestCommon", "test_noncontiguous_samples"), + DecorateInfo(unittest.skip('Skipped!'), 'TestCommon', 'test_variant_consistency_eager'), + DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_conj_view'), + DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_conj_view'), + DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_view'), + )), + OpInfo('new_full', + op=lambda x, *args, **kwargs: x.new_full(*args, **kwargs), + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf), + supports_out=False, + sample_inputs_func=sample_inputs_new_full, + skips=( + DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), + ), + supports_autograd=False), + OpInfo('multinomial', + op=lambda inp, *args, **kwargs: + wrapper_set_seed(torch.multinomial, inp, *args, **kwargs), + method_variant=lambda inp, *args, **kwargs: + wrapper_set_seed(torch.Tensor.multinomial, inp, *args, **kwargs), + dtypes=floating_types_and(torch.bfloat16, torch.half), + supports_out=True, + sample_inputs_func=sample_inputs_multinomial, + error_inputs_func=error_inputs_multinomial, + skips=( + DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), + # Strides are not the same! + # This may not be reproducible in CI + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_out'), + # AssertionError: JIT Test does not execute any logic + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + # UserWarning not triggered : Resized a non-empty tensor but did not warn about it. + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'), + DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu')), + supports_autograd=False), + OpInfo('normal', + op=lambda inp, *args, **kwargs: + wrapper_set_seed(torch.normal, inp, *args, **kwargs), + # The inplace variant (Tensor.normal_) is different from torch.normal + inplace_variant=None, + dtypes=floating_types_and(torch.bfloat16, torch.half), + dtypesIfCUDA=floating_types_and(torch.bfloat16, torch.half), + supports_out=True, + sample_inputs_func=sample_inputs_normal_tensor_first, + skips=( + DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), + # Tensor-likes are not close! + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'), + # AssertionError: JIT Test does not execute any logic + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + # UserWarning not triggered : Resized a non-empty tensor but did not warn about it. + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'), + # Computed gradient is incorrect -- would be an exfail but gradgrad somehow passes + DecorateInfo(unittest.skip("Gradients are incorrect!"), 'TestFwdGradients'), + DecorateInfo(unittest.skip("Gradients are incorrect!"), 'TestBwdGradients'), + DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'), + # RuntimeError: Difference from {dtype} is larger with decomposition + DecorateInfo(unittest.skip("Skipped!"), 'TestDecomp', 'test_comprehensive'), + DecorateInfo(unittest.skip("Skipped!"), 'TestDecomp', 'test_quick'), + # The inplace variant (Tensor.normal_) is different from torch.normal + # inplace variant Tensor.normal_ is decomposed using randn_like() + DecorateInfo(unittest.skip("Skipped!"), 'TestMeta', 'test_dispatch_symbolic_meta_outplace_all_strides'))), + OpInfo('normal', + # This has its own variant b/c OpInfos assume the first arg is a Tensor but it is not here + variant_test_name='number_mean', + op=lambda std, mean, *args, **kwargs: + wrapper_set_seed(torch.normal, mean, std, *args, **kwargs), + # The inplace variant (Tensor.normal_) is different from torch.normal + inplace_variant=None, + dtypes=floating_types_and(torch.bfloat16, torch.half), + dtypesIfCUDA=floating_types_and(torch.bfloat16, torch.half), + supports_out=True, + sample_inputs_func=sample_inputs_normal_tensor_second, + skips=( + DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), + # AssertionError: JIT Test does not execute any logic + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_noncontiguous_samples'), + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_variant_consistency_eager'), + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_out'), + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_out_warning'), + DecorateInfo(unittest.skip("Skipped!"), 'TestCompositeCompliance', 'test_backward'), + DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_view'), + DecorateInfo(unittest.skip("Skipped!"), 'TestFwdGradients'), + DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients'), + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_compare_cpu'), + DecorateInfo(unittest.skip("Skipped!"), 'TestEagerFusionOpInfo'), + DecorateInfo(unittest.skip("Skipped!"), 'TestOperators'), + # AssertionError + DecorateInfo(unittest.skip("Skipped!"), 'TestDecomp', 'test_comprehensive'), + # AssertionError + DecorateInfo(unittest.skip("Skipped!"), 'TestDecomp', 'test_quick'), + # AssertionError in CUDA variant + DecorateInfo(unittest.skip("Skipped!"), 'TestFakeTensor', device_type='cuda'), + DecorateInfo(unittest.skip("Skipped!"), 'TestDeviceUtils', 'test_device_mode_ops'))), + OpInfo('bernoulli', + op=lambda inp, *args, **kwargs: + wrapper_set_seed(torch.bernoulli, inp, *args, **kwargs), + # The inplace variant (Tensor.bernoulli_) is different from torch.bernoulli + inplace_variant=None, + method_variant=lambda inp, *args, **kwargs: + wrapper_set_seed(torch.Tensor.bernoulli, inp, *args, **kwargs), + dtypes=floating_types_and(torch.bfloat16, torch.half), + supports_out=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + sample_inputs_func=sample_inputs_bernoulli, + error_inputs_func=error_inputs_bernoulli, + skips=( + # vmap: We do not yet support calling random operations inside of vmap + DecorateInfo(unittest.expectedFailure, 'TestFwdGradients', 'test_forward_mode_AD'), + DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), + # AssertionError: JIT Test does not execute any logic + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + # Expected RuntimeError when doing an unsafe cast from a result of + # dtype torch.float32 into an out= with dtype torch.lon + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'), + # UserWarning not triggered : Resized a non-empty tensor but did not warn about it. + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'), + DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'))), + OpInfo('scatter_add', + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + inplace_variant=torch.Tensor.scatter_add_, + sample_inputs_func=sample_inputs_scatter_add, + error_inputs_func=error_inputs_scatter_and_scatter_add, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + skips=( + # Compiler issue on ROCm. Regression started in ROCm 6.4. + DecorateInfo(unittest.skip('Skipped!'), 'TestCommon', 'test_non_standard_bool_values', + dtypes=[torch.bool], active_if=TEST_WITH_ROCM), + )), + OpInfo('stack', + dtypes=all_types_and_complex_and(torch.complex32, torch.bool, torch.float16, torch.bfloat16), + sample_inputs_func=sample_inputs_stack, + assert_autodiffed=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + skips=( + # https://github.com/pytorch/pytorch/issues/77046 + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'), + )), + OpInfo('_chunk_cat', + dtypes=all_types_and_complex_and(torch.complex32, torch.bool, torch.float16, torch.bfloat16), + sample_inputs_func=sample_inputs_chunk_cat, + error_inputs_func=error_inputs_chunk_cat, + supports_autograd=False, + supports_out=True, + ), + OpInfo('hstack', + dtypes=all_types_and_complex_and(torch.complex32, torch.bool, torch.float16, torch.bfloat16), + sample_inputs_func=sample_inputs_hstack_dstack_vstack, + error_inputs_func=error_inputs_hstack_dstack_vstack, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + ), + BinaryUfuncInfo('hypot', + dtypes=floating_types_and(torch.bfloat16, torch.half), + dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_rhs_python_scalar=False), + OpInfo('histogram', + dtypes=floating_types(), + dtypesIfCUDA=_dispatch_dtypes(), # histogram is only implemented on CPU + sample_inputs_func=sample_inputs_histogram, + supports_autograd=False, + skips=( + # JIT tests don't work with Tensor keyword arguments + # https://github.com/pytorch/pytorch/issues/58507 + # RuntimeError: + # undefined value tensor: + # File "", line 3 + # def the_method(i0): + # return torch.histogram(i0, 1, weight=tensor(-0.5735, dtype=torch.float32), density=False) + # ~~~~~~ <--- HERE + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + # Not Implemented on XLA. + DecorateInfo(unittest.skip("Skipped!"), 'TestOpInfo', device_type='xla'), + )), + OpInfo('histogramdd', + dtypes=floating_types(), + dtypesIfCUDA=_dispatch_dtypes(), # histogramdd is only implemented on CPU + sample_inputs_func=sample_inputs_histogramdd, + error_inputs_func=error_inputs_histogramdd, + supports_autograd=False, + skips=( + # Not implemented on CUDA + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_errors', device_type='cuda'), + # JIT tests don't work with Tensor keyword arguments + # https://github.com/pytorch/pytorch/issues/58507 + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + )), + OpInfo('histc', + dtypes=floating_types_and(torch.bfloat16, torch.float16), + dtypesIfCUDA=floating_types_and(torch.int8, torch.uint8, torch.int16, torch.int32, torch.int64), + sample_inputs_func=sample_inputs_histc, + supports_out=True, + supports_autograd=False, + skips=( + # CUDA histc returns a float tensor but does not correctly warn when passed an integral out tensor + # "AssertionError: RuntimeError not raised : Expected RuntimeError when doing an unsafe cast + # from a result of dtype torch.float32 into an out= with dtype torch.long" + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out', device_type='cuda'), + )), + OpInfo('bincount', + dtypes=integral_types_and(), + sample_inputs_func=sample_inputs_bincount, + supports_out=False, + supports_autograd=False, + skips=( + # JIT tests don't work with Tensor keyword arguments + # https://github.com/pytorch/pytorch/issues/58507 + DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'), + )), + OpInfo('bucketize', + dtypes=all_types_and(torch.float16, torch.bfloat16), + dtypesIfCUDA=all_types_and(torch.bfloat16, torch.float16), + sample_inputs_func=sample_inputs_bucketize, + reference_inputs_func=reference_inputs_bucketize, + error_inputs_func=error_inputs_bucketize, + supports_autograd=False, + skips=( + # JIT tests don't work with Tensor keyword arguments + DecorateInfo(unittest.skip("Expected failure!"), 'TestJit', 'test_variant_consistency_jit'), + )), + OpInfo('searchsorted', + dtypes=all_types_and(torch.bfloat16, torch.float16), + dtypesIfCUDA=all_types_and(torch.bfloat16, torch.float16), + sample_inputs_func=sample_inputs_searchsorted, + supports_autograd=False, + ref=reference_searchsorted, + skips=( + # JIT tests don't work with Tensor keyword arguments + # https://github.com/pytorch/pytorch/issues/58507 + DecorateInfo(unittest.skip("Expected failure!"), 'TestJit', 'test_variant_consistency_jit'), + )), + OpInfo('cat', + ref=_cat_np, + aliases=('concat', 'concatenate'), + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.complex32), + sample_inputs_func=sample_inputs_cat_concat, + reference_inputs_func=reference_inputs_cat, + error_inputs_func=error_inputs_cat, + # https://github.com/pytorch/pytorch/issues/80411 + gradcheck_fast_mode=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # See https://github.com/pytorch/pytorch/issues/66357 + check_batched_forward_grad=False, + assert_autodiffed=True, + skips=( + # https://github.com/pytorch/pytorch/issues/89353 + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_numpy_ref_mps'), + # RuntimeError: Arguments for call not valid. + # Expected a value of type 'List[Tensor]' for argument + # 'tensors' but instead found type 'Tensor (inferred)'. + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_jit_alias_remapping'), + # see https://github.com/pytorch/pytorch/issues/71286 + DecorateInfo(unittest.expectedFailure, 'TestNNCOpInfo', 'test_nnc_correctness'), + # see https://github.com/pytorch/pytorch/issues/99806 + # RuntimeError: The size of tensor a (25) must match the size of tensor b (0) at non-singleton dimension 0. + DecorateInfo(unittest.expectedFailure, 'TestBwdGradients', 'test_fn_gradgrad'), + )), + OpInfo('unbind', + dtypes=all_types_and_complex_and(torch.complex32, torch.bool, torch.float16, torch.bfloat16), + ref=reference_unbind, + sample_inputs_func=sample_inputs_unbind, + error_inputs_func=error_inputs_unbind, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_gradgrad=True, + supports_out=False, + ), + OpInfo('unbind_copy', + dtypes=all_types_and_complex_and(torch.complex32, torch.bool, torch.float16, torch.bfloat16), + ref=reference_unbind, + sample_inputs_func=sample_inputs_unbind, + error_inputs_func=error_inputs_unbind, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_gradgrad=True, + supports_out=True, + check_batched_grad=False, + ), + OpInfo('vstack', + aliases=('row_stack',), + dtypes=all_types_and_complex_and(torch.complex32, torch.bool, torch.float16, torch.bfloat16), + sample_inputs_func=sample_inputs_hstack_dstack_vstack, + error_inputs_func=error_inputs_hstack_dstack_vstack, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + skips=( + # RuntimeError: _fn() Expected a value of type + # 'Tensor (inferred)' for argument 't0' but instead found type 'tuple'. + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_jit_alias_remapping'),)), + OpInfo('dstack', + dtypes=all_types_and_complex_and(torch.complex32, torch.bool, torch.float16, torch.bfloat16), + sample_inputs_func=sample_inputs_hstack_dstack_vstack, + error_inputs_func=error_inputs_hstack_dstack_vstack, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # See https://github.com/pytorch/pytorch/pull/78358 + check_batched_forward_grad=False, + ), + OpInfo('unfold', + op=lambda x, *args: x.unfold(*args), + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf), + backward_dtypes=floating_and_complex_types_and(torch.float16, torch.bfloat16), + # Runs very slowly on slow gradcheck - alternatively reduce input sizes + gradcheck_fast_mode=True, + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + check_batched_gradgrad=False, + # See https://github.com/pytorch/pytorch/issues/66357 + check_batched_forward_grad=False, + skips=( + DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), + # Skip operator schema test because this is a functional and not an operator + DecorateInfo(unittest.expectedFailure, 'TestOperatorSignatures', 'test_get_torch_func_signature_exhaustive'), + ), + sample_inputs_func=sample_inputs_unfold), + OpInfo('unfold_copy', + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf), + backward_dtypes=floating_and_complex_types_and(torch.float16, torch.bfloat16), + # Runs very slowly on slow gradcheck - alternatively reduce input sizes + gradcheck_fast_mode=True, + supports_out=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + check_batched_gradgrad=False, + # See https://github.com/pytorch/pytorch/issues/66357 + check_batched_forward_grad=False, + sample_inputs_func=sample_inputs_unfold), + OpInfo('msort', + dtypes=all_types_and(torch.bool, torch.float16, torch.bfloat16), + dtypesIfCUDA=all_types_and(torch.bool, torch.float16, torch.bfloat16), + check_batched_gradgrad=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + sample_inputs_func=sample_inputs_msort), + OpInfo('movedim', + aliases=('moveaxis',), + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf), + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # See https://github.com/pytorch/pytorch/pull/78358 + check_batched_forward_grad=False, + sample_inputs_func=sample_movedim_moveaxis, + reference_inputs_func=reference_movedim_moveaxis, + error_inputs_func=error_movedim_moveaxis), + OpInfo('renorm', + dtypes=floating_and_complex_types_and(torch.float16, torch.bfloat16), + sample_inputs_func=sample_inputs_renorm, + error_inputs_func=error_inputs_renorm, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + skips=( + # RuntimeError: Difference from float64 is larger with decomposition + # linalg_vector_norm.default than original on output 0. + # Original max diff: 2.560596747969157e-07, + # Decomp max diff: 1.8187482915266173e-06 + DecorateInfo(unittest.skip("Inconsistent accuracy"), 'TestDecomp', 'test_comprehensive', + device_type='cpu', dtypes=(torch.float16,)), + DecorateInfo(toleranceOverride({torch.float16: tol(atol=3e-4, rtol=3e-6)}), + "TestConsistency", "test_output_match", device_type="mps"), + )), + ShapeFuncInfo('repeat', + op=lambda x, dims: x.repeat(dims), + ref=np.tile, + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), + # https://github.com/pytorch/pytorch/issues/80411 + gradcheck_fast_mode=True, + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + sample_inputs_func=sample_repeat_tile, + skips=( + DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), + )), + OpInfo('squeeze', + ref=_squeeze_ref, + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf), + supports_out=False, + assert_autodiffed=True, + autodiff_fusible_nodes=[], # aliases inputs, shouldn't be fused + autodiff_nonfusible_nodes=[], # aliases inputs, shouldn't be fused + assert_jit_shape_analysis=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # vmap does not support inplace views + check_inplace_batched_forward_grad=False, + # https://github.com/pytorch/pytorch/issues/66357 + check_batched_forward_grad=False, + sample_inputs_func=sample_inputs_squeeze), + OpInfo('squeeze', + ref=_squeeze_ref, + variant_test_name="multiple", + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf), + supports_out=False, + assert_autodiffed=True, + autodiff_fusible_nodes=[], # aliases inputs, shouldn't be fused + autodiff_nonfusible_nodes=[], # aliases inputs, shouldn't be fused + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # vmap does not support inplace views + check_inplace_batched_forward_grad=False, + # https://github.com/pytorch/pytorch/issues/66357 + check_batched_forward_grad=False, + sample_inputs_func=sample_inputs_squeeze_multiple), + OpInfo('squeeze_copy', + ref=_squeeze_ref, + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf), + supports_out=True, + assert_autodiffed=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # vmap does not support inplace views + check_inplace_batched_forward_grad=False, + # https://github.com/pytorch/pytorch/issues/66357 + check_batched_forward_grad=False, + sample_inputs_func=sample_inputs_squeeze, + skips=( + DecorateInfo( + unittest.expectedFailure, + 'TestJit', + 'test_variant_consistency_jit', + dtypes=(torch.float32,), + ), + )), + UnaryUfuncInfo( + 'fill', + ref=_fill_np, + method_variant=None, + sample_kwargs=_fill_sample_kwargs, + sample_inputs_func=partial(sample_inputs_elementwise_unary, op_kwargs={'value': True}), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # https://github.com/pytorch/pytorch/issues/66357 + check_batched_forward_grad=False, + dtypes=all_types_and_complex_and(torch.complex32, torch.bool, torch.float16, torch.bfloat16), + supports_out=False, + skips=( + # JIT has issue when op is passed as lambda + # AssertionError: JIT Test does not execute any logic + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + DecorateInfo(unittest.skip("No fill_ op"), 'TestCudaFuserOpInfo'), + DecorateInfo(unittest.skip("No fill_ op"), 'TestNNCOpInfo'), + )), + OpInfo('resize_', + op=lambda x, shape: x.clone().resize_(shape), + method_variant=None, + inplace_variant=torch.Tensor.resize_, + # the test fails because resize_ doesn't work with imag views as expected by the test + # https://github.com/pytorch/pytorch/issues/65945 + test_neg_view=False, + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), + supports_out=False, + supports_autograd=False, + skips=( + # Cannot resize variables that require grad + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_dtypes'), + DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), + DecorateInfo(unittest.skip("Allowed exception"), 'TestCompositeCompliance', 'test_operator'), + ), + sample_inputs_func=sample_inputs_resize_ops), + OpInfo('resize_as_', + op=lambda x, other: torch.resize_as_(x.clone(), other), + method_variant=None, + inplace_variant=torch.Tensor.resize_as_, + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), + supports_out=False, + supports_autograd=False, + skips=( + # Cannot resize variables that require grad + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_dtypes'), + DecorateInfo(unittest.skip('Allowed exemption'), 'TestCompositeCompliance', 'test_operator'), + ), + sample_inputs_func=sample_inputs_resize_ops), + OpInfo('take_along_dim', + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), + dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), + supports_inplace_autograd=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # See https://github.com/pytorch/pytorch/pull/78358 + check_batched_forward_grad=False, + sample_inputs_func=sample_inputs_take_along_dim, + gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, + decorators=( + # RuntimeError: view size is not compatible with input tensor's size and stride + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides"), + )), + ShapeFuncInfo('tile', + ref=np.tile, + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), + # https://github.com/pytorch/pytorch/issues/80411 + gradcheck_fast_mode=True, + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + sample_inputs_func=sample_repeat_tile), + OpInfo('trapz', # TODO: in the future, 'trapz' should be made a proper alias of 'trapezoid' + dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16), + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # See https://github.com/pytorch/pytorch/pull/78358 + check_batched_forward_grad=False, + decorators=[ + DecorateInfo( + toleranceOverride({torch.half: tol(atol=9e-4, rtol=4.3e-3)}), + 'TestInductorOpInfo', 'test_comprehensive', device_type='cuda' + ), + ], + sample_inputs_func=sample_trapezoid), + OpInfo('trapezoid', + dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16), + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # See https://github.com/pytorch/pytorch/pull/78358 + check_batched_forward_grad=False, + decorators=[ + DecorateInfo( + toleranceOverride({torch.half: tol(atol=9e-4, rtol=4.3e-3)}), + 'TestInductorOpInfo', 'test_comprehensive', device_type='cuda' + ), + ], + sample_inputs_func=sample_trapezoid), + OpInfo('cumulative_trapezoid', + dtypes=all_types_and_complex_and(torch.bfloat16, torch.float16), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # See https://github.com/pytorch/pytorch/pull/78358 + check_batched_forward_grad=False, + supports_out=False, + decorators=( + DecorateInfo( + toleranceOverride({torch.float16: tol(atol=4e-3, rtol=4e-3)}), + 'TestInductorOpInfo', 'test_comprehensive', + ), + ), + sample_inputs_func=sample_cumulative_trapezoid,), + OpInfo('unsqueeze', + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf), + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # See https://github.com/pytorch/pytorch/pull/78358 + check_batched_forward_grad=False, + # vmap does not support inplace views + check_inplace_batched_forward_grad=False, + assert_jit_shape_analysis=True, + assert_autodiffed=True, + autodiff_fusible_nodes=[], # aliases inputs, shouldn't be fused + autodiff_nonfusible_nodes=[], # aliases inputs, shouldn't be fused + sample_inputs_func=sample_unsqueeze), + OpInfo('unsqueeze_copy', + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf), + supports_out=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # See https://github.com/pytorch/pytorch/pull/78358 + check_batched_forward_grad=False, + # vmap does not support inplace views + check_inplace_batched_forward_grad=False, + assert_jit_shape_analysis=True, + assert_autodiffed=True, + autodiff_fusible_nodes=[], # aliases inputs, shouldn't be fused + autodiff_nonfusible_nodes=[], # aliases inputs, shouldn't be fused + sample_inputs_func=sample_unsqueeze, + skips=( + DecorateInfo(unittest.expectedFailure, 'TestDTensorOps', 'test_dtensor_op_db'), + DecorateInfo( + unittest.expectedFailure, + 'TestJit', + 'test_variant_consistency_jit', + dtypes=(torch.float32,), + ), + )), + BinaryUfuncInfo('xlogy', + aliases=('special.xlogy',), + dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16), + promotes_int_to_float=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_one_python_scalar=True, + # We don't test 0 as the gradient will be NaN and it'll break + rhs_make_tensor_kwargs=dict(low=0.01)), + OpInfo('zero_', + op=lambda x: torch.zero_(x.clone()), + method_variant=None, + inplace_variant=torch.Tensor.zero_, + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), + # https://github.com/pytorch/pytorch/issues/80411 + gradcheck_fast_mode=True, + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_gradgrad=True, + skips=( + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + ), + sample_inputs_func=sample_inputs_zero_), + OpInfo('logsumexp', + aliases=('special.logsumexp',), + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + assert_autodiffed=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + gradcheck_fast_mode=False, + sample_inputs_func=sample_inputs_logsumexp, + reference_inputs_func=reference_inputs_logsumexp), + OpInfo('trace', + dtypes=all_types_and_complex(), + dtypesIfCUDA=all_types_and_complex_and(torch.chalf, torch.bool, torch.half, torch.bfloat16), + error_inputs_func=error_inputs_trace, + supports_inplace_autograd=False, + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + sample_inputs_func=sample_inputs_trace), + OpInfo('transpose', + ref=_numpy_ref_transpose, + aliases=('swapdims', 'swapaxes'), + assert_jit_shape_analysis=True, + dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.half, torch.chalf), + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # vmap does not support inplace views + check_inplace_batched_forward_grad=False, + sample_inputs_func=sample_inputs_transpose_swapdims), + OpInfo('transpose_copy', + assert_jit_shape_analysis=True, + dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.half, torch.chalf), + supports_out=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # vmap does not support inplace views + check_inplace_batched_forward_grad=False, + sample_inputs_func=sample_inputs_transpose_swapdims, + skips=( + DecorateInfo(unittest.expectedFailure, 'TestDTensorOps', 'test_dtensor_op_db'), + DecorateInfo( + unittest.expectedFailure, + 'TestJit', + 'test_variant_consistency_jit', + dtypes=(torch.float32,) + ), + )), + OpInfo('T', + op=lambda x: x.T, + dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.half, torch.chalf), + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + skips=( + # lambda impl + DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), + DecorateInfo(unittest.expectedFailure, "TestJit", "test_variant_consistency_jit"),), + sample_inputs_func=sample_inputs_T, + error_inputs_func=error_inputs_T), + OpInfo('H', + op=lambda x: x.H, + dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.half, torch.chalf), + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # See https://github.com/pytorch/pytorch/pull/78358 + check_batched_forward_grad=False, + skips=( + # lambda impl + DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), + DecorateInfo(unittest.expectedFailure, "TestJit", "test_variant_consistency_jit"),), + sample_inputs_func=sample_inputs_T), + OpInfo('mT', + op=lambda x: x.mT, + dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.half, torch.chalf), + # Runs very slowly on slow gradcheck - alternatively reduce input sizes + gradcheck_fast_mode=True, + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + skips=( + # lambda impl + DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), + DecorateInfo(unittest.expectedFailure, "TestJit", "test_variant_consistency_jit"),), + sample_inputs_func=sample_inputs_adjoint), + OpInfo('mH', + op=lambda x: x.mH, + aliases=('adjoint',), + dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.half, torch.chalf), + # Runs very slowly on slow gradcheck - alternatively reduce input sizes + gradcheck_fast_mode=True, + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # See https://github.com/pytorch/pytorch/pull/78358 + check_batched_forward_grad=False, + skips=( + # lambda impl + DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), + DecorateInfo(unittest.expectedFailure, "TestJit", "test_variant_consistency_jit"),), + sample_inputs_func=sample_inputs_adjoint), + OpInfo('tril', + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + error_inputs_func=error_inputs_tril_triu, + sample_inputs_func=sample_inputs_tril_triu, + skips=( + # Compiler issue on ROCm. Regression started in ROCm 6.4. + DecorateInfo(unittest.skip('Skipped!'), 'TestCommon', 'test_non_standard_bool_values', + dtypes=[torch.bool], active_if=TEST_WITH_ROCM), + )), + OpInfo('triu', + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + error_inputs_func=error_inputs_tril_triu, + sample_inputs_func=sample_inputs_tril_triu, + skips=( + # Compiler issue on ROCm. Regression started in ROCm 6.4. + DecorateInfo(unittest.skip('Skipped!'), 'TestCommon', 'test_non_standard_bool_values', + dtypes=[torch.bool], active_if=TEST_WITH_ROCM), + )), + OpInfo('triu_indices', + dtypes=_dispatch_dtypes((torch.int32, torch.int64)), + sample_inputs_func=sample_inputs_trilu_indices, + ref=lambda h, w, ofs=0, dtype=torch.long, device='cpu' : np.array(np.triu_indices(h, ofs, w), dtype=dtype), + supports_out=False, + supports_autograd=False, + skips=( + # skip these tests since we have non tensor input + DecorateInfo(unittest.skip('Skipped!'), 'TestCommon', 'test_noncontiguous_samples'), + DecorateInfo(unittest.skip('Skipped!'), 'TestCommon', 'test_variant_consistency_eager'), + DecorateInfo(unittest.skip('Skipped!'), 'TestJit', 'test_variant_consistency_jit'), + DecorateInfo(unittest.skip('Skipped!'), 'TestMathBits', 'test_neg_view'), + )), + OpInfo('tril_indices', + dtypes=_dispatch_dtypes((torch.int32, torch.int64)), + sample_inputs_func=sample_inputs_trilu_indices, + ref=lambda h, w, ofs=0, dtype=torch.long, device='cpu' : np.array(np.tril_indices(h, ofs, w), dtype=dtype), + supports_out=False, + supports_autograd=False, + skips=( + # skip these tests since we have non tensor input + DecorateInfo(unittest.skip('Skipped!'), 'TestCommon', 'test_noncontiguous_samples'), + DecorateInfo(unittest.skip('Skipped!'), 'TestCommon', 'test_variant_consistency_eager'), + DecorateInfo(unittest.skip('Skipped!'), 'TestJit', 'test_variant_consistency_jit'), + DecorateInfo(unittest.skip('Skipped!'), 'TestMathBits', 'test_neg_view'), + )), + OpInfo('kron', + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + # Runs very slowly on slow gradcheck - alternatively reduce input sizes + gradcheck_fast_mode=True, + supports_inplace_autograd=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + sample_inputs_func=sample_inputs_kron, + decorators=( + # RuntimeError: view size is not compatible with input tensor's size and stride + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides"), + )), + OpInfo('inner', + dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16), + dtypesIfCUDA=floating_and_complex_types_and(torch.float16, torch.bfloat16), + dtypesIfROCM=floating_and_complex_types_and(torch.half, torch.bfloat16), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # See https://github.com/pytorch/pytorch/pull/78358 + check_batched_forward_grad=False, + sample_inputs_func=sample_inputs_inner, + ), + OpInfo('tensordot', + dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16), + dtypesIfCUDA=floating_and_complex_types_and(torch.float16, torch.bfloat16), + dtypesIfROCM=floating_and_complex_types_and(torch.half, torch.bfloat16), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # See https://github.com/pytorch/pytorch/pull/78358 + check_batched_forward_grad=False, + sample_inputs_func=sample_inputs_tensordot, + skips=( + # Skip operator schema test because this is a functional and not an operator. + # Reference: https://github.com/pytorch/pytorch/issues/54574 + DecorateInfo(unittest.skip("Skipped!"), 'TestOperatorSignatures', 'test_get_torch_func_signature_exhaustive'), + ) + ), + OpInfo('to_sparse', + op=lambda x, *args: x.to_sparse(*args), + sample_inputs_func=sample_inputs_to_sparse, + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), + backward_dtypes=floating_types(), + backward_dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16), + supports_out=False, + supports_sparse_csr=True, + supports_sparse_csc=True, + check_batched_grad=False, + check_batched_gradgrad=False, + skips=( + # NotImplementedError: Could not run 'aten::normal_' with arguments from the 'SparseCPU' backend + DecorateInfo(unittest.skip(""), 'TestCommon', 'test_noncontiguous_samples'), + # TODO: FIXME: complex inputs requiring grad error in forward + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_dtypes'), + # lambda impl + DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), + # Allowed exception: sparse tensors don't have strides + DecorateInfo(unittest.skip("Allowed exception"), 'TestCompositeCompliance', 'test_operator'), + DecorateInfo(unittest.skip("Allowed exception"), 'TestCompositeCompliance', 'test_backward'), + DecorateInfo(unittest.skip("Allowed exception"), 'TestTags', 'test_tags'), + # TODO: implement csr.to_sparse(sample_dim) where sampled_dim is 1. + DecorateInfo(unittest.skip("csr.to_sparse(1) not implemented. Skipped!"), + 'TestSparseCSR', 'test_sparse_csr_consistency'), + # Compiler issue on ROCm. Might need to skip until ROCm5.5 + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_non_standard_bool_values', + dtypes=[torch.bool], active_if=TEST_WITH_ROCM), + ) + ), + OpInfo('logcumsumexp', + dtypes=floating_and_complex_types_and(torch.bfloat16, torch.half), + backward_dtypes=floating_and_complex_types_and(torch.bfloat16, torch.half), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + skips=( + # AssertionError: UserWarning not triggered : Resized a non-empty tensor but did not warn about it. + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning', device_type='cuda'), + # RuntimeError: "max_values_cpu" not implemented for 'ComplexDouble' + # Falling back to non-numerically stabilized exp, causing nan in the results. + DecorateInfo(unittest.expectedFailure, 'TestFwdGradients', 'test_forward_mode_AD', dtypes=[torch.complex128]), + DecorateInfo(unittest.expectedFailure, 'TestFwdGradients', 'test_fn_fwgrad_bwgrad', dtypes=[torch.complex128]), + DecorateInfo( + toleranceOverride({ + torch.float16: tol(atol=7e-5, rtol=6e-3), + }), + "TestInductorOpInfo", + "test_comprehensive", + device_type="cuda" + ), + ), + sample_inputs_func=sample_inputs_logcumsumexp, + error_inputs_func=error_inputs_logcumsumexp), + UnaryUfuncInfo('sigmoid', + aliases=('special.expit', 'nn.functional.sigmoid'), + aten_backward_name='sigmoid_backward', + ref=reference_sigmoid if TEST_SCIPY else None, + decorators=(precisionOverride({torch.float16: 1e-2, + torch.complex64: 1e-1, + torch.bfloat16: 1e-2}),), + skips=( + # Reference: https://github.com/pytorch/pytorch/issues/56012 + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', + dtypes=[torch.complex64, torch.cdouble], device_type='cuda'), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', + dtypes=[torch.chalf, torch.complex64, torch.cdouble], device_type='cuda')), + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), + dtypesIfCUDA=all_types_and_complex_and(torch.complex32, torch.bool, torch.half, torch.bfloat16), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + promotes_int_to_float=True, + assert_autodiffed=True, + # sigmoid(z) = 1 / (1 + exp(-z)), at z = j * pi * odd_number, the denominator is zero + reference_numerics_filter=NumericsFilter( + condition=lambda x: (close_to_int(x / (math.pi * 1j)) + if x.is_complex() else x.new_tensor(False, dtype=torch.bool)), + safe_val=0)), + UnaryUfuncInfo('digamma', + ref=scipy.special.digamma if TEST_SCIPY else None, + aliases=('special.psi', 'special.digamma',), + decorators=(precisionOverride({torch.float16: 5e-1}),), + dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16), + dtypesIfCUDA=all_types_and(torch.bool, torch.half, torch.bfloat16), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + promotes_int_to_float=True), + UnaryUfuncInfo('erf', + ref=scipy.special.erf if TEST_SCIPY else None, + aliases=('special.erf', ), + decorators=(precisionOverride({torch.float16: 1e-2, + torch.bfloat16: 1e-2}),), + skips=( + DecorateInfo(unittest.skip("Skipped! sparse backward not supported"), + 'TestSparseUnaryUfuncs', 'test_sparse_fn_grad'), + + ), + dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16), + assert_autodiffed=True, + assert_jit_shape_analysis=True, + supports_sparse=True, + supports_sparse_csr=True, + supports_sparse_csc=True, + supports_sparse_bsr=True, + supports_sparse_bsc=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + promotes_int_to_float=True), + UnaryUfuncInfo('erfc', + ref=scipy.special.erfc if TEST_SCIPY else None, + aliases=('special.erfc', ), + decorators=(precisionOverride({torch.float16: 1e-2, + torch.bfloat16: 1e-2}),), + dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16), + assert_autodiffed=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + promotes_int_to_float=True), + UnaryUfuncInfo('erfinv', + ref=scipy.special.erfinv if TEST_SCIPY else None, + aliases=('special.erfinv', ), + decorators=(precisionOverride({torch.float16: 1e-2, + torch.bfloat16: 1e-2, + torch.float32: 1e-4}),), + dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16), + dtypesIfCUDA=all_types_and(torch.bool, torch.half, torch.bfloat16), + supports_sparse=True, + supports_sparse_csr=True, + supports_sparse_csc=True, + supports_sparse_bsr=True, + supports_sparse_bsc=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + promotes_int_to_float=True, + domain=(-1, 1), + skips=( + # Reference: https://github.com/pytorch/pytorch/pull/49155#issuecomment-742664611 + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', + active_if=TEST_SCIPY and version.parse(scipy.__version__) < version.parse("1.4.0")), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', + active_if=TEST_SCIPY and version.parse(scipy.__version__) < version.parse("1.4.0")), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_small', + active_if=TEST_SCIPY and version.parse(scipy.__version__) < version.parse("1.4.0")), + DecorateInfo(unittest.expectedFailure, 'TestSparseUnaryUfuncs', 'test_sparse_fn_grad'), + )), + OpInfo("nn.functional.smooth_l1_loss", + ref=reference_smooth_l1_loss, + sample_inputs_func=sample_inputs_smooth_l1_loss, + dtypes=floating_types_and(torch.float16, torch.bfloat16), + backward_dtypes=floating_types_and(torch.float16, torch.bfloat16), + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + skips=( + # RuntimeError: input->type()->kind() == TypeKind::OptionalTypeINTERNAL ASSERT FAILED + # at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":270, please report a bug to PyTorch. + DecorateInfo(unittest.expectedFailure, "TestJit", "test_variant_consistency_jit"),)), + OpInfo( + "nn.functional.l1_loss", + ref=loss_reference_reduction_wrapper(lambda input, target: np.abs(input - target)), + sample_inputs_func=sample_inputs_l1_loss, + error_inputs_func=error_inputs_l1_loss, + dtypes=floating_and_complex_types_and(torch.float16, torch.bfloat16), + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + skips=( + # RuntimeError: input->type()->kind() == TypeKind::OptionalTypeINTERNAL ASSERT FAILED + # at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":270, please report a bug to PyTorch. + DecorateInfo( + unittest.expectedFailure, + "TestJit", + "test_variant_consistency_jit", + dtypes=(torch.float32,), + ), + ), + ), + UnaryUfuncInfo('lgamma', + ref=reference_lgamma if TEST_SCIPY else None, + aliases=('special.gammaln', ), + decorators=(precisionOverride({torch.float16: 7e-1}),), + dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16), + dtypesIfCUDA=all_types_and(torch.bool, torch.half, torch.bfloat16), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + promotes_int_to_float=True, + skips=( + # Reference: https://github.com/pytorch/pytorch/pull/50140#issuecomment-756150214 + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', + dtypes=[torch.float32, torch.float64], active_if=IS_WINDOWS), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', + dtypes=[torch.float32, torch.float64], active_if=IS_WINDOWS), + ), + # lgamma have multiple singularities at x <= 0 + reference_numerics_filter=NumericsFilter(condition=lambda x: x < 0.1, safe_val=1)), + OpInfo( + 'logdet', + dtypes=floating_and_complex_types(), + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + sample_inputs_func=sample_inputs_linalg_det_logdet_slogdet, + decorators=[skipCUDAIfNoMagma, skipCPUIfNoLapack]), + # `log_softmax` supports different dtypes based on whether `dtype` argument, + # is passed or not. Hence two OpInfo entries, one with dtype and other without. + OpInfo( + 'log_softmax', + aliases=('special.log_softmax', 'nn.functional.log_softmax'), + supports_out=True, + aten_backward_name='_log_softmax_backward_data', + dtypes=floating_types_and(torch.float16, torch.bfloat16), + sample_inputs_func=sample_inputs_softmax_variant, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + assert_autodiffed=True), + OpInfo( + 'log_softmax', + variant_test_name='with_dtype', + aliases=('special.log_softmax', 'nn.functional.log_softmax'), + supports_out=True, + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf), + sample_inputs_func=partial(sample_inputs_softmax_variant, with_dtype=True), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + assert_autodiffed=True), + UnaryUfuncInfo('logit', + aten_backward_name='logit_backward', + ref=scipy.special.logit if TEST_SCIPY else None, + domain=(0, 1), + aliases=('special.logit', ), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + promotes_int_to_float=True, + decorators=(precisionOverride({torch.bfloat16: 5e-1, + torch.float16: 5e-1}),), + dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16), + sample_inputs_func=sample_inputs_logit), + OpInfo('where', + # Currently only the `input` is tested in gradcheck. + # If we pass `condition` first, none of the input which supports + # autograd will be tested. Hence the following lambda. + op=lambda self, condition, other, **kwargs: torch.where(condition, self, other, **kwargs), + ref=lambda self, condition, other: np.where(condition, self, other), + sample_inputs_func=sample_inputs_where, + reference_inputs_func=reference_inputs_where, + error_inputs_func=error_inputs_where, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + decorators=( + DecorateInfo(onlyCUDA, "TestCommon", 'test_errors'),), + skips=( + # lambda impl + DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), + DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'), + ), + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf)), + OpInfo('nonzero', + dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16, torch.chalf), + sample_inputs_func=sample_inputs_nonzero, + supports_autograd=False, + skips=( + DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), + # nonzero(): argument 'out' must be Tensor, not tuple + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'), + # https://github.com/pytorch/pytorch/issues/67458 + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + # nonzero is not raising a warning when the out is resized + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'), + # Can't find schemas for this operator for some reason + DecorateInfo(unittest.expectedFailure, 'TestOperatorSignatures', 'test_get_torch_func_signature_exhaustive'), + # Compiler issue on ROCm. Might need to skip until ROCm5.5 + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_non_standard_bool_values', + dtypes=[torch.bool], active_if=TEST_WITH_ROCM), + )), + OpInfo('nonzero_static', + dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16, torch.chalf), + sample_inputs_func=sample_inputs_nonzero_static, + supports_out=False, + supports_autograd=False, + decorators=[onlyCPU], + skips=( + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'), + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'), + DecorateInfo(unittest.expectedFailure, 'TestDTensorOps', 'test_dtensor_op_db'), + DecorateInfo(unittest.expectedFailure, 'TestInductorOpInfo', 'test_comprehensive'), + DecorateInfo(unittest.expectedFailure, 'TestVmapOperatorsOpInfo', 'test_op_has_batch_rule'), + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_non_standard_bool_values', + dtypes=[torch.bool], active_if=TEST_WITH_ROCM), + )), + # Following tests are for jiterator's python interface + # Jiterator can be used to author elementwise CUDA kernel + # jiterator._create_jit_fn returns a callable that behaves like a regular pytorch op + # See create_jit_fn in jiterator.py for more information + UnaryUfuncInfo( + 'jiterator_unary', + op=torch.cuda.jiterator._create_jit_fn("template T unary(T x) { return x * x + x; }"), + ref=lambda x: x * x + x, + dtypes=all_types_and_complex_and(torch.bfloat16, torch.float16, torch.bool), + supports_out=False, + supports_autograd=False, # jiterator ops doesn't have backward defined + decorators=[ + onlyCUDA, + DecorateInfo(toleranceOverride({torch.float16: tol(atol=1e-02, rtol=1e-02)}), + 'TestUnaryUfuncs', 'test_reference_numerics_extremal'), + DecorateInfo(toleranceOverride({torch.float16: tol(atol=1e-02, rtol=1e-02)}), + 'TestUnaryUfuncs', 'test_reference_numerics_hard'), + DecorateInfo(toleranceOverride({torch.float16: tol(atol=1e-02, rtol=1e-02)}), + 'TestUnaryUfuncs', 'test_reference_numerics_normal'), + DecorateInfo(toleranceOverride({torch.float16: tol(atol=1e-02, rtol=1e-02)}), + 'TestUnaryUfuncs', 'test_reference_numerics_small'), + ], + skips=( + # Jiterator ops doesn't support neg or conj view + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'), + # Jiterator ops doesn't support CompositeCompliantTensor + # Following test should expectedFailure, but it's causing cascading failures in CUDA, thus skipped + DecorateInfo(unittest.skip("skip"), 'TestCompositeCompliance', 'test_operator'), + # Skip reference_numerics tests for bool type, as the defined function doesn't work for bool + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', + dtypes=[torch.bool]), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_hard', + dtypes=[torch.bool]), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_normal', + dtypes=[torch.bool]), + # ROCm generates -inf+infj instead of nan+infj for complex64 for some of the results + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', + dtypes=[torch.complex64], active_if=TEST_WITH_ROCM), + # Newer numpy generates -inf+infj instead of nan+infj for complex64 for some of the results + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', + dtypes=[torch.complex64], device_type='cuda'), + # Expected failure: torch.jiterator_unary is not a valid op + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + # Skip Nvfuser + DecorateInfo(unittest.skip('Skipped!'), 'TestCudaFuserOpInfo'), + ) + ), + BinaryUfuncInfo( + 'jiterator_binary', + op=torch.cuda.jiterator._create_jit_fn( + "template T binary(T x, T y, T alpha) { return x + alpha * y; }", alpha=1), + ref=lambda input, other, *, alpha=1: ( + np.add(input, other) + if alpha == 1 + else np.add(input, np.multiply(alpha, other)) + ), + dtypes=all_types_and_complex_and(torch.bfloat16, torch.float16, torch.bool), + sample_inputs_func=partial(sample_inputs_jiterator, num_inputs=2, alpha=-3.14), + supports_out=False, + supports_autograd=False, # jiterator ops doesn't have backward defined + supports_rhs_python_scalar=False, + decorators=[onlyCUDA], + skips=( + # Jiterator ops doesn't support neg or conj view + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'), + # Jiterator ops doesn't support CompositeCompliantTensor + # Following test should expectedFailure, but it's causing cascading failures in CUDA, thus skipped + DecorateInfo(unittest.skip("skip"), 'TestCompositeCompliance', 'test_operator'), + # Expected failure: torch.jiterator_binary is not a valid op + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + # Skip Nvfuser + DecorateInfo(unittest.skip('Skipped!'), 'TestCudaFuserOpInfo'), + ) + ), + OpInfo( + 'jiterator_4inputs_with_extra_args', + op=torch.cuda.jiterator._create_jit_fn( + "template T binary(T i0, T i1, T i2, T i3, T alpha, T beta) { return alpha * i0 + beta * i1 + i2 + i3; }", + alpha=1, beta=1), + ref=lambda i0, i1, i2, i3, *, alpha=1, beta=1: alpha * i0 + beta * i1 + i2 + i3, + dtypes=all_types_and_complex_and(torch.bfloat16, torch.float16, torch.bool), + sample_inputs_func=partial(sample_inputs_jiterator, num_inputs=4, alpha=3.14, beta=-4.20), + supports_out=False, + supports_autograd=False, # jiterator ops doesn't have backward defined + decorators=[onlyCUDA], + skips=( + # Jiterator ops doesn't support neg or conj view + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'), + # Jiterator ops doesn't support CompositeCompliantTensor + # Following test should expectedFailure, but it's causing cascading failures in CUDA, thus skipped + DecorateInfo(unittest.skip("skip"), 'TestCompositeCompliance', 'test_operator'), + # Expected failure: torch.jiterator_4inputs_with_extra_args is not a valid op + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + # Skip Nvfuser + DecorateInfo(unittest.skip('Skipped!'), 'TestCudaFuserOpInfo'), + ) + ), + BinaryUfuncInfo( + 'jiterator_binary_return_by_ref', + op=torch.cuda.jiterator._create_multi_output_jit_fn( + """ + template + void binary_return_by_ref(T i0, T i1, T& out0) { + out0 = i0 + i1; + } + """, + num_outputs=1), + ref=operator.add, + dtypes=all_types_and_complex_and(torch.bfloat16, torch.float16, torch.bool), + sample_inputs_func=partial(sample_inputs_jiterator, num_inputs=2, alpha=-0.42), + supports_out=False, + supports_autograd=False, # jiterator ops doesn't have backward defined + supports_rhs_python_scalar=False, + decorators=[onlyCUDA], + skips=( + # Jiterator ops doesn't support neg or conj view + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'), + # Jiterator ops doesn't support CompositeCompliantTensor + # Following test should expectedFailure, but it's causing cascading failures in CUDA, thus skipped + DecorateInfo(unittest.skip("skip"), 'TestCompositeCompliance', 'test_operator'), + # Expected failure: torch.jiterator_4inputs_with_extra_args is not a valid op + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + # Skip Nvfuser + DecorateInfo(unittest.skip('Skipped!'), 'TestCudaFuserOpInfo'), + ) + ), + OpInfo( + 'jiterator_2inputs_2outputs', + op=torch.cuda.jiterator._create_multi_output_jit_fn( + """ + template + void binary_2outputs(T i0, T i1, T& out0, T& out1) { + out0 = i0 + i1; + out1 = i0 - i1; + } + """, + num_outputs=2), + ref=lambda i0, i1, *, alpha=1: (i0 + i1, i0 - i1), + dtypes=all_types_and_complex_and(torch.bfloat16, torch.float16, torch.bool), + sample_inputs_func=partial(sample_inputs_jiterator, num_inputs=2), + supports_out=False, + supports_autograd=False, # jiterator ops doesn't have backward defined + decorators=[onlyCUDA], + skips=( + # Jiterator ops doesn't support neg or conj view + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'), + # Jiterator ops doesn't support CompositeCompliantTensor + # Following test should expectedFailure, but it's causing cascading failures in CUDA, thus skipped + DecorateInfo(unittest.skip("skip"), 'TestCompositeCompliance', 'test_operator'), + # Expected failure: torch.jiterator_4inputs_with_extra_args is not a valid op + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + # Skip Nvfuser + DecorateInfo(unittest.skip('Skipped!'), 'TestCudaFuserOpInfo'), + ) + ), + # `torch.norm` has multiple code paths depending on the value of `p`. + # These paths have different dtype support. Also JIT supports, + # most variants but not all of them. So we split the OpInfo entries, + # for `norm` based on the code-paths and JIT support. + OpInfo( + "norm", + sample_inputs_func=sample_inputs_norm, + dtypes=floating_and_complex_types_and(torch.float16, torch.bfloat16, torch.chalf), + dtypesIfCUDA=floating_and_complex_types_and(torch.float16, torch.bfloat16), + # TODO Benchmark again with the new implementation + # Runs very slowly on slow gradcheck - alternatively reduce input sizes + gradcheck_fast_mode=True, + check_batched_forward_grad=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + skips=( + # Dispatches in Python to vector_norm. Not sure how to make this test happy + # Happens to pass on complex64. Also a mystery + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', + dtypes=(torch.float32,)),) + ), + OpInfo('norm', + variant_test_name='nuc', + sample_inputs_func=sample_inputs_norm_nuc, + decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack], + check_batched_gradgrad=False, + # torch.autograd.gradcheck.GradcheckError: While computing batched gradients + # got: Could not allocate memory to change Tensor SizesAndStrides! + check_batched_forward_grad=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + dtypes=floating_and_complex_types(), + dtypesIfCUDA=floating_and_complex_types(), + skips=( + # Dispatches in Python to matrix_norm. Not sure how to make this test happy + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', + dtypes=(torch.complex64, torch.float32,)),) + ), + OpInfo('norm', + variant_test_name='fro', + sample_inputs_func=sample_inputs_norm_fro, + dtypes=floating_and_complex_types_and(torch.bfloat16, torch.float16), + dtypesIfCUDA=floating_and_complex_types_and(torch.float16, torch.bfloat16), + supports_forward_ad=True, + # torch.autograd.gradcheck.GradcheckError: While computing batched gradients + # got: Could not allocate memory to change Tensor SizesAndStrides! + check_batched_forward_grad=False, + supports_fwgrad_bwgrad=True, + skips=( + # MPS has some mild accuracy issues for float16. We divide the tolerances by 10 + DecorateInfo( + toleranceOverride({torch.float16: tol(atol=1e-4, rtol=0.01)}), + 'TestConsistency', + 'test_output_match', + + ), + # Issue with conj and torch dispatch, see https://github.com/pytorch/pytorch/issues/82479 + DecorateInfo( + unittest.skip("Skipped!"), + 'TestSchemaCheckModeOpInfo', + 'test_schema_correctness', + dtypes=(torch.complex64, torch.complex128)), + # Dispatches in Python to vector_norm. Not sure how to make this test happy + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', + dtypes=(torch.complex64, torch.float32,)),) + ), + OpInfo( + "norm", + variant_test_name="inf", + sample_inputs_func=sample_inputs_norm_inf, + dtypes=floating_and_complex_types_and(torch.float16, torch.bfloat16, torch.chalf), + dtypesIfCUDA=floating_and_complex_types_and(torch.float16, torch.bfloat16), + supports_forward_ad=True, + check_batched_forward_grad=False, + supports_fwgrad_bwgrad=True, + # fast gradcheck produces NaNs + gradcheck_fast_mode=False, + skips=( + DecorateInfo( + toleranceOverride({torch.float16: tol(atol=2e-3, rtol=1e-3)}), + 'TestInductorOpInfo', 'test_comprehensive', device_type='cuda', + ), + # Dispatches in Python to vector_norm. Not sure how to make this test happy + # Happens to pass on complex64. Also a mystery + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', + dtypes=(torch.float32,)) + ), + ), + OpInfo('t', + sample_inputs_func=sample_inputs_t, + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # See https://github.com/pytorch/pytorch/pull/78358 + check_batched_forward_grad=False, + # vmap does not support inplace views + check_inplace_batched_forward_grad=False, + autodiff_fusible_nodes=[], # aliases inputs, shouldn't be fused + autodiff_nonfusible_nodes=[], # aliases inputs, shouldn't be fused + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), + assert_autodiffed=True, + error_inputs_func=error_inputs_t), + OpInfo('t_copy', + sample_inputs_func=sample_inputs_t, + supports_out=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # See https://github.com/pytorch/pytorch/pull/78358 + check_batched_forward_grad=False, + # vmap does not support inplace views + check_inplace_batched_forward_grad=False, + autodiff_fusible_nodes=[], # aliases inputs, shouldn't be fused + autodiff_nonfusible_nodes=[], # aliases inputs, shouldn't be fused + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), + assert_autodiffed=True, + error_inputs_func=error_inputs_t), + OpInfo( + "nn.functional.dropout", + op=lambda input, *args, **kwargs: + wrapper_set_seed(torch.nn.functional.dropout, input, *args, **kwargs), + dtypes=floating_types_and(torch.float16, torch.bfloat16), + skips=( + DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), + # Probably because we have used lambda for the op here + # AssertionError: JIT Test does not execute any logic + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + # inplace variant dispatches to dropout kernel, while on CUDA + # the op dispatches to _fused_dropout (with a few more conditions) + # hence, different values and this skip here + DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_view', device_type='cuda'), + DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu')), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # https://github.com/pytorch/pytorch/issues/66357 + check_batched_forward_grad=False, + supports_out=False, + sample_inputs_func=sample_inputs_dropout, + inplace_variant=lambda input, *args, **kwargs: + wrapper_set_seed(torch.nn.functional.dropout, input, *args, **kwargs, inplace=True)), + OpInfo( + "native_dropout_backward", + op=torch.ops.aten.native_dropout_backward.default, + aten_name="native_dropout_backward", + dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool), + dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16), + supports_out=False, + sample_inputs_func=sample_inputs_dropout_backward, + skips=( + DecorateInfo(unittest.skip('Skipped!'), 'TestJit', 'test_variant_consistency_jit'), + # Lazy tensor failures + DecorateInfo(unittest.skip('Skipped!'), 'TestLazyOpInfo', 'test_dispatched_to_lazy'), + # These tests fail only when built with ASAN + DecorateInfo(unittest.skip("Fails with ASAN"), 'TestLazyOpInfo', 'test_correctness', active_if=TEST_WITH_ASAN), + DecorateInfo( + unittest.skip("Fails with ASAN"), + 'TestLazyOpInfo', + 'test_correctness_with_reusing_ir', + active_if=TEST_WITH_ASAN + ), + ), + ), + OpInfo( + "nn.functional.dropout2d", + op=lambda input, *args, **kwargs: + wrapper_set_seed(torch.nn.functional.dropout2d, input, *args, **kwargs), + dtypes=floating_types_and(torch.float16, torch.bfloat16), + skips=( + # lambda impl + DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu')), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_out=False, + check_batched_forward_grad=False, + # As per the docs, valid input dims are (3, 4) + sample_inputs_func=partial(sample_inputs_dropout, valid_input_dim=(3, 4)), + inplace_variant=lambda input, *args, **kwargs: + wrapper_set_seed(torch.nn.functional.dropout2d, input, *args, **kwargs, inplace=True)), + OpInfo( + "nn.functional.dropout3d", + op=lambda input, *args, **kwargs: + wrapper_set_seed(torch.nn.functional.dropout3d, input, *args, **kwargs), + dtypes=floating_types_and(torch.float16, torch.bfloat16), + skips=( + # lambda impl + DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu')), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_out=False, + check_batched_forward_grad=False, + # As per the docs, valid input dims are (4, 5) + sample_inputs_func=partial(sample_inputs_dropout, valid_input_dim=(4, 5)), + inplace_variant=lambda input, *args, **kwargs: + wrapper_set_seed(torch.nn.functional.dropout3d, input, *args, **kwargs, inplace=True)), + OpInfo( + "nn.functional.alpha_dropout", + op=lambda input, *args, **kwargs: + wrapper_set_seed(torch.nn.functional.alpha_dropout, input, *args, **kwargs), + dtypes=floating_types_and(torch.float16, torch.bfloat16), + gradcheck_wrapper=wrapper_set_seed, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_out=False, + sample_inputs_func=sample_inputs_dropout, + check_batched_forward_grad=False, + inplace_variant=lambda input, *args, **kwargs: + wrapper_set_seed(torch.nn.functional.alpha_dropout, input, *args, **kwargs, inplace=True), + skips=( + # lambda impl + DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), + # AssertionError: Tensor-likes are not close! + # Fails in cuda11.7 + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_compare_cpu', device_type='cuda'), + # AssertionError: Tensor-likes are not close! + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_compare_cpu', device_type='xpu'), + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),),), + # In training mode, feature_alpha_dropout currently doesn't support inputs of complex dtype + # unlike when `train=False`, it supports complex inputs, hence 2 OpInfos to cover all cases + OpInfo( + "nn.functional.feature_alpha_dropout", + op=lambda input, *args, **kwargs: + wrapper_set_seed(torch.nn.functional.feature_alpha_dropout, input, *args, **kwargs), + variant_test_name="with_train", + dtypes=floating_types_and(torch.float16, torch.bfloat16), + skips=( + # lambda impl + DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + # torch.autograd.gradcheck.GradcheckError: While computing batched gradients, got: + # vmap: We do not yet support calling random operations inside of vmap. + # Please perform random operations outside of vmap as a workaround + DecorateInfo(unittest.expectedFailure, 'TestFwdGradients', "test_forward_mode_AD"), + DecorateInfo(unittest.expectedFailure, 'TestFwdGradients', "test_inplace_forward_mode_AD"), + DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu')), + # Runs very slowly on slow gradcheck - alternatively reduce input sizes + gradcheck_fast_mode=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_out=False, + # As per the docs, valid input dims are (4, 5) + sample_inputs_func=partial(sample_inputs_dropout, train=True, valid_input_dim=(4, 5)), + inplace_variant=lambda input, *args, **kwargs: + wrapper_set_seed(torch.nn.functional.feature_alpha_dropout, input, *args, **kwargs, inplace=True)), + OpInfo( + "nn.functional.feature_alpha_dropout", + op=lambda input, *args, **kwargs: + wrapper_set_seed(torch.nn.functional.feature_alpha_dropout, input, *args, **kwargs), + variant_test_name="without_train", + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), + skips=( + # lambda impl + DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),), + gradcheck_wrapper=wrapper_set_seed, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_out=False, + sample_inputs_func=partial(sample_inputs_dropout, train=False), + inplace_variant=lambda input, *args, **kwargs: + wrapper_set_seed(torch.nn.functional.feature_alpha_dropout, input, *args, **kwargs, inplace=True)), + OpInfo( + "nn.functional.one_hot", + ref=reference_one_hot, + supports_out=False, + dtypes=_dispatch_dtypes((torch.int64,)), + sample_inputs_func=sample_inputs_one_hot, + ), + OpInfo( + "nn.functional.embedding", + aten_backward_name="embedding_dense_backward", + # We use lambda to reshuffle the positional arguments. + # This is because currently only the `input` field of SampleInput + # is tested in gradient tests. + op=lambda weight, idx, **kwargs: torch.nn.functional.embedding(idx, weight, **kwargs), + dtypes=floating_types_and(torch.bfloat16, torch.float16), + sample_inputs_func=sample_inputs_embedding, + allow_cow_input_materialize_forward=[0], + error_inputs_func=error_inputs_embedding, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + skips=( + # lambda impl + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), + # Fails on CI https://github.com/pytorch/pytorch/issues/85377 + DecorateInfo(unittest.skip('Skipped!'), 'TestCommon', 'test_compare_cpu'), + # Reference: https://github.com/pytorch/pytorch/issues/67084 + DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_view', device_type='cuda'), + # Not a problem: embedding does weird stuff to its input (it renormalizes) + DecorateInfo(unittest.skip('Allowed exemption'), 'TestCompositeCompliance', 'test_operator'), + # Fails due to non-determinism (see issue #74679) + # TODO: Investigate why more granular skips in the test don't work in CI + DecorateInfo(unittest.skip('Skipped!'), + 'TestExpandedWeightFunctional', + 'test_expanded_weight_forward'), + ), + supports_expanded_weight=True, + supports_out=False, + ), + OpInfo( + "nn.functional.embedding_bag", + # We use lambda to reshuffle the positional arguments. + # This is because currently only the `input` field of SampleInput + # is tested in gradient tests. + op=lambda weight, idx, **kwargs: torch.nn.functional.embedding_bag(idx, weight, **kwargs), + dtypes=floating_types_and(torch.bfloat16, torch.float16), + dtypesIfCUDA=floating_types_and(torch.bfloat16, torch.float16), + # backward is not supported for mode `max` and dtype `bfloat16` + backward_dtypesIfCUDA=floating_types_and(torch.float16), + sample_inputs_func=sample_inputs_embedding_bag, + skips=( + # lambda impl + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), + # Not a problem: embedding_bag does weird stuff to its input (it renormalizes) + DecorateInfo(unittest.skip('Allowed exemption'), 'TestCompositeCompliance', 'test_operator'), + ), + gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, + supports_out=False, + supports_gradgrad=False, + allow_cow_input_materialize_forward=[0], + ), + OpInfo( + "nn.functional.multi_head_attention_forward", + op=lambda input, *args, **kwargs: + wrapper_set_seed(torch.nn.functional.multi_head_attention_forward, input, *args, **kwargs), + dtypes=floating_types_and(torch.bfloat16, torch.float16), + sample_inputs_func=sample_inputs_multi_head_attention_forward, + skips=( + # Tensor-likes are not close + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_noncontiguous_samples', dtypes=(torch.float32,)), + DecorateInfo(toleranceOverride({torch.float32: tol(atol=5e-3, rtol=0)}), 'TestDecomp', 'test_comprehensive'), + + # TODO skip this for now since we can't skip on runtime arch support (taken from scaled_dot_product_attention) + DecorateInfo(unittest.skip("Skipped!"), 'TestInductorOpInfo', 'test_comprehensive'), + # randomness + DecorateInfo(unittest.skip("Skipped!"), 'TestFwdGradients', 'test_forward_mode_AD'), + DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'), + # lambda impl + # AssertionError: JIT Test does not execute any logic + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), + # tests running very slowly break slow tests, so we skip them instead of using `slowTest`. + DecorateInfo(unittest.skip("Skipped!"), 'TestCompositeCompliance', 'test_forward_ad'), + DecorateInfo(unittest.skip("Skipped!"), 'TestCompositeCompliance', 'test_operator'), + DecorateInfo( + unittest.skip("Skipped - baddbmm decomp does not have enough precision for 16-bit float"), + 'TestDecomp', + 'test_comprehensive', + dtypes=(torch.bfloat16, torch.float16), + ), + DecorateInfo( + unittest.skip("Skipped - baddbmm decomp does not have enough precision for 16-bit float"), + 'TestDecomp', + 'test_quick', + dtypes=(torch.bfloat16, torch.float16))), + supports_out=False, + supports_gradgrad=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # Runs very slowly on slow gradcheck - alternatively reduce input sizes + gradcheck_fast_mode=True, + ), + UnaryUfuncInfo( + "nn.functional.softplus", + aten_backward_name='softplus_backward', + ref=reference_softplus, + sample_kwargs=lambda device, dtype, input: ({'beta': 3, 'threshold': .2}, {'beta': 3, 'threshold': .2}), + sample_inputs_func=partial(sample_inputs_elementwise_unary, op_kwargs={'beta': 3, 'threshold': .2}), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + dtypes=floating_types_and(torch.bfloat16, torch.float16), + decorators=( + DecorateInfo( + toleranceOverride + ({ + torch.half: tol(atol=1e-2, rtol=1e-2), + torch.bfloat16: tol(atol=1e-2, rtol=1e-2), + }), + 'TestUnaryUfuncs'), + ), + ), + OpInfo( + "nn.functional.mse_loss", + aten_backward_name='mse_loss_backward', + ref=loss_reference_reduction_wrapper(lambda input, target: (input - target) ** 2), + sample_inputs_func=sample_inputs_loss, + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + dtypes=floating_types_and(torch.float16, torch.bfloat16), + skips=( + # RuntimeError: input->type()->kind() == TypeKind::OptionalType + # INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":252, + # please report a bug to PyTorch. + DecorateInfo(unittest.expectedFailure, "TestJit", "test_variant_consistency_jit", dtypes=(torch.float32,),), + ), + ), + OpInfo( + "nn.functional.grid_sample", + dtypes=floating_types_and(torch.float16, torch.bfloat16), + supports_out=False, + sample_inputs_func=sample_inputs_grid_sample, + reference_inputs_func=reference_inputs_grid_sample, + supports_gradgrad=False, + gradcheck_nondet_tol=1e-15), + # TODO: delete this OpInfo once we add meta support for grid_sampler_3d + OpInfo( + "grid_sampler_2d", + dtypes=floating_types_and(torch.float16, torch.bfloat16), + supports_out=False, + sample_inputs_func=sample_inputs_grid_sampler_2d, + supports_gradgrad=False, + gradcheck_nondet_tol=1e-15, + skips=( + DecorateInfo(slowTest, 'TestDecomp', 'test_comprehensive', dtypes=(torch.float32, torch.float64), + active_if=IS_WINDOWS), + ),), + # TODO: Remove grid_sampler_3d tests once `nn.functional.grid_sample` has + # MPS support for all cases. + OpInfo( + "grid_sampler_3d", + dtypes=floating_types_and(torch.float16, torch.bfloat16), + supports_out=False, + sample_inputs_func=sample_inputs_grid_sampler_3d, + supports_gradgrad=False, + gradcheck_nondet_tol=1e-15, + skips=( + # NOTE: Only run on MPS + DecorateInfo(unittest.skip('Skipped!'), device_type='cpu'), + DecorateInfo(unittest.skip('Skipped!'), device_type='cuda'), + DecorateInfo(unittest.skip('Skipped!'), device_type='xpu'), + DecorateInfo(unittest.skip('Skipped!'), device_type='meta'), + ),), + OpInfo( + "argwhere", + ref=np.argwhere, + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), + supports_out=False, + supports_autograd=False, + sample_inputs_func=sample_inputs_argwhere, + skips=( + # Compiler issue on ROCm. Might need to skip until ROCm5.5 + DecorateInfo(unittest.skip('Skipped!'), 'TestCommon', 'test_non_standard_bool_values', + dtypes=[torch.bool], active_if=TEST_WITH_ROCM), + ), + ), + ReductionOpInfo( + 'all', + identity=True, + supports_autograd=False, + result_dtype=torch.bool, + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), + ref=reference_reduction_numpy(np.all), + skips=( + # FIXME: uint8 input returns uint8 instead of bool + DecorateInfo(unittest.expectedFailure, 'TestReductions', 'test_result_dtype', dtypes=[torch.uint8]), + ), + ), + ReductionOpInfo( + 'any', + identity=False, + supports_autograd=False, + result_dtype=torch.bool, + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), + ref=reference_reduction_numpy(np.any), + skips=( + # FIXME: uint8 input returns uint8 instead of bool + DecorateInfo(unittest.expectedFailure, 'TestReductions', 'test_result_dtype', dtypes=[torch.uint8]), + ), + ), + ReductionOpInfo( + 'amax', + nan_policy='propagate', + supports_forward_ad=True, + check_batched_forward_grad=False, + supports_fwgrad_bwgrad=True, + dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool), + ref=reference_reduction_numpy(np.amax), + skips=( + # FIXME: reduces all dimensions when dim=[] + DecorateInfo(unittest.expectedFailure, 'TestReductions', 'test_dim_empty'), + DecorateInfo(unittest.expectedFailure, 'TestReductions', 'test_dim_empty_keepdim'), + ), + error_inputs_func=error_inputs_aminmax_amax_amin, + ), + ReductionOpInfo( + 'amin', + nan_policy='propagate', + supports_forward_ad=True, + check_batched_forward_grad=False, + supports_fwgrad_bwgrad=True, + dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool), + ref=reference_reduction_numpy(np.amin), + skips=( + # FIXME: reduces all dimensions when dim=[] + DecorateInfo(unittest.expectedFailure, 'TestReductions', 'test_dim_empty'), + DecorateInfo(unittest.expectedFailure, 'TestReductions', 'test_dim_empty_keepdim'), + ), + error_inputs_func=error_inputs_aminmax_amax_amin, + ), + ReductionOpInfo( + 'argmax', + supports_multiple_dims=False, + supports_autograd=False, + assert_jit_shape_analysis=True, + result_dtype=torch.int64, + dtypes=all_types_and(torch.float16, torch.bfloat16), + ref=reference_reduction_numpy(np.argmax, supports_keepdims=False), + ), + ReductionOpInfo( + 'argmin', + supports_multiple_dims=False, + supports_autograd=False, + result_dtype=torch.int64, + dtypes=all_types_and(torch.float16, torch.bfloat16), + ref=reference_reduction_numpy(np.argmin, supports_keepdims=False), + ), + ReductionOpInfo( + 'count_nonzero', + identity=0, + supports_out=False, + supports_autograd=False, + result_dtype=torch.int64, + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), + sample_inputs_func=sample_inputs_reduction_count_nonzero, + ref=reference_reduction_numpy(np.count_nonzero), + skips=( + # FIXME: count_nonzero does not accept keepdim kwarg + DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_default_keepdim'), + DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_none_keepdim'), + DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_single_keepdim'), + DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty_keepdim'), + DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_multi_keepdim'), + DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_multi_unsorted_keepdim'), + DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_offbounds_keepdim'), + # FIXME: dim=[] reduces all dimensions + DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty'), + ), + ), + ReductionOpInfo( + 'mean', + nan_policy='propagate', + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # FIXME: mean needs 'dim' parameter when using the 'out' overload. + # Adding it with 'generate_args_kwargs' does not work, since these also get passed + # onto the reference implementations. + supports_out=True, + assert_autodiffed=True, + assert_jit_shape_analysis=True, + promotes_int_to_float=True, + dtypes=floating_and_complex_types_and(torch.float16, torch.bfloat16), + ref=reference_reduction_numpy(np.mean), + error_inputs_func=error_inputs_mean, + skips=( + # AssertionError: RuntimeError not raised : Expected RuntimeError when doing an unsafe cast from a result + # of dtype torch.float32 into an out= with dtype torch.long + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_out', device_type='cuda', dtypes=[torch.float32]), + # FIXME: mean does not support passing keepdim without passing dim + DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_default_keepdim'), + # FIXME: mean reduces all dimensions when dim=[] + DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty'), + DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty_keepdim'), + # FIXME: improve precision + DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_ref_small_input', + dtypes=[torch.float16]), + DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_ref_extremal_values', + device_type='cuda', dtypes=[torch.complex64]), + ), + ), + ReductionOpInfo( + 'nanmean', + nan_policy='omit', + assert_autodiffed=True, + promotes_int_to_float=True, + supports_forward_ad=True, + check_batched_forward_grad=False, + supports_fwgrad_bwgrad=True, + dtypes=floating_types_and(torch.float16, torch.bfloat16), + dtypesIfCUDA=floating_and_complex_types_and(torch.float16, torch.bfloat16, torch.chalf), + sample_inputs_func=sample_inputs_nan_reduction(supports_multiple_dims=True), + ref=reference_reduction_numpy(np.nanmean), + skips=( + # AssertionError: False is not true : + # Failure in testing nodes' autodifferentiation. + DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'), + # FIXME: prod reduces all dimensions when dim=[] + DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty'), + DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty_keepdim'), + # FIXME: improve precision + DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_ref_small_input', + dtypes=[torch.float16]), + DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_ref_duplicate_values', + device_type='cuda', dtypes=[torch.float16]), + DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_ref_extremal_values', + device_type='cuda', dtypes=[torch.complex64]), + DecorateInfo(toleranceOverride({torch.float16: tol(atol=2e-5, rtol=4e-2)}), + "TestConsistency", "test_output_match", device_type="mps"), + ), + ), + ReductionOpInfo( + 'std', + nan_policy='propagate', + supports_out=True, + complex_to_real=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + assert_autodiffed=True, + promotes_int_to_float=True, + check_batched_forward_grad=False, + dtypes=floating_and_complex_types_and(torch.half, torch.bfloat16), + dtypesIfCUDA=floating_and_complex_types_and(torch.half, torch.bfloat16), + sample_inputs_func=sample_inputs_std_var, + ref=reference_std_var(np.std), + generate_args_kwargs=generate_std_var_kwargs, + skips=( + # FIXME: cannot specify keepdim without dim + DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_default_keepdim'), + # FIXME: dim=[] reduces all dimensions + DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty'), + DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty_keepdim'), + # FIXME: improve precision + DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_ref_small_input', + dtypes=(torch.float16,)), + DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_ref_duplicate_values', + dtypes=(torch.float16,)), + ), + ), + ReductionOpInfo( + 'std', + variant_test_name='unbiased', + nan_policy='propagate', + supports_out=False, + complex_to_real=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + assert_autodiffed=True, + promotes_int_to_float=True, + check_batched_forward_grad=False, + dtypes=floating_and_complex_types_and(torch.half, torch.bfloat16), + dtypesIfCUDA=floating_and_complex_types_and(torch.half, torch.bfloat16), + sample_inputs_func=sample_inputs_std_var_unbiased, + skips=( + # FIXME: dim=[] reduces all dimensions + DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty'), + DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty_keepdim'), + ), + ), + ReductionOpInfo( + 'var', + nan_policy='propagate', + supports_out=True, + assert_autodiffed=True, + promotes_int_to_float=True, + complex_to_real=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + check_batched_forward_grad=False, + dtypes=floating_and_complex_types_and(torch.half, torch.bfloat16), + dtypesIfCUDA=floating_and_complex_types_and(torch.half, torch.bfloat16), + sample_inputs_func=sample_inputs_std_var, + ref=reference_std_var(np.var), + generate_args_kwargs=generate_std_var_kwargs, + skips=( + # FIXME: cannot specify keepdim without dim + DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_default_keepdim'), + # FIXME: dim=[] reduces all dimensions + DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty'), + DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty_keepdim'), + # FIXME: improve precision + DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_ref_small_input'), + DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_ref_duplicate_values'), + # NumPy is giving NaN for this + DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_ref_large_input'), + ), + ), + ReductionOpInfo( + 'var', + variant_test_name='unbiased', + nan_policy='propagate', + supports_out=False, + complex_to_real=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + assert_autodiffed=True, + promotes_int_to_float=True, + check_batched_forward_grad=False, + dtypes=floating_and_complex_types_and(torch.half, torch.bfloat16), + dtypesIfCUDA=floating_and_complex_types_and(torch.half, torch.bfloat16), + sample_inputs_func=sample_inputs_std_var_unbiased, + skips=( + # FIXME: dim=[] reduces all dimensions + DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty'), + DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty_keepdim'), + ), + ), + ReductionOpInfo( + 'prod', + identity=1, + nan_policy='propagate', + supports_multiple_dims=False, + # https://github.com/pytorch/pytorch/issues/80411 + gradcheck_fast_mode=True, + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + promotes_int_to_int64=True, + gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, + dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16), + dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf), + sample_inputs_func=sample_inputs_prod, + ref=prod_numpy, + skips=( + # FIXME: prod does not support passing keepdim without passing dim + DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_default_keepdim'), + # FIXME: prod reduces all dimensions when dim=[] + DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty'), + DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty_keepdim'), + # FIXME: prod does not support passing None to dim + DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_none'), + DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_none_keepdim'), + DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_ref_small_input', + dtypes=[torch.float16, torch.complex64]), + DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_ref_duplicate_values', + dtypes=[torch.uint8, torch.float16, torch.complex64]), + # FIXME: ValueError: The data in MaskedTensor a and Tensor b do not match + DecorateInfo(unittest.skip("Skipped!"), 'TestOperators', 'test_reduction_all', + dtypes=[torch.float16]), + ), + ), + ReductionOpInfo( + 'sum', + identity=0, + nan_policy='propagate', + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + promotes_int_to_int64=True, + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), + dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf), + ref=reference_reduction_numpy(np.sum), + error_inputs_sparse_func=error_inputs_sparse_reduction_sum, + sample_inputs_sparse_coo_func=partial(sample_inputs_sparse_reduction_sum, layout=torch.sparse_coo), + sample_inputs_sparse_csr_func=partial(sample_inputs_sparse_reduction_sum, layout=torch.sparse_csr), + sample_inputs_sparse_csc_func=partial(sample_inputs_sparse_reduction_sum, layout=torch.sparse_csc), + sample_inputs_sparse_bsr_func=partial(sample_inputs_sparse_reduction_sum, layout=torch.sparse_bsr), + sample_inputs_sparse_bsc_func=partial(sample_inputs_sparse_reduction_sum, layout=torch.sparse_bsc), + skips=( + # FIXME: sum does not support passing keepdim without passing dim + DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_default_keepdim'), + # FIXME: sum reduces all dimensions when dim=[] + DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty'), + DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty_keepdim'), + # FIXME: improve precision + DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_ref_small_input', + dtypes=[torch.float16]), + DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_ref_duplicate_values', + dtypes=[torch.float16]), + DecorateInfo(unittest.skip("Skipped!"), 'TestOperators', 'test_reduction_all', + dtypes=[torch.float32]), + ), + ), + ReductionOpInfo( + 'nansum', + identity=0, + nan_policy='omit', + supports_out=True, + promotes_int_to_int64=True, + supports_forward_ad=True, + check_batched_forward_grad=False, + supports_fwgrad_bwgrad=True, + dtypes=all_types_and(torch.bool, torch.float16, torch.bfloat16), + dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf), + sample_inputs_func=sample_inputs_nan_reduction(supports_multiple_dims=True), + ref=reference_reduction_numpy(np.nansum), + skips=( + # please report a bug to PyTorch. + DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'), + # FIXME: nansum reduces all dimensions when dim=[] + DecorateInfo(unittest.expectedFailure, 'TestReductions', 'test_dim_empty'), + DecorateInfo(unittest.expectedFailure, 'TestReductions', 'test_dim_empty_keepdim'), + # FIXME: flaky test so skipped instead of xfailed + # possibly bad low precision reference in numpy + DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_ref_small_input', + dtypes=[torch.float16]), + DecorateInfo(toleranceOverride({torch.float16: tol(atol=3e-3, rtol=4e-2)}), + "TestConsistency", "test_output_match", device_type="mps"), + ), + ), + ReductionOpInfo( + 'hash_tensor', + result_dtype=torch.uint64, + supports_autograd=False, + dtypes=all_types_and(torch.bool, torch.float16, torch.bfloat16), + dtypesIfCUDA=all_types_and(torch.bool, torch.float16, torch.bfloat16), + ref=reference_hash_tensor, + skips=( + # hash_tensor reduces all dimensions when dim=[] (as do sum, prod etc.) + DecorateInfo(unittest.expectedFailure, 'TestReductions', 'test_dim_empty'), + DecorateInfo(unittest.expectedFailure, 'TestReductions', 'test_dim_empty_keepdim'), + # aten::hash_tensor hit the vmap fallback which is currently disabled + DecorateInfo(unittest.skip("Skipped!"), "TestVmapOperatorsOpInfo", "test_op_has_batch_rule"), + DecorateInfo(unittest.skip("Skipped!"), "TestVmapOperatorsOpInfo", "test_vmap_exhaustive"), + # NYI + DecorateInfo(unittest.expectedFailure, 'TestInductorOpInfo', 'test_comprehensive'), + # Sharding strategy NYI + DecorateInfo(unittest.expectedFailure, 'TestDTensorOps', 'test_dtensor_op_db'), + ) + ), + OpInfo( + "nn.functional.ctc_loss", + dtypes=floating_types(), + supports_out=False, + sample_inputs_func=sample_inputs_ctc_loss, + # gradcheck_wrapper, see https://github.com/pytorch/pytorch/issues/52241 + gradcheck_wrapper=gradcheck_wrapper_ctc_loss, + skips=( + # RuntimeError: derivative for aten::_ctc_loss_backward is not implemented + DecorateInfo( + unittest.expectedFailure, + "TestBwdGradients", + "test_fn_gradgrad", + dtypes=(torch.float64,), + ), + # RuntimeError: derivative for aten::_ctc_loss_backward is not implemented + DecorateInfo( + unittest.skip("Skipped!"), + "TestJit", + "test_variant_consistency_jit", + dtypes=(torch.float32,), + ), + # Ref: https://github.com/pytorch/pytorch/issues/85231 + DecorateInfo(unittest.skip("Fails with ASAN"), + 'TestProxyTensorOpInfo', + 'test_make_fx_fake_exhaustive', active_if=TEST_WITH_ASAN), + ), + ), + OpInfo( + "nn.functional.cosine_embedding_loss", + dtypes=all_types_and(torch.half, torch.bfloat16, torch.bool), + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + decorators=[ + DecorateInfo( + toleranceOverride({torch.float16: tol(atol=1e-4, rtol=2e-3)}), + 'TestInductorOpInfo', 'test_comprehensive', device_type="cuda", + ), + ], + sample_inputs_func=sample_inputs_cosine_embedding_loss, + ), + OpInfo( + "nn.functional.nll_loss", + dtypes=floating_types_and(torch.float16, torch.bfloat16), + supports_out=False, + sample_inputs_func=sample_inputs_nll_loss, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + assert_jit_shape_analysis=True, + skips=( + # RuntimeError: + # undefined value tensor: + # File "", line 3 + # def the_method(i0, i1): + # return torch.nn.functional.nll_loss(i0, i1, weight=tensor([8.4784, 1.7658, 4.3228], dtype=torch.float32)) + # ~~~~~~ <--- HERE + DecorateInfo(unittest.skip("Skipped!"), "TestJit", "test_variant_consistency_jit", dtypes=(torch.float32,),), + # Fails for unknown reason: https://github.com/pytorch/pytorch/issues/120782 + DecorateInfo( + unittest.skip("Skipped!"), + "TestCompositeCompliance", + "test_cow_input", + device_type='cuda', + ), + DecorateInfo(unittest.skip("FP16 nll_loss cases have not been enabled on MPS yet"), + dtypes=(torch.half,), device_type="mps"), + + ), + ), + OpInfo( + "nn.functional.gaussian_nll_loss", + dtypes=floating_types_and(torch.half, torch.bfloat16), + # Runs very slowly on slow gradcheck - alternatively reduce input sizes + gradcheck_fast_mode=True, + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + sample_inputs_func=sample_inputs_gaussian_nll_loss, + error_inputs_func=error_inputs_gaussian_nll_loss, + skips=( + # Pre-existing condition (calls .item); needs to be fixed + DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_backward'), + DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_forward_ad'), + # Pre-existing condition (calls .item); needs to be fixed + DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_operator'), + # JIT does not support variadic tensors. + # RuntimeError: input->type()->kind() == TypeKind::OptionalType + # INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":270, + # please report a bug to PyTorch. + DecorateInfo(unittest.skip("Skipped!"), "TestJit", "test_variant_consistency_jit", dtypes=(torch.float32,),), + DecorateInfo(toleranceOverride({torch.float16: tol(atol=8e-3, rtol=2e-3)}), + "TestConsistency", "test_output_match", device_type="mps"), + DecorateInfo(toleranceOverride({torch.float16: tol(atol=8e-3, rtol=2e-3)}), + "TestConsistency", "test_output_grad_match", device_type="mps"), + ), + ), + OpInfo( + "nn.functional.hinge_embedding_loss", + dtypes=floating_types_and(torch.half, torch.bfloat16), + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + sample_inputs_func=sample_inputs_hinge_embedding_loss, + error_inputs_func=error_inputs_hinge_embedding_loss, + reference_inputs_func=reference_inputs_hinge_embedding_loss, + ), + OpInfo( + "nn.functional.huber_loss", + aten_backward_name='huber_loss_backward', + dtypes=floating_types_and(torch.float16, torch.bfloat16), + supports_out=False, + supports_forward_ad=True, + sample_inputs_func=sample_inputs_huber_loss, + error_inputs_func=error_inputs_huber_loss, + skips=( + # JIT does not support variadic tensors. + # RuntimeError: input->type()->kind() == TypeKind::OptionalType + # INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":270, + # please report a bug to PyTorch. + DecorateInfo(unittest.skip("Skipped!"), "TestJit", "test_variant_consistency_jit", dtypes=(torch.float32,),), + ) + ), + OpInfo( + "nn.functional.pdist", + ref=reference_pdist, + sample_inputs_func=sample_inputs_pdist, + dtypes=floating_types(), + supports_out=False, + supports_gradgrad=False, + skips=( + DecorateInfo(unittest.skip("Unsupported on MPS for now"), 'TestCommon', 'test_numpy_ref_mps'), + ) + ), + OpInfo( + "nn.functional.poisson_nll_loss", + dtypes=all_types_and(torch.half, torch.bfloat16), + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + sample_inputs_func=sample_inputs_poisson_nll_loss, + error_inputs_func=error_inputs_poisson_nll_loss, + ), + OpInfo( + "argsort", + dtypes=all_types_and(torch.bool, torch.float16, torch.bfloat16), + dtypesIfCUDA=all_types_and(torch.bool, torch.float16, torch.bfloat16), + sample_inputs_func=sample_inputs_sort, + supports_out=False, + supports_autograd=False, + skips=( + DecorateInfo( + unittest.skip("Skipped!"), + "TestJit", + "test_variant_consistency_jit", + dtypes=(torch.float32,), + ), + DecorateInfo( + unittest.expectedFailure, + "TestCommon", + "test_non_standard_bool_values", + dtypes=[torch.bool], + device_type='cuda', + active_if=not TEST_WITH_ROCM + ), + ), + ), + OpInfo( + "repeat_interleave", + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf), + backward_dtypesIfCUDA=floating_and_complex_types_and(torch.float16, torch.bfloat16, torch.chalf), + sample_inputs_func=sample_inputs_repeat_interleave, + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # See https://github.com/pytorch/pytorch/pull/78358 + check_batched_forward_grad=False, + skips=( + DecorateInfo( + unittest.skip("Skipped!"), + "TestJit", + "test_variant_consistency_jit", + dtypes=(torch.float32, torch.complex64), + ), + ), + ), + OpInfo( + "nn.functional.pairwise_distance", + ref=lambda a, b, p=2.0, eps=1e-6, keepdim=False: ( + np.sum(np.abs(a - b + eps) ** p, axis=-1, keepdims=keepdim) ** (1 / p) + ), + sample_inputs_func=sample_inputs_pairwise_distance, + dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16), + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + skips=( + DecorateInfo( + unittest.skip("Skipped!"), + "TestJit", + "test_variant_consistency_jit", + dtypes=(torch.float32, torch.complex64), + ), + ), + ), + OpInfo( + "nn.functional.pixel_shuffle", + sample_inputs_func=sample_inputs_pixel_shuffle, + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + skips=( + DecorateInfo( + unittest.skip("Skipped!"), + "TestJit", + "test_variant_consistency_jit", + dtypes=(torch.float32, torch.complex64), + ), + ), + ), + OpInfo( + "nn.functional.pixel_unshuffle", + sample_inputs_func=sample_inputs_pixel_unshuffle, + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + skips=( + DecorateInfo( + unittest.skip("Skipped!"), + "TestJit", + "test_variant_consistency_jit", + dtypes=(torch.float32, torch.complex64), + ), + ), + ), + OpInfo( + "nn.functional.channel_shuffle", + sample_inputs_func=sample_inputs_channel_shuffle, + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + allow_cow_input_materialize_forward=[0], + allow_cow_input_materialize_backward=[0, 'output grad 0'], + skips=( + # Skip due to NotImplementedError for MPS device. + DecorateInfo(unittest.expectedFailure, 'TestConsistency'), + DecorateInfo(unittest.expectedFailure, 'TestDTensorOps', 'test_dtensor_op_db'), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides"), + ), + ), + OpInfo( + "nn.functional.kl_div", + sample_inputs_func=sample_inputs_kl_div, + dtypes=floating_types_and(torch.float16, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + ), + OpInfo( + "diagflat", + ref=lambda input, offset=0: np.diagflat(input, k=offset), + sample_inputs_func=sample_inputs_diagflat, + dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16), + dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), + supports_out=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # See https://github.com/pytorch/pytorch/pull/78358 + check_batched_forward_grad=False, + ), + OpInfo( + 'scatter_reduce', + variant_test_name='sum', + inplace_variant=torch.Tensor.scatter_reduce_, + # complex not added to dtypes as complex gradients are not properly handled + # and scatter_reduce hasn't been added to the whitelist in gen_variable_type yet + dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + sample_inputs_func=sample_inputs_scatter_reduce, + skips=( + # Compiler issue on ROCm. Regression started in ROCm 6.4. + DecorateInfo(unittest.skip('Skipped!'), 'TestCommon', 'test_non_standard_bool_values', + dtypes=[torch.bool], active_if=TEST_WITH_ROCM), + ), + ), + OpInfo( + 'scatter_reduce', + variant_test_name='prod', + # complex not added to dtypes as complex gradients are not properly handled + # and scatter_reduce hasn't been added to the whitelist in gen_variable_type yet + dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool), + dtypesIfCUDA=all_types_and(torch.float16, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + sample_inputs_func=sample_inputs_scatter_reduce, + skips=( + # Not implemented + DecorateInfo(unittest.expectedFailure, 'TestFwdGradients', 'test_forward_mode_AD'), + DecorateInfo(unittest.expectedFailure, 'TestFwdGradients', 'test_inplace_forward_mode_AD'), + DecorateInfo(unittest.expectedFailure, 'TestFwdGradients', 'test_fn_fwgrad_bwgrad'), + ), + ), + OpInfo( + 'scatter_reduce', + variant_test_name='mean', + # complex not added to dtypes as complex gradients are not properly handled + # and scatter_reduce hasn't been added to the whitelist in gen_variable_type yet + dtypes=all_types_and(torch.float16, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + dtypesIfCUDA=all_types_and(torch.float16, torch.bfloat16), + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + sample_inputs_func=sample_inputs_scatter_reduce, + ), + OpInfo( + 'scatter_reduce', + variant_test_name='amin', + dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool), + dtypesIfCUDA=all_types_and(torch.float16, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + supports_forward_ad=True, + check_batched_forward_grad=False, + supports_fwgrad_bwgrad=True, + sample_inputs_func=sample_inputs_scatter_reduce, + ), + OpInfo( + 'scatter_reduce', + variant_test_name='amax', + dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool), + dtypesIfCUDA=all_types_and(torch.float16, torch.bfloat16), + dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), + supports_forward_ad=True, + check_batched_forward_grad=False, + supports_fwgrad_bwgrad=True, + sample_inputs_func=sample_inputs_scatter_reduce, + ), + OpInfo( + '_segment_reduce', + aten_name='segment_reduce', + variant_test_name='lengths', + dtypes=floating_types_and(torch.float16, torch.bfloat16), + supports_out=False, + # RuntimeError: derivative for aten::_segment_reduce_backward is not implemented + supports_gradgrad=False, + sample_inputs_func=sample_inputs_segment_reduce, + skips=( + # FIXME: CUDA driver API confirmed a leak in + # __main__.TestJitCUDA.test_variant_consistency_jit_segment_reduce_cuda_float32 + DecorateInfo( + unittest.skip("Skipped!"), + "TestJit", + "test_variant_consistency_jit", + device_type="cuda", + ), + ), + ), + OpInfo( + '_segment_reduce', + aten_name='segment_reduce', + variant_test_name='offsets', + dtypes=floating_types_and(torch.float16, torch.bfloat16), + supports_out=False, + # RuntimeError: derivative for aten::_segment_reduce_backward is not implemented + supports_gradgrad=False, + sample_inputs_func=partial(sample_inputs_segment_reduce, mode='offsets'), + skips=( + # FIXME: CUDA driver API confirmed a leak in + # __main__.TestJitCUDA.test_variant_consistency_jit_segment_reduce_cuda_float32 + DecorateInfo( + unittest.skip("Skipped!"), + "TestJit", + "test_variant_consistency_jit", + device_type="cuda", + ), + ), + ), +] +op_db += opinfo.definitions.op_db + + +# Separate registry for experimental Python Reference OpInfos. +python_ref_db = [ + # + # Elementwise Unary OpInfos + # + ElementwiseUnaryPythonRefInfo( + "_refs.abs", + torch_opinfo_name="abs", + skips=( + # Reference: https://github.com/pytorch/pytorch/issues/49224 + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_small', + dtypes=[torch.int8], active_if=TEST_WITH_ASAN), + ), + ), + ElementwiseUnaryPythonRefInfo( + "_refs.acos", + torch_opinfo_name="acos", + skips=( + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_normal', + device_type='cuda', dtypes=[torch.cdouble], + active_if=IS_WINDOWS), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_extremal', + device_type='cuda', dtypes=[torch.cdouble], + active_if=IS_WINDOWS), + # Failing with wrong imaginary sign on at least some Windows jobs + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_small', + device_type='cuda', dtypes=[torch.cdouble], + active_if=IS_WINDOWS), + # Failing with wrong imaginary sign on at least some Windows jobs + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_large', + device_type='cuda', dtypes=[torch.cdouble], + active_if=IS_WINDOWS), + ) + ), + ElementwiseUnaryPythonRefInfo( + "_refs.acosh", + torch_opinfo_name="acosh", + skips=( + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_normal', + device_type='cuda', dtypes=[torch.cdouble], + active_if=IS_WINDOWS), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_extremal', + device_type='cuda', dtypes=[torch.cdouble], + active_if=IS_WINDOWS), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_extremal', + device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_large', + device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_extremal', + device_type='cuda', dtypes=[torch.cdouble], + active_if=IS_WINDOWS), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_large', + device_type='cuda', dtypes=[torch.cdouble], + active_if=IS_WINDOWS), + # Failing with wrong imaginary sign on at least some Windows jobs + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_small', + device_type='cuda', dtypes=[torch.cdouble], + active_if=IS_WINDOWS), + ), + ), + ElementwiseUnaryPythonRefInfo( + "_refs.asin", + torch_opinfo_name="asin", + decorators=[ + DecorateInfo( + toleranceOverride({torch.float16: tol(atol=1e-05, rtol=1e-03)}), + 'TestUnaryUfuncs', device_type='cuda'), + DecorateInfo( + toleranceOverride({torch.complex64: tol(atol=5e-05, rtol=2e-05)}), + 'TestUnaryUfuncs', 'test_reference_numerics_extremal', device_type='cpu' + ), + precisionOverride({torch.bfloat16: 1e-2}), + ], + skips=( + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_extremal', + device_type='cuda', dtypes=[torch.cdouble], + active_if=IS_WINDOWS), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_large', + device_type='cuda', dtypes=[torch.cdouble], + active_if=IS_WINDOWS), + ), + ), + ElementwiseUnaryPythonRefInfo( + "_refs.asinh", + torch_opinfo_name="asinh", + decorators=(precisionOverride({torch.bfloat16: 5e-2}),), + skips=( + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_extremal', + device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_large', + device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_small', + device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_normal', + device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_extremal', + device_type='cuda', dtypes=[torch.cdouble], + active_if=IS_WINDOWS), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_large', + device_type='cuda', dtypes=[torch.cdouble], + active_if=IS_WINDOWS), + ), + ), + PythonRefInfo( + "_refs.lerp", + torch_opinfo_name="lerp", + ), + PythonRefInfo( + "_refs.ones", + torch_opinfo_name="ones", + skips=( + # Tests that assume input is a tensor or sequence of tensors + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'), + ), + ), + PythonRefInfo( + "_refs.zeros", + torch_opinfo_name="zeros", + skips=( + # Tests that assume input is a tensor or sequence of tensors + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'), + ), + ), + PythonRefInfo( + "_refs.cauchy", + torch_opinfo_name="cauchy", + decorators=( + # TODO: RuntimeError: no _refs support for torch.rand_like + DecorateInfo(unittest.skip("TODO: RuntimeError: no _refs support for torch.rand_like"), + 'TestCommon', + 'test_python_ref'), + # AssertionError: Tensor-likes are not close! + DecorateInfo(unittest.skip("Expected: cauchy is not comparable"), + 'TestCommon', + 'test_out'), + DecorateInfo(unittest.skip("Expected: cauchy is not comparable"), + 'TestCommon', + 'test_out_warning'), + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_executor'), + DecorateInfo(unittest.skip("Expected: cauchy is not comparable"), + 'TestCommon', + 'test_python_ref_torch_fallback'), + DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'), + ) + ), + PythonRefInfo( + "_refs.exponential", + torch_opinfo_name="exponential", + supports_out=True, + decorators=( + # dtypes that do not support check_uniform_bounds of rand_like + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_meta', + dtypes=(torch.int8, torch.uint8, torch.int16, torch.int32, torch.int64)), + DecorateInfo(unittest.skip('Skipped!'), 'TestCommon', 'test_dtypes'), + + # TODO: RuntimeError: no _refs support for torch.rand_like + DecorateInfo(unittest.skip("TODO: RuntimeError: no _refs support for torch.rand_like"), + 'TestCommon', + 'test_python_ref'), + + # AssertionError: Tensor-likes are not close! + DecorateInfo(unittest.skip("Expected: exponential is not comparable"), + 'TestCommon', + 'test_out'), + DecorateInfo(unittest.skip("Expected: exponential is not comparable"), + 'TestCommon', + 'test_out_warning'), + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_executor'), + DecorateInfo(unittest.skip("Expected: exponential is not comparable"), + 'TestCommon', + 'test_python_ref_torch_fallback'), + DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'), + ) + ), + PythonRefInfo( + "_refs.geometric", + torch_opinfo_name="geometric", + supports_out=True, + decorators=( + # dtypes that do not support check_uniform_bounds of rand_like + DecorateInfo(unittest.skip('Skipped!'), 'TestCommon', 'test_dtypes'), + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_meta', + dtypes=(torch.int8, torch.uint8, torch.int16, torch.int32, torch.int64)), + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_torch_fallback', + dtypes=(torch.int8, torch.uint8, torch.int16, torch.int32, torch.int64)), + + # TODO: RuntimeError: no _refs support for torch.rand_like + DecorateInfo(unittest.skip("TODO: RuntimeError: no _refs support for torch.rand_like"), + 'TestCommon', + 'test_python_ref'), + DecorateInfo(unittest.skip("Expected: geometric is not comparable"), + 'TestCommon', + 'test_python_ref_executor', device_type='cuda'), + + # AssertionError: Tensor-likes are not close! + DecorateInfo(unittest.skip("Expected: geometric is not comparable"), + 'TestCommon', + 'test_out'), + DecorateInfo(unittest.skip("Expected: geometric is not comparable"), + 'TestCommon', + 'test_out_warning'), + DecorateInfo(unittest.skip("Expected: geometric is not comparable"), + 'TestCommon', + 'test_python_ref_torch_fallback'), + DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'), + ) + ), + PythonRefInfo( + "_refs.log_normal", + torch_opinfo_name="log_normal", + supports_out=True, + decorators=( + # TODO: RuntimeError: no _refs support for torch.rand_like + DecorateInfo(unittest.skip("TODO: RuntimeError: no _refs support for torch.rand_like"), + 'TestCommon', + 'test_python_ref'), + DecorateInfo(unittest.skip("Expected: log_normal is not comparable"), + 'TestCommon', + 'test_python_ref_executor', device_type='cuda'), + + # AssertionError: Tensor-likes are not close! + DecorateInfo(unittest.skip("Expected: log_normal is not comparable"), + 'TestCommon', + 'test_out'), + DecorateInfo(unittest.skip("Expected: log_normal is not comparable"), + 'TestCommon', + 'test_out_warning'), + DecorateInfo(unittest.skip("Expected: log_normal is not comparable"), + 'TestCommon', + 'test_python_ref_torch_fallback'), + DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'), + ) + ), + PythonRefInfo( + "_refs.normal", + torch_opinfo_name="normal", + supports_out=True, + decorators=( + # TODO: RuntimeError: no _refs support for torch.rand_like + DecorateInfo(unittest.skip("TODO: RuntimeError: no _refs support for torch.rand_like"), + 'TestCommon', + 'test_python_ref'), + + # AssertionError: Tensor-likes are not close! + DecorateInfo(unittest.skip("Expected: normal is not comparable"), + 'TestCommon', + 'test_out'), + DecorateInfo(unittest.skip("Expected: normal is not comparable"), + 'TestCommon', + 'test_out_warning'), + DecorateInfo(unittest.skip("Expected: normal is not comparable"), + 'TestCommon', + 'test_python_ref_torch_fallback'), + DecorateInfo(unittest.skip("Expected: normal is not comparable"), 'TestDecomp', 'test_comprehensive'), + DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'), + DecorateInfo(unittest.skip("make_traced() doesn't set seed properly!"), 'TestCommon', 'test_python_ref_executor'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'), + ) + ), + PythonRefInfo( + "_refs.normal", + torch_opinfo_name="normal", + torch_opinfo_variant_name="number_mean", + supports_out=True, + decorators=( + # TODO: RuntimeError: no _refs support for torch.rand_like + DecorateInfo(unittest.skip("TODO: RuntimeError: no _refs support for torch.rand_like"), + 'TestCommon', + 'test_python_ref'), + + # AssertionError: Tensor-likes are not close! + DecorateInfo(unittest.skip("Expected: normal is not comparable"), + 'TestCommon', + 'test_out'), + DecorateInfo(unittest.skip("Expected: normal is not comparable"), + 'TestCommon', + 'test_out_warning'), + DecorateInfo(unittest.skip("Expected: normal is not comparable"), + 'TestCommon', + 'test_python_ref_torch_fallback'), + DecorateInfo(unittest.skip("Expected: normal is not comparable"), 'TestDecomp', 'test_comprehensive'), + DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'), + DecorateInfo(unittest.skip("make_traced() doesn't set seed properly!"), 'TestCommon', 'test_python_ref_executor'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'), + ) + ), + PythonRefInfo( + "_refs.normal_", + op=torch.Tensor.normal_, + torch_opinfo_name="normal", + torch_opinfo_variant_name="in_place", + supports_out=False, + decorators=( + # TODO: RuntimeError: no _refs support for torch.rand_like + DecorateInfo(unittest.skip("TODO: RuntimeError: no _refs support for torch.rand_like"), + 'TestCommon', + 'test_python_ref'), + + # AssertionError: Tensor-likes are not close! + DecorateInfo(unittest.skip("Expected: normal is not comparable"), + 'TestCommon', + 'test_out'), + DecorateInfo(unittest.skip("Expected: normal is not comparable"), + 'TestCommon', + 'test_out_warning'), + DecorateInfo(unittest.skip("Expected: normal is not comparable"), + 'TestCommon', + 'test_python_ref_torch_fallback'), + DecorateInfo(unittest.skip("Expected: normal is not comparable"), 'TestDecomp', 'test_comprehensive'), + DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'), + DecorateInfo(unittest.skip("make_traced() doesn't set seed properly!"), 'TestCommon', 'test_python_ref_executor'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'), + ) + ), + PythonRefInfo( + "_refs.arange", + torch_opinfo_name="arange", + skips=( + # Tests that assume input is a tensor or sequence of tensors + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'), + ), + ), + PythonRefInfo( + "_refs.linspace", + torch_opinfo_name="linspace", + skips=( + # Tests that assume input is a tensor or sequence of tensors + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'), + + # cpu implementation is wrong on some integral types + # https://github.com/pytorch/pytorch/issues/81996 + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_torch_fallback', + dtypes=(torch.int8, torch.uint8, torch.int16, torch.int32, torch.int64), device_type="cpu"), + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref', + dtypes=(torch.int8, torch.uint8, torch.int16, torch.int32, torch.int64), device_type="cpu"), + + # cuda implementation is off-by-one on some inputs due to precision issues + # https://github.com/pytorch/pytorch/issues/82230 + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_torch_fallback', + dtypes=(torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64), + device_type="cuda"), + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref', + dtypes=(torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64), + device_type="cuda"), + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_executor', + dtypes=(torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64), + device_type="cuda"), + ), + ), + PythonRefInfo( + "_refs.linspace", + torch_opinfo_name="linspace", + torch_opinfo_variant_name="tensor_overload", + skips=( + # TypeError: 'int' object is not subscriptable + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'), + + # cpu implementation is wrong on some integral types + # https://github.com/pytorch/pytorch/issues/81996 + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_torch_fallback', + dtypes=(torch.int8, torch.uint8, torch.int16, torch.int32, torch.int64), device_type="cpu"), + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref', + dtypes=(torch.int8, torch.uint8, torch.int16, torch.int32, torch.int64), device_type="cpu"), + + # cuda implementation is off-by-one on some inputs due to precision issues + # https://github.com/pytorch/pytorch/issues/82230 + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_torch_fallback', + dtypes=(torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64), + device_type="cuda"), + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref', + dtypes=(torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64), + device_type="cuda"), + # TODO torch.ops.aten.copy is not in _refs + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref', + dtypes=(torch.float32, torch.float64, torch.float16, torch.complex64, torch.complex128, torch.bfloat16), + device_type="cuda"), + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref', + dtypes=(torch.float32, torch.float64, torch.float16, torch.complex64, torch.complex128, torch.bfloat16), + device_type="cpu"), + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_executor', + dtypes=(torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64), + device_type="cuda"), + ), + ), + PythonRefInfo( + "_refs.logspace", + torch_opinfo_name="logspace", + skips=( + # Tests that assume input is a tensor or sequence of tensors + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'), + + # Off-by-one issue when casting floats to ints + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_torch_fallback', + dtypes=(torch.int16, torch.int32, torch.int64), + device_type="cuda"), + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref', + dtypes=(torch.int16, torch.int32, torch.int64), + device_type="cuda"), + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_executor', + dtypes=(torch.int16, torch.int32, torch.int64), + device_type="cuda"), + ), + ), + PythonRefInfo( + "_refs.logspace", + torch_opinfo_name="logspace", + torch_opinfo_variant_name="tensor_overload", + skips=( + # TypeError: 'int' object is not subscriptable + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'), + + # Off-by-one issue when casting floats to ints + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_torch_fallback', + dtypes=(torch.int16, torch.int32, torch.int64), + device_type="cuda"), + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref', + dtypes=(torch.int16, torch.int32, torch.int64), + device_type="cuda"), + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_executor', + dtypes=(torch.int16, torch.int32, torch.int64), + device_type="cuda"), + # TODO copy doesn't have prim refs + DecorateInfo( + unittest.expectedFailure, 'TestCommon', 'test_python_ref', + dtypes=( + torch.float32, torch.float64, torch.float16, torch.complex64, + torch.complex128, torch.bfloat16, torch.int8, torch.uint8 + ), + device_type="cuda" + ), + DecorateInfo( + unittest.expectedFailure, 'TestCommon', 'test_python_ref', + dtypes=( + torch.float32, torch.float64, torch.float16, + torch.complex64, torch.complex128, torch.bfloat16, + torch.int16, torch.int32, torch.int64, torch.int8, torch.uint8 + ), + device_type="cpu"), + ), + ), + PythonRefInfo( + "_refs.meshgrid", + torch_opinfo_name="meshgrid", + torch_opinfo_variant_name="variadic_tensors", + ), + PythonRefInfo( + "_refs.take_along_dim", + torch_opinfo_name="take_along_dim", + skips=( + DecorateInfo(unittest.expectedFailure, + 'TestCommon', + 'test_python_ref'), + ), + ), + PythonRefInfo( + "_refs.to", + torch_opinfo_name="to", + ), + PythonRefInfo( + "_refs.triu", + torch_opinfo_name="triu", + ), + PythonRefInfo( + "_refs.tril", + torch_opinfo_name="tril", + ), + PythonRefInfo( + "_refs.triu_indices", + torch_opinfo_name="triu_indices", + # the implementation uses torch.stack that violates view consistency + validate_view_consistency=False, + skips=( + # skip these tests since we have non tensor input + DecorateInfo(unittest.skip('Skipped!'), 'TestCommon', 'test_noncontiguous_samples'), + DecorateInfo(unittest.skip('Skipped!'), 'TestCommon', 'test_variant_consistency_eager'), + DecorateInfo(unittest.skip('Skipped!'), 'TestJit', 'test_variant_consistency_jit'), + DecorateInfo(unittest.skip('Skipped!'), 'TestMathBits', 'test_neg_view'), + )), + PythonRefInfo( + "_refs.tril_indices", + torch_opinfo_name="tril_indices", + # the implementation uses torch.stack that violates view consistency + validate_view_consistency=False, + skips=( + # skip these tests since we have non tensor input + DecorateInfo(unittest.skip('Skipped!'), 'TestCommon', 'test_noncontiguous_samples'), + DecorateInfo(unittest.skip('Skipped!'), 'TestCommon', 'test_variant_consistency_eager'), + DecorateInfo(unittest.skip('Skipped!'), 'TestJit', 'test_variant_consistency_jit'), + DecorateInfo(unittest.skip('Skipped!'), 'TestMathBits', 'test_neg_view'), + )), + PythonRefInfo( + "_refs.meshgrid", + torch_opinfo_name="meshgrid", + torch_opinfo_variant_name="list_of_tensors", + ), + PythonRefInfo( + "_refs.movedim", + aliases=('moveaxis',), + torch_opinfo_name="movedim", + ), + PythonRefInfo( + "_refs.bucketize", + torch_opinfo_name="bucketize", + skips=( + # RuntimeError: It appears that you're trying to get value out of a tracing tensor with + # aten._local_scalar_dense.default - erroring out! [...] + # triggered by mid_val = boundaries[mid] + DecorateInfo(unittest.expectedFailure, "TestCommon", "test_python_ref_executor"), + ) + ), + PythonRefInfo( + "_refs.equal", + torch_opinfo_name="equal", + skips=( + # RuntimeError: Cannot cast FakeTensor to number + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_meta',), + ) + ), + ElementwiseUnaryPythonRefInfo( + "_refs.atan", + torch_opinfo_name="atan", + decorators=(precisionOverride({torch.bfloat16: 1e-2}),), + skips=( + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_extremal', + device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_large', + device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_small', + device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_extremal', + device_type='cuda', dtypes=[torch.cfloat, torch.cdouble], + active_if=IS_WINDOWS), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_large', + device_type='cuda', dtypes=[torch.cfloat, torch.cdouble], + active_if=IS_WINDOWS), + ), + ), + ElementwiseUnaryPythonRefInfo( + "_refs.atanh", + torch_opinfo_name="atanh", + decorators=(precisionOverride({torch.bfloat16: 1e-2}),), + skips=( + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_small', + device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_extremal', + device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_large', + device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_extremal', + device_type='cuda', dtypes=[torch.cfloat, torch.cdouble], + active_if=IS_WINDOWS), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_large', + device_type='cuda', dtypes=[torch.cfloat], + active_if=IS_WINDOWS), + ), + ), + ElementwiseUnaryPythonRefInfo( + "_refs.bitwise_not", + torch_opinfo_name="bitwise_not", + ), + ElementwiseUnaryPythonRefInfo( + "_refs.ceil", + torch_opinfo_name="ceil", + # Fails on int32 + # https://github.com/pytorch/pytorch/issues/85258 + ), + PythonRefInfo( + "_refs.item", + torch_opinfo_name="item", + skips=( + # RuntimeError: Cannot cast FakeTensor(FakeTensor(..., device='meta', size=()), cpu) to number + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_meta'), + # ValueError: Can't convert a tensor with 10 elements to a number! + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_errors'),), + ), + ElementwiseUnaryPythonRefInfo( + "_refs.conj_physical", + torch_opinfo_name="conj_physical", + ), + ElementwiseUnaryPythonRefInfo( + "_refs.cos", + torch_opinfo_name="cos", + decorators=(precisionOverride({torch.bfloat16: 1e-2}),), + skips=( + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_large', + dtypes=(torch.cfloat, torch.cdouble,), device_type='cpu', + active_if=IS_WINDOWS), + # This fails on CUDA but passes on ROCm + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_large', + dtypes=(torch.cdouble,), device_type='cuda'), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_extremal', + dtypes=[torch.cfloat, torch.cdouble], active_if=IS_WINDOWS), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_extremal', + device_type='cpu', + dtypes=[torch.cfloat, torch.cdouble], active_if=IS_MACOS), + # AssertionError: Tensor-likes are not close! + # Greatest absolute difference: nan at index (700,) (up to 1e-05 allowed) + # Greatest relative difference: nan at index (700,) (up to 0.001 allowed) + DecorateInfo(unittest.expectedFailure, 'TestUnaryUfuncs', + 'test_reference_numerics_large', + device_type='cuda', + dtypes=(torch.chalf,), active_if=IS_WINDOWS), + ), + ), + ElementwiseUnaryPythonRefInfo( + "_refs.cosh", + torch_opinfo_name="cosh", + skips=( + # Reference: https://github.com/pytorch/pytorch/issues/48641 + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_large', + device_type='cpu', dtypes=[torch.int8]), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_large', + dtypes=[torch.cdouble]), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_extremal', + dtypes=[torch.cfloat, torch.cdouble], active_if=IS_WINDOWS), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_large', + dtypes=[torch.cfloat, torch.cdouble], active_if=IS_WINDOWS), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_extremal', + device_type='cpu', + dtypes=[torch.cfloat, torch.cdouble], active_if=IS_MACOS), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_large', + device_type='cpu', + dtypes=[torch.cfloat, torch.cdouble], active_if=IS_MACOS), + # AssertionError: Tensor-likes are not close! + # Greatest absolute difference: nan at index (6000,) (up to 1e-05 allowed) + # Greatest relative difference: nan at index (6000,) (up to 0.001 allowed) + DecorateInfo(unittest.expectedFailure, 'TestUnaryUfuncs', + 'test_reference_numerics_large', + device_type='cuda', + dtypes=(torch.chalf,), active_if=IS_WINDOWS), + ), + ), + ElementwiseUnaryPythonRefInfo( + "_refs.digamma", + torch_opinfo_name="digamma", + ), + ElementwiseUnaryPythonRefInfo( + "_refs.erf", + torch_opinfo_name="erf", + ), + ElementwiseUnaryPythonRefInfo( + "_refs.erfinv", + torch_opinfo_name="erfinv", + decorators=(precisionOverride({torch.float16: 1e-2, + torch.bfloat16: 1e-2, + torch.float32: 1e-4}),), + skips=( + # Reference: https://github.com/pytorch/pytorch/pull/49155#issuecomment-742664611 + DecorateInfo( + unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_extremal', + active_if=TEST_SCIPY and version.parse(scipy.__version__) < version.parse("1.4.0")), + DecorateInfo( + unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_large', + active_if=TEST_SCIPY and version.parse(scipy.__version__) < version.parse("1.4.0")), + DecorateInfo( + unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_small', + active_if=TEST_SCIPY and version.parse(scipy.__version__) < version.parse("1.4.0")), + ), + ), + ElementwiseUnaryPythonRefInfo( + "_refs.erfc", + torch_opinfo_name="erfc", + ), + ElementwiseUnaryPythonRefInfo( + "_refs.exp", + torch_opinfo_name="exp", + skips=( + # Reference: https://github.com/pytorch/pytorch/issues/48010 + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_extremal', + device_type='cpu', dtypes=[torch.cfloat, torch.cdouble], active_if=IS_WINDOWS), + ), + ), + ElementwiseUnaryPythonRefInfo( + "_refs.expm1", + torch_opinfo_name="expm1", + ), + ElementwiseUnaryPythonRefInfo( + "_refs.exp2", + torch_opinfo_name="exp2", + skips=( + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_large', + dtypes=[torch.cdouble]), + # Reference: https://github.com/pytorch/pytorch/issues/48010 + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_extremal', + device_type='cpu', dtypes=[torch.cfloat, torch.cdouble], active_if=IS_WINDOWS), + ), + ), + ElementwiseUnaryPythonRefInfo( + "_refs.fill", + torch_opinfo_name="fill", + supports_out=True, + ), + ElementwiseUnaryPythonRefInfo( + "_refs.floor", + torch_opinfo_name="floor", + # Fails on int32 + # https://github.com/pytorch/pytorch/issues/85258 + ), + ElementwiseUnaryPythonRefInfo( + "_refs.frexp", + torch_opinfo_name="frexp", + # Skipped due to numerical failures on Windows CI. + # This is also skipped in frexp earlier in the file. + skips=( + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', + active_if=IS_WINDOWS), + ), + ), + ElementwiseUnaryPythonRefInfo( + "_refs.frac", + torch_opinfo_name="frac", + skips=( + DecorateInfo( + unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_extremal', + dtypes=(torch.bfloat16, torch.float16, torch.float32, torch.float64)), + ), + ), + ElementwiseUnaryPythonRefInfo( + "_refs.imag", + torch_opinfo_name="imag", + ), + ElementwiseUnaryPythonRefInfo( + "_refs.isfinite", + torch_opinfo_name="isfinite", + supports_out=True, + ), + ElementwiseUnaryPythonRefInfo( + "_refs.isinf", + torch_opinfo_name="isinf", + supports_out=True, + ), + ElementwiseUnaryPythonRefInfo( + "_refs.isposinf", + torch_opinfo_name="isposinf", + supports_out=True, + ), + ElementwiseUnaryPythonRefInfo( + "_refs.isneginf", + torch_opinfo_name="isneginf", + supports_out=True, + ), + ElementwiseUnaryPythonRefInfo( + "_refs.isnan", + torch_opinfo_name="isnan", + supports_out=True, + ), + ElementwiseUnaryPythonRefInfo( + "_refs.isreal", + torch_opinfo_name="isreal", + supports_out=True, + ), + ElementwiseUnaryPythonRefInfo( + "_refs.i0", + torch_opinfo_name="i0", + decorators=(precisionOverride({torch.bfloat16: 3e-1, + torch.float16: 5e-1}),), + skips=( + DecorateInfo(unittest.skip("Skipped!"), + 'TestUnaryUfuncs', + 'test_reference_numerics_large', + dtypes=(torch.int8,)), + ), + ), + ElementwiseUnaryPythonRefInfo( + "_refs.lgamma", + torch_opinfo_name="lgamma", + decorators=(precisionOverride({torch.float16: 7e-1}),), + skips=( + # Reference: https://github.com/pytorch/pytorch/pull/50140#issuecomment-756150214 + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_extremal', + dtypes=[torch.float32, torch.float64], active_if=IS_WINDOWS), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_large', + dtypes=[torch.float32, torch.float64], active_if=IS_WINDOWS), + ), + ), + ElementwiseUnaryPythonRefInfo( + "_refs.special.multigammaln", + torch_opinfo_name="mvlgamma", + torch_opinfo_variant_name="mvlgamma_p_1", + skips=skips_mvlgamma(), + decorators=( + DecorateInfo(torch.testing._internal.common_utils.markDynamoStrictTest, 'TestUnaryUfuncs', + 'test_reference_numerics_large'), + DecorateInfo(torch.testing._internal.common_utils.xfailIfTorchDynamo, 'TestUnaryUfuncs', + 'test_reference_numerics_large'), + ), + ), + ElementwiseUnaryPythonRefInfo( + "_refs.special.multigammaln", + torch_opinfo_name="mvlgamma", + torch_opinfo_variant_name="mvlgamma_p_3", + skips=skips_mvlgamma(), + ), + ElementwiseUnaryPythonRefInfo( + "_refs.special.multigammaln", + torch_opinfo_name="mvlgamma", + torch_opinfo_variant_name="mvlgamma_p_5", + skips=skips_mvlgamma(), + ), + ElementwiseUnaryPythonRefInfo( + "_refs.log", + torch_opinfo_name="log", + decorators=(precisionOverride({torch.bfloat16: 5e-2}),), + skips=( + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_extremal', + device_type='cpu', dtypes=[torch.cfloat, torch.cdouble], + active_if=IS_WINDOWS), + ), + ), + ElementwiseUnaryPythonRefInfo( + "_refs.log1p", + torch_opinfo_name="log1p", + ), + ElementwiseUnaryPythonRefInfo( + "_refs.log10", + torch_opinfo_name="log10", + decorators=(precisionOverride({torch.bfloat16: 5e-2}),), + skips=( + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_extremal', + device_type='cpu', dtypes=[torch.cfloat, torch.cdouble], + active_if=IS_WINDOWS), + ), + ), + ElementwiseUnaryPythonRefInfo( + "_refs.log2", + torch_opinfo_name="log2", + decorators=(precisionOverride({torch.bfloat16: 1e-1}),), + skips=( + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_extremal', + dtypes=[torch.cfloat, torch.cdouble]), + ), + ), + PythonRefInfo( + "_refs.logsumexp", + torch_opinfo_name="logsumexp", + # When keepdim=False logsumexp function uses squeeze operation + # that is not yet exposed in nvFuser's Python API. + ), + PythonRefInfo( + "_refs.log_softmax", + torch_opinfo_name="log_softmax", + torch_opinfo_variant_name="with_dtype", + ), + ElementwiseUnaryPythonRefInfo( + "_refs.nan_to_num", + torch_opinfo_name="nan_to_num", + ), + ElementwiseUnaryPythonRefInfo( + "_refs.neg", + torch_opinfo_name="neg", + ), + ElementwiseUnaryPythonRefInfo( + "_refs.positive", + torch_opinfo_name="positive", + ), + ElementwiseUnaryPythonRefInfo( + "_refs.real", + torch_opinfo_name="real", + ), + ElementwiseUnaryPythonRefInfo( + "_refs.reciprocal", + torch_opinfo_name="reciprocal", + skips=( + # Reference: https://github.com/pytorch/pytorch/issues/45690 + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_extremal', + dtypes=[torch.cfloat, torch.cdouble]), + ), + ), + ElementwiseUnaryPythonRefInfo( + "_refs.round", + torch_opinfo_name="round", + # Fails on int32 + # https://github.com/pytorch/pytorch/issues/85258 + skips=( + DecorateInfo(toleranceOverride({torch.bfloat16: tol(atol=1e-3, rtol=0.016)}), + "TestUnaryUfuncs", "test_reference_numerics_extremal", + device_type="cuda"), + DecorateInfo(toleranceOverride({torch.bfloat16: tol(atol=1e-3, rtol=0.016)}), + "TestUnaryUfuncs", "test_reference_numerics_normal", + device_type="cuda"), + ), + ), + ElementwiseUnaryPythonRefInfo( + "_refs.rsqrt", + torch_opinfo_name="rsqrt", + decorators=(precisionOverride({torch.half: 5e-2}),), + skips=( + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_extremal', + dtypes=(torch.cfloat, torch.cdouble)), + # AssertionError: Tensor-likes are not close! + # Greatest absolute difference: nan at index (700,) (up to 0.01 allowed) + # Greatest relative difference: nan at index (700,) (up to 0.001 allowed) + DecorateInfo(unittest.expectedFailure, 'TestUnaryUfuncs', + 'test_reference_numerics_large', + dtypes=(torch.chalf,)), + ), + ), + ElementwiseUnaryPythonRefInfo( + "_refs.sigmoid", + torch_opinfo_name="sigmoid", + aliases=('_refs.special.expit',), + # Reference: https://github.com/pytorch/pytorch/issues/56012 + handles_complex_extremal_values=False, + handles_large_floats=False, + decorators=(precisionOverride({torch.float16: 1e-2, + torch.complex64: 1e-1, + torch.bfloat16: 1e-2}),), + skips=( + # Reference: https://github.com/pytorch/pytorch/issues/56012 + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_extremal', + dtypes=[torch.complex64, torch.cdouble], device_type='cuda'), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_large', + dtypes=[torch.chalf, torch.complex64, torch.cdouble], device_type='cuda') + ), + ), + ElementwiseUnaryPythonRefInfo( + "_refs.sign", + torch_opinfo_name="sign", + skips=( + # Reference: https://github.com/pytorch/pytorch/issues/41245 + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_extremal', + dtypes=[torch.bfloat16, torch.float16, torch.float32, + torch.float64]), + ), + ), + ElementwiseUnaryPythonRefInfo( + "_refs.sgn", + torch_opinfo_name="sgn", + # This is an issue with the vectorised abs on CPU + handles_complex_extremal_values=False, + handles_large_floats=False, + skips=( + # Reference: https://github.com/pytorch/pytorch/issues/41245 + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_extremal', + dtypes=[torch.bfloat16, torch.float16, torch.float32, + torch.float64]), + ), + ), + ElementwiseUnaryPythonRefInfo( + "_refs.signbit", + torch_opinfo_name="signbit", + ), + ElementwiseUnaryPythonRefInfo( + "_refs.sin", + torch_opinfo_name="sin", + decorators=(precisionOverride({torch.bfloat16: 1e-2}),), + skips=( + # Fails on CUDA but passes on ROCm + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_large', + dtypes=(torch.cdouble,), device_type='cuda'), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_extremal', + dtypes=(torch.cfloat, torch.cdouble,), device_type='cpu', + active_if=IS_WINDOWS), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_large', + dtypes=(torch.cfloat, torch.cdouble,), device_type='cpu', + active_if=IS_WINDOWS), + ), + ), + ElementwiseUnaryPythonRefInfo( + "_refs.sinc", + torch_opinfo_name="sinc", + decorators=(precisionOverride({torch.bfloat16: 1e-2, + torch.float16: 1e-2}),), + skips=( + # Reference: https://github.com/pytorch/pytorch/issues/49133 + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_small', + dtypes=[torch.cfloat]), + ), + ), + ElementwiseUnaryPythonRefInfo( + "_refs.sinh", + torch_opinfo_name="sinh", + decorators=(precisionOverride({torch.float16: 1e-2}),), + skips=( + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_extremal', + device_type='cpu', dtypes=[torch.cfloat, torch.cdouble], + active_if=(IS_MACOS or IS_WINDOWS)), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_large', + device_type='cpu', dtypes=[torch.cfloat, torch.cdouble], + active_if=(IS_MACOS or IS_WINDOWS)), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_large', + dtypes=(torch.cdouble,)), + # Reference: https://github.com/pytorch/pytorch/issues/48641 + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_large', + device_type='cpu', dtypes=[torch.int8]), + ), + ), + PythonRefInfo( + "_refs.softmax", + torch_opinfo_name="softmax", + torch_opinfo_variant_name="with_dtype", + ), + ElementwiseUnaryPythonRefInfo( + "_refs.sqrt", + torch_opinfo_name="sqrt", + decorators=( + precisionOverride({torch.bfloat16: 7e-2}), + DecorateInfo( + toleranceOverride({torch.chalf: tol(atol=1e-2, rtol=0)}), + 'TestUnaryUfuncs', 'test_reference_numerics_large'), + ), + skips=( + # Reference: https://github.com/pytorch/pytorch/issues/47358 + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_large', + device_type='cpu', dtypes=(torch.cfloat, torch.cdouble), + active_if=IS_MACOS), + # Reference: https://github.com/pytorch/pytorch/pull/47293#issuecomment-721774436 + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_large', + dtypes=(torch.bfloat16,)), + ), + ), + ElementwiseUnaryPythonRefInfo( + "_refs.square", + torch_opinfo_name="square", + decorators=(precisionOverride({torch.complex64: 3e-4, torch.bfloat16: 3e-1}),), + skips=( + # AssertionError: Reference result was farther (2.2417024338305655e-07) from the precise computation + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_python_ref_executor', dtypes=(torch.complex64,)), + # Reference: https://github.com/pytorch/pytorch/issues/52549 + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_large', + dtypes=[torch.cfloat, torch.cdouble]), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_extremal', + device_type='cuda', dtypes=[torch.cfloat, torch.cdouble]), + ), + ), + ElementwiseUnaryPythonRefInfo( + "_refs.tan", + torch_opinfo_name="tan", + decorators=[ + DecorateInfo( + toleranceOverride({torch.complex64: tol(atol=1e-04, rtol=1e-05)}), + 'TestUnaryUfuncs', 'test_reference_numerics_extremal', device_type='cuda'), + ], + skips=( + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_extremal', + device_type='cpu', dtypes=[torch.cfloat, torch.cdouble], + active_if=(IS_MACOS or IS_WINDOWS)), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_large', + device_type='cpu', dtypes=[torch.cfloat, torch.cdouble], + active_if=(IS_MACOS or IS_WINDOWS)), + ) + ), + ElementwiseUnaryPythonRefInfo( + "_refs.tanh", + torch_opinfo_name="tanh", + decorators=[ + DecorateInfo( + toleranceOverride({torch.complex64: tol(atol=1e-04, rtol=2e-05)}), + 'TestUnaryUfuncs', 'test_reference_numerics_extremal', device_type='cuda'), + ], + skips=( + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_extremal', + device_type='cpu', dtypes=[torch.cfloat, torch.cdouble], + active_if=(IS_MACOS or IS_WINDOWS)), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_large', + device_type='cpu', dtypes=[torch.cfloat, torch.cdouble], + active_if=(IS_MACOS or IS_WINDOWS)), + ), + ), + ElementwiseUnaryPythonRefInfo( + "_refs.trunc", + torch_opinfo_name="trunc", + # Fails on int32 + # https://github.com/pytorch/pytorch/issues/85258 + ), + PythonRefInfo( + "_refs.special.log_softmax", + torch_opinfo_name="log_softmax", # alias + torch_opinfo_variant_name="with_dtype", + supports_out=False, + ), + PythonRefInfo( + "_refs.special.softmax", + torch_opinfo_name="softmax", # alias + torch_opinfo_variant_name="with_dtype", + supports_out=False, + ), + # + # Elementwise Unary Special OpInfos + # + ElementwiseUnaryPythonRefInfo( + "_refs.special.logit", + torch_opinfo_name="logit", + ), + # + # Elementwise Unary nn.functional OpInfos + # + PythonRefInfo( + "_refs.nn.functional.alpha_dropout", + torch_opinfo_name="nn.functional.alpha_dropout", + decorators=( + DecorateInfo(unittest.skip("Expected: dropout is not comparable"), + 'TestCommon', + 'test_python_ref'), + # AssertionError: Tensor-likes are not close! + DecorateInfo(unittest.skip("Expected: dropout is not comparable"), + 'TestCommon', + 'test_python_ref_torch_fallback'), + DecorateInfo(unittest.skip("Expected: dropout is not comparable"), + 'TestCommon', + 'test_python_ref_executor', device_type='cuda'), + # AssertionError: Tensor-likes are not close! + DecorateInfo(unittest.skip("Expected: dropout is not comparable"), + 'TestMathBits', + 'test_neg_view'), + # AssertionError: Tensor-likes are not close! + DecorateInfo(unittest.skip("Expected: dropout is not comparable"), + 'TestCommon', + 'test_compare_cpu'), + ) + ), + ElementwiseUnaryPythonRefInfo( + "_refs.nn.functional.celu", + torch_opinfo_name="nn.functional.celu", + supports_out=True, + ), + PythonRefInfo( + "_refs.nn.functional.channel_shuffle", + torch_opinfo_name="nn.functional.channel_shuffle", + supports_out=True, + ), + ElementwiseUnaryPythonRefInfo( + "_refs.nn.functional.threshold", + torch_opinfo_name="nn.functional.threshold", + supports_out=True, + ), + PythonRefInfo( + "_refs.nn.functional.dropout", + torch_opinfo_name="nn.functional.dropout", + decorators=( + DecorateInfo(unittest.skip("Expected: dropout is not comparable"), + 'TestCommon', + 'test_python_ref'), + DecorateInfo(unittest.skip("Expected: dropout is not comparable"), + 'TestCommon', + 'test_python_ref_torch_fallback'), + DecorateInfo(unittest.skip("Expected: dropout is not comparable"), + 'TestCommon', + 'test_out'), + DecorateInfo(unittest.skip("Expected: dropout is not comparable"), + 'TestCommon', + 'test_out_warning'), + DecorateInfo(unittest.skip("Expected: dropout is not comparable"), + 'TestMathBits', + 'test_conj_view'), + DecorateInfo(unittest.skip("Expected: dropout is not comparable"), + 'TestMathBits', + 'test_neg_conj_view'), + DecorateInfo(unittest.skip("Expected: dropout is not comparable"), + 'TestMathBits', + 'test_neg_view'), + # dropout is not comparable + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_executor'), + DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'), + ) + ), + ElementwiseUnaryPythonRefInfo( + "_refs.nn.functional.elu", + torch_opinfo_name="nn.functional.elu", + supports_out=True, + decorators=[ + DecorateInfo( + toleranceOverride({ + torch.float16: tol(atol=1e-03, rtol=1.2e-03), + torch.bfloat16: tol(atol=1e-03, rtol=1.2e-03) + }), + 'TestUnaryUfuncs', device_type='cuda', + ), ], + ), + ElementwiseUnaryPythonRefInfo( + "_refs.nn.functional.hardtanh", + torch_opinfo_name="nn.functional.hardtanh", + supports_out=True, + ), + PythonRefInfo( # TODO: Port this to an UnaryOpInfo + "_refs.nn.functional.gelu", + torch_opinfo_name="nn.functional.gelu", + ), + PythonRefInfo( + "_refs.nn.functional.layer_norm", + torch_opinfo_name="nn.functional.layer_norm", + skips=( + # Reference result was farther (3.5762786809723224e-07) from the precise computation + # than the torch result was (2.5068410824946596e-07)! + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_python_ref', + dtypes=(torch.float32,), device_type='cpu'), + ), + ), + PythonRefInfo( + "_refs.nn.functional.glu", + torch_opinfo_name="nn.functional.glu", + supports_out=True, + ), + PythonRefInfo( + "_refs.nn.functional.pairwise_distance", + torch_opinfo_name="nn.functional.pairwise_distance", + supports_out=True, + ), + PythonRefInfo( + "_refs.nn.functional.pdist", + torch_opinfo_name="nn.functional.pdist", + supports_out=True, + skips=( + # RunTimeError: no _refs support for torch.Tensor.index_select + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref'), + # Reference result was farther (1.946091651916504e-05) from the precise + # computation than the torch result was (1.1920928955078125e-06)! + DecorateInfo( + unittest.expectedFailure, + 'TestCommon', + 'test_python_ref_torch_fallback', + dtypes=(torch.float32,), + device_type='cpu', + ), + )), + PythonRefInfo( + "_refs.nn.functional.leaky_relu", + torch_opinfo_name="nn.functional.leaky_relu", + supports_out=True, + ), + PythonRefInfo( + "_refs.nn.functional.log_softmax", + torch_opinfo_name="log_softmax", # alias + torch_opinfo_variant_name="with_dtype", + supports_out=False, + ), + PythonRefInfo( + "_refs.nn.functional.pixel_shuffle", + torch_opinfo_name="nn.functional.pixel_shuffle", + ), + PythonRefInfo( + "_refs.nn.functional.pixel_unshuffle", + torch_opinfo_name="nn.functional.pixel_unshuffle", + ), + PythonRefInfo( + "_refs.nn.functional.poisson_nll_loss", + torch_opinfo_name="nn.functional.poisson_nll_loss", + ), + ElementwiseUnaryPythonRefInfo( + "_refs.nn.functional.prelu", + torch_opinfo_name="nn.functional.prelu", + ), + ElementwiseUnaryPythonRefInfo( + "_refs.nn.functional.relu", + torch_opinfo_name="nn.functional.relu", + supports_out=True, + ), + ElementwiseUnaryPythonRefInfo( + "_refs.nn.functional.relu6", + torch_opinfo_name="nn.functional.relu6", + supports_out=True, + ), + ElementwiseUnaryPythonRefInfo( + "_refs.nn.functional.mish", + torch_opinfo_name="nn.functional.mish", + supports_out=True, + decorators=[ + DecorateInfo( + toleranceOverride({torch.float16: tol(atol=1e-02, rtol=1e-03)}), + 'TestUnaryUfuncs',), ], + ), + ElementwiseUnaryPythonRefInfo( + "_refs.nn.functional.selu", + torch_opinfo_name="nn.functional.selu", + supports_out=True, + decorators=[ + DecorateInfo( + toleranceOverride({ + torch.float16: tol(atol=1e-2, rtol=1.8e-2), + torch.bfloat16: tol(atol=1e-2, rtol=1.8e-2) + }), + 'TestUnaryUfuncs', device_type='cuda', + ), ], + ), + PythonRefInfo( + "_refs.nn.functional.softmax", + torch_opinfo_name="softmax", # alias + torch_opinfo_variant_name="with_dtype", + supports_out=False, + ), + PythonRefInfo( + "_refs.nn.functional.softmin", + torch_opinfo_name="nn.functional.softmin", + torch_opinfo_variant_name="with_dtype", + supports_out=False, + ), + ElementwiseUnaryPythonRefInfo( + "_refs.nn.functional.softplus", + torch_opinfo_name="nn.functional.softplus", + ), + PythonRefInfo( + "_refs.nn.functional.l1_loss", + torch_opinfo_name="nn.functional.l1_loss", + ), + PythonRefInfo( + "_refs.nn.functional.margin_ranking_loss", + torch_opinfo_name="nn.functional.margin_ranking_loss", + ), + PythonRefInfo( + "_refs.nn.functional.mse_loss", + torch_opinfo_name="nn.functional.mse_loss", + ), + PythonRefInfo( + "_refs.nn.functional.smooth_l1_loss", + torch_opinfo_name="nn.functional.smooth_l1_loss", + ), + PythonRefInfo( + "_refs.nn.functional.hinge_embedding_loss", + torch_opinfo_name="nn.functional.hinge_embedding_loss" + ), + PythonRefInfo( + "_refs.nn.functional.nll_loss", + torch_opinfo_name="nn.functional.nll_loss", + # The corresponding PyTorch op doesn't support out. But the ref is + # registered as a decomp and ATen has an out variant. + supports_out=True, + # For simpler indexing, we flatten target indices, then reshape the result tensor. + # This creates inconsistent view state with reference impl. + validate_view_consistency=False, + skips=( + # RuntimeError: It appears that you're trying to get value out of a tracing tensor - erroring out! + DecorateInfo( + unittest.expectedFailure, 'TestCommon', 'test_python_ref_executor', device_type="cuda" + ), + ), + ), + PythonRefInfo( + "_refs.nn.functional.huber_loss", + torch_opinfo_name="nn.functional.huber_loss", + # The corresponding PyTorch op doesn't support out. But the ref is + # registered as a decomp and ATen has an out variant. + supports_out=True, + ), + ElementwiseUnaryPythonRefInfo( + "_refs.nn.functional.tanhshrink", + torch_opinfo_name="nn.functional.tanhshrink", + decorators=[ + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_normal', + device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]), + DecorateInfo( + toleranceOverride({torch.bfloat16: tol(atol=1e-02, rtol=1.6e-02), + torch.complex64: tol(atol=6e-04, rtol=1e-05)}), + 'TestUnaryUfuncs', 'test_reference_numerics_extremal', device_type='cuda'), + ], + skips=( + # in each case, pytorch will produce a nan while numpy will not + DecorateInfo(unittest.skip("Fails on some jobs works on others!"), + 'TestUnaryUfuncs', "test_reference_numerics_large", + dtypes=(torch.complex64, torch.complex128), + active_if=(IS_MACOS)), + DecorateInfo(unittest.skip("Fails on some jobs works on others!"), + 'TestUnaryUfuncs', "test_reference_numerics_extremal", + dtypes=(torch.complex64, torch.complex128), + device_type='cpu', + active_if=(IS_MACOS or IS_WINDOWS)), + ), + ), + ElementwiseUnaryPythonRefInfo( + "_refs.nn.functional.hardshrink", + torch_opinfo_name="nn.functional.hardshrink", + ), + ElementwiseUnaryPythonRefInfo( + "_refs.nn.functional.softshrink", + torch_opinfo_name="nn.functional.softshrink", + ), + # + # Elementwise Binary Reference OpInfos + # + ElementwiseBinaryPythonRefInfo( + "_refs.add", + torch_opinfo_name="add", + # https://github.com/pytorch/pytorch/issues/76944 + supports_two_python_scalars=True, + supports_one_python_scalar=True, + decorators=( + DecorateInfo( + toleranceOverride({torch.chalf: tol(atol=1e-2, rtol=0)}), + 'TestBinaryUfuncs', 'test_reference_numerics'), + ), + skips=( + DecorateInfo(unittest.skip("Skipped!"), + 'TestBinaryUfuncs', + 'test_reference_numerics_extremal_values', + dtypes=(torch.complex64, torch.complex128)), + ), + ), + ElementwiseBinaryPythonRefInfo( + "_refs.atan2", + torch_opinfo_name="atan2", + ), + ElementwiseBinaryPythonRefInfo( + "_refs.bitwise_and", + torch_opinfo_name="bitwise_and", + ), + ElementwiseBinaryPythonRefInfo( + "_refs.bitwise_left_shift", + torch_opinfo_name="bitwise_left_shift", + skips=( + # https://github.com/pytorch/pytorch/issues/70904 + DecorateInfo(unittest.skip("Some inputs produce undefined outputs"), 'TestCommon', 'test_compare_cpu'), + ), + ), + ElementwiseBinaryPythonRefInfo( + "_refs.bitwise_right_shift", + torch_opinfo_name="bitwise_right_shift", + skips=( + # # https://github.com/pytorch/pytorch/issues/70904 + DecorateInfo(unittest.skip("Skipped some inputs produce undefined outputs"), 'TestCommon', 'test_compare_cpu'), + ), + ), + ElementwiseBinaryPythonRefInfo( + "_refs.bitwise_or", + torch_opinfo_name="bitwise_or", + ), + ElementwiseBinaryPythonRefInfo( + "_refs.bitwise_xor", + torch_opinfo_name="bitwise_xor", + ), + ElementwiseBinaryPythonRefInfo( + "_refs.copysign", + torch_opinfo_name="copysign", + skips=( + # RuntimeError: Expected divisor (b) to be on the same device (cuda:0) as dividend (a), but it is found on cpu! + DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', 'test_type_promotion'), + # FIXME output 0: meta disagrees with real impl + DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_binary_ufuncs_mixed_dtype'), + ) + ), + ElementwiseBinaryPythonRefInfo( + "_refs.div", + torch_opinfo_name="div", + torch_opinfo_variant_name="no_rounding_mode", + # https://github.com/pytorch/pytorch/issues/76944 + supports_two_python_scalars=True, + supports_one_python_scalar=True, + skips=( + # NotImplementedError: argument of type: + DecorateInfo( + unittest.skip("Skipped!"), 'TestCommon', 'test_python_ref_executor', + dtypes=(torch.complex32, torch.complex64, torch.complex128,) + ), + # Reference result was farther (0.7433461727239705) from the precise + # computation than the torch result was (nan)! + DecorateInfo( + unittest.expectedFailure, 'TestCommon', 'test_python_ref', + dtypes=(torch.complex32,), device_type="cuda" + ), + # Reference result was farther (0.7433461727239705) from the precise + # computation than the torch result was (nan)! + DecorateInfo( + unittest.expectedFailure, 'TestCommon', 'test_python_ref_torch_fallback', + dtypes=(torch.complex32,), device_type="cuda" + ), + ), + ), + ElementwiseBinaryPythonRefInfo( + "_refs.div", + torch_opinfo_name="div", + torch_opinfo_variant_name="trunc_rounding", + # https://github.com/pytorch/pytorch/issues/76944 + supports_two_python_scalars=True, + supports_one_python_scalar=True, + decorators=( + # See https://github.com/pytorch/pytorch/issues/111126 + DecorateInfo(unittest.expectedFailure, 'TestBinaryUfuncs', 'test_type_promotion'), + ), + ), + ElementwiseBinaryPythonRefInfo( + "_refs.div", + torch_opinfo_name="div", + torch_opinfo_variant_name="floor_rounding", + # https://github.com/pytorch/pytorch/issues/76944 + supports_two_python_scalars=True, + supports_one_python_scalar=True, + decorators=( + # See https://github.com/pytorch/pytorch/issues/111126 + DecorateInfo(unittest.expectedFailure, 'TestBinaryUfuncs', 'test_type_promotion'), + # Reference result was farther (nan) from the precise computation than the + # torch result was (inf)! + DecorateInfo( + unittest.expectedFailure, + "TestCommon", + "test_python_ref", + dtypes=(torch.bfloat16,), + device_type="cpu", + active_if=not IS_S390X, + ), + ), + ), + ElementwiseBinaryPythonRefInfo( + "_refs.eq", + torch_opinfo_name="eq", + ), + ElementwiseBinaryPythonRefInfo( + "_refs.float_power", + torch_opinfo_name="float_power", + skips=( + # Test doesn't account for float -> double type promotion + DecorateInfo(unittest.expectedFailure, 'TestBinaryUfuncs', 'test_type_promotion'), + # Complex values error with: Greatest absolute difference: nan at index + DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', + 'test_reference_numerics_small_values', + dtypes=[torch.complex64, torch.complex128]), + DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', + 'test_reference_numerics_large_values', + dtypes=[torch.complex64, torch.complex128]), + DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', + 'test_reference_numerics_extremal_values', + dtypes=[torch.complex64, torch.complex128]), + ), + ), + ElementwiseBinaryPythonRefInfo( + "_refs.logaddexp", + torch_opinfo_name="logaddexp", + skips=( + # failure due to mismatch in edge cases, which boils down to what torch.exp(inf + infj) should be + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref', + dtypes=(torch.complex32, torch.complex64, torch.complex128)), + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_torch_fallback', + dtypes=(torch.complex32, torch.complex64, torch.complex128)), + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_executor', + dtypes=(torch.complex32, torch.complex64, torch.complex128)), + ), + ), + PythonRefInfo( + "_refs.logaddexp2", + torch_opinfo_name="logaddexp2", + ), + ElementwiseBinaryPythonRefInfo( + "_refs.floor_divide", + torch_opinfo_name="floor_divide", + rhs_make_tensor_kwargs=dict(exclude_zero=True), + # https://github.com/pytorch/pytorch/issues/76944 + supports_two_python_scalars=True, + supports_one_python_scalar=True, + # bfloat16 floor_divide compared with a float32 reference works inconsistently + skips=( + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_python_ref', + dtypes=(torch.bfloat16,)), + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_python_ref_torch_fallback', + dtypes=(torch.bfloat16,)), + # bfloat16 floor_divide compared with a float32 reference works inconsistently + DecorateInfo(unittest.skip('Skipped!'), 'TestBinaryUfuncs', + dtypes=(torch.bfloat16,)), + # int8 floor divide has different results for -128 // -1 vs. NumPy + DecorateInfo(unittest.skip('Skipped!'), 'TestBinaryUfuncs', + 'test_reference_numerics_small_values', + dtypes=(torch.int8,)), + # The following tests fails on some jobs + DecorateInfo(unittest.skip('Skipped!'), 'TestBinaryUfuncs', + 'test_reference_numerics_extremal_values', + dtypes=(torch.float16,)), + DecorateInfo(toleranceOverride({torch.float16: tol(atol=1e-3, rtol=5e-3)}), + 'TestBinaryUfuncs', 'test_reference_numerics'), + # FIXME output 0: meta disagrees with real impl + DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_binary_ufuncs_mixed_dtype'), + ), + ), + ElementwiseBinaryPythonRefInfo( + "_refs.fmax", + torch_opinfo_name="fmax", + supports_rhs_python_scalar=False, + ), + ElementwiseBinaryPythonRefInfo( + "_refs.fmin", + torch_opinfo_name="fmin", + supports_rhs_python_scalar=False, + ), + ElementwiseBinaryPythonRefInfo( + "_refs.fmod", + torch_opinfo_name="fmod", + rhs_make_tensor_kwargs={'exclude_zero': True}, + supports_rhs_python_scalar=True, + skips=( + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_python_ref', + dtypes=(torch.bfloat16,), device_type='cpu'), + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_python_ref_torch_fallback', + dtypes=(torch.bfloat16,), device_type='cpu'), + DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', + 'test_contig_vs_every_other', + dtypes=(torch.bfloat16,)), + DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', + 'test_non_contig', + dtypes=(torch.bfloat16,)), + DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', + 'test_reference_numerics', + dtypes=(torch.bfloat16,)), + DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', + 'test_reference_numerics_small_values', + dtypes=(torch.uint8,)), + ), + ), + ElementwiseBinaryPythonRefInfo( + "_refs.gcd", + torch_opinfo_name="gcd", + skips=( + DecorateInfo(unittest.expectedFailure, + 'TestBinaryUfuncs', + 'test_reference_numerics_small_values', + dtypes=(torch.int8,)), + ), + ), + ElementwiseBinaryPythonRefInfo( + "_refs.ge", + torch_opinfo_name="ge", + ), + ElementwiseBinaryPythonRefInfo( + "_refs.gt", + torch_opinfo_name="gt", + ), + ElementwiseBinaryPythonRefInfo( + "_refs.heaviside", + torch_opinfo_name="heaviside", + supports_rhs_python_scalar=False, + skips=( + # PyTorch's heaviside does not appear to propagate NaNs + DecorateInfo(unittest.skip("Skipped!"), + 'TestBinaryUfuncs', + 'test_reference_numerics_extremal_values'), + ), + ), + ElementwiseBinaryPythonRefInfo( + "_refs.hypot", + torch_opinfo_name="hypot", + supports_rhs_python_scalar=False, + ), + ElementwiseBinaryPythonRefInfo( + "_refs.igamma", + torch_opinfo_name="igamma", + ), + ElementwiseBinaryPythonRefInfo( + "_refs.igammac", + torch_opinfo_name="igammac", + ), + ElementwiseBinaryPythonRefInfo( + "_refs.isclose", + torch_opinfo_name="isclose", + skips=( + # Intentional xfail -- isclose does not type promote + DecorateInfo(unittest.expectedFailure, 'TestBinaryUfuncs', 'test_type_promotion'), + DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_binary_ufuncs_mixed_dtype'), + DecorateInfo(unittest.skip("Skipped!"), + 'TestBinaryUfuncs', + 'test_reference_numerics_extremal_values'), + ), + ), + ElementwiseBinaryPythonRefInfo( + "_refs.lcm", + torch_opinfo_name="lcm", + ), + ElementwiseBinaryPythonRefInfo( + "_refs.le", + torch_opinfo_name="le", + ), + ElementwiseBinaryPythonRefInfo( + "_refs.logical_and", + torch_opinfo_name="logical_and", + ), + ElementwiseUnaryPythonRefInfo( + "_refs.logical_not", + torch_opinfo_name="logical_not", + ), + ElementwiseBinaryPythonRefInfo( + "_refs.logical_or", + torch_opinfo_name="logical_or", + ), + ElementwiseBinaryPythonRefInfo( + "_refs.logical_xor", + torch_opinfo_name="logical_xor", + ), + ElementwiseBinaryPythonRefInfo( + "_refs.lt", + torch_opinfo_name="lt", + ), + ElementwiseBinaryPythonRefInfo( + "_refs.maximum", + torch_opinfo_name="maximum", + skips=( + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_errors'), + ), + ), + ElementwiseBinaryPythonRefInfo( + "_refs.minimum", + torch_opinfo_name="minimum", + skips=( + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_errors'), + ), + ), + ElementwiseBinaryPythonRefInfo( + "_refs.mul", + torch_opinfo_name="mul", + # https://github.com/pytorch/pytorch/issues/76944 + supports_two_python_scalars=True, + supports_one_python_scalar=True, + skips=( + # Reference result was farther (0.0) from the precise computation + # than the torch result was (nan)! + DecorateInfo( + unittest.expectedFailure, 'TestCommon', 'test_python_ref_executor', + dtypes=(torch.complex32,), + ), + # Reference result was farther (0.0) from the precise computation + # than the torch result was (nan)! + DecorateInfo( + unittest.expectedFailure, 'TestCommon', 'test_python_ref', + dtypes=(torch.complex32,), device_type='cuda' + ), + # Reference result was farther (0.0) from the precise computation + # than the torch result was (nan)! + DecorateInfo( + unittest.expectedFailure, 'TestCommon', 'test_python_ref_torch_fallback', + dtypes=(torch.complex32,), device_type='cuda' + ), + ) + ), + ElementwiseBinaryPythonRefInfo( + "_refs.ne", + torch_opinfo_name="ne", + ), + ElementwiseBinaryPythonRefInfo( + "_refs.nextafter", + torch_opinfo_name="nextafter", + ), + ElementwiseBinaryPythonRefInfo( + "_refs.pow", + torch_opinfo_name="pow", + decorators=( + DecorateInfo( + toleranceOverride({torch.complex64: tol(atol=1e-4, rtol=1.3e-05)}), + 'TestBinaryUfuncs', 'test_reference_numerics'), + DecorateInfo( + toleranceOverride({torch.complex64: tol(atol=1e-4, rtol=1.3e-05), + torch.complex128: tol(atol=1e-4, rtol=1.3e-05)}), + 'TestBinaryUfuncs', 'test_scalar_support'), + ), + skips=( + # Reference result was farther (inf) from the precise + # computation than the torch result was (nan)! + DecorateInfo( + unittest.expectedFailure, 'TestCommon', 'test_python_ref_executor', + dtypes=(torch.complex32,), + ), + # Reference result was farther (inf) from the precise + # computation than the torch result was (nan)! + DecorateInfo( + unittest.expectedFailure, 'TestCommon', 'test_python_ref', + dtypes=(torch.complex32,), device_type="cuda" + ), + # Reference result was farther (inf) from the precise + # computation than the torch result was (nan)! + DecorateInfo( + unittest.expectedFailure, 'TestCommon', 'test_python_ref_torch_fallback', + dtypes=(torch.complex32,), device_type="cuda" + ), + # Skipping integers because they are being raised to negative powers causing an error + DecorateInfo(unittest.expectedFailure, 'TestBinaryUfuncs', + 'test_reference_numerics_small_values', + dtypes=[torch.int8, torch.int16, torch.int32, torch.int64]), + DecorateInfo(unittest.expectedFailure, 'TestBinaryUfuncs', + 'test_reference_numerics_large_values', + dtypes=[torch.int16, torch.int32, torch.int64]), + DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', + 'test_reference_numerics', + dtypes=(torch.complex32,)), + DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', + 'test_reference_numerics_small_values', + dtypes=(torch.complex32, torch.complex64, torch.complex128)), + DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', + 'test_reference_numerics_large_values', + dtypes=(torch.complex32, torch.complex64, torch.complex128)), + DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', + 'test_reference_numerics_extremal_values', + dtypes=(torch.complex32, torch.complex64, torch.complex128)), + ), + ), + ElementwiseBinaryPythonRefInfo( + "_refs.remainder", + torch_opinfo_name="remainder", + skips=( + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_python_ref', + dtypes=(torch.bfloat16,), device_type='cpu'), + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_python_ref_torch_fallback', + dtypes=(torch.bfloat16,), device_type='cpu'), + DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', + 'test_reference_numerics', + dtypes=(torch.bfloat16,)), + DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', + 'test_reference_numerics_small_values', + dtypes=(torch.uint8,)), + ), + ), + ElementwiseBinaryPythonRefInfo( + "_refs.rsub", + torch_opinfo_name="rsub", + # https://github.com/pytorch/pytorch/issues/76944 + skips=( + # Reference result was farther (nan) from the precise computation than + # the torch result was (nan)! + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref', + dtypes=(torch.chalf,), device_type='cpu'), + # Reference result was farther (nan) from the precise computation than + # the torch result was (nan)! + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_torch_fallback', + dtypes=(torch.chalf,), device_type='cpu'), + ), + ), + ElementwiseBinaryPythonRefInfo( + "_refs.sub", + torch_opinfo_name="sub", + # https://github.com/pytorch/pytorch/issues/76944 + supports_two_python_scalars=True, + supports_one_python_scalar=True, + decorators=( + DecorateInfo( + toleranceOverride({torch.float16: tol(atol=1e-2, rtol=0), + torch.bfloat16: tol(atol=1e-5, rtol=5e-3), + torch.complex32: tol(atol=1e-5, rtol=1e-3)}), + 'TestBinaryUfuncs', 'test_reference_numerics'), + DecorateInfo( + toleranceOverride({torch.chalf: tol(atol=1e-2, rtol=0)}), + 'TestCommon', 'test_complex_half_reference_testing', device_type='cpu'), + DecorateInfo( + toleranceOverride({torch.chalf: tol(atol=5e-3, rtol=0)}), + 'TestDecomp', 'test_comprehensive', device_type='cpu'), + DecorateInfo( + toleranceOverride({torch.chalf: tol(atol=5e-3, rtol=0)}), + 'TestDecomp', 'test_quick', device_type='cpu'), + ), + skips=( + DecorateInfo(unittest.skip("Skipped!"), + 'TestBinaryUfuncs', + 'test_reference_numerics', + dtypes=(torch.uint8,)), + DecorateInfo(unittest.skip("Skipped!"), + 'TestBinaryUfuncs', + 'test_reference_numerics_small_values', + dtypes=(torch.uint8,)), + ), + ), + ElementwiseBinaryPythonRefInfo( + "_refs.true_divide", + torch_opinfo_name="true_divide", + # https://github.com/pytorch/pytorch/issues/76944 + supports_two_python_scalars=True, + supports_one_python_scalar=True, + skips=( + # Reference result was farther (0.7433461727239705) from the precise + # computation than the torch result was (nan)! + DecorateInfo( + unittest.expectedFailure, 'TestCommon', 'test_python_ref_executor', + dtypes=(torch.complex32,), + ), + # Reference result was farther (0.7433461727239705) from the precise + # computation than the torch result was (nan)! + DecorateInfo( + unittest.expectedFailure, 'TestCommon', 'test_python_ref', + dtypes=(torch.complex32,), device_type="cuda" + ), + # Reference result was farther (0.7433461727239705) from the precise + # computation than the torch result was (nan)! + DecorateInfo( + unittest.expectedFailure, 'TestCommon', 'test_python_ref_torch_fallback', + dtypes=(torch.complex32,), device_type="cuda" + ), + ), + ), + # + # Elementwise Ternary Reference OpInfos + # + PythonRefInfo( + "_refs.addcdiv", + torch_opinfo_name="addcdiv", + ), + PythonRefInfo( + "_refs.addcmul", + torch_opinfo_name="addcmul", + skips=( + # Reference result was farther (1.3343989849090576e-05) + # from the precise computation than the torch result + # was (9.592622518539429e-06)! + # FIXME: enable dtype-based tolerances in test_ops.py:TestCommon._ref_test_helper + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_python_ref', + dtypes=(torch.float16,), device_type="cpu"), + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_python_ref_torch_fallback', + dtypes=(torch.float16,), device_type="cpu"), + ), + ), + ElementwiseBinaryPythonRefInfo( + "_refs.clamp_min", + torch_opinfo_name="clamp_min", + skips=( + # test error disabled since rhs non-tensor python scalar is supported + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_errors'), + ), + ), + ElementwiseBinaryPythonRefInfo( + "_refs.clamp_max", + torch_opinfo_name="clamp_max", + skips=( + # test error disabled since rhs non-tensor python scalar is supported + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_errors'), + ), + ), + PythonRefInfo( + "_refs.clamp", + torch_opinfo_name="clamp", + ), + PythonRefInfo( + "_refs.nn.functional.triplet_margin_loss", + torch_opinfo_name="nn.functional.triplet_margin_loss", + supports_out=False, + # TODO: Uses minimum and clamp + skips=( + # AssertionError: Tensor-likes are not close! + # Greatest absolute difference: 6.103515625e-05 at index (4,) (up to 1e-05 allowed) + # Greatest relative difference: 8.519846983548175e-06 at index (4,) (up to 1.3e-06 allowed) + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_python_ref', + dtypes=(torch.uint8,), device_type="cpu"), + ) + ), + ElementwiseBinaryPythonRefInfo( + "_refs.xlogy", + torch_opinfo_name="xlogy", + supports_one_python_scalar=True, + ), + # + # Elementwise Binary Special OpInfos + # + ElementwiseBinaryPythonRefInfo( + "_refs.special.xlog1py", + torch_opinfo_name="special.xlog1py", + supports_one_python_scalar=True, + ), + # + # Data Conversion & Data Movement Opinfos + # + ElementwiseUnaryPythonRefInfo( + "_refs._conversions.bfloat16", + torch_opinfo_name="bfloat16", + # TODO: If self already has the correct dtype and device, then self is + # returned ignoring memory_format. + # https://github.com/pytorch/pytorch/issues/86558 + validate_view_consistency=False, + ), + ElementwiseUnaryPythonRefInfo( + "_refs._conversions.bool", + torch_opinfo_name="bool", + # TODO: If self already has the correct dtype and device, then self is + # returned ignoring memory_format. + # https://github.com/pytorch/pytorch/issues/86558 + validate_view_consistency=False, + ), + ElementwiseUnaryPythonRefInfo( + "_refs._conversions.byte", + torch_opinfo_name="byte", + # TODO: If self already has the correct dtype and device, then self is + # returned ignoring memory_format. + # https://github.com/pytorch/pytorch/issues/86558 + validate_view_consistency=False, + skips=( + DecorateInfo(unittest.skip('Overflow when downcasting signed type is undefined'), 'TestCommon', 'test_compare_cpu'), + ) + ), + ElementwiseUnaryPythonRefInfo( + "_refs._conversions.char", + torch_opinfo_name="char", + # TODO: If self already has the correct dtype and device, then self is + # returned ignoring memory_format. + # https://github.com/pytorch/pytorch/issues/86558 + validate_view_consistency=False, + skips=( + DecorateInfo(unittest.skip('Overflow when downcasting signed type is undefined'), 'TestCommon', 'test_compare_cpu'), + ) + ), + ElementwiseBinaryPythonRefInfo( + "_refs._conversions.complex", + torch_opinfo_name="complex", + error_inputs_func=partial(error_inputs_complex, is_ref=True), + skips=( + # Tests don't account for complex's type promotion semantics + DecorateInfo(unittest.expectedFailure, 'TestBinaryUfuncs', 'test_type_promotion'), + DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_binary_ufuncs_mixed_dtype'), + ) + ), + ElementwiseBinaryPythonRefInfo( + "_refs._conversions.polar", + torch_opinfo_name="polar", + skips=( + # Tests don't account for complex's type promotion semantics + DecorateInfo(unittest.expectedFailure, 'TestBinaryUfuncs', 'test_type_promotion'), + DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_binary_ufuncs_mixed_dtype'), + ) + ), + ElementwiseUnaryPythonRefInfo( + "_refs._conversions.double", + torch_opinfo_name="double", + # TODO: If self already has the correct dtype and device, then self is + # returned ignoring memory_format. + # https://github.com/pytorch/pytorch/issues/86558 + validate_view_consistency=False, + ), + ElementwiseUnaryPythonRefInfo( + "_refs._conversions.float", + torch_opinfo_name="float", + # TODO: If self already has the correct dtype and device, then self is + # returned ignoring memory_format. + # https://github.com/pytorch/pytorch/issues/86558 + validate_view_consistency=False, + ), + ElementwiseUnaryPythonRefInfo( + "_refs._conversions.half", + torch_opinfo_name="half", + # TODO: If self already has the correct dtype and device, then self is + # returned ignoring memory_format. + # https://github.com/pytorch/pytorch/issues/86558 + validate_view_consistency=False, + ), + ElementwiseUnaryPythonRefInfo( + "_refs._conversions.int", + torch_opinfo_name="int", + # TODO: If self already has the correct dtype and device, then self is + # returned ignoring memory_format. + # https://github.com/pytorch/pytorch/issues/86558 + validate_view_consistency=False, + skips=( + DecorateInfo(unittest.skip('Overflow when downcasting signed type is undefined'), 'TestCommon', 'test_compare_cpu'), + ) + ), + ElementwiseUnaryPythonRefInfo( + "_refs._conversions.long", + torch_opinfo_name="long", + # TODO: If self already has the correct dtype and device, then self is + # returned ignoring memory_format. + # https://github.com/pytorch/pytorch/issues/86558 + validate_view_consistency=False, + skips=( + DecorateInfo(unittest.skip('Overflow when downcasting signed type is undefined'), 'TestCommon', 'test_compare_cpu'), + ) + ), + ElementwiseUnaryPythonRefInfo( + "_refs._conversions.short", + torch_opinfo_name="short", + # TODO: If self already has the correct dtype and device, then self is + # returned ignoring memory_format. + # https://github.com/pytorch/pytorch/issues/86558 + validate_view_consistency=False, + skips=( + DecorateInfo(unittest.skip('Overflow when downcasting signed type is undefined'), 'TestCommon', 'test_compare_cpu'), + ) + ), + ElementwiseUnaryPythonRefInfo( + "_refs._conversions.chalf", + torch_opinfo_name="chalf", + # TODO: If self already has the correct dtype and device, then self is + # returned ignoring memory_format. + # https://github.com/pytorch/pytorch/issues/86558 + validate_view_consistency=False, + ), + ElementwiseUnaryPythonRefInfo( + "_refs._conversions.cfloat", + torch_opinfo_name="cfloat", + # TODO: If self already has the correct dtype and device, then self is + # returned ignoring memory_format. + # https://github.com/pytorch/pytorch/issues/86558 + validate_view_consistency=False, + ), + ElementwiseUnaryPythonRefInfo( + "_refs._conversions.cdouble", + torch_opinfo_name="cdouble", + # TODO: If self already has the correct dtype and device, then self is + # returned ignoring memory_format. + # https://github.com/pytorch/pytorch/issues/86558 + validate_view_consistency=False, + ), + PythonRefInfo( + "_refs.clone", + torch_opinfo_name="clone", + ), + # + # View & Shape OpInfos + # + PythonRefInfo( + "_refs.alias_copy", + torch_opinfo_name="alias_copy", + supports_out=True, + ), + PythonRefInfo( + "_refs.atleast_1d", + torch_opinfo_name="atleast_1d", + validate_view_consistency=False, + ), + PythonRefInfo( + "_refs.atleast_2d", + torch_opinfo_name="atleast_2d", + validate_view_consistency=False, + ), + PythonRefInfo( + "_refs.atleast_3d", + torch_opinfo_name="atleast_3d", + validate_view_consistency=False, + ), + PythonRefInfo( + "_refs.as_strided", + torch_opinfo_name="as_strided", + # FIXME: doesn't support chalf + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), + skips=( + # cloned_mutable_input.is_same(returned_output) INTERNAL ASSERT FAILED + DecorateInfo(unittest.skip("Errors when storage_offset is included"), 'TestMathBits', 'test_neg_view'), + DecorateInfo(unittest.skip("Errors when storage_offset is included"), 'TestMathBits', 'test_conj_view'), + DecorateInfo(unittest.skip("Errors when storage_offset is included"), 'TestMathBits', 'test_neg_conj_view'), + ), + ), + PythonRefInfo( + "_refs.as_strided_copy", + torch_opinfo_name="as_strided_copy", + supports_out=True, + # FIXME: doesn't support chalf + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), + skips=( + # cloned_mutable_input.is_same(returned_output) INTERNAL ASSERT FAILED + DecorateInfo(unittest.skip("Errors when storage_offset is included"), 'TestMathBits', 'test_neg_view'), + DecorateInfo(unittest.skip("Errors when storage_offset is included"), 'TestMathBits', 'test_conj_view'), + DecorateInfo(unittest.skip("Errors when storage_offset is included"), 'TestMathBits', 'test_neg_conj_view'), + # The view function this decompose into does not have a ref + DecorateInfo(unittest.expectedFailure, "TestCommon", "test_python_ref"), + ), + ), + PythonRefInfo( + "_refs.as_strided", + torch_opinfo_name="as_strided", + torch_opinfo_variant_name="partial_views", + # FIXME: doesn't support chalf + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), + skips=( + # cloned_mutable_input.is_same(returned_output) INTERNAL ASSERT FAILED + DecorateInfo(unittest.skip("Errors when storage_offset is included"), 'TestMathBits', 'test_neg_view'), + DecorateInfo(unittest.skip("Errors when storage_offset is included"), 'TestMathBits', 'test_conj_view'), + DecorateInfo(unittest.skip("Errors when storage_offset is included"), 'TestMathBits', 'test_neg_conj_view'), + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_compare_cpu'), + ), + ), + PythonRefInfo( + "_refs.as_strided_scatter", + torch_opinfo_name="as_strided_scatter", + # returns a view of an intermediate tensor (as_strided) + validate_view_consistency=False, + ), + PythonRefInfo( + "_refs.block_diag", + torch_opinfo_name="block_diag", + ), + PythonRefInfo( + "_refs.broadcast_shapes", + torch_opinfo_name="broadcast_shapes", + ), + PythonRefInfo( + "_refs.broadcast_tensors", + torch_opinfo_name="broadcast_tensors", + ), + PythonRefInfo( + "_refs.broadcast_to", + torch_opinfo_name="broadcast_to", + ), + PythonRefInfo( + "_refs.cat", + torch_opinfo_name="cat", + skips=( + # FIXME: AssertionError: RuntimeError not raised + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_errors'), + ), + ), + PythonRefInfo( + "_refs.chunk", + torch_opinfo_name="chunk", + ), + PythonRefInfo( + "_refs.column_stack", + torch_opinfo_name="column_stack", + ), + ElementwiseUnaryPythonRefInfo( + "_refs.conj", + torch_opinfo_name="conj", + ), + PythonRefInfo( + "_refs.constant_pad_nd", + torch_opinfo_name="constant_pad_nd", + ), + PythonRefInfo( + "_refs.contiguous", + torch_opinfo_name="contiguous", + ), + ElementwiseUnaryPythonRefInfo( + "_refs.deg2rad", + torch_opinfo_name="deg2rad", + decorators=(precisionOverride({torch.bfloat16: 7e-1, + torch.float16: 7e-1}),), + ), + PythonRefInfo( + "_refs.dsplit", + torch_opinfo_name="dsplit", + ), + PythonRefInfo( + "_refs.diag", + torch_opinfo_name="diag", + ), + PythonRefInfo( + "_refs.diagonal", + torch_opinfo_name="diagonal", + ), + PythonRefInfo( + "_refs.diagonal_copy", + torch_opinfo_name="diagonal_copy", + supports_out=True, + ), + PythonRefInfo( + "_refs.diagonal_scatter", + torch_opinfo_name="diagonal_scatter", + supports_out=True, + # returns a view of an intermediate tensor (as_strided) + validate_view_consistency=False, + ), + PythonRefInfo( + "_refs.diag_embed", + torch_opinfo_name="diag_embed", + supports_out=True, + ), + PythonRefInfo( + "_refs.dstack", + torch_opinfo_name="dstack", + skips=( + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_errors'), + ), + ), + PythonRefInfo( + "_refs.expand", + torch_opinfo_name="expand", + ), + PythonRefInfo( + "_refs.expand_as", + torch_opinfo_name="expand_as", + ), + PythonRefInfo( + "_refs.expand_copy", + torch_opinfo_name="expand_copy", + supports_out=True, + ), + PythonRefInfo( + "_refs.flatten", + torch_opinfo_name="flatten", + ), + PythonRefInfo( + "_refs.flip", + torch_opinfo_name="flip", + ), + PythonRefInfo( + "_refs.fliplr", + torch_opinfo_name="fliplr", + ), + PythonRefInfo( + "_refs.flipud", + torch_opinfo_name="flipud", + ), + PythonRefInfo( + "_refs.hstack", + torch_opinfo_name="hstack", + skips=( + # https://github.com/pytorch/pytorch/issues/78613 + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_errors'), + ), + ), + PythonRefInfo( + "_refs.narrow", + torch_opinfo_name="narrow", + error_inputs_func=partial(error_inputs_narrow_narrow_copy, is_narrow=True, is_ref=True), + ), + PythonRefInfo( + "_refs.narrow_copy", + torch_opinfo_name="narrow_copy", + supports_out=True, + error_inputs_func=partial(error_inputs_narrow_narrow_copy, is_narrow=False, is_ref=True), + skips=( + # The view function this decompose into does not have a ref + DecorateInfo(unittest.expectedFailure, "TestCommon", "test_python_ref"), + ), + ), + PythonRefInfo( + "_refs.nn.functional.group_norm", + torch_opinfo_name="nn.functional.group_norm", + validate_view_consistency=False, + ), + PythonRefInfo( + "_refs.native_layer_norm", + torch_opinfo_name="native_layer_norm", + skips=( + DecorateInfo(unittest.skip("Skipped!"), "TestCommon", "test_python_ref", + device_type="cpu", dtypes=(torch.float32,)), + DecorateInfo(unittest.skip("Skipped!"), "TestCommon", "test_python_ref_torch_fallback", + device_type="cpu", dtypes=(torch.float32,)), + ), + ), + PythonRefInfo( + "_refs.permute", + torch_opinfo_name="permute", + ), + PythonRefInfo( + "_refs.permute_copy", + torch_opinfo_name="permute_copy", + supports_out=True, + ), + ElementwiseUnaryPythonRefInfo( + "_refs.rad2deg", + torch_opinfo_name="rad2deg", + decorators=(precisionOverride({torch.bfloat16: 7e-1, + torch.float16: 7e-1}),), + ), + PythonRefInfo( + "_refs.ravel", + torch_opinfo_name="ravel", + ), + PythonRefInfo( + "_refs.renorm", + torch_opinfo_name="renorm", + ), + PythonRefInfo( + "_refs.repeat", + torch_opinfo_name="repeat", + validate_view_consistency=False, + ), + PythonRefInfo( + "_refs.reshape", + torch_opinfo_name="reshape", + ), + PythonRefInfo( + "_refs.reshape_as", + torch_opinfo_name="reshape_as", + ), + PythonRefInfo( + "_refs.roll", + torch_opinfo_name="roll", + validate_view_consistency=False, + ), + PythonRefInfo( + "_refs.rot90", + torch_opinfo_name="rot90", + validate_view_consistency=False, + ), + PythonRefInfo( + "_refs.select_scatter", + torch_opinfo_name="select_scatter", + ), + PythonRefInfo( + "_refs.stack", + torch_opinfo_name="stack", + validate_view_consistency=False, + ), + PythonRefInfo( + "_refs.squeeze", + torch_opinfo_name="squeeze", + ), + PythonRefInfo( + "_refs.squeeze_copy", + torch_opinfo_name="squeeze_copy", + supports_out=True, + ), + PythonRefInfo( + "_refs.squeeze", + torch_opinfo_name="squeeze", + torch_opinfo_variant_name="multiple", + ), + PythonRefInfo( + "_refs.tensor_split", + torch_opinfo_name="tensor_split", + skips=( + # RuntimeError: no _refs support for torch.Tensor.tolist + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref'), + ), + ), + PythonRefInfo( + "_refs.hsplit", + torch_opinfo_name="hsplit", + ), + PythonRefInfo( + "_refs.vsplit", + torch_opinfo_name="vsplit", + ), + PythonRefInfo( + "_refs.dot", + torch_opinfo_name="dot", + error_inputs_func=partial(error_inputs_dot_vdot, is_ref=True), + # .conj() does not set ._is_view() correctly in ATen + validate_view_consistency=False, + skips=( + # RuntimeError: no _refs support for torch.Tensor.is_conj + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref', dtypes=[torch.complex64, torch.complex128]), + ), + ), + PythonRefInfo( + "_refs.vdot", + torch_opinfo_name="vdot", + error_inputs_func=partial(error_inputs_dot_vdot, is_ref=True), + # .conj() does not set ._is_view() correctly in ATen + validate_view_consistency=False, + skips=( + # RuntimeError: no _refs support for torch.Tensor.is_conj + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref', dtypes=[torch.complex64, torch.complex128]), + ), + ), + PythonRefInfo( + "_refs.transpose", + torch_opinfo_name="transpose", + ), + PythonRefInfo( + "_refs.transpose_copy", + torch_opinfo_name="transpose_copy", + supports_out=True, + ), + PythonRefInfo( + "_refs.t", + torch_opinfo_name="t", + ), + PythonRefInfo( + "_refs.t_copy", + torch_opinfo_name="t_copy", + supports_out=True, + ), + PythonRefInfo( + "_refs.T", + torch_opinfo_name="T", + error_inputs_func=partial(error_inputs_T, has_ndims_error=True), + ), + PythonRefInfo( + "_refs.unbind_copy", + torch_opinfo_name="unbind_copy", + ), + PythonRefInfo( + "_refs.unfold", + torch_opinfo_name="unfold", + ), + PythonRefInfo( + "_refs.unfold_copy", + torch_opinfo_name="unfold_copy", + supports_out=True, + ), + PythonRefInfo( + "_refs.unsqueeze", + torch_opinfo_name="unsqueeze", + ), + PythonRefInfo( + "_refs.unsqueeze_copy", + torch_opinfo_name="unsqueeze_copy", + supports_out=True, + ), + PythonRefInfo( + "_refs.view", + torch_opinfo_name="view", + ), + PythonRefInfo( + "_refs.view_as", + torch_opinfo_name="view_as", + ), + PythonRefInfo( + "_refs.view_copy", + torch_opinfo_name="view_copy", + supports_out=True, + ), + PythonRefInfo( + "_refs.vstack", + torch_opinfo_name="vstack", + skips=( + # https://github.com/pytorch/pytorch/issues/78613 + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_errors'), + ), + ), + PythonRefInfo( + "_refs.unflatten", + torch_opinfo_name="unflatten", + ), + PythonRefInfo( + "_refs.unbind", + torch_opinfo_name="unbind", + ), + # + # Reduction Reference OpInfos + # + ReductionPythonRefInfo( + "_refs.all", + torch_opinfo_name="all", + skips=( + # FIXME: uint8 input returns uint8 instead of bool + DecorateInfo( + unittest.expectedFailure, 'TestReductions', 'test_result_dtype', + dtypes=[torch.uint8]), + ), + ), + ReductionPythonRefInfo( + "_refs.amax", + torch_opinfo_name="amax", + error_inputs_func=partial(error_inputs_aminmax_amax_amin, is_ref=True), + skips=( + # FIXME: reduces all dimensions when dim=[] + DecorateInfo( + unittest.expectedFailure, 'TestReductions', 'test_dim_empty'), + DecorateInfo( + unittest.expectedFailure, 'TestReductions', 'test_dim_empty_keepdim'), + ), + ), + ReductionPythonRefInfo( + "_refs.amin", + torch_opinfo_name="amin", + error_inputs_func=partial(error_inputs_aminmax_amax_amin, is_ref=True), + skips=( + # FIXME: reduces all dimensions when dim=[] + DecorateInfo( + unittest.expectedFailure, 'TestReductions', 'test_dim_empty'), + DecorateInfo( + unittest.expectedFailure, 'TestReductions', 'test_dim_empty_keepdim'), + ), + ), + ReductionPythonRefInfo( + "_refs.any", + torch_opinfo_name="any", + skips=( + # FIXME: uint8 input returns uint8 instead of bool + DecorateInfo( + unittest.expectedFailure, 'TestReductions', 'test_result_dtype', + dtypes=[torch.uint8]), + ), + ), + ReductionPythonRefInfo( + "_refs.count_nonzero", + torch_opinfo_name="count_nonzero", + skips=( + # FIXME: count_nonzero does not accept keepdim kwarg + DecorateInfo( + unittest.skip("Skipped!"), 'TestReductions', + 'test_dim_default_keepdim'), + DecorateInfo( + unittest.skip("Skipped!"), 'TestReductions', 'test_dim_none_keepdim'), + DecorateInfo( + unittest.skip("Skipped!"), 'TestReductions', 'test_dim_single_keepdim'), + DecorateInfo( + unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty_keepdim'), + DecorateInfo( + unittest.skip("Skipped!"), 'TestReductions', 'test_dim_multi_keepdim'), + DecorateInfo( + unittest.skip("Skipped!"), 'TestReductions', + 'test_dim_multi_unsorted_keepdim'), + # FIXME: dim=[] reduces all dimensions + DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty'), + ), + ), + ReductionPythonRefInfo( + "_refs.mean", + torch_opinfo_name="mean", + supports_out=True, + error_inputs_func=partial(error_inputs_mean, is_ref=True), + skips=( + # FIXME: reduces all dimensions when dim=[] + DecorateInfo( + unittest.expectedFailure, 'TestReductions', 'test_dim_empty'), + DecorateInfo( + unittest.expectedFailure, 'TestReductions', 'test_dim_empty_keepdim'), + ), + ), + ReductionPythonRefInfo( + "_refs.std", + torch_opinfo_name="std", + supports_out=True, + skips=( + # FIXME: reduces all dimensions when dim=[] + DecorateInfo( + unittest.expectedFailure, 'TestReductions', 'test_dim_empty'), + DecorateInfo( + unittest.expectedFailure, 'TestReductions', 'test_dim_empty_keepdim'), + # FIXME: improve precision + DecorateInfo( + unittest.skip("Skipped!"), 'TestReductions', 'test_ref_small_input', + dtypes=(torch.float16,)), + DecorateInfo( + unittest.skip("Skipped!"), 'TestReductions', + 'test_ref_duplicate_values', + dtypes=(torch.float16,)), + ), + ), + # std_mean and var_mean are not ReductionInfos + PythonRefInfo( + "_refs.std_mean", + torch_opinfo_name="std_mean", + ), + ReductionPythonRefInfo( + "_refs.sum", + torch_opinfo_name="sum", + supports_out=True, + skips=( + # FIXME: doesn't test out behavior properly for this operator + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'), + # FIXME: mean reduces all dimensions when dim=[] + DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty'), + DecorateInfo( + unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty_keepdim'), + # FIXME: improve precision + DecorateInfo( + unittest.skip("Skipped!"), 'TestReductions', 'test_ref_small_input', + dtypes=[torch.float16]), + DecorateInfo( + unittest.skip("Skipped!"), 'TestReductions', + 'test_ref_duplicate_values', + dtypes=[torch.float16]), + DecorateInfo( + unittest.skip("Skipped!"), 'TestOperators', 'test_reduction_all', + dtypes=[torch.float32]), + ), + ), + PythonRefInfo( + "_refs.cumsum", + torch_opinfo_name="cumsum", + supports_out=True, + skips=( + # doesn't test out behavior properly for this operator + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'), + ), + ), + PythonRefInfo( + "_refs.cumprod", + torch_opinfo_name="cumprod", + supports_out=True, + skips=( + # doesn't test out behavior properly for this operator + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'), + ), + ), + PythonRefInfo( + "_refs.sum_to_size", + torch_opinfo_name="sum_to_size", + validate_view_consistency=False, + ), + ReductionPythonRefInfo( + "_refs.prod", + torch_opinfo_name="prod", + supports_out=True, + supports_multiple_dims=True, + skips=( + # FIXME: doesn't test out behavior properly for this operator + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'), + # FIXME: reduces all dimensions when dim=[] + DecorateInfo( + unittest.expectedFailure, 'TestReductions', 'test_dim_empty'), + DecorateInfo( + unittest.expectedFailure, 'TestReductions', 'test_dim_empty_keepdim'), + # FIXME: improve precision + DecorateInfo( + unittest.skip("Skipped!"), 'TestReductions', 'test_ref_small_input', + dtypes=[torch.float16, torch.complex64]), + ), + ), + ReductionPythonRefInfo( + "_refs.var", + torch_opinfo_name="var", + supports_out=True, + skips=( + # FIXME: reduces all dimensions when dim=[] + DecorateInfo( + unittest.expectedFailure, 'TestReductions', 'test_dim_empty'), + DecorateInfo( + unittest.expectedFailure, 'TestReductions', 'test_dim_empty_keepdim'), + # FIXME: improve precision + DecorateInfo( + unittest.skip("Skipped!"), 'TestReductions', 'test_ref_small_input'), + ), + ), + PythonRefInfo( + "_refs.var_mean", + torch_opinfo_name="var_mean", + validate_view_consistency=False, + ), + # + # Linear Algebra Operators + # + PythonRefInfo( + "_refs.addr", + torch_opinfo_name="addr", + decorators=( + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref',), + ), + ), + PythonRefInfo( + "_refs.trace", + torch_opinfo_name="trace", + ), + PythonRefInfo( + "_refs.norm", + torch_opinfo_name="norm", + supports_out=True, + # Uses vector_norm inside and vector_norm is affected by + # https://github.com/pytorch/pytorch/issues/77216 + validate_view_consistency=False, + ), + # + # Tensor Creation Reference OpInfos + # + PythonRefInfo( + "_refs.empty", + torch_opinfo_name="empty", + skips=( + DecorateInfo(unittest.skip("Expected: empty is not comparable"), + 'TestCommon', + 'test_python_ref'), + DecorateInfo(unittest.skip("Expected: empty is not comparable"), + 'TestCommon', + 'test_python_ref_torch_fallback'), + DecorateInfo(unittest.skip("Expected: empty is not comparable"), + 'TestCommon', + 'test_out'), + DecorateInfo(unittest.skip("Expected: empty is not comparable"), + 'TestCommon', + 'test_out_warning'), + DecorateInfo(unittest.skip("Expected: empty is not comparable"), + 'TestMathBits', + 'test_conj_view'), + DecorateInfo(unittest.skip("Expected: empty is not comparable"), + 'TestMathBits', + 'test_neg_conj_view'), + DecorateInfo(unittest.skip("Expected: empty is not comparable"), + 'TestMathBits', + 'test_neg_view'), + # FIXME: shouldn't check empty results + DecorateInfo(unittest.skip("Can't check result for empty"), 'TestCommon', 'test_python_ref_executor'), + DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'), + ), + ), + PythonRefInfo( + "_refs.empty_like", + torch_opinfo_name="empty_like", + skips=( + DecorateInfo(unittest.skip("Expected: empty is not comparable"), + 'TestCommon', + 'test_python_ref'), + DecorateInfo(unittest.skip("Expected: empty is not comparable"), + 'TestCommon', + 'test_python_ref_torch_fallback'), + DecorateInfo(unittest.skip("Expected: empty is not comparable"), + 'TestCommon', + 'test_out'), + DecorateInfo(unittest.skip("Expected: empty is not comparable"), + 'TestCommon', + 'test_out_warning'), + DecorateInfo(unittest.skip("Expected: empty is not comparable"), + 'TestMathBits', + 'test_conj_view'), + DecorateInfo(unittest.skip("Expected: empty is not comparable"), + 'TestMathBits', + 'test_neg_conj_view'), + DecorateInfo(unittest.skip("Expected: empty is not comparable"), + 'TestMathBits', + 'test_neg_view'), + # FIXME: should not compare results of empty_like + DecorateInfo(unittest.skip("Can't check result for empty_like"), 'TestCommon', 'test_python_ref_executor'), + DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'), + ), + ), + PythonRefInfo( + "_refs.randn", + torch_opinfo_name="randn", + op=lambda *args, **kwargs: wrapper_set_seed(refs.randn, *args, **kwargs), + skips=( + # see https://github.com/pytorch/pytorch/issues/85121 + DecorateInfo(unittest.skip("make_traced() doesn't set seed properly!"), + 'TestCommon', + 'test_python_ref_executor'), + # These tests expect the input to be a tensor or a sequence of tensors + DecorateInfo(unittest.skip("Test expects tensor input"), "TestCommon", "test_noncontiguous_samples"), + DecorateInfo(unittest.skip("Test expects tensor input"), 'TestMathBits', 'test_neg_view'), + DecorateInfo(unittest.skip("Test expects tensor input"), 'TestMathBits', 'test_conj_view'), + DecorateInfo(unittest.skip("Test expects tensor input"), 'TestMathBits', 'test_neg_conj_view'), + ), + ), + PythonRefInfo( + "_refs.eye", + torch_opinfo_name="eye", + skips=( + # skip these tests since we have non tensor input + DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_conj_view'), + DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_conj_view'), + DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_view'), + ), + ), + PythonRefInfo( + "_refs.new_empty", + torch_opinfo_name="new_empty", + skips=( + DecorateInfo(unittest.skip("Expected: empty is not comparable"), + 'TestCommon', + 'test_python_ref'), + DecorateInfo(unittest.skip("Expected: empty is not comparable"), + 'TestCommon', + 'test_python_ref_torch_fallback'), + DecorateInfo(unittest.skip("Expected: empty is not comparable"), + 'TestCommon', + 'test_out'), + DecorateInfo(unittest.skip("Expected: empty is not comparable"), + 'TestCommon', + 'test_out_warning'), + DecorateInfo(unittest.skip("Expected: empty is not comparable"), + 'TestMathBits', + 'test_conj_view'), + DecorateInfo(unittest.skip("Expected: empty is not comparable"), + 'TestMathBits', + 'test_neg_conj_view'), + DecorateInfo(unittest.skip("Expected: empty is not comparable"), + 'TestMathBits', + 'test_neg_view'), + # FIXME: should not compare results of empty_like + DecorateInfo(unittest.skip("Can't check result for new_empty"), 'TestCommon', 'test_python_ref_executor'), + DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'), + ), + ), + PythonRefInfo( + "_refs.new_empty_strided", + torch_opinfo_name="new_empty_strided", + skips=( + DecorateInfo(unittest.skip("Expected: empty_strided is not comparable"), + 'TestCommon', + 'test_python_ref'), + DecorateInfo(unittest.skip("Expected: empty_strided is not comparable"), + 'TestCommon', + 'test_python_ref_torch_fallback'), + DecorateInfo(unittest.skip("Expected: empty_strided is not comparable"), + 'TestMathBits', + 'test_conj_view'), + DecorateInfo(unittest.skip("Expected: empty_strided is not comparable"), + 'TestMathBits', + 'test_neg_conj_view'), + DecorateInfo(unittest.skip("Expected: empty_strided is not comparable"), + 'TestMathBits', + 'test_neg_view'), + DecorateInfo(unittest.skip("Expected: empty_strided is not comparable"), + 'TestCommon', + 'test_python_ref_executor'), + DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'), + + ), + ), + PythonRefInfo( + "_refs.empty_strided", + torch_opinfo_name="empty_strided", + skips=( + DecorateInfo(unittest.skip("Expected: empty_strided is not comparable"), + 'TestCommon', + 'test_python_ref'), + DecorateInfo(unittest.skip("Expected: empty_strided is not comparable"), + 'TestCommon', + 'test_python_ref_torch_fallback'), + DecorateInfo(unittest.skip("Expected: empty_strided is not comparable"), + 'TestMathBits', + 'test_conj_view'), + DecorateInfo(unittest.skip("Expected: empty_strided is not comparable"), + 'TestMathBits', + 'test_neg_conj_view'), + DecorateInfo(unittest.skip("Expected: empty_strided is not comparable"), + 'TestMathBits', + 'test_neg_view'), + DecorateInfo(unittest.skip("Expected: empty_strided is not comparable"), + 'TestCommon', + 'test_python_ref_executor'), + DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'), + ), + ), + PythonRefInfo( + "_refs.new_full", + torch_opinfo_name="new_full", + ), + PythonRefInfo( + "_refs.new_ones", + torch_opinfo_name="new_ones", + ), + PythonRefInfo( + "_refs.new_zeros", + torch_opinfo_name="new_zeros", + ), + # + # Conditional Reference OpInfos + # + PythonRefInfo( + "_refs.masked_fill", + torch_opinfo_name="masked_fill", + skips=( + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_errors'), + ), + ), + PythonRefInfo( + "_refs.where", + torch_opinfo_name="where", + op=lambda self, condition, other: refs.where(condition, self, other), + supports_out=False, + skips=( + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_errors', device_type='cuda'), + ), + ), + PythonRefInfo( + "_refs.index_select", + torch_opinfo_name="index_select", + # empty_strided + skips=( + # no _refs support for Tensor.__setitem__ + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref'), + # Sample out= with a stride of zero. This _out operation checks that the input has no + # inner overlap + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_errors'),) + ), + PythonRefInfo( + "_refs.index_copy", + torch_opinfo_name="index_copy", + # empty_strided + skips=( + # no _refs support for Tensor.__setitem__ + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref'), + ), + ), + PythonRefInfo( + "_refs.index_add", + torch_opinfo_name="index_add", + # empty_strided + skips=( + # no _refs support for Tensor.__setitem__ + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref'), + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_errors'), + ), + ), + PythonRefInfo( + "_refs.index_fill", + torch_opinfo_name="index_fill", + # empty_strided + skips=( + # no _refs support for Tensor.__setitem__ + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref'),) + ), + # + # Test-related functions + # + PythonRefInfo( + "_refs.allclose", + torch_opinfo_name="allclose", + ), + # + # Misc functions + # + PythonRefInfo( + "_refs.stft", + torch_opinfo_name="stft", + skips=[ + # RuntimeError: no _refs support for aten.pad + DecorateInfo( + unittest.expectedFailure, 'TestCommon', 'test_python_ref' + ), + ], + ), + PythonRefInfo( + "_refs.istft", + torch_opinfo_name="istft", + skips=[ + # RuntimeError: no _refs support for aten.unfold_backward + DecorateInfo( + unittest.expectedFailure, 'TestCommon', 'test_python_ref' + ), + DecorateInfo( + unittest.skip("Expected: unfold_backward() got an unexpected keyword argument 'input_sizes'"), + 'TestCommon', + 'test_python_ref_executor', + dtypes=(torch.complex64, torch.complex128), + ), + ], + ), + PythonRefInfo( + "_refs.view_as_complex", + torch_opinfo_name="view_as_complex", + ), + PythonRefInfo( + "_refs.split_with_sizes", + torch_opinfo_name="split_with_sizes", + ), +] +python_ref_db += opinfo.definitions.python_ref_db + +# Common operator groupings +ops_and_refs = op_db + python_ref_db +unary_ufuncs = [op for op in ops_and_refs if isinstance(op, UnaryUfuncInfo)] +binary_ufuncs = [op for op in ops_and_refs if isinstance(op, BinaryUfuncInfo)] +binary_ufuncs_and_refs = tuple(op for op in ops_and_refs if isinstance(op, BinaryUfuncInfo)) +spectral_funcs = [op for op in ops_and_refs if isinstance(op, SpectralFuncInfo)] +sparse_unary_ufuncs = [op for op in op_db if isinstance(op, UnaryUfuncInfo) and op.supports_sparse] +sparse_csr_unary_ufuncs = [op for op in op_db if isinstance(op, UnaryUfuncInfo) and op.supports_sparse_csr] +sparse_reduction_ops = [op for op in op_db if isinstance(op, ReductionOpInfo) and op.supports_sparse] +shape_funcs = [op for op in ops_and_refs if isinstance(op, ShapeFuncInfo)] +reduction_ops = [op for op in ops_and_refs if isinstance(op, ReductionOpInfo)] +reference_filtered_ops = [op for op in reduction_ops if op.ref is not None] +reference_masked_ops = [op for op in reference_filtered_ops if op.name.startswith('masked.')] +sparse_masked_reduction_ops = [op for op in sparse_reduction_ops if op.name.startswith('masked.')] + +def index_variable(shape, max_indices, device=torch.device('cpu')): + if not isinstance(shape, tuple): + shape = (shape,) + return torch.testing.make_tensor(*shape, dtype=torch.long, device=device, low=0, high=max_indices) + +def gather_variable(shape, index_dim, max_indices, duplicate=False, device=torch.device('cpu')): + assert len(shape) == 2 + assert index_dim < 2 + batch_dim = 1 - index_dim + index = torch.zeros(*shape, dtype=torch.long, device=device) + for i in range(shape[index_dim]): + index.select(index_dim, i).copy_( + torch.randperm(max_indices, device=device)[:shape[batch_dim]]) + if duplicate: + index.select(batch_dim, 0).copy_(index.select(batch_dim, 1)) + return index + +def bernoulli_scalar(): + return torch.tensor(0, dtype=torch.bool).bernoulli_() + +def mask_not_all_zeros(shape): + assert len(shape) > 0 + while True: + result = torch.randn(shape).gt(0) + if result.sum() > 0: + return result + +# Copied from functorch +def xfail(op_name, variant_name='', *, device_type=None, dtypes=None): + return (op_name, variant_name, device_type, dtypes, True) + + +def skip(op_name, variant_name='', *, device_type=None, dtypes=None): + return (op_name, variant_name, device_type, dtypes, False) + + +def skipOps(test_case_name, base_test_name, to_skip): + all_opinfos = op_db + for xfail in to_skip: + op_name, variant_name, device_type, dtypes, expected_failure = xfail + matching_opinfos = [o for o in all_opinfos + if o.name == op_name and o.variant_test_name == variant_name] + assert len(matching_opinfos) >= 1, f"Couldn't find OpInfo for {xfail}" + for op in matching_opinfos: + decorators = list(op.decorators) + if expected_failure: + decorator = DecorateInfo(unittest.expectedFailure, + test_case_name, base_test_name, + device_type=device_type, dtypes=dtypes) + decorators.append(decorator) + else: + decorator = DecorateInfo(unittest.skip("Skipped!"), + test_case_name, base_test_name, + device_type=device_type, dtypes=dtypes) + decorators.append(decorator) + op.decorators = tuple(decorators) + + # This decorator doesn't modify fn in any way + def wrapped(fn): + return fn + return wrapped diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/common_pruning.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/common_pruning.py new file mode 100644 index 0000000000000000000000000000000000000000..13cd86e05bd6f7b4e9515cf102cc1e6d3b49781d --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/common_pruning.py @@ -0,0 +1,385 @@ +# Owner(s): ["module: unknown"] + +from typing import Any +from torch.ao.pruning import BaseSparsifier +import torch +import torch.nn.functional as F +from torch import nn + +class ImplementedSparsifier(BaseSparsifier): + def __init__(self, **kwargs: dict[str, Any]) -> None: + super().__init__(defaults=kwargs) + + def update_mask(self, module: nn.Module, tensor_name: str, **kwargs: dict[str, Any]) -> None: + module.parametrizations.weight[0].mask[0] = 0 # type: ignore[index, union-attr] + linear_state = self.state['linear1.weight'] + linear_state['step_count'] = linear_state.get('step_count', 0) + 1 + + +class MockSparseLinear(nn.Linear): + """ + This class is a MockSparseLinear class to check convert functionality. + It is the same as a normal Linear layer, except with a different type, as + well as an additional from_dense method. + """ + @classmethod + def from_dense(cls, mod: nn.Linear) -> 'MockSparseLinear': + """ + """ + linear = cls(mod.in_features, + mod.out_features) + return linear + + +def rows_are_subset(subset_tensor: torch.Tensor, superset_tensor: torch.Tensor) -> bool: + """ + Checks to see if all rows in subset tensor are present in the superset tensor + """ + i = 0 + for row in subset_tensor: + while i < len(superset_tensor): + if not torch.equal(row, superset_tensor[i]): + i += 1 + else: + break + else: + return False + return True + + +class SimpleLinear(nn.Module): + r"""Model with only Linear layers without biases, some wrapped in a Sequential, + some following the Sequential. Used to test basic pruned Linear-Linear fusion.""" + + def __init__(self) -> None: + super().__init__() + self.seq = nn.Sequential( + nn.Linear(7, 5, bias=False), + nn.Linear(5, 6, bias=False), + nn.Linear(6, 4, bias=False), + ) + self.linear1 = nn.Linear(4, 4, bias=False) + self.linear2 = nn.Linear(4, 10, bias=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.seq(x) + x = self.linear1(x) + x = self.linear2(x) + return x + + +class LinearBias(nn.Module): + r"""Model with only Linear layers, alternating layers with biases, + wrapped in a Sequential. Used to test pruned Linear-Bias-Linear fusion.""" + + def __init__(self) -> None: + super().__init__() + self.seq = nn.Sequential( + nn.Linear(7, 5, bias=True), + nn.Linear(5, 6, bias=False), + nn.Linear(6, 3, bias=True), + nn.Linear(3, 3, bias=True), + nn.Linear(3, 10, bias=False), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.seq(x) + return x + + +class LinearActivation(nn.Module): + r"""Model with only Linear layers, some with bias, some in a Sequential and some following. + Activation functions modules in between each Linear in the Sequential, and each outside layer. + Used to test pruned Linear(Bias)-Activation-Linear fusion.""" + + def __init__(self) -> None: + super().__init__() + self.seq = nn.Sequential( + nn.Linear(7, 5, bias=True), + nn.ReLU(), + nn.Linear(5, 6, bias=False), + nn.Tanh(), + nn.Linear(6, 4, bias=True), + ) + self.linear1 = nn.Linear(4, 3, bias=True) + self.act1 = nn.ReLU() + self.linear2 = nn.Linear(3, 10, bias=False) + self.act2 = nn.Tanh() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.seq(x) + x = self.linear1(x) + x = self.act1(x) + x = self.linear2(x) + x = self.act2(x) + return x + + +class LinearActivationFunctional(nn.Module): + r"""Model with only Linear layers, some with bias, some in a Sequential and some following. + Activation functions modules in between each Linear in the Sequential, and functional + activationals are called in between each outside layer. + Used to test pruned Linear(Bias)-Activation-Linear fusion.""" + + def __init__(self) -> None: + super().__init__() + self.seq = nn.Sequential( + nn.Linear(7, 5, bias=True), + nn.ReLU(), + nn.Linear(5, 6, bias=False), + nn.ReLU(), + nn.Linear(6, 4, bias=True), + ) + self.linear1 = nn.Linear(4, 3, bias=True) + self.linear2 = nn.Linear(3, 8, bias=False) + self.linear3 = nn.Linear(8, 10, bias=False) + self.act1 = nn.ReLU() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.seq(x) + x = self.linear1(x) + x = F.relu(x) + x = self.linear2(x) + x = F.relu(x) + x = self.linear3(x) + x = F.relu(x) + return x + + +class SimpleConv2d(nn.Module): + r"""Model with only Conv2d layers, all without bias, some in a Sequential and some following. + Used to test pruned Conv2d-Conv2d fusion.""" + + def __init__(self) -> None: + super().__init__() + self.seq = nn.Sequential( + nn.Conv2d(1, 32, 3, 1, bias=False), + nn.Conv2d(32, 64, 3, 1, bias=False), + ) + self.conv2d1 = nn.Conv2d(64, 48, 3, 1, bias=False) + self.conv2d2 = nn.Conv2d(48, 52, 3, 1, bias=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.seq(x) + x = self.conv2d1(x) + x = self.conv2d2(x) + return x + + +class Conv2dBias(nn.Module): + r"""Model with only Conv2d layers, some with bias, some in a Sequential and some outside. + Used to test pruned Conv2d-Bias-Conv2d fusion.""" + + def __init__(self) -> None: + super().__init__() + self.seq = nn.Sequential( + nn.Conv2d(1, 32, 3, 1, bias=True), + nn.Conv2d(32, 32, 3, 1, bias=True), + nn.Conv2d(32, 64, 3, 1, bias=False), + ) + self.conv2d1 = nn.Conv2d(64, 48, 3, 1, bias=True) + self.conv2d2 = nn.Conv2d(48, 52, 3, 1, bias=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.seq(x) + x = self.conv2d1(x) + x = self.conv2d2(x) + return x + + +class Conv2dActivation(nn.Module): + r"""Model with only Conv2d layers, some with bias, some in a Sequential and some following. + Activation function modules in between each Sequential layer, functional activations called + in-between each outside layer. + Used to test pruned Conv2d-Bias-Activation-Conv2d fusion.""" + + def __init__(self) -> None: + super().__init__() + self.seq = nn.Sequential( + nn.Conv2d(1, 32, 3, 1, bias=True), + nn.ReLU(), + nn.Conv2d(32, 64, 3, 1, bias=True), + nn.Tanh(), + nn.Conv2d(64, 64, 3, 1, bias=False), + nn.ReLU(), + ) + self.conv2d1 = nn.Conv2d(64, 48, 3, 1, bias=False) + self.conv2d2 = nn.Conv2d(48, 52, 3, 1, bias=True) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.seq(x) + x = self.conv2d1(x) + x = F.relu(x) + x = self.conv2d2(x) + x = F.hardtanh(x) + return x + + +class Conv2dPadBias(nn.Module): + r"""Model with only Conv2d layers, all with bias and some with padding > 0, + some in a Sequential and some following. Activation function modules in between each layer. + Used to test that bias is propagated correctly in the special case of + pruned Conv2d-Bias-(Activation)Conv2d fusion, when the second Conv2d layer has padding > 0.""" + + def __init__(self) -> None: + super().__init__() + self.seq = nn.Sequential( + nn.Conv2d(1, 32, 3, 1, padding=1, bias=True), + nn.ReLU(), + nn.Conv2d(32, 32, 3, 1, bias=False), + nn.ReLU(), + nn.Conv2d(32, 32, 3, 1, padding=1, bias=True), + nn.ReLU(), + nn.Conv2d(32, 32, 3, 1, padding=1, bias=True), + nn.ReLU(), + nn.Conv2d(32, 64, 3, 1, bias=True), + nn.Tanh(), + ) + self.conv2d1 = nn.Conv2d(64, 48, 3, 1, padding=1, bias=True) + self.act1 = nn.ReLU() + self.conv2d2 = nn.Conv2d(48, 52, 3, 1, padding=1, bias=True) + self.act2 = nn.Tanh() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.seq(x) + x = self.conv2d1(x) + x = self.act1(x) + x = self.conv2d2(x) + x = self.act2(x) + return x + + +class Conv2dPool(nn.Module): + r"""Model with only Conv2d layers, all with bias, some in a Sequential and some following. + Activation function modules in between each layer, Pool2d modules in between each layer. + Used to test pruned Conv2d-Pool2d-Conv2d fusion.""" + + def __init__(self) -> None: + super().__init__() + self.seq = nn.Sequential( + nn.Conv2d(1, 32, kernel_size=3, padding=1, bias=True), + nn.MaxPool2d(kernel_size=2, stride=2, padding=1), + nn.ReLU(), + nn.Conv2d(32, 64, kernel_size=3, padding=1, bias=True), + nn.Tanh(), + nn.AvgPool2d(kernel_size=2, stride=2, padding=1), + ) + self.conv2d1 = nn.Conv2d(64, 48, kernel_size=3, padding=1, bias=True) + self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2, padding=1) + self.af1 = nn.ReLU() + self.conv2d2 = nn.Conv2d(48, 52, kernel_size=3, padding=1, bias=True) + self.conv2d3 = nn.Conv2d(52, 52, kernel_size=3, padding=1, bias=True) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.seq(x) + x = self.conv2d1(x) + x = self.maxpool(x) + x = self.af1(x) + x = self.conv2d2(x) + x = F.avg_pool2d(x, kernel_size=2, stride=2, padding=1) + x = F.relu(x) + x = self.conv2d3(x) + return x + + +class Conv2dPoolFlattenFunctional(nn.Module): + r"""Model with Conv2d layers, all with bias, some in a Sequential and some following, and then a Pool2d + and a functional Flatten followed by a Linear layer. + Activation functions and Pool2ds in between each layer also. + Used to test pruned Conv2d-Pool2d-Flatten-Linear fusion.""" + + def __init__(self) -> None: + super().__init__() + self.seq = nn.Sequential( + nn.Conv2d(1, 3, kernel_size=3, padding=1, bias=True), + nn.MaxPool2d(kernel_size=2, stride=2, padding=1), + nn.ReLU(), + nn.Conv2d(3, 5, kernel_size=3, padding=1, bias=True), + nn.Tanh(), + nn.AvgPool2d(kernel_size=2, stride=2, padding=1), + ) + self.conv2d1 = nn.Conv2d(5, 7, kernel_size=3, padding=1, bias=True) + self.af1 = nn.ReLU() + self.conv2d2 = nn.Conv2d(7, 11, kernel_size=3, padding=1, bias=True) + self.avg_pool = nn.AdaptiveAvgPool2d((1, 1)) + self.fc = nn.Linear(11, 13, bias=True) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.seq(x) + x = self.conv2d1(x) + x = F.max_pool2d(x, kernel_size=2, stride=2, padding=1) + x = self.af1(x) + x = self.conv2d2(x) + x = self.avg_pool(x) + x = torch.flatten(x, 1) # test functional flatten + x = self.fc(x) + return x + + +class Conv2dPoolFlatten(nn.Module): + r"""Model with Conv2d layers, all with bias, some in a Sequential and some following, and then a Pool2d + and a Flatten module followed by a Linear layer. + Activation functions and Pool2ds in between each layer also. + Used to test pruned Conv2d-Pool2d-Flatten-Linear fusion.""" + + def __init__(self) -> None: + super().__init__() + self.seq = nn.Sequential( + nn.Conv2d(1, 3, kernel_size=3, padding=1, bias=True), + nn.MaxPool2d(kernel_size=2, stride=2, padding=1), + nn.ReLU(), + nn.Conv2d(3, 5, kernel_size=3, padding=1, bias=True), + nn.Tanh(), + nn.AvgPool2d(kernel_size=2, stride=2, padding=1), + ) + self.conv2d1 = nn.Conv2d(5, 7, kernel_size=3, padding=1, bias=True) + self.af1 = nn.ReLU() + self.conv2d2 = nn.Conv2d(7, 11, kernel_size=3, padding=1, bias=True) + self.avg_pool = nn.AdaptiveAvgPool2d((2, 2)) + self.flatten = nn.Flatten() + self.fc = nn.Linear(44, 13, bias=True) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.seq(x) + x = self.conv2d1(x) + x = F.max_pool2d(x, kernel_size=2, stride=2, padding=1) + x = self.af1(x) + x = self.conv2d2(x) + x = self.avg_pool(x) + x = self.flatten(x) + x = self.fc(x) + return x + + +class LSTMLinearModel(nn.Module): + """Container module with an encoder, a recurrent module, and a linear.""" + + def __init__( + self, input_dim: int, hidden_dim: int, output_dim: int, num_layers: int + ) -> None: + super().__init__() + self.lstm = nn.LSTM(input_dim, hidden_dim, num_layers) + self.linear = nn.Linear(hidden_dim, output_dim) + + def forward(self, input: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + output, _hidden = self.lstm(input) + decoded = self.linear(output) + return decoded, output + + +class LSTMLayerNormLinearModel(nn.Module): + """Container module with an LSTM, a LayerNorm, and a linear.""" + + def __init__( + self, input_dim: int, hidden_dim: int, output_dim: int, num_layers: int + ) -> None: + super().__init__() + self.lstm = nn.LSTM(input_dim, hidden_dim, num_layers) + self.norm = nn.LayerNorm(hidden_dim) + self.linear = nn.Linear(hidden_dim, output_dim) + + def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + x, state = self.lstm(x) + x = self.norm(x) + x = self.linear(x) + return x, state diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/composite_compliance.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/composite_compliance.py new file mode 100644 index 0000000000000000000000000000000000000000..773bea63eef82f7dd83034d764a484a6085ed3ab --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/composite_compliance.py @@ -0,0 +1,608 @@ +# mypy: ignore-errors + +import torch +from torch import Tensor +import itertools + +from torch.utils._python_dispatch import TorchDispatchMode +from torch.utils._pytree import tree_map, tree_flatten, tree_unflatten +from torch.utils import _pytree as pytree +from functools import partial +from torch.utils._mode_utils import no_dispatch, all_same_mode +import torch.autograd.forward_ad as fwAD +from collections.abc import Callable +import re + + +def check_attr_consistency(wrapper_tensor, metadata_name, metadata_accessor): + elem = wrapper_tensor.elem + metadata_wrapper_tensor = metadata_accessor(wrapper_tensor) + metadata_elem = metadata_accessor(elem) + if metadata_wrapper_tensor == metadata_elem: + return + raise RuntimeError( + f"This operator is not Composite Compliant: the " + f"{metadata_name} of the tensor was modified directly without " + f"going through the PyTorch dispatcher.") + +def check_metadata_consistency(wrapper_tensor, CCT): + # CCT: CompositeCompliantTensor class which is generated using generate_cct + if not isinstance(wrapper_tensor, CCT): + return + things_to_check = { + 'shape': Tensor.size, + 'dtype': lambda x: x.dtype, + 'device': lambda x: x.device, + 'numel': Tensor.numel, + 'stride': Tensor.stride, + 'storage_offset': Tensor.storage_offset, + } + for metadata_name, metadata_accessor in things_to_check.items(): + check_attr_consistency(wrapper_tensor, metadata_name, metadata_accessor) + +def is_view_fn(func): + return func.overloadpacket.__name__ in { + 'as_strided', + 'detach', + 'diagonal', + 'expand', + 'expand_as', + 'movedim', + 'narrow', + 'permute', + 'select', + 'squeeze', + 'transpose', + 't', + 'real', + 'imag', + 'view_as_real', + 'view_as_complex', + 'unflatten', + 'unfold', + 'unsqueeze', + 'view', + 'view_as', + 'unbind', + 'split', + 'split_with_sizes', + 'vsplit', + 'hsplit', + 'tensor_split', + 'chunk', + 'swapaxes', + 'slice', + '_reshape_alias', + '_unsafe_view', + '_conj', + 'alias', + } + +# manually populated from native_functions that have inplace_view: True. +# In the future we will probably be able to grab that list directly +def is_inplace_view_fn(func): + return func.overloadpacket.__name__ in { + 'as_strided_', + 'detach_', + 'squeeze_', + 'swapaxes_', + 'swapdims_', + 't_', + 'transpose_', + 'unsqueeze_', + } + + +# Introspection please save us +def is_inplace(func): + name = func.overloadpacket.__name__ + if re.match('__i.+__', name): + return True + if re.match('__.+__', name): + return False + return name[-1] == '_' + + +def generate_cct_and_mode(autograd_view_consistency=True): + # This function returns a new class CompositeCompliantTensor + # The two arguments control the behaviour described below. + + # autograd_view_consistency: + # If True, alias result using `set_` if func returns a view + # (See Note [Alias Result]). + # Since Forward AD doesn't work with `set_` + # we disable it by setting alias to False. + + class CompositeCompliantTensor(torch.Tensor): + elem: torch.Tensor + + __slots__ = ['elem'] + + @staticmethod + def __new__(cls, elem, mode, *args, **kwargs): + assert type(elem) is not cls, \ + "Wrapping a CompositeCompliantTensor in a CompositeCompliantTensor is not supported" + + # The storage of CompositeCompliantTensor should never be used directly + # by a Composite operation; if the Composite + # operator attempts to read from the storage without dispatching then it'll + # raise a RuntimeError due to it being a meta storage. + r = torch.Tensor._make_wrapper_subclass( + cls, elem.size(), + dtype=elem.dtype, layout=elem.layout, + device=elem.device, requires_grad=elem.requires_grad, + strides=elem.stride(), storage_offset=elem.storage_offset()) + + if elem.requires_grad: + # CompositeCompliantTensor steals the "requires_grad"-ness. + # Why a new copy of `elem`? Because sometimes OpInfo shares inputs between tests... + tmp = torch.empty( + (), + dtype=elem.dtype, + device=elem.device, + layout=elem.layout, + requires_grad=False, + ) + # Use set_ rather than empty_strided() + copy_ so that we can preserve + # things like storage_offset. + tmp.set_( + source=elem.untyped_storage().clone(), + storage_offset=elem.storage_offset(), + size=elem.size(), + stride=elem.stride(), + ) + r.elem = tmp + else: + r.elem = elem + + assert r.stride() == r.elem.stride() + + # Propagate conjugate bits to the wrapper tensor + # Ref: https://github.com/albanD/subclass_zoo/issues/24 + # Ref: https://github.com/albanD/subclass_zoo/issues/21 + torch._C._set_conj(r, r.elem.is_conj()) + torch._C._set_neg(r, r.elem.is_neg()) + + r.mode = mode + return r + + def __repr__(self): + return f"CompositeCompliantTensor({self.elem})" + + @classmethod + def __torch_dispatch__(cls, func, types, args=(), kwargs=None): + all_args = pytree.arg_tree_leaves(*args, **(kwargs or {})) + modes = tuple(e.mode for e in all_args if isinstance(e, CompositeCompliantTensor)) + if not all_same_mode(modes): + raise RuntimeError("Multiple CompositeCompliantTensorModes NYI") + with modes[0]: + return func(*args, **kwargs) + + class CompositeCompliantTensorMode(TorchDispatchMode): + def __torch_dispatch__(self, func, types, args=(), kwargs=None): + def unwrap(e): + return e.elem if isinstance(e, CompositeCompliantTensor) else e + + def wrap(e): + return CompositeCompliantTensor(e, self) if isinstance(e, torch.Tensor) else e + + if func is torch.ops.aten._local_scalar_dense.default: + raise RuntimeError( + ".item() is not allowed to be called inside of composite " + "functions in the PyTorch library because not all backends " + "and/or Tensor subclasses (e.g. vmap, ProxyTensor) support them.") + + if func.overloadpacket.__name__ in ('set_', 'resize_'): + raise RuntimeError( + f"{func.__name__} is not allowed to be called inside of " + f"Composite operators.") + + if is_inplace(func): + # NB: We are making an assumption that if the function is in-place, + # then the first argument is being written to. Introspection please save us! + mutated_argument = args[0] + if not isinstance(mutated_argument, CompositeCompliantTensor) and \ + any(isinstance(a, CompositeCompliantTensor) for a in args[1:]): + raise RuntimeError( + 'Not composite compliant: performing in-place operation ' + f'{func.__name__} where the Tensor being written to is ' + 'regular Tensor but the other tensors are Tensor Subclasses. ' + 'Please try to avoid this in-place operation.') + + unwrapped_args = tree_map(unwrap, args) + unwrapped_kwargs = tree_map(unwrap, kwargs) + unwrapped_rs = func(*unwrapped_args, **unwrapped_kwargs) + rs = tree_map(wrap, unwrapped_rs) + + if is_view_fn(func) and autograd_view_consistency: + # Note [Alias Result] + # Autograd asserts that for B = A.view_fn(...), B and A's storages + # are the same. Here we try to make B alias A to avoid those asserts. + # See https://github.com/pytorch/pytorch/issues/65339 for more information + # about the issue. + with no_dispatch(): + # Idea: this is a weird way of getting a storage that aliases the input. + # This is a workaround for #65339. + # 1. under no_dispatch, all of the wrapper tensors look like regular + # tensors with special storage (the storage is nullptr and + # advertises CPU/CUDA device. + # 2. we run func, which ends up running the view operation + # 3. All view operations reuse the input's storage and return + # result Tensor(s) with new sizes/strides/offset that alias + # the input. + # 4. we set the storage (and sizes/strides/offset) of the wrapper + # tensor results to be that of the tensors that alias the input + result = func(*args, **kwargs) + if isinstance(result, (tuple, list)): + for a, b in zip(rs, result, strict=True): + a.set_(b) + else: + rs.set_(result) + + # Some operations are allowed to in-place modify the metadata of the + # inputs. The only ones are the "inplace view functions"; when we + # run into these, we manually modify the metadata of the input. + with no_dispatch(): + if is_inplace_view_fn(func): + func(*args, **kwargs) + + # For each CompositeCompliantTensor t, we check that t and t.elem + # have consistent metadata. If they don't have consistent metadata, + # that means the operator did something fishy. + check = partial(check_metadata_consistency, CCT=CompositeCompliantTensor) + pytree.tree_map_(check, args) + pytree.tree_map_(check, kwargs) + pytree.tree_map_(check, rs) + return rs + + return CompositeCompliantTensor, CompositeCompliantTensorMode() + +def is_tensorlist(lst): + if not isinstance(lst, list) and not isinstance(lst, tuple): + return False + if len(lst) == 0: + return False + all_tensors = all(isinstance(elt, torch.Tensor) for elt in lst) + if all_tensors: + return True + exists_one_tensor = all(isinstance(elt, torch.Tensor) for elt in lst) + if exists_one_tensor: + raise RuntimeError('This test assumes that PyTorch APIs cannot take ' + 'mixed lists of Tensor and other things') + return False + + +def maybe_map(fn, should_map, arg): + return fn(arg) if should_map else arg + + +def wrap(arg, CCT, cct_mode): + # CCT: CompositeCompliantTensor class which is generated using generate_cct_and_mode + if isinstance(arg, torch.Tensor): + return CCT(arg, cct_mode) + if is_tensorlist(arg): + return [CCT(a, cct_mode) for a in arg] + raise RuntimeError("wrap assumes that the input can be wrapped") + + +# Given a list of flat arguments, some of which may be Tensors, return all +# possible ways some of the arguments could be CompositeCompliantTensors (CCT). +# For example, given Tensors A, B, C and flat_args = [A, 1, B], +# We would return the following 4 options: +# [CCT(A), 1, CCT(B)] +# [CCT(A), 1, B] +# [A, 1, CCT(B)] +# [A, 1, B] +# NB: Yes, this is exponential. No, we don't care too much because PyTorch ops +# don't accept that many input Tensors. +def generate_subclass_choices(flat_args, CCT, cct_mode): + # CCT: CompositeCompliantTensor class which is generated using generate_cct_and_mode + is_tensor_likes = [isinstance(arg, torch.Tensor) or is_tensorlist(arg) for arg in flat_args] + subclass_options = [[False, True] if is_tensor_like else [False] for is_tensor_like in is_tensor_likes] + + for which_args_are_wrapped in itertools.product(*subclass_options): + + result = [maybe_map(partial(wrap, CCT=CCT, cct_mode=cct_mode), should_wrap_arg, arg) + for should_wrap_arg, arg in zip(which_args_are_wrapped, flat_args, strict=True)] + yield result, which_args_are_wrapped + + +# For an operation f(*args, **kwargs), each Tensor argument may either be +# a regular Tensor or a Tensor Subclass. This iterator iterates through +# all of those options. +def generate_subclass_choices_args_kwargs(args, kwargs, CCT, cct_mode): + # CCT: CompositeCompliantTensor class which is generated using generate_cct_and_mode + flat_kwargs, spec = tree_flatten(kwargs) + flat_args_kwargs = list(args) + list(flat_kwargs) + for choice, debug_metadata in generate_subclass_choices(flat_args_kwargs, CCT, cct_mode): + new_args = choice[:len(args)] + new_kwargs = tree_unflatten(choice[len(args):], spec) + which_args_are_wrapped = debug_metadata[:len(args)] + which_kwargs_are_wrapped = tree_unflatten(debug_metadata[len(args):], spec) + yield new_args, new_kwargs, which_args_are_wrapped, which_kwargs_are_wrapped + + +def raise_composite_compliance_error(err, additional_info=''): + raise RuntimeError( + "Composite compliance check failed with " + "the above error.\n" + f"{additional_info}" + "If you are adding an OpInfo of an " + "existing operator, please feel free to skip this test " + "because the problem was pre-existing and file an issue. " + "Otherwise, if you added a new operator, please read " + "through the Composite Compliance section in " + "aten/src/ATen/native/README.md for how to resolve this. " + ) from err + + +# This test checks ALL possible permutations of calling `op` with arguments +# that are individually either a regular Tensor or a Tensor subclass. +# +# The general strategy is to wrap some Tensor args and kwargs in +# CompositeCompliantTensor wrappers and call the operation. + +# If some composite operation does any non-compliant behavior, +# CompositeCompliantTensor will raise an error. +def check_all_permutations(op, args, kwargs, assert_equal_fn): + CCT, cct_mode = generate_cct_and_mode() + expected = op(*args, **kwargs) + for choice in generate_subclass_choices_args_kwargs(args, kwargs, CCT, cct_mode): + new_args, new_kwargs, which_args_are_wrapped, which_kwargs_are_wrapped = choice + + try: + actual = op(*new_args, **new_kwargs) + # NOTE: [What errors are Composite Compliance trying to catch?] + # + # There's two things we want to catch: + # - errors that would raise within the torch_dispatch impl + # - data_ptr accesses + # The first is easy to filter for (we could make the error a different + # error class), the second is always going to be a RuntimeError due to + # how it is implemented (if you try to access the data_ptr of the + # wrapper Tensor, it raises you some internal RuntimeError). + # + # So the most general thing to catch here was RuntimeError. If you + # are here and debugging why your test failed, it's plausible that + # the operator itself is broken and that there are other tests failing. + except RuntimeError as err: + raise_composite_compliance_error( + err, + f"- wrapped_args: {which_args_are_wrapped}\n" + f"- wrapped_kwargs: {which_kwargs_are_wrapped}\n" + ) + + def unwrap(e): + return e.elem if isinstance(e, CCT) else e + + assert_equal_fn(tree_map(unwrap, actual), expected) + +# Checks via the usage of torch dispatch mode certain anti-patterns that +# are not composite compliant. +# +# In particular, the anti-pattern we are trying to prevent is a user +# creating an empty tensor and then resize_-ing it. Torch Dispatch Mode helps +# here because all factory functions will create tensors that are +# CompositeCompliantTensor. +# +# The general strategy is to wrap all Tensor args and kwargs in +# CompositeCompliantTensor wrappers. If an operator that is +# Composite does any non-compliant behavior, +# CompositeCompliantTensor will raise an error. +def check_with_mode(op, args, kwargs, assert_equal_fn): + CCT, cct_mode = generate_cct_and_mode() + + def wrap(e): + return CCT(e, cct_mode) if isinstance(e, torch.Tensor) else e + + expected = op(*args, **kwargs) + + args = tree_map(wrap, args) + kwargs = tree_map(wrap, kwargs) + try: + with cct_mode: + actual = op(*args, **kwargs) + # see NOTE: [What errors are Composite Compliance trying to catch?] + except RuntimeError as err: + raise_composite_compliance_error(err) + + def unwrap(e): + return e.elem if isinstance(e, CCT) else e + + assert_equal_fn(tree_map(unwrap, actual), expected) + +def gather_leaf_tensors(args, kwargs): + leaf_tensors = [] + args, _args_spec = tree_flatten(args) + kwargs, _kwargs_spec = tree_flatten(kwargs) + args = args + kwargs + for arg in args: + if not isinstance(arg, torch.Tensor): + continue + if arg.requires_grad: + leaf_tensors.append(arg) + return leaf_tensors + + +def compute_expected_grads(op, args, kwargs, output_process_fn_grad=None, gradcheck_wrapper=None): + if gradcheck_wrapper is None: + results = op(*args, **kwargs) + else: + results = gradcheck_wrapper(op, *args, **kwargs) + + if output_process_fn_grad is not None: + results = output_process_fn_grad(results) + + flat_results = pytree.tree_leaves(results) + flat_results = [r for r in flat_results if isinstance(r, torch.Tensor)] + flat_diff_results = [r for r in flat_results if r.requires_grad] + assert len(flat_diff_results) > 0 + + grads = [torch.ones(r.shape, device=r.device, dtype=r.dtype) for r in flat_diff_results] + leaf_tensors = gather_leaf_tensors(args, kwargs) + assert len(leaf_tensors) > 0 + return torch.autograd.grad(flat_diff_results, leaf_tensors, + grads, allow_unused=True, retain_graph=True) + + +# Checks if the backward formula is composite compliant by testing +# all possible permutations of {inputs, grad_outputs} being +# CompositeCompliantTensor or regular Tensors. +# +# NB: it is important that op is accepted as a Callable and not an OpInfo, +# this means we can apply check_backward_formula to things that aren't OpInfos +# while debugging. +def check_backward_formula(op: Callable, args, kwargs, + output_process_fn_grad=None, + gradcheck_wrapper=None, assert_equal_fn=None): + CCT, cct_mode = generate_cct_and_mode() + + expected = compute_expected_grads(op, args, kwargs, output_process_fn_grad, gradcheck_wrapper) + + for choice in generate_subclass_choices_args_kwargs(args, kwargs, CCT, cct_mode): + new_args, new_kwargs, which_args_are_wrapped, which_kwargs_are_wrapped = choice + leaf_tensors = gather_leaf_tensors(new_args, new_kwargs) + assert len(leaf_tensors) > 0 + + try: + if gradcheck_wrapper is None: + results = op(*new_args, **new_kwargs) + else: + results = gradcheck_wrapper(op, *new_args, **new_kwargs) + if output_process_fn_grad is not None: + results = output_process_fn_grad(results) + # see NOTE: [What errors are Composite Compliance trying to catch?] + except RuntimeError as err: + raise_composite_compliance_error( + err, + f"- wrapped_args: {which_args_are_wrapped}\n" + f"- wrapped_kwargs: {which_kwargs_are_wrapped}\n" + ) + + flat_results = pytree.tree_leaves(results) + flat_results = [r for r in flat_results if isinstance(r, torch.Tensor)] + flat_diff_results = [r for r in flat_results if r.requires_grad] + assert len(flat_diff_results) > 0 + + # NB: ones, not ones_like, so we get a regular Tensor here + grads = [torch.ones(r.shape, device=r.device, dtype=r.dtype) + for r in flat_diff_results] + for flat_new_grads, which_grad_is_batched in generate_subclass_choices(grads, CCT, cct_mode): + try: + actual = torch.autograd.grad(flat_diff_results, leaf_tensors, flat_new_grads, + allow_unused=True, retain_graph=True) + # see NOTE: [What errors are Composite Compliance trying to catch?] + except RuntimeError as err: + raise_composite_compliance_error( + err, + f"- wrapped_args: {which_args_are_wrapped}\n" + f"- wrapped_kwargs: {which_kwargs_are_wrapped}\n" + f"- wrapped_grads: {which_grad_is_batched}\n" + ) + + def unwrap(e): + return e.elem if isinstance(e, CCT) else e + + assert_equal_fn(tuple(map(unwrap, actual)), expected, equal_nan=True) + +# Checks if the forward AD formula is composite compliant by testing +# all possible permutations of {primals, tangents} being +# CompositeCompliantTensor or regular Tensors. +# +# NB: it is important that op is accepted as a Callable and not an OpInfo, +# this means we can apply check_forward_ad_formula to things that aren't OpInfos +# while debugging. +def check_forward_ad_formula(op: Callable, args, kwargs, gradcheck_wrapper=None, assert_equal_fn=None): + CCT, cct_mode = generate_cct_and_mode(autograd_view_consistency=False) + + def maybe_tangent(t): + assert type(t) is not CCT + # Generate `tangent` tensor + # if given object is a Tensor and requires grad is set. + if isinstance(t, torch.Tensor) and t.requires_grad: + return torch.randn_like(t) + elif is_tensorlist(t): + return [torch.randn_like(e) if e.requires_grad else None for e in t] + return None + + tangent_args = tuple(maybe_tangent(arg) for arg in args) + flat_kwargs, spec = tree_flatten(kwargs) + flat_tangent_kwargs = tuple(maybe_tangent(arg) for arg in flat_kwargs) + tangent_kwargs = tree_unflatten(flat_tangent_kwargs, spec) + + with fwAD.dual_level(): + def maybe_make_dual(dual): + # Returns dual tensor if primal is a tensor/tensor subclass + # with requires_grad set. + primal, tangent = dual + if isinstance(primal, torch.Tensor) and primal.requires_grad: + return fwAD.make_dual(primal.detach(), tangent) + elif is_tensorlist(primal): + return tuple(fwAD.make_dual(pri.detach(), tang) if tang is not None else pri + for pri, tang in zip(primal, tangent, strict=True)) + return primal + + def compute_expected_grad(args, tangent_args, kwargs, tangent_kwargs): + op_args = tuple(map(maybe_make_dual, zip(args, tangent_args, strict=True))) + op_kwargs = {k: maybe_make_dual((v, tangent_kwargs[k])) for k, v in kwargs.items()} + + if gradcheck_wrapper is None: + return op(*op_args, **op_kwargs) + return gradcheck_wrapper(op, *op_args, **op_kwargs) + + expected = compute_expected_grad(args, tangent_args, kwargs, tangent_kwargs) + expected = tree_map(fwAD.unpack_dual, expected) + expected_primals = tree_map( + lambda x: x.primal, + expected, + is_leaf=lambda x: type(x) is fwAD.UnpackedDualTensor, + ) + expected_tangents = tree_map( + lambda x: x.tangent, + expected, + is_leaf=lambda x: type(x) is fwAD.UnpackedDualTensor, + ) + + # Permutations of arg and kwargs in CCT. + for choice in generate_subclass_choices_args_kwargs(args, kwargs, CCT, cct_mode): + new_args, new_kwargs, which_args_are_wrapped, which_kwargs_are_wrapped = choice + + # Permutations tangent arg and tangent kwargs in CCT. + for tang_choice in generate_subclass_choices_args_kwargs(tangent_args, tangent_kwargs, CCT, cct_mode): + new_tang_args, new_tang_kwargs, \ + which_tang_args_are_wrapped, which_tang_kwargs_are_wrapped = tang_choice + + op_args = tuple(map(maybe_make_dual, zip(new_args, new_tang_args, strict=True))) + op_kwargs = {k: maybe_make_dual((v, new_tang_kwargs[k])) for k, v in new_kwargs.items()} + + try: + if gradcheck_wrapper is None: + actual = op(*op_args, **op_kwargs) + else: + actual = gradcheck_wrapper(op, *op_args, **op_kwargs) + # see NOTE: [What errors are Composite Compliance trying to catch?] + except RuntimeError as err: + raise_composite_compliance_error( + err, + f"- wrapped_args: {which_args_are_wrapped}\n" + f"- wrapped_kwargs: {which_kwargs_are_wrapped}\n" + f"- wrapped_tangent_args: {which_tang_args_are_wrapped}\n" + f"- wrapped_tangent_kwargs: {which_tang_kwargs_are_wrapped}\n" + ) + + def unwrap(e): + return e.elem if isinstance(e, CCT) else e + + actual = tree_map(fwAD.unpack_dual, actual) + actual_primals = tree_map( + lambda x: unwrap(x.primal), + actual, + is_leaf=lambda x: type(x) is fwAD.UnpackedDualTensor, + ) + actual_tangents = tree_map( + lambda x: unwrap(x.tangent), + actual, + is_leaf=lambda x: type(x) is fwAD.UnpackedDualTensor, + ) + assert_equal_fn(actual_primals, expected_primals, equal_nan=True) + assert_equal_fn(actual_tangents, expected_tangents, equal_nan=True) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/custom_tensor.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/custom_tensor.py new file mode 100644 index 0000000000000000000000000000000000000000..de1b44ba8dac890142eaf2b013c2399eb59c2193 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/custom_tensor.py @@ -0,0 +1,160 @@ +# mypy: ignore-errors + + +from collections import namedtuple + +import torch +import torch.utils._pytree as pytree +from torch.utils._python_dispatch import return_and_correct_aliasing + + +FancyNamedTuple = namedtuple("FancyNamedTuple", ["foo", "bar"]) + + +# A simple tensor subclass that holds a tensor with custom metadata and custom method +class ConstantExtraMetadataTensor(torch.Tensor): + @staticmethod + def __new__(cls, elem): + shape = elem.shape + kwargs = {} + kwargs["strides"] = elem.stride() + kwargs["storage_offset"] = elem.storage_offset() + kwargs["device"] = elem.device + kwargs["layout"] = elem.layout + kwargs["requires_grad"] = elem.requires_grad + kwargs["dtype"] = elem.dtype + return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) + + def __init__(self, elem): + self.elem = elem + self.constant_attribute = 4 + + def __repr__(self): + inner_repr = repr(self.elem) + return f"CustomTensor({inner_repr})" + + def get_complicated_metadata(self): + return FancyNamedTuple(self.constant_attribute, self.constant_attribute) + + def __tensor_flatten__(self): + return ["elem"], self.constant_attribute + + def add_constant(self, a): + self.constant_attribute += a + + @staticmethod + def __tensor_unflatten__(inner_tensors, meta, outer_size, outer_stride): + assert meta is not None + elem = inner_tensors["elem"] + out = ConstantExtraMetadataTensor(elem) + out.constant_attribute = meta + return out + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs): + if kwargs is None: + kwargs = {} + args_inner = pytree.tree_map_only( + ConstantExtraMetadataTensor, lambda x: x.elem, args + ) + + kwargs_inner = pytree.tree_map_only( + ConstantExtraMetadataTensor, lambda x: x.elem, kwargs + ) + + out_inner = func(*args_inner, **kwargs_inner) + out_inner_flat, spec = pytree.tree_flatten(out_inner) + # for aten ops that return non-tensors, just assume that + # our cust inner tensors return the same value + out_flat = [ + ConstantExtraMetadataTensor(o_inner) + if isinstance(o_inner, torch.Tensor) + else o_inner + for o_inner in out_inner_flat + ] + out = pytree.tree_unflatten(out_flat, spec) + return return_and_correct_aliasing(func, args, kwargs, out) + + +# A simple tensor subclass that always returns plain tensor during __torch_dispatch__ +# It is similar to TwoTensor and is used to simulate torchao quantized tensors +class CustomTensorPlainOut(torch.Tensor): + @staticmethod + def __new__(cls, elem1, elem2): + shape = elem1.shape + kwargs = {} + kwargs["strides"] = elem1.stride() + kwargs["storage_offset"] = elem1.storage_offset() + kwargs["device"] = elem1.device + kwargs["layout"] = elem1.layout + kwargs["requires_grad"] = elem1.requires_grad + kwargs["dtype"] = elem1.dtype + return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) + + def __init__(self, elem1, elem2): + self.elem1 = elem1 + self.elem2 = elem2 + + def get_elem(self): + return self.elem1 + + def __repr__(self): + inner_repr_1 = repr(self.elem1) + inner_repr_2 = repr(self.elem2) + return f"CustomTensorPlainOut({inner_repr_1}, {inner_repr_2})" + + def __tensor_flatten__(self): + return ["elem1", "elem2"], None + + @staticmethod + def __tensor_unflatten__(inner_tensors, meta, outer_size, outer_stride): + elem1 = inner_tensors["elem1"] + elem2 = inner_tensors["elem2"] + out = CustomTensorPlainOut(elem1, elem2) + return out + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs): + # Don't use this tensor with view ops + if kwargs is None: + kwargs = {} + args_inner_1 = pytree.tree_map_only( + CustomTensorPlainOut, lambda x: x.elem1, args + ) + + kwargs_inner_1 = pytree.tree_map_only( + CustomTensorPlainOut, lambda x: x.elem1, kwargs + ) + + args_inner_2 = pytree.tree_map_only( + CustomTensorPlainOut, lambda x: x.elem2, args + ) + + kwargs_inner_2 = pytree.tree_map_only( + CustomTensorPlainOut, lambda x: x.elem2, kwargs + ) + + out_inner_1 = func(*args_inner_1, **kwargs_inner_1) + out_inner_2 = func(*args_inner_2, **kwargs_inner_2) + + out_inner_flat_1, spec = pytree.tree_flatten(out_inner_1) + out_inner_flat_2, spec = pytree.tree_flatten(out_inner_2) + + if func.is_view: + new_out = pytree.tree_unflatten( + ( + CustomTensorPlainOut(tensor1, tensor2) + for tensor1, tensor2 in zip( + out_inner_flat_1, out_inner_flat_2, strict=True + ) + ), + spec, + ) + return return_and_correct_aliasing(func, args, kwargs, new_out) + + out_new = ( + out_inner_flat_1[ix] + out_inner_flat_2[ix] + for ix in range(len(out_inner_flat_1)) + ) + + return pytree.tree_unflatten(out_new, spec) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/data/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1e3572cfc4c6a0ddc3d8fa2e1b056415204acdfa --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/data/__init__.py @@ -0,0 +1 @@ +# mypy: ignore-errors diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/data/network1.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/data/network1.py new file mode 100644 index 0000000000000000000000000000000000000000..8755643a78cca80668988df9e9db3de75778b5db --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/data/network1.py @@ -0,0 +1,10 @@ +# mypy: ignore-errors + +import torch.nn as nn + + +class Net(nn.Module): + + def __init__(self) -> None: + super().__init__() + self.linear = nn.Linear(10, 20) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/data/network2.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/data/network2.py new file mode 100644 index 0000000000000000000000000000000000000000..19b0b8ee53d3b530aa33978c7a13da4e5fee4ebd --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/data/network2.py @@ -0,0 +1,11 @@ +# mypy: ignore-errors + +import torch.nn as nn + + +class Net(nn.Module): + + def __init__(self) -> None: + super().__init__() + self.linear = nn.Linear(10, 20) + self.relu = nn.ReLU() diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/dist_utils.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/dist_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..45af2552cf25cef03a517f5b136c1a2e61c3a61d --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/dist_utils.py @@ -0,0 +1,199 @@ +# mypy: ignore-errors + +import re +import sys +import time +from functools import partial, wraps + +import torch.distributed as dist +import torch.distributed.rpc as rpc +from torch.distributed.rpc import _rref_context_get_debug_info +from torch.testing._internal.common_utils import FILE_SCHEMA, TEST_WITH_TSAN + + +if not dist.is_available(): + print("c10d not available, skipping tests", file=sys.stderr) + sys.exit(0) + + +INIT_METHOD_TEMPLATE = FILE_SCHEMA + "{file_name}" + +def dist_init( + old_test_method=None, + setup_rpc: bool = True, + clean_shutdown: bool = True, + faulty_messages=None, + messages_to_delay=None, +): + """ + We use this decorator for setting up and tearing down state since + MultiProcessTestCase runs each `test*` method in a separate process and + each process just runs the `test*` method without actually calling + 'setUp' and 'tearDown' methods of unittest. + + Note: pass the string representation of MessageTypes that should be used + with the faulty agent's send function. By default, all retriable messages + ("RREF_FORK_REQUEST", "RREF_CHILD_ACCEPT", "RREF_USER_DELETE", + "CLEANUP_AUTOGRAD_CONTEXT_REQ") will use the faulty send (this default is + set from faulty_rpc_agent_test_fixture.py). + """ + # If we use dist_init without arguments (ex: @dist_init), old_test_method is + # appropriately set and we return the wrapper appropriately. On the other + # hand if dist_init has arguments (ex: @dist_init(clean_shutdown=False)), + # old_test_method is None and we return a functools.partial which is the real + # decorator that is used and as a result we recursively call dist_init with + # old_test_method and the rest of the arguments appropriately set. + if old_test_method is None: + return partial( + dist_init, + setup_rpc=setup_rpc, + clean_shutdown=clean_shutdown, + faulty_messages=faulty_messages, + messages_to_delay=messages_to_delay, + ) + + @wraps(old_test_method) + def new_test_method(self, *arg, **kwargs): + # Setting _ignore_rref_leak to make sure OwnerRRefs are properly deleted + # in tests. + import torch.distributed.rpc.api as api + + api._ignore_rref_leak = False + self.worker_id = self.rank + self.setup_fault_injection(faulty_messages, messages_to_delay) + + rpc_backend_options = self.rpc_backend_options + if setup_rpc: + if TEST_WITH_TSAN: + # TSAN runs much slower. + rpc_backend_options.rpc_timeout = rpc.constants.DEFAULT_RPC_TIMEOUT_SEC * 5 + rpc.constants.DEFAULT_SHUTDOWN_TIMEOUT = 60 + + rpc.init_rpc( + name=f"worker{self.rank:d}", + backend=self.rpc_backend, + rank=self.rank, + world_size=self.world_size, + rpc_backend_options=rpc_backend_options, + ) + + return_value = old_test_method(self, *arg, **kwargs) + + if setup_rpc: + rpc.shutdown(graceful=clean_shutdown) + + return return_value + + return new_test_method + + +def noop() -> None: + pass + + +def wait_until_node_failure(rank: int, expected_error_regex: str = ".*") -> str: + """ + Loops until an RPC to the given rank fails. This is used to + indicate that the node has failed in unit tests. + Args: + rank (int): Rank of the node expected to fail + expected_error_regex (optional, str): Regex of exception message expected. Useful to ensure a specific failure + occurs, not just any. + """ + while True: + try: + rpc.rpc_sync(f"worker{rank}", noop, args=()) + time.sleep(0.1) + except Exception as e: + if re.search(pattern=expected_error_regex, string=str(e)): + return str(e) + + +def wait_until_pending_futures_and_users_flushed(timeout: int = 20) -> None: + """ + The RRef protocol holds forkIds of rrefs in a map until those forks are + confirmed by the owner. The message confirming the fork may arrive after + our tests check whether this map is empty, which leads to failures and + flaky tests. to_here also does not guarantee that we have finished + processind the owner's confirmation message for the RRef. This function + loops until the map is empty, which means the messages have been received + as processed. Call this function before asserting the map returned by + _get_debug_info is empty. + """ + start = time.time() + while True: + debug_info = _rref_context_get_debug_info() + num_pending_futures = int(debug_info["num_pending_futures"]) + num_pending_users = int(debug_info["num_pending_users"]) + if num_pending_futures == 0 and num_pending_users == 0: + break + time.sleep(0.1) + if time.time() - start > timeout: + raise ValueError( + f"Timed out waiting to flush pending futures and users, " + f"had {num_pending_futures} pending futures and {num_pending_users} pending users" + ) + + +def get_num_owners_and_forks() -> tuple[str, str]: + """ + Retrieves number of OwnerRRefs and forks on this node from + _rref_context_get_debug_info. + """ + rref_dbg_info = _rref_context_get_debug_info() + num_owners = rref_dbg_info["num_owner_rrefs"] + num_forks = rref_dbg_info["num_forks"] + return num_owners, num_forks + + +def wait_until_owners_and_forks_on_rank( + num_owners: int, num_forks: int, rank: int, timeout: int = 20 +) -> None: + """ + Waits until timeout for num_forks and num_owners to exist on the rank. Used + to ensure proper deletion of RRefs in tests. + """ + start = time.time() + while True: + num_owners_on_rank, num_forks_on_rank = rpc.rpc_sync( + worker_name(rank), get_num_owners_and_forks, args=(), timeout=5 + ) + num_owners_on_rank = int(num_owners_on_rank) + num_forks_on_rank = int(num_forks_on_rank) + if num_owners_on_rank == num_owners and num_forks_on_rank == num_forks: + return + time.sleep(1) + if time.time() - start > timeout: + raise ValueError( + f"Timed out waiting {timeout} sec for {num_owners} owners and {num_forks} forks on rank," + f" had {num_owners_on_rank} owners and {num_forks_on_rank} forks" + ) + + +def initialize_pg(init_method, rank: int, world_size: int) -> None: + # This is for tests using `dist.barrier`. + if not dist.is_initialized(): + dist.init_process_group( + backend="gloo", + init_method=init_method, + rank=rank, + world_size=world_size, + ) + + +def worker_name(rank: int) -> str: + return f"worker{rank}" + + +def get_function_event(function_events, partial_event_name): + """ + Returns the first event that matches partial_event_name in the provided + function_events. These function_events should be the output of + torch.autograd.profiler.function_events(). + + Args: + function_events: function_events returned by the profiler. + event_name (str): partial key that the event was profiled with. + """ + event = [event for event in function_events if partial_event_name in event.name][0] # noqa: RUF015 + return event diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/fake_config_module.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/fake_config_module.py new file mode 100644 index 0000000000000000000000000000000000000000..1e93c41de72a765415a8dce5d3c98c8cd0cf2c41 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/fake_config_module.py @@ -0,0 +1,44 @@ +import sys +from typing import Optional + +from torch.utils._config_module import Config, install_config_module + + +e_bool = True +e_int = 1 +e_float = 1.0 +e_string = "string" +e_list = [1] +e_set = {1} +e_tuple = (1,) +e_dict = {1: 2} +e_none: Optional[bool] = None +e_optional: Optional[bool] = True +e_ignored = True +_e_ignored = True +magic_cache_config_ignored = True +# [@compile_ignored: debug] +e_compile_ignored = True +e_config: bool = Config(default=True) +e_jk: bool = Config(justknob="does_not_exist", default=True) +e_jk_false: bool = Config(justknob="does_not_exist", default=False) +e_env_default: bool = Config(env_name_default="ENV_TRUE", default=False) +e_env_default_FALSE: bool = Config(env_name_default="ENV_FALSE", default=True) +e_env_default_str: bool = Config(env_name_default="ENV_STR", default="default") +e_env_default_str_empty: bool = Config( + env_name_default="ENV_STR_EMPTY", default="default" +) +e_env_force: bool = Config(env_name_force="ENV_TRUE", default=False) +e_aliased_bool: bool = Config( + alias="torch.testing._internal.fake_config_module2.e_aliasing_bool" +) + + +class nested: + e_bool = True + + +_cache_config_ignore_prefix = ["magic_cache_config"] +_save_config_ignore = ["e_ignored"] + +install_config_module(sys.modules[__name__]) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/fake_config_module2.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/fake_config_module2.py new file mode 100644 index 0000000000000000000000000000000000000000..77c2e2baa4ddca7685adf734809488979c21ab63 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/fake_config_module2.py @@ -0,0 +1,13 @@ +import sys + +from torch.utils._config_module import Config, install_config_module + + +e_aliasing_bool = False + +e_env_default_multi: bool = Config( + env_name_default=["ENV_TRUE", "ENV_FALSE"], default=False +) +e_env_force_multi: bool = Config(env_name_force=["ENV_FAKE", "ENV_TRUE"], default=False) + +install_config_module(sys.modules[__name__]) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/generated/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/generated/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/generated/annotated_fn_args.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/generated/annotated_fn_args.py new file mode 100644 index 0000000000000000000000000000000000000000..2c8fdd3bb138fa7225a10d7518f148784a7bb116 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/generated/annotated_fn_args.py @@ -0,0 +1,2905 @@ +""" +This file is needed for generating procedural tests required for +testing __torch_function__. See tests/test_overrides.py. +""" + +# flake8: noqa +import torch + +annotated_args = { + torch._C._VariableFunctions._cast_Byte: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._cast_Char: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._cast_Double: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._cast_Float: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._cast_Int: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._cast_Long: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._cast_Short: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._cast_Half: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._make_dual: [{'is_kwarg_only': 'False', 'name': 'primal', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'tangent', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'level', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions._unpack_dual: [{'is_kwarg_only': 'False', 'name': 'dual', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'level', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.align_tensors: [{'is_kwarg_only': 'False', 'name': 'tensors', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._assert_async: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._assert_async: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'assert_msg', 'simple_type': 'c10::string_view'}], + torch._C._VariableFunctions._assert_scalar: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'assert_msg', 'simple_type': 'c10::string_view'}], + torch._C._VariableFunctions._functional_assert_scalar: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'assert_msg', 'simple_type': 'c10::string_view'}, {'is_kwarg_only': 'False', 'name': 'dep_token', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._functional_assert_async: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'assert_msg', 'simple_type': 'c10::string_view'}, {'is_kwarg_only': 'False', 'name': 'dep_token', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._assert_tensor_metadata: [{'is_kwarg_only': 'False', 'name': 'a', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._print: [{'is_kwarg_only': 'False', 'name': 's', 'simple_type': 'c10::string_view'}], + torch._C._VariableFunctions.sym_constrain_range: [{'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.sym_constrain_range_for_size: [{'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions._functional_sym_constrain_range: [{'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'min', 'simple_type': 'int64_t?'}, {'is_kwarg_only': 'False', 'name': 'max', 'simple_type': 'int64_t?'}, {'is_kwarg_only': 'False', 'name': 'dep_token', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._functional_sym_constrain_range_for_size: [{'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'min', 'simple_type': 'int64_t?'}, {'is_kwarg_only': 'False', 'name': 'max', 'simple_type': 'int64_t?'}, {'is_kwarg_only': 'False', 'name': 'dep_token', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._make_dep_token: [], + torch._C._VariableFunctions._use_cudnn_ctc_loss: [{'is_kwarg_only': 'False', 'name': 'log_probs', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'targets', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'input_lengths', 'simple_type': 'IntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'target_lengths', 'simple_type': 'IntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'blank', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions._use_cudnn_ctc_loss: [{'is_kwarg_only': 'False', 'name': 'log_probs', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'targets', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'input_lengths', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'target_lengths', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'blank', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions._cudnn_ctc_loss: [{'is_kwarg_only': 'False', 'name': 'log_probs', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'targets', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'input_lengths', 'simple_type': 'IntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'target_lengths', 'simple_type': 'IntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'blank', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'deterministic', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'zero_infinity', 'simple_type': 'bool'}], + torch._C._VariableFunctions._cudnn_ctc_loss: [{'is_kwarg_only': 'False', 'name': 'log_probs', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'targets', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'input_lengths', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'target_lengths', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'blank', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'deterministic', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'zero_infinity', 'simple_type': 'bool'}], + torch._C._VariableFunctions._use_cudnn_rnn_flatten_weight: [], + torch._C._VariableFunctions._cudnn_rnn_flatten_weight: [{'is_kwarg_only': 'False', 'name': 'weight_arr', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'weight_stride0', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'input_size', 'simple_type': 'SymInt'}, {'is_kwarg_only': 'False', 'name': 'mode', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'hidden_size', 'simple_type': 'SymInt'}, {'is_kwarg_only': 'False', 'name': 'proj_size', 'simple_type': 'SymInt'}, {'is_kwarg_only': 'False', 'name': 'num_layers', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'batch_first', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'bidirectional', 'simple_type': 'bool'}], + torch._C._VariableFunctions._cudnn_rnn: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'weight_stride0', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'weight_buf', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'hx', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'cx', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'mode', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'hidden_size', 'simple_type': 'SymInt'}, {'is_kwarg_only': 'False', 'name': 'proj_size', 'simple_type': 'SymInt'}, {'is_kwarg_only': 'False', 'name': 'num_layers', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'batch_first', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'dropout', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'train', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'bidirectional', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'batch_sizes', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'dropout_state', 'simple_type': 'Tensor?'}], + torch._C._VariableFunctions._cudnn_init_dropout_state: [{'is_kwarg_only': 'False', 'name': 'dropout', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'train', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'dropout_seed', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions._debug_has_internal_overlap: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._fused_dropout: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'p', 'simple_type': 'double'}], + torch._C._VariableFunctions._masked_scale: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mask', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'scale', 'simple_type': 'double'}], + torch._C._VariableFunctions.native_dropout: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'p', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'train', 'simple_type': 'bool?'}], + torch._C._VariableFunctions._sobol_engine_draw: [{'is_kwarg_only': 'False', 'name': 'quasi', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'sobolstate', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dimension', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'num_generated', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'dtype', 'simple_type': 'ScalarType?'}], + torch._C._VariableFunctions._sobol_engine_ff_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'sobolstate', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dimension', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'num_generated', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions._sobol_engine_scramble_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'ltm', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dimension', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions._sobol_engine_initialize_state_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dimension', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions._reshape_from_tensor: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'shape', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._shape_as_tensor: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.dropout: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'p', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'train', 'simple_type': 'bool'}], + torch._C._VariableFunctions.dropout_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'p', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'train', 'simple_type': 'bool'}], + torch._C._VariableFunctions.feature_dropout: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'p', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'train', 'simple_type': 'bool'}], + torch._C._VariableFunctions.feature_dropout_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'p', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'train', 'simple_type': 'bool'}], + torch._C._VariableFunctions.alpha_dropout: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'p', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'train', 'simple_type': 'bool'}], + torch._C._VariableFunctions.alpha_dropout_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'p', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'train', 'simple_type': 'bool'}], + torch._C._VariableFunctions.feature_alpha_dropout: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'p', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'train', 'simple_type': 'bool'}], + torch._C._VariableFunctions.feature_alpha_dropout_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'p', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'train', 'simple_type': 'bool'}], + torch._C._VariableFunctions.abs: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.abs: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.abs_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.absolute: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.absolute: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.angle: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.angle: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.view_as_real: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.view_as_complex: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.sgn: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.sgn: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.real: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.imag: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._conj: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.conj: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._conj_physical: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.conj_physical: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.conj_physical: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.conj_physical_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.resolve_conj: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.resolve_neg: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._neg_view: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.acos: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.acos: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.acos_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.arccos: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.arccos: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.arccos_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.avg_pool1d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'kernel_size', 'simple_type': 'IntArrayRef', 'size': 1}], + torch._C._VariableFunctions.adaptive_avg_pool1d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'IntArrayRef', 'size': 1}], + torch._C._VariableFunctions.adaptive_max_pool1d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'IntArrayRef', 'size': 1}], + torch._C._VariableFunctions.add: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.add: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._add_relu: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._add_relu: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._add_relu: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions._add_relu_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._add_relu_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.addmv: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'vec', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.addmv: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'vec', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.addmv_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'vec', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.addr: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'vec1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'vec2', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.addr: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'vec1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'vec2', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.affine_grid_generator: [{'is_kwarg_only': 'False', 'name': 'theta', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'align_corners', 'simple_type': 'bool'}], + torch._C._VariableFunctions._is_all_true: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._is_any_true: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._test_check_tensor: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._test_functorch_fallback: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.all: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.all: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.all: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.all: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.all: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch._C._VariableFunctions.all: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch._C._VariableFunctions.all: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.all: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.allclose: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.any: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.any: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.any: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.any: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.any: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch._C._VariableFunctions.any: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch._C._VariableFunctions.any: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.any: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.arange: [{'is_kwarg_only': 'False', 'name': 'end', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.arange: [{'is_kwarg_only': 'False', 'name': 'start', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'end', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.arange: [{'is_kwarg_only': 'False', 'name': 'start', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'end', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.arange: [{'is_kwarg_only': 'False', 'name': 'end', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.arange: [{'is_kwarg_only': 'False', 'name': 'start', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'end', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions._dim_arange: [{'is_kwarg_only': 'False', 'name': 'like', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.argmax: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.argmax: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.argmin: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.argmin: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.acosh: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.acosh: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.acosh_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.arccosh: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.arccosh: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.arccosh_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.asinh: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.asinh: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.asinh_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.arcsinh: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.arcsinh: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.arcsinh_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.atanh: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.atanh: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.atanh_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.arctanh: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.arctanh: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.arctanh_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.as_strided: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'stride', 'simple_type': 'SymIntArrayRef'}], + torch._C._VariableFunctions.as_strided_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'stride', 'simple_type': 'SymIntArrayRef'}], + torch._C._VariableFunctions.asin: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.asin: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.asin_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.arcsin: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.arcsin: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.arcsin_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.atan: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.atan: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.atan_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.arctan: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.arctan: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.arctan_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.atleast_1d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.atleast_1d: [{'is_kwarg_only': 'False', 'name': 'tensors', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions.atleast_2d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.atleast_2d: [{'is_kwarg_only': 'False', 'name': 'tensors', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions.atleast_3d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.atleast_3d: [{'is_kwarg_only': 'False', 'name': 'tensors', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions.baddbmm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'batch1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'batch2', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.baddbmm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'batch1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'batch2', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.baddbmm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'batch1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'batch2', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'out_dtype', 'simple_type': 'ScalarType'}], + torch._C._VariableFunctions.baddbmm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'batch1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'batch2', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'out_dtype', 'simple_type': 'ScalarType'}], + torch._C._VariableFunctions.bartlett_window: [{'is_kwarg_only': 'False', 'name': 'window_length', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.bartlett_window: [{'is_kwarg_only': 'False', 'name': 'window_length', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'periodic', 'simple_type': 'bool'}], + torch._C._VariableFunctions.batch_norm: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'bias', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'running_mean', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'running_var', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'training', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'momentum', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'eps', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'cudnn_enabled', 'simple_type': 'bool'}], + torch._C._VariableFunctions.quantized_batch_norm: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'bias', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'mean', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'var', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'eps', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'output_scale', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'output_zero_point', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions._batch_norm_impl_index: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'bias', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'running_mean', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'running_var', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'training', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'momentum', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'eps', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'cudnn_enabled', 'simple_type': 'bool'}], + torch._C._VariableFunctions.bernoulli: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.bernoulli: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.bernoulli: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'p', 'simple_type': 'double'}], + torch._C._VariableFunctions.bilinear: [{'is_kwarg_only': 'False', 'name': 'input1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'input2', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.binary_cross_entropy_with_logits: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'target', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.bincount: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.bitwise_not: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.bitwise_not: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.copysign: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.copysign: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.copysign: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.copysign: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions._lazy_clone: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.logical_not: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.logical_not: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.logical_xor: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.logical_xor: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.logical_and: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.logical_and: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.logical_or: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.logical_or: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.blackman_window: [{'is_kwarg_only': 'False', 'name': 'window_length', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.blackman_window: [{'is_kwarg_only': 'False', 'name': 'window_length', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'periodic', 'simple_type': 'bool'}], + torch._C._VariableFunctions.bmm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat2', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.bmm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat2', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.bmm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat2', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'out_dtype', 'simple_type': 'ScalarType'}], + torch._C._VariableFunctions.bmm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat2', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'out_dtype', 'simple_type': 'ScalarType'}], + torch._C._VariableFunctions.broadcast_tensors: [{'is_kwarg_only': 'False', 'name': 'tensors', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions.broadcast_to: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}], + torch._C._VariableFunctions._sparse_broadcast_to: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'IntArrayRef'}], + torch._C._VariableFunctions.cat: [{'is_kwarg_only': 'False', 'name': 'tensors', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions.cat: [{'is_kwarg_only': 'False', 'name': 'tensors', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions.cat: [{'is_kwarg_only': 'False', 'name': 'tensors', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch._C._VariableFunctions.cat: [{'is_kwarg_only': 'False', 'name': 'tensors', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch._C._VariableFunctions.concat: [{'is_kwarg_only': 'False', 'name': 'tensors', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions.concat: [{'is_kwarg_only': 'False', 'name': 'tensors', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions.concat: [{'is_kwarg_only': 'False', 'name': 'tensors', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch._C._VariableFunctions.concat: [{'is_kwarg_only': 'False', 'name': 'tensors', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch._C._VariableFunctions.concatenate: [{'is_kwarg_only': 'False', 'name': 'tensors', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions.concatenate: [{'is_kwarg_only': 'False', 'name': 'tensors', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions.concatenate: [{'is_kwarg_only': 'False', 'name': 'tensors', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch._C._VariableFunctions.concatenate: [{'is_kwarg_only': 'False', 'name': 'tensors', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch._C._VariableFunctions.block_diag: [{'is_kwarg_only': 'False', 'name': 'tensors', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions.ceil: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.ceil: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.ceil_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.chain_matmul: [{'is_kwarg_only': 'False', 'name': 'matrices', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions.chain_matmul: [{'is_kwarg_only': 'False', 'name': 'matrices', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions.unsafe_chunk: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'chunks', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.chunk: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'chunks', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.tensor_split: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'sections', 'simple_type': 'SymInt'}], + torch._C._VariableFunctions.tensor_split: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'indices', 'simple_type': 'SymIntArrayRef'}], + torch._C._VariableFunctions.tensor_split: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'tensor_indices_or_sections', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.clamp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.clamp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.clamp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.clamp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.clamp_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.clamp_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.clamp_max: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'max', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.clamp_max: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'max', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.clamp_max: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'max', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.clamp_max: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'max', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.clamp_max_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'max', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.clamp_max_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'max', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.clamp_min: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'min', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.clamp_min: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'min', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.clamp_min: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'min', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.clamp_min: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'min', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.clamp_min_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'min', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.clamp_min_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'min', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.clip: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.clip: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.clip: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.clip: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.clip_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.clip_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.cudnn_is_acceptable: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.complex: [{'is_kwarg_only': 'False', 'name': 'real', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'imag', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.complex: [{'is_kwarg_only': 'False', 'name': 'real', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'imag', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.polar: [{'is_kwarg_only': 'False', 'name': 'abs', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'angle', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.polar: [{'is_kwarg_only': 'False', 'name': 'abs', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'angle', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.constant_pad_nd: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'pad', 'simple_type': 'SymIntArrayRef'}], + torch._C._VariableFunctions.convolution: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'bias', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'stride', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'padding', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'dilation', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'transposed', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'output_padding', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'groups', 'simple_type': 'SymInt'}], + torch._C._VariableFunctions._convolution: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'bias', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'stride', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'padding', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'dilation', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'transposed', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'output_padding', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'groups', 'simple_type': 'SymInt'}, {'is_kwarg_only': 'False', 'name': 'benchmark', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'deterministic', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'cudnn_enabled', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'allow_tf32', 'simple_type': 'bool'}], + torch._C._VariableFunctions._convolution: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'bias', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'stride', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'padding', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'dilation', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'transposed', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'output_padding', 'simple_type': 'IntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'groups', 'simple_type': 'SymInt'}, {'is_kwarg_only': 'False', 'name': 'benchmark', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'deterministic', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'cudnn_enabled', 'simple_type': 'bool'}], + torch._C._VariableFunctions._convolution_mode: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'bias', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'stride', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'padding', 'simple_type': 'c10::string_view'}, {'is_kwarg_only': 'False', 'name': 'dilation', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'groups', 'simple_type': 'SymInt'}], + torch._C._VariableFunctions.conv1d: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.conv1d: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.conv2d: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.conv2d: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.conv3d: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.conv3d: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.conv_tbc: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'bias', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.conv_transpose1d: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.conv_transpose2d: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.conv_transpose3d: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._copy_from: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dst', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._copy_from_and_resize: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dst', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.cos: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.cos: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.cos_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.cosh: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.cosh: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.cosh_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.cosine_embedding_loss: [{'is_kwarg_only': 'False', 'name': 'input1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'input2', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'target', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.count_nonzero: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'IntArrayRef'}], + torch._C._VariableFunctions.count_nonzero: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.cov: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.corrcoef: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.cudnn_affine_grid_generator: [{'is_kwarg_only': 'False', 'name': 'theta', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'N', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'C', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'H', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'W', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.cudnn_batch_norm: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'bias', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'running_mean', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'running_var', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'training', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'exponential_average_factor', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'epsilon', 'simple_type': 'double'}], + torch._C._VariableFunctions.cudnn_batch_norm: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'bias', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'running_mean', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'running_var', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'training', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'exponential_average_factor', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'epsilon', 'simple_type': 'double'}], + torch._C._VariableFunctions.cudnn_convolution: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'padding', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'stride', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'dilation', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'groups', 'simple_type': 'SymInt'}, {'is_kwarg_only': 'False', 'name': 'benchmark', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'deterministic', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'allow_tf32', 'simple_type': 'bool'}], + torch._C._VariableFunctions.cudnn_convolution: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'padding', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'stride', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'dilation', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'groups', 'simple_type': 'SymInt'}, {'is_kwarg_only': 'False', 'name': 'benchmark', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'deterministic', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'allow_tf32', 'simple_type': 'bool'}], + torch._C._VariableFunctions.cudnn_convolution_transpose: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'padding', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'output_padding', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'stride', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'dilation', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'groups', 'simple_type': 'SymInt'}, {'is_kwarg_only': 'False', 'name': 'benchmark', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'deterministic', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'allow_tf32', 'simple_type': 'bool'}], + torch._C._VariableFunctions._mps_convolution_transpose: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'padding', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'output_padding', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'stride', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'dilation', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'groups', 'simple_type': 'SymInt'}], + torch._C._VariableFunctions.cudnn_convolution_relu: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'bias', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'stride', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'padding', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'dilation', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'groups', 'simple_type': 'SymInt'}], + torch._C._VariableFunctions.cudnn_convolution_add_relu: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'z', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'alpha', 'simple_type': 'Scalar?'}, {'is_kwarg_only': 'False', 'name': 'bias', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'stride', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'padding', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'dilation', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'groups', 'simple_type': 'SymInt'}], + torch._C._VariableFunctions.cudnn_grid_sampler: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'grid', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.cummax: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.cummax: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.cummax: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch._C._VariableFunctions.cummax: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch._C._VariableFunctions._cummax_helper: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'values', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'indices', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.cummin: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.cummin: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.cummin: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch._C._VariableFunctions.cummin: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch._C._VariableFunctions._cummin_helper: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'values', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'indices', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.cumprod: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.cumprod: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.cumprod: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch._C._VariableFunctions.cumprod: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch._C._VariableFunctions.cumsum: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.cumsum: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.cumsum: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch._C._VariableFunctions.cumsum: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch._C._VariableFunctions.cumulative_trapezoid: [{'is_kwarg_only': 'False', 'name': 'y', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.cumulative_trapezoid: [{'is_kwarg_only': 'False', 'name': 'y', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.ctc_loss: [{'is_kwarg_only': 'False', 'name': 'log_probs', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'targets', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'input_lengths', 'simple_type': 'IntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'target_lengths', 'simple_type': 'IntArrayRef'}], + torch._C._VariableFunctions.ctc_loss: [{'is_kwarg_only': 'False', 'name': 'log_probs', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'targets', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'input_lengths', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'target_lengths', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._ctc_loss: [{'is_kwarg_only': 'False', 'name': 'log_probs', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'targets', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'input_lengths', 'simple_type': 'IntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'target_lengths', 'simple_type': 'IntArrayRef'}], + torch._C._VariableFunctions._ctc_loss: [{'is_kwarg_only': 'False', 'name': 'log_probs', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'targets', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'input_lengths', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'target_lengths', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.diag_embed: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.diagflat: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.diagonal: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.diagonal: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.diff: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.diff: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.gradient: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.gradient: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'True', 'name': 'spacing', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'True', 'name': 'dim', 'simple_type': 'IntArrayRef'}], + torch._C._VariableFunctions.gradient: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'True', 'name': 'dim', 'simple_type': 'IntArrayRef'}], + torch._C._VariableFunctions.gradient: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'True', 'name': 'spacing', 'simple_type': 'ScalarList'}], + torch._C._VariableFunctions.gradient: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'True', 'name': 'spacing', 'simple_type': 'ScalarList'}, {'is_kwarg_only': 'True', 'name': 'dim', 'simple_type': 'IntArrayRef'}], + torch._C._VariableFunctions.gradient: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'True', 'name': 'spacing', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions.gradient: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'True', 'name': 'spacing', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'True', 'name': 'dim', 'simple_type': 'IntArrayRef'}], + torch._C._VariableFunctions.div: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.div: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.div: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'True', 'name': 'rounding_mode', 'simple_type': 'c10::string_view?'}], + torch._C._VariableFunctions.div: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'True', 'name': 'rounding_mode', 'simple_type': 'c10::string_view?'}], + torch._C._VariableFunctions.div: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'True', 'name': 'rounding_mode', 'simple_type': 'c10::string_view?'}], + torch._C._VariableFunctions.divide: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.divide: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.divide: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.divide: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'True', 'name': 'rounding_mode', 'simple_type': 'c10::string_view?'}], + torch._C._VariableFunctions.divide: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'True', 'name': 'rounding_mode', 'simple_type': 'c10::string_view?'}], + torch._C._VariableFunctions.divide: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'True', 'name': 'rounding_mode', 'simple_type': 'c10::string_view?'}], + torch._C._VariableFunctions.true_divide: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.true_divide: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.true_divide: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.dot: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'tensor', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.dot: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'tensor', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.vdot: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.vdot: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.einsum: [{'is_kwarg_only': 'False', 'name': 'equation', 'simple_type': 'c10::string_view'}, {'is_kwarg_only': 'False', 'name': 'tensors', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions.embedding: [{'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'indices', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.embedding_renorm_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'indices', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'max_norm', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'norm_type', 'simple_type': 'double'}], + torch._C._VariableFunctions._embedding_bag_forward_only: [{'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'indices', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'offsets', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._rowwise_prune: [{'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mask', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'compressed_indices_dtype', 'simple_type': 'ScalarType'}], + torch._C._VariableFunctions.row_stack: [{'is_kwarg_only': 'False', 'name': 'tensors', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions.row_stack: [{'is_kwarg_only': 'False', 'name': 'tensors', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions.embedding_bag: [{'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'indices', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'offsets', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.embedding_bag: [{'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'indices', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'offsets', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'scale_grad_by_freq', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'mode', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'sparse', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'per_sample_weights', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'include_last_offset', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'padding_idx', 'simple_type': 'int64_t?'}], + torch._C._VariableFunctions._embedding_bag: [{'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'indices', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'offsets', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.empty: [{'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'IntArrayRef'}, {'is_kwarg_only': 'True', 'name': 'names', 'simple_type': 'DimnameList?'}], + torch._C._VariableFunctions.empty: [{'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}], + torch._C._VariableFunctions.empty: [{'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}], + torch._C._VariableFunctions.empty_permuted: [{'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'physical_layout', 'simple_type': 'IntArrayRef'}], + torch._C._VariableFunctions._empty_affine_quantized: [{'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}], + torch._C._VariableFunctions._empty_per_channel_affine_quantized: [{'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'True', 'name': 'scales', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'True', 'name': 'zero_points', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'True', 'name': 'axis', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions._resize_output_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'device', 'simple_type': 'Device'}], + torch._C._VariableFunctions.empty_quantized: [{'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'IntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'qtensor', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.empty_like: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.empty_strided: [{'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'stride', 'simple_type': 'SymIntArrayRef'}], + torch._C._VariableFunctions.erf: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.erf: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.erf_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.erfc: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.erfc: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.erfc_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.exp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.exp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.exp_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.exp2: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.exp2: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.exp2_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.expm1: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.expm1: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.expm1_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.eye: [{'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'SymInt'}], + torch._C._VariableFunctions.eye: [{'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'SymInt'}, {'is_kwarg_only': 'False', 'name': 'm', 'simple_type': 'SymInt'}], + torch._C._VariableFunctions.eye: [{'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'SymInt'}], + torch._C._VariableFunctions.eye: [{'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'SymInt'}, {'is_kwarg_only': 'False', 'name': 'm', 'simple_type': 'SymInt'}], + torch._C._VariableFunctions.flatten: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.flatten: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'start_dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'end_dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'out_dim', 'simple_type': 'Dimname'}], + torch._C._VariableFunctions.flatten: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'start_dim', 'simple_type': 'Dimname'}, {'is_kwarg_only': 'False', 'name': 'end_dim', 'simple_type': 'Dimname'}, {'is_kwarg_only': 'False', 'name': 'out_dim', 'simple_type': 'Dimname'}], + torch._C._VariableFunctions.flatten: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dims', 'simple_type': 'DimnameList'}, {'is_kwarg_only': 'False', 'name': 'out_dim', 'simple_type': 'Dimname'}], + torch._C._VariableFunctions.unflatten: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'sizes', 'simple_type': 'SymIntArrayRef'}], + torch._C._VariableFunctions.unflatten: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}, {'is_kwarg_only': 'False', 'name': 'sizes', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'names', 'simple_type': 'DimnameList'}], + torch._C._VariableFunctions.fill: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.fill: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.fill_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.fill_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.floor: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.floor: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.floor_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.floor_divide: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.floor_divide: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.floor_divide: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.frac: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.frac: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.frac_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.full: [{'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'IntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'fill_value', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'True', 'name': 'names', 'simple_type': 'DimnameList?'}], + torch._C._VariableFunctions.full: [{'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'fill_value', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.full: [{'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'fill_value', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.full_like: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'fill_value', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.from_file: [{'is_kwarg_only': 'False', 'name': 'filename', 'simple_type': 'c10::string_view'}], + torch._C._VariableFunctions.gcd: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.gcd: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.gcd_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.lcm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.lcm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.lcm_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.grid_sampler: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'grid', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'interpolation_mode', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'padding_mode', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'align_corners', 'simple_type': 'bool'}], + torch._C._VariableFunctions.grid_sampler_2d: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'grid', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'interpolation_mode', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'padding_mode', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'align_corners', 'simple_type': 'bool'}], + torch._C._VariableFunctions._grid_sampler_2d_cpu_fallback: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'grid', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'interpolation_mode', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'padding_mode', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'align_corners', 'simple_type': 'bool'}], + torch._C._VariableFunctions.grid_sampler_3d: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'grid', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'interpolation_mode', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'padding_mode', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'align_corners', 'simple_type': 'bool'}], + torch._C._VariableFunctions.hann_window: [{'is_kwarg_only': 'False', 'name': 'window_length', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.hann_window: [{'is_kwarg_only': 'False', 'name': 'window_length', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'periodic', 'simple_type': 'bool'}], + torch._C._VariableFunctions.hamming_window: [{'is_kwarg_only': 'False', 'name': 'window_length', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.hamming_window: [{'is_kwarg_only': 'False', 'name': 'window_length', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'periodic', 'simple_type': 'bool'}], + torch._C._VariableFunctions.hamming_window: [{'is_kwarg_only': 'False', 'name': 'window_length', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'periodic', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'alpha', 'simple_type': 'double'}], + torch._C._VariableFunctions.hamming_window: [{'is_kwarg_only': 'False', 'name': 'window_length', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'periodic', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'alpha', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'beta', 'simple_type': 'double'}], + torch._C._VariableFunctions.kaiser_window: [{'is_kwarg_only': 'False', 'name': 'window_length', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.kaiser_window: [{'is_kwarg_only': 'False', 'name': 'window_length', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'periodic', 'simple_type': 'bool'}], + torch._C._VariableFunctions.kaiser_window: [{'is_kwarg_only': 'False', 'name': 'window_length', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'periodic', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'beta', 'simple_type': 'double'}], + torch._C._VariableFunctions.hinge_embedding_loss: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'target', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.group_norm: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'num_groups', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.native_group_norm: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'bias', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'N', 'simple_type': 'SymInt'}, {'is_kwarg_only': 'False', 'name': 'C', 'simple_type': 'SymInt'}, {'is_kwarg_only': 'False', 'name': 'HxW', 'simple_type': 'SymInt'}, {'is_kwarg_only': 'False', 'name': 'group', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'eps', 'simple_type': 'double'}], + torch._C._VariableFunctions._fft_r2c: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'IntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'normalization', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'onesided', 'simple_type': 'bool'}], + torch._C._VariableFunctions._fft_r2c: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'IntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'normalization', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'onesided', 'simple_type': 'bool'}], + torch._C._VariableFunctions._fft_c2r: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'IntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'normalization', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'last_dim_size', 'simple_type': 'SymInt'}], + torch._C._VariableFunctions._fft_c2r: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'IntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'normalization', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'last_dim_size', 'simple_type': 'SymInt'}], + torch._C._VariableFunctions._fft_c2c: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'normalization', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'forward', 'simple_type': 'bool'}], + torch._C._VariableFunctions._fft_c2c: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'normalization', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'forward', 'simple_type': 'bool'}], + torch._C._VariableFunctions._validate_compressed_sparse_indices: [{'is_kwarg_only': 'False', 'name': 'is_crow', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'compressed_idx', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'plain_idx', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'cdim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'nnz', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions._cufft_get_plan_cache_size: [{'is_kwarg_only': 'False', 'name': 'device_index', 'simple_type': 'DeviceIndex'}], + torch._C._VariableFunctions._cufft_get_plan_cache_max_size: [{'is_kwarg_only': 'False', 'name': 'device_index', 'simple_type': 'DeviceIndex'}], + torch._C._VariableFunctions._cufft_set_plan_cache_max_size: [{'is_kwarg_only': 'False', 'name': 'device_index', 'simple_type': 'DeviceIndex'}, {'is_kwarg_only': 'False', 'name': 'max_size', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions._cufft_clear_plan_cache: [{'is_kwarg_only': 'False', 'name': 'device_index', 'simple_type': 'DeviceIndex'}], + torch._C._VariableFunctions._unsafe_index: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'indices', 'simple_type': 'c10::List<::std::optional>'}], + torch._C._VariableFunctions._unsafe_masked_index: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mask', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'indices', 'simple_type': 'c10::List<::std::optional>'}, {'is_kwarg_only': 'False', 'name': 'fill', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions._unsafe_masked_index_put_accumulate: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mask', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'indices', 'simple_type': 'c10::List<::std::optional>'}, {'is_kwarg_only': 'False', 'name': 'values', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.index_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'source', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.index_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'source', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.index_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'source', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.index_put_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'indices', 'simple_type': 'c10::List<::std::optional>'}, {'is_kwarg_only': 'False', 'name': 'values', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.index_put: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'indices', 'simple_type': 'c10::List<::std::optional>'}, {'is_kwarg_only': 'False', 'name': 'values', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._unsafe_index_put: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'indices', 'simple_type': 'c10::List<::std::optional>'}, {'is_kwarg_only': 'False', 'name': 'values', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._index_put_impl_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'indices', 'simple_type': 'c10::List<::std::optional>'}, {'is_kwarg_only': 'False', 'name': 'values', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.instance_norm: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'bias', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'running_mean', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'running_var', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'use_input_stats', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'momentum', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'eps', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'cudnn_enabled', 'simple_type': 'bool'}], + torch._C._VariableFunctions.isclose: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.isin: [{'is_kwarg_only': 'False', 'name': 'elements', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'test_elements', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.isin: [{'is_kwarg_only': 'False', 'name': 'elements', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'test_elements', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.isin: [{'is_kwarg_only': 'False', 'name': 'elements', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'test_element', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.isin: [{'is_kwarg_only': 'False', 'name': 'elements', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'test_element', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.isin: [{'is_kwarg_only': 'False', 'name': 'element', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'test_elements', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.isin: [{'is_kwarg_only': 'False', 'name': 'element', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'test_elements', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.isnan: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.is_distributed: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.is_floating_point: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.is_complex: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.is_conj: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._is_zerotensor: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.is_neg: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.isreal: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.is_nonzero: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.is_same_size: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.is_signed: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.is_inference: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.kl_div: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'target', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.kron: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.kron: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.kthvalue: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'k', 'simple_type': 'SymInt'}], + torch._C._VariableFunctions.kthvalue: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'k', 'simple_type': 'SymInt'}], + torch._C._VariableFunctions.kthvalue: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'k', 'simple_type': 'SymInt'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch._C._VariableFunctions.kthvalue: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'k', 'simple_type': 'SymInt'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch._C._VariableFunctions.layer_norm: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'normalized_shape', 'simple_type': 'SymIntArrayRef'}], + torch._C._VariableFunctions.native_layer_norm: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'normalized_shape', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'bias', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'eps', 'simple_type': 'double'}], + torch._C._VariableFunctions.rms_norm: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'normalized_shape', 'simple_type': 'SymIntArrayRef'}], + torch._C._VariableFunctions._fused_rms_norm: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'normalized_shape', 'simple_type': 'IntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'eps', 'simple_type': 'double?'}], + torch._C._VariableFunctions.nan_to_num: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.nan_to_num: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.nan_to_num_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.mkldnn_linear_backward_weights: [{'is_kwarg_only': 'False', 'name': 'grad_output', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'bias_defined', 'simple_type': 'bool'}], + torch._C._VariableFunctions._cslt_compress: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._cslt_sparse_mm: [{'is_kwarg_only': 'False', 'name': 'compressed_A', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dense_B', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._cslt_sparse_mm_search: [{'is_kwarg_only': 'False', 'name': 'compressed_A', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dense_B', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._sparse_semi_structured_tile: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._sparse_semi_structured_apply: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'thread_masks', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._sparse_semi_structured_apply_dense: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'thread_masks', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._sparse_semi_structured_linear: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'meta', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._sparse_semi_structured_mm: [{'is_kwarg_only': 'False', 'name': 'mat1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat1_meta', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat2', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._sparse_semi_structured_addmm: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat1_meta', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat2', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._mixed_dtypes_linear: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'scale', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.fbgemm_linear_int8_weight_fp32_activation: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'packed', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'col_offsets', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight_scale', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'weight_zero_point', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'bias', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.fbgemm_linear_int8_weight: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'packed', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'col_offsets', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight_scale', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'weight_zero_point', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'bias', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.fbgemm_linear_quantize_weight: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.fbgemm_pack_gemm_matrix_fp16: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._wrapped_linear_prepack: [{'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight_scale', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight_zero_point', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'bias', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._wrapped_quantized_linear_prepacked: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'input_scale', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'input_zero_point', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'packed_weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_scale', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_zero_point', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'out_channel', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.fbgemm_linear_fp16_weight_fp32_activation: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'packed_weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'bias', 'simple_type': 'Tensor?'}], + torch._C._VariableFunctions.fbgemm_linear_fp16_weight_fp32_activation: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'packed_weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'bias', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'output', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.fbgemm_linear_fp16_weight: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'packed_weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'bias', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.fbgemm_linear_fp16_weight: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'packed_weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'bias', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.fbgemm_pack_quantized_matrix: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.fbgemm_pack_quantized_matrix: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'K', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'N', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.ldexp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.ldexp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.ldexp_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.linspace: [{'is_kwarg_only': 'False', 'name': 'start', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'end', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'steps', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.linspace: [{'is_kwarg_only': 'False', 'name': 'start', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'end', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'steps', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.linspace: [{'is_kwarg_only': 'False', 'name': 'start', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'end', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'steps', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.linspace: [{'is_kwarg_only': 'False', 'name': 'start', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'end', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'steps', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.linspace: [{'is_kwarg_only': 'False', 'name': 'start', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'end', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'steps', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.linspace: [{'is_kwarg_only': 'False', 'name': 'start', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'end', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'steps', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.linspace: [{'is_kwarg_only': 'False', 'name': 'start', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'end', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'steps', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.linspace: [{'is_kwarg_only': 'False', 'name': 'start', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'end', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'steps', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.log: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.log: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.log_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.log10: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.log10: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.log10_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.log1p: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.log1p: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.log1p_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.log2: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.log2: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.log2_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.logaddexp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.logaddexp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.logaddexp2: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.logaddexp2: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.xlogy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.xlogy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.xlogy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.xlogy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.xlogy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.xlogy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.xlogy_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.xlogy_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.logspace: [{'is_kwarg_only': 'False', 'name': 'start', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'end', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'steps', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.logspace: [{'is_kwarg_only': 'False', 'name': 'start', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'end', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'steps', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.logspace: [{'is_kwarg_only': 'False', 'name': 'start', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'end', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'steps', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.logspace: [{'is_kwarg_only': 'False', 'name': 'start', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'end', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'steps', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.logspace: [{'is_kwarg_only': 'False', 'name': 'start', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'end', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'steps', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.logspace: [{'is_kwarg_only': 'False', 'name': 'start', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'end', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'steps', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.logspace: [{'is_kwarg_only': 'False', 'name': 'start', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'end', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'steps', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.logspace: [{'is_kwarg_only': 'False', 'name': 'start', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'end', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'steps', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.log_softmax: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.log_softmax: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.log_softmax: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch._C._VariableFunctions._log_softmax: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'half_to_float', 'simple_type': 'bool'}], + torch._C._VariableFunctions._log_softmax: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'half_to_float', 'simple_type': 'bool'}], + torch._C._VariableFunctions._log_softmax_backward_data: [{'is_kwarg_only': 'False', 'name': 'grad_output', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'input_dtype', 'simple_type': 'ScalarType'}], + torch._C._VariableFunctions._log_softmax_backward_data: [{'is_kwarg_only': 'False', 'name': 'grad_output', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'input_dtype', 'simple_type': 'ScalarType'}], + torch._C._VariableFunctions._logcumsumexp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions._logcumsumexp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.logcumsumexp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.logcumsumexp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.logcumsumexp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch._C._VariableFunctions.logcumsumexp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch._C._VariableFunctions.logsumexp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'IntArrayRef', 'size': 1}], + torch._C._VariableFunctions.logsumexp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'IntArrayRef', 'size': 1}], + torch._C._VariableFunctions.logsumexp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'DimnameList', 'size': 1}], + torch._C._VariableFunctions.logsumexp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'DimnameList', 'size': 1}], + torch._C._VariableFunctions.margin_ranking_loss: [{'is_kwarg_only': 'False', 'name': 'input1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'input2', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'target', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.matmul: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.matmul: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.matrix_power: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.matrix_power: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.matrix_exp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._aminmax: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._aminmax: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.aminmax: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.aminmax: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._compute_linear_combination: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'coefficients', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._compute_linear_combination: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'coefficients', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.max: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.max: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.max: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch._C._VariableFunctions.max: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch._C._VariableFunctions.max: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.max: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.max: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.max: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.amax: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.amax: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.max_pool1d_with_indices: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'kernel_size', 'simple_type': 'IntArrayRef', 'size': 1}], + torch._C._VariableFunctions.max_pool1d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'kernel_size', 'simple_type': 'IntArrayRef', 'size': 1}], + torch._C._VariableFunctions.max_pool2d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'kernel_size', 'simple_type': 'IntArrayRef', 'size': 2}], + torch._C._VariableFunctions.mkldnn_max_pool2d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'kernel_size', 'simple_type': 'IntArrayRef', 'size': 2}], + torch._C._VariableFunctions.mkldnn_max_pool3d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'kernel_size', 'simple_type': 'IntArrayRef', 'size': 3}], + torch._C._VariableFunctions.quantized_max_pool1d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'kernel_size', 'simple_type': 'IntArrayRef', 'size': 1}], + torch._C._VariableFunctions.quantized_max_pool2d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'kernel_size', 'simple_type': 'IntArrayRef', 'size': 2}], + torch._C._VariableFunctions.quantized_max_pool3d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'kernel_size', 'simple_type': 'IntArrayRef', 'size': 3}], + torch._C._VariableFunctions.max_pool3d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'kernel_size', 'simple_type': 'IntArrayRef', 'size': 3}], + torch._C._VariableFunctions.mean: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.mean: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.mean: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'IntArrayRef?', 'size': 1}], + torch._C._VariableFunctions.mean: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'IntArrayRef?', 'size': 1}], + torch._C._VariableFunctions.mean: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'DimnameList', 'size': 1}], + torch._C._VariableFunctions.mean: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'DimnameList', 'size': 1}], + torch._C._VariableFunctions.nanmean: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.nanmean: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.median: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.median: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.median: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.median: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch._C._VariableFunctions.median: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch._C._VariableFunctions.nanmedian: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.nanmedian: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.nanmedian: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.nanmedian: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch._C._VariableFunctions.nanmedian: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch._C._VariableFunctions.min: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.min: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.min: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch._C._VariableFunctions.min: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch._C._VariableFunctions.min: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.min: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.min: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.min: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.amin: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.amin: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._mps_convolution: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'bias', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'padding', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'stride', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'dilation', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'groups', 'simple_type': 'SymInt'}], + torch._C._VariableFunctions.mkldnn_convolution: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'bias', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'padding', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'stride', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'dilation', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'groups', 'simple_type': 'SymInt'}], + torch._C._VariableFunctions.mkldnn_rnn_layer: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight0', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight2', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight3', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'hx_', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'cx_', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'reverse', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'batch_sizes', 'simple_type': 'IntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'mode', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'hidden_size', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'num_layers', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'has_biases', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'bidirectional', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'batch_first', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'train', 'simple_type': 'bool'}], + torch._C._VariableFunctions.miopen_batch_norm: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'bias', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'running_mean', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'running_var', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'training', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'exponential_average_factor', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'epsilon', 'simple_type': 'double'}], + torch._C._VariableFunctions.miopen_convolution: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'bias', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'padding', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'stride', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'dilation', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'groups', 'simple_type': 'SymInt'}, {'is_kwarg_only': 'False', 'name': 'benchmark', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'deterministic', 'simple_type': 'bool'}], + torch._C._VariableFunctions.miopen_convolution_transpose: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'bias', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'padding', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'output_padding', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'stride', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'dilation', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'groups', 'simple_type': 'SymInt'}, {'is_kwarg_only': 'False', 'name': 'benchmark', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'deterministic', 'simple_type': 'bool'}], + torch._C._VariableFunctions.miopen_depthwise_convolution: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'bias', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'padding', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'stride', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'dilation', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'groups', 'simple_type': 'SymInt'}, {'is_kwarg_only': 'False', 'name': 'benchmark', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'deterministic', 'simple_type': 'bool'}], + torch._C._VariableFunctions.miopen_convolution_relu: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'bias', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'stride', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'padding', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'dilation', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'groups', 'simple_type': 'SymInt'}], + torch._C._VariableFunctions.miopen_convolution_add_relu: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'z', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'alpha', 'simple_type': 'Scalar?'}, {'is_kwarg_only': 'False', 'name': 'bias', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'stride', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'padding', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'dilation', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'groups', 'simple_type': 'SymInt'}], + torch._C._VariableFunctions.miopen_rnn: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'weight_stride0', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'hx', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'cx', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'mode', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'hidden_size', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'num_layers', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'batch_first', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'dropout', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'train', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'bidirectional', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'batch_sizes', 'simple_type': 'IntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'dropout_state', 'simple_type': 'Tensor?'}], + torch._C._VariableFunctions.mm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat2', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.mm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat2', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.mm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat2', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'out_dtype', 'simple_type': 'ScalarType'}], + torch._C._VariableFunctions.mm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat2', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'out_dtype', 'simple_type': 'ScalarType'}], + torch._C._VariableFunctions._int_mm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat2', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._int_mm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat2', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._convert_weight_to_int4pack: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'innerKTiles', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions._weight_int4pack_mm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat2', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'qGroupSize', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'qScaleAndZeros', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._weight_int4pack_mm_with_scales_and_zeros: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat2', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'qGroupSize', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'qScale', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'qZeros', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._convert_weight_to_int4pack_for_cpu: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'innerKTiles', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions._weight_int4pack_mm_for_cpu: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat2', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'qGroupSize', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'qScaleAndZeros', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._dyn_quant_pack_4bit_weight: [{'is_kwarg_only': 'False', 'name': 'weights', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'scales_zeros', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'bias', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'block_size', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'in_features', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'out_features', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions._dyn_quant_matmul_4bit: [{'is_kwarg_only': 'False', 'name': 'inp', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'packed_weights', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'block_size', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'in_features', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'out_features', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions._weight_int8pack_mm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat2', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'scales', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._sparse_sparse_matmul: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.mode: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.mode: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.mode: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch._C._VariableFunctions.mode: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch._C._VariableFunctions.mul: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.mul: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.multiply: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.multiply: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.multiply: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.mv: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'vec', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.mv: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'vec', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.mvlgamma: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'p', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.mvlgamma: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'p', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.narrow_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'start', 'simple_type': 'SymInt'}, {'is_kwarg_only': 'False', 'name': 'length', 'simple_type': 'SymInt'}], + torch._C._VariableFunctions.narrow_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'start', 'simple_type': 'SymInt'}, {'is_kwarg_only': 'False', 'name': 'length', 'simple_type': 'SymInt'}], + torch._C._VariableFunctions.narrow: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'start', 'simple_type': 'SymInt'}, {'is_kwarg_only': 'False', 'name': 'length', 'simple_type': 'SymInt'}], + torch._C._VariableFunctions.narrow: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'start', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'length', 'simple_type': 'SymInt'}], + torch._C._VariableFunctions.native_batch_norm: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'bias', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'running_mean', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'running_var', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'training', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'momentum', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'eps', 'simple_type': 'double'}], + torch._C._VariableFunctions.native_batch_norm: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'bias', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'running_mean', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'running_var', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'training', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'momentum', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'eps', 'simple_type': 'double'}], + torch._C._VariableFunctions._native_batch_norm_legit: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'bias', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'running_mean', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'running_var', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'training', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'momentum', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'eps', 'simple_type': 'double'}], + torch._C._VariableFunctions._native_batch_norm_legit: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'bias', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'running_mean', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'running_var', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'training', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'momentum', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'eps', 'simple_type': 'double'}], + torch._C._VariableFunctions._native_batch_norm_legit: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'bias', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'training', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'momentum', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'eps', 'simple_type': 'double'}], + torch._C._VariableFunctions._native_batch_norm_legit: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'bias', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'training', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'momentum', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'eps', 'simple_type': 'double'}], + torch._C._VariableFunctions._native_batch_norm_legit_no_training: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'bias', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'running_mean', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'running_var', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'momentum', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'eps', 'simple_type': 'double'}], + torch._C._VariableFunctions.batch_norm_stats: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'eps', 'simple_type': 'double'}], + torch._C._VariableFunctions.batch_norm_elemt: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'bias', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'mean', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'invstd', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'eps', 'simple_type': 'double'}], + torch._C._VariableFunctions.batch_norm_elemt: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'bias', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'mean', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'invstd', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'eps', 'simple_type': 'double'}], + torch._C._VariableFunctions.batch_norm_gather_stats: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mean', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'invstd', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'running_mean', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'running_var', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'momentum', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'eps', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'count', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.batch_norm_gather_stats_with_counts: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mean', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'invstd', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'running_mean', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'running_var', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'momentum', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'eps', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'counts', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.batch_norm_backward_reduce: [{'is_kwarg_only': 'False', 'name': 'grad_out', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mean', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'invstd', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'input_g', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'weight_g', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'bias_g', 'simple_type': 'bool'}], + torch._C._VariableFunctions.batch_norm_backward_elemt: [{'is_kwarg_only': 'False', 'name': 'grad_out', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mean', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'invstd', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'sum_dy', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'sum_dy_xmu', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'count', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.batch_norm_update_stats: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'running_mean', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'running_var', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'momentum', 'simple_type': 'double'}], + torch._C._VariableFunctions.is_vulkan_available: [], + torch._C._VariableFunctions._nnpack_available: [], + torch._C._VariableFunctions._nnpack_spatial_convolution: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'bias', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'padding', 'simple_type': 'SymIntArrayRef', 'size': 2}], + torch._C._VariableFunctions.ones: [{'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'IntArrayRef'}, {'is_kwarg_only': 'True', 'name': 'names', 'simple_type': 'DimnameList?'}], + torch._C._VariableFunctions.ones: [{'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}], + torch._C._VariableFunctions.ones: [{'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}], + torch._C._VariableFunctions.ones_like: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.pairwise_distance: [{'is_kwarg_only': 'False', 'name': 'x1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'x2', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.cdist: [{'is_kwarg_only': 'False', 'name': 'x1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'x2', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._euclidean_dist: [{'is_kwarg_only': 'False', 'name': 'x1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'x2', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.pdist: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.cosine_similarity: [{'is_kwarg_only': 'False', 'name': 'x1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'x2', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.permute: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dims', 'simple_type': 'IntArrayRef'}], + torch._C._VariableFunctions.movedim: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'source', 'simple_type': 'IntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'destination', 'simple_type': 'IntArrayRef'}], + torch._C._VariableFunctions.movedim: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'source', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'destination', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.moveaxis: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'source', 'simple_type': 'IntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'destination', 'simple_type': 'IntArrayRef'}], + torch._C._VariableFunctions.moveaxis: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'source', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'destination', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.adjoint: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.pixel_shuffle: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'upscale_factor', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.pixel_unshuffle: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'downscale_factor', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.channel_shuffle: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'groups', 'simple_type': 'SymInt'}], + torch._C._VariableFunctions.native_channel_shuffle: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'groups', 'simple_type': 'SymInt'}], + torch._C._VariableFunctions._pin_memory: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.pinverse: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.poisson_nll_loss: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'target', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'log_input', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'full', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'eps', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'reduction', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.rad2deg: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.rad2deg: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.rad2deg_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.deg2rad: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.deg2rad: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.deg2rad_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.scalar_tensor: [{'is_kwarg_only': 'False', 'name': 's', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.rand: [{'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'True', 'name': 'names', 'simple_type': 'DimnameList?'}], + torch._C._VariableFunctions.rand: [{'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'True', 'name': 'generator', 'simple_type': 'Generator?'}, {'is_kwarg_only': 'True', 'name': 'names', 'simple_type': 'DimnameList?'}], + torch._C._VariableFunctions.rand: [{'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}], + torch._C._VariableFunctions.rand: [{'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'True', 'name': 'generator', 'simple_type': 'Generator?'}], + torch._C._VariableFunctions.rand: [{'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}], + torch._C._VariableFunctions.rand: [{'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'True', 'name': 'generator', 'simple_type': 'Generator?'}], + torch._C._VariableFunctions.rand_like: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.rand_like: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'True', 'name': 'generator', 'simple_type': 'Generator?'}], + torch._C._VariableFunctions.randint: [{'is_kwarg_only': 'False', 'name': 'high', 'simple_type': 'SymInt'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}], + torch._C._VariableFunctions.randint: [{'is_kwarg_only': 'False', 'name': 'high', 'simple_type': 'SymInt'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'True', 'name': 'generator', 'simple_type': 'Generator?'}], + torch._C._VariableFunctions.randint: [{'is_kwarg_only': 'False', 'name': 'low', 'simple_type': 'SymInt'}, {'is_kwarg_only': 'False', 'name': 'high', 'simple_type': 'SymInt'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}], + torch._C._VariableFunctions.randint: [{'is_kwarg_only': 'False', 'name': 'low', 'simple_type': 'SymInt'}, {'is_kwarg_only': 'False', 'name': 'high', 'simple_type': 'SymInt'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'True', 'name': 'generator', 'simple_type': 'Generator?'}], + torch._C._VariableFunctions.randint: [{'is_kwarg_only': 'False', 'name': 'high', 'simple_type': 'SymInt'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}], + torch._C._VariableFunctions.randint: [{'is_kwarg_only': 'False', 'name': 'high', 'simple_type': 'SymInt'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'True', 'name': 'generator', 'simple_type': 'Generator?'}], + torch._C._VariableFunctions.randint: [{'is_kwarg_only': 'False', 'name': 'low', 'simple_type': 'SymInt'}, {'is_kwarg_only': 'False', 'name': 'high', 'simple_type': 'SymInt'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}], + torch._C._VariableFunctions.randint: [{'is_kwarg_only': 'False', 'name': 'low', 'simple_type': 'SymInt'}, {'is_kwarg_only': 'False', 'name': 'high', 'simple_type': 'SymInt'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'True', 'name': 'generator', 'simple_type': 'Generator?'}], + torch._C._VariableFunctions.randint_like: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'high', 'simple_type': 'SymInt'}], + torch._C._VariableFunctions.randint_like: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'high', 'simple_type': 'SymInt'}, {'is_kwarg_only': 'True', 'name': 'generator', 'simple_type': 'Generator?'}], + torch._C._VariableFunctions.randint_like: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'high', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.randint_like: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'high', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'True', 'name': 'generator', 'simple_type': 'Generator?'}], + torch._C._VariableFunctions.randint_like: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'low', 'simple_type': 'SymInt'}, {'is_kwarg_only': 'False', 'name': 'high', 'simple_type': 'SymInt'}], + torch._C._VariableFunctions.randint_like: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'low', 'simple_type': 'SymInt'}, {'is_kwarg_only': 'False', 'name': 'high', 'simple_type': 'SymInt'}, {'is_kwarg_only': 'True', 'name': 'generator', 'simple_type': 'Generator?'}], + torch._C._VariableFunctions.randn: [{'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}], + torch._C._VariableFunctions.randn: [{'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'True', 'name': 'generator', 'simple_type': 'Generator?'}], + torch._C._VariableFunctions.randn: [{'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'True', 'name': 'names', 'simple_type': 'DimnameList?'}], + torch._C._VariableFunctions.randn: [{'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'True', 'name': 'generator', 'simple_type': 'Generator?'}, {'is_kwarg_only': 'True', 'name': 'names', 'simple_type': 'DimnameList?'}], + torch._C._VariableFunctions.randn: [{'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}], + torch._C._VariableFunctions.randn: [{'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'True', 'name': 'generator', 'simple_type': 'Generator?'}], + torch._C._VariableFunctions.randn_like: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.randn_like: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'True', 'name': 'generator', 'simple_type': 'Generator?'}], + torch._C._VariableFunctions.randperm: [{'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'SymInt'}], + torch._C._VariableFunctions.randperm: [{'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'SymInt'}, {'is_kwarg_only': 'True', 'name': 'generator', 'simple_type': 'Generator?'}], + torch._C._VariableFunctions.randperm: [{'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'SymInt'}], + torch._C._VariableFunctions.randperm: [{'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'SymInt'}, {'is_kwarg_only': 'True', 'name': 'generator', 'simple_type': 'Generator?'}], + torch._C._VariableFunctions.ravel: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.reciprocal: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.reciprocal: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.reciprocal_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.neg: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.neg: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.neg_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.negative: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.negative: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.negative_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.repeat_interleave: [{'is_kwarg_only': 'False', 'name': 'repeats', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.repeat_interleave: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'repeats', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.repeat_interleave: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'repeats', 'simple_type': 'SymInt'}], + torch._C._VariableFunctions.reshape: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'shape', 'simple_type': 'SymIntArrayRef'}], + torch._C._VariableFunctions._mkldnn_reshape: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'shape', 'simple_type': 'IntArrayRef'}], + torch._C._VariableFunctions.round: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.round: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.round: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.round: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.round_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.round_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.rrelu: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.rrelu_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.relu: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.relu_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.prelu: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._prelu_kernel: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.hardshrink: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.hardshrink: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.rsqrt: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.rsqrt: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.rsqrt_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.select: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.select: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'SymInt'}], + torch._C._VariableFunctions.selu: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.selu_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.celu: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.celu_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.sigmoid: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.sigmoid: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.sigmoid_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.logit: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.logit: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.logit_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.sin: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.sin: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.sin_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.sinc: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.sinc: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.sinc_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.sinh: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.sinh: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.sinh_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.detach: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.detach_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.slice_inverse: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'src', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.slice_scatter: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'src', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.slice_scatter: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'src', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.select_scatter: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'src', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'SymInt'}], + torch._C._VariableFunctions.diagonal_scatter: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'src', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.as_strided_scatter: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'src', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'stride', 'simple_type': 'SymIntArrayRef'}], + torch._C._VariableFunctions.smm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat2', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.softmax: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.softmax: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.softmax: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch._C._VariableFunctions._softmax: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'half_to_float', 'simple_type': 'bool'}], + torch._C._VariableFunctions._softmax: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'half_to_float', 'simple_type': 'bool'}], + torch._C._VariableFunctions._softmax_backward_data: [{'is_kwarg_only': 'False', 'name': 'grad_output', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'input_dtype', 'simple_type': 'ScalarType'}], + torch._C._VariableFunctions._softmax_backward_data: [{'is_kwarg_only': 'False', 'name': 'grad_output', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'input_dtype', 'simple_type': 'ScalarType'}], + torch._C._VariableFunctions.unsafe_split: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'split_size', 'simple_type': 'SymInt'}], + torch._C._VariableFunctions.split: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'split_size', 'simple_type': 'SymInt'}], + torch._C._VariableFunctions.split: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'split_size', 'simple_type': 'SymIntArrayRef'}], + torch._C._VariableFunctions.unsafe_split_with_sizes: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'split_sizes', 'simple_type': 'SymIntArrayRef'}], + torch._C._VariableFunctions.split_with_sizes: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'split_sizes', 'simple_type': 'SymIntArrayRef'}], + torch._C._VariableFunctions.hsplit: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'sections', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.hsplit: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'indices', 'simple_type': 'IntArrayRef'}], + torch._C._VariableFunctions.vsplit: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'sections', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.vsplit: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'indices', 'simple_type': 'IntArrayRef'}], + torch._C._VariableFunctions.dsplit: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'sections', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.dsplit: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'indices', 'simple_type': 'IntArrayRef'}], + torch._C._VariableFunctions.squeeze: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.squeeze: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.squeeze: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch._C._VariableFunctions.squeeze: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'IntArrayRef'}], + torch._C._VariableFunctions.sspaddmm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat2', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.sspaddmm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat2', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._chunk_cat: [{'is_kwarg_only': 'False', 'name': 'tensors', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'num_chunks', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions._chunk_cat: [{'is_kwarg_only': 'False', 'name': 'tensors', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'num_chunks', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.stack: [{'is_kwarg_only': 'False', 'name': 'tensors', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions.stack: [{'is_kwarg_only': 'False', 'name': 'tensors', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._stack: [{'is_kwarg_only': 'False', 'name': 'tensors', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._stack: [{'is_kwarg_only': 'False', 'name': 'tensors', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions.hstack: [{'is_kwarg_only': 'False', 'name': 'tensors', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions.hstack: [{'is_kwarg_only': 'False', 'name': 'tensors', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions.vstack: [{'is_kwarg_only': 'False', 'name': 'tensors', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions.vstack: [{'is_kwarg_only': 'False', 'name': 'tensors', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions.dstack: [{'is_kwarg_only': 'False', 'name': 'tensors', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions.dstack: [{'is_kwarg_only': 'False', 'name': 'tensors', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions.stft: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n_fft', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.stft: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n_fft', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.istft: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n_fft', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.sum: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.sum: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'IntArrayRef?', 'size': 1}], + torch._C._VariableFunctions.sum: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'DimnameList', 'size': 1}], + torch._C._VariableFunctions.sum: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'IntArrayRef?', 'size': 1}], + torch._C._VariableFunctions.sum: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'DimnameList', 'size': 1}], + torch._C._VariableFunctions.nansum: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.nansum: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.hash_tensor: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.hash_tensor: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.sqrt: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.sqrt: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.sqrt_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.square: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.square: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.square_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.std: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.std: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'IntArrayRef?', 'size': 1}], + torch._C._VariableFunctions.std: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.std: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'IntArrayRef?', 'size': 1}], + torch._C._VariableFunctions.std: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.std: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'DimnameList', 'size': 1}], + torch._C._VariableFunctions.std: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'DimnameList', 'size': 1}], + torch._C._VariableFunctions.std: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'DimnameList', 'size': 1}], + torch._C._VariableFunctions.std: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'DimnameList', 'size': 1}], + torch._C._VariableFunctions.std_mean: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.std_mean: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'IntArrayRef?', 'size': 1}], + torch._C._VariableFunctions.std_mean: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.std_mean: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'DimnameList', 'size': 1}], + torch._C._VariableFunctions.std_mean: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'DimnameList', 'size': 1}], + torch._C._VariableFunctions.prod: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.prod: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.prod: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.prod: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch._C._VariableFunctions.prod: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch._C._VariableFunctions.t: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.tan: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.tan: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.tan_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.tanh: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.tanh: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.tanh_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.tensordot: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dims_self', 'simple_type': 'IntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'dims_other', 'simple_type': 'IntArrayRef'}], + torch._C._VariableFunctions.tensordot: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dims_self', 'simple_type': 'IntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'dims_other', 'simple_type': 'IntArrayRef'}], + torch._C._VariableFunctions.threshold: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'threshold', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.threshold: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'threshold', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.threshold_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'threshold', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.tile: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dims', 'simple_type': 'SymIntArrayRef'}], + torch._C._VariableFunctions.transpose: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim0', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'dim1', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.transpose: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim0', 'simple_type': 'Dimname'}, {'is_kwarg_only': 'False', 'name': 'dim1', 'simple_type': 'Dimname'}], + torch._C._VariableFunctions._mkldnn_transpose: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim0', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'dim1', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions._mkldnn_transpose_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim0', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'dim1', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.flip: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dims', 'simple_type': 'IntArrayRef'}], + torch._C._VariableFunctions.fliplr: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.flipud: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.roll: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'shifts', 'simple_type': 'SymIntArrayRef', 'size': 1}], + torch._C._VariableFunctions.rot90: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.trapezoid: [{'is_kwarg_only': 'False', 'name': 'y', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.trapezoid: [{'is_kwarg_only': 'False', 'name': 'y', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.trapz: [{'is_kwarg_only': 'False', 'name': 'y', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.trapz: [{'is_kwarg_only': 'False', 'name': 'y', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._transform_bias_rescale_qkv: [{'is_kwarg_only': 'False', 'name': 'qkv', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'qkv_bias', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'num_heads', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions._nested_tensor_from_mask: [{'is_kwarg_only': 'False', 'name': 't', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mask', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._nested_tensor_from_mask_left_aligned: [{'is_kwarg_only': 'False', 'name': 't', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mask', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._nested_from_padded: [{'is_kwarg_only': 'False', 'name': 'padded', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'cpu_nested_shape_example', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._nested_from_padded_and_nested_example: [{'is_kwarg_only': 'False', 'name': 'padded', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'nt_example', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._nested_view_from_buffer: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'nested_size', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'nested_strides', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'offsets', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._nested_view_from_buffer_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'nested_size', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'nested_strides', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'offsets', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._nested_view_from_buffer_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'nested_size', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'nested_strides', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'offsets', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._nested_view_from_jagged: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'offsets', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dummy', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._nested_view_from_jagged_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'offsets', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dummy', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._nested_view_from_jagged_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'offsets', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dummy', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._nested_get_values: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._nested_get_values_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._nested_get_values_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._nested_get_offsets: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._nested_get_lengths: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._nested_get_ragged_idx: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._nested_get_min_seqlen: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._nested_get_max_seqlen: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._nested_get_jagged_dummy: [{'is_kwarg_only': 'False', 'name': 'any', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._nested_compute_contiguous_strides_offsets: [{'is_kwarg_only': 'False', 'name': 'nested_size', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._trilinear: [{'is_kwarg_only': 'False', 'name': 'i1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'i2', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'i3', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'expand1', 'simple_type': 'IntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'expand2', 'simple_type': 'IntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'expand3', 'simple_type': 'IntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'sumdim', 'simple_type': 'IntArrayRef'}], + torch._C._VariableFunctions.triplet_margin_loss: [{'is_kwarg_only': 'False', 'name': 'anchor', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'positive', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'negative', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.trunc: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.trunc: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.trunc_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.fix: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.fix: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.fix_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._has_compatible_shallow_copy_type: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'from', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._unique: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.unique_dim: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.unique_consecutive: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._unique2: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.unsqueeze: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.vander: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.var: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.var: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'IntArrayRef?', 'size': 1}], + torch._C._VariableFunctions.var: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.var: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'IntArrayRef?', 'size': 1}], + torch._C._VariableFunctions.var: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.var: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'DimnameList', 'size': 1}], + torch._C._VariableFunctions.var: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'DimnameList', 'size': 1}], + torch._C._VariableFunctions.var: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'DimnameList', 'size': 1}], + torch._C._VariableFunctions.var: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'DimnameList', 'size': 1}], + torch._C._VariableFunctions.var_mean: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.var_mean: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'IntArrayRef?', 'size': 1}], + torch._C._VariableFunctions.var_mean: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.var_mean: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'DimnameList', 'size': 1}], + torch._C._VariableFunctions.var_mean: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'DimnameList', 'size': 1}], + torch._C._VariableFunctions.where: [{'is_kwarg_only': 'False', 'name': 'condition', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.where: [{'is_kwarg_only': 'False', 'name': 'condition', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.where: [{'is_kwarg_only': 'False', 'name': 'condition', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.where: [{'is_kwarg_only': 'False', 'name': 'condition', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.where: [{'is_kwarg_only': 'False', 'name': 'condition', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.where: [{'is_kwarg_only': 'False', 'name': 'condition', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.norm_except_dim: [{'is_kwarg_only': 'False', 'name': 'v', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._weight_norm: [{'is_kwarg_only': 'False', 'name': 'v', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'g', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._weight_norm_interface: [{'is_kwarg_only': 'False', 'name': 'v', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'g', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.zeros: [{'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'IntArrayRef'}, {'is_kwarg_only': 'True', 'name': 'names', 'simple_type': 'DimnameList?'}], + torch._C._VariableFunctions.zeros: [{'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}], + torch._C._VariableFunctions.zeros: [{'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}], + torch._C._VariableFunctions._efficientzerotensor: [{'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}], + torch._C._VariableFunctions.zeros_like: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._standard_gamma_grad: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._standard_gamma: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._dirichlet_grad: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'alpha', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'total', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._sample_dirichlet: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.poisson: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.binomial: [{'is_kwarg_only': 'False', 'name': 'count', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'prob', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.native_norm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.native_norm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'p', 'simple_type': 'Scalar?'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'IntArrayRef', 'size': 1}, {'is_kwarg_only': 'False', 'name': 'keepdim', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'dtype', 'simple_type': 'ScalarType?'}], + torch._C._VariableFunctions._sparse_sum: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._sparse_sum: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'True', 'name': 'dtype', 'simple_type': 'ScalarType'}], + torch._C._VariableFunctions._sparse_sum: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'IntArrayRef', 'size': 1}], + torch._C._VariableFunctions._sparse_sum: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'IntArrayRef', 'size': 1}, {'is_kwarg_only': 'True', 'name': 'dtype', 'simple_type': 'ScalarType'}], + torch._C._VariableFunctions._sparse_csr_sum: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'IntArrayRef', 'size': 1}], + torch._C._VariableFunctions._sparse_csr_prod: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'IntArrayRef', 'size': 1}], + torch._C._VariableFunctions._sparse_softmax_backward_data: [{'is_kwarg_only': 'False', 'name': 'grad_output', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._sparse_log_softmax_backward_data: [{'is_kwarg_only': 'False', 'name': 'grad_output', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.norm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'p', 'simple_type': 'Scalar?'}, {'is_kwarg_only': 'True', 'name': 'dtype', 'simple_type': 'ScalarType'}], + torch._C._VariableFunctions.norm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.norm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'p', 'simple_type': 'Scalar?'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'IntArrayRef', 'size': 1}, {'is_kwarg_only': 'False', 'name': 'keepdim', 'simple_type': 'bool'}, {'is_kwarg_only': 'True', 'name': 'dtype', 'simple_type': 'ScalarType'}], + torch._C._VariableFunctions.norm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'p', 'simple_type': 'Scalar?'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'IntArrayRef', 'size': 1}], + torch._C._VariableFunctions.norm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'p', 'simple_type': 'Scalar?'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'IntArrayRef', 'size': 1}, {'is_kwarg_only': 'False', 'name': 'keepdim', 'simple_type': 'bool'}, {'is_kwarg_only': 'True', 'name': 'dtype', 'simple_type': 'ScalarType'}], + torch._C._VariableFunctions.norm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'p', 'simple_type': 'Scalar?'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'IntArrayRef', 'size': 1}], + torch._C._VariableFunctions.norm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'p', 'simple_type': 'Scalar?'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'DimnameList', 'size': 1}, {'is_kwarg_only': 'False', 'name': 'keepdim', 'simple_type': 'bool'}, {'is_kwarg_only': 'True', 'name': 'dtype', 'simple_type': 'ScalarType'}], + torch._C._VariableFunctions.norm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'p', 'simple_type': 'Scalar?'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'DimnameList', 'size': 1}], + torch._C._VariableFunctions.norm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'p', 'simple_type': 'Scalar?'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'DimnameList', 'size': 1}, {'is_kwarg_only': 'False', 'name': 'keepdim', 'simple_type': 'bool'}, {'is_kwarg_only': 'True', 'name': 'dtype', 'simple_type': 'ScalarType'}], + torch._C._VariableFunctions.norm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'p', 'simple_type': 'Scalar?'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'DimnameList', 'size': 1}], + torch._C._VariableFunctions.frexp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.frexp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.frobenius_norm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'IntArrayRef', 'size': 1}], + torch._C._VariableFunctions.frobenius_norm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'IntArrayRef', 'size': 1}], + torch._C._VariableFunctions.nuclear_norm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.nuclear_norm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.nuclear_norm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'IntArrayRef', 'size': 2}], + torch._C._VariableFunctions.nuclear_norm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'IntArrayRef', 'size': 2}], + torch._C._VariableFunctions.clone: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.positive: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.resize_as_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'the_template', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.resize_as_sparse_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'the_template', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.zero_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.sub: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.sub: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.subtract: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.subtract: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.subtract: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.rsub: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.rsub: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.heaviside: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'values', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.heaviside: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'values', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.addmm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat2', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.addmm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat2', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.addmm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat2', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'out_dtype', 'simple_type': 'ScalarType'}], + torch._C._VariableFunctions.addmm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat2', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'out_dtype', 'simple_type': 'ScalarType'}], + torch._C._VariableFunctions._addmm_activation: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat2', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._addmm_activation: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat2', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._scaled_mm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat2', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'scale_a', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'scale_b', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._scaled_mm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat2', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'scale_a', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'scale_b', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._scaled_mm_v2: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat2', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'scale_a', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'recipe_a', 'simple_type': 'IntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'swizzle_a', 'simple_type': 'IntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'scale_b', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'recipe_b', 'simple_type': 'IntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'swizzle_b', 'simple_type': 'IntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'bias', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'out_dtype', 'simple_type': 'ScalarType?'}], + torch._C._VariableFunctions._scaled_mm_v2: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat2', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'scale_a', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'recipe_a', 'simple_type': 'IntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'swizzle_a', 'simple_type': 'IntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'scale_b', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'recipe_b', 'simple_type': 'IntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'swizzle_b', 'simple_type': 'IntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'bias', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'out_dtype', 'simple_type': 'ScalarType?'}], + torch._C._VariableFunctions._scaled_grouped_mm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat2', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'scale_a', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'scale_b', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._scaled_grouped_mm_v2: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat2', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'scale_a', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'recipe_a', 'simple_type': 'IntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'swizzle_a', 'simple_type': 'IntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'scale_b', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'recipe_b', 'simple_type': 'IntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'swizzle_b', 'simple_type': 'IntArrayRef'}], + torch._C._VariableFunctions._grouped_mm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat2', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._validate_sparse_coo_tensor_args: [{'is_kwarg_only': 'False', 'name': 'indices', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'values', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'IntArrayRef'}], + torch._C._VariableFunctions._validate_sparse_compressed_tensor_args: [{'is_kwarg_only': 'False', 'name': 'compressed_indices', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'plain_indices', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'values', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'IntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'layout', 'simple_type': 'Layout'}], + torch._C._VariableFunctions._validate_sparse_csr_tensor_args: [{'is_kwarg_only': 'False', 'name': 'crow_indices', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'col_indices', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'values', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'IntArrayRef'}], + torch._C._VariableFunctions._validate_sparse_csc_tensor_args: [{'is_kwarg_only': 'False', 'name': 'ccol_indices', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'row_indices', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'values', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'IntArrayRef'}], + torch._C._VariableFunctions._validate_sparse_bsr_tensor_args: [{'is_kwarg_only': 'False', 'name': 'crow_indices', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'col_indices', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'values', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'IntArrayRef'}], + torch._C._VariableFunctions._validate_sparse_bsc_tensor_args: [{'is_kwarg_only': 'False', 'name': 'ccol_indices', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'row_indices', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'values', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'IntArrayRef'}], + torch._C._VariableFunctions._to_cpu: [{'is_kwarg_only': 'False', 'name': 'tensors', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._coalesce: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.hspmm: [{'is_kwarg_only': 'False', 'name': 'mat1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat2', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.hspmm: [{'is_kwarg_only': 'False', 'name': 'mat1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat2', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.unbind: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.unbind: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch._C._VariableFunctions._to_sparse_semi_structured: [{'is_kwarg_only': 'False', 'name': 'dense', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.quantize_per_tensor_dynamic: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dtype', 'simple_type': 'ScalarType'}, {'is_kwarg_only': 'False', 'name': 'reduce_range', 'simple_type': 'bool'}], + torch._C._VariableFunctions.quantize_per_tensor: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'scale', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'zero_point', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'dtype', 'simple_type': 'ScalarType'}], + torch._C._VariableFunctions.quantize_per_tensor: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'scale', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'zero_point', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dtype', 'simple_type': 'ScalarType'}], + torch._C._VariableFunctions.quantize_per_tensor: [{'is_kwarg_only': 'False', 'name': 'tensors', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'scales', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'zero_points', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dtype', 'simple_type': 'ScalarType'}], + torch._C._VariableFunctions.quantize_per_channel: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'scales', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'zero_points', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'axis', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'dtype', 'simple_type': 'ScalarType'}], + torch._C._VariableFunctions.dequantize: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.dequantize: [{'is_kwarg_only': 'False', 'name': 'tensors', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions.q_scale: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.q_zero_point: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.q_per_channel_scales: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.q_per_channel_zero_points: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.q_per_channel_axis: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.int_repr: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._make_per_tensor_quantized_tensor: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'scale', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'zero_point', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions._make_per_channel_quantized_tensor: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'scale', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'zero_point', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'axis', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.fake_quantize_per_tensor_affine: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'scale', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'zero_point', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'quant_min', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'quant_max', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.fake_quantize_per_tensor_affine: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'scale', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'zero_point', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'quant_min', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'quant_max', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions._fake_quantize_per_tensor_affine_cachemask_tensor_qparams: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'scale', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'zero_point', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'fake_quant_enabled', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'quant_min', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'quant_max', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions._fake_quantize_learnable_per_tensor_affine: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'scale', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'zero_point', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'quant_min', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'quant_max', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.fake_quantize_per_channel_affine: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'scale', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'zero_point', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'axis', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'quant_min', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'quant_max', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions._fake_quantize_learnable_per_channel_affine: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'scale', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'zero_point', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'axis', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'quant_min', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'quant_max', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.fused_moving_avg_obs_fake_quant: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'observer_on', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'fake_quant_on', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'running_min', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'running_max', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'scale', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'zero_point', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'averaging_const', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'quant_min', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'quant_max', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'ch_axis', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions._fused_moving_avg_obs_fq_helper: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'observer_on', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'fake_quant_on', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'running_min', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'running_max', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'scale', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'zero_point', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'averaging_const', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'quant_min', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'quant_max', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'ch_axis', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions._choose_qparams_per_tensor: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._saturate_weight_to_fp16: [{'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.choose_qparams_optimized: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'numel', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'n_bins', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'ratio', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'bit_width', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.meshgrid: [{'is_kwarg_only': 'False', 'name': 'tensors', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions.meshgrid: [{'is_kwarg_only': 'False', 'name': 'tensors', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'True', 'name': 'indexing', 'simple_type': 'c10::string_view'}], + torch._C._VariableFunctions.cartesian_prod: [{'is_kwarg_only': 'False', 'name': 'tensors', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions.combinations: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.result_type: [{'is_kwarg_only': 'False', 'name': 'tensor', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.result_type: [{'is_kwarg_only': 'False', 'name': 'tensor', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.result_type: [{'is_kwarg_only': 'False', 'name': 'scalar', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'tensor', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.result_type: [{'is_kwarg_only': 'False', 'name': 'scalar1', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'scalar2', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.can_cast: [{'is_kwarg_only': 'False', 'name': 'from_', 'simple_type': 'ScalarType'}, {'is_kwarg_only': 'False', 'name': 'to', 'simple_type': 'ScalarType'}], + torch._C._VariableFunctions.promote_types: [{'is_kwarg_only': 'False', 'name': 'type1', 'simple_type': 'ScalarType'}, {'is_kwarg_only': 'False', 'name': 'type2', 'simple_type': 'ScalarType'}], + torch._C._VariableFunctions._lstm_mps: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'hx', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'params', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'has_biases', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'num_layers', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'dropout', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'train', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'bidirectional', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'batch_first', 'simple_type': 'bool'}], + torch._C._VariableFunctions.lstm: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'hx', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'params', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'has_biases', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'num_layers', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'dropout', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'train', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'bidirectional', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'batch_first', 'simple_type': 'bool'}], + torch._C._VariableFunctions.lstm: [{'is_kwarg_only': 'False', 'name': 'data', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'batch_sizes', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'hx', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'params', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'has_biases', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'num_layers', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'dropout', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'train', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'bidirectional', 'simple_type': 'bool'}], + torch._C._VariableFunctions.gru: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'hx', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'params', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'has_biases', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'num_layers', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'dropout', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'train', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'bidirectional', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'batch_first', 'simple_type': 'bool'}], + torch._C._VariableFunctions.gru: [{'is_kwarg_only': 'False', 'name': 'data', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'batch_sizes', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'hx', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'params', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'has_biases', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'num_layers', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'dropout', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'train', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'bidirectional', 'simple_type': 'bool'}], + torch._C._VariableFunctions.rnn_tanh: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'hx', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'params', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'has_biases', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'num_layers', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'dropout', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'train', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'bidirectional', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'batch_first', 'simple_type': 'bool'}], + torch._C._VariableFunctions.rnn_tanh: [{'is_kwarg_only': 'False', 'name': 'data', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'batch_sizes', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'hx', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'params', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'has_biases', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'num_layers', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'dropout', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'train', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'bidirectional', 'simple_type': 'bool'}], + torch._C._VariableFunctions.rnn_relu: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'hx', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'params', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'has_biases', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'num_layers', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'dropout', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'train', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'bidirectional', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'batch_first', 'simple_type': 'bool'}], + torch._C._VariableFunctions.rnn_relu: [{'is_kwarg_only': 'False', 'name': 'data', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'batch_sizes', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'hx', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'params', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'has_biases', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'num_layers', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'dropout', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'train', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'bidirectional', 'simple_type': 'bool'}], + torch._C._VariableFunctions.lstm_cell: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'hx', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'w_ih', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'w_hh', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.gru_cell: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'hx', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'w_ih', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'w_hh', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.rnn_tanh_cell: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'hx', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'w_ih', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'w_hh', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.rnn_relu_cell: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'hx', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'w_ih', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'w_hh', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.quantized_lstm_cell: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'hx', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'w_ih', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'w_hh', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'b_ih', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'b_hh', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'packed_ih', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'packed_hh', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'col_offsets_ih', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'col_offsets_hh', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'scale_ih', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'scale_hh', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'zero_point_ih', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'zero_point_hh', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.quantized_gru_cell: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'hx', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'w_ih', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'w_hh', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'b_ih', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'b_hh', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'packed_ih', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'packed_hh', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'col_offsets_ih', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'col_offsets_hh', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'scale_ih', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'scale_hh', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'zero_point_ih', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'zero_point_hh', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.quantized_rnn_relu_cell: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'hx', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'w_ih', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'w_hh', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'b_ih', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'b_hh', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'packed_ih', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'packed_hh', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'col_offsets_ih', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'col_offsets_hh', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'scale_ih', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'scale_hh', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'zero_point_ih', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'zero_point_hh', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.quantized_rnn_tanh_cell: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'hx', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'w_ih', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'w_hh', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'b_ih', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'b_hh', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'packed_ih', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'packed_hh', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'col_offsets_ih', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'col_offsets_hh', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'scale_ih', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'scale_hh', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'zero_point_ih', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'zero_point_hh', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions._pack_padded_sequence: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'lengths', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'batch_first', 'simple_type': 'bool'}], + torch._C._VariableFunctions._pad_packed_sequence: [{'is_kwarg_only': 'False', 'name': 'data', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'batch_sizes', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'batch_first', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'padding_value', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'total_length', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.masked_fill: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mask', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.masked_fill: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mask', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.masked_scatter: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mask', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'source', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._masked_softmax: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mask', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.put: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'source', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.index_add: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'source', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.index_add: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'source', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.index_add: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'source', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.index_reduce: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'source', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'reduce', 'simple_type': 'c10::string_view'}], + torch._C._VariableFunctions.index_reduce: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'source', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'reduce', 'simple_type': 'c10::string_view'}], + torch._C._VariableFunctions.index_fill: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.index_fill: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.index_fill: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.index_fill: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.scatter: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'src', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.scatter: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'src', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.scatter: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.scatter: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.scatter: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'src', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'True', 'name': 'reduce', 'simple_type': 'c10::string_view'}], + torch._C._VariableFunctions.scatter: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'src', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'True', 'name': 'reduce', 'simple_type': 'c10::string_view'}], + torch._C._VariableFunctions.scatter: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'True', 'name': 'reduce', 'simple_type': 'c10::string_view'}], + torch._C._VariableFunctions.scatter: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'True', 'name': 'reduce', 'simple_type': 'c10::string_view'}], + torch._C._VariableFunctions.scatter: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'src', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.scatter: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.scatter_add: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'src', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.scatter_add: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'src', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.scatter_add: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'src', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.scatter_reduce: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'src', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'reduce', 'simple_type': 'c10::string_view'}], + torch._C._VariableFunctions.scatter_reduce: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'src', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'reduce', 'simple_type': 'c10::string_view'}], + torch._C._VariableFunctions.bitwise_and: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.bitwise_and: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.bitwise_and: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.bitwise_and: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.bitwise_and: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.__and__: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.__and__: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.bitwise_or: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.bitwise_or: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.bitwise_or: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.bitwise_or: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.bitwise_or: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.__or__: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.__or__: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.bitwise_xor: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.bitwise_xor: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.bitwise_xor: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.bitwise_xor: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.bitwise_xor: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.__xor__: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.__xor__: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.__lshift__: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.__lshift__: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.bitwise_left_shift: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.bitwise_left_shift: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.bitwise_left_shift: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.bitwise_left_shift: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.bitwise_left_shift: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.__rshift__: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.__rshift__: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.bitwise_right_shift: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.bitwise_right_shift: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.bitwise_right_shift: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.bitwise_right_shift: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.bitwise_right_shift: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.addbmm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'batch1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'batch2', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.addbmm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'batch1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'batch2', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.diag: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.diag: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.cross: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.cross: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.triu: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.triu: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.tril: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.tril: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.tril_indices: [{'is_kwarg_only': 'False', 'name': 'row', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'col', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.triu_indices: [{'is_kwarg_only': 'False', 'name': 'row', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'col', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.trace: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.ne: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.ne: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.ne: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.ne: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.not_equal: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.not_equal: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.not_equal: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.not_equal: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.eq: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.eq: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.eq: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.eq: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.ge: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.ge: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.ge: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.ge: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.greater_equal: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.greater_equal: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.greater_equal: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.greater_equal: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.le: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.le: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.le: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.le: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.less_equal: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.less_equal: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.less_equal: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.less_equal: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.gt: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.gt: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.gt: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.gt: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.greater: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.greater: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.greater: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.greater: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.lt: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.lt: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.lt: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.lt: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.less: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.less: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.less: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.less: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.take: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.take: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.take_along_dim: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'indices', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.take_along_dim: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'indices', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.index_select: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.index_select: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.index_select: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.index_select: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.masked_select: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mask', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.masked_select: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mask', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.nonzero_static: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'True', 'name': 'size', 'simple_type': 'SymInt'}], + torch._C._VariableFunctions.nonzero_static: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'True', 'name': 'size', 'simple_type': 'SymInt'}], + torch._C._VariableFunctions.argwhere: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.gather: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.gather: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.gather: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.gather: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.addcmul: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'tensor1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'tensor2', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.addcmul: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'tensor1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'tensor2', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.addcdiv: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'tensor1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'tensor2', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.addcdiv: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'tensor1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'tensor2', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.triangular_solve: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'A', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.triangular_solve: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'A', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._linalg_check_errors: [{'is_kwarg_only': 'False', 'name': 'info', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'api_name', 'simple_type': 'c10::string_view'}, {'is_kwarg_only': 'True', 'name': 'is_matrix', 'simple_type': 'bool'}], + torch._C._VariableFunctions.svd: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.svd: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.swapaxes: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'axis0', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'axis1', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.swapdims: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim0', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'dim1', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.cholesky: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.cholesky: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.cholesky_solve: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'input2', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.cholesky_solve: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'input2', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.cholesky_inverse: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.cholesky_inverse: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.qr: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.qr: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.geqrf: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.geqrf: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.orgqr: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'input2', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.orgqr: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'input2', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.ormqr: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'input2', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'input3', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.ormqr: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'input2', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'input3', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._lu_with_info: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.lu_solve: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'LU_data', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'LU_pivots', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.lu_solve: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'LU_data', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'LU_pivots', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.lu_unpack: [{'is_kwarg_only': 'False', 'name': 'LU_data', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'LU_pivots', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.lu_unpack: [{'is_kwarg_only': 'False', 'name': 'LU_data', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'LU_pivots', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.multinomial: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'num_samples', 'simple_type': 'SymInt'}], + torch._C._VariableFunctions.multinomial: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'num_samples', 'simple_type': 'SymInt'}], + torch._C._VariableFunctions.lgamma: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.lgamma: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.digamma: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.digamma: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.polygamma: [{'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.polygamma: [{'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.erfinv: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.erfinv: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.i0: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.i0: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.i0_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.sign: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.sign: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.signbit: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.signbit: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.dist: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.atan2: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.atan2: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.arctan2: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.arctan2: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.lerp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'end', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.lerp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'end', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.lerp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'end', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.lerp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'end', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.histc: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.histc: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.histogram: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'bins', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.histogram: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'bins', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.histogram: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.histogram: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._histogramdd_bin_edges: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'bins', 'simple_type': 'IntArrayRef'}], + torch._C._VariableFunctions._histogramdd_from_bin_cts: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'bins', 'simple_type': 'IntArrayRef'}], + torch._C._VariableFunctions._histogramdd_from_bin_tensors: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'bins', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions.histogramdd: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'bins', 'simple_type': 'IntArrayRef'}], + torch._C._VariableFunctions.histogramdd: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'bins', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.histogramdd: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'bins', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions.fmod: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.fmod: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.fmod: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.fmod: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.hypot: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.hypot: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.igamma: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.igamma: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.igammac: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.igammac: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.nextafter: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.nextafter: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.remainder: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.remainder: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.remainder: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.remainder: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.remainder: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.fmin: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.fmin: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.fmax: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.fmax: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.maximum: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.maximum: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.minimum: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.minimum: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.quantile: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'q', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.quantile: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'q', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.quantile: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'q', 'simple_type': 'double'}], + torch._C._VariableFunctions.quantile: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'q', 'simple_type': 'double'}], + torch._C._VariableFunctions.nanquantile: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'q', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.nanquantile: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'q', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.nanquantile: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'q', 'simple_type': 'double'}], + torch._C._VariableFunctions.nanquantile: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'q', 'simple_type': 'double'}], + torch._C._VariableFunctions.sort: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.sort: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'True', 'name': 'stable', 'simple_type': 'bool?'}], + torch._C._VariableFunctions.sort: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.sort: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'True', 'name': 'stable', 'simple_type': 'bool?'}], + torch._C._VariableFunctions.sort: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch._C._VariableFunctions.sort: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'True', 'name': 'stable', 'simple_type': 'bool?'}, {'is_kwarg_only': 'True', 'name': 'dim', 'simple_type': 'Dimname'}], + torch._C._VariableFunctions.sort: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch._C._VariableFunctions.sort: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'True', 'name': 'stable', 'simple_type': 'bool?'}, {'is_kwarg_only': 'True', 'name': 'dim', 'simple_type': 'Dimname'}], + torch._C._VariableFunctions.msort: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.msort: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.argsort: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.argsort: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'True', 'name': 'stable', 'simple_type': 'bool'}], + torch._C._VariableFunctions.argsort: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'True', 'name': 'stable', 'simple_type': 'bool'}], + torch._C._VariableFunctions.argsort: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch._C._VariableFunctions.topk: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'k', 'simple_type': 'SymInt'}], + torch._C._VariableFunctions.topk: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'k', 'simple_type': 'SymInt'}], + torch._C._VariableFunctions.renorm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'p', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'maxnorm', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.renorm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'p', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'maxnorm', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.equal: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.pow: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'exponent', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.pow: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'exponent', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.pow: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'exponent', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.pow: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'exponent', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.pow: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'exponent', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.pow: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'exponent', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.float_power: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'exponent', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.float_power: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'exponent', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.float_power: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'exponent', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.float_power: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'exponent', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.float_power: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'exponent', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.float_power: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'exponent', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.normal: [{'is_kwarg_only': 'False', 'name': 'mean', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.normal: [{'is_kwarg_only': 'False', 'name': 'mean', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.normal: [{'is_kwarg_only': 'False', 'name': 'mean', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'std', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.normal: [{'is_kwarg_only': 'False', 'name': 'mean', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'std', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.normal: [{'is_kwarg_only': 'False', 'name': 'mean', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'std', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.normal: [{'is_kwarg_only': 'False', 'name': 'mean', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'std', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.normal: [{'is_kwarg_only': 'False', 'name': 'mean', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'std', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}], + torch._C._VariableFunctions.normal: [{'is_kwarg_only': 'False', 'name': 'mean', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'std', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}], + torch._C._VariableFunctions._amp_foreach_non_finite_check_and_unscale_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'found_inf', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'inv_scale', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._amp_update_scale_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'growth_tracker', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'found_inf', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'scale_growth_factor', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'scale_backoff_factor', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'growth_interval', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions._foreach_add: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'scalar', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions._foreach_add: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_add: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'scalars', 'simple_type': 'ScalarList'}], + torch._C._VariableFunctions._foreach_add: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._foreach_add_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'scalar', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions._foreach_add_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_add_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'scalars', 'simple_type': 'ScalarList'}], + torch._C._VariableFunctions._foreach_add_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._foreach_sub: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'scalar', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions._foreach_sub: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_sub: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'scalars', 'simple_type': 'ScalarList'}], + torch._C._VariableFunctions._foreach_sub_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'scalar', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions._foreach_sub_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_sub_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'scalars', 'simple_type': 'ScalarList'}], + torch._C._VariableFunctions._foreach_mul: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'scalar', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions._foreach_mul: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_mul: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'scalars', 'simple_type': 'ScalarList'}], + torch._C._VariableFunctions._foreach_mul: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._foreach_mul_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'scalar', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions._foreach_mul_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_mul_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'scalars', 'simple_type': 'ScalarList'}], + torch._C._VariableFunctions._foreach_mul_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._foreach_div: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'scalar', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions._foreach_div: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_div: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'scalars', 'simple_type': 'ScalarList'}], + torch._C._VariableFunctions._foreach_div: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._foreach_div_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'scalar', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions._foreach_div_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_div_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'scalars', 'simple_type': 'ScalarList'}], + torch._C._VariableFunctions._foreach_div_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._foreach_clamp_max: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'scalar', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions._foreach_clamp_max: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_clamp_max: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'scalars', 'simple_type': 'ScalarList'}], + torch._C._VariableFunctions._foreach_clamp_max_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'scalar', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions._foreach_clamp_max_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_clamp_max_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'scalars', 'simple_type': 'ScalarList'}], + torch._C._VariableFunctions._foreach_clamp_min: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'scalar', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions._foreach_clamp_min: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_clamp_min: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'scalars', 'simple_type': 'ScalarList'}], + torch._C._VariableFunctions._foreach_clamp_min_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'scalar', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions._foreach_clamp_min_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_clamp_min_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'scalars', 'simple_type': 'ScalarList'}], + torch._C._VariableFunctions._foreach_maximum: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'scalar', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions._foreach_maximum: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_maximum: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'scalars', 'simple_type': 'ScalarList'}], + torch._C._VariableFunctions._foreach_maximum_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'scalar', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions._foreach_maximum_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_maximum_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'scalars', 'simple_type': 'ScalarList'}], + torch._C._VariableFunctions._foreach_minimum: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'scalar', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions._foreach_minimum: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_minimum: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'scalars', 'simple_type': 'ScalarList'}], + torch._C._VariableFunctions._foreach_minimum_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'scalar', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions._foreach_minimum_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_minimum_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'scalars', 'simple_type': 'ScalarList'}], + torch._C._VariableFunctions._foreach_addcdiv: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'tensor1', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'tensor2', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_addcdiv: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'tensor1', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'tensor2', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'scalars', 'simple_type': 'ScalarList'}], + torch._C._VariableFunctions._foreach_addcdiv: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'tensor1', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'tensor2', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'scalars', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._foreach_addcdiv_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'tensor1', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'tensor2', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_addcdiv_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'tensor1', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'tensor2', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'scalars', 'simple_type': 'ScalarList'}], + torch._C._VariableFunctions._foreach_addcdiv_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'tensor1', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'tensor2', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'scalars', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._foreach_addcmul: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'tensor1', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'tensor2', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_addcmul: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'tensor1', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'tensor2', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'scalars', 'simple_type': 'ScalarList'}], + torch._C._VariableFunctions._foreach_addcmul: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'tensor1', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'tensor2', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'scalars', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._foreach_addcmul_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'tensor1', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'tensor2', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_addcmul_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'tensor1', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'tensor2', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'scalars', 'simple_type': 'ScalarList'}], + torch._C._VariableFunctions._foreach_addcmul_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'tensor1', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'tensor2', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'scalars', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._foreach_abs: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_abs_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_acos: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_acos_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_asin: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_asin_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_atan: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_atan_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_ceil: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_ceil_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_cos: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_cos_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_cosh: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_cosh_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_erf: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_erf_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_erfc: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_erfc_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_exp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_exp_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_expm1: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_expm1_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_floor: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_floor_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_frac: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_frac_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_lerp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'tensors1', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'weights', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_lerp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'tensors1', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions._foreach_lerp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'tensors1', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'ScalarList'}], + torch._C._VariableFunctions._foreach_lerp_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'tensors1', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'weights', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_lerp_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'tensors1', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions._foreach_lerp_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'tensors1', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'ScalarList'}], + torch._C._VariableFunctions._foreach_lgamma: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_lgamma_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_log: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_log_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_log10: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_log10_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_log1p: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_log1p_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_log2: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_log2_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_max: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_neg: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_neg_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_norm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_pow: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'exponent', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_pow: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'exponent', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions._foreach_pow: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'exponent', 'simple_type': 'ScalarList'}], + torch._C._VariableFunctions._foreach_pow: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'exponent', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_pow_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'exponent', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_pow_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'exponent', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions._foreach_pow_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'exponent', 'simple_type': 'ScalarList'}], + torch._C._VariableFunctions._foreach_reciprocal: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_reciprocal_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_round: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_round_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_rsqrt: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_rsqrt_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_sigmoid: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_sigmoid_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_sign: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_sign_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_sin: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_sin_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_sinh: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_sinh_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_sqrt: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_sqrt_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_tan: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_tan_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_tanh: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_tanh_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_trunc: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_trunc_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_zero_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._foreach_copy_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'src', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions.bucketize: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'boundaries', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.bucketize: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'boundaries', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.bucketize: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'boundaries', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.searchsorted: [{'is_kwarg_only': 'False', 'name': 'sorted_sequence', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.searchsorted: [{'is_kwarg_only': 'False', 'name': 'sorted_sequence', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.searchsorted: [{'is_kwarg_only': 'False', 'name': 'sorted_sequence', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions.searchsorted: [{'is_kwarg_only': 'False', 'name': 'sorted_sequence', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Scalar'}], + torch._C._VariableFunctions._convert_indices_from_coo_to_csr: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions._convert_indices_from_coo_to_csr: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions._convert_indices_from_csr_to_coo: [{'is_kwarg_only': 'False', 'name': 'crow_indices', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'col_indices', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._convert_indices_from_csr_to_coo: [{'is_kwarg_only': 'False', 'name': 'crow_indices', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'col_indices', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.mkldnn_adaptive_avg_pool2d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'IntArrayRef', 'size': 2}], + torch._C._VariableFunctions.mkldnn_adaptive_avg_pool2d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'IntArrayRef', 'size': 2}], + torch._C._VariableFunctions._adaptive_avg_pool2d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef', 'size': 2}], + torch._C._VariableFunctions._adaptive_avg_pool3d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef', 'size': 3}], + torch._C._VariableFunctions.column_stack: [{'is_kwarg_only': 'False', 'name': 'tensors', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions.column_stack: [{'is_kwarg_only': 'False', 'name': 'tensors', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions.isfinite: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.isinf: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.isposinf: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.isposinf: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.isneginf: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.isneginf: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._add_batch_dim: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'batch_dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'level', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions._remove_batch_dim: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'level', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'batch_size', 'simple_type': 'SymInt'}, {'is_kwarg_only': 'False', 'name': 'out_dim', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions._linalg_det: [{'is_kwarg_only': 'False', 'name': 'A', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._linalg_det: [{'is_kwarg_only': 'False', 'name': 'A', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.det: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._linalg_slogdet: [{'is_kwarg_only': 'False', 'name': 'A', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._linalg_slogdet: [{'is_kwarg_only': 'False', 'name': 'A', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.slogdet: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.slogdet: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.logdet: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._linalg_eigh: [{'is_kwarg_only': 'False', 'name': 'A', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._linalg_eigh: [{'is_kwarg_only': 'False', 'name': 'A', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.inverse: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.inverse: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.inner: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.inner: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.outer: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'vec2', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.outer: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'vec2', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.ger: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'vec2', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.ger: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'vec2', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._linalg_svd: [{'is_kwarg_only': 'False', 'name': 'A', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._linalg_svd: [{'is_kwarg_only': 'False', 'name': 'A', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._linalg_solve_ex: [{'is_kwarg_only': 'False', 'name': 'A', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'B', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._linalg_solve_ex: [{'is_kwarg_only': 'False', 'name': 'A', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'B', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._test_serialization_subcmul: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._test_parallel_materialize: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'num_parallel', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions._test_autograd_multiple_dispatch: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._test_autograd_multiple_dispatch: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'b', 'simple_type': 'bool'}], + torch._C._VariableFunctions._test_autograd_multiple_dispatch_view: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._test_autograd_multiple_dispatch_view_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._test_autograd_multiple_dispatch_view_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.segment_reduce: [{'is_kwarg_only': 'False', 'name': 'data', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'reduce', 'simple_type': 'c10::string_view'}], + torch._C._VariableFunctions._nested_tensor_from_tensor_list: [{'is_kwarg_only': 'False', 'name': 'list', 'simple_type': 'TensorList'}], + torch._C._VariableFunctions._fw_primal_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'level', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions._fw_primal_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'level', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions._make_dual_copy: [{'is_kwarg_only': 'False', 'name': 'primal', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'tangent', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'level', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions._make_dual_copy: [{'is_kwarg_only': 'False', 'name': 'primal', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'tangent', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'level', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.view_as_real_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.view_as_real_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.view_as_complex_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.view_as_complex_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._conj_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._conj_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._neg_view_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._neg_view_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.as_strided_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'stride', 'simple_type': 'SymIntArrayRef'}], + torch._C._VariableFunctions.as_strided_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'stride', 'simple_type': 'SymIntArrayRef'}], + torch._C._VariableFunctions._sparse_broadcast_to_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'IntArrayRef'}], + torch._C._VariableFunctions._sparse_broadcast_to_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'IntArrayRef'}], + torch._C._VariableFunctions.diagonal_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.diagonal_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.expand_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}], + torch._C._VariableFunctions.expand_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}], + torch._C._VariableFunctions.permute_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dims', 'simple_type': 'IntArrayRef'}], + torch._C._VariableFunctions.permute_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dims', 'simple_type': 'IntArrayRef'}], + torch._C._VariableFunctions._reshape_alias_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'stride', 'simple_type': 'SymIntArrayRef'}], + torch._C._VariableFunctions._reshape_alias_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'stride', 'simple_type': 'SymIntArrayRef'}], + torch._C._VariableFunctions.select_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'SymInt'}], + torch._C._VariableFunctions.select_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'SymInt'}], + torch._C._VariableFunctions.detach_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.detach_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.slice_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.slice_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.split_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'split_size', 'simple_type': 'SymInt'}], + torch._C._VariableFunctions.split_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'split_size', 'simple_type': 'SymInt'}], + torch._C._VariableFunctions.split_with_sizes_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'split_sizes', 'simple_type': 'SymIntArrayRef'}], + torch._C._VariableFunctions.split_with_sizes_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'split_sizes', 'simple_type': 'SymIntArrayRef'}], + torch._C._VariableFunctions.squeeze_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.squeeze_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.squeeze_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'IntArrayRef'}], + torch._C._VariableFunctions.squeeze_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.squeeze_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.squeeze_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'IntArrayRef'}], + torch._C._VariableFunctions.t_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.t_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.transpose_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim0', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'dim1', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.transpose_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim0', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'dim1', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.unsqueeze_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.unsqueeze_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions._indices_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._indices_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._values_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._values_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.indices_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.indices_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.values_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.values_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.crow_indices_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.crow_indices_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.col_indices_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.col_indices_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.ccol_indices_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.ccol_indices_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.row_indices_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.row_indices_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.unbind_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.unbind_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.view_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}], + torch._C._VariableFunctions.view_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dtype', 'simple_type': 'ScalarType'}], + torch._C._VariableFunctions.view_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}], + torch._C._VariableFunctions.view_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dtype', 'simple_type': 'ScalarType'}], + torch._C._VariableFunctions.unfold_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dimension', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'step', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.unfold_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dimension', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'step', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions.alias_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions.alias_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._nested_from_padded_tensor: [{'is_kwarg_only': 'False', 'name': 'padded', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'offsets', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dummy', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._nested_tensor_softmax_with_shape: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'query', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._safe_softmax: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions._transformer_encoder_layer_fwd: [{'is_kwarg_only': 'False', 'name': 'src', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'embed_dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'num_heads', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'qkv_weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'qkv_bias', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'proj_weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'proj_bias', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'use_gelu', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'norm_first', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'eps', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'norm_weight_1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'norm_bias_1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'norm_weight_2', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'norm_bias_2', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'ffn_weight_1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'ffn_bias_1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'ffn_weight_2', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'ffn_bias_2', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._native_multi_head_attention: [{'is_kwarg_only': 'False', 'name': 'query', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'key', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'embed_dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'num_head', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'qkv_weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'qkv_bias', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'proj_weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'proj_bias', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._fused_sdp_choice: [{'is_kwarg_only': 'False', 'name': 'query', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'key', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._scaled_dot_product_attention_math: [{'is_kwarg_only': 'False', 'name': 'query', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'key', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._scaled_dot_product_attention_math_for_mps: [{'is_kwarg_only': 'False', 'name': 'query', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'key', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._scaled_dot_product_flash_attention: [{'is_kwarg_only': 'False', 'name': 'query', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'key', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._scaled_dot_product_flash_attention_for_cpu: [{'is_kwarg_only': 'False', 'name': 'query', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'key', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._scaled_dot_product_efficient_attention: [{'is_kwarg_only': 'False', 'name': 'query', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'key', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'attn_bias', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'compute_log_sumexp', 'simple_type': 'bool'}], + torch._C._VariableFunctions._scaled_dot_product_cudnn_attention: [{'is_kwarg_only': 'False', 'name': 'query', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'key', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'attn_bias', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'compute_log_sumexp', 'simple_type': 'bool'}], + torch._C._VariableFunctions._triton_scaled_dot_attention: [{'is_kwarg_only': 'False', 'name': 'q', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'k', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'v', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._fill_mem_eff_dropout_mask_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dropout_p', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'seed', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'offset', 'simple_type': 'int64_t'}], + torch._C._VariableFunctions._triton_multi_head_attention: [{'is_kwarg_only': 'False', 'name': 'query', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'key', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'embed_dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'num_head', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'qkv_weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'qkv_bias', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'proj_weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'proj_bias', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._foobar: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._VariableFunctions._fused_adam_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'grads', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'exp_avgs', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'exp_avg_sqs', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'max_exp_avg_sqs', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'state_steps', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'True', 'name': 'lr', 'simple_type': 'double'}, {'is_kwarg_only': 'True', 'name': 'beta1', 'simple_type': 'double'}, {'is_kwarg_only': 'True', 'name': 'beta2', 'simple_type': 'double'}, {'is_kwarg_only': 'True', 'name': 'weight_decay', 'simple_type': 'double'}, {'is_kwarg_only': 'True', 'name': 'eps', 'simple_type': 'double'}, {'is_kwarg_only': 'True', 'name': 'amsgrad', 'simple_type': 'bool'}, {'is_kwarg_only': 'True', 'name': 'maximize', 'simple_type': 'bool'}], + torch._C._VariableFunctions._fused_adam_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'grads', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'exp_avgs', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'exp_avg_sqs', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'max_exp_avg_sqs', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'state_steps', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'True', 'name': 'lr', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'True', 'name': 'beta1', 'simple_type': 'double'}, {'is_kwarg_only': 'True', 'name': 'beta2', 'simple_type': 'double'}, {'is_kwarg_only': 'True', 'name': 'weight_decay', 'simple_type': 'double'}, {'is_kwarg_only': 'True', 'name': 'eps', 'simple_type': 'double'}, {'is_kwarg_only': 'True', 'name': 'amsgrad', 'simple_type': 'bool'}, {'is_kwarg_only': 'True', 'name': 'maximize', 'simple_type': 'bool'}], + torch._C._VariableFunctions._fused_adamw_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'grads', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'exp_avgs', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'exp_avg_sqs', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'max_exp_avg_sqs', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'state_steps', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'True', 'name': 'lr', 'simple_type': 'double'}, {'is_kwarg_only': 'True', 'name': 'beta1', 'simple_type': 'double'}, {'is_kwarg_only': 'True', 'name': 'beta2', 'simple_type': 'double'}, {'is_kwarg_only': 'True', 'name': 'weight_decay', 'simple_type': 'double'}, {'is_kwarg_only': 'True', 'name': 'eps', 'simple_type': 'double'}, {'is_kwarg_only': 'True', 'name': 'amsgrad', 'simple_type': 'bool'}, {'is_kwarg_only': 'True', 'name': 'maximize', 'simple_type': 'bool'}], + torch._C._VariableFunctions._fused_adamw_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'grads', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'exp_avgs', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'exp_avg_sqs', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'max_exp_avg_sqs', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'state_steps', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'True', 'name': 'lr', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'True', 'name': 'beta1', 'simple_type': 'double'}, {'is_kwarg_only': 'True', 'name': 'beta2', 'simple_type': 'double'}, {'is_kwarg_only': 'True', 'name': 'weight_decay', 'simple_type': 'double'}, {'is_kwarg_only': 'True', 'name': 'eps', 'simple_type': 'double'}, {'is_kwarg_only': 'True', 'name': 'amsgrad', 'simple_type': 'bool'}, {'is_kwarg_only': 'True', 'name': 'maximize', 'simple_type': 'bool'}], + torch._C._VariableFunctions._fused_sgd_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'grads', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'momentum_buffer_list', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'True', 'name': 'weight_decay', 'simple_type': 'double'}, {'is_kwarg_only': 'True', 'name': 'momentum', 'simple_type': 'double'}, {'is_kwarg_only': 'True', 'name': 'lr', 'simple_type': 'double'}, {'is_kwarg_only': 'True', 'name': 'dampening', 'simple_type': 'double'}, {'is_kwarg_only': 'True', 'name': 'nesterov', 'simple_type': 'bool'}, {'is_kwarg_only': 'True', 'name': 'maximize', 'simple_type': 'bool'}, {'is_kwarg_only': 'True', 'name': 'is_first_step', 'simple_type': 'bool'}], + torch._C._VariableFunctions._fused_sgd_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'grads', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'momentum_buffer_list', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'True', 'name': 'weight_decay', 'simple_type': 'double'}, {'is_kwarg_only': 'True', 'name': 'momentum', 'simple_type': 'double'}, {'is_kwarg_only': 'True', 'name': 'lr', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'True', 'name': 'dampening', 'simple_type': 'double'}, {'is_kwarg_only': 'True', 'name': 'nesterov', 'simple_type': 'bool'}, {'is_kwarg_only': 'True', 'name': 'maximize', 'simple_type': 'bool'}, {'is_kwarg_only': 'True', 'name': 'is_first_step', 'simple_type': 'bool'}], + torch._C._VariableFunctions._fused_adagrad_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'grads', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'state_sums', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'state_steps', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'True', 'name': 'lr', 'simple_type': 'double'}, {'is_kwarg_only': 'True', 'name': 'lr_decay', 'simple_type': 'double'}, {'is_kwarg_only': 'True', 'name': 'weight_decay', 'simple_type': 'double'}, {'is_kwarg_only': 'True', 'name': 'eps', 'simple_type': 'double'}, {'is_kwarg_only': 'True', 'name': 'maximize', 'simple_type': 'bool'}], + torch._C._VariableFunctions._fused_adagrad_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'grads', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'state_sums', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'state_steps', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'True', 'name': 'lr', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'True', 'name': 'lr_decay', 'simple_type': 'double'}, {'is_kwarg_only': 'True', 'name': 'weight_decay', 'simple_type': 'double'}, {'is_kwarg_only': 'True', 'name': 'eps', 'simple_type': 'double'}, {'is_kwarg_only': 'True', 'name': 'maximize', 'simple_type': 'bool'}], + torch._C._VariableFunctions._propagate_xla_data: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output', 'simple_type': 'Tensor'}], + torch._C._nn.binary_cross_entropy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'target', 'simple_type': 'Tensor'}], + torch._C._nn.binary_cross_entropy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'target', 'simple_type': 'Tensor'}], + torch._C._nn.linear: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}], + torch._C._nn.linear: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}], + torch._C._nn.mkldnn_linear: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}], + torch._C._nn.relu6: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._nn.relu6_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._nn.gelu: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._nn.gelu: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._nn.gelu_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._nn.silu: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._nn.silu: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._nn.silu_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._nn.mish: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._nn.mish: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._nn.mish_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._nn.one_hot: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._nn.mkldnn_reorder_conv2d_weight: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._nn.mkldnn_reorder_conv3d_weight: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._nn.cross_entropy_loss: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'target', 'simple_type': 'Tensor'}], + torch._C._nn.mse_loss: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'target', 'simple_type': 'Tensor'}], + torch._C._nn.mse_loss: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'target', 'simple_type': 'Tensor'}], + torch._C._nn.l1_loss: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'target', 'simple_type': 'Tensor'}], + torch._C._nn.multi_margin_loss: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'target', 'simple_type': 'Tensor'}], + torch._C._nn.multi_margin_loss: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'target', 'simple_type': 'Tensor'}], + torch._C._nn.multilabel_margin_loss: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'target', 'simple_type': 'Tensor'}], + torch._C._nn.multilabel_margin_loss: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'target', 'simple_type': 'Tensor'}], + torch._C._nn.nll_loss: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'target', 'simple_type': 'Tensor'}], + torch._C._nn.nll_loss: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'target', 'simple_type': 'Tensor'}], + torch._C._nn.nll_loss_nd: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'target', 'simple_type': 'Tensor'}], + torch._C._nn.nll_loss2d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'target', 'simple_type': 'Tensor'}], + torch._C._nn.nll_loss2d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'target', 'simple_type': 'Tensor'}], + torch._C._nn.smooth_l1_loss: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'target', 'simple_type': 'Tensor'}], + torch._C._nn.smooth_l1_loss: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'target', 'simple_type': 'Tensor'}], + torch._C._nn.huber_loss: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'target', 'simple_type': 'Tensor'}], + torch._C._nn.huber_loss: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'target', 'simple_type': 'Tensor'}], + torch._C._nn.soft_margin_loss: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'target', 'simple_type': 'Tensor'}], + torch._C._nn.soft_margin_loss: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'target', 'simple_type': 'Tensor'}], + torch._C._nn.elu: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._nn.elu: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._nn.elu_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._nn.glu: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._nn.glu: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._nn.hardsigmoid: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._nn.hardsigmoid: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._nn.hardsigmoid_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._nn.hardtanh: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._nn.hardtanh: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._nn.hardtanh_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._nn.hardswish: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._nn.hardswish: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._nn.hardswish_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._nn.leaky_relu: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._nn.leaky_relu: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._nn.leaky_relu_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._nn.log_sigmoid: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._nn.log_sigmoid: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._nn.rrelu_with_noise: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'noise', 'simple_type': 'Tensor'}], + torch._C._nn.rrelu_with_noise: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'noise', 'simple_type': 'Tensor'}], + torch._C._nn.rrelu_with_noise_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'noise', 'simple_type': 'Tensor'}], + torch._C._nn.softplus: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._nn.softplus: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._nn.softshrink: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._nn.softshrink: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._nn.adaptive_avg_pool2d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef', 'size': 2}], + torch._C._nn.adaptive_avg_pool2d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef', 'size': 2}], + torch._C._nn.adaptive_avg_pool3d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef', 'size': 3}], + torch._C._nn.adaptive_avg_pool3d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef', 'size': 3}], + torch._C._nn.adaptive_max_pool2d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'IntArrayRef', 'size': 2}], + torch._C._nn.adaptive_max_pool2d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'IntArrayRef', 'size': 2}], + torch._C._nn.adaptive_max_pool3d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'IntArrayRef', 'size': 3}], + torch._C._nn.adaptive_max_pool3d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'IntArrayRef', 'size': 3}], + torch._C._nn.avg_pool2d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'kernel_size', 'simple_type': 'IntArrayRef', 'size': 2}], + torch._C._nn.avg_pool2d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'kernel_size', 'simple_type': 'IntArrayRef', 'size': 2}], + torch._C._nn.avg_pool3d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'kernel_size', 'simple_type': 'IntArrayRef', 'size': 3}], + torch._C._nn.avg_pool3d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'kernel_size', 'simple_type': 'IntArrayRef', 'size': 3}], + torch._C._nn.fractional_max_pool2d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'kernel_size', 'simple_type': 'IntArrayRef', 'size': 2}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'IntArrayRef', 'size': 2}, {'is_kwarg_only': 'False', 'name': 'random_samples', 'simple_type': 'Tensor'}], + torch._C._nn.fractional_max_pool2d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'kernel_size', 'simple_type': 'IntArrayRef', 'size': 2}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'IntArrayRef', 'size': 2}, {'is_kwarg_only': 'False', 'name': 'random_samples', 'simple_type': 'Tensor'}], + torch._C._nn.fractional_max_pool3d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'kernel_size', 'simple_type': 'IntArrayRef', 'size': 3}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'IntArrayRef', 'size': 3}, {'is_kwarg_only': 'False', 'name': 'random_samples', 'simple_type': 'Tensor'}], + torch._C._nn.fractional_max_pool3d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'kernel_size', 'simple_type': 'IntArrayRef', 'size': 3}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'IntArrayRef', 'size': 3}, {'is_kwarg_only': 'False', 'name': 'random_samples', 'simple_type': 'Tensor'}], + torch._C._nn.max_pool2d_with_indices: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'kernel_size', 'simple_type': 'IntArrayRef', 'size': 2}], + torch._C._nn.max_pool2d_with_indices: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'kernel_size', 'simple_type': 'IntArrayRef', 'size': 2}], + torch._C._nn.max_pool3d_with_indices: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'kernel_size', 'simple_type': 'IntArrayRef', 'size': 3}], + torch._C._nn.max_pool3d_with_indices: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'kernel_size', 'simple_type': 'IntArrayRef', 'size': 3}], + torch._C._nn.max_unpool2d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'indices', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef', 'size': 2}], + torch._C._nn.max_unpool2d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'indices', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef', 'size': 2}], + torch._C._nn.max_unpool3d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'indices', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef', 'size': 3}, {'is_kwarg_only': 'False', 'name': 'stride', 'simple_type': 'IntArrayRef', 'size': 3}, {'is_kwarg_only': 'False', 'name': 'padding', 'simple_type': 'IntArrayRef', 'size': 3}], + torch._C._nn.max_unpool3d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'indices', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef', 'size': 3}, {'is_kwarg_only': 'False', 'name': 'stride', 'simple_type': 'IntArrayRef', 'size': 3}, {'is_kwarg_only': 'False', 'name': 'padding', 'simple_type': 'IntArrayRef', 'size': 3}], + torch._C._nn.reflection_pad1d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'padding', 'simple_type': 'SymIntArrayRef', 'size': 2}], + torch._C._nn.reflection_pad1d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'padding', 'simple_type': 'SymIntArrayRef', 'size': 2}], + torch._C._nn.reflection_pad2d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'padding', 'simple_type': 'SymIntArrayRef', 'size': 4}], + torch._C._nn.reflection_pad2d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'padding', 'simple_type': 'SymIntArrayRef', 'size': 4}], + torch._C._nn.reflection_pad3d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'padding', 'simple_type': 'SymIntArrayRef', 'size': 6}], + torch._C._nn.reflection_pad3d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'padding', 'simple_type': 'SymIntArrayRef', 'size': 6}], + torch._C._nn.replication_pad1d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'padding', 'simple_type': 'SymIntArrayRef', 'size': 2}], + torch._C._nn.replication_pad1d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'padding', 'simple_type': 'SymIntArrayRef', 'size': 2}], + torch._C._nn.replication_pad2d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'padding', 'simple_type': 'SymIntArrayRef', 'size': 4}], + torch._C._nn.replication_pad2d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'padding', 'simple_type': 'SymIntArrayRef', 'size': 4}], + torch._C._nn.replication_pad3d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'padding', 'simple_type': 'SymIntArrayRef', 'size': 6}], + torch._C._nn.replication_pad3d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'padding', 'simple_type': 'SymIntArrayRef', 'size': 6}], + torch._C._nn._pad_circular: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'pad', 'simple_type': 'SymIntArrayRef'}], + torch._C._nn._pad_enum: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'pad', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'mode', 'simple_type': 'int64_t'}], + torch._C._nn.pad: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'pad', 'simple_type': 'SymIntArrayRef'}], + torch._C._nn.upsample_linear1d: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef?'}, {'is_kwarg_only': 'False', 'name': 'align_corners', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'scale_factors', 'simple_type': 'ArrayRef?'}], + torch._C._nn.upsample_linear1d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef', 'size': 1}, {'is_kwarg_only': 'False', 'name': 'align_corners', 'simple_type': 'bool'}], + torch._C._nn.upsample_linear1d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef', 'size': 1}, {'is_kwarg_only': 'False', 'name': 'align_corners', 'simple_type': 'bool'}], + torch._C._nn.upsample_bilinear2d: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef?'}, {'is_kwarg_only': 'False', 'name': 'align_corners', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'scale_factors', 'simple_type': 'ArrayRef?'}], + torch._C._nn.upsample_bilinear2d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef', 'size': 2}, {'is_kwarg_only': 'False', 'name': 'align_corners', 'simple_type': 'bool'}], + torch._C._nn.upsample_bilinear2d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef', 'size': 2}, {'is_kwarg_only': 'False', 'name': 'align_corners', 'simple_type': 'bool'}], + torch._C._nn._upsample_bilinear2d_aa: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef?'}, {'is_kwarg_only': 'False', 'name': 'align_corners', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'scale_factors', 'simple_type': 'ArrayRef?'}], + torch._C._nn._upsample_bilinear2d_aa: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef', 'size': 2}, {'is_kwarg_only': 'False', 'name': 'align_corners', 'simple_type': 'bool'}], + torch._C._nn._upsample_bilinear2d_aa: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef', 'size': 2}, {'is_kwarg_only': 'False', 'name': 'align_corners', 'simple_type': 'bool'}], + torch._C._nn.upsample_trilinear3d: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef?'}, {'is_kwarg_only': 'False', 'name': 'align_corners', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'scale_factors', 'simple_type': 'ArrayRef?'}], + torch._C._nn.upsample_trilinear3d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef', 'size': 3}, {'is_kwarg_only': 'False', 'name': 'align_corners', 'simple_type': 'bool'}], + torch._C._nn.upsample_trilinear3d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef', 'size': 3}, {'is_kwarg_only': 'False', 'name': 'align_corners', 'simple_type': 'bool'}], + torch._C._nn.upsample_bicubic2d: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef?'}, {'is_kwarg_only': 'False', 'name': 'align_corners', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'scale_factors', 'simple_type': 'ArrayRef?'}], + torch._C._nn.upsample_bicubic2d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef', 'size': 2}, {'is_kwarg_only': 'False', 'name': 'align_corners', 'simple_type': 'bool'}], + torch._C._nn.upsample_bicubic2d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef', 'size': 2}, {'is_kwarg_only': 'False', 'name': 'align_corners', 'simple_type': 'bool'}], + torch._C._nn._upsample_bicubic2d_aa: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef?'}, {'is_kwarg_only': 'False', 'name': 'align_corners', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'scale_factors', 'simple_type': 'ArrayRef?'}], + torch._C._nn._upsample_bicubic2d_aa: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef', 'size': 2}, {'is_kwarg_only': 'False', 'name': 'align_corners', 'simple_type': 'bool'}], + torch._C._nn._upsample_bicubic2d_aa: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef', 'size': 2}, {'is_kwarg_only': 'False', 'name': 'align_corners', 'simple_type': 'bool'}], + torch._C._nn.upsample_nearest1d: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef?'}, {'is_kwarg_only': 'False', 'name': 'scale_factors', 'simple_type': 'ArrayRef?'}], + torch._C._nn.upsample_nearest1d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef', 'size': 1}], + torch._C._nn.upsample_nearest1d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef', 'size': 1}], + torch._C._nn._upsample_nearest_exact1d: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef?'}, {'is_kwarg_only': 'False', 'name': 'scale_factors', 'simple_type': 'ArrayRef?'}], + torch._C._nn._upsample_nearest_exact1d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef', 'size': 1}], + torch._C._nn._upsample_nearest_exact1d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef', 'size': 1}], + torch._C._nn.upsample_nearest2d: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef?'}, {'is_kwarg_only': 'False', 'name': 'scale_factors', 'simple_type': 'ArrayRef?'}], + torch._C._nn.upsample_nearest2d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef', 'size': 2}], + torch._C._nn.upsample_nearest2d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef', 'size': 2}], + torch._C._nn._upsample_nearest_exact2d: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef?'}, {'is_kwarg_only': 'False', 'name': 'scale_factors', 'simple_type': 'ArrayRef?'}], + torch._C._nn._upsample_nearest_exact2d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef', 'size': 2}], + torch._C._nn._upsample_nearest_exact2d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef', 'size': 2}], + torch._C._nn.upsample_nearest3d: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef?'}, {'is_kwarg_only': 'False', 'name': 'scale_factors', 'simple_type': 'ArrayRef?'}], + torch._C._nn.upsample_nearest3d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef', 'size': 3}], + torch._C._nn.upsample_nearest3d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef', 'size': 3}], + torch._C._nn._upsample_nearest_exact3d: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef?'}, {'is_kwarg_only': 'False', 'name': 'scale_factors', 'simple_type': 'ArrayRef?'}], + torch._C._nn._upsample_nearest_exact3d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef', 'size': 3}], + torch._C._nn._upsample_nearest_exact3d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef', 'size': 3}], + torch._C._nn.slow_conv_transpose2d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'kernel_size', 'simple_type': 'SymIntArrayRef', 'size': 2}], + torch._C._nn.slow_conv_transpose2d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'kernel_size', 'simple_type': 'SymIntArrayRef', 'size': 2}], + torch._C._nn.slow_conv_transpose3d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'kernel_size', 'simple_type': 'SymIntArrayRef', 'size': 3}], + torch._C._nn.slow_conv_transpose3d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'kernel_size', 'simple_type': 'SymIntArrayRef', 'size': 3}], + torch._C._nn.thnn_conv2d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'kernel_size', 'simple_type': 'SymIntArrayRef', 'size': 2}], + torch._C._nn.thnn_conv2d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'kernel_size', 'simple_type': 'SymIntArrayRef', 'size': 2}], + torch._C._nn._conv_depthwise2d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'kernel_size', 'simple_type': 'SymIntArrayRef', 'size': 2}, {'is_kwarg_only': 'False', 'name': 'bias', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'stride', 'simple_type': 'SymIntArrayRef', 'size': 2}, {'is_kwarg_only': 'False', 'name': 'padding', 'simple_type': 'SymIntArrayRef', 'size': 2}, {'is_kwarg_only': 'False', 'name': 'dilation', 'simple_type': 'SymIntArrayRef', 'size': 2}], + torch._C._nn._conv_depthwise2d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'kernel_size', 'simple_type': 'SymIntArrayRef', 'size': 2}, {'is_kwarg_only': 'False', 'name': 'bias', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'stride', 'simple_type': 'SymIntArrayRef', 'size': 2}, {'is_kwarg_only': 'False', 'name': 'padding', 'simple_type': 'SymIntArrayRef', 'size': 2}, {'is_kwarg_only': 'False', 'name': 'dilation', 'simple_type': 'SymIntArrayRef', 'size': 2}], + torch._C._nn.conv_depthwise3d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'kernel_size', 'simple_type': 'SymIntArrayRef', 'size': 3}, {'is_kwarg_only': 'False', 'name': 'bias', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'stride', 'simple_type': 'SymIntArrayRef', 'size': 3}, {'is_kwarg_only': 'False', 'name': 'padding', 'simple_type': 'SymIntArrayRef', 'size': 3}, {'is_kwarg_only': 'False', 'name': 'dilation', 'simple_type': 'SymIntArrayRef', 'size': 3}], + torch._C._nn.slow_conv3d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'kernel_size', 'simple_type': 'SymIntArrayRef', 'size': 3}], + torch._C._nn.slow_conv3d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'kernel_size', 'simple_type': 'SymIntArrayRef', 'size': 3}], + torch._C._nn.slow_conv_dilated2d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'kernel_size', 'simple_type': 'SymIntArrayRef', 'size': 2}], + torch._C._nn.slow_conv_dilated3d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'kernel_size', 'simple_type': 'SymIntArrayRef', 'size': 3}], + torch._C._nn.col2im: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef', 'size': 2}, {'is_kwarg_only': 'False', 'name': 'kernel_size', 'simple_type': 'IntArrayRef', 'size': 2}, {'is_kwarg_only': 'False', 'name': 'dilation', 'simple_type': 'IntArrayRef', 'size': 2}, {'is_kwarg_only': 'False', 'name': 'padding', 'simple_type': 'IntArrayRef', 'size': 2}, {'is_kwarg_only': 'False', 'name': 'stride', 'simple_type': 'IntArrayRef', 'size': 2}], + torch._C._nn.col2im: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef', 'size': 2}, {'is_kwarg_only': 'False', 'name': 'kernel_size', 'simple_type': 'IntArrayRef', 'size': 2}, {'is_kwarg_only': 'False', 'name': 'dilation', 'simple_type': 'IntArrayRef', 'size': 2}, {'is_kwarg_only': 'False', 'name': 'padding', 'simple_type': 'IntArrayRef', 'size': 2}, {'is_kwarg_only': 'False', 'name': 'stride', 'simple_type': 'IntArrayRef', 'size': 2}], + torch._C._nn.im2col: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'kernel_size', 'simple_type': 'IntArrayRef', 'size': 2}, {'is_kwarg_only': 'False', 'name': 'dilation', 'simple_type': 'IntArrayRef', 'size': 2}, {'is_kwarg_only': 'False', 'name': 'padding', 'simple_type': 'IntArrayRef', 'size': 2}, {'is_kwarg_only': 'False', 'name': 'stride', 'simple_type': 'IntArrayRef', 'size': 2}], + torch._C._nn.im2col: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'kernel_size', 'simple_type': 'IntArrayRef', 'size': 2}, {'is_kwarg_only': 'False', 'name': 'dilation', 'simple_type': 'IntArrayRef', 'size': 2}, {'is_kwarg_only': 'False', 'name': 'padding', 'simple_type': 'IntArrayRef', 'size': 2}, {'is_kwarg_only': 'False', 'name': 'stride', 'simple_type': 'IntArrayRef', 'size': 2}], + torch._C._nn._test_optional_intlist: [{'is_kwarg_only': 'False', 'name': 'values', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'addends', 'simple_type': 'IntArrayRef?'}], + torch._C._nn._test_optional_filled_intlist: [{'is_kwarg_only': 'False', 'name': 'values', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'addends', 'simple_type': 'IntArrayRef?', 'size': 2}], + torch._C._nn._test_optional_floatlist: [{'is_kwarg_only': 'False', 'name': 'values', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'addends', 'simple_type': 'ArrayRef?'}], + torch._C._nn._test_string_default: [{'is_kwarg_only': 'False', 'name': 'dummy', 'simple_type': 'Tensor'}], + torch._C._nn._test_ambiguous_defaults: [{'is_kwarg_only': 'False', 'name': 'dummy', 'simple_type': 'Tensor'}], + torch._C._nn._test_ambiguous_defaults: [{'is_kwarg_only': 'False', 'name': 'dummy', 'simple_type': 'Tensor'}], + torch._C._nn._test_warn_in_autograd: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._nn.pad_sequence: [{'is_kwarg_only': 'False', 'name': 'sequences', 'simple_type': 'TensorList'}], + torch._C._nn.flatten_dense_tensors: [{'is_kwarg_only': 'False', 'name': 'tensors', 'simple_type': 'TensorList'}], + torch._C._nn.unflatten_dense_tensors: [{'is_kwarg_only': 'False', 'name': 'flat', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'tensors', 'simple_type': 'TensorList'}], + torch._C._nn.scaled_dot_product_attention: [{'is_kwarg_only': 'False', 'name': 'query', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'key', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_diagonal: [{'is_kwarg_only': 'False', 'name': 'A', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_solve_triangular: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'B', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'True', 'name': 'upper', 'simple_type': 'bool'}], + torch._C._linalg.linalg_solve_triangular: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'B', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'True', 'name': 'upper', 'simple_type': 'bool'}], + torch._C._linalg.linalg_vander: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_cholesky_ex: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_cholesky_ex: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_cholesky: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_cholesky: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_cross: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_cross: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_lu_factor: [{'is_kwarg_only': 'False', 'name': 'A', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_lu_factor: [{'is_kwarg_only': 'False', 'name': 'A', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_lu_factor_ex: [{'is_kwarg_only': 'False', 'name': 'A', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_lu_factor_ex: [{'is_kwarg_only': 'False', 'name': 'A', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_lu: [{'is_kwarg_only': 'False', 'name': 'A', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_lu: [{'is_kwarg_only': 'False', 'name': 'A', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_lu_solve: [{'is_kwarg_only': 'False', 'name': 'LU', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'pivots', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'B', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_lu_solve: [{'is_kwarg_only': 'False', 'name': 'LU', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'pivots', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'B', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_det: [{'is_kwarg_only': 'False', 'name': 'A', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_det: [{'is_kwarg_only': 'False', 'name': 'A', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_ldl_factor_ex: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_ldl_factor_ex: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_ldl_factor: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_ldl_factor: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_ldl_solve: [{'is_kwarg_only': 'False', 'name': 'LD', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'pivots', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'B', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_ldl_solve: [{'is_kwarg_only': 'False', 'name': 'LD', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'pivots', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'B', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_lstsq: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'b', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_lstsq: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'b', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_matmul: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_matmul: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_vecdot: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'y', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_vecdot: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'y', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_matrix_exp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_slogdet: [{'is_kwarg_only': 'False', 'name': 'A', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_slogdet: [{'is_kwarg_only': 'False', 'name': 'A', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_eig: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_eig: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._linalg._linalg_eigvals: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_eigvals: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_eigvals: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_eigh: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_eigh: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_eigvalsh: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_eigvalsh: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_householder_product: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'tau', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_householder_product: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'tau', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_inv_ex: [{'is_kwarg_only': 'False', 'name': 'A', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_inv_ex: [{'is_kwarg_only': 'False', 'name': 'A', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_inv: [{'is_kwarg_only': 'False', 'name': 'A', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_inv: [{'is_kwarg_only': 'False', 'name': 'A', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_norm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_norm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'ord', 'simple_type': 'c10::string_view'}], + torch._C._linalg.linalg_norm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_norm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'ord', 'simple_type': 'c10::string_view'}], + torch._C._linalg.linalg_vector_norm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_vector_norm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_matrix_norm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'ord', 'simple_type': 'Scalar'}], + torch._C._linalg.linalg_matrix_norm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'ord', 'simple_type': 'Scalar'}], + torch._C._linalg.linalg_matrix_norm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_matrix_norm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_svd: [{'is_kwarg_only': 'False', 'name': 'A', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_svd: [{'is_kwarg_only': 'False', 'name': 'A', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_svdvals: [{'is_kwarg_only': 'False', 'name': 'A', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_svdvals: [{'is_kwarg_only': 'False', 'name': 'A', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_cond: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_cond: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_cond: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'p', 'simple_type': 'c10::string_view'}], + torch._C._linalg.linalg_cond: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'p', 'simple_type': 'c10::string_view'}], + torch._C._linalg.linalg_pinv: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_pinv: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_pinv: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_pinv: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_pinv: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'rcond', 'simple_type': 'double'}], + torch._C._linalg.linalg_pinv: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'rcond', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_pinv: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'rcond', 'simple_type': 'double'}], + torch._C._linalg.linalg_pinv: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'rcond', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_solve_ex: [{'is_kwarg_only': 'False', 'name': 'A', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'B', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_solve_ex: [{'is_kwarg_only': 'False', 'name': 'A', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'B', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_solve: [{'is_kwarg_only': 'False', 'name': 'A', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'B', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_solve: [{'is_kwarg_only': 'False', 'name': 'A', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'B', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_tensorinv: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_tensorinv: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_tensorsolve: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_tensorsolve: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_qr: [{'is_kwarg_only': 'False', 'name': 'A', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_qr: [{'is_kwarg_only': 'False', 'name': 'A', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_matrix_power: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'int64_t'}], + torch._C._linalg.linalg_matrix_power: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'int64_t'}], + torch._C._linalg.linalg_matrix_rank: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_matrix_rank: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_matrix_rank: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_matrix_rank: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_matrix_rank: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'tol', 'simple_type': 'double'}], + torch._C._linalg.linalg_matrix_rank: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'tol', 'simple_type': 'double'}], + torch._C._linalg.linalg_matrix_rank: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'tol', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_matrix_rank: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'tol', 'simple_type': 'Tensor'}], + torch._C._linalg.linalg_multi_dot: [{'is_kwarg_only': 'False', 'name': 'tensors', 'simple_type': 'TensorList'}], + torch._C._linalg.linalg_multi_dot: [{'is_kwarg_only': 'False', 'name': 'tensors', 'simple_type': 'TensorList'}], + torch._C._special.special_entr: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._special.special_entr: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._special.special_ndtri: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._special.special_ndtri: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._special.special_log_ndtr: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._special.special_log_ndtr: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._special.special_expm1: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._special.special_expm1: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._special.special_exp2: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._special.special_exp2: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._special.special_psi: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._special.special_psi: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._special.special_digamma: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._special.special_digamma: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._special.special_gammaln: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._special.special_gammaln: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._special.special_erf: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._special.special_erf: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._special.special_erfc: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._special.special_erfc: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._special.special_erfcx: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._special.special_erfcx: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._special.special_erfinv: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._special.special_erfinv: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._special.special_ndtr: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._special.special_ndtr: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._special.special_xlog1py: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._special.special_xlog1py: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._special.special_xlog1py: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._special.special_xlog1py: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._special.special_xlog1py: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._special.special_xlog1py: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._special.special_xlogy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._special.special_xlogy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._special.special_xlogy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._special.special_xlogy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._special.special_xlogy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._special.special_xlogy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._special.special_zeta: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._special.special_zeta: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._special.special_zeta: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._special.special_zeta: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._special.special_zeta: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._special.special_zeta: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch._C._special.special_i0: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._special.special_i0: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._special.special_i0e: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._special.special_i0e: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._special.special_i1: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._special.special_i1: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._special.special_i1e: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._special.special_i1e: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._special.special_logit: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._special.special_logit: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._special.special_polygamma: [{'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._special.special_polygamma: [{'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._special.special_logsumexp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'IntArrayRef', 'size': 1}], + torch._C._special.special_logsumexp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'IntArrayRef', 'size': 1}], + torch._C._special.special_expit: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._special.special_expit: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._special.special_sinc: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._special.special_sinc: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._special.special_round: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._special.special_round: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._special.special_log1p: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._special.special_log1p: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._special.special_log_softmax: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch._C._special.special_gammainc: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._special.special_gammainc: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._special.special_gammaincc: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._special.special_gammaincc: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch._C._special.special_multigammaln: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'p', 'simple_type': 'int64_t'}], + torch._C._special.special_multigammaln: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'p', 'simple_type': 'int64_t'}], + torch._C._special.special_softmax: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch._C._special.special_airy_ai: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}], + torch._C._special.special_airy_ai: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}], + torch._C._special.special_bessel_j0: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._special.special_bessel_j0: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._special.special_bessel_j1: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._special.special_bessel_j1: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._special.special_bessel_y0: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._special.special_bessel_y0: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._special.special_bessel_y1: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._special.special_bessel_y1: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._special.special_chebyshev_polynomial_t: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}], + torch._C._special.special_chebyshev_polynomial_t: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}], + torch._C._special.special_chebyshev_polynomial_t: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Scalar'}], + torch._C._special.special_chebyshev_polynomial_t: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}], + torch._C._special.special_chebyshev_polynomial_t: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}], + torch._C._special.special_chebyshev_polynomial_t: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Scalar'}], + torch._C._special.special_chebyshev_polynomial_u: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}], + torch._C._special.special_chebyshev_polynomial_u: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}], + torch._C._special.special_chebyshev_polynomial_u: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Scalar'}], + torch._C._special.special_chebyshev_polynomial_u: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}], + torch._C._special.special_chebyshev_polynomial_u: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}], + torch._C._special.special_chebyshev_polynomial_u: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Scalar'}], + torch._C._special.special_chebyshev_polynomial_v: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}], + torch._C._special.special_chebyshev_polynomial_v: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}], + torch._C._special.special_chebyshev_polynomial_v: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Scalar'}], + torch._C._special.special_chebyshev_polynomial_v: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}], + torch._C._special.special_chebyshev_polynomial_v: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}], + torch._C._special.special_chebyshev_polynomial_v: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Scalar'}], + torch._C._special.special_chebyshev_polynomial_w: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}], + torch._C._special.special_chebyshev_polynomial_w: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}], + torch._C._special.special_chebyshev_polynomial_w: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Scalar'}], + torch._C._special.special_chebyshev_polynomial_w: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}], + torch._C._special.special_chebyshev_polynomial_w: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}], + torch._C._special.special_chebyshev_polynomial_w: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Scalar'}], + torch._C._special.special_hermite_polynomial_h: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}], + torch._C._special.special_hermite_polynomial_h: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}], + torch._C._special.special_hermite_polynomial_h: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Scalar'}], + torch._C._special.special_hermite_polynomial_h: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}], + torch._C._special.special_hermite_polynomial_h: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}], + torch._C._special.special_hermite_polynomial_h: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Scalar'}], + torch._C._special.special_hermite_polynomial_he: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}], + torch._C._special.special_hermite_polynomial_he: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}], + torch._C._special.special_hermite_polynomial_he: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Scalar'}], + torch._C._special.special_hermite_polynomial_he: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}], + torch._C._special.special_hermite_polynomial_he: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}], + torch._C._special.special_hermite_polynomial_he: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Scalar'}], + torch._C._special.special_laguerre_polynomial_l: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}], + torch._C._special.special_laguerre_polynomial_l: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}], + torch._C._special.special_laguerre_polynomial_l: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Scalar'}], + torch._C._special.special_laguerre_polynomial_l: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}], + torch._C._special.special_laguerre_polynomial_l: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}], + torch._C._special.special_laguerre_polynomial_l: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Scalar'}], + torch._C._special.special_legendre_polynomial_p: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}], + torch._C._special.special_legendre_polynomial_p: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}], + torch._C._special.special_legendre_polynomial_p: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Scalar'}], + torch._C._special.special_legendre_polynomial_p: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}], + torch._C._special.special_legendre_polynomial_p: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}], + torch._C._special.special_legendre_polynomial_p: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Scalar'}], + torch._C._special.special_modified_bessel_i0: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._special.special_modified_bessel_i0: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._special.special_modified_bessel_i1: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._special.special_modified_bessel_i1: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._special.special_modified_bessel_k0: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._special.special_modified_bessel_k0: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._special.special_modified_bessel_k1: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._special.special_modified_bessel_k1: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._special.special_scaled_modified_bessel_k0: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}], + torch._C._special.special_scaled_modified_bessel_k0: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}], + torch._C._special.special_scaled_modified_bessel_k1: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}], + torch._C._special.special_scaled_modified_bessel_k1: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}], + torch._C._special.special_shifted_chebyshev_polynomial_t: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}], + torch._C._special.special_shifted_chebyshev_polynomial_t: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}], + torch._C._special.special_shifted_chebyshev_polynomial_t: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Scalar'}], + torch._C._special.special_shifted_chebyshev_polynomial_t: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}], + torch._C._special.special_shifted_chebyshev_polynomial_t: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}], + torch._C._special.special_shifted_chebyshev_polynomial_t: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Scalar'}], + torch._C._special.special_shifted_chebyshev_polynomial_u: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}], + torch._C._special.special_shifted_chebyshev_polynomial_u: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}], + torch._C._special.special_shifted_chebyshev_polynomial_u: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Scalar'}], + torch._C._special.special_shifted_chebyshev_polynomial_u: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}], + torch._C._special.special_shifted_chebyshev_polynomial_u: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}], + torch._C._special.special_shifted_chebyshev_polynomial_u: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Scalar'}], + torch._C._special.special_shifted_chebyshev_polynomial_v: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}], + torch._C._special.special_shifted_chebyshev_polynomial_v: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}], + torch._C._special.special_shifted_chebyshev_polynomial_v: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Scalar'}], + torch._C._special.special_shifted_chebyshev_polynomial_v: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}], + torch._C._special.special_shifted_chebyshev_polynomial_v: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}], + torch._C._special.special_shifted_chebyshev_polynomial_v: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Scalar'}], + torch._C._special.special_shifted_chebyshev_polynomial_w: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}], + torch._C._special.special_shifted_chebyshev_polynomial_w: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}], + torch._C._special.special_shifted_chebyshev_polynomial_w: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Scalar'}], + torch._C._special.special_shifted_chebyshev_polynomial_w: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}], + torch._C._special.special_shifted_chebyshev_polynomial_w: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}], + torch._C._special.special_shifted_chebyshev_polynomial_w: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Scalar'}], + torch._C._special.special_spherical_bessel_j0: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}], + torch._C._special.special_spherical_bessel_j0: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}], + torch._C._fft.fft_fft: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._fft.fft_fft: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._fft.fft_ifft: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._fft.fft_ifft: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._fft.fft_rfft: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._fft.fft_rfft: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._fft.fft_irfft: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._fft.fft_irfft: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._fft.fft_hfft: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._fft.fft_hfft: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._fft.fft_ihfft: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._fft.fft_ihfft: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._fft.fft_fft2: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._fft.fft_fft2: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._fft.fft_ifft2: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._fft.fft_ifft2: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._fft.fft_rfft2: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._fft.fft_rfft2: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._fft.fft_irfft2: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._fft.fft_irfft2: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._fft.fft_hfft2: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._fft.fft_hfft2: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._fft.fft_ihfft2: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._fft.fft_ihfft2: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._fft.fft_fftn: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._fft.fft_fftn: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._fft.fft_ifftn: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._fft.fft_ifftn: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._fft.fft_rfftn: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._fft.fft_rfftn: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._fft.fft_irfftn: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._fft.fft_irfftn: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._fft.fft_hfftn: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._fft.fft_hfftn: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._fft.fft_ihfftn: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._fft.fft_ihfftn: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._fft.fft_fftfreq: [{'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'int64_t'}], + torch._C._fft.fft_fftfreq: [{'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'int64_t'}], + torch._C._fft.fft_rfftfreq: [{'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'int64_t'}], + torch._C._fft.fft_rfftfreq: [{'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'int64_t'}], + torch._C._fft.fft_fftshift: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch._C._fft.fft_ifftshift: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.retain_grad: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.rename_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'names', 'simple_type': 'DimnameList?'}], + torch.Tensor.rename: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'names', 'simple_type': 'DimnameList?'}], + torch.Tensor.align_to: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'names', 'simple_type': 'DimnameList'}], + torch.Tensor.align_to: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'order', 'simple_type': 'DimnameList'}, {'is_kwarg_only': 'False', 'name': 'ellipsis_idx', 'simple_type': 'int64_t'}], + torch.Tensor.align_as: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.refine_names: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'names', 'simple_type': 'DimnameList'}], + torch.Tensor.abs: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.abs_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.absolute: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.absolute_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.angle: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.sgn: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.sgn_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.chalf: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor._conj: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.conj: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor._conj_physical: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.conj_physical: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.conj_physical_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.resolve_conj: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.resolve_neg: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor._neg_view: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.acos: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.acos_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.arccos: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.arccos_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.add: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.add_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.addmv: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'vec', 'simple_type': 'Tensor'}], + torch.Tensor.addmv_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'vec', 'simple_type': 'Tensor'}], + torch.Tensor.addr: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'vec1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'vec2', 'simple_type': 'Tensor'}], + torch.Tensor.addr_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'vec1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'vec2', 'simple_type': 'Tensor'}], + torch.Tensor._is_all_true: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor._is_any_true: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.all: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch.Tensor.all: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.all: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch.Tensor.all: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.allclose: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.any: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch.Tensor.any: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.any: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch.Tensor.any: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.argmax: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.argmin: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.acosh: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.acosh_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.arccosh: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.arccosh_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.asinh: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.asinh_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.arcsinh: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.arcsinh_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.atanh: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.atanh_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.arctanh: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.arctanh_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.as_strided: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'stride', 'simple_type': 'SymIntArrayRef'}], + torch.Tensor.as_strided_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'stride', 'simple_type': 'SymIntArrayRef'}], + torch.Tensor.asin: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.asin_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.arcsin: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.arcsin_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.atan: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.atan_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.arctan: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.arctan_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.baddbmm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'batch1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'batch2', 'simple_type': 'Tensor'}], + torch.Tensor.baddbmm_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'batch1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'batch2', 'simple_type': 'Tensor'}], + torch.Tensor.bernoulli: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.bernoulli: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'p', 'simple_type': 'double'}], + torch.Tensor.bernoulli_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'p', 'simple_type': 'Tensor'}], + torch.Tensor.bernoulli_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.bincount: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.bitwise_not: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.bitwise_not_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.copysign: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.copysign: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch.Tensor.copysign_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.copysign_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch.Tensor._lazy_clone: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.logical_not: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.logical_not_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.logical_xor: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.logical_xor_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.logical_and: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.logical_and_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.logical_or: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.logical_or_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.bmm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat2', 'simple_type': 'Tensor'}], + torch.Tensor.broadcast_to: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}], + torch.Tensor.ceil: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.ceil_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.unsafe_chunk: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'chunks', 'simple_type': 'int64_t'}], + torch.Tensor.chunk: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'chunks', 'simple_type': 'int64_t'}], + torch.Tensor.tensor_split: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'sections', 'simple_type': 'SymInt'}], + torch.Tensor.tensor_split: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'indices', 'simple_type': 'SymIntArrayRef'}], + torch.Tensor.tensor_split: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'tensor_indices_or_sections', 'simple_type': 'Tensor'}], + torch.Tensor.clamp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.clamp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.clamp_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.clamp_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.clamp_max: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'max', 'simple_type': 'Scalar'}], + torch.Tensor.clamp_max: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'max', 'simple_type': 'Tensor'}], + torch.Tensor.clamp_max_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'max', 'simple_type': 'Scalar'}], + torch.Tensor.clamp_max_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'max', 'simple_type': 'Tensor'}], + torch.Tensor.clamp_min: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'min', 'simple_type': 'Scalar'}], + torch.Tensor.clamp_min: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'min', 'simple_type': 'Tensor'}], + torch.Tensor.clamp_min_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'min', 'simple_type': 'Scalar'}], + torch.Tensor.clamp_min_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'min', 'simple_type': 'Tensor'}], + torch.Tensor.clip: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.clip: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.clip_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.clip_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.cos: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.cos_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.cosh: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.cosh_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.count_nonzero: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'IntArrayRef'}], + torch.Tensor.count_nonzero: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.cov: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.corrcoef: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.cummax: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch.Tensor.cummax: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch.Tensor.cummin: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch.Tensor.cummin: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch.Tensor.cumprod: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch.Tensor.cumprod: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch.Tensor.cumprod_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch.Tensor.cumprod_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch.Tensor.cumsum: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch.Tensor.cumsum: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch.Tensor.cumsum_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch.Tensor.cumsum_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch.Tensor.diag_embed: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.diagflat: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.diagonal: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.diagonal: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.fill_diagonal_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'fill_value', 'simple_type': 'Scalar'}], + torch.Tensor.diff: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.div: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.div: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'True', 'name': 'rounding_mode', 'simple_type': 'c10::string_view?'}], + torch.Tensor.div: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'True', 'name': 'rounding_mode', 'simple_type': 'c10::string_view?'}], + torch.Tensor.div_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.div_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'True', 'name': 'rounding_mode', 'simple_type': 'c10::string_view?'}], + torch.Tensor.div_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'True', 'name': 'rounding_mode', 'simple_type': 'c10::string_view?'}], + torch.Tensor.divide: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.divide: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch.Tensor.divide: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'True', 'name': 'rounding_mode', 'simple_type': 'c10::string_view?'}], + torch.Tensor.divide: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'True', 'name': 'rounding_mode', 'simple_type': 'c10::string_view?'}], + torch.Tensor.divide_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.divide_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch.Tensor.divide_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'True', 'name': 'rounding_mode', 'simple_type': 'c10::string_view?'}], + torch.Tensor.divide_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'True', 'name': 'rounding_mode', 'simple_type': 'c10::string_view?'}], + torch.Tensor.true_divide: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.true_divide: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch.Tensor.true_divide_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.true_divide_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch.Tensor.dot: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'tensor', 'simple_type': 'Tensor'}], + torch.Tensor.vdot: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.new_empty: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}], + torch.Tensor.new_empty_strided: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'stride', 'simple_type': 'SymIntArrayRef'}], + torch.Tensor.new_full: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'fill_value', 'simple_type': 'Scalar'}], + torch.Tensor.new_zeros: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}], + torch.Tensor.new_ones: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}], + torch.Tensor.resize_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}], + torch.Tensor.erf: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.erf_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.erfc: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.erfc_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.exp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.exp_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.exp2: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.exp2_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.expm1: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.expm1_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.expand: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}], + torch.Tensor.expand_as: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.flatten: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.flatten: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'start_dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'end_dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'out_dim', 'simple_type': 'Dimname'}], + torch.Tensor.flatten: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'start_dim', 'simple_type': 'Dimname'}, {'is_kwarg_only': 'False', 'name': 'end_dim', 'simple_type': 'Dimname'}, {'is_kwarg_only': 'False', 'name': 'out_dim', 'simple_type': 'Dimname'}], + torch.Tensor.flatten: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dims', 'simple_type': 'DimnameList'}, {'is_kwarg_only': 'False', 'name': 'out_dim', 'simple_type': 'Dimname'}], + torch.Tensor.unflatten: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'sizes', 'simple_type': 'SymIntArrayRef'}], + torch.Tensor.unflatten: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}, {'is_kwarg_only': 'False', 'name': 'sizes', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'names', 'simple_type': 'DimnameList'}], + torch.Tensor.fill_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Scalar'}], + torch.Tensor.fill_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Tensor'}], + torch.Tensor.floor: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.floor_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.floor_divide: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.floor_divide: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch.Tensor.floor_divide_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.floor_divide_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch.Tensor.frac: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.frac_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.gcd: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.gcd_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.lcm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.lcm_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.index_copy_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'source', 'simple_type': 'Tensor'}], + torch.Tensor.index_copy_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'source', 'simple_type': 'Tensor'}], + torch.Tensor.index_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'source', 'simple_type': 'Tensor'}], + torch.Tensor.index_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'source', 'simple_type': 'Tensor'}], + torch.Tensor.index_put_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'indices', 'simple_type': 'c10::List<::std::optional>'}, {'is_kwarg_only': 'False', 'name': 'values', 'simple_type': 'Tensor'}], + torch.Tensor.index_put: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'indices', 'simple_type': 'c10::List<::std::optional>'}, {'is_kwarg_only': 'False', 'name': 'values', 'simple_type': 'Tensor'}], + torch.Tensor.isclose: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.isnan: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.is_distributed: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.is_floating_point: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.is_complex: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.is_conj: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor._is_zerotensor: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.is_neg: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.isreal: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.is_nonzero: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.is_same_size: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.is_signed: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.is_inference: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.kron: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.kthvalue: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'k', 'simple_type': 'SymInt'}], + torch.Tensor.kthvalue: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'k', 'simple_type': 'SymInt'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch.Tensor.nan_to_num: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.nan_to_num_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.ldexp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.ldexp_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.log: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.log_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.log10: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.log10_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.log1p: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.log1p_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.log2: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.log2_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.logaddexp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.logaddexp2: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.xlogy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.xlogy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch.Tensor.xlogy_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.xlogy_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch.Tensor.log_softmax: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch.Tensor.log_softmax: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch.Tensor.logcumsumexp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch.Tensor.logcumsumexp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch.Tensor.logsumexp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'IntArrayRef', 'size': 1}], + torch.Tensor.logsumexp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'DimnameList', 'size': 1}], + torch.Tensor.matmul: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.matrix_power: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'int64_t'}], + torch.Tensor.matrix_exp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.aminmax: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.max: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch.Tensor.max: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch.Tensor.max: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.max: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.amax: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.mean: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.mean: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'IntArrayRef?', 'size': 1}], + torch.Tensor.mean: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'DimnameList', 'size': 1}], + torch.Tensor.nanmean: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.median: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.median: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch.Tensor.median: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch.Tensor.nanmedian: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.nanmedian: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch.Tensor.nanmedian: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch.Tensor.min: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch.Tensor.min: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch.Tensor.min: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.min: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.amin: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.mm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat2', 'simple_type': 'Tensor'}], + torch.Tensor.mode: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.mode: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch.Tensor.mul: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.mul_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.multiply: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.multiply: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch.Tensor.multiply_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.multiply_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch.Tensor.mv: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'vec', 'simple_type': 'Tensor'}], + torch.Tensor.mvlgamma: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'p', 'simple_type': 'int64_t'}], + torch.Tensor.mvlgamma_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'p', 'simple_type': 'int64_t'}], + torch.Tensor.narrow_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'start', 'simple_type': 'SymInt'}, {'is_kwarg_only': 'False', 'name': 'length', 'simple_type': 'SymInt'}], + torch.Tensor.narrow: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'start', 'simple_type': 'SymInt'}, {'is_kwarg_only': 'False', 'name': 'length', 'simple_type': 'SymInt'}], + torch.Tensor.narrow: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'start', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'length', 'simple_type': 'SymInt'}], + torch.Tensor.permute: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dims', 'simple_type': 'IntArrayRef'}], + torch.Tensor.movedim: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'source', 'simple_type': 'IntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'destination', 'simple_type': 'IntArrayRef'}], + torch.Tensor.movedim: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'source', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'destination', 'simple_type': 'int64_t'}], + torch.Tensor.moveaxis: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'source', 'simple_type': 'IntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'destination', 'simple_type': 'IntArrayRef'}], + torch.Tensor.moveaxis: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'source', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'destination', 'simple_type': 'int64_t'}], + torch.Tensor.adjoint: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.is_pinned: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.pin_memory: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.pinverse: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.rad2deg: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.rad2deg_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.deg2rad: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.deg2rad_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.ravel: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.reciprocal: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.reciprocal_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.neg: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.neg_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.negative: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.negative_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.repeat: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'repeats', 'simple_type': 'SymIntArrayRef'}], + torch.Tensor.repeat_interleave: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'repeats', 'simple_type': 'Tensor'}], + torch.Tensor.repeat_interleave: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'repeats', 'simple_type': 'SymInt'}], + torch.Tensor.reshape: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'shape', 'simple_type': 'SymIntArrayRef'}], + torch.Tensor.reshape_as: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.round: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.round: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.round_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.round_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.relu: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.relu_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.prelu: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}], + torch.Tensor.hardshrink: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.rsqrt: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.rsqrt_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.select: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'int64_t'}], + torch.Tensor.select: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'SymInt'}], + torch.Tensor.sigmoid: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.sigmoid_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.logit: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.logit_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.sin: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.sin_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.sinc: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.sinc_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.sinh: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.sinh_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.detach: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.detach_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.slice_inverse: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'src', 'simple_type': 'Tensor'}], + torch.Tensor.slice_scatter: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'src', 'simple_type': 'Tensor'}], + torch.Tensor.select_scatter: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'src', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'SymInt'}], + torch.Tensor.diagonal_scatter: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'src', 'simple_type': 'Tensor'}], + torch.Tensor.as_strided_scatter: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'src', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'stride', 'simple_type': 'SymIntArrayRef'}], + torch.Tensor.smm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat2', 'simple_type': 'Tensor'}], + torch.Tensor.softmax: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch.Tensor.softmax: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch.Tensor.unsafe_split: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'split_size', 'simple_type': 'SymInt'}], + torch.Tensor.split: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'split_size', 'simple_type': 'SymInt'}], + torch.Tensor.split: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'split_size', 'simple_type': 'SymIntArrayRef'}], + torch.Tensor.unsafe_split_with_sizes: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'split_sizes', 'simple_type': 'SymIntArrayRef'}], + torch.Tensor.split_with_sizes: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'split_sizes', 'simple_type': 'SymIntArrayRef'}], + torch.Tensor.hsplit: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'sections', 'simple_type': 'int64_t'}], + torch.Tensor.hsplit: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'indices', 'simple_type': 'IntArrayRef'}], + torch.Tensor.vsplit: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'sections', 'simple_type': 'int64_t'}], + torch.Tensor.vsplit: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'indices', 'simple_type': 'IntArrayRef'}], + torch.Tensor.dsplit: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'sections', 'simple_type': 'int64_t'}], + torch.Tensor.dsplit: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'indices', 'simple_type': 'IntArrayRef'}], + torch.Tensor.squeeze: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.squeeze: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch.Tensor.squeeze: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch.Tensor.squeeze: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'IntArrayRef'}], + torch.Tensor.squeeze_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.squeeze_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch.Tensor.squeeze_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'IntArrayRef'}], + torch.Tensor.squeeze_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch.Tensor.sspaddmm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat2', 'simple_type': 'Tensor'}], + torch.Tensor.stft: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n_fft', 'simple_type': 'int64_t'}], + torch.Tensor.stft: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n_fft', 'simple_type': 'int64_t'}], + torch.Tensor.istft: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n_fft', 'simple_type': 'int64_t'}], + torch.Tensor.sum: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.sum: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'IntArrayRef?', 'size': 1}], + torch.Tensor.sum: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'DimnameList', 'size': 1}], + torch.Tensor.nansum: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.hash_tensor: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.sum_to_size: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}], + torch.Tensor.sqrt: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.sqrt_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.square: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.square_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.std: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.std: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'IntArrayRef?', 'size': 1}], + torch.Tensor.std: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.std: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'DimnameList', 'size': 1}], + torch.Tensor.std: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'DimnameList', 'size': 1}], + torch.Tensor.prod: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.prod: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch.Tensor.prod: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch.Tensor.t: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.t_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.tan: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.tan_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.tanh: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.tanh_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.tile: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dims', 'simple_type': 'SymIntArrayRef'}], + torch.Tensor.transpose: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim0', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'dim1', 'simple_type': 'int64_t'}], + torch.Tensor.transpose: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim0', 'simple_type': 'Dimname'}, {'is_kwarg_only': 'False', 'name': 'dim1', 'simple_type': 'Dimname'}], + torch.Tensor.transpose_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim0', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'dim1', 'simple_type': 'int64_t'}], + torch.Tensor.flip: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dims', 'simple_type': 'IntArrayRef'}], + torch.Tensor.fliplr: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.flipud: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.roll: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'shifts', 'simple_type': 'SymIntArrayRef', 'size': 1}], + torch.Tensor.rot90: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor._nested_tensor_size: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor._nested_tensor_strides: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor._nested_tensor_storage_offsets: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.trunc: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.trunc_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.fix: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.fix_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.type_as: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.unsqueeze: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch.Tensor.unsqueeze_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}], + torch.Tensor.var: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.var: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'IntArrayRef?', 'size': 1}], + torch.Tensor.var: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.var: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'DimnameList', 'size': 1}], + torch.Tensor.var: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'DimnameList', 'size': 1}], + torch.Tensor.view_as: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.where: [{'is_kwarg_only': 'False', 'name': 'condition', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.where: [{'is_kwarg_only': 'False', 'name': 'condition', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch.Tensor.norm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'p', 'simple_type': 'Scalar?'}, {'is_kwarg_only': 'True', 'name': 'dtype', 'simple_type': 'ScalarType'}], + torch.Tensor.norm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.norm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'p', 'simple_type': 'Scalar?'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'IntArrayRef', 'size': 1}, {'is_kwarg_only': 'False', 'name': 'keepdim', 'simple_type': 'bool'}, {'is_kwarg_only': 'True', 'name': 'dtype', 'simple_type': 'ScalarType'}], + torch.Tensor.norm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'p', 'simple_type': 'Scalar?'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'IntArrayRef', 'size': 1}], + torch.Tensor.norm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'p', 'simple_type': 'Scalar?'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'DimnameList', 'size': 1}, {'is_kwarg_only': 'False', 'name': 'keepdim', 'simple_type': 'bool'}, {'is_kwarg_only': 'True', 'name': 'dtype', 'simple_type': 'ScalarType'}], + torch.Tensor.norm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'p', 'simple_type': 'Scalar?'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'DimnameList', 'size': 1}], + torch.Tensor.frexp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.clone: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.positive: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.resize_as_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'the_template', 'simple_type': 'Tensor'}], + torch.Tensor.resize_as_sparse_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'the_template', 'simple_type': 'Tensor'}], + torch.Tensor.zero_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.sub: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.sub_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.subtract: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.subtract: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch.Tensor.subtract_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.subtract_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch.Tensor.heaviside: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'values', 'simple_type': 'Tensor'}], + torch.Tensor.heaviside_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'values', 'simple_type': 'Tensor'}], + torch.Tensor.addmm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat2', 'simple_type': 'Tensor'}], + torch.Tensor.addmm_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat2', 'simple_type': 'Tensor'}], + torch.Tensor._addmm_activation: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat2', 'simple_type': 'Tensor'}], + torch.Tensor.sparse_resize_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'IntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'sparse_dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'dense_dim', 'simple_type': 'int64_t'}], + torch.Tensor.sparse_resize_and_clear_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'IntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'sparse_dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'dense_dim', 'simple_type': 'int64_t'}], + torch.Tensor.sparse_mask: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mask', 'simple_type': 'Tensor'}], + torch.Tensor._sparse_mask_projection: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mask', 'simple_type': 'Tensor'}], + torch.Tensor.to_dense: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor._to_dense: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.sparse_dim: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor._dimI: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.dense_dim: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor._dimV: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor._nnz: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.coalesce: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.is_coalesced: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor._indices: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor._values: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor._coalesced_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'coalesced', 'simple_type': 'bool'}], + torch.Tensor.indices: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.values: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.crow_indices: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.col_indices: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.ccol_indices: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.row_indices: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.unbind: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.unbind: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch.Tensor.to_sparse: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'sparse_dim', 'simple_type': 'int64_t'}], + torch.Tensor.to_sparse: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor._to_sparse: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'sparse_dim', 'simple_type': 'int64_t'}], + torch.Tensor._to_sparse: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.to_sparse_csr: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor._to_sparse_csr: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.to_sparse_csc: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor._to_sparse_csc: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.to_sparse_bsr: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'blocksize', 'simple_type': 'IntArrayRef', 'size': 2}], + torch.Tensor._to_sparse_bsr: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'blocksize', 'simple_type': 'IntArrayRef', 'size': 2}], + torch.Tensor.to_sparse_bsc: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'blocksize', 'simple_type': 'IntArrayRef', 'size': 2}], + torch.Tensor._to_sparse_bsc: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'blocksize', 'simple_type': 'IntArrayRef', 'size': 2}], + torch.Tensor.to_mkldnn: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.dequantize: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.q_scale: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.q_zero_point: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.q_per_channel_scales: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.q_per_channel_zero_points: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.q_per_channel_axis: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.int_repr: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.qscheme: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor._autocast_to_reduced_precision: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'cuda_enabled', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'cpu_enabled', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'cuda_dtype', 'simple_type': 'ScalarType'}, {'is_kwarg_only': 'False', 'name': 'cpu_dtype', 'simple_type': 'ScalarType'}], + torch.Tensor._autocast_to_full_precision: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'cuda_enabled', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'cpu_enabled', 'simple_type': 'bool'}], + torch.Tensor.is_set_to: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'tensor', 'simple_type': 'Tensor'}], + torch.Tensor.masked_fill_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mask', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Scalar'}], + torch.Tensor.masked_fill_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mask', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Tensor'}], + torch.Tensor.masked_fill: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mask', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Scalar'}], + torch.Tensor.masked_fill: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mask', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Tensor'}], + torch.Tensor.masked_scatter_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mask', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'source', 'simple_type': 'Tensor'}], + torch.Tensor.masked_scatter: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mask', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'source', 'simple_type': 'Tensor'}], + torch.Tensor.view: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}], + torch.Tensor.view: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dtype', 'simple_type': 'ScalarType'}], + torch.Tensor.put_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'source', 'simple_type': 'Tensor'}], + torch.Tensor.put: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'source', 'simple_type': 'Tensor'}], + torch.Tensor.index_add_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'source', 'simple_type': 'Tensor'}], + torch.Tensor.index_add: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'source', 'simple_type': 'Tensor'}], + torch.Tensor.index_add: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'source', 'simple_type': 'Tensor'}], + torch.Tensor.index_reduce_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'source', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'reduce', 'simple_type': 'c10::string_view'}], + torch.Tensor.index_reduce: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'source', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'reduce', 'simple_type': 'c10::string_view'}], + torch.Tensor.index_fill_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Scalar'}], + torch.Tensor.index_fill_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Tensor'}], + torch.Tensor.index_fill_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Scalar'}], + torch.Tensor.index_fill_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Tensor'}], + torch.Tensor.index_fill: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Scalar'}], + torch.Tensor.index_fill: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Tensor'}], + torch.Tensor.index_fill: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Scalar'}], + torch.Tensor.index_fill: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Tensor'}], + torch.Tensor.scatter: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'src', 'simple_type': 'Tensor'}], + torch.Tensor.scatter: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Scalar'}], + torch.Tensor.scatter: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'src', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'True', 'name': 'reduce', 'simple_type': 'c10::string_view'}], + torch.Tensor.scatter: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'True', 'name': 'reduce', 'simple_type': 'c10::string_view'}], + torch.Tensor.scatter: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'src', 'simple_type': 'Tensor'}], + torch.Tensor.scatter: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Scalar'}], + torch.Tensor.scatter_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'src', 'simple_type': 'Tensor'}], + torch.Tensor.scatter_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Scalar'}], + torch.Tensor.scatter_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'src', 'simple_type': 'Tensor'}], + torch.Tensor.scatter_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Scalar'}], + torch.Tensor.scatter_add: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'src', 'simple_type': 'Tensor'}], + torch.Tensor.scatter_add: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'src', 'simple_type': 'Tensor'}], + torch.Tensor.scatter_add_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'src', 'simple_type': 'Tensor'}], + torch.Tensor.scatter_reduce: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'src', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'reduce', 'simple_type': 'c10::string_view'}], + torch.Tensor.scatter_reduce_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'src', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'reduce', 'simple_type': 'c10::string_view'}], + torch.Tensor.eq_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch.Tensor.eq_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.bitwise_and: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch.Tensor.bitwise_and: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.bitwise_and_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch.Tensor.bitwise_and_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.__and__: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch.Tensor.__and__: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.__iand__: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch.Tensor.__iand__: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.bitwise_or: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch.Tensor.bitwise_or: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.bitwise_or_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch.Tensor.bitwise_or_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.__or__: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch.Tensor.__or__: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.__ior__: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch.Tensor.__ior__: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.bitwise_xor: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch.Tensor.bitwise_xor: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.bitwise_xor_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch.Tensor.bitwise_xor_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.__xor__: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch.Tensor.__xor__: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.__ixor__: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch.Tensor.__ixor__: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.__lshift__: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch.Tensor.__lshift__: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.__ilshift__: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch.Tensor.__ilshift__: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.bitwise_left_shift: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.bitwise_left_shift: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch.Tensor.bitwise_left_shift_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.bitwise_left_shift_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch.Tensor.__rshift__: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch.Tensor.__rshift__: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.__irshift__: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch.Tensor.__irshift__: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.bitwise_right_shift: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.bitwise_right_shift: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch.Tensor.bitwise_right_shift_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.bitwise_right_shift_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch.Tensor.tril_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.triu_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.digamma_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.lerp_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'end', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Scalar'}], + torch.Tensor.lerp_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'end', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}], + torch.Tensor.addbmm_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'batch1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'batch2', 'simple_type': 'Tensor'}], + torch.Tensor.addbmm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'batch1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'batch2', 'simple_type': 'Tensor'}], + torch.Tensor.random_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'from', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'to', 'simple_type': 'int64_t?'}], + torch.Tensor.random_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'to', 'simple_type': 'int64_t'}], + torch.Tensor.random_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.uniform_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.cauchy_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.log_normal_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.exponential_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.geometric_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'p', 'simple_type': 'double'}], + torch.Tensor.diag: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.cross: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.triu: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.tril: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.trace: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.ne: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch.Tensor.ne: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.ne_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch.Tensor.ne_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.not_equal: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch.Tensor.not_equal: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.not_equal_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch.Tensor.not_equal_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.eq: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch.Tensor.eq: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.ge: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch.Tensor.ge: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.ge_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch.Tensor.ge_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.greater_equal: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch.Tensor.greater_equal: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.greater_equal_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch.Tensor.greater_equal_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.le: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch.Tensor.le: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.le_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch.Tensor.le_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.less_equal: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch.Tensor.less_equal: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.less_equal_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch.Tensor.less_equal_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.gt: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch.Tensor.gt: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.gt_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch.Tensor.gt_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.greater: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch.Tensor.greater: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.greater_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch.Tensor.greater_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.lt: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch.Tensor.lt: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.lt_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch.Tensor.lt_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.less: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch.Tensor.less: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.less_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch.Tensor.less_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.take: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}], + torch.Tensor.take_along_dim: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'indices', 'simple_type': 'Tensor'}], + torch.Tensor.index_select: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}], + torch.Tensor.index_select: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}], + torch.Tensor.masked_select: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mask', 'simple_type': 'Tensor'}], + torch.Tensor.nonzero_static: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'True', 'name': 'size', 'simple_type': 'SymInt'}], + torch.Tensor.argwhere: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.gather: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}], + torch.Tensor.gather: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}], + torch.Tensor.addcmul: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'tensor1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'tensor2', 'simple_type': 'Tensor'}], + torch.Tensor.addcmul_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'tensor1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'tensor2', 'simple_type': 'Tensor'}], + torch.Tensor.addcdiv: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'tensor1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'tensor2', 'simple_type': 'Tensor'}], + torch.Tensor.addcdiv_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'tensor1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'tensor2', 'simple_type': 'Tensor'}], + torch.Tensor.triangular_solve: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'A', 'simple_type': 'Tensor'}], + torch.Tensor.svd: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.swapaxes: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'axis0', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'axis1', 'simple_type': 'int64_t'}], + torch.Tensor.swapaxes_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'axis0', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'axis1', 'simple_type': 'int64_t'}], + torch.Tensor.swapdims: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim0', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'dim1', 'simple_type': 'int64_t'}], + torch.Tensor.swapdims_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim0', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'dim1', 'simple_type': 'int64_t'}], + torch.Tensor.cholesky: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.cholesky_solve: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'input2', 'simple_type': 'Tensor'}], + torch.Tensor.cholesky_inverse: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.qr: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.geqrf: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.orgqr: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'input2', 'simple_type': 'Tensor'}], + torch.Tensor.ormqr: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'input2', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'input3', 'simple_type': 'Tensor'}], + torch.Tensor.lu_solve: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'LU_data', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'LU_pivots', 'simple_type': 'Tensor'}], + torch.Tensor.multinomial: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'num_samples', 'simple_type': 'SymInt'}], + torch.Tensor.lgamma_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.lgamma: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.digamma: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.polygamma: [{'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.polygamma_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'int64_t'}], + torch.Tensor.erfinv: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.erfinv_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.i0: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.i0_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.sign: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.sign_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.signbit: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.dist: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.atan2_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.atan2: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.arctan2: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.arctan2_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.lerp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'end', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Scalar'}], + torch.Tensor.lerp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'end', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}], + torch.Tensor.histc: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.histogram: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'bins', 'simple_type': 'Tensor'}], + torch.Tensor.histogram: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.fmod: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch.Tensor.fmod: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.fmod_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch.Tensor.fmod_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.hypot: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.hypot_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.igamma: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.igamma_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.igammac: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.igammac_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.nextafter: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.nextafter_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.remainder: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch.Tensor.remainder: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.remainder_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}], + torch.Tensor.remainder_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.fmin: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.fmax: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.maximum: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.minimum: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.quantile: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'q', 'simple_type': 'Tensor'}], + torch.Tensor.quantile: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'q', 'simple_type': 'double'}], + torch.Tensor.nanquantile: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'q', 'simple_type': 'Tensor'}], + torch.Tensor.nanquantile: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'q', 'simple_type': 'double'}], + torch.Tensor.sort: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.sort: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'True', 'name': 'stable', 'simple_type': 'bool?'}], + torch.Tensor.sort: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch.Tensor.sort: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'True', 'name': 'stable', 'simple_type': 'bool?'}, {'is_kwarg_only': 'True', 'name': 'dim', 'simple_type': 'Dimname'}], + torch.Tensor.msort: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.argsort: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.argsort: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'True', 'name': 'stable', 'simple_type': 'bool'}], + torch.Tensor.argsort: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}], + torch.Tensor.topk: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'k', 'simple_type': 'SymInt'}], + torch.Tensor.renorm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'p', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'maxnorm', 'simple_type': 'Scalar'}], + torch.Tensor.renorm_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'p', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'maxnorm', 'simple_type': 'Scalar'}], + torch.Tensor.unfold: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dimension', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'step', 'simple_type': 'int64_t'}], + torch.Tensor.equal: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.pow: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'exponent', 'simple_type': 'Tensor'}], + torch.Tensor.pow: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'exponent', 'simple_type': 'Scalar'}], + torch.Tensor.pow_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'exponent', 'simple_type': 'Scalar'}], + torch.Tensor.pow_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'exponent', 'simple_type': 'Tensor'}], + torch.Tensor.float_power: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'exponent', 'simple_type': 'Tensor'}], + torch.Tensor.float_power: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'exponent', 'simple_type': 'Scalar'}], + torch.Tensor.float_power_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'exponent', 'simple_type': 'Scalar'}], + torch.Tensor.float_power_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'exponent', 'simple_type': 'Tensor'}], + torch.Tensor.normal_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.isfinite: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.isinf: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.record_stream: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 's', 'simple_type': 'Stream'}], + torch.Tensor.isposinf: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.isneginf: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.det: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.slogdet: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.logdet: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.inverse: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}], + torch.Tensor.inner: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}], + torch.Tensor.outer: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'vec2', 'simple_type': 'Tensor'}], + torch.Tensor.ger: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'vec2', 'simple_type': 'Tensor'}], + torch.Tensor.to_padded_tensor: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'padding', 'simple_type': 'double'}], +} diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/inductor_utils.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/inductor_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..96317780dffb52409562c395c06797b3c658a1db --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/inductor_utils.py @@ -0,0 +1,440 @@ +# mypy: ignore-errors + +import contextlib +import functools +import logging +import os +import re +import sys +import unittest +from subprocess import CalledProcessError + +import torch +import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools +from torch._inductor.codecache import CppCodeCache +from torch._inductor.codegen.common import ( + get_custom_backend_config_for_device, + get_custom_backend_pass_for_device, + get_scheduling_for_device, + get_wrapper_codegen_for_device, + init_backend_registration, + register_backend_for_device, +) +from torch._inductor.codegen.wrapper import PythonWrapperCodegen +from torch._inductor.compile_fx import shape_env_from_inputs +from torch._inductor.custom_graph_pass import CustomGraphModulePass +from torch._inductor.graph import GraphLowering +from torch._inductor.utils import ( + get_gpu_shared_memory, + get_gpu_type, + GPU_TYPES, + is_big_gpu, + is_gpu, + OrderedSet, +) +from torch.fx.experimental.proxy_tensor import make_fx +from torch.utils._helion import has_helion +from torch.utils._pallas import has_pallas_package, has_tpu_pallas +from torch.utils._triton import has_triton +from torch.utils._config_module import ConfigModule +from torch.testing._internal.common_device_type import ( + get_desired_device_type_test_bases, +) +from torch.testing._internal.common_utils import ( + IS_CI, + IS_FBCODE, + IS_WINDOWS, + LazyVal, + TestCase, +) + +log: logging.Logger = logging.getLogger(__name__) + + +def test_cpu(): + try: + CppCodeCache.load("") + return not IS_FBCODE + except ( + CalledProcessError, + OSError, + torch._inductor.exc.InvalidCxxCompiler, + torch._inductor.exc.CppCompileError, + ): + return False + + +HAS_CPU = LazyVal(test_cpu) + +HAS_TRITON = has_triton() + +HAS_PALLAS = has_pallas_package() + +HAS_HELION = has_helion() + +if HAS_TRITON: + import triton + + TRITON_HAS_CPU = "cpu" in triton.backends.backends +else: + TRITON_HAS_CPU = False + + +HAS_CUDA_AND_TRITON = torch.cuda.is_available() and HAS_TRITON + +HAS_XPU_AND_TRITON = torch.xpu.is_available() and HAS_TRITON + +HAS_MPS = torch.mps.is_available() + +HAS_GPU = HAS_CUDA_AND_TRITON or HAS_XPU_AND_TRITON +HAS_GPU_AND_TRITON = HAS_GPU + +GPU_TYPE = get_gpu_type() + +HAS_MULTIGPU = any( + getattr(torch, gpu).is_available() and getattr(torch, gpu).device_count() >= 2 + for gpu in GPU_TYPES +) + +_desired_test_bases = get_desired_device_type_test_bases(allow_xpu=True) +RUN_GPU = HAS_GPU and any( + is_gpu(getattr(x, "device_type", "")) for x in _desired_test_bases +) + +RUN_CPU = HAS_CPU and any( + getattr(x, "device_type", "") == "cpu" for x in _desired_test_bases +) + +HAS_TPU = has_tpu_pallas() +RUN_TPU = HAS_TPU + + +def _check_has_dynamic_shape( + self: TestCase, + code, +): + for_loop_found = False + has_dynamic = False + lines = code.split("\n") + for line in lines: + if "for(" in line: + for_loop_found = True + if re.search(r";.*ks.*;", line) is not None: + has_dynamic = True + break + self.assertTrue( + has_dynamic, msg=f"Failed to find dynamic for loop variable\n{code}" + ) + self.assertTrue(for_loop_found, f"Failed to find for loop\n{code}") + + +def skipDeviceIf(cond, msg, *, device): + if cond: + + def decorate_fn(fn): + @functools.wraps(fn) + def inner(self, *args, **kwargs): + if not hasattr(self, "device"): + warn_msg = ( + "Expect the test class to have attribute device but not found. " + ) + if hasattr(self, "device_type"): + warn_msg += "Consider using the skip device decorators in common_device_type.py" + log.warning(warn_msg) + if self.device == device: + raise unittest.SkipTest(msg) + return fn(self, *args, **kwargs) + + return inner + + else: + + def decorate_fn(fn): + return fn + + return decorate_fn + + +def skip_windows_ci(name: str, file: str) -> None: + if IS_WINDOWS and IS_CI: + module = os.path.basename(file).strip(".py") + sys.stderr.write( + f"Windows CI does not have necessary dependencies for {module} tests yet\n" + ) + if name == "__main__": + sys.exit(0) + raise unittest.SkipTest("requires sympy/functorch/filelock") + + +# TODO: Remove HAS_MPS condition when `HAS_GPU` includes HAS_MPS +requires_gpu = functools.partial( + unittest.skipIf, not (HAS_GPU or HAS_MPS), "requires gpu" +) +requires_triton = functools.partial(unittest.skipIf, not HAS_TRITON, "requires triton") +requires_helion = functools.partial(unittest.skipIf, not HAS_HELION, "requires helion") + + +def requires_cuda_with_enough_memory(min_mem_required): + def inner(fn): + if ( + not torch.cuda.is_available() + or torch.cuda.get_device_properties().total_memory < min_mem_required + ): + return unittest.skip( + f"Only if the CUDA device has at least {min_mem_required / 1e9:.3f}GB memory to be safe" + )(fn) + else: + return fn + + return inner + + +skipCUDAIf = functools.partial(skipDeviceIf, device="cuda") +skipXPUIf = functools.partial(skipDeviceIf, device="xpu") +skipCPUIf = functools.partial(skipDeviceIf, device="cpu") + +IS_A100 = LazyVal(lambda: HAS_CUDA_AND_TRITON and get_gpu_shared_memory() == 166912) + +IS_H100 = LazyVal(lambda: HAS_CUDA_AND_TRITON and get_gpu_shared_memory() == 232448) + +IS_BIG_GPU = LazyVal(lambda: HAS_GPU_AND_TRITON and is_big_gpu()) + + +def dummy_graph() -> GraphLowering: + """ + Create a graph. This is useful for unit testing code which accesses + V.graph.sizevars. + """ + example_inputs = [torch.randn(10) for _ in range(2)] + gm = make_fx(torch.add, tracing_mode="fake")(*example_inputs) + shape_env = shape_env_from_inputs(example_inputs) + graph = GraphLowering( + gm, + shape_env=shape_env, + ) + + return graph + + +def maybe_skip_size_asserts(op): + """ + For certain ops, there meta and eager implementation returns different + strides. This cause size/strides assert fail. Skip adding those + asserts for now. + """ + if ( + op.aten_name + in ( + "fft_hfftn", + "fft_hfft", + "fft_hfft2", + "fft_ihfftn", + "fft_fft", + "fft_fft2", + "fft_fftn", + "fft_ifft", + "fft_ifft2", + "fft_ifftn", + "fft_irfft", + "fft_irfft2", + "fft_irfftn", + "fft_ihfft", + "fft_ihfft2", + "fft_rfft", + "fft_rfft2", + "fft_rfftn", + "linalg_eig", + "linalg_eigvals", + ) + and "TORCHINDUCTOR_SIZE_ASSERTS" not in os.environ + ): + return torch._inductor.config.patch(size_asserts=False) + else: + return contextlib.nullcontext() + + +def get_func_call() -> str: + return ( + "void inductor_entry_impl(" + if torch._inductor.config.cpp_wrapper + else "def call(" + ) + + +def get_kernel_launch() -> str: + return "call_triton_" if torch._inductor.config.cpp_wrapper else ".run(" + + +def clone_preserve_strides_offset(x, device=None): + if not isinstance(x, torch.Tensor): + return x + buffer = torch.as_strided( + x, (x.untyped_storage().size() // x.element_size(),), (1,), 0 + ) + if not device: + buffer = buffer.clone() + else: + buffer = buffer.to(device, copy=True) + out = torch.as_strided(buffer, x.size(), x.stride(), x.storage_offset()) + return out + + +# define the e4m3/e5m2 constants +E4M3_MAX_POS = torch.finfo(torch.float8_e4m3fn).max +E5M2_MAX_POS = torch.finfo(torch.float8_e5m2).max +E4M3FNUZ_MAX_POS = torch.finfo(torch.float8_e4m3fnuz).max +E5M2FNUZ_MAX_POS = torch.finfo(torch.float8_e5m2fnuz).max + +FP16_MAX_POS: float = torch.finfo(torch.float16).max +EPS: float = 1e-12 + +Tensor = torch.Tensor + + +def _to_fp8_saturated(x: Tensor, float8_dtype: torch.dtype) -> Tensor: + # The default behavior in PyTorch for casting to `float8_e4m3fn` + # and `e5m2` is to not saturate. In this context, we should saturate. + # A common case where we want to saturate is when the history of a + # tensor has a maximum value of `amax1`, and the current amax value + # is `amax2`, where `amax1 < amax2`. This is common when using delayed + # scaling. + if float8_dtype == torch.float8_e4m3fn: + x = x.clamp(min=-1 * E4M3_MAX_POS, max=E4M3_MAX_POS) + elif float8_dtype == torch.float8_e5m2: + x = x.clamp(min=-1 * E5M2_MAX_POS, max=E5M2_MAX_POS) + elif float8_dtype == torch.float8_e4m3fnuz: + x = x.clamp(min=-1 * E4M3FNUZ_MAX_POS, max=E4M3FNUZ_MAX_POS) + elif float8_dtype == torch.float8_e5m2fnuz: + x = x.clamp(min=-1 * E5M2FNUZ_MAX_POS, max=E5M2FNUZ_MAX_POS) + else: + raise TypeError(f"Unsupported float8_dtype: {float8_dtype}") + return x.to(float8_dtype) + + +@torch.no_grad() +def _amax_to_scale( + amax: torch.Tensor, float8_dtype: torch.dtype, orig_dtype: torch.dtype +) -> torch.Tensor: + # To make scale dtype to be fp32 for accuracy + amax = amax.float() + if float8_dtype == torch.float8_e4m3fn: + res = E4M3_MAX_POS / torch.clamp(amax, min=EPS) + else: # e5m2 + res = E5M2_MAX_POS / torch.clamp(amax, min=EPS) + + # Ensure that the scale is representable in float16, + # this helps when amax is small. We are assuming that we don't need + # to care about this for float32/bfloat16. + if orig_dtype is torch.float16: + res = torch.clamp(res, max=FP16_MAX_POS) + return res + + +def _quantize_tensorwise(x: Tensor, float8_dtype: torch.dtype): + amax = torch.max(torch.abs(x)) + scale = _amax_to_scale(amax, float8_dtype, x.dtype) + x_fp8 = _to_fp8_saturated(x * scale, float8_dtype) + inverse_scale = scale.reciprocal() + return x_fp8, inverse_scale + + +def _quantize_rowwise(x: Tensor, float8_dtype: torch.dtype): + amax = torch.max(torch.abs(x), dim=1, keepdim=True).values + scale = _amax_to_scale(amax, float8_dtype, x.dtype) + x_fp8 = _to_fp8_saturated(x * scale, float8_dtype) + inverse_scale = scale.reciprocal() + return x_fp8, inverse_scale + + +def _quantize_blockwise( + x: Tensor, float8_dtype: torch.dtype, block_outer: int, block_inner: int +): + min_outer = min(block_outer, x.shape[0]) + min_inner = min(block_inner, x.shape[1]) + x = x.unflatten(1, (-1, min_inner)).unflatten(0, (-1, min_outer)) + amax = x.abs().amax(dim=[1, 3], keepdim=True).float() + scale = _amax_to_scale(amax, float8_dtype, x.dtype) + x = x.flatten(2, 3).flatten(0, 1) + scale = scale.flatten(2, 3).flatten(0, 1) + scale_expanded = scale.repeat_interleave(min_outer, dim=0).repeat_interleave( + min_inner, dim=1 + ) + x_fp8 = _to_fp8_saturated( + x / scale_expanded, # Ensures that scaling doesn't cause inf/nan values + float8_dtype, + ) + inverse_scale = scale.reciprocal() + return x_fp8, inverse_scale + + +class MockGraphHandler(GraphLowering): + """Minimal mock graph handler for testing virtualized context.""" + + def __init__(self, name_to_buffer=None): + import torch._inductor.sizevars + + self.sizevars = torch._inductor.sizevars.SizeVarAllocator() + self.name_to_buffer = name_to_buffer or {} + self.graph_inputs = {} + self.mutated_buffers = OrderedSet() + self.removed_buffers = OrderedSet() + self.constants = {} + self.scheduler = None + + def get_dtype(self, buffer_name: str) -> torch.dtype: # noqa: ARG002 + """Return default dtype for any buffer (for testing).""" + return torch.float32 + + +@contextlib.contextmanager +def patch_inductor_backend( + device: str, + python_wrapper_codegen: PythonWrapperCodegen = None, + custom_pass: CustomGraphModulePass = None, + custom_backend_config: ConfigModule = None, +): + """ + Patch the inductor backend for a specific device. + """ + # Make sure the backend is already registered + init_backend_registration() + + # Get the original registration parameters + original_scheduling = get_scheduling_for_device(device) + original_python_wrapper = get_wrapper_codegen_for_device(device, False) + original_cpp_wrapper = get_wrapper_codegen_for_device(device, True) + original_fx_wrapper = get_wrapper_codegen_for_device(device, fx_wrapper=True) + original_custom_pass = get_custom_backend_pass_for_device(device) + original_custom_backend_config = get_custom_backend_config_for_device(device) + + try: + # Register modified backend for the device + register_backend_for_device( + device, + original_scheduling, + ( + python_wrapper_codegen + if python_wrapper_codegen is not None + else original_python_wrapper + ), + original_cpp_wrapper, + original_fx_wrapper, + custom_pass if custom_pass is not None else original_custom_pass, + ( + custom_backend_config + if custom_backend_config is not None + else original_custom_backend_config + ), + ) + yield + finally: + # Restore the original backend + register_backend_for_device( + device, + original_scheduling, + original_python_wrapper, + original_cpp_wrapper, + original_fx_wrapper, + original_custom_pass, + original_custom_backend_config, + ) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/logging_tensor.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/logging_tensor.py new file mode 100644 index 0000000000000000000000000000000000000000..e71f0f46854756a4b4251df6a53a03a288183172 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/logging_tensor.py @@ -0,0 +1,168 @@ +# mypy: ignore-errors + +import torch +from torch.utils._pytree import tree_map +from typing import Optional +from collections.abc import Iterator +import logging +import contextlib +import itertools +from torch.utils._dtype_abbrs import dtype_abbrs as _dtype_abbrs +from torch.utils._python_dispatch import TorchDispatchMode +from torch.utils.weak import WeakTensorKeyDictionary +import functools +from torch._C._profiler import gather_traceback, symbolize_tracebacks + +logger = logging.getLogger("LoggingTensor") + +# How the chain of calls works for LoggingTensor: +# 1. Call torch.sin +# 2. Attempt __torch_function__. In LoggingTensor torch function is disabled so we bypass it entirely +# 3. Enter dispatcher, wind your way through Autograd +# 4. Hit Python dispatch key, call __torch_dispatch__ + +# This Tensor can work with autograd in two ways: +# - The wrapped Tensor does not require gradients. In that case, the LoggingTensor +# can require gradients if the user asks for it as a constructor kwarg. +# - The wrapped Tensor can require gradients. In that case autograd will be tracked +# for the wrapped Tensor and the LoggingTensor itself cannot require gradients. +# WARNING: We allow these two possibilities for testing purposes. You should NEVER use both in a single +# test or you might get surprising behavior. + +# TODO: TensorBase should work +class LoggingTensor(torch.Tensor): + elem: torch.Tensor + + __slots__ = ['elem'] + + context = contextlib.nullcontext + + @staticmethod + def __new__(cls, elem, *args, **kwargs): + # The wrapping tensor (LoggingTensor) shouldn't hold any + # memory for the class in question, but it should still + # advertise the same device as before + r = torch.Tensor._make_wrapper_subclass( + cls, elem.size(), + strides=elem.stride(), storage_offset=elem.storage_offset(), + # TODO: clone storage aliasing + dtype=elem.dtype, layout=elem.layout, + device=elem.device, requires_grad=kwargs.get("requires_grad", False) + ) + # ...the real tensor is held as an element on the tensor. + r.elem = elem.detach() if r.requires_grad else elem + return r + + def __repr__(self): + return super().__repr__(tensor_contents=f"{self.elem}") + + @classmethod + def __torch_dispatch__(cls, func, types, args=(), kwargs=None): + def unwrap(e): + return e.elem if isinstance(e, cls) else e + + def wrap(e): + return cls(e) if isinstance(e, torch.Tensor) else e + + with cls.context(): + rs = tree_map(wrap, func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs))) + logging.getLogger("LoggingTensor").info(f"{func.__module__}.{func.__name__}", args, kwargs, rs) # noqa: G004 + return rs + +class LoggingTensorMode(TorchDispatchMode): + def __torch_dispatch__(self, func, types, args=(), kwargs=None): + if kwargs is None: + kwargs = {} + rs = func(*args, **kwargs) + logging.getLogger("LoggingTensor").info(f"{func.__module__}.{func.__name__}", args, kwargs, rs) # noqa: G004 + return rs + +class LoggingTensorReentrant(LoggingTensor): + context = torch.overrides.enable_reentrant_dispatch + +# https://stackoverflow.com/questions/36408496/python-logging-handler-to-append-to-list +class LoggingTensorHandler(logging.Handler): + def __init__( + self, log_list: list[str], use_shortid_for_all_tensors: bool, + with_type: bool, tracebacks_list: Optional[list]) -> None: + logging.Handler.__init__(self) + self.log_list = log_list + self.use_shortid_for_all_tensors = use_shortid_for_all_tensors + self.tracebacks_list = tracebacks_list + self.memo = WeakTensorKeyDictionary() + self.next_id = 0 + self.with_type = with_type + + def _shortid(self, t: torch.Tensor) -> int: + if t not in self.memo: + self.memo[t] = self.next_id + self.next_id += 1 + return self.memo[t] + + def _fmt(self, a: object, with_type: bool = False) -> str: + cond_cls = torch.Tensor if self.use_shortid_for_all_tensors else LoggingTensor + if isinstance(a, cond_cls): + maybe_type = "" + if with_type and self.with_type: + maybe_type = f": {_dtype_abbrs[a.dtype]}[{', '.join(map(str, a.shape))}]" + x = f"${self._shortid(a)}{maybe_type}" + return x + else: + return repr(a) + + def emit(self, record): + fmt_args = ", ".join( + itertools.chain( + (str(tree_map(self._fmt, a)) for a in record.args[0]), + (f"{k}={str(tree_map(self._fmt, v))}" for k, v in record.args[1].items()), + ) + ) + fmt_rets = tree_map(functools.partial(self._fmt, with_type=True), record.args[2]) + self.log_list.append(f'{fmt_rets} = {record.msg}({fmt_args})') + if self.tracebacks_list is not None: + self.tracebacks_list.append(record.traceback) + +def log_input(name: str, var: object) -> None: + logger.info("input", (name,), {}, var) # noqa: PLE1205 + +class GatherTraceback(logging.Filter): + def __init__(self, python=True, script=True, cpp=False): + self.python = python + self.script = script + self.cpp = cpp + + def filter(self, record): + record.traceback = gather_traceback(python=self.python, script=self.script, cpp=self.cpp) + return True + +@contextlib.contextmanager +def capture_logs(is_mode=False, python_tb=False, script_tb=False, cpp_tb=False) -> Iterator[list[str]]: + collect_traceback = python_tb or script_tb or cpp_tb + log_list: list[str] = [] + tracebacks_list: list[str] = [] + handler = LoggingTensorHandler( + log_list, + with_type=True, + use_shortid_for_all_tensors=is_mode, + tracebacks_list=tracebacks_list if collect_traceback else None + ) + logger.addHandler(handler) + logger.setLevel(logging.INFO) + logger.propagate = False + if collect_traceback: + logger.addFilter(GatherTraceback(python=python_tb, script=script_tb, cpp=cpp_tb)) + try: + if collect_traceback: + yield log_list, tracebacks_list + else: + yield log_list + finally: + symbolized_tracebacks = symbolize_tracebacks(tracebacks_list) + tracebacks_list.clear() + tracebacks_list.extend(symbolized_tracebacks) + logger.removeHandler(handler) + +@contextlib.contextmanager +def capture_logs_with_logging_tensor_mode(python_tb=False, script_tb=False, cpp_tb=False): + with LoggingTensorMode(), capture_logs(True, python_tb, script_tb, cpp_tb) as logs: + yield logs diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/opinfo/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/opinfo/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..97c38f3560625213fbd59d09a9cfd22bad26ba04 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/opinfo/__init__.py @@ -0,0 +1,4 @@ +# mypy: ignore-errors + +import torch.testing._internal.opinfo.core +import torch.testing._internal.opinfo.definitions diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/opinfo/core.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/opinfo/core.py new file mode 100644 index 0000000000000000000000000000000000000000..d6e88e239e7b6ce0567c09e8640c23a9547dd67e --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/opinfo/core.py @@ -0,0 +1,3221 @@ +# mypy: ignore-errors + +import collections +import collections.abc +import contextlib +import logging +import math +import operator +import unittest +from abc import ABC, abstractmethod +from collections.abc import Callable, Iterable +from dataclasses import asdict, dataclass, field +from enum import Enum +from functools import partial +from itertools import product +from typing import Any, Optional, TypeVar, Union + +import torch +from torch.testing import make_tensor +from torch.testing._internal.common_device_type import ( + skipCPUIfNoFFT, + tol, + toleranceOverride, +) +from torch.testing._internal.common_dtype import ( + _dispatch_dtypes, + floating_and_complex_types, + floating_and_complex_types_and, + floating_types, + get_all_dtypes, +) +from torch.testing._internal.common_utils import ( + extract_test_fn, + IS_FBCODE, + is_iterable_of_tensors, + noncontiguous_like, + OPINFO_SAMPLE_INPUT_INDEX, + TEST_WITH_ROCM, + torch_to_numpy_dtype_dict, + TrackedInputIter, + USE_PYTEST, +) +from torch.testing._internal.opinfo import utils +from torchgen.utils import dataclass_repr + + +# setup logging +log = logging.getLogger(__name__) + +# Reasonable testing sizes for dimensions +L = 20 +M = 10 +S = 5 +XS = 3 + +# Unique value to distinguish default from anything else +_NOTHING = object() + + +# Extension of getattr to support qualified names +# e.g. _getattr_qual(torch, 'linalg.norm') -> torch.linalg.norm +def _getattr_qual(obj, name, default=_NOTHING): + try: + for path in name.split("."): + obj = getattr(obj, path) + return obj + except AttributeError: + if default is not _NOTHING: + return default + else: + raise + + +class DecorateInfo: + """Describes which test, or type of tests, should be wrapped in the given + decorators when testing an operator. Any test that matches all provided + arguments will be decorated. The decorators will only be applied if the + active_if argument is True.""" + + __slots__ = [ + "decorators", + "cls_name", + "test_name", + "device_type", + "dtypes", + "active_if", + ] + + def __init__( + self, + decorators, + cls_name=None, + test_name=None, + *, + device_type=None, + dtypes=None, + active_if=True, + ): + self.decorators = ( + list(decorators) + if isinstance(decorators, collections.abc.Sequence) + else [decorators] + ) + self.cls_name = cls_name + self.test_name = test_name + self.device_type = device_type + self.dtypes = dtypes + self.active_if = active_if + + # Validate dtypes + if self.dtypes is not None: + for dtype in self.dtypes: + assert isinstance(dtype, torch.dtype) + + def is_active(self, cls_name, test_name, device_type, dtype, param_kwargs): + return ( + self.active_if + and (self.cls_name is None or self.cls_name == cls_name) + and (self.test_name is None or self.test_name == test_name) + and (self.device_type is None or self.device_type == device_type) + and (self.dtypes is None or dtype in self.dtypes) + # Support callables over kwargs to determine if the decorator is active. + and ( + self.active_if(param_kwargs) + if isinstance(self.active_if, Callable) + else self.active_if + ) + ) + + +# FIXME +# Note: historically the 'input' kwarg had to be a Tensor or TensorList, but we are trying +# to support scalar inputs, too. Some tests still depend on 'input' being a Tensor +# or TensorList, however. +class SampleInput: + """Represents sample inputs to a function.""" + + __slots__ = [ + "input", + "args", + "kwargs", + "output_process_fn_grad", + "broadcasts_input", + "name", + ] + + def __init__( + self, + input, + *var_args, + args=None, + kwargs=None, + output_process_fn_grad=None, + broadcasts_input=None, + name=None, + **var_kwargs, + ): + # input is the first input to the op and is typically either a Tensor or TensorList (Sequence[Tensor]). + # This follows the typical pattern where for Tensor inputs op(t, ...) = t.op(...). + self.input = input + + # Allow calling either as SampleInput(input, args=args, kwargs=kwargs), or as + # SampleInput(input, *args, **kwargs) but not to mix the two forms + if args is not None or kwargs is not None: + assert not var_args and not var_kwargs, """ +A SampleInput can be constructed "naturally" with *args and **kwargs or by +explicitly setting the "args" and "kwargs" parameters, but the two +methods of construction cannot be mixed!""" + elif var_args or var_kwargs: + assert ( + output_process_fn_grad is None + and broadcasts_input is None + and name is None + ), """ +A SampleInput constructed "naturally" with *args and **kwargs +cannot specify additional metadata in keyword arguments""" + + self.args = args if args is not None else var_args + assert isinstance(self.args, tuple) + self.kwargs = kwargs if kwargs is not None else var_kwargs + assert isinstance(self.kwargs, dict) + + self.output_process_fn_grad = ( + output_process_fn_grad + if output_process_fn_grad is not None + else lambda x: x + ) + self.name = name if name is not None else "" + + # Specifies if `self.input` is broadcasted or not, + # given that the operator supports broadcasting. + # This field is used to verify the behavior for inplace variant. + # + # If a SampleInput is marked with `broadcasts_input=True`, + # it is verified that we get a `RuntimeError` with this sample, + # and inplace variant. Also inplace grad{grad} tests are skipped, + # for such inputs (as they will error out otherwise). + self.broadcasts_input = ( + broadcasts_input if broadcasts_input is not None else False + ) + + def with_metadata( + self, *, output_process_fn_grad=None, broadcasts_input=None, name=None + ): + if output_process_fn_grad is not None: + self.output_process_fn_grad = output_process_fn_grad + if broadcasts_input is not None: + self.broadcasts_input = broadcasts_input + if name is not None: + self.name = name + return self + + def _repr_helper(self, formatter): + # Helper function to return the details of the SampleInput as `str` + # It consolidates all the fields of SampleInput and allows, + # formatting the fields like `input`, `args`, etc with `formatter` + # callable to customize the representation. + # Look at `summary` method for example. + arguments = [ + f"input={formatter(self.input)}", + f"args={formatter(self.args)}", + f"kwargs={formatter(self.kwargs)}", + f"broadcasts_input={self.broadcasts_input}", + f"name={repr(self.name)}", + ] + + return f"SampleInput({', '.join(a for a in arguments if a is not None)})" + + def __repr__(self): + return self._repr_helper(lambda x: x) + + def summary(self): + # Returns the SampleInput details in a more + # friendly format. + # It formats `Tensor` and `TensorList` + # in a more condensed representation. + def formatter(arg): + # Format any instance of `Tensor` (standalone, in list, or in dict) + # by Tensor[TensorShape] + # Eg. Tensor with shape (3, 4) is formatted as Tensor[3, 4] + if isinstance(arg, torch.Tensor): + shape = str(tuple(arg.shape)) + dtype = str(arg.dtype) + device = str(arg.device) + contiguity_suffix = "" + # NB: sparse CSR tensors annoyingly return is_sparse=False + is_sparse = arg.is_sparse or arg.layout == torch.sparse_csr + if not is_sparse and not arg.is_contiguous(): + contiguity_suffix = ", contiguous=False" + return f'Tensor[size={shape}, device="{device}", dtype={dtype}{contiguity_suffix}]' + elif isinstance(arg, dict): + return {k: formatter(v) for k, v in arg.items()} + elif is_iterable_of_tensors(arg): + return "TensorList[" + ", ".join(map(formatter, arg)) + "]" + elif isinstance(arg, (list, tuple)): # Handle list, tuple + return "(" + ",".join(map(formatter, arg)) + ")" + + return repr(arg) + + return self._repr_helper(formatter) + + # Applies the transform f(t) -> t to each tensor and dtype in the SampleInput + def transform(self, f): + def tt(t): + def _tt(t): + with torch.no_grad(): + return f(t) + + if isinstance(t, torch.Tensor): + return _tt(t) + elif isinstance(t, torch.dtype): + return _tt(t) + elif isinstance(t, list): + return list(map(tt, t)) + elif isinstance(t, tuple): + return tuple(map(tt, t)) + elif isinstance(t, dict): + return {k: tt(v) for k, v in t.items()} + else: + return t + + sample_tt_input, tt_args, tt_kwargs = ( + tt(self.input), + tt(self.args), + tt(self.kwargs), + ) + + # Note the transformed SampleInput assumes metadata like output_process_fn_grad is still valid! + return SampleInput( + sample_tt_input, + args=tt_args, + kwargs=tt_kwargs, + output_process_fn_grad=self.output_process_fn_grad, + broadcasts_input=self.broadcasts_input, + name=self.name + "_transformed", + ) + + # Returns the NumPy version of the sample input object in the form of a tuple: (input, args, kwargs) + # Converts tensors to ndarrays by calling .detach().cpu().numpy() on them + # Converts dtypes by remapping them using torch_to_numpy_dtype_dict + def numpy(self): + def to_numpy(t): + if isinstance(t, torch.Tensor): + if t.dtype is torch.bfloat16: + return t.detach().cpu().to(torch.float32).numpy() + if t.dtype is torch.chalf: + return t.detach().cpu().to(torch.cfloat).numpy() + return t.detach().cpu().numpy() + elif isinstance(t, torch.dtype): + return torch_to_numpy_dtype_dict[t] + + return t + + return self.transform(to_numpy) + + def noncontiguous(self): + def to_noncontiguous(t): + if isinstance(t, torch.Tensor): + return noncontiguous_like(t) + elif isinstance(t, torch.dtype): + return t + + return t + + return self.transform(to_noncontiguous) + + +NumericsFilter = collections.namedtuple("NumericsFilter", ["condition", "safe_val"]) + + +class ErrorInput: + """ + A SampleInput that will cause the operation to throw an error plus information + about the resulting error. + """ + + __slots__ = ["sample_input", "error_type", "error_regex"] + + def __init__(self, sample_input, *, error_type=RuntimeError, error_regex): + self.sample_input = sample_input + self.error_type = error_type + self.error_regex = error_regex + + +class AliasInfo: + """Class holds alias information. For example, torch.abs -> + torch.absolute, torch.Tensor.absolute, torch.Tensor.absolute_ + """ + + def __init__(self, alias_name): + self.name = alias_name + self.op = _getattr_qual(torch, alias_name) + self.method_variant = getattr(torch.Tensor, alias_name, None) + self.inplace_variant = getattr(torch.Tensor, alias_name + "_", None) + + def __call__(self, *args, **kwargs): + return self.op(*args, **kwargs) + + +# Note [OpInfos] +# ~~~~~~~~~~~~~~ +# +# The majority of this note was written shortly after the PyTorch 1.9 release. +# If you notice it's out-of-date or think it could be improved then please +# file an issue. +# +# See also: the OpInfo tracker (https://github.com/pytorch/pytorch/issues/54261) +# See also: "Writing Test Templates" in common_device_type.py to learn how to +# parametrize a test template using OpInfos. +# See also: PyTorch's GitHub wiki on running and writing tests +# https://github.com/pytorch/pytorch/wiki/Running-and-writing-tests +# See also: ModuleInfos, OpInfo's sister class, defined in common_modules.py +# +# An OpInfo is a collection of metadata related to a PyTorch operator. This +# metadata is used to generate tests that validate properties of the operator, +# like if it implements the correct gradient formula. +# +# WHY OPINFOS? +# ~~~~~~~~~~~~ +# +# OpInfos are principally intended to do three things: +# +# 1) to allow systematic testing over all PyTorch's operators +# 2) to simplify operating testing by autogenerating many tests +# 3) to allow systems (like autograd, torchscript, fx, nnc...) to test +# against every PyTorch operator +# +# All these goals are still a work in progress. Not every operator has an +# OpInfo, and some operator tests that could be automatically generated +# still have to be written manually. +# +# It's helpful to understand that OpInfos are both about test simplification and +# modularity. PyTorch is a complicated framework with many interrelated systems, +# too many for any one person to keep track of. An OpInfo can be thought of as the +# interface between an operator implementer and those other systems. Instead of +# requiring the implementer of torch.foo understand how to test its forward +# mode AD or NNC support that's typically handled automatically just by +# defining an OpInfo. +# +# It's often surprising to OpInfo writers that just implementing an OpInfo +# typically can't verify an operator is actually implemented correctly: +# +# "If an OpInfo doesn't validate my op works as expected, what's the point +# of it?" +# +# But the point of is the above. OpInfos are intended to let you focus on testing +# the operator logic you're familiar with instead of having to write tests for +# how the operator interacts with each of PyTorch's many systems. +# +# And, OK, it turns out that SOMETIMES just writing an OpInfo DOES +# validate your op works as expected, but that's only in special +# cases. See below for details. +# +# WHAT'S AN OPINFO? +# ~~~~~~~~~~~~~~~~~ +# +# So what is an OpInfo? It's a Python class that describes an operator's properties, +# like which dtypes it supports on the CPU and whether it has any aliases. +# These properties can be divided into three categories: +# +# 1) Metadata describing the operator, like the operator's name and if it +# "supports" the out kwarg. +# 2) Test directives, like "skips" that tell the test suite to skip some +# tests. +# 3) A "sample inputs" function that generates valid inputs for the operator. +# +# OpInfo attributes are described in more detail below. +# +# THE SAMPLE INPUTS FUNCTION +# ~~~~~~~~~~~~~~~~~~~~~~~~~~ +# +# The "sample inputs" function merits special elaboration. This function is +# crucial to testing with OpInfos. A typical OpInfo test has to treat the operator +# as a black box. There's no structure for the test to understand or exploit. +# Without "sample inputs" it wouldn't even know how to call the OpInfo's +# operator. The sample input function saves the day by providing different +# "SampleInputs" that can be used to call the operator. A sample input +# function should have the following signature: +# +# def sample_inputs_foo(op_info, device, dtype, requires_grad, **kwargs): +# +# And should return an iterable of SampleInputs (see the class description +# above). Each SampleInput defines an "input", "args", "kwargs", an +# "output_process_fn_grad" function, the "broadcasts_input" bool and a +# "name". +# +# All the "sample_inputs" functions are invoked within a `torch.no_grad()` +# environment for efficiency and correctness. As such remember to set the +# "requires_grad" flag on the inputs **after** performing any transformations +# on them. +# +# The "input" is the first argument to the operator, or the tensor that +# the method or inplace variants of the operator should be called on, and +# should be on the requested device, of the requested dtype, and its +# requires_grad attribute should be set to the requires_grad argument. +# +# "args" should contain positional arguments, and "kwargs" keyword arguments. +# +# "output_process_fn_grad" has an interesting name. It's a function that maps +# the operator's output (when given the input, args, and kwargs) to the +# portion of the output to gradcheck. For example, consider an operator +# like torch.linalg.slogdet +# (https://pytorch.org/docs/main/generated/torch.linalg.slogdet.html). +# This operator returns a tuple of two tensors, but the first tensor +# cannot be backwarded through. Its "output_process_fn_grad" filters +# this output tuple to just the second argument, which we can call backward +# on. Functions that produce a single tensor can ignore this argument. +# +# "broadcasts_input" is a bool indicated if the SampleInput causes the operator +# to broadcast the "input" argument. This is important for tests to understand +# because inplace variants of operations throw a runtime error if they +# would broadcast their input arguments, so tests that work with inplace +# variants filter SampleInputs that broadcast their input. +# +# "name" is a string that's just used for debugging. It appears when printing +# the SampleInput. +# +# Sample inputs are designed to be used with many tests, some +# that are very time consuming, so they should be a small +# set with small tensors. An elaborated set of sample inputs +# can be specified using the "reference_inputs_func" attribute. +# The "reference inputs" for an operation are an extended +# set of sample inputs that can more exhaustively test an +# operator. They are used by only a few tests that are careful +# not to take too long to run. Adding reference inputs +# is highly encouraged! +# +# THE (OPTIONAL) ERROR INPUTS FUNCTION +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# +# OpInfos may optionally specify "error inputs" through an error function. If +# specified test_errors in test_ops.py will call the op with these inputs +# and validate that the desired error is thrown. +# +# Error inputs automate a common testing pattern where multiple inputs are +# passed to an operation and the errors they thrown are reviewed. Tests +# written in this style should be ported to the new OpInfo pattern. +# +# Error inputs are specified using the ErrorInputs class, which contains +# a SampleInput (see above) and data about the expected error. +# +# OPINFO FILE ORGANIZATION +# ~~~~~~~~~~~~~~~~~~~~~~~~ +# +# All OpInfos are currently defined in this file. Most OpInfo tests are defined +# in test_ops.py, but some system-specific tests are defined in those +# systems' test files, and subclass-specific tests are defined in the test +# file that corresponds to that subclass (see the below). +# Expect a reorganization in the future. +# +# WHAT'S TESTED? +# ~~~~~~~~~~~~~~ +# +# Every OpInfo in the op_db sequence has the following properties validated in +# test_ops.py: +# +# - that its supported dtypes are specified correctly +# - that the operation produces the same results when called with noncontiguous inputs +# - that it supports the out= argument properly (if it allows out=), +# see https://github.com/pytorch/pytorch/wiki/Developer-FAQ#how-does-out-work-in-pytorch +# - that it works with the conjugate view bit properly +# - that its function, method, and inplace variants perform the same operation +# (that is, that torch.add, torch.Tensor.add, and torch.Tensor.add_ all +# do the same thing). +# - that its inplace variant preserves the input's storage +# - that its gradient formula is implemented correctly, and that it supports +# gradgrad and complex grad and gradgrad and forward mode AD properly for +# the op's function and inplace variants (method variants are skipped +# to reduce test time). +# - that the operation performs the same operation when traced or scripted +# using the jit +# - that the operation is autodifferentiated by the jit as expected +# - that the operator's aliases, if any, perform the same operation and that +# the jit understands the alias +# - that the operator throws the correct errors (if error_inputs is defined) +# - that the operator produces the same results as a NumPy reference (if ref is defined) +# - that the operator produces the same results as a NumPy reference on an extended +# set of "reference inputs" (if both ref and reference_inputs_func are defined) +# (NOTE: elementwise unary and elementwise binary OpInfos do this even if only +# ref is defined, because they effectively autogenerate reference inputs) +# - that the operator works on different CUDA devices +# +# Additional OpInfo tests are in test_jit_fuser_te.py, test_fx_experimental.py, +# and test_fx.py. These tests validate that operators work with NNC and FX +# as expected. +# +# For performance, some of the above tests may only run on the first +# SampleInput returned by an OpInfo's sample input function. +# +# In addition to these tests, some subclasses (discussed in the next section) +# define additional tests. +# +# Critically, as mentioned above, what's not necessarily tested is that the operator +# works as expected. When implementing an OpInfo an engineer must still +# typically write one or more tests validating the operator's behavior. +# The exception to this is if reference testing is sufficient, or if +# the operation belongs to an OpInfo subclass that has more exhaustive +# operator testing. Elementwise unary and elementwise binary operators, +# in particular, usually don't require additional testing beyond +# writing an Opinfo. +# +# +# OPINFO (SUB)CLASSES +# ~~~~~~~~~~~~~~~~~~~ +# +# In addition to the OpInfo base class there are several specialized OpInfo +# subclasses. For example, the UnaryUfuncInfo subclass is used for +# unary elementwise operations. These operations have a common structure +# that test_unary_ufuncs.py exploits with additional automated testing. +# The automated testing in test_unary_ufuncs.py is so thorough, comparing +# the operator to a NumPy reference function on a plethora of values, that +# just implementing an OpInfo for a unary elementwise operation is often +# sufficient testing. +# +# The ForeachFuncInfo is another OpInfo subclass that is hyper-specialized to a +# very unique class of operations. These OpInfos aren't included in the +# op_db sequence and have their own tests. +# +# Other OpInfo subclasses, like SpectralFuncInfo, are just for convenience +# when writing OpInfos. +# +# TESTING A NEW OPERATOR +# ~~~~~~~~~~~~~~~~~~~~~~ +# +# If you're adding a new operator to any of the following namespaces: +# - torch +# - torch.fft +# - torch.linalg, +# - torch.special +# - torch.nn.functional +# then you should typically add an OpInfo for it. +# +# As mentioned a couple times above, implementing an OpInfo is not +# usually sufficient testing (unless the operator is a unary or binary elementwise +# operator). The OpInfo will only test the properties described in the +# "WHAT'S TESTED" section. It DOES NOT necessarily verify that the operator is +# implemented correctly. +# +# TIPS FOR WRITING AN OPINFO AND OPINFO TESTS +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# +# Writing an OpInfo can be a little daunting. Since the point of an OpInfo is to +# be consumed by a variety of systems it can be hard to understand how to +# deal with test failures or how to set the OpInfo metadata properly. +# +# Before adding an OpInfo it helps to look at other OpInfos. A sample inputs +# function must be defined, and the operator's dtypes must be specified. +# Once that's done you should run the operator's tests in test_ops.py +# (these can be filtered using the "-k" argument in pytest). Tests that +# fail should provide an error message that describes what to change about +# your OpInfo. You don't need to worry about changing an OpInfo's default +# values unless a test yells at you. +# +# Similarly, if you're writing a test that consumes OpInfos then it's critical +# your test provides a clear error message describing what to do when it +# fails. You should not assume the OpInfo implementer is familiar with your +# system. +# +# If you see a confusing error message while developing an OpInfo then please +# file an issue describing what happened. +# +# This trial-and-error approach to writing an OpInfo can be frustrating, +# but it's probably necessary as long as OpInfos don't require +# learning about all the systems that consume them. One thing that can help +# is the get_supported_dtypes() function defined in utils.py. This +# function can be used to programmatically specify the dtypes an operator +# supports, and is especially useful if writing an OpInfo on a machine +# without a CUDA device. See its documentation for more details. +# +# THE FUTURE OF OPINFOS AND OPINFO TESTING +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# +# In the future we expect OpInfo coverage to improve and cover +# the great majority of PyTorch's (public) operators. +# + + +# Classes and methods for the operator database +@dataclass +class OpInfo: + """Operator information and helper functions for acquiring it.""" + + # the string name of the function + name: str + + # An optional reference function that accepts ndarrays (AKA "NumPy arrays"). + # If given, the op will be compared with its reference on each of its sample inputs. + ref: Optional[Callable] = None + + # the following metadata describes the operator, its variants, and its aliases, if any + + # iterable of aliases, e.g. ("absolute",) for torch.abs + aliases: Iterable = None + + # additional string to include in the test name + # this is useful when an op needs multiple OpInfos, + # like divide does, often because it's really several + # different ops behind the scenes + variant_test_name: str = "" + + # the function variant of the operation, populated as torch. if None + op: Callable = None + + # allows the method variant of this operation to be specified as follows: + # - if _NOTHING (default), then the OpInfo attempts to discover the variant using its name + # - if None, then the OpInfo explicitly specifies is has no associated method + # - if a Callable, then that callable should be the method associated with this operation + method_variant: Callable = _NOTHING + + # allows the inplace variant of this operation to be specified as follows: + # - if _NOTHING (default), then the OpInfo attempts to discover the variant using its name + # - if None, then the OpInfo explicitly specifies is has no associated inplace variant + # - if a Callable, then that callable should be the inplace variant associated with this operation + inplace_variant: Callable = _NOTHING + + # allows the operator variant of this operation to be specified as follows: + # - if _NOTHING (default), then the OpInfo attempts to discover the variant using its name + # - if None, then the OpInfo explicitly specifies is has no associated operator + # - if a Callable, then that callable should be the operator associated with this operation + operator_variant: Callable = _NOTHING + + # allows the inplace operator variant of this operation to be specified as follows: + # - if _NOTHING (default), then the OpInfo attempts to discover the variant using its name + # - if None, then the OpInfo explicitly specifies is has no associated inplace operator + # - if a Callable, then that callable should be the inplace operator associated with this operation + inplace_operator_variant: Callable = _NOTHING + + # the following metadata are test directives for skipping or modifying tests + + # information about which tests to skip + skips: tuple = () + + # decorators to apply to generated tests + decorators: tuple = () + + # the following are pointers to functions to generate certain classes of inputs + + # function to generate sample inputs with strided layouts + sample_inputs_func: Callable = None + + # function to generate a more thorough set of samples inputs with strided layouts + reference_inputs_func: Callable = None + + # function to generate inputs that will throw errors + error_inputs_func: Callable = None + + # function to generate sparse (coo, csr, csc, bsr, bsc) inputs that will throw errors + error_inputs_sparse_func: Callable = None + + # function to generate sample inputs with sparse coo layouts + sample_inputs_sparse_coo_func: Callable = None + + # function to generate sample inputs with sparse csr layouts + sample_inputs_sparse_csr_func: Callable = None + + # function to generate sample inputs with sparse csc layouts + sample_inputs_sparse_csc_func: Callable = None + + # function to generate sample inputs with sparse bsr layouts + sample_inputs_sparse_bsr_func: Callable = None + + # function to generate sample inputs with sparse bsc layouts + sample_inputs_sparse_bsc_func: Callable = None + + # the following metadata relates to dtype support and is tested for correctness in test_ops.py + + # dtypes this function works with on the CPU, + # inherited by other device types that don't specify their own dtypes + dtypes: _dispatch_dtypes = None + + # the following dtypesIf... options override the dtypes value on their respective device types + # I.e. instead of writing multiple `dtypesIfCUDA`, `dtypesIfROCM`, etc one can simply define a dict + # dtypesIf = { 'cuda': (torch.float, torch.double), 'rocm': (torch.half, torch.bfloat16) } + dtypesIf: dict[str, _dispatch_dtypes] = field(default_factory=dict) + + def __getattribute__(self, name: str) -> Any: + if name.startswith("dtypesIf") and name != "dtypesIf": + # TODO: Warn if used + dev_name = name.removeprefix("dtypesIf").lower() + return self.dtypesIf.get(dev_name) + return super().__getattribute__(name) + + def __setattr__(self, name: str, value: Any) -> None: + # TODO: After migration, start adding warnings here + if name.startswith("dtypesIf") and name != "dtypesIf": + assert isinstance(value, (_dispatch_dtypes, type(None))) + dev_name = name.removeprefix("dtypesIf").lower() + self.dtypesIf[dev_name] = value + return + super().__setattr__(name, value) + + # dtypes this function is expected to work with on CUDA + dtypesIfCUDA: _dispatch_dtypes = None + + # dtypes this function is expected to work with on ROCM + dtypesIfROCM: _dispatch_dtypes = None + + dtypesIfHpu: _dispatch_dtypes = None + + # dtypes this function is expected to work with on XPU + dtypesIfXPU: _dispatch_dtypes = None + + # backward dtypes this function is expected to work with + backward_dtypes: _dispatch_dtypes = None + + # backward dtypes this function is expected to work with on CUDA + backward_dtypesIfCUDA: _dispatch_dtypes = None + + # backward dtypes this function is expected to work with on ROCM + backward_dtypesIfROCM: _dispatch_dtypes = None + + backward_dtypesIfHpu: _dispatch_dtypes = None + + # the following metadata describes the operators out= support + + # whether the op supports the out kwarg + # defaults to True, if the op does not allow the out kwarg or + # supports it incorrectly then test_out in test_ops.py should fail + supports_out: bool = True + + # the following metadata relates to autograd support + # whether the operation supports backward mode AD + # if true, gradient correctness is tested in test_ops.py + # using the op's sample inputs + supports_autograd: bool = True + + # whether the op supports second order gradients + # if true, gradgrad correctness is tested in test_ops.py + # defaults to support_autograd's value + # TODO: rename this to supports_bwgrad_bwgrad to be consistent with below + supports_gradgrad: bool = None + + # whether the ops supports second order gradients via + # forward-over-reverse. If True, forward-over-reverse gradgrad correctness + # is tested. If False, test that forward grad is not implemented. + # Defaults to False. + supports_fwgrad_bwgrad: bool = False + + # whether the operation supports inplace autograd + # if true, tested in test_ops.py + # defaults to supports_autograd's value + supports_inplace_autograd: bool = None + + # Whether the operation support forward mode AD + # If the value is True, we check that the gradients are correct + # If the value is False, we test that forward grad is not implemented + supports_forward_ad: bool = False + + # Whether the operation has a varargs variant + # (e.g. functions like ones, zeros, methods like view, permute) + supports_varargs: bool = False + + # Whether the forward operation avoids materializing COW tensor inputs + supports_cow_input_no_materialize_forward: bool = True + + # Whether the backward operation avoids materializing COW tensor inputs + supports_cow_input_no_materialize_backward: bool = True + + # Whether to skip the backward part of the COW tensor input test + skip_cow_input_backward: bool = False + + # If `supports_cow_input_no_materialize_forward == True`, this list contains + # the arg indices or kwarg names of inputs that are expected to materialize + allow_cow_input_materialize_forward: list[Union[int, str]] = None + + # If `supports_cow_input_no_materialize_backward == True`, this list contains + # the arg indices or kwarg names of inputs that are expected to materialize + allow_cow_input_materialize_backward: list[Union[int, str]] = None + + # wrapper function for gradcheck + gradcheck_wrapper: Callable = lambda op, *args, **kwargs: op(*args, **kwargs) + + # whether to check batched grad when doing gradcheck + # defaults to support_autograd's value + check_batched_grad: bool = None + + # whether to check batched grad grad when doing gradgradcheck + # default's to support_gradgrad's value + check_batched_gradgrad: bool = None + + # whether to check batched forward grad when doing gradcheck + # defaults to the value of `supports_forward_ad` + check_batched_forward_grad: bool = None + + # whether to check batched forward grad when doing gradcheck + # defaults to the value of `check_batched_forward_grad` + check_inplace_batched_forward_grad: bool = None + + # tolerance for nondeterminism while performing gradcheck + gradcheck_nondet_tol: float = 0.0 + + # Whether to use the fast implementation for gradcheck/gradgradcheck. + # When set to None, defers to the default value provided by the wrapper + # function around gradcheck (testing._internal.common_utils.gradcheck) + gradcheck_fast_mode: bool = None + + # the following metadata relates to JIT support and is tested for correctness in test_ops.py + + # name of the corresponding aten:: operator + aten_name: str = None + + # if this is a composite implicit autograd op, the decomposed op + decomp_aten_name: Optional[str] = None + + # name of the corresponding aten:: operator for backwards + aten_backward_name: Optional[str] = None + + # if a op's aten::node is expected to be symbolically autodiffed + assert_autodiffed: bool = False + + # a list of strings with node names that are expected to be in a + # DifferentiableGraph when autodiffed. Ex: ['aten::add', 'aten::mm'], + # default is populated to be ['aten::(name of Python operator)'] + autodiff_nonfusible_nodes: list[str] = None + + # a list of strings with node names that are expected to be in FusionGroups + # inside of DifferentiableGraphs when this operation is autodiffed. + # Ex: ['aten::add', 'aten::mm'], defaults to an empty list + # Note: currently no ops use fusible nodes + autodiff_fusible_nodes: list[str] = None + + # the following metadata relates to sparse support and is used in test_sparse.py + + # whether the op supports sparse coo inputs, defaults to False + # TODO: rename supports_sparse to supports_sparse_coo + supports_sparse: bool = None + + # only run tracing tests + supports_scripting: bool = True + + # if the operator can be traced + supports_tracing: bool = True + + # the following metadata relates to sparse compressed support and + # is used in test_sparse_csr.py and test_sparse.py + + # whether the op supports sparse csr inputs, defaults to False + supports_sparse_csr: bool = None + # whether the op supports sparse csc inputs, defaults to False + supports_sparse_csc: bool = None + # whether the op supports sparse bsr inputs, defaults to False + supports_sparse_bsr: bool = None + # whether the op supports sparse bsc inputs, defaults to False + supports_sparse_bsc: bool = None + # whether the op supports nested jagged inputs, defaults to False + supports_njt: bool = None + + # whether the op promotes integer inputs to float + promotes_int_to_float: bool = False + + # the following metadata relates to complex support and is checked in test_ops.py + + test_conjugated_samples: bool = True + + test_neg_view: bool = True + + # assert that jit shape analysis fully propagates shape + assert_jit_shape_analysis: bool = False + + # the following metadata relates to ExpandedWeights support and is checked in test_expanded_weights.py + + supports_expanded_weight: bool = False + + is_factory_function: bool = False + + skip_correctness_check_compile_vs_eager: bool = False + + def __post_init__(self): + self._original_opinfo_args = asdict(self).copy() + + assert self.dtypes is not None, f"OpInfo for {self.name} has no dtypes!" + + # Validates the dtypes are generated from the dispatch-related functions + for name, val in self.dtypesIf.items(): + if val is not None: + assert isinstance(val, _dispatch_dtypes) + self.dtypesIf[name] = set(val) + + if self.aten_name is None: + self.aten_name = self.name + + # Attribute to verify dynamic_dtypes are used. + self.dynamic_dtypes = any( + isinstance(dtypes, utils._dynamic_dispatch_dtypes) + for dtypes in self.dtypesIf.values() + ) + + if self.dynamic_dtypes: + # Make sure `dtyesIfCUDA` is dynamic, if dynamic dispatch is used for CPU + # This is because, below we set dtypesIfCUDA to dtypes if they are None. + assert isinstance(self.dtypesIfCUDA, utils._dynamic_dispatch_dtypes), ( + f"To use dynamic dtypes for operator {self.name}, " + "acquire the dtypes dynamically for argument `dtypesIfCUDA`." + "This is to ensure that CUDA dtypes are acquired correctly as they" + "differ from CPU dtypes occasionally" + ) + + self.dtypes = set(self.dtypes) + + # NOTE: backward dtypes must be acquired before forward dtypes + # since they fallback to explicit (not implicit!) specifications of + # forward dtypes + self.backward_dtypesIfROCM = ( + set(self.backward_dtypesIfROCM) + if self.backward_dtypesIfROCM is not None + else ( + self.backward_dtypesIfCUDA + if self.backward_dtypesIfCUDA is not None + else self.backward_dtypes + if self.backward_dtypes is not None + else self.dtypesIfROCM + if self.dtypesIfROCM is not None + else self.dtypesIfCUDA + if self.dtypesIfCUDA is not None + else self.dtypes + ) + ) + self.backward_dtypesIfCUDA = ( + set(self.backward_dtypesIfCUDA) + if self.backward_dtypesIfCUDA is not None + else ( + self.backward_dtypes + if self.backward_dtypes is not None + else self.dtypesIfCUDA + if self.dtypesIfCUDA is not None + else self.dtypes + ) + ) + self.backward_dtypesIfHpu = ( + set(self.backward_dtypesIfHpu) + if self.backward_dtypesIfHpu is not None + else ( + self.backward_dtypes + if self.backward_dtypes is not None + else self.dtypes + ) + ) + + self.backward_dtypes = ( + set(self.backward_dtypes) + if self.backward_dtypes is not None + else self.dtypes + ) + + # Inherit from cpu + for dev_type in ["cuda", "hpu"]: + if self.dtypesIf.get(dev_type) is None: + self.dtypesIf[dev_type] = self.dtypes + + # Inherit from CUDA + for dev_type in ["rocm", "xpu"]: + if self.dtypesIf.get(dev_type) is None: + self.dtypesIf[dev_type] = self.dtypesIf["cuda"] + + # NOTE: if the op is unspecified it is assumed to be under the torch namespace + if not self.op: + self.op = _getattr_qual(torch, self.name) + + if self.method_variant is _NOTHING: + self.method_variant = getattr(torch.Tensor, self.name, None) + + # attributes like real, imag are not callable + if not callable(self.method_variant): + self.method_variant = None + + if self.inplace_variant is _NOTHING: + inplace_name = self.name + "_" + self.inplace_variant = getattr(torch.Tensor, inplace_name, None) + + if self.operator_variant is _NOTHING: + self.operator_variant = getattr(operator, self.name, None) + + if self.inplace_operator_variant is _NOTHING: + # Note: operator.i will use operator. and assign the result to the lhs when no + # __i__ method is found. This results in the appearance of an inplace operator variant which + # does not have the correct inplace behavior. To avoid this, we guard automatic detection of the inplace + # operator with a check that an inplace variant exists. + if self.inplace_variant is not None: + inplace_operator_name = "i" + self.name + self.inplace_operator_variant = getattr( + operator, inplace_operator_name, None + ) + else: + self.inplace_operator_variant = None + + self.decorators = (*self.decorators, *self.skips) + + # Specifying sample inputs function without specifying the + # corresponding layout support implies the layout support: + if self.supports_sparse is None: + self.supports_sparse = self.sample_inputs_sparse_coo_func is not None + if self.sample_inputs_sparse_coo_func is None: + self.sample_inputs_sparse_coo_func = self._sample_inputs_unspecified + + if self.supports_sparse_csr is None: + self.supports_sparse_csr = self.sample_inputs_sparse_csr_func is not None + if self.sample_inputs_sparse_csr_func is None: + self.sample_inputs_sparse_csr_func = self._sample_inputs_unspecified + + if self.supports_sparse_csc is None: + self.supports_sparse_csc = self.sample_inputs_sparse_csc_func is not None + if self.sample_inputs_sparse_csc_func is None: + self.sample_inputs_sparse_csc_func = self._sample_inputs_unspecified + + if self.supports_sparse_bsr is None: + self.supports_sparse_bsr = self.sample_inputs_sparse_bsr_func is not None + if self.sample_inputs_sparse_bsr_func is None: + self.sample_inputs_sparse_bsr_func = self._sample_inputs_unspecified + + if self.supports_sparse_bsc is None: + self.supports_sparse_bsc = self.sample_inputs_sparse_bsc_func is not None + if self.sample_inputs_sparse_bsc_func is None: + self.sample_inputs_sparse_bsc_func = self._sample_inputs_unspecified + + if self.supports_njt is None: + self.supports_njt = False + + # We run the sampling functions without tracking the gradiends of the creation of inputs + self.sample_inputs_func = torch.no_grad()(self.sample_inputs_func) + self.sample_inputs_sparse_coo_func = torch.no_grad()( + self.sample_inputs_sparse_coo_func + ) + self.sample_inputs_sparse_csr_func = torch.no_grad()( + self.sample_inputs_sparse_csr_func + ) + self.sample_inputs_sparse_csc_func = torch.no_grad()( + self.sample_inputs_sparse_csc_func + ) + self.sample_inputs_sparse_bsr_func = torch.no_grad()( + self.sample_inputs_sparse_bsr_func + ) + self.sample_inputs_sparse_bsc_func = torch.no_grad()( + self.sample_inputs_sparse_bsc_func + ) + if self.reference_inputs_func is not None: + self.reference_inputs_func = torch.no_grad()(self.reference_inputs_func) + + if not self.autodiff_fusible_nodes: + self.autodiff_fusible_nodes = [] + + if self.autodiff_nonfusible_nodes is None: + self.autodiff_nonfusible_nodes = ["aten::" + self.name] + + # Autograd support + + # Autograd flags that depend on backward AD only + # - If setting has been explicitly set, raise error if inconsistent + if self.supports_gradgrad is None: + self.supports_gradgrad = self.supports_autograd + else: + assert not (self.supports_gradgrad and not self.supports_autograd), ( + "supports_gradgrad refines the part of autograd is supported, so it should " + "not be set if supports_autograd is False" + ) + if self.check_batched_grad is None: + self.check_batched_grad = self.supports_autograd or self.supports_forward_ad + else: + assert not ( + self.check_batched_grad + and not (self.supports_autograd or self.supports_forward_ad) + ), ( + "check_batched_grad refines the part of autograd that will be checked (by gradcheck), so " + "it should not be set if supports_autograd is False" + ) + if self.check_batched_gradgrad is None: + self.check_batched_gradgrad = self.supports_gradgrad + else: + assert not (self.check_batched_gradgrad and not self.supports_gradgrad), ( + "check_batched_gradgrad refines the part of autograd that will be checked (by " + "gradgradcheck), so it should not be set if either supports_gradgrad or supports_autograd " + "is False." + ) + if self.check_batched_forward_grad is None: + self.check_batched_forward_grad = self.supports_forward_ad + else: + assert not ( + self.check_batched_forward_grad and not self.supports_forward_ad + ), ( + "check_batched_forward_grad should only be used when supports_forward_ad " + "is True. It is used to disable the test in the specific cases " + "where the op supports forward ad but fails to compute " + "batched forward grad." + ) + + if self.check_inplace_batched_forward_grad is None: + self.check_inplace_batched_forward_grad = self.check_batched_forward_grad + else: + assert not ( + self.check_inplace_batched_forward_grad + and not self.check_batched_forward_grad + ), ( + "check_batched_forward_grad should only be used when check_batched_forward_grad " + "is True. It is used to disable the test in the specific cases " + "where the op supports batched forward grad but fails to compute batched forward " + "grad for the inplace variant of the op." + ) + + assert not (self.supports_fwgrad_bwgrad and not self.supports_autograd), ( + "supports_fwgrad_bwgrad enables forward-over-backward gradgrad checks and should only be " + "True if backward ad is also checked, i.e., supports_forward_ad should be True.", + self.name, + ) + + # Autograd flags that depend on both forward AD and backward AD + if self.supports_inplace_autograd is None: + self.supports_inplace_autograd = ( + self.supports_autograd or self.supports_forward_ad + ) + else: + assert not ( + self.supports_inplace_autograd + and not self.supports_autograd + and not self.supports_forward_ad + ), ( + "supports_inplace_autograd refines the part of autograd that is supported, so " + "it should not be set if both supports_autograd and supports_forward_ad are False" + ) + + if self.aliases is not None: + self.aliases = tuple(AliasInfo(a) for a in self.aliases) # type: ignore[assignment] + else: + self.aliases = () + + def __call__(self, *args, **kwargs): + """Calls the function variant of the operator.""" + return self.op(*args, **kwargs) + + def __str__(self): + return dataclass_repr(self) + + def get_op(self): + """Returns the function variant of the operator, torch..""" + return self.op + + def get_method(self): + """Returns the method variant of the operator, torch.Tensor.. + Returns None if the operator has no method variant. + """ + return self.method_variant + + def get_inplace(self): + """Returns the inplace variant of the operator, torch.Tensor._. + Returns None if the operator has no inplace variant. + """ + return self.inplace_variant + + def get_operator(self): + """Returns operator variant of the operator, e.g. operator.neg + Returns None if the operator has no operator variant. + """ + return self.operator_variant + + def get_inplace_operator(self): + """Returns the inplace operator variant of the operator, e.g operator.iadd + Returns None if the operator has no inplace operator variant""" + return self.inplace_operator_variant + + # Returns a tuple of callables: + # (TestCase -> subtest context, TestCase -> skip / xfail context) + # I'd love to combine these into one but I haven't figured out how to do it + # in a way that works like it should, and I tried a LOT of things. + def _maybe_skip_or_xfail(self, rules, device, sample, idx): + def _subtest_fn(test_case, sample=sample.name, idx=idx): + return test_case.subTest(sample=sample, idx=idx) + + if rules is None or len(rules) == 0: + return (_subtest_fn, lambda _: contextlib.nullcontext()) + + # NB: match first rule only (order matters!) + for rule in rules: + if rule.sample_match_fn(device, sample): + log.debug( + "matched %s rule '%s': %s %s %s", + rule.type, + rule.name, + self.full_name, + device, + sample, + ) + + # Provide a context for the test case to run the sample input + # through as a subtest AND handle skip / xfail for it as needed. + return ( + _subtest_fn, + lambda test_case, rule=rule: rule.get_context(test_case), + ) + + log.debug("matched no rules: %s %s %s", self.full_name, device, sample) + return (_subtest_fn, lambda _: contextlib.nullcontext()) + + def _sample_callback_fn(self, use_subtests, device): + # Get sample-specific skips / xfails. + sample_skips_and_xfails = getattr( + extract_test_fn(), "sample_skips_and_xfails", None + ) + + if sample_skips_and_xfails is not None and not use_subtests: + raise RuntimeError( + """Sample-specific skips / xfails require use_subtests=True. +Please pass this to the sample generation function and run the test logic within the +returned contexts (NB: order matters!). For example: + +def test_foo(self, device, dtype, op): + for sample, subtest_ctx, skip_xfail_ctx in op.sample_inputs(..., use_subtests=True): + # these contexts handle running within subtests and skips / xfails + with subtest_ctx(self), skip_xfail_ctx(self): + # test logic here + ...""" + ) + + if not use_subtests: + # use the default callback that returns the sample without a subtest context + return None + + if USE_PYTEST: + try: + import pytest_subtests # noqa: F401 + except ModuleNotFoundError: + raise RuntimeError( + "Encountered an OpInfo test with use_subtests=True and pytest-subtests is " + "not installed. The feature will not work correctly within pytest without " + "this package; please install it." + ) from None + + def _f( + sample, + idx, + self=self, + device=device, + sample_skips_and_xfails=sample_skips_and_xfails, + use_subtests=use_subtests, + ): + # When subtests are enabled, also return a subtest context. This is required + # for xfails / skips to work properly. + return ( + sample, + *self._maybe_skip_or_xfail( + sample_skips_and_xfails, device, sample, idx + ), + ) + + return _f + + def conjugate_sample_inputs(self, device, dtype, requires_grad=False, **kwargs): + """Returns an iterable of SampleInputs but with the tensor input or first + tensor in a sequence input conjugated. + """ + + set_seed = kwargs.pop("set_seed", True) + use_subtests = kwargs.pop("use_subtests", False) + samples = self.sample_inputs_func(self, device, dtype, requires_grad, **kwargs) + conj_samples = list(samples) + + def conjugate(tensor): + _requires_grad = tensor.requires_grad + tensor = tensor.conj() + return tensor.requires_grad_(_requires_grad) + + for i, sample in enumerate(samples): + sample = conj_samples[i] + # Note: it is assumed that the input here is either a tensor or tensorlist + if isinstance(sample.input, torch.Tensor): + sample.input = conjugate(sample.input) + else: + sample.input[0] = conjugate(sample.input[0]) + + return TrackedInputIter( + iter(conj_samples), + "conjugate sample input", + item_callback=self._sample_callback_fn(use_subtests, device), + set_seed=set_seed, + restrict_to_index=OPINFO_SAMPLE_INPUT_INDEX, + ) + + def sample_inputs(self, device, dtype, requires_grad=False, **kwargs): + """ + Returns an iterable of SampleInputs. + + These samples should be sufficient to test the function works correctly + with autograd, TorchScript, etc. + """ + set_seed = kwargs.pop("set_seed", True) + use_subtests = kwargs.pop("use_subtests", False) + samples = self.sample_inputs_func(self, device, dtype, requires_grad, **kwargs) + + if kwargs.get("include_conjugated_inputs", False): + conj_samples = self.conjugate_sample_inputs( + device, dtype, requires_grad, **kwargs + ) + samples_list = list(samples) + samples_list.extend(conj_samples) + samples = tuple(samples_list) + + return TrackedInputIter( + iter(samples), + "sample input", + item_callback=self._sample_callback_fn(use_subtests, device), + set_seed=set_seed, + restrict_to_index=OPINFO_SAMPLE_INPUT_INDEX, + ) + + def reference_inputs(self, device, dtype, requires_grad=False, **kwargs): + """ + Returns an iterable of SampleInputs. + + Distinct from sample_inputs() above because this returns an expanded set + of inputs when reference_inputs_func is defined. If undefined this returns + the sample inputs. + """ + set_seed = kwargs.pop("set_seed", True) + use_subtests = kwargs.pop("use_subtests", False) + if self.reference_inputs_func is None: + samples = self.sample_inputs_func( + self, device, dtype, requires_grad, **kwargs + ) + return TrackedInputIter( + iter(samples), + "reference input", + item_callback=self._sample_callback_fn(use_subtests, device), + set_seed=set_seed, + restrict_to_index=OPINFO_SAMPLE_INPUT_INDEX, + ) + + if kwargs.get("include_conjugated_inputs", False): + raise NotImplementedError + + references = self.reference_inputs_func( + self, device, dtype, requires_grad, **kwargs + ) + return TrackedInputIter( + iter(references), + "reference input", + item_callback=self._sample_callback_fn(use_subtests, device), + set_seed=set_seed, + restrict_to_index=OPINFO_SAMPLE_INPUT_INDEX, + ) + + def error_inputs(self, device, **kwargs): + """ + Returns an iterable of ErrorInputs. + """ + set_seed = kwargs.pop("set_seed", True) + use_subtests = kwargs.pop("use_subtests", False) + errs = self.error_inputs_func(self, device, **kwargs) + + def _error_item_callback(e, i, use_subtests=use_subtests, device=device): + cb = self._sample_callback_fn(use_subtests, device) + # no rules to apply; just return the sample + if cb is None: + return e + + # adapt the callback call since ErrorInputs contain SampleInputs + _, subtest_ctx = cb(e.sample_input, i) + return (e, subtest_ctx) + + return TrackedInputIter( + iter(errs), + "error input", + track_callback=lambda e: e.sample_input, + item_callback=_error_item_callback, + set_seed=set_seed, + restrict_to_index=OPINFO_SAMPLE_INPUT_INDEX, + ) + + def error_inputs_sparse(self, device, layout, **kwargs): + """ + Returns an iterable of ErrorInputs that contain sparse sample + inputs with a specified layout. + """ + if not self.supports_sparse_layout(layout): + raise unittest.SkipTest("unsupported sparse layout") + return self.error_inputs_sparse_func(self, device, layout, **kwargs) + + def supports_sparse_layout(self, layout): + """Return True if OpInfo supports the specified sparse layout.""" + layout_name = str(layout).split(".")[-1] + # map torch.sparse_coo to OpInfo.supports_sparse: + layout_name = layout_name.replace("_coo", "") + return getattr(self, f"supports_{layout_name}") + + def sample_inputs_sparse( + self, layout, device, dtype, requires_grad=False, **kwargs + ): + """Returns an iterable of SampleInputs that contain inputs with a + specified sparse layout. + """ + layout_name = str(layout).split(".")[-1] + sample_inputs_mth = getattr(self, "sample_inputs_" + layout_name) + + def non_empty_sampler(op, generator): + found_sample = False + for sample in generator: + found_sample = True + yield sample + if not found_sample: + raise unittest.SkipTest("NO SAMPLES!") + + return non_empty_sampler( + self, + sample_inputs_mth(device, dtype, requires_grad=requires_grad, **kwargs), + ) + + def _sample_inputs_unspecified(self, *args, **kwargs): + """Raises an NotImplemented exception in a OpInfo instance creation + that specifies supports_sparse(|_csr|_csc|_bsr|_bsc)=True + without specifying the corresponding sample function as + sample_inputs_sparse_(coo|csr|csc|bsr|bsc)_func. + + To avoid this, either define the corresponding sample function, + or re-map unsupported samples to error inputs in an appropriate + + opinfo/definitions/sparse.py:_validate_sample_input_sparse_ + + function. + """ + raise NotImplementedError("no sample function specified") + + def sample_inputs_sparse_coo(self, device, dtype, requires_grad=False, **kwargs): + """Returns an iterable of SampleInputs that contain inputs with sparse + coo layout. + """ + return self.sample_inputs_sparse_coo_func( + self, device, dtype, requires_grad, **kwargs + ) + + def sample_inputs_sparse_csr(self, device, dtype, requires_grad=False, **kwargs): + """Returns an iterable of SampleInputs that contain inputs with sparse + csr layout. + """ + return self.sample_inputs_sparse_csr_func( + self, device, dtype, requires_grad, **kwargs + ) + + def sample_inputs_sparse_csc(self, device, dtype, requires_grad=False, **kwargs): + """Returns an iterable of SampleInputs that contain inputs with sparse + csc layout. + """ + return self.sample_inputs_sparse_csc_func( + self, device, dtype, requires_grad, **kwargs + ) + + def sample_inputs_sparse_bsr(self, device, dtype, requires_grad=False, **kwargs): + """Returns an iterable of SampleInputs that contain inputs with sparse + bsr layout. + """ + return self.sample_inputs_sparse_bsr_func( + self, device, dtype, requires_grad, **kwargs + ) + + def sample_inputs_sparse_bsc(self, device, dtype, requires_grad=False, **kwargs): + """Returns an iterable of SampleInputs that contain inputs with sparse + bsc layout. + """ + return self.sample_inputs_sparse_bsc_func( + self, device, dtype, requires_grad, **kwargs + ) + + def get_decorators(self, test_class, test_name, device, dtype, param_kwargs): + """Returns the decorators targeting the given test.""" + result = [] + for decorator in self.decorators: + if isinstance(decorator, DecorateInfo): + if decorator.is_active( + test_class, test_name, device, dtype, param_kwargs + ): + result.extend(decorator.decorators) + else: + result.append(decorator) + return result + + def supported_dtypes(self, device_type): + if device_type == "privateuse1": + device_type = torch._C._get_privateuse1_backend_name() + device_type = torch.device(device_type).type + if device_type == "cuda" and TEST_WITH_ROCM: + device_type = "rocm" + result = self.dtypesIf.get(device_type, self.dtypes) + if device_type == "mps": + return result - {torch.float64, torch.cdouble} + return result + + def supported_backward_dtypes(self, device_type): + if not self.supports_autograd: + return set() + + if device_type == "privateuse1": + device_type = torch._C._get_privateuse1_backend_name() + device_type = torch.device(device_type).type + backward_dtypes = None + if device_type == "cuda": + backward_dtypes = ( + self.backward_dtypesIfROCM + if TEST_WITH_ROCM + else self.backward_dtypesIfCUDA + ) + elif device_type == "hpu": + backward_dtypes = self.backward_dtypesIfHpu + elif device_type == "mps": + backward_dtypes = self.backward_dtypes - {torch.double, torch.cdouble} + else: + backward_dtypes = self.backward_dtypes + + allowed_backward_dtypes = floating_and_complex_types_and( + torch.bfloat16, torch.float16, torch.complex32 + ) + return set(allowed_backward_dtypes).intersection(backward_dtypes) + + def supports_dtype(self, dtype, device_type) -> bool: + return dtype in self.supported_dtypes(device_type) + + @property + def full_name(self): + """Returns a full name that helps to uniquely identify this OpInfo.""" + variant = "." + self.variant_test_name if self.variant_test_name else "" + # example: "normal.in_place" where "normal" is the name and "in_place" is the variant + return f"{self.name}{variant}" + + @property + def formatted_name(self): + """Returns a formatted full name for this OpInfo that can be used in test names.""" + return self.full_name.replace(".", "_") + + +# Represents a skip / xfail rule matching a particular set of tests. It allows granularity +# at the device, dtype, op, and individual sample levels. This flexibility allows entire +# bugs to be represented by a single rule, even if this corresponds with multiple conceptual +# test cases across multiple ops. +@dataclass +class SampleRule(ABC): + # function to indicate whether the rule applies to this op; return True if so + # NB: str arg of callable is device_type + op_match_fn: Callable[[str, OpInfo], bool] = None + # function to indicate whether the rule applies to this sample; return True if so + sample_match_fn: Callable[[torch.device, SampleInput], bool] = None + # optional name for identifying the rule + name: str = "" + + def __post_init__(self): + if self.op_match_fn is None: + raise ValueError("must have op_match_fn set to be useful") + if self.sample_match_fn is None: + # by default, match for all samples + self.sample_match_fn = lambda device, sample: True + + # returns a string identifier of the rule type + @abstractmethod + def type(self) -> str: ... + + # returns an appropriate context that handles the xfail, skips, etc. + @abstractmethod + def get_context(self, test_case): ... + + +# useful for specifying xfails +@dataclass +class XFailRule(SampleRule): + # expected error type + error_type: TypeVar = Exception + # expected error message + error_msg: str = ".*" + + @property + def type(self) -> str: + return "xfail" + + def get_context(self, test_case): + return test_case.assertRaisesRegex( + # failing within torch.compile wraps within a BackendCompilerFailed + (self.error_type, torch._dynamo.exc.BackendCompilerFailed), + self.error_msg, + ) + + +# useful for specifying skips +@dataclass +class SkipRule(SampleRule): + @property + def type(self): + return "skip" + + def get_context(self, test_case): + @contextlib.contextmanager + def skipcontext(test_case=test_case): + test_case.skipTest("Skipped!") + yield + + return skipcontext() + + +# Decorator that defines skip / xfail rules for a given test function. If these are +# present, the @ops decorator will apply these for each op and place them onto the +# parametrized test functions for use by e.g. OpInfo.sample_inputs(). +class sample_skips_and_xfails: + def __init__(self, rules): + self.rules = rules + + def __call__(self, fn): + rules = getattr(fn, "sample_skips_and_xfails", None) + if rules is not None: + raise RuntimeError("Multiple sets of sample_skips_and_xfails defined") + + fn.sample_skips_and_xfails = self.rules + return fn + + +def _generate_reduction_inputs(device, dtype, requires_grad, **kwargs): + """Generates input tensors for testing reduction operators""" + yield make_tensor([], dtype=dtype, device=device, requires_grad=requires_grad) + yield make_tensor([2], dtype=dtype, device=device, requires_grad=requires_grad) + yield make_tensor([3, 5], dtype=dtype, device=device, requires_grad=requires_grad) + yield make_tensor( + [3, 2, 1, 2], dtype=dtype, device=device, requires_grad=requires_grad + ) + + +def _generate_reduction_kwargs(ndim, supports_multiple_dims=True): + """Generates a subset of all valid dim and keepdim kwargs given ndim that + is appropriate for testing reduction operators. + """ + + # Test default dim and keepdim + yield {} + + # Test reducing inner and outer most dimensions + yield {"dim": 0, "keepdim": True} + yield {"dim": -1, "keepdim": False} + + # Test reducing middle dimension + if ndim > 2: + yield {"dim": ndim // 2, "keepdim": True} + + if supports_multiple_dims: + # Test reducing all dimensions + yield {"dim": tuple(range(ndim)), "keepdim": False} + + # Test reducing both first and last dimensions + if ndim > 1: + yield {"dim": (0, -1), "keepdim": True} + + # Test reducing every other dimension starting with the second + if ndim > 3: + yield {"dim": tuple(range(1, ndim, 2)), "keepdim": False} + + +def sample_inputs_reduction(op_info, device, dtype, requires_grad, **kwargs): + """Sample inputs for reduction operators.""" + + # TODO(@heitorschueroff) Once all reduction operators are using + # ReductionOpInfo use op_info.supports_multiple_dims directly. + supports_multiple_dims: bool = kwargs.get("supports_multiple_dims", True) + + # TODO(@heitorschueroff) Once all reduction operators are using ReductionOpInfo + # use op_info.generate_args_kwargs directly. + generate_args_kwargs = kwargs.get( + "generate_args_kwargs", lambda *args, **kwargs: (yield (), {}) + ) + + for t in _generate_reduction_inputs(device, dtype, requires_grad): + for reduction_kwargs in _generate_reduction_kwargs( + t.ndim, supports_multiple_dims + ): + for args, kwargs in generate_args_kwargs(t, **reduction_kwargs): + kwargs.update(reduction_kwargs) + yield SampleInput( + t.detach().requires_grad_(requires_grad), args=args, kwargs=kwargs + ) + + +# NOTE [Reductions]: +# +# For testing purposes, we relax the definition of a reduction operator +# as defined in the docstring below. We do this to capture operators with +# a similar API so they can be tested automatically. However... +# +# Strictly speaking a reduction operator is an operator that can reduce an +# array to a single scalar value and that can be computed from the partial +# result of reducing subarrays. This usually means that the reduction operation +# should be commutative and associative. This definition is important when it +# comes to implementation as it determines how a reduction can be parallelized. +# +# For example, many summary statistics such as median, mode and quantile cannot +# be computed from partial results because these are sorting and counting based +# algorithms that need information that would be lost in the reduced value. +class ReductionOpInfo(OpInfo): + """Reduction operator information. + + An operator is a reduction operator if it reduces one or more dimensions of + the input tensor to a single value. Reduction operators must implement the + following signature: + + - `op(input, *args, *, dim=None, keepdim=False, **kwargs) -> Tensor` + + ReductionOpInfo tests that reduction operators implement a consistent API. + Optional features such as reducing over multiple dimensions are captured in + the optional keyword parameters of the ReductionOpInfo constructor. + + If a reduction operator does not yet implement the full required API of + reduction operators, this should be documented by xfailing the failing + tests rather than adding optional parameters to ReductionOpInfo. + + NOTE + The API for reduction operators has not yet been finalized and some + requirements may change. + + See tests in test/test_reductions.py + """ + + def __init__( + self, + name, + *, + # The identity value for the operator if it has one. + identity: Optional[Any] = None, + # The nan policy for the operator if it implements one. + # - propagate: NaN values are propagated to the output + # - omit: NaN values are discarded during the reduction + nan_policy: Optional[str] = None, + # Whether the operator supports reducing multiple dimensions. + supports_multiple_dims: bool = True, + # Whether the operator promotes integral to floating point dtypes. + promotes_int_to_float: bool = False, + # Whether the operator promotes all integral dtypes to int64. + promotes_int_to_int64: bool = False, + # If a specific dtype is given, then the operator always returns that + # dtype irrespective of the input dtype. If None, the operator returns + # the dtype according to the type promotion rules above. + result_dtype: Optional[torch.dtype] = None, + # Casts complex results to real (e.g. linalg.norm or torch.var) + complex_to_real: bool = False, + # ReductionOpInfo tests generate their own input, dim and keepdim + # arguments and call this function to generate tuples of extra args and + # kwargs to use when calling the op. This is required for operators that + # have other required parameters besides the input tensor. + generate_args_kwargs: Callable = lambda t, dim=None, keepdim=False: ( + yield ( + (), + {}, + ) + ), + # Options from the OpInfo base class + **kwargs, + ): + self._original_reduction_args = locals().copy() + assert nan_policy in (None, "propagate", "omit") + + # These are mutually exclusive options + assert not (result_dtype and promotes_int_to_float) + assert not (result_dtype and promotes_int_to_int64) + assert not (result_dtype and complex_to_real) + assert not (promotes_int_to_float and promotes_int_to_int64) + + # Default sample_inputs_func for ReductionOpInfo which augments sample + # inputs from sample_inputs_reduction with the args and kwargs from + # generate_args_kwargs. This is only used if sample_inputs_func is None. + def sample_inputs_func(*args, **kwargs): + kwargs["supports_multiple_dims"] = supports_multiple_dims + kwargs["generate_args_kwargs"] = generate_args_kwargs + yield from sample_inputs_reduction(*args, **kwargs) + + # Override OpInfo defaults and call base class __init__ + kwargs.setdefault("inplace_variant", None) + kwargs.setdefault("sample_inputs_func", sample_inputs_func) + super().__init__(name, promotes_int_to_float=promotes_int_to_float, **kwargs) + + self.identity = identity + self.nan_policy = nan_policy + self.supports_multiple_dims = supports_multiple_dims + self.promotes_int_to_int64 = promotes_int_to_int64 + self.complex_to_real = complex_to_real + self.result_dtype = result_dtype + self.generate_args_kwargs = generate_args_kwargs + + +# The base reference input generation for elementwise binary operations +def _reference_inputs_elementwise_binary( + op, device, dtype, requires_grad, exclude_zero, **kwargs +): + yield from op.sample_inputs_func(op, device, dtype, requires_grad, **kwargs) + yield from generate_elementwise_binary_tensors( + op, + device=device, + dtype=dtype, + requires_grad=requires_grad, + exclude_zero=exclude_zero, + ) + if dtype is not torch.bool: + yield from generate_elementwise_binary_small_value_tensors( + op, device=device, dtype=dtype, requires_grad=requires_grad + ) + if dtype not in (torch.bool, torch.uint8, torch.int8): + yield from generate_elementwise_binary_large_value_tensors( + op, device=device, dtype=dtype, requires_grad=requires_grad + ) + yield from generate_elementwise_binary_broadcasting_tensors( + op, + device=device, + dtype=dtype, + requires_grad=requires_grad, + exclude_zero=exclude_zero, + ) + yield from generate_elementwise_binary_with_scalar_samples( + op, device=device, dtype=dtype, requires_grad=requires_grad + ) + + yield from generate_elementwise_binary_with_scalar_and_type_promotion_samples( + op, device=device, dtype=dtype, requires_grad=requires_grad + ) + + if dtype.is_floating_point or dtype.is_complex: + yield from generate_elementwise_binary_extremal_value_tensors( + op, device=device, dtype=dtype, requires_grad=requires_grad + ) + + +# Note that these references inputs use scalars for the SampleInput.input value, +# and many tests require SampleInput.input be a tensor or a list of tensors +def reference_inputs_elementwise_binary(op, device, dtype, requires_grad, **kwargs): + if hasattr(op, "rhs_make_tensor_kwargs"): + exclude_zero = op.rhs_make_tensor_kwargs.get("exclude_zero", False) + + gen = partial( + _reference_inputs_elementwise_binary, + op, + device, + dtype, + requires_grad, + exclude_zero, + **kwargs, + ) + + # yields "normal" samples + yield from gen() + + # yields noncontiguous samples + for sample in gen(): + yield sample.noncontiguous() + + yield from generate_elementwise_binary_noncontiguous_tensors( + op, + device=device, + dtype=dtype, + requires_grad=requires_grad, + exclude_zero=exclude_zero, + ) + + yield from generate_elementwise_binary_arbitrarily_strided_tensors( + op, + device=device, + dtype=dtype, + requires_grad=requires_grad, + exclude_zero=exclude_zero, + ) + + +# A functional that extends an elementwise binary operator's bespoke error inputs +# with generic error inputs for the class of elementwise binary operations +def make_error_inputs_elementwise_binary(error_inputs_func): + def error_inputs_func_wrapper(op, device, **kwargs): + if error_inputs_func is not None: + yield from error_inputs_func(op, device, **kwargs) + + if not op.supports_rhs_python_scalar: + si = SampleInput(torch.tensor((1, 2, 3), device=device), args=(2,)) + yield ErrorInput(si, error_type=Exception, error_regex="") + + if not op.supports_one_python_scalar: + si = SampleInput(2, args=(torch.tensor((1, 2, 3), device=device),)) + yield ErrorInput(si, error_type=Exception, error_regex="") + + if ( + not kwargs.get("skip_two_python_scalars", False) + and not op.supports_two_python_scalars + ): + si = SampleInput(2, args=(3,)) + yield ErrorInput(si, error_type=Exception, error_regex="") + + return error_inputs_func_wrapper + + +# The following functions and classes are for testing elementwise binary operators. + + +# Returns a generator of pairs of contiguous tensors on the requested device +# and with the requested dtype. +# +# This function is intended to test the non-vectorized and vectorized code +# paths of elementwise binary functions, as well as their handling of odd tensor +# sizes (like zero-dim tensors and tensors with zero elements). +# +# Each iterable will include an a tensor with no elements, +# zero dim (scalar) tensors, small 1D tensors, a medium 1D tensor, and +# a large 2D tensor. +def generate_elementwise_binary_tensors( + op, *, device, dtype, requires_grad=False, exclude_zero=False +): + shapes = ( + # tensors with no elements + (0,), + (1, 0, 3), + # zero dim (scalar) tensor + (), + # small 1D tensor + (20,), + # medium 1D tensor + (812,), + # large 2D tensor + (1029, 917), + ) + + make_arg = partial( + make_tensor, + device=device, + dtype=dtype, + requires_grad=requires_grad, + exclude_zero=exclude_zero, + ) + for shape in shapes: + lhs = make_arg(shape, **op.lhs_make_tensor_kwargs) + rhs = make_arg(shape, **op.rhs_make_tensor_kwargs) + yield SampleInput( + lhs, args=(rhs,), kwargs=op.sample_kwargs(device, dtype, lhs)[0] + ) + + +def generate_elementwise_binary_arbitrarily_strided_tensors( + op, *, device, dtype, requires_grad=False, exclude_zero=False +): + # shape, strides, offset + strided_cases = ( + ((5, 6, 2), (1, 1, 7), 2), + ((5, 5, 4), (1, 1, 7), 2), + ((5, 5, 2), (4, 5, 7), 3), + ((5, 5, 2), (5, 5, 7), 3), + ((5, 5, 2), (5, 5, 5), 3), + ((9, 5, 2), (0, 1, 7), 3), + ) + + make_arg = partial( + make_tensor, + device=device, + dtype=dtype, + requires_grad=requires_grad, + exclude_zero=exclude_zero, + ) + for shape, strides, offset in strided_cases: + a = make_arg( + 500, + ).as_strided(shape, strides, offset) + b = make_arg(shape) + yield SampleInput(a, args=(b,), kwargs=op.sample_kwargs(device, dtype, a)[0]) + + +# Returns a generator of pairs of contiguous tensors on the requested device and with +# the requested dtype. +# +# Unlike the previous function, the values in these tensors are specified manually. +def generate_elementwise_binary_small_value_tensors( + op, *, device, dtype, requires_grad=False, exclude_zero=None +): + if exclude_zero is None: + if hasattr(op, "rhs_make_tensor_kwargs"): + exclude_zero = op.rhs_make_tensor_kwargs.get("exclude_zero", False) + + # defines interesting values + _unsigned_int_vals = (0, 1, 55, 127, 128, 190, 210, 220, 254) + _int_vals = (0, -1, 1, -55, 55, -127, 127, -128) + _float_vals = ( + 0.0, + -0.0, + -0.001, + 0.001, + -0.25, + 0.25, + -1.0, + 1.0, + -math.pi / 2, + math.pi / 2, + -math.pi + 0.00001, + math.pi - 0.00001, + -math.pi, + math.pi, + -math.pi - 0.00001, + math.pi + 0.00001, + ) + + l_vals = [] + r_vals = [] + + if dtype.is_floating_point: + prod = product(_float_vals, _float_vals) + elif dtype.is_complex: + complex_vals = product(_float_vals, _float_vals) + # Note the use of list is required here or the map generator will be + # emptied by the following product and it won't produce the desired cross-product + complex_vals = [complex(*x) for x in complex_vals] + prod = product(complex_vals, complex_vals) + elif dtype in (torch.int8, torch.int16, torch.int32, torch.int64): + prod = product(_int_vals, _int_vals) + elif dtype is torch.uint8: + prod = product(_unsigned_int_vals, _unsigned_int_vals) + else: + raise ValueError("Unsupported dtype!") + + for l, r in prod: + l_vals.append(l) + if r == 0 and exclude_zero: + r_vals.append(1) + else: + r_vals.append(r) + + lhs = torch.tensor(l_vals, device=device, dtype=dtype, requires_grad=requires_grad) + rhs = torch.tensor(r_vals, device=device, dtype=dtype, requires_grad=requires_grad) + + yield SampleInput(lhs, args=(rhs,), kwargs=op.sample_kwargs(device, dtype, lhs)[0]) + + +def generate_elementwise_binary_large_value_tensors( + op, *, device, dtype, requires_grad=False +): + _large_int_vals = (-1113, 1113, -10701, 10701) + _large_float16_vals = (-501, 501, -1001.2, 1001.2, -13437.7, 13437.7) + _large_float_vals = _large_float16_vals + (-4988429.2, 4988429.2, -1e20, 1e20) + + l_vals = [] + r_vals = [] + + if dtype == torch.float16: + prod = product(_large_float16_vals, _large_float16_vals) + elif dtype.is_floating_point: + prod = product(_large_float_vals, _large_float_vals) + elif dtype.is_complex: + complex_vals = product(_large_float_vals, _large_float_vals) + # Note the use of list is required here or the map generator will be + # emptied by the following product and it won't produce the desired cross-product + complex_vals = [complex(*x) for x in complex_vals] + prod = product(complex_vals, complex_vals) + elif dtype in (torch.int16, torch.int32, torch.int64): + prod = product(_large_int_vals, _large_int_vals) + else: + raise ValueError("Unsupported dtype!") + + for l, r in prod: + l_vals.append(l) + r_vals.append(r) + + lhs = torch.tensor(l_vals, device=device, dtype=dtype, requires_grad=requires_grad) + rhs = torch.tensor(r_vals, device=device, dtype=dtype, requires_grad=requires_grad) + + yield SampleInput(lhs, args=(rhs,), kwargs=op.sample_kwargs(device, dtype, lhs)[0]) + + +def generate_elementwise_binary_extremal_value_tensors( + op, *, device, dtype, requires_grad=False +): + _float_extremals = (float("inf"), float("-inf"), float("nan")) + + l_vals = [] + r_vals = [] + + if dtype.is_floating_point: + prod = product(_float_extremals, _float_extremals) + elif dtype.is_complex: + complex_vals = product(_float_extremals, _float_extremals) + # Note the use of list is required here or the map generator will be + # emptied by the following product and it won't produce the desired cross-product + complex_vals = [complex(*x) for x in complex_vals] + prod = product(complex_vals, complex_vals) + else: + raise ValueError("Unsupported dtype!") + + for l, r in prod: + l_vals.append(l) + r_vals.append(r) + + lhs = torch.tensor(l_vals, device=device, dtype=dtype, requires_grad=requires_grad) + rhs = torch.tensor(r_vals, device=device, dtype=dtype, requires_grad=requires_grad) + + yield SampleInput(lhs, args=(rhs,), kwargs=op.sample_kwargs(device, dtype, lhs)[0]) + + # Test case for NaN propagation + nan = ( + float("nan") if dtype.is_floating_point else complex(float("nan"), float("nan")) + ) + lhs = make_tensor( + (128, 128), device=device, dtype=dtype, requires_grad=requires_grad + ) + lhs.view(-1)[::3] = nan + rhs = make_tensor( + (128, 128), device=device, dtype=dtype, requires_grad=requires_grad + ) + rhs.view(-1)[::3] = nan + + yield SampleInput(lhs, args=(rhs,), kwargs=op.sample_kwargs(device, dtype, lhs)[0]) + + +# Returns a generator of pairs of contiguous and noncontiguous tensors that +# require broadcasting +def generate_elementwise_binary_broadcasting_tensors( + op, *, device, dtype, requires_grad=False, exclude_zero=False +): + shapes = ( + ((1,), ()), + ((2,), ()), + ((1,), (2,)), + ((2, 1), (2,)), + ((1, 2), (2,)), + ((3, 2), (2,)), + ((1, 3, 2), (2,)), + ((1, 3, 2), (3, 2)), + ((3, 1, 2), (3, 2)), + ((2, 3, 2), ()), + ((3, 1, 2), (1, 3, 2)), + ) + + make_arg = partial( + make_tensor, + device=device, + dtype=dtype, + requires_grad=requires_grad, + exclude_zero=exclude_zero, + ) + for shape, noncontiguous in product(shapes, [True, False]): + shape_lhs, shape_rhs = shape + lhs = make_arg( + shape_lhs, noncontiguous=noncontiguous, **op.lhs_make_tensor_kwargs + ) + rhs = make_arg( + shape_rhs, noncontiguous=noncontiguous, **op.rhs_make_tensor_kwargs + ) + + yield SampleInput( + lhs, + args=(rhs,), + broadcasts_input=True, + kwargs=op.sample_kwargs(device, dtype, lhs)[0], + ) + + +# Returns a generator of pairs of contiguous tensors and scalars +def generate_elementwise_binary_with_scalar_samples( + op, *, device, dtype, requires_grad=False +): + make_arg = partial( + make_tensor, device=device, dtype=dtype, requires_grad=requires_grad + ) + + shapes = ((), (3,), (5, 3), (0, 1, 3), (1, 5)) + if op.supports_rhs_python_scalar: + for shape in shapes: + lhs = make_arg(shape, **op.lhs_make_tensor_kwargs) + rhs = make_arg(shape, **op.rhs_make_tensor_kwargs) + lhs_scalar = make_arg((), **op.lhs_make_tensor_kwargs).item() + rhs_scalar = make_arg((), **op.rhs_make_tensor_kwargs).item() + + yield SampleInput( + lhs, args=(rhs_scalar,), kwargs=op.sample_kwargs(device, dtype, lhs)[0] + ) + + # Extends with scalar lhs + if op.supports_one_python_scalar: + yield SampleInput( + lhs_scalar, + args=(rhs,), + kwargs=op.sample_kwargs(device, dtype, lhs_scalar)[0], + ) + + if op.supports_two_python_scalars: + lhs_scalar = make_arg((), **op.lhs_make_tensor_kwargs).item() + rhs_scalar = make_arg((), **op.rhs_make_tensor_kwargs).item() + + yield SampleInput( + lhs_scalar, + args=(rhs_scalar,), + kwargs=op.sample_kwargs(device, dtype, lhs_scalar)[0], + ) + + +# Returns a generator of pairs of contiguous tensors and 0d tensors and scalars and type promotion +def generate_elementwise_binary_with_scalar_and_type_promotion_samples( + op, *, device, dtype, requires_grad=False +): + # add these samples only for logical and comparison ops, arithmetic ops are not happy about extremal scalars + if op.name in ( + "eq", + "ne", + "gt", + "ge", + "lt", + "le", + "logical_and", + "logical_or", + "logical_xor", + ): + make_arg = partial( + make_tensor, device=device, dtype=dtype, requires_grad=requires_grad + ) + shape = ( + 23, + ) # this shape is big enough to trigger vectorization, and has non-vectorized tail + values = (float("nan"), float("inf"), -float("inf")) + scalar_tensors = tuple(torch.tensor(val) for val in values) + if op.supports_rhs_python_scalar: + lhs = make_arg(shape, **op.lhs_make_tensor_kwargs) + rhs = make_arg(shape, **op.rhs_make_tensor_kwargs) + for scalar in values + scalar_tensors: + yield SampleInput( + lhs, args=(scalar,), kwargs=op.sample_kwargs(device, dtype, lhs)[0] + ) + # Extends with scalar lhs + if op.supports_one_python_scalar: + yield SampleInput( + scalar, + args=(rhs,), + kwargs=op.sample_kwargs(device, dtype, scalar)[0], + ) + + +# Returns a generator of pairs of noncontiguous tensors +def generate_elementwise_binary_noncontiguous_tensors( + op, *, device, dtype, requires_grad=False, exclude_zero=False +): + make_arg = partial( + make_tensor, + device=device, + dtype=dtype, + requires_grad=requires_grad, + exclude_zero=exclude_zero, + ) + + # Generic noncontiguity + lhs = make_arg((1026,), noncontiguous=True, **op.lhs_make_tensor_kwargs) + rhs = make_arg((1026,), noncontiguous=True, **op.rhs_make_tensor_kwargs) + + yield SampleInput( + lhs.clone(), args=(rhs.clone(),), kwargs=op.sample_kwargs(device, dtype, lhs)[0] + ) + yield SampleInput( + lhs.contiguous(), args=(rhs,), kwargs=op.sample_kwargs(device, dtype, lhs)[0] + ) + + # Transposed + lhs = make_arg((789, 357), **op.lhs_make_tensor_kwargs) + rhs = make_arg((789, 357), **op.rhs_make_tensor_kwargs) + + yield SampleInput( + lhs.T, args=(rhs.T,), kwargs=op.sample_kwargs(device, dtype, lhs)[0] + ) + + # More noncontiguity + shapes = ((5, 7), (1024,)) + + for shape in shapes: + lhs = make_arg(shape, **op.lhs_make_tensor_kwargs) + rhs = make_arg(shape, **op.rhs_make_tensor_kwargs) + + lhs_non_contig = torch.empty(shape + (2,), device=device, dtype=dtype)[..., 0] + lhs_non_contig.copy_(lhs) + + rhs_non_contig = torch.empty(shape + (2,), device=device, dtype=dtype)[..., 0] + rhs_non_contig.copy_(rhs) + + yield SampleInput( + lhs_non_contig.clone(), + args=(rhs_non_contig.clone(),), + kwargs=op.sample_kwargs(device, dtype, lhs)[0], + ) + yield SampleInput( + lhs_non_contig.contiguous(), + args=(rhs_non_contig,), + kwargs=op.sample_kwargs(device, dtype, lhs)[0], + ) + + # Noncontiguous indices + shape = (2, 2, 1, 2) + lhs = make_arg(shape, **op.lhs_make_tensor_kwargs) + rhs = make_arg(shape, **op.rhs_make_tensor_kwargs) + + lhs_non_contig = lhs[:, 1, ...] + rhs_non_contig = rhs[:, 1, ...] + + yield SampleInput( + lhs_non_contig.clone(), + args=(rhs_non_contig.clone(),), + kwargs=op.sample_kwargs(device, dtype, lhs)[0], + ) + yield SampleInput( + lhs_non_contig.contiguous(), + args=(rhs_non_contig,), + kwargs=op.sample_kwargs(device, dtype, lhs)[0], + ) + + # Expanded tensors + shapes = ((1, 3), (1, 7), (5, 7)) + + for shape in shapes: + lhs = make_arg(shape, **op.lhs_make_tensor_kwargs) + rhs = make_arg(shape, **op.rhs_make_tensor_kwargs) + + lhs_non_contig = lhs.expand(3, -1, -1) + rhs_non_contig = rhs.expand(3, -1, -1) + + yield SampleInput( + lhs_non_contig, + args=(rhs_non_contig,), + kwargs=op.sample_kwargs(device, dtype, lhs)[0], + ) + + +# Sample inputs for elementwise binary operators, like add +def sample_inputs_elementwise_binary(op, device, dtype, requires_grad, **kwargs): + _M = S if kwargs.get("small_inputs_only", False) else M + _S = XS if kwargs.get("small_inputs_only", False) else S + + if hasattr(op, "rhs_make_tensor_kwargs"): + exclude_zero = op.rhs_make_tensor_kwargs.get("exclude_zero", False) + + make_arg = partial( + make_tensor, + device=device, + dtype=dtype, + requires_grad=requires_grad, + exclude_zero=exclude_zero, + ) + + shapes = ( + ((), ()), + ((_S,), ()), + ((_S, 1), (_S,)), + ((_M, _S), ()), + ((_S, _M, _S), (_M, _S)), + ((_S, _M, _S), (_S, _M, _S)), + ((_M, 1, _S), (_M, _S)), + ((_M, 1, _S), (1, _M, _S)), + ((0, 1, XS), (0, _M, XS)), + ) + + for shape_lhs, shape_rhs in shapes: + lhs = make_arg(shape_lhs, **op.lhs_make_tensor_kwargs) + rhs = make_arg(shape_rhs, **op.rhs_make_tensor_kwargs) + broadcasts_input = shape_lhs != torch.broadcast_shapes(shape_lhs, shape_rhs) + + yield SampleInput( + lhs, + args=(rhs,), + kwargs=op.sample_kwargs(device, dtype, lhs)[0], + broadcasts_input=broadcasts_input, + ) + + +# Metadata class for binary "universal functions (ufuncs)" that accept two +# tensor and have common properties +class BinaryUfuncInfo(OpInfo): + """Operator information for 'universal binary functions (binary ufuncs).' + These are functions of two tensors with common properties like: + - they are elementwise functions + - the output shape is determined by the input shape + - they typically have method and inplace variants + - they typically support the out kwarg + - they typically have NumPy or SciPy references + See NumPy's universal function documentation + (https://numpy.org/doc/stable/reference/ufuncs.html) for more details + about the concept of ufuncs. + """ + + def __init__( + self, + name, + *, + sample_inputs_func=sample_inputs_elementwise_binary, + reference_inputs_func=reference_inputs_elementwise_binary, + sample_kwargs=lambda device, dtype, input: ({}, {}), + error_inputs_func=None, + lhs_make_tensor_kwargs=None, + rhs_make_tensor_kwargs=None, + always_returns_bool=False, # Set to true if the op always returns bool tensors + supports_rhs_python_scalar=True, # Whether the operator allows Tensor x scalar inputs + supports_one_python_scalar=False, # Whether the operator allows scalar x tensor and tensor x scalar inputs + supports_two_python_scalars=False, # Whether the operator allows scalar x scalar inputs + **kwargs, + ): + self._original_binary_ufunc_args = locals().copy() + + # Elementwise binary operations perform the equivalent of test_numpy_refs + # in test_binary_ufuncs, but with additional test granularity. So the + # generic test_ops.py test is skipped because it's redundant. + common_skips = ( + DecorateInfo( + unittest.skip("Skipping redundant test."), + "TestCommon", + "test_numpy_refs", + ), + ) + kwargs["skips"] = kwargs.get("skips", ()) + common_skips + super().__init__( + name, + sample_inputs_func=sample_inputs_func, + reference_inputs_func=reference_inputs_func, + error_inputs_func=make_error_inputs_elementwise_binary(error_inputs_func), + **kwargs, + ) + + self.sample_kwargs = sample_kwargs + + # [lr]hs_make_tensor_kwargs are part of the OpInfo to be able to dynamically generate valid samples later on. + if lhs_make_tensor_kwargs is None: + lhs_make_tensor_kwargs = {} + self.lhs_make_tensor_kwargs = lhs_make_tensor_kwargs + + if rhs_make_tensor_kwargs is None: + rhs_make_tensor_kwargs = {} + self.rhs_make_tensor_kwargs = rhs_make_tensor_kwargs + + self.always_returns_bool = always_returns_bool + self.supports_rhs_python_scalar = supports_rhs_python_scalar + self.supports_one_python_scalar = supports_one_python_scalar + self.supports_two_python_scalars = supports_two_python_scalars + + if self.supports_two_python_scalars: + self.supports_one_python_scalar = True + + if self.supports_one_python_scalar: + assert supports_rhs_python_scalar, ( + "Can't support lhs and rhs Python scalars but not rhs scalars!" + ) + + +# The following functions and classes are for testing elementwise unary operators. +def sample_inputs_elementwise_unary( + op_info, device, dtype, requires_grad, op_kwargs=None, **kwargs +): + if not op_kwargs: + op_kwargs = {} + + _L = S if kwargs.get("small_inputs_only", False) else L + + low, high = op_info.domain + is_floating = dtype.is_floating_point or dtype.is_complex + low = low if low is None or not is_floating else low + op_info._domain_eps + high = high if high is None or not is_floating else high - op_info._domain_eps + if ( + op_info.supports_sparse_csr + or op_info.supports_sparse_csc + or op_info.supports_sparse_bsr + or op_info.supports_sparse_bsc + ): + # Tensors with dim=2 for sparse compressed testing + yield SampleInput( + make_tensor( + (_L, _L), + device=device, + dtype=dtype, + low=low, + high=high, + requires_grad=requires_grad, + ), + kwargs=op_kwargs, + ) + else: + # Creates a 1D, empty, and scalar tensor + for shape in ((_L,), (1, 0, 3), ()): + yield SampleInput( + make_tensor( + shape, + device=device, + dtype=dtype, + low=low, + high=high, + requires_grad=requires_grad, + ), + kwargs=op_kwargs, + ) + + +# Replace values satisfying condition with a safe value. This is used to block +# out values the could cause singularity like tan(pi/2) +def _replace_values_in_tensor(tensor, condition, safe_value): + mask = condition(tensor) + tensor.masked_fill_(mask, safe_value) + + +# Helper to create a unary elementwise tensor with valid inputs +def _make_unary_elementwise_tensor(shape, *, op, dtype, **kwargs): + low, high = op.domain + is_floating = dtype.is_floating_point or dtype.is_complex + low = low if low is None or not is_floating else low + op._domain_eps + high = high if high is None or not is_floating else high - op._domain_eps + + a = make_tensor(shape, low=low, high=high, dtype=dtype, **kwargs) + + if op.reference_numerics_filter is not None and dtype is not torch.bool: + condition, safe_value = op.reference_numerics_filter + _replace_values_in_tensor(a, condition, safe_value) + + return a + + +# Restricts the values in the tensor to the domain of the +# given elementwise unary operator +def _filter_unary_elementwise_tensor(a, *, op): + # short-circuits for boolean tensors + if a.dtype is torch.bool: + return a + + low, high = op.domain + is_floating = a.dtype.is_floating_point or a.dtype.is_complex + low = low if low is None or not is_floating else low + op._domain_eps + high = high if high is None or not is_floating else high - op._domain_eps + + if a.dtype is torch.uint8 and low is not None: + low = max(low, 0) + + if not a.dtype.is_floating_point and not a.dtype.is_complex: + low = math.ceil(low) if low is not None else None + high = math.floor(high) if high is not None else None + + if op.reference_numerics_filter is not None: + condition, safe_value = op.reference_numerics_filter + _replace_values_in_tensor(a, condition, safe_value) + + if low is not None or high is not None: + if a.dtype.is_complex: + a.real.clamp_(low, high) + a.imag.clamp_(low, high) + else: + a.clamp_(min=low, max=high) + + return a + + +def generate_elementwise_unary_tensors(op, *, device, dtype, requires_grad, **kwargs): + # Special-cases bool + if dtype is torch.bool: + tensors = ( + torch.empty(0, device=device, dtype=torch.bool), + torch.tensor(True, device=device), + torch.tensor(False, device=device), + torch.tensor((True, False), device=device), + make_tensor((812,), device=device, dtype=dtype), + make_tensor((1029, 917), device=device, dtype=dtype), + ) + for a in tensors: + yield SampleInput(a, kwargs=op.sample_kwargs(device, dtype, a)[0]) + + shapes = ( + (1029, 917), + (812,), + # Empty sizes + (0,), + (0, 3, 3), + (1, 0, 5), + (6, 0, 0, 0), + (3, 0, 1, 0), + ) + + make_arg = partial( + _make_unary_elementwise_tensor, + op=op, + device=device, + dtype=dtype, + requires_grad=requires_grad, + ) + for shape in shapes: + a = make_arg(shape) + yield SampleInput(a, kwargs=op.sample_kwargs(device, dtype, a)[0]) + + +def generate_elementwise_unary_small_value_tensors( + op, *, device, dtype, requires_grad=False +): + for sample in generate_elementwise_binary_small_value_tensors( + op, device=device, dtype=dtype, requires_grad=requires_grad + ): + a = _filter_unary_elementwise_tensor(sample.input, op=op) + yield SampleInput(a, kwargs=op.sample_kwargs(device, dtype, a)[0]) + + +def generate_elementwise_unary_large_value_tensors( + op, *, device, dtype, requires_grad=False +): + for sample in generate_elementwise_binary_large_value_tensors( + op, device=device, dtype=dtype, requires_grad=requires_grad + ): + a = _filter_unary_elementwise_tensor(sample.input, op=op) + yield SampleInput(sample.input, kwargs=op.sample_kwargs(device, dtype, a)[0]) + + +def generate_elementwise_unary_extremal_value_tensors( + op, *, device, dtype, requires_grad=False +): + for sample in generate_elementwise_binary_extremal_value_tensors( + op, device=device, dtype=dtype, requires_grad=requires_grad + ): + yield SampleInput( + sample.input, kwargs=op.sample_kwargs(device, dtype, sample.input)[0] + ) + + +def generate_elementwise_unary_noncontiguous_tensors( + op, *, device, dtype, requires_grad=False +): + make_arg = partial( + _make_unary_elementwise_tensor, + op=op, + device=device, + dtype=dtype, + requires_grad=requires_grad, + ) + + # Generic noncontiguity + t = make_arg((1026,), noncontiguous=True) + yield SampleInput(t, kwargs=op.sample_kwargs(device, dtype, t)[0]) + + # Transposed + t = make_arg((1024, 1024)).T + yield SampleInput(t, kwargs=op.sample_kwargs(device, dtype, t)[0]) + + # Expanded tensors + shapes = ((1, 3), (1, 7), (5, 7)) + + for shape in shapes: + t = make_arg(shape) + t_non_contig = t.expand(3, -1, -1) + yield SampleInput( + t_non_contig, kwargs=op.sample_kwargs(device, dtype, t_non_contig)[0] + ) + + +def generate_elementwise_unary_arbitrarily_strided_tensors( + op, *, device, dtype, requires_grad=False +): + # shape, strides, offset + strided_cases = ( + ((5, 6, 2), (1, 1, 7), 2), + ((5, 5, 4), (1, 1, 7), 2), + ((5, 5, 2), (4, 5, 7), 3), + ((5, 5, 2), (5, 5, 7), 3), + ((5, 5, 2), (5, 5, 5), 3), + ((9, 5, 2), (0, 1, 7), 3), + ) + + make_arg = partial( + make_tensor, device=device, dtype=dtype, requires_grad=requires_grad + ) + for shape, strides, offset in strided_cases: + a = make_arg( + 500, + ).as_strided(shape, strides, offset) + yield SampleInput(a, kwargs=op.sample_kwargs(device, dtype, a)[0]) + + +# Reuses the elementwise binary generators for consistency +# TODO: in the future generalize the reference generators to handle n-ary elementwise operations +def _reference_inputs_elementwise_unary(op, device, dtype, requires_grad, **kwargs): + yield from op.sample_inputs_func(op, device, dtype, requires_grad, **kwargs) + + yield from generate_elementwise_unary_tensors( + op, device=device, dtype=dtype, requires_grad=requires_grad, **kwargs + ) + + if dtype is not torch.bool: + yield from generate_elementwise_unary_small_value_tensors( + op, device=device, dtype=dtype, requires_grad=requires_grad, **kwargs + ) + if dtype not in (torch.bool, torch.uint8, torch.int8) and ( + op.handles_large_floats + or (not dtype.is_floating_point and not dtype.is_complex) + ): + yield from generate_elementwise_unary_large_value_tensors( + op, device=device, dtype=dtype, requires_grad=requires_grad, **kwargs + ) + + if dtype.is_floating_point or ( + op.handles_complex_extremal_values and dtype.is_complex + ): + yield from generate_elementwise_unary_extremal_value_tensors( + op, device=device, dtype=dtype, requires_grad=requires_grad, **kwargs + ) + + +def reference_inputs_elementwise_unary(op, device, dtype, requires_grad, **kwargs): + gen = partial( + _reference_inputs_elementwise_unary, op, device, dtype, requires_grad, **kwargs + ) + + # yields "normal" samples + yield from gen() + + # yields noncontiguous samples + for sample in gen(): + yield sample.noncontiguous() + + yield from generate_elementwise_unary_noncontiguous_tensors( + op, device=device, dtype=dtype, requires_grad=requires_grad, **kwargs + ) + + yield from generate_elementwise_unary_arbitrarily_strided_tensors( + op, device=device, dtype=dtype, requires_grad=requires_grad, **kwargs + ) + + +# Metadata class for unary "universal functions (ufuncs)" that accept a single +# tensor and have common properties like: +class UnaryUfuncInfo(OpInfo): + """Operator information for 'universal unary functions (unary ufuncs).' + These are functions of a single tensor with common properties like: + - they are elementwise functions + - the input shape is the output shape + - they typically have method and inplace variants + - they typically support the out kwarg + - they typically have NumPy or SciPy references + See NumPy's universal function documentation + (https://numpy.org/doc/1.18/reference/ufuncs.html) for more details + about the concept of ufuncs. + """ + + def __init__( + self, + name, # the string name of the function + *, + dtypes=floating_types(), + domain=(None, None), # the [low, high) domain of the function + handles_complex_extremal_values=True, # whether the op correctly handles extremal values (like nan/inf) + handles_large_floats=True, # whether the op correctly handles large float values (like 1e20) + supports_complex_to_float=False, # op supports casting from complex input to real output safely eg. angle + sample_inputs_func=sample_inputs_elementwise_unary, + reference_inputs_func=reference_inputs_elementwise_unary, + sample_kwargs=lambda device, dtype, input: ({}, {}), + reference_numerics_filter=None, # Filters values in the range of the domain specified above but that should not be tested + **kwargs, + ): + self._original_unary_ufunc_args = locals().copy() + + super().__init__( + name, + dtypes=dtypes, + sample_inputs_func=sample_inputs_func, + reference_inputs_func=reference_inputs_func, + **kwargs, + ) + self.domain = domain + self.handles_complex_extremal_values = handles_complex_extremal_values + self.handles_large_floats = handles_large_floats + self.supports_complex_to_float = supports_complex_to_float + self.reference_numerics_filter = reference_numerics_filter + + # test_unary_ufuncs.py generates its own inputs to test the consistency + # of the operator on sliced tensors, non-contig tensors, etc. + # `sample_kwargs` is a utility function to provide kwargs + # along with those inputs if required (eg. clamp). + # It should return two dictionaries, first holding kwarg for + # torch operator and second one for reference NumPy operator. + self.sample_kwargs = sample_kwargs + + # Epsilon to ensure grad and gradgrad checks don't test values + # outside a function's domain. + self._domain_eps = 1e-5 + + +def sample_inputs_spectral_ops(self, device, dtype, requires_grad=False, **kwargs): + is_fp16_or_chalf = dtype == torch.complex32 or dtype == torch.half + if not is_fp16_or_chalf: + nd_tensor = partial( + make_tensor, + (S, S + 1, S + 2), + device=device, + dtype=dtype, + requires_grad=requires_grad, + ) + oned_tensor = partial( + make_tensor, (31,), device=device, dtype=dtype, requires_grad=requires_grad + ) + else: + # cuFFT supports powers of 2 for half and complex half precision + # NOTE: For hfft, hfft2, hfftn, irfft, irfft2, irfftn with default args + # where output_size n=2*(input_size - 1), we make sure that logical fft size is a power of two + low = None + high = None + if self.name in ["fft.hfft", "fft.irfft", "_refs.fft.hfft", "_refs.fft.irfft"]: + shapes = ((2, 9, 9), (33,)) + elif self.name in [ + "fft.hfft2", + "fft.irfft2", + "_refs.fft.hfft2", + "_refs.fft.irfft2", + ]: + shapes = ((2, 8, 9), (33,)) + elif self.name in [ + "fft.hfftn", + "fft.irfftn", + "_refs.fft.hfftn", + "_refs.fft.irfftn", + ]: + shapes = ((2, 2, 33), (33,)) + # Adjusting the limits because the test would be flaky due to over-saturation of float16 + # See: https://github.com/pytorch/pytorch/pull/81416 + low = -1.0 + high = 1.0 + else: + shapes = ((2, 8, 16), (32,)) + nd_tensor = partial( + make_tensor, + shapes[0], + device=device, + low=low, + high=high, + dtype=dtype, + requires_grad=requires_grad, + ) + oned_tensor = partial( + make_tensor, + shapes[1], + device=device, + low=low, + high=high, + dtype=dtype, + requires_grad=requires_grad, + ) + + if self.ndimensional == SpectralFuncType.ND: + yield SampleInput( + nd_tensor(), + s=(3, 10) if not is_fp16_or_chalf else (4, 8), + dim=(1, 2), + norm="ortho", + ) + yield SampleInput(nd_tensor(), norm="ortho") + yield SampleInput(nd_tensor(), s=(8,)) + yield SampleInput(oned_tensor()) + yield from (SampleInput(nd_tensor(), dim=dim) for dim in [-1, -2, -3, (0, -1)]) + elif self.ndimensional == SpectralFuncType.TwoD: + yield SampleInput( + nd_tensor(), + s=(3, 10) if not is_fp16_or_chalf else (4, 8), + dim=(1, 2), + norm="ortho", + ) + yield SampleInput(nd_tensor(), norm="ortho") + yield SampleInput(nd_tensor(), s=(6, 8) if not is_fp16_or_chalf else (4, 8)) + yield SampleInput(nd_tensor(), dim=0) + yield SampleInput(nd_tensor(), dim=(0, -1)) + yield SampleInput(nd_tensor(), dim=(-3, -2, -1)) + else: + yield SampleInput( + nd_tensor(), + n=10 if not is_fp16_or_chalf else 8, + dim=1, + norm="ortho", + ) + yield SampleInput(nd_tensor(), norm="ortho") + yield SampleInput(nd_tensor(), n=7 if not is_fp16_or_chalf else 8) + yield SampleInput(oned_tensor()) + yield from (SampleInput(nd_tensor(), dim=dim) for dim in [-1, -2, -3]) + + +SpectralFuncType = Enum("SpectralFuncType", ("OneD", "TwoD", "ND")) + + +# Metadata class for Fast Fourier Transforms in torch.fft. +class SpectralFuncInfo(OpInfo): + """Operator information for torch.fft transforms.""" + + def __init__( + self, + name, # the string name of the function + *, + ref=None, # Reference implementation (probably in np.fft namespace) + dtypes=floating_and_complex_types(), + ndimensional: SpectralFuncType, + sample_inputs_func=sample_inputs_spectral_ops, + decorators=None, + **kwargs, + ): + self._original_spectral_func_args = dict(locals()).copy() + self._original_spectral_func_args.update(kwargs) + + decorators = list(decorators) if decorators is not None else [] + decorators += [ + skipCPUIfNoFFT, + DecorateInfo( + toleranceOverride({torch.chalf: tol(4e-2, 4e-2)}), + "TestCommon", + "test_complex_half_reference_testing", + ), + ] + + super().__init__( + name=name, + dtypes=dtypes, + decorators=decorators, + sample_inputs_func=sample_inputs_func, + **kwargs, + ) + self.ref = ref + self.ndimensional = ndimensional + + +class ShapeFuncInfo(OpInfo): + """Early version of a specialized OpInfo for Shape manipulating operations like tile and roll""" + + def __init__( + self, + name, # the string name of the function + *, + ref, # a reference function + dtypes=floating_types(), + dtypesIfCUDA=None, + dtypesIfROCM=None, + dtypesIfXPU=None, + sample_inputs_func=None, + **kwargs, + ): + super().__init__( + name, + dtypes=dtypes, + dtypesIfCUDA=dtypesIfCUDA, + dtypesIfROCM=dtypesIfROCM, + dtypesIfXPU=dtypesIfXPU, + sample_inputs_func=sample_inputs_func, + **kwargs, + ) + self.ref = ref + + +def sample_inputs_foreach( + self, + device, + dtype, + N, + *, + noncontiguous=False, + same_size=False, + low=None, + high=None, + # zero_size means EVERY input is empty + zero_size: bool, + requires_grad: bool, + # mutually exclusive from same_size and zero_size, which are all or nothing + intersperse_empty_tensors: bool = False, +): + if zero_size: + return [torch.empty(0, dtype=dtype, device=device) for _ in range(N)] + if same_size: + return [ + make_tensor( + (N, N), + dtype=dtype, + device=device, + noncontiguous=noncontiguous, + low=low, + high=high, + requires_grad=requires_grad, + ) + for _ in range(N) + ] + else: + # interweave some empty tensors + have the last 2 tensors be empty (see #100701) + return [ + torch.empty(0, dtype=dtype, device=device, requires_grad=requires_grad) + if (i % 3 == 0 or i >= N - 2) and intersperse_empty_tensors + else make_tensor( + (N - i, N - i), + dtype=dtype, + device=device, + noncontiguous=noncontiguous, + low=low, + high=high, + requires_grad=requires_grad, + ) + for i in range(N) + ] + + +def get_foreach_method_names(name): + # get torch inplace reference function + op_name = "_foreach_" + name + inplace_op_name = op_name + "_" + + op = getattr(torch, op_name, None) + inplace_op = getattr(torch, inplace_op_name, None) + + ref = getattr(torch, name, None) + ref_inplace = getattr(torch.Tensor, name + "_", None) + return op, inplace_op, ref, ref_inplace + + +@dataclass +class ForeachFuncInfo(OpInfo): + """Early version of a specialized OpInfo for foreach functions + + The main differences from the parent class are (a) `dtypes`, `dtypesIfCUDA`, and `dtypesIfROCM` + are set to `get_all_dtypes(include_qint=False)`, and (b) the following arguments. + + ``supports_alpha_param=True`` means that the function supports a python scalar (``numbers.Number``) + as the last keyword argument such as `_foreach_add`. + ``supports_scalar_self_arg=True`` means that the function can take a python scalar as its first argument. + Currently only `_foreach_pow` supports this. + ``backward_requires_result=True``, which could sound self-explanatory, means that the function uses + the forward result for its backward computation. + """ + + supports_alpha_param: bool = False + supports_scalar_self_arg: bool = False + backward_requires_result: bool = False + + def __post_init__(self): + ( + foreach_method, + foreach_method_inplace, + torch_ref_method, + torch_ref_inplace, + ) = get_foreach_method_names(self.name) + if not self.supports_out: + # note(crcrpar): `foreach_method` for `"zero"` is `None` but `None` would call + # `_getattr_qual` in `OpInfo.__post_init__` which should fail since `_foreach_zero` + # is not defined at the moment. Thus to skip the qualification, set a similar torch + # function. + assert foreach_method is None + assert torch_ref_method is None + foreach_method = foreach_method_inplace + torch_ref_method = torch_ref_inplace + + # We disable all complex128 tests internally for foreach due to reported flakiness + # tracked in #139648 + supported_dtypes = get_all_dtypes(include_qint=False) + if IS_FBCODE: + supported_dtypes = [ + x for x in supported_dtypes if x is not torch.complex128 + ] + self.dtypes = _dispatch_dtypes(supported_dtypes) + + self.op = foreach_method + self.method_variant = foreach_method + self.ref = torch_ref_method + self.inplace_variant = foreach_method_inplace + self.ref_inplace = torch_ref_inplace + self.has_no_in_place = self.inplace_variant is None + + name = self.name + self.name = f"_foreach_{name}" + if name == "norm": + self.ref = torch.linalg.vector_norm + elif name == "minimum": + # because minimum ref does not support inplace or scalar + self.ref = torch.clamp_max + self.ref_inplace = torch.Tensor.clamp_max_ + elif name == "maximum": + # because maximum ref does not support inplace or scalar + self.ref = torch.clamp_min + self.ref_inplace = torch.Tensor.clamp_min_ + + # The following sets `dtypesIfCUDA` and `dtypesIfROCM` accordingly. + super().__post_init__() + + def sample_zero_size_inputs(self, device, dtype, requires_grad=False, **kwargs): + if not hasattr(self.sample_inputs_func, "sample_zero_size_tensor_inputs"): + return [] + return self.sample_inputs_func.sample_zero_size_tensor_inputs( + self, device, dtype, requires_grad, **kwargs + ) + + +def gradcheck_wrapper_hermitian_input(op, input, *args, **kwargs): + """Gradcheck wrapper for functions that take Hermitian matrices as input. + + They require a modified function because the finite-difference algorithm + for calculating derivatives does not preserve the Hermitian property of the input. + """ + return op(input + input.mH, *args, **kwargs) + + +def gradcheck_wrapper_ctc_loss(op, input, *args, **kwargs): + """Gradcheck wrapper for ctc loss to project onto log-simplex space.""" + # See https://github.com/pytorch/pytorch/issues/52241 + return op(input.log_softmax(dim=2), *args, **kwargs) + + +def gradcheck_wrapper_triangular_input(op, *args, upper=False, idx=0, **kwargs): + """Gradcheck wrapper for functions that take lower or upper triangular matrices as input. + + They require a modified function because the finite-difference algorithm + for calculating derivatives does not preserve the triangular property of the input. + `idx` is used to specific which `args[idx]` is to be triangularized. + """ + triangular_arg = args[idx].triu() if upper else args[idx].tril() + return op(*args[:idx], triangular_arg, *args[idx + 1 :], upper, **kwargs) + + +def gradcheck_wrapper_triangular_input_real_positive_diagonal( + op, *args, upper=False, idx=0, **kwargs +): + """Gradcheck wrapper for functions that take lower/upper triangular matrices + with real and positive diagonals, for example, cholesky-like operations. + """ + arg = args[idx] + arg_diag = arg.diagonal(0, -2, -1) + arg_diag_embed = torch.diag_embed(arg_diag) + id_diag_tensor = torch.ones_like(arg_diag) + id_tensor = torch.diag_embed(id_diag_tensor) + # new_arg = arg - diag(arg) + I + new_arg = arg - arg_diag_embed + id_tensor + return gradcheck_wrapper_triangular_input( + op, *args[:idx], new_arg, *args[idx + 1 :], upper=upper, idx=idx, **kwargs + ) + + +def gradcheck_wrapper_masked_operation(op, input, *args, **kwargs): + """Gradcheck wrapper for masked operations. + + When mask is specified, replaces masked-out elements with zeros. + + Use for operations that produce non-finite masked-out elements, + for instance, for minimum and maximum reductions. + """ + output = op(input, *args, **kwargs) + mask = kwargs.get("mask") + if mask is not None: + output_mask = torch.masked._output_mask(op, input, *args, **kwargs) + output = torch.where(output_mask, output, output.new_zeros([])) + return output + + +def gradcheck_wrapper_masked_pointwise_operation(op, input, *args, **kwargs): + """Gradcheck wrapper for masked pointwise operations. Assumes that the result + will be masked iff both tensors are masked at a specific index + + When mask is specified, replaces masked-out elements with zeros. + + Use for operations that produce non-finite masked-out elements, + for instance, for minimum and maximum reductions. + """ + output = op(input, *args, **kwargs) + input_mask = kwargs.get("input_mask") + other_mask = kwargs.get("other_mask") + if input_mask is not None and other_mask is not None: + combined_mask = torch.logical_and(input_mask, other_mask) + new_kwargs = dict(mask=combined_mask, **kwargs) + output_mask = torch.masked._input_mask(input, *args, **new_kwargs) + output = torch.where(output_mask, output, output.new_zeros([])) + return output + + +def clone_sample(sample, **kwargs): + """ + Given a SampleInput, this function analyzes its input, args and kwargs, + and produces a copy with each non-Tensor entry being copied by reference, + and with each Tensor entry cloned with `t.clone().requires_grad_(t.requires_grad)` + """ + + def clone_tensor(t): + if isinstance(t, torch.Tensor): + return t.detach().clone().requires_grad_(t.requires_grad) + else: + return t + + sample_kwargs = kwargs if kwargs else sample.kwargs + + return SampleInput( + clone_tensor(sample.input), + args=tuple(map(clone_tensor, sample.args)), + kwargs={k: clone_tensor(v) for k, v in sample_kwargs.items()}, + ) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/opinfo/refs.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/opinfo/refs.py new file mode 100644 index 0000000000000000000000000000000000000000..435a9d113164b3652af4d246655f579d1b72d4dc --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/opinfo/refs.py @@ -0,0 +1,207 @@ +# mypy: ignore-errors + +from torch.testing._internal.opinfo.core import ( + BinaryUfuncInfo, + OpInfo, + ReductionOpInfo, + UnaryUfuncInfo, +) + + +# NOTE [Python References] +# Python References emulate existing PyTorch operations, but can ultimately +# be expressed in terms of "primitive" operations from torch._prims. +# +# These references are experimental. +# See https://dev-discuss.pytorch.org/t/tracing-with-primitives-update-0/577 +# for additional context. +# +# Python Reference OpInfos should be added to the python_ref_db list below. +# Tests can opt-into running on these references by including +# that list in the Sequence they pass to the @ops decorator. +# +# When a Python Reference OpInfo is constructed a pointer to an +# existing OpInfo must be provided using the torch_opinfo_name kwarg. +# The existing OpInfo with that name and no variant will be found +# to inherit from. +# +# Instead of just inheriting the existing OpInfo's metadata, the +# Python Reference OpInfos inherit the existing OpInfo's +# construction arguments. These arguments can be overridden +# by adding kwargs to the constructor. + + +def _find_referenced_opinfo(referenced_name, variant_name, *, op_db=None): + """ + Finds the OpInfo with the given name that has no variant name. + """ + # NOTE: searching the global op_db doesn't work when OpInfos are split into + # different modules, as otherwise the op_db will not be fully constructed + # yet. So, instead the local op_db must be passed in explicitly. + if op_db is None: + from torch.testing._internal.common_methods_invocations import op_db + + for opinfo in op_db: + if opinfo.name == referenced_name and opinfo.variant_test_name == variant_name: + return opinfo + + +def _inherit_constructor_args(name, op, inherited, overrides): + # inherits metadata + common_kwargs = { + "name": name, + "op": op, + "aliases": None, # TODO add a check for alias coverage + "method_variant": None, + "inplace_variant": None, # TODO: add a check for inplace coverage + "supports_scripting": False, + } + + # Acquires inherited kwargs + kwargs = inherited.copy() + + # Fixes metadata + if "kwargs" in kwargs: + kwargs.update(kwargs["kwargs"]) + del kwargs["kwargs"] + if "self" in kwargs: + del kwargs["self"] + if "__class__" in kwargs: + del kwargs["__class__"] + if "skips" in kwargs: + del kwargs["skips"] + if "decorators" in kwargs: + del kwargs["decorators"] + + # Overrides metadata + kwargs.update(common_kwargs) + kwargs.update(overrides) + + # At the moment no prims support autograd, so we must not run autograd + # tests e.g. when testing dtype support. Once we start writing autograd + # formulas for prims this can be removed. + kwargs["supports_autograd"] = False + kwargs["supports_gradgrad"] = False + kwargs["supports_fwgrad_bwgrad"] = False + kwargs["supports_inplace_autograd"] = False + kwargs["supports_forward_ad"] = False + + return kwargs + + +class PythonRefInfo(OpInfo): + """ + An OpInfo for a Python reference of an OpInfo base class operation. + """ + + def __init__( + self, + name, # the stringname of the callable Python reference + *, + op=None, # the function variant of the operation, populated as torch. if None + op_db=None, # The database of opinfos to search for the parent opinfo + torch_opinfo_name, # the string name of the corresponding torch opinfo + torch_opinfo_variant_name="", # the variant name for corresponding torch opinfo + validate_view_consistency=True, + **kwargs, + ): # additional kwargs override kwargs inherited from the torch opinfo + self.torch_opinfo_name = torch_opinfo_name + self.torch_opinfo_variant_name = torch_opinfo_variant_name + self.torch_opinfo = _find_referenced_opinfo( + torch_opinfo_name, torch_opinfo_variant_name, op_db=op_db + ) + self.validate_view_consistency = validate_view_consistency + assert isinstance(self.torch_opinfo, OpInfo) + + inherited = self.torch_opinfo._original_opinfo_args + ukwargs = _inherit_constructor_args(name, op, inherited, kwargs) + super().__init__(**ukwargs) + + +class ReductionPythonRefInfo(ReductionOpInfo): + """ + An OpInfo for a Python reference of an elementwise unary operation. + """ + + def __init__( + self, + name, # the stringname of the callable Python reference + *, + op=None, # the function variant of the operation, populated as torch. if None + op_db=None, # The database of opinfos to search for the parent opinfo + torch_opinfo_name, # the string name of the corresponding torch opinfo + torch_opinfo_variant_name="", # the variant name for corresponding torch opinfo + **kwargs, + ): # additional kwargs override kwargs inherited from the torch opinfo + self.torch_opinfo_name = torch_opinfo_name + self.torch_opinfo_variant_name = torch_opinfo_variant_name + self.torch_opinfo = _find_referenced_opinfo( + torch_opinfo_name, torch_opinfo_variant_name, op_db=op_db + ) + assert isinstance(self.torch_opinfo, ReductionOpInfo) + + inherited = self.torch_opinfo._original_reduction_args + ukwargs = _inherit_constructor_args(name, op, inherited, kwargs) + + # See https://github.com/pytorch/pytorch/issues/77216 + self.validate_view_consistency = False + + super().__init__(**ukwargs) + + +class ElementwiseUnaryPythonRefInfo(UnaryUfuncInfo): + """ + An OpInfo for a Python reference of an elementwise unary operation. + """ + + def __init__( + self, + name, # the stringname of the callable Python reference + *, + op=None, # the function variant of the operation, populated as torch. if None + op_db=None, # The database of opinfos to search for the parent opinfo + torch_opinfo_name, # the string name of the corresponding torch opinfo + torch_opinfo_variant_name="", # the variant name for corresponding torch opinfo + validate_view_consistency=True, + **kwargs, + ): # additional kwargs override kwargs inherited from the torch opinfo + self.torch_opinfo_name = torch_opinfo_name + self.torch_opinfo_variant_name = torch_opinfo_variant_name + self.torch_opinfo = _find_referenced_opinfo( + torch_opinfo_name, torch_opinfo_variant_name, op_db=op_db + ) + self.validate_view_consistency = validate_view_consistency + assert isinstance(self.torch_opinfo, UnaryUfuncInfo) + + inherited = self.torch_opinfo._original_unary_ufunc_args + ukwargs = _inherit_constructor_args(name, op, inherited, kwargs) + + super().__init__(**ukwargs) + + +class ElementwiseBinaryPythonRefInfo(BinaryUfuncInfo): + """ + An OpInfo for a Python reference of an elementwise binary operation. + """ + + def __init__( + self, + name, # the stringname of the callable Python reference + *, + op=None, # the function variant of the operation, populated as torch. if None + op_db=None, # The database of opinfos to search for the parent opinfo + torch_opinfo_name, # the string name of the corresponding torch opinfo + torch_opinfo_variant_name="", # the variant name for corresponding torch opinfo + **kwargs, + ): # additional kwargs override kwargs inherited from the torch opinfo + self.torch_opinfo_name = torch_opinfo_name + self.torch_opinfo_variant_name = torch_opinfo_variant_name + self.torch_opinfo = _find_referenced_opinfo( + torch_opinfo_name, torch_opinfo_variant_name, op_db=op_db + ) + assert isinstance(self.torch_opinfo, BinaryUfuncInfo) + + inherited = self.torch_opinfo._original_binary_ufunc_args + ukwargs = _inherit_constructor_args(name, op, inherited, kwargs) + + super().__init__(**ukwargs) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/opinfo/utils.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/opinfo/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..d9e2127e956b46c711961bf90d822a461b99aedd --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/opinfo/utils.py @@ -0,0 +1,276 @@ +# mypy: ignore-errors + +import collections +import warnings +from collections.abc import Sequence +from functools import partial, wraps + +import numpy as np +import numpy.typing as npt + +import torch +from torch.testing._internal.common_cuda import TEST_CUDA +from torch.testing._internal.common_dtype import ( + _dispatch_dtypes, + all_types, + all_types_and, + all_types_and_complex, + all_types_and_complex_and, + all_types_and_half, + complex_types, + floating_and_complex_types, + floating_and_complex_types_and, + floating_types, + floating_types_and, + floating_types_and_half, + integral_types, + integral_types_and, +) +from torch.testing._internal.common_utils import torch_to_numpy_dtype_dict + + +COMPLETE_DTYPES_DISPATCH = ( + all_types, + all_types_and_complex, + all_types_and_half, + floating_types, + floating_and_complex_types, + floating_types_and_half, + integral_types, + complex_types, +) + +EXTENSIBLE_DTYPE_DISPATCH = ( + all_types_and_complex_and, + floating_types_and, + floating_and_complex_types_and, + integral_types_and, + all_types_and, +) + +# Better way to acquire devices? +DEVICES = ["cpu"] + (["cuda"] if TEST_CUDA else []) + + +class _dynamic_dispatch_dtypes(_dispatch_dtypes): + # Class to tag the dynamically generated types. + pass + + +def get_supported_dtypes(op, sample_inputs_fn, device_type): + # Returns the supported dtypes for the given operator and device_type pair. + assert device_type in ["cpu", "cuda"] + if not TEST_CUDA and device_type == "cuda": + warnings.warn( + "WARNING: CUDA is not available, empty_dtypes dispatch will be returned!", + stacklevel=2, + ) + return _dynamic_dispatch_dtypes(()) + + supported_dtypes = set() + for dtype in all_types_and_complex_and(torch.bool, torch.bfloat16, torch.half): + try: + samples = sample_inputs_fn(op, device_type, dtype, False) + except RuntimeError: + # If `sample_inputs_fn` doesn't support sampling for a given + # `dtype`, we assume that the `dtype` is not supported. + # We raise a warning, so that user knows that this was the case + # and can investigate if there was an issue with the `sample_inputs_fn`. + warnings.warn( + f"WARNING: Unable to generate sample for device:{device_type} and dtype:{dtype}", + stacklevel=2, + ) + continue + + # We assume the dtype is supported + # only if all samples pass for the given dtype. + supported = True + for sample in samples: + try: + op(sample.input, *sample.args, **sample.kwargs) + except RuntimeError: + # dtype is not supported + supported = False + break + + if supported: + supported_dtypes.add(dtype) + + return _dynamic_dispatch_dtypes(supported_dtypes) + + +def dtypes_dispatch_hint(dtypes): + # Function returns the appropriate dispatch function (from COMPLETE_DTYPES_DISPATCH and EXTENSIBLE_DTYPE_DISPATCH) + # and its string representation for the passed `dtypes`. + return_type = collections.namedtuple("return_type", "dispatch_fn dispatch_fn_str") + + # CUDA is not available, dtypes will be empty. + if len(dtypes) == 0: + return return_type((), "()") + + set_dtypes = set(dtypes) + for dispatch in COMPLETE_DTYPES_DISPATCH: + # Short circuit if we get an exact match. + if set(dispatch()) == set_dtypes: + return return_type(dispatch, dispatch.__name__ + "()") + + chosen_dispatch = None + chosen_dispatch_score = 0.0 + for dispatch in EXTENSIBLE_DTYPE_DISPATCH: + dispatch_dtypes = set(dispatch()) + if not dispatch_dtypes.issubset(set_dtypes): + continue + + score = len(dispatch_dtypes) + if score > chosen_dispatch_score: + chosen_dispatch_score = score + chosen_dispatch = dispatch + + # If user passed dtypes which are lower than the lowest + # dispatch type available (not likely but possible in code path). + if chosen_dispatch is None: + return return_type((), str(dtypes)) + + return return_type( + partial(dispatch, *tuple(set(dtypes) - set(dispatch()))), + dispatch.__name__ + str(tuple(set(dtypes) - set(dispatch()))), + ) + + +def is_dynamic_dtype_set(op): + # Detect if the OpInfo entry acquired dtypes dynamically + # using `get_supported_dtypes`. + return op.dynamic_dtypes + + +def str_format_dynamic_dtype(op): + fmt_str = f""" + OpInfo({op.name}, + dtypes={dtypes_dispatch_hint(op.dtypes).dispatch_fn_str}, + dtypesIfCUDA={dtypes_dispatch_hint(op.dtypesIfCUDA).dispatch_fn_str}, + ) + """ + + return fmt_str + + +def np_unary_ufunc_integer_promotion_wrapper(fn): + # Wrapper that passes PyTorch's default scalar + # type as an argument to the wrapped NumPy + # unary ufunc when given an integer input. + # This mimics PyTorch's integer->floating point + # type promotion. + # + # This is necessary when NumPy promotes + # integer types to double, since PyTorch promotes + # integer types to the default scalar type. + + # Helper to determine if promotion is needed + def is_integral(dtype): + return dtype in [ + np.bool_, + bool, + np.uint8, + np.int8, + np.int16, + np.int32, + np.int64, + ] + + @wraps(fn) + def wrapped_fn(x): + # As the default dtype can change, acquire it when function is called. + # NOTE: Promotion in PyTorch is from integer types to the default dtype + np_dtype = torch_to_numpy_dtype_dict[torch.get_default_dtype()] + + if is_integral(x.dtype): + return fn(x.astype(np_dtype)) + return fn(x) + + return wrapped_fn + + +def reference_reduction_numpy(f, supports_keepdims=True): + """Wraps a NumPy reduction operator. + + The wrapper function will forward dim, keepdim, mask, and identity + kwargs to the wrapped function as the NumPy equivalent axis, + keepdims, where, and initiak kwargs, respectively. + + Args: + f: NumPy reduction operator to wrap + supports_keepdims (bool, optional): Whether the NumPy operator accepts + keepdims parameter. If it does not, the wrapper will manually unsqueeze + the reduced dimensions if it was called with keepdim=True. Defaults to True. + + Returns: + Wrapped function + + """ + + @wraps(f) + def wrapper(x: npt.NDArray, *args, **kwargs): + # Copy keys into a set + keys = set(kwargs.keys()) + + dim = kwargs.pop("dim", None) + keepdim = kwargs.pop("keepdim", False) + + if "dim" in keys: + dim = tuple(dim) if isinstance(dim, Sequence) else dim + + # NumPy reductions don't accept dim=0 for scalar inputs + # so we convert it to None if and only if dim is equivalent + if x.ndim == 0 and dim in {0, -1, (0,), (-1,)}: + kwargs["axis"] = None + else: + kwargs["axis"] = dim + + if "keepdim" in keys and supports_keepdims: + kwargs["keepdims"] = keepdim + + if "mask" in keys: + mask = kwargs.pop("mask") + if mask is not None: + assert mask.layout == torch.strided + kwargs["where"] = mask.cpu().numpy() + + if "identity" in keys: + identity = kwargs.pop("identity") + if identity is not None: + if identity.dtype is torch.bfloat16: + identity = identity.cpu().to(torch.float32) + else: + identity = identity.cpu() + kwargs["initial"] = identity.numpy() + + result = f(x, *args, **kwargs) + + # Unsqueeze reduced dimensions if NumPy does not support keepdims + if keepdim and not supports_keepdims and x.ndim > 0: + dim = list(range(x.ndim)) if dim is None else dim + result = np.expand_dims(result, dim) + + return result + + return wrapper + + +def prod_numpy(a, *args, **kwargs): + """ + The function will call np.prod with type as np.int64 if the input type + is int or uint64 if is uint. This is necessary because windows np.prod uses by default + int32 while on linux it uses int64. + This is for fixing integer overflow https://github.com/pytorch/pytorch/issues/77320 + + Returns: + np.prod of input + """ + if "dtype" not in kwargs: + if np.issubdtype(a.dtype, np.signedinteger): + a = a.astype(np.int64) + elif np.issubdtype(a.dtype, np.unsignedinteger): + a = a.astype(np.uint64) + + fn = reference_reduction_numpy(np.prod) + return fn(a, *args, **kwargs) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/optests/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/optests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e9125ba0ebe7e0623a12ad1a1cd7eeb7d2749a3a --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/optests/__init__.py @@ -0,0 +1,7 @@ +# mypy: ignore-errors + +from .make_fx import make_fx_check +from .aot_autograd import aot_autograd_check, _test_aot_autograd_forwards_backwards_helper +from .fake_tensor import fake_check +from .autograd_registration import autograd_registration_check +from .generate_tests import generate_opcheck_tests, opcheck, OpCheckError, dontGenerateOpCheckTests, is_inside_opcheck_mode diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/optests/aot_autograd.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/optests/aot_autograd.py new file mode 100644 index 0000000000000000000000000000000000000000..3c4d05a95a33e262e19efbb4cbb0d3a01d3dbf3b --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/optests/aot_autograd.py @@ -0,0 +1,175 @@ +# mypy: ignore-errors + +import torch +import torch.utils._pytree as pytree +from torch.testing._utils import wrapper_set_seed +from functorch.compile import compiled_function, min_cut_rematerialization_partition, default_partition, nop +from .make_fx import randomize +import re + + +class assert_raises_regex: + def __init__(self, exception_cls, regex): + self.exception_cls = exception_cls + self.regex = regex + + def __enter__(self): + pass + + def __exit__(self, exc_type, exc_val, traceback): + if exc_type == self.exception_cls: + msg = str(exc_val) + if not re.search(self.regex, msg): + raise AssertionError( + f"Expected exception to match regex. regex: {self.regex}, exception: {msg}") + return True # Squashes the exception + if exc_type is not None: + raise AssertionError( + f"Expected {self.exception_cls} to be raised, instead got exception {exc_type}") + raise AssertionError("Expected exception to be raised but none was") + + +def aot_autograd_check( + func, + args, + kwargs, + dynamic, + assert_raises_regex_fn=assert_raises_regex, + assert_equals_fn=torch.testing.assert_close, + check_gradients=True, + try_check_data_specialization=False, + skip_correctness_check=False, + disable_functionalization=False): + """Compares func(*args, **kwargs) in eager-mode to under AOTAutograd. + + Compares outputs and (if check_gradients=True) gradients produced by + AOTAutograd against eager-mode PyTorch. + + We assume that func(*args, **kwargs) succeeds in eager-mode PyTorch. + + """ + flat_args, args_spec = pytree.tree_flatten((args, kwargs)) + args = [arg for arg in flat_args if isinstance(arg, torch.Tensor)] + + # We construct a new function that only accepts Tensors as inputs + def func_no_tensors(args): + reconstructed_flat_args = [] + args = iter(args) + for v in flat_args: + if isinstance(v, torch.Tensor): + reconstructed_flat_args.append(next(args)) + else: + reconstructed_flat_args.append(v) + + c_args, c_kwargs = pytree.tree_unflatten(reconstructed_flat_args, args_spec) + return func(*c_args, **c_kwargs) + + # cannot use the min cut partitioner without functionalization + if disable_functionalization: + compiled_f = compiled_function( + func_no_tensors, + nop, + nop, + dynamic=dynamic, + partition_fn=default_partition, + keep_inference_input_mutations=True, + disable_functionalization=True + ) + else: + compiled_f = compiled_function( + func_no_tensors, + nop, + nop, + dynamic=dynamic, + partition_fn=min_cut_rematerialization_partition, + keep_inference_input_mutations=True, + disable_functionalization=False + ) + + out = wrapper_set_seed(func_no_tensors, args) + if check_gradients == "auto": + any_tensor_requires_grad = pytree.tree_any_only(torch.Tensor, lambda x: x.requires_grad, args) + any_output_requires_grad = pytree.tree_any_only(torch.Tensor, lambda x: x.requires_grad, out) + check_gradients = any_tensor_requires_grad and any_output_requires_grad + if not check_gradients: + compiled_out = wrapper_set_seed(compiled_f, args) + if not skip_correctness_check: + assert_equals_fn(compiled_out, out, msg=outputs_msg) + return + _test_aot_autograd_forwards_backwards_helper( + func_no_tensors, compiled_f, args, assert_raises_regex_fn, assert_equals_fn, + try_check_data_specialization, skip_correctness_check) + +outputs_msg = ( + "Outputs of the operator are different in eager-mode PyTorch vs " + "AOTDispatcher tracing. This means the operator will have incorrect output " + "underneath torch.compile. This could be because the operator's " + "implementation not traceable." +) + + +def _test_aot_autograd_forwards_backwards_helper( + f, compiled_f, args, assert_raises_regex_fn, assert_equals_fn, + try_check_data_specialization, skip_correctness_check=False): + # Verify grads are equal between compiled and non-compiled versions of f. + + def call_forwards_backwards(f, args): + flat_args = pytree.arg_tree_leaves(*args) + diff_args = [arg for arg in flat_args if isinstance(arg, torch.Tensor) and + arg.requires_grad] + out = wrapper_set_seed(f, args) + flat_out = pytree.tree_leaves(out) + + sm = 0 + for i in flat_out: + if isinstance(i, torch.Tensor): + # We need to call .abs() because it is possible that the output of the + # operator is a complex Tensor and autograd will yell at autograd.grad + # on a complex Tensor unless we manually provide the grad_output flag. + sm += i.sum().abs() + assert isinstance(sm, torch.Tensor) + return out, torch.autograd.grad(sm, diff_args, allow_unused=True) + + def check(args, ignore_failure=False): + try: + orig_out, orig_grad = call_forwards_backwards(f, args) + except Exception: + if ignore_failure: + return + raise + + # See https://github.com/pytorch/pytorch/pull/98960#issuecomment-1505962215 + tensor_args = [x for x in pytree.tree_flatten(args)[0] if isinstance(x, torch.Tensor)] + any_non_leaves = any(x.grad_fn is not None for x in tensor_args) + if all(x is None for x in orig_grad) and any_non_leaves: + with assert_raises_regex_fn(RuntimeError, 'does not require grad and does not have a grad_fn'): + call_forwards_backwards(compiled_f, args) + return + + msg = ( + "Gradients of the operator are different in eager-mode PyTorch vs " + "AOTDispatcher. This means the operator will have incorrect gradients " + "underneath torch.compile. This could be because the operator's " + "backward is incorrectly registered or not traceable." + ) + + compiled_out, compiled_grad = call_forwards_backwards(compiled_f, args) + if not skip_correctness_check: + try: + assert_equals_fn(compiled_out, orig_out) + except Exception as e: + raise type(e)(outputs_msg) from e + try: + assert_equals_fn(compiled_grad, orig_grad) + except Exception as e: + raise type(e)(msg) from e + + check(args, ignore_failure=False) + + # Randomize the data and run the traced graph with it, to catch bugs + # where we may have baked in Tensor data into the trace. + # This is not guaranteed to succeed, because `f` might have preconditions + # on the values of the inputs, so we just ignore if this test fails. + if try_check_data_specialization: + args = randomize(args) + check(args, ignore_failure=True) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/optests/autograd_registration.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/optests/autograd_registration.py new file mode 100644 index 0000000000000000000000000000000000000000..ae5ae34059eaa3d7ae1197699638f52f86538b02 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/optests/autograd_registration.py @@ -0,0 +1,134 @@ +# mypy: ignore-errors + +import contextlib + +import torch +import torch.utils._pytree as pytree + + +@contextlib.contextmanager +def set_autograd_fallback_mode(mode): + prev = torch._C._get_autograd_fallback_mode() + try: + torch._C._set_autograd_fallback_mode(mode) + yield + finally: + torch._C._set_autograd_fallback_mode(prev) + + +def autograd_registration_check(op, args, kwargs): + """Check if autograd was registered correctly (for the operator). + + Operators should have "autograd support" registered directly to an + autograd dispatch key. + An incorrect registration may lead to unexpected silent incorrectness. + Note that this check won't catch all problems but will catch + the most common ones. + + Example usage: + >>> x = torch.randn(3, requires_grad=True) + >>> autograd_registration_check(torch.ops.aten.sin.default, (x,), {}) + + Here are some best practices if you do find your autograd is + registered incorrectly: + - If the operator is composite (i.e. consists of other PyTorch ops) + and you wish the operator to decompose and get autograd support + that way, then please register the implementation to + DispatchKey::CompositeImplicitAutograd + - If you're adding an autograd formula for the operator, the correct + thing to do is to register an autograd.Function to + DispatchKey::Autograd (preferred) or one of the + DispatchKey::Autograd keys. It is NOT OK to register + an autograd.Function to a backend (e.g. CPU/CUDA) key. + - If your operator is non-differentiable, then you should register + an implementation to the Autograd key that uses + AutoDispatchBelowAutograd and re-invokes the operator. + + """ + assert isinstance(op, torch._ops.OpOverload) + # Implementation details + # ----------------------------------------------- + # If an operator doesn't have an autograd kernel at an autograd key, + # and the operator does not return inputs as-is, then all of + # the outputs should have requires_grad=False before we apply + # special behaviors of our default autograd fallback. + # (The default autograd fallback may set requires_grad=True on output + # tensors in certain modes so that when they are backpropped through, + # they raise an error). + # + # Our strategy for detecting if an operator doesn't have an autograd + # kernel at the autograd key is: + # - set the autograd fallback mode to "nothing" (so it does not change + # the required-gradness of outputs) + # - run the operator + # - Check if any outputs of the operator (that are not inputs) require + # grad. This would only happen if the user calls regular PyTorch + # operations in their backend key (this op should instead be + # CompositeImplicitAutograd or not an op) or if the user invokes + # an autograd.Function in the backend key. + # + # Note that it's already likely a bug if the operator directly returns + # an input as output (because custom ops don't have a good way of + # constructing true in-place or out variants), but we defer that + # responsibility to a different test (schema_check). + + flat_args = pytree.arg_tree_leaves(*args, **kwargs) + all_tensors = [arg for arg in flat_args if isinstance(arg, torch.Tensor)] + if not any(t.requires_grad for t in all_tensors): + raise RuntimeError( + "autograd_registration_check: no inputs have requires_grad=True so " + "we are unable to actually perform this test. Please pass inputs " + "that do require grad." + ) + + # Determine which AutogradBACKEND key to check + all_device_types = {arg.device.type for arg in all_tensors} + if not all_device_types.issubset(["cpu", "cuda", "xpu"]): + # Don't want to support other keys yet + raise NotImplementedError( + f"autograd_registration_check: NYI devices other than CPU/CUDA/XPU, got {all_device_types}" + ) + if "cuda" in all_device_types: + key = "AutogradCUDA" + elif "cpu" in all_device_types: + key = "AutogradCPU" + elif "xpu" in all_device_types: + key = "AutogradXPU" + + if torch._C._dispatch_has_kernel_for_dispatch_key(op.name(), key): + return + if torch._C._dispatch_has_kernel_for_dispatch_key(op.name(), "Autograd"): + return + if torch._C._dispatch_has_kernel_for_dispatch_key( + op.name(), "CompositeImplicitAutograd" + ): + return + + # At this point, we know the operator doesn't have a kernel registered to an + # autograd key. Let's proceed with our test. + with set_autograd_fallback_mode("nothing"): + all_outs = op(*args, **kwargs) + + inp_ids = {id(arg) for arg in flat_args} + + def not_an_input_and_requires_grad(tensor): + if not tensor.requires_grad: + return False + if id(tensor) in inp_ids: + return False + return True + + if not pytree.tree_any_only(torch.Tensor, not_an_input_and_requires_grad, all_outs): + return + + raise AssertionError( + f"{op.name()}: at least one output of this operator has requires_grad=True " + f"but the operator does not have an autograd kernel defined at an autograd " + f"key (e.g. DispatchKey::Autograd). This could mean that you have " + f"incorrectly registered an autograd kernel to a non-Autograd DispatchKey, " + f"which may lead to silently incorrect results. If your operator consists " + f"of regular PyTorch operations, consider not using an operator at all " + f"or registering your operator as CompositeImplicitAutograd. If you have " + f"an autograd.Function registered to a backend (CPU/CUDA/XPU) key, the correct " + f"location for it is the Autograd key." + ) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/optests/fake_tensor.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/optests/fake_tensor.py new file mode 100644 index 0000000000000000000000000000000000000000..5e60f50189b5dc3ab43fdd97120d5fa23559a84e --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/optests/fake_tensor.py @@ -0,0 +1,12 @@ +# mypy: ignore-errors + +import torch._subclasses + + +def is_builtin(op): + return op.namespace in ('aten', 'prims', 'prim') + + +def fake_check(op, args, kwargs): + with torch._subclasses.CrossRefFakeMode(ignore_op_fn=is_builtin): + op(*args, **kwargs) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/optests/generate_tests.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/optests/generate_tests.py new file mode 100644 index 0000000000000000000000000000000000000000..398425853f09adccce056b4115042f9379a1a9b3 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/optests/generate_tests.py @@ -0,0 +1,852 @@ +# mypy: ignore-errors + +import datetime +import difflib +import functools +import inspect +import json +import os +import re +import tempfile +import threading +import unittest +from collections.abc import Callable, Sequence +from typing import Any, Optional, Union + +import torch +import torch._dynamo +import torch.utils._pytree as pytree +from torch._dynamo.utils import clone_input +from torch._library.custom_ops import CustomOpDef +from torch._subclasses.schema_check_mode import SchemaCheckMode +from torch._utils_internal import get_file_path_2 +from torch.overrides import TorchFunctionMode +from torch.testing._internal.optests import ( + aot_autograd_check, + autograd_registration_check, + fake_check, +) + + +def dontGenerateOpCheckTests(reason: str): + def inner(fun): + fun._torch_dont_generate_opcheck_tests = True + return fun + + return inner + + +def is_abstract(tensor: torch.Tensor) -> bool: + if tensor.is_meta: + return True + if torch._subclasses.fake_tensor.is_fake(tensor): + return True + return False + + +def safe_schema_check( + op: torch._ops.OpOverload, + args: tuple[Any, ...], + kwargs: dict[str, Any], + *, + copy_inputs: bool = True, + rtol: Optional[float] = None, + atol: Optional[float] = None, +) -> Any: + if copy_inputs: + args, kwargs = deepcopy_tensors((args, kwargs)) + if pytree.tree_any_only(torch.Tensor, is_abstract, (args, kwargs)): + return None + with SchemaCheckMode(): + result = op(*args, **kwargs) + return result + + +def safe_autograd_registration_check( + op: torch._ops.OpOverload, + args: tuple[Any, ...], + kwargs: dict[str, Any], + *, + copy_inputs: bool = True, + rtol: Optional[float] = None, + atol: Optional[float] = None, +) -> None: + if pytree.tree_any_only(torch.Tensor, is_abstract, (args, kwargs)): + return + if copy_inputs: + args, kwargs = deepcopy_tensors((args, kwargs)) + # Don't perform autograd_registration_check if none of the inputs require grad. + if not pytree.tree_any_only( + torch.Tensor, lambda x: x.requires_grad, (args, kwargs) + ): + return + return autograd_registration_check(op, args, kwargs) + + +def safe_fake_check( + op: torch._ops.OpOverload, + args: tuple[Any, ...], + kwargs: dict[str, Any], + *, + copy_inputs: bool = True, + rtol: Optional[float] = None, + atol: Optional[float] = None, +) -> None: + if pytree.tree_any_only(torch.Tensor, is_abstract, (args, kwargs)): + return None + if copy_inputs: + args, kwargs = deepcopy_tensors((args, kwargs)) + return fake_check(op, args, kwargs) + + +def safe_aot_autograd_check( + op: torch._ops.OpOverload, + args: tuple[Any, ...], + kwargs: dict[str, Any], + dynamic: bool, + *, + copy_inputs: bool = True, + rtol: Optional[float] = None, + atol: Optional[float] = None, +) -> Any: + # NB: copy_inputs does nothing for aot_autograd_check: it always needs to copy + # inputs. + if pytree.tree_any_only(torch.Tensor, is_abstract, (args, kwargs)): + return None + + def func(*args, **kwargs): + args, kwargs = pytree.tree_map_only(torch.Tensor, torch.clone, (args, kwargs)) + return op(*args, **kwargs) + + # aot_autograd_check runs func(*args, **kwargs) multiple times + # and assumes `func` does not modify its inputs. + if rtol and atol: + assert_equals_fn = functools.partial( + torch.testing.assert_close, rtol=rtol, atol=atol + ) + else: + assert_equals_fn = torch.testing.assert_close + return aot_autograd_check( + func, + args, + kwargs, + dynamic, + check_gradients="auto", + assert_equals_fn=assert_equals_fn, + ) + + +def deepcopy_tensors(inputs: Any) -> Any: + return pytree.tree_map_only(torch.Tensor, clone_input, inputs) + + +# Test util requirements +# - The test util must have signature (op: OpOverload, args, kwargs) +# - The test util must NOT mutate args, kwargs. +# - The test utils in this list must not be prefixes of each other. For example, +# having both "test_schema" and "test_schema_is_functional" is NOT OK. +# - The order of items in this dict matters (for opcheck), we'll run them +# in order. +ALL_TEST_UTILS = { + "test_schema": safe_schema_check, + "test_autograd_registration": safe_autograd_registration_check, + "test_faketensor": safe_fake_check, + "test_aot_dispatch_static": functools.partial( + safe_aot_autograd_check, + dynamic=False, + ), + "test_aot_dispatch_dynamic": functools.partial( + safe_aot_autograd_check, + dynamic=True, + ), +} + +GDOC = "https://docs.google.com/document/d/1Pj5HRZvdOq3xpFpbEjUZp2hBovhy7Wnxw14m6lF2154/edit" + +DEFAULT_TEST_UTILS = [ + "test_schema", + "test_autograd_registration", + "test_faketensor", + "test_aot_dispatch_dynamic", +] + +DEPRECATED_DEFAULT_TEST_UTILS = DEFAULT_TEST_UTILS + [ + "test_aot_dispatch_static", +] + + +def generate_opcheck_tests( + testcase: Any, + namespaces: list[str], + failures_dict_path: Optional[str] = None, + additional_decorators: Optional[dict[str, Callable]] = None, + test_utils: list[str] = DEFAULT_TEST_UTILS, +) -> None: + """Given an existing TestCase, use the existing tests to generate + additional validation tests for custom operators. + + For {all existing tests in the TestCase} x {all test utils}, + we will generate one new test. The new test runs a TorchFunctionMode + that intercepts ``op(*args, **kwargs)`` calls and invokes + ``test_util(op, *args, **kwargs)``, where ``op`` is an operator. + + The test_util that we support are in ALL_TEST_UTILS. They are: + - test_schema: This runs SchemaCheckMode. + - test_autograd_registration: This runs autograd_registration_check. + - test_faketensor: This runs CrossRefFakeMode. + - test_aot_dispatch_static: This runs aot_autograd_check, which: + checks that the outputs (and gradients, if they are computable) + are the same under eager-mode PyTorch and using AOTAutograd. + - test_aot_dispatch_dynamic: Same as aot_dispatch_static, but + runs AOTAutograd using dynamic shapes instead of static shapes. + + The generated test will have name ``{test_util}__{original_name}``. + For example, if there is a method named ``test_cumsum``, then + we will generate a ``test_schema__test_cumsum``, + ``test_faketensor__test_cumsum``, etc. + + For more details, see https://docs.google.com/document/d/1Pj5HRZvdOq3xpFpbEjUZp2hBovhy7Wnxw14m6lF2154/edit + + Args: + testcase: The testcase we will modify and generate additional tests for. + namespaces: We will only intercept calls to custom operators with these + namespaces. + failures_dict_path: See ``validate_failures_dict_structure`` for more details + test_utils: a list of test_utils to generate. Example: ["test_schema", "test_faketensor"] + """ + if additional_decorators is None: + additional_decorators = {} + test_methods = [ + m + for m in dir(testcase) + if m.startswith("test_") and callable(getattr(testcase, m)) + ] + if failures_dict_path is None: + # The default failures_dict_path is failures_dict.json in + # the same directory as the test file. + prev_frame = inspect.currentframe().f_back + filename = inspect.getframeinfo(prev_frame)[0] + failures_dict_path = get_file_path_2( + os.path.dirname(filename), "failures_dict.json" + ) + failures_dict = FailuresDict.load( + failures_dict_path, create_file=should_update_failures_dict() + ) + validate_failures_dict_structure(failures_dict, test_utils, testcase) + validate_failures_dict_formatting(failures_dict_path) + + def construct_method(attr, prefix, tester): + method = getattr(testcase, attr) + if getattr(method, "_torch_dont_generate_opcheck_tests", False): + return + new_method_name = prefix + "__" + attr + + @functools.wraps(method) + def new_method(*args, **kwargs): + with OpCheckMode( + namespaces, + prefix, + tester, + failures_dict, + f"{testcase.__name__}.{new_method_name}", + failures_dict_path, + ): + result = method(*args, **kwargs) + return result + + if pytestmark := new_method.__dict__.get("pytestmark"): + import pytest + + # check if we need to simplify the parametrize marks + # NB: you need to add this mark to your pytest.ini + opcheck_only_one = False + for mark in pytestmark: + if isinstance(mark, pytest.Mark) and mark.name == "opcheck_only_one": + opcheck_only_one = True + + if opcheck_only_one: + new_pytestmark = [] + for mark in pytestmark: + if isinstance(mark, pytest.Mark) and mark.name == "parametrize": + argnames, argvalues = mark.args + assert not mark.kwargs, "NYI" + # Special case for device, we want to run on all + # devices + if argnames != "device": + new_pytestmark.append( + pytest.mark.parametrize( + argnames, (next(iter(argvalues)),) + ) + ) + continue + new_pytestmark.append(mark) + new_method.__dict__["pytestmark"] = new_pytestmark + + if new_method_name in additional_decorators: + for dec in additional_decorators[new_method_name]: + new_method = dec(new_method) + + if hasattr(testcase, new_method_name): + raise RuntimeError( + f"Tried to autogenerate {new_method_name} but {testcase} already " + f"has method named {new_method_name}. Please rename the original " + f"method on the TestCase." + ) + setattr(testcase, new_method_name, new_method) + + test_utils = {name: ALL_TEST_UTILS[name] for name in test_utils} + for attr in test_methods: + for prefix, tester in test_utils.items(): + construct_method(attr, prefix, tester) + + generate_tag_tests(testcase, failures_dict, additional_decorators) + + +def generate_tag_tests(testcase, failures_dict, additional_decorators): + def generate_test(qualname, definitely_not_pt2_compliant, xfailed_tests): + def inner(self): + try: + op = torch._library.utils.lookup_op(qualname) + except AttributeError as e: + # Operator not importable in this test file + raise unittest.SkipTest(f"Can't import operator {qualname}") from e + op_marked_as_compliant = torch.Tag.pt2_compliant_tag in op.tags + if not op_marked_as_compliant: + return + if not definitely_not_pt2_compliant: + return + raise AssertionError( + f"op '{qualname}' was tagged with torch.Tag.pt2_compliant_tag " + f"but it failed some of the generated opcheck tests " + f"({xfailed_tests}). This may lead to silent correctness issues, " + f"please fix this." + ) + + return inner + + for qualname, test_dict in failures_dict.data.items(): + xfailed_tests = [ + test + for test, status_dict in test_dict.items() + # We're about to delete the following test after Ed's PR + # to specialize on C++ .size() calls + if "test_aot_dispatch_static" not in test + and status_dict["status"] == "xfail" + ] + definitely_not_pt2_compliant = len(xfailed_tests) > 0 + generated = generate_test(qualname, definitely_not_pt2_compliant, xfailed_tests) + + # Could result in collisions, but unlikely. We'll raise if we see one below. + mangled_qualname = qualname.replace("::", "_").replace(".", "_") + test_name = "test_pt2_compliant_tag_" + mangled_qualname + + # You can skip this test via the additional_decorators argument + # in generate_opcheck_tests + if test_name in additional_decorators: + for decorator in additional_decorators[test_name]: + generated = decorator(generated) + + if hasattr(testcase, test_name): + raise RuntimeError( + f"Tried to generate a test named {test_name}, but it exists " + f"already. This could be because of a name collision (where " + f"we generated two tests with the same name), or where we " + f"generated a test with the same name as an existing test." + ) + setattr(testcase, test_name, generated) + + +TEST_OPTIONS = ("xfail", "skip", "xsuccess") + + +def validate_failures_dict_formatting(failures_dict_path: str) -> None: + with open(failures_dict_path) as fp: + actual = fp.read() + failures_dict = FailuresDict.load(failures_dict_path) + expected = failures_dict._save(to_str=True) + if actual == expected: + return + if should_update_failures_dict(): + failures_dict = FailuresDict.load(failures_dict_path) + failures_dict.save() + return + expected = expected.splitlines(1) + actual = actual.splitlines(1) + diff = difflib.unified_diff(actual, expected) + diff = "".join(diff) + raise RuntimeError( + f"\n{diff}\n\nExpected the failures dict to be formatted " + f"a certain way. Please see the above diff; you can correct " + f"this either manually or by re-running the test with " + f"PYTORCH_OPCHECK_ACCEPT=1" + ) + + +def validate_failures_dict_structure( + failure_dict: "FailuresDict", test_utils: list[str], testcase: Any +) -> None: + """Validates the failures dict. + + The failure dict looks something like the following. + It maps operator name (qualname) to a list of autogenerated tests. + Each autogenerated test may have a check for the operator (if the operator is + called by the test); the dictionary specifies if we should skip the check, + or if we expect some check to fail. + + { + "fbgemm::split_lengths": { + "test_schema__test_split_lengths": { + "comment": "you can put whatever you want into the comment section", + "status": "xfail", + } + "test_schema__test_split_lengths_empty": { + "comment": "", + "status": "skip", + }, + }, + "fbgemm::gather_lengths": { + "test_schema__test_gather_lengths": { + "comment": "", + "status": "skip", + }, + }, + } + + """ + failure_dict = failure_dict.data + for test_to_option in failure_dict.values(): + for test_name, test_dict in test_to_option.items(): + if set(test_dict.keys()) != set({"comment", "status"}): + raise RuntimeError( + "in failures_dict, expected sub-dict to have keys 'comment' and 'status'" + ) + test_option = test_dict["status"] + if test_option not in TEST_OPTIONS: + raise RuntimeError( + f"In failures_dict, got status={test_option} but it needs to be in {TEST_OPTIONS}" + ) + test_class, actual_test_name = test_name.split(".") + if not any(actual_test_name.startswith(test) for test in test_utils): + raise RuntimeError( + f"In failures_dict, test name '{test_name}' should begin with one of {test_utils}" + ) + for test in test_utils: + if not actual_test_name.startswith(test): + continue + base_test_name = actual_test_name[len(test) + 2 :] + # remove potential pytest parametrization suffix + base_test_name = re.sub(r"\[.*\]", "", base_test_name) + if testcase.__name__ != test_class: + continue + if hasattr(testcase, base_test_name): + continue + raise RuntimeError( + f"In failures dict, got test name '{test_name}'. We parsed this as " + f"running test '{test}' on '{base_test_name}', but " + f"{base_test_name} does not exist on the TestCase '{testcase.__name__}]. " + f"Maybe you need to change the test name?" + ) + + +def should_update_failures_dict() -> bool: + key = "PYTORCH_OPCHECK_ACCEPT" + return key in os.environ and os.environ[key] == "1" + + +_is_inside_opcheck_mode = threading.local() +_is_inside_opcheck_mode.value = False + + +def is_inside_opcheck_mode(): + return _is_inside_opcheck_mode.value + + +class OpCheckMode(TorchFunctionMode): + """ + For a given test, OpCheckMode intercepts calls to operators and runs + test_util(op, args, kwargs) for each intercepted (op, args, kwargs). + """ + + def __init__( + self, + namespaces: list[str], + test_util_name: str, + test_util: Callable, + failures_dict: "FailuresDict", + test_name: str, + failures_dict_path: str, + ): + # We will intercept calls to ops with these namespaces + self.namespaces = namespaces + # The test utility function. Its signature should be (op, args, kwargs) -> None. + # Examples of test utilities are: schema_check, make_fx_check + self.test_util = test_util + self.test_util_name = test_util_name + # The name of the test that is running this OpCheckMode. + self.test_name = test_name + # Maps qualname -> test_name -> skip/xfail + # Tells us if we should skip a test or assert that there is a failure. + self.failures_dict = failures_dict + # Location of the failures dict. Makes it so that the error message is better. + self.failures_dict_path = failures_dict_path + + # OpCheckMode suppresses errors, collects them here, and then raises them on exit. + # Maps qualname -> List[(Exception, func, maybe args, maybe kwargs)] + self.seen_ops_to_errors = {} + + def maybe_raise_errors_on_exit(self) -> None: + # Check expected failures first + for qualname in self.seen_ops_to_errors: + option = self.failures_dict.get_status(qualname, self.test_name) + if len(self.seen_ops_to_errors[qualname]) == 0: + if should_update_failures_dict(): + self.failures_dict.set_status( + qualname, self.test_name, "xsuccess", comment="" + ) + else: + if option == "xfail": + raise OpCheckError( + f"generate_opcheck_tests: Unexpected success for operator " + f"{qualname} on test {self.test_name}. This may mean that " + f"you have fixed this test failure. Please rerun the test with " + f"PYTORCH_OPCHECK_ACCEPT=1 to automatically update the test runner " + f"or manually remove the " + f"expected failure in the failure dict at " + f"{self.failures_dict_path}" + f"For more details, see " + f"{GDOC}" + ) + continue + failed_ops = [] + for qualname in self.seen_ops_to_errors: + option = self.failures_dict.get_status(qualname, self.test_name) + if option != "xsuccess": + continue + if len(self.seen_ops_to_errors[qualname]) == 0: + continue + failed_ops.append(qualname) + if not failed_ops: + return + + if should_update_failures_dict(): + for op in failed_ops: + self.failures_dict.set_status(op, self.test_name, "xfail") + return + + # Raise from the first error but also report about all of them to make + # recording xfails easier. + ex, op, args, kwargs = self.seen_ops_to_errors[failed_ops[0]][0] + repro_command = generate_repro( + self.test_util_name, op, args, kwargs, save_data=should_print_better_repro() + ) + raise OpCheckError( + f"Test generated by `generate_opcheck_tests`, {self.test_name}, " + f"failed on operators {failed_ops}. This usually means that the " + f"operators are not implemented correctly and may lead to silently " + f"incorrect behavior. Set PYTORCH_OPCHECK_PRINT_BETTER_REPRO=1 for a standalone repro, " + f"or please see " + f"{GDOC} " + f"for more recommendations. " + f"To reproduce this problem locally, try to run the following:\n{repro_command}" + ) from ex + + def __enter__(self, *args, **kwargs): + self.prev_is_opcheck_mode = _is_inside_opcheck_mode.value + self.prev_dynamo_disable = os.environ.get("TORCHDYNAMO_DISABLE", "") + _is_inside_opcheck_mode.value = True + os.environ["TORCHDYNAMO_DISABLE"] = "1" + return super().__enter__(*args, **kwargs) + + def __exit__(self, *args, **kwargs): + _is_inside_opcheck_mode.value = self.prev_is_opcheck_mode + os.environ["TORCHDYNAMO_DISABLE"] = self.prev_dynamo_disable + try: + self.maybe_raise_errors_on_exit() + if should_update_failures_dict(): + self.failures_dict.save() + finally: + result = super().__exit__(*args, **kwargs) + return result + + def run_test_util(self, op, args, kwargs): + try: + self.test_util(op, args, kwargs, copy_inputs=False) + except torch._subclasses.fake_tensor.UnsupportedFakeTensorException: + # We might get here if the input is already a FakeTensor + # or if we're in a torch.compile block. Just ignore these + # since we can't handle them and reporting them as failures + # is too noisy. + pass + + def __torch_function__(self, func, types, args=(), kwargs=None): + kwargs = kwargs if kwargs else {} + + # Only intercept calls to operators + if not isinstance(func, (torch._ops.OpOverloadPacket, torch._ops.OpOverload)): + return func(*args, **kwargs) + if ( + torch.jit.is_tracing() + or torch.jit.is_scripting() + or torch._dynamo.is_compiling() + ): + return func(*args, **kwargs) + # Pre-existing code may not use the .default overload. If we see an + # OpOverloadPacket and we cannot resolve the overload, then we just throw + # and ask the user to clarify. Otherwise, we attempt to resolve the overload. + if isinstance(func, torch._ops.OpOverloadPacket): + func = resolve_unique_overload_or_throw(func) + qualname = func.name() + ns = qualname.split("::")[0] + if ns not in self.namespaces: + return func(*args, **kwargs) + + args_c, kwargs_c = deepcopy_tensors((args, kwargs)) + result = func(*args, **kwargs) + + option = self.failures_dict.get_status(qualname, self.test_name) + if option == "xsuccess" or option == "xfail": + # Suppress all errors during execution. Raise them during __exit__. + try: + if qualname not in self.seen_ops_to_errors: + self.seen_ops_to_errors[qualname] = [] + self.run_test_util(func, args_c, kwargs_c) + except Exception as ex: + if should_print_better_repro(): + self.seen_ops_to_errors[qualname].append((ex, func, args, kwargs)) + else: + self.seen_ops_to_errors[qualname].append((ex, func, None, None)) + elif option == "skip": + pass + return result + + +def should_print_better_repro() -> None: + """If set, the tests generated by `generate_opcheck_tests` will print a + repro command on failure. + + In order to print the repro command, we need to save some tensors to disk. + These will be saved under the following directory: + {tempfile.gettempdir()}/pytorch_opcheck_safe_to_delete/. + + Although this is a temp folder, it will usually not automatically get cleaned + up, so you'll need to manually delete it. + """ + key = "PYTORCH_OPCHECK_PRINT_BETTER_REPRO" + if key not in os.environ: + return False + value = os.environ[key] + return value == "1" or value == 1 + + +def opcheck( + op: Union[torch._ops.OpOverload, torch._ops.OpOverloadPacket, CustomOpDef], + args: tuple[Any, ...], + kwargs: Optional[dict[str, Any]] = None, + *, + test_utils: Union[str, Sequence[str]] = DEFAULT_TEST_UTILS, + raise_exception: bool = True, + rtol: Optional[float] = None, + atol: Optional[float] = None, +) -> dict[str, str]: + """See torch.library.opcheck for docstring""" + + if (rtol is None) ^ (atol is None): + raise ValueError( + "opcheck(op, ...): if you specify one of rtol/atol, you must specify both" + ) + + if kwargs is None: + kwargs = {} + if isinstance(op, CustomOpDef): + op = op._opoverload + if isinstance(op, torch._ops.OpOverloadPacket): + op = resolve_unique_overload_or_throw(op) + if not isinstance(op, torch._ops.OpOverload): + raise ValueError( + f"opcheck(op, ...): op must be instance of torch._ops.OpOverload, " + f"e.g. torch.ops.aten.sin.default, got {type(op)}" + ) + if test_utils == "ALL": + test_utils = tuple(ALL_TEST_UTILS.keys()) + if isinstance(test_utils, str): + test_utils = (test_utils,) + if not isinstance(test_utils, (tuple, list)) or not set(test_utils).issubset( + ALL_TEST_UTILS.keys() + ): + raise ValueError( + f"opcheck(op, ..., test_utils={test_utils}), expected test_utils " + f"to be subset of {tuple(ALL_TEST_UTILS.keys())} but it was not" + ) + + results_dict = {} + for test_util in test_utils: + tester = ALL_TEST_UTILS[test_util] + try: + tester(op, args, kwargs, rtol=rtol, atol=atol) + results_dict[test_util] = "SUCCESS" + except Exception as ex: + if raise_exception: + raise OpCheckError( + f"opcheck(op, ...): {test_util} failed with {ex} " + f"(scroll up for stack trace)" + ) from ex + results_dict[test_util] = ex + return results_dict + + +class OpCheckError(Exception): + pass + + +def generate_repro( + test: str, + op: torch._ops.OpOverload, + args: tuple[Any, ...], + kwargs: dict[str, Any], + *, + save_data: bool, + dry_run: bool = False, +) -> str: + if save_data: + now = datetime.datetime.now() + path = os.path.join(tempfile.gettempdir(), "pytorch_opcheck_safe_to_delete") + unix_timestamp = datetime.datetime.timestamp(now) * 100000 + filepath = os.path.join(path, f"repro_{unix_timestamp}.pt") + if not dry_run: + os.makedirs(path, exist_ok=True) + torch.save((args, kwargs), filepath) + args_kwargs = f'args, kwargs = torch.load("{filepath}")' + else: + args_kwargs = ( + "# If you rerun your test with PYTORCH_OPCHECK_PRINT_BETTER_REPRO=1\n" + "# we will fill them in same (args, kwargs) as in your test\n" + "args = () # args to the operator\n" + "kwargs = {} # kwargs to the operator" + ) + + ns, name = op._schema.name.split("::") + overload = op._overloadname + + repro_command = ( + f"# =========================================================\n" + f"# BEGIN REPRO SCRIPT\n" + f"# =========================================================\n" + f"import torch\n" + f"from torch.testing._internal.optests import opcheck\n" + f"\n" + f"# Make sure you have loaded the library that contains the op\n" + f"# via an import or torch.ops.load_library(...)\n" + f"op = torch.ops.{ns}.{name}.{overload}\n" + f"\n" + f"{args_kwargs}\n" + f'opcheck(op, args, kwargs, test_utils="{test}")\n' + f"# =========================================================\n" + f"# END REPRO SCRIPT\n" + f"# =========================================================\n" + ) + return repro_command + + +def resolve_unique_overload_or_throw( + op: torch._ops.OpOverloadPacket, +) -> torch._ops.OpOverload: + all_schemas = torch._C._jit_get_schemas_for_operator(op._qualified_op_name) + if len(all_schemas) != 1: + raise RuntimeError( + f"opcheck can only test operators without overloads. " + f"Got the following overloads for {op._qualified_op_name}: " + f"{[schema.overload_name for schema in all_schemas]}" + ) + + overload_name = all_schemas[0].overload_name + if overload_name == "": + return op.default + return getattr(op, overload_name) + + +DUMP_OPTIONS = {"indent": 2, "sort_keys": True} + + +FailuresDictData = dict[str, dict[str, dict[str, str]]] + + +VERSION = 1 +DESCRIPTION = ( + f"This is a dict containing failures for tests autogenerated by " + f"generate_opcheck_tests. " + f"For more details, please see {GDOC}" +) + + +class FailuresDict: + def __init__(self, path: str, data: FailuresDictData): + self.path = path + self.data = data + + @staticmethod + def load(path, *, create_file=False) -> "FailuresDict": + if create_file and not os.path.exists(path): + result = FailuresDict(path, {}) + FailuresDict.save() + return result + with open(path) as fp: + contents = fp.read() + if contents.strip() == "": + dct = { + "_description": DESCRIPTION, + "data": {}, + "_version": VERSION, + } + else: + dct = json.loads(contents) + assert "data" in dct + assert "_version" in dct and dct["_version"] == VERSION + return FailuresDict(path, dct["data"]) + + def _save(self, to_str=False) -> Optional[str]: + to_dump = { + "_description": DESCRIPTION, + "data": self.data, + "_version": VERSION, + } + # json.dumps doesn't end with a newline. Let's add one because files + # should end in newlines. + serialized = json.dumps(to_dump, **DUMP_OPTIONS) + "\n" + if to_str: + return serialized + with open(self.path, "w") as fp: + fp.write(serialized) + return None + + def save(self) -> None: + return self._save() + + def get_status(self, qualname: str, test_name: str) -> str: + if qualname not in self.data: + return "xsuccess" + dct = self.data[qualname] + if test_name not in dct: + return "xsuccess" + return dct[test_name]["status"] + + def set_status( + self, + qualname: str, + test_name: str, + status: str, + *, + comment: Optional[str] = None, + ): + if qualname not in self.data: + self.data[qualname] = {} + dct = self.data[qualname] + if test_name not in dct: + dct[test_name] = {"status": None, "comment": ""} + + if status == "xsuccess": + # The default status is "xsuccess". + del dct[test_name] + else: + dct[test_name]["status"] = status + if comment is not None: + dct[test_name]["comment"] = comment diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/optests/make_fx.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/optests/make_fx.py new file mode 100644 index 0000000000000000000000000000000000000000..970a0be1b36956d3693a5a93d07dbf32027c9773 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/optests/make_fx.py @@ -0,0 +1,89 @@ +# mypy: ignore-errors + +import torch +from torch.fx.experimental.proxy_tensor import make_fx +from torch.testing._utils import wrapper_set_seed +import torch.utils._pytree as pytree + + +def make_fx_check( + func, + args, + kwargs, + tracing_mode, + assert_close=torch.testing.assert_close, + randomize_data=False, +): + f, *new_args = handle_sizes_for_dynamic_shapes(func, args, kwargs) + + def run(f, *args, **kwargs): + return wrapper_set_seed(f, *args, **kwargs) + + traced_f = make_fx(f, tracing_mode=tracing_mode)(*new_args) + + msg = ( + "op(*args, **kwargs) and make_fx(op)(*args, **kwargs) produced different " + "values. This could mean that your abstract impls (meta/FakeTensor impls) " + "are incorrect, that your operator is not completely traceable (e.g., " + "it relies on some global state), or that there is a bug in make_fx. " + "Note that if you passed a python function (and not an operator) to " + "make_fx_check, it is still possible that the python function will still " + "work with torch.compile because it handles capturing pieces of " + "your python code to compile." + ) + + # Randomize the data and run the traced graph with it, to catch bugs + # where we may have baked in Tensor data into the trace. + # This is not guaranteed to succeed, because `f` might have preconditions + # on the values of the inputs, so we just ignore if we used + # random data and it fails. + if randomize_data: + new_args = randomize(new_args) + try: + expected = run(f, *new_args) + except Exception: + if randomize_data: + return + raise + result = run(traced_f, *new_args) + assert_close(result, expected, msg=msg) + + +# Arguably we should make make_fx promote torch.Size() objects to symbolic shapes. +# Absent that, here is our strategy: +# +# If any argument is a torch.Size(), maybe get dynamic shapes for it by: +# - Create a temporary Tensor whose size is the torch.Size() we want. Note that +# we use an expanded Tensor as we cannot pass "meta" Tensors to make_fx. +# - Pass it to make_fx such that it is converted to a proxy Tensor +# - Unpack the size in the wrapper to get a torch.Size with dynamic shapes (in +# symbolic mode, a no-op otherwise) +def handle_sizes_for_dynamic_shapes(func, args, kwargs): + def f(args, kwargs, extra_args, extra_kwargs): + if extra_args: + for i, t in extra_args: + args[i] = t.size() + if extra_kwargs: + for k, t in extra_kwargs.items(): + kwargs[k] = t.size() + + return func(*args, **kwargs) + + extra_args = [] + extra_kwargs = {} + for i, arg in enumerate(args): + if isinstance(arg, torch.Size): + extra_args.append((i, torch.empty(arg, device="cpu"))) + for key, value in kwargs.items(): + if isinstance(value, torch.Size): + extra_kwargs[key] = torch.empty(value, device="cpu") + + return f, args, kwargs, extra_args, extra_kwargs + + +def randomize(args): + def transform(x): + if not x.dtype.is_floating_point: + return x + return x.detach().clone().uniform_(0, 1).requires_grad_(x.requires_grad) + return pytree.tree_map_only(torch.Tensor, transform, args) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/quantization_torch_package_models.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/quantization_torch_package_models.py new file mode 100644 index 0000000000000000000000000000000000000000..abc4ab6f7e4734361ec7ecea3d4755910f9cf2ab --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/quantization_torch_package_models.py @@ -0,0 +1,33 @@ +# mypy: ignore-errors + +import math + +import torch +import torch.nn as nn + + +class LinearReluFunctionalChild(nn.Module): + def __init__(self, N): + super().__init__() + self.w1 = nn.Parameter(torch.empty(N, N)) + self.b1 = nn.Parameter(torch.zeros(N)) + torch.nn.init.kaiming_uniform_(self.w1, a=math.sqrt(5)) + + def forward(self, x): + x = torch.nn.functional.linear(x, self.w1, self.b1) + x = torch.nn.functional.relu(x) + return x + +class LinearReluFunctional(nn.Module): + def __init__(self, N): + super().__init__() + self.child = LinearReluFunctionalChild(N) + self.w1 = nn.Parameter(torch.empty(N, N)) + self.b1 = nn.Parameter(torch.zeros(N)) + torch.nn.init.kaiming_uniform_(self.w1, a=math.sqrt(5)) + + def forward(self, x): + x = self.child(x) + x = torch.nn.functional.linear(x, self.w1, self.b1) + x = torch.nn.functional.relu(x) + return x diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/test_module/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/test_module/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/test_module/future_div.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/test_module/future_div.py new file mode 100644 index 0000000000000000000000000000000000000000..0a3494f945fad36d84cb8056dcf722d6911f0af2 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/test_module/future_div.py @@ -0,0 +1,10 @@ +# mypy: ignore-errors + + + +def div_int_future(): + return 1 / 2 + + +def div_float_future(): + return 3.14 / 0.125 diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/test_module/no_future_div.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/test_module/no_future_div.py new file mode 100644 index 0000000000000000000000000000000000000000..164e6d168414a11039f3b63885760ad08b81ae99 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/testing/_internal/test_module/no_future_div.py @@ -0,0 +1,11 @@ +# mypy: ignore-errors + +import torch # noqa: F401 + + +def div_int_nofuture(): + return 1 / 2 + + +def div_float_nofuture(): + return 3.14 / 0.125 diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/xpu/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/xpu/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dff85be59148c87c4b1a8e0186a6bbe3f034496f Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/xpu/__pycache__/__init__.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/xpu/__pycache__/_gpu_trace.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/xpu/__pycache__/_gpu_trace.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..74144596dbc1c17bd39806fc044d69a311dffe62 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/xpu/__pycache__/_gpu_trace.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/xpu/__pycache__/_utils.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/xpu/__pycache__/_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4a25eaadf02e0a4758a014d9f6be5948742283c8 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/xpu/__pycache__/_utils.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/xpu/__pycache__/memory.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/xpu/__pycache__/memory.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9e2c02fce5d5ed3f6126b7ae2c0e53445b7250a1 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/xpu/__pycache__/memory.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/xpu/__pycache__/random.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/xpu/__pycache__/random.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..05421eb05d6302869e1b7771826ec32dd93742db Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/xpu/__pycache__/random.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/xpu/__pycache__/streams.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/xpu/__pycache__/streams.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6c11ba16d703f9bd39359de5da3fe7d119f75a13 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/xpu/__pycache__/streams.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/xxhash/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/xxhash/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b47cc53c3fec65d21739c04b287ef8eaa13462d4 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/xxhash/__pycache__/__init__.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/xxhash/__pycache__/version.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/xxhash/__pycache__/version.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..def89ae362bae7cedbb9a70d0beddcb4f9dfc4de Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/xxhash/__pycache__/version.cpython-312.pyc differ